Skip to content

Commit

Permalink
fixed typo in Q and llh calculation and exchanged L2 for L1 regulariz…
Browse files Browse the repository at this point in the history
…ation in combo model
  • Loading branch information
john-waczak committed Jul 22, 2024
1 parent 725938a commit 8fa283c
Show file tree
Hide file tree
Showing 13 changed files with 75 additions and 61 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/Manifest.toml
/docs/build/
.DS_Store
*.key
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GenerativeTopographicMapping"
uuid = "110c1e60-17ba-4aeb-8cee-444277a6d160"
authors = ["John Waczak <[email protected]>"]
version = "0.7.9"
version = "0.7.10"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Expand Down
19 changes: 0 additions & 19 deletions src/GenerativeTopographicMapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,6 @@ include("gsm-big-combo-base.jl")
include("gsm-big-combo-mlj.jl")


# # GSMMultupLinear
# include("gsm-multup-linear-base.jl")
# include("gsm-multup-linear-mlj.jl")


# # GSMMultupNoninear
# include("gsm-multup-nonlinear-base.jl")
# include("gsm-multup-nonlinear-mlj.jl")

# # # GSMMultupBigLinear
# include("gsm-multup-big-linear-base.jl")
# include("gsm-multup-big-linear-mlj.jl")

# # GSMMultupBigNonlinear
# include("gsm-multup-big-nonlinear-base.jl")
# include("gsm-multup-big-nonlinear-mlj.jl")



export GTM

export GSMLinear
Expand Down
6 changes: 3 additions & 3 deletions src/gsm-big-combo-mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end



function GSMBigCombo(; n_nodes=1000, n_rbfs=500, s=0.05, Nv=3, λe=0.01, λw=0.1, nepochs=100, niters=100, tol=1e-3, nconverged=4, rng=123,)
function GSMBigCombo(; n_nodes=1000, n_rbfs=500, s=0.05, Nv=3, λe=0.01, λw=0.1, nepochs=100, niters=10, tol=1e-3, nconverged=4, rng=123,)
model = GSMBigCombo(n_nodes, n_rbfs, s, Nv, λe, λw, nepochs, niters, tol, nconverged, mk_rng(rng),)
message = MLJModelInterface.clean!(model)
isempty(message) || @warn message
Expand Down Expand Up @@ -61,8 +61,8 @@ function MLJModelInterface.clean!(m::GSMBigCombo)
end

if m.niters 0
warning *= "Parameter `niters` expected to be positive, resetting to 100\n"
m.niters = 100
warning *= "Parameter `niters` expected to be positive, resetting to 10\n"
m.niters = 10
end

if m.tol 0
Expand Down
6 changes: 3 additions & 3 deletions src/gsm-big-linear-mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ mutable struct GSMBigLinear<: MLJModelInterface.Unsupervised
rng::Any
end

function GSMBigLinear(; n_nodes=1000, Nv=3, λ=0.1, nepochs=100, niters=100, tol=1e-3, nconverged=4, rng=123)
function GSMBigLinear(; n_nodes=1000, Nv=3, λ=0.1, nepochs=100, niters=10, tol=1e-3, nconverged=4, rng=123)
model = GSMBigLinear(n_nodes, Nv, λ, nepochs, niters, tol, nconverged, mk_rng(rng))
message = MLJModelInterface.clean!(model)
isempty(message) || @warn message
Expand Down Expand Up @@ -42,8 +42,8 @@ function MLJModelInterface.clean!(m::GSMBigLinear)
end

if m.niters 0
warning *= "Parameter `niters` expected to be positive, resetting to 100\n"
m.niters = 100
warning *= "Parameter `niters` expected to be positive, resetting to 10\n"
m.niters = 10
end

if m.tol 0
Expand Down
6 changes: 3 additions & 3 deletions src/gsm-big-nonlinear-mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mutable struct GSMBigNonlinear<: MLJModelInterface.Unsupervised
rng::Any
end

function GSMBigNonlinear(; n_nodes=1000, n_rbfs=500, s=0.25, Nv=3, λ=0.1, nepochs=100, niters=100, tol=1e-3, nconverged=4, rng=123)
function GSMBigNonlinear(; n_nodes=1000, n_rbfs=500, s=0.25, Nv=3, λ=0.1, nepochs=100, niters=10, tol=1e-3, nconverged=4, rng=123)
model = GSMBigNonlinear(n_nodes, n_rbfs, s, Nv, λ, nepochs, niters, tol, nconverged, mk_rng(rng))
message = MLJModelInterface.clean!(model)
isempty(message) || @warn message
Expand Down Expand Up @@ -55,8 +55,8 @@ function MLJModelInterface.clean!(m::GSMBigNonlinear)
end

if m.niters 0
warning *= "Parameter `niters` expected to be positive, resetting to 100\n"
m.niters = 100
warning *= "Parameter `niters` expected to be positive, resetting to 10\n"
m.niters = 10
end

if m.tol 0
Expand Down
30 changes: 15 additions & 15 deletions src/gsm-combo-base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ end



function fit!(gsm::GSMComboBase, Nv, X; λe = 0.01, λw=0.1, nepochs=100, tol=1e-3, nconverged=5, verbose=false, n_steps =100)
function fit!(gsm::GSMComboBase, Nv, X; λe = 0.01, λw=0.1, nepochs=10, tol=1e-3, nconverged=5, verbose=false, n_steps =100)
# get the needed dimensions
N,D = size(X)

Expand All @@ -93,7 +93,9 @@ function fit!(gsm::GSMComboBase, Nv, X; λe = 0.01, λw=0.1, nepochs=100, tol=1e
end

# this replaces the diagonal matrix in M-step
Λ = Diagonal(vcat(λe * ones(Nv), λw * ones(M-Nv)))
# Λ = Diagonal(vcat(λe * ones(Nv), λw * ones(M-Nv)))
Λ = λw .* ones(size(gsm.W))
Λ[:, 1:Nv] .= λe .* gsm.W[:, 1:Nv]

Q = 0.0
Q_prev = 0.0
Expand All @@ -111,7 +113,6 @@ function fit!(gsm::GSMComboBase, Nv, X; λe = 0.01, λw=0.1, nepochs=100, tol=1e
XtRt = X'*gsm.R'
WΦt = gsm.W*gsm.Φ'
= G*gsm.Φ
= gsm.W*Λ

@assert all(gsm.W .≥ 0.0)

Expand All @@ -129,18 +130,16 @@ function fit!(gsm::GSMComboBase, Nv, X; λe = 0.01, λw=0.1, nepochs=100, tol=1e
# MAXIMIZATION

# 1. Update the πk values
# gsm.πk .= (1/N) .* sum(gsm.R, dims=2)
gsm.πk .= max.((1/N) .* sum(gsm.R, dims=2), eps(eltype(gsm.πk)))
for n axes(LnΠ,2)
LnΠ[:,n] .= log.(gsm.πk)
end


# 2. update weight matrix
# gsm.W = ((gsm.Φ'*G*gsm.Φ + gsm.β⁻¹*Λ)\(gsm.Φ'*gsm.R*X))'

for step 1:n_steps
# gsm.W .*= max.((X' * gsm.R' * gsm.Φ ./ gsm.β⁻¹), 0.0) ./ max.((gsm.W * gsm.Φ' * G * gsm.Φ ./ gsm.β⁻¹ + (gsm.W * Λ)), eps(eltype(gsm.W)))
# update Λ matrix
Λ[:, 1:Nv] .= λe .* gsm.W[:, 1:Nv]

# update numerator
mul!(XtRt, X', gsm.R')
Expand All @@ -150,16 +149,15 @@ function fit!(gsm::GSMComboBase, Nv, X; λe = 0.01, λw=0.1, nepochs=100, tol=1e
# update denominator
mul!(WΦt, gsm.W, gsm.Φ')
mul!(GΦ, G, gsm.Φ)
mul!(WΛ, gsm.W, Λ)

mul!(Denom, WΦt, GΦ)
Denom ./= gsm.β⁻¹
Denom .+=
Denom .+= Λ

# update weights
gsm.W .*= max.(Numer, 0.0) ./ max.(Denom, eps(eltype(gsm.W)))
end


@assert all(gsm.W .≥ 0.0)

# 3. update precision β
Expand All @@ -169,21 +167,23 @@ function fit!(gsm::GSMComboBase, Nv, X; λe = 0.01, λw=0.1, nepochs=100, tol=1e

# UPDATE LOG-LIKELIHOOD
prefac = (N*D/2)*log(1/(2* gsm.β⁻¹* π))
gsm.Δ² .*= -(1/(2*gsm.β⁻¹))

if i == 1
l = max(prefac + sum(logsumexp(gsm.Δ² .* LnΠ, dims=1)), nextfloat(typemin(1.0)))
# l = max(prefac + sum(logsumexp(gsm.Δ² .* LnΠ, dims=1)), nextfloat(typemin(1.0)))
l = max(prefac + sum(logsumexp(gsm.Δ² .+ LnΠ, dims=1)), nextfloat(typemin(1.0)))

Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* LnΠ) - sum(gsm.R .* gsm.Δ²) + (D*(M-Nv)/2)*log(λw/(2π)) + ((D*Nv)/2)*log(λe/(2π)) - (λe/2)*sum(gsm.W[:,1:Nv]) - (λe/2)*sum(gsm.W[:,Nv+1:end])
Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* (LnΠ .- gsm.Δ²)) + ((D*Nv)/2)*log(λe/(2π)) - (λe/2)*sum(gsm.W[:,1:Nv].^2) + D*(M-Nv)*log(λw/2) - λw*sum(gsm.W[:,Nv+1:end])

push!(llhs, l)
push!(Qs, Q)
else
Q_prev = Q

l = max(prefac + sum(logsumexp(gsm.Δ² .* LnΠ, dims=1)), nextfloat(typemin(1.0)))

Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* LnΠ) - sum(gsm.R .* gsm.Δ²) + (D*(M-Nv)/2)*log(λw/(2π)) + ((D*Nv)/2)*log(λe/(2π)) - (λe/2)*sum(gsm.W[:,1:Nv]) - (λe/2)*sum(gsm.W[:,Nv+1:end])
# l = max(prefac + sum(logsumexp(gsm.Δ² .* LnΠ, dims=1)), nextfloat(typemin(1.0)))
l = max(prefac + sum(logsumexp(gsm.Δ² .+ LnΠ, dims=1)), nextfloat(typemin(1.0)))

Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* (LnΠ .- gsm.Δ²)) + ((D*Nv)/2)*log(λe/(2π)) - (λe/2)*sum(gsm.W[:,1:Nv].^2) + D*(M-Nv)*log(λw/2) - λw*sum(gsm.W[:,Nv+1:end])

push!(llhs, l)
push!(Qs, Q)
Expand Down
6 changes: 3 additions & 3 deletions src/gsm-combo-mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ end



function GSMCombo(; k=10, m=5, Nv=3, λe=0.01, λw=0.1, nepochs=100, niters=100, tol=1e-3, nconverged=4, rng=123)
function GSMCombo(; k=10, m=5, Nv=3, λe=0.01, λw=0.1, nepochs=10, niters=100, tol=1e-3, nconverged=4, rng=123)
model = GSMCombo(k, m, Nv, λe, λw, nepochs, niters, tol, nconverged, mk_rng(rng))
message = MLJModelInterface.clean!(model)
isempty(message) || @warn message
Expand Down Expand Up @@ -55,8 +55,8 @@ function MLJModelInterface.clean!(m::GSMCombo)
end

if m.niters 0
warning *= "Parameter `niters` expected to be positive, resetting to 100\n"
m.niters = 100
warning *= "Parameter `niters` expected to be positive, resetting to 10\n"
m.niters = 10
end

if m.tol 0
Expand Down
13 changes: 9 additions & 4 deletions src/gsm-linear-base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,24 @@ function fit!(gsm::GSMLinearBase, X; λ = 0.1, nepochs=100, tol=1e-3, nconverged

# UPDATE Q and LOG-LIKELIHOOD
prefac = (N*D/2)*log(1/(2* gsm.β⁻¹* π))
gsm.Δ² .*= -(1/(2*gsm.β⁻¹))

if i == 1
l = max(prefac + sum(logsumexp(gsm.Δ² .* LnΠ, dims=1)), nextfloat(typemin(1.0)))
# l = max(prefac + sum(logsumexp(gsm.Δ² .* LnΠ, dims=1)), nextfloat(typemin(1.0)))
l = max(prefac + sum(logsumexp(gsm.Δ² .+ LnΠ, dims=1)), nextfloat(typemin(1.0)))

Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* LnΠ) + (length(gsm.W)/2)*log/(2π)) - sum(gsm.R .* gsm.Δ²) -/2)*sum(gsm.W)
# Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* LnΠ) + (length(gsm.W)/2)*log(λ/(2π)) - sum(gsm.R .* gsm.Δ²) - (λ/2)*sum(gsm.W)
Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* (LnΠ .- gsm.Δ²)) + (length(gsm.W)/2)*log/(2π)) -/2)*sum(gsm.W .^ 2)

push!(llhs, l)
push!(Qs, Q)
else
Q_prev = Q

l = max(prefac + sum(logsumexp(gsm.Δ² .* LnΠ, dims=1)), nextfloat(typemin(1.0)))
Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* LnΠ) + (length(gsm.W)/2)*log/(2π)) - sum(gsm.R .* gsm.Δ²) -/2)*sum(gsm.W)
# l = max(prefac + sum(logsumexp(gsm.Δ² .* LnΠ, dims=1)), nextfloat(typemin(1.0)))
l = max(prefac + sum(logsumexp(gsm.Δ² .+ LnΠ, dims=1)), nextfloat(typemin(1.0)))
# Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* LnΠ) + (length(gsm.W)/2)*log(λ/(2π)) - sum(gsm.R .* gsm.Δ²) - (λ/2)*sum(gsm.W)
Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* (LnΠ .- gsm.Δ²)) + (length(gsm.W)/2)*log/(2π)) -/2)*sum(gsm.W .^ 2)

push!(llhs, l)
push!(Qs, Q)
Expand Down
6 changes: 3 additions & 3 deletions src/gsm-linear-mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ end



function GSMLinear(; k=10, Nv=3, λ=0.1, nepochs=100, niters=100, tol=1e-3, nconverged=4, rng=123)
function GSMLinear(; k=10, Nv=3, λ=0.1, nepochs=100, niters=10, tol=1e-3, nconverged=4, rng=123)
model = GSMLinear(k, Nv, λ, nepochs, niters, tol, nconverged, mk_rng(rng))
message = MLJModelInterface.clean!(model)
isempty(message) || @warn message
Expand Down Expand Up @@ -43,8 +43,8 @@ function MLJModelInterface.clean!(m::GSMLinear)
end

if m.niters 0
warning *= "Parameter `niters` expected to be positive, resetting to 100\n"
m.niters = 100
warning *= "Parameter `niters` expected to be positive, resetting to 10\n"
m.niters = 10
end

if m.tol 0
Expand Down
14 changes: 10 additions & 4 deletions src/gsm-nonlinear-base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,19 +150,25 @@ function fit!(gsm::GSMNonlinearBase, X; λ = 0.1, nepochs=100, tol=1e-3, nconver

# UPDATE LOG-LIKELIHOOD
prefac = (N*D/2)*log(1/(2* gsm.β⁻¹* π))
gsm.Δ² .*= -(1/(2*gsm.β⁻¹))

if i == 1
l = max(prefac + sum(logsumexp(gsm.Δ² .* LnΠ, dims=1)), nextfloat(typemin(1.0)))
# l = max(prefac + sum(logsumexp(gsm.Δ² .* LnΠ, dims=1)), nextfloat(typemin(1.0)))
l = max(prefac + sum(logsumexp(gsm.Δ² .+ LnΠ, dims=1)), nextfloat(typemin(1.0)))

Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* LnΠ) + (length(gsm.W)/2)*log/(2π)) - sum(gsm.R .* gsm.Δ²) -/2)*sum(gsm.W)
# Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* LnΠ) + (length(gsm.W)/2)*log(λ/(2π)) - sum(gsm.R .* gsm.Δ²) - (λ/2)*sum(gsm.W)
Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* (LnΠ .- gsm.Δ²)) + (length(gsm.W)/2)*log/(2π)) -/2)*sum(gsm.W .^ 2)

push!(llhs, l)
push!(Qs, Q)
else
Q_prev = Q

l = max(prefac + sum(logsumexp(gsm.Δ² .* LnΠ, dims=1)), nextfloat(typemin(1.0)))
Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* LnΠ) + (length(gsm.W)/2)*log/(2π)) - sum(gsm.R .* gsm.Δ²) -/2)*sum(gsm.W)
# l = max(prefac + sum(logsumexp(gsm.Δ² .* LnΠ, dims=1)), nextfloat(typemin(1.0)))
l = max(prefac + sum(logsumexp(gsm.Δ² .+ LnΠ, dims=1)), nextfloat(typemin(1.0)))

# Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* LnΠ) + (length(gsm.W)/2)*log(λ/(2π)) - sum(gsm.R .* gsm.Δ²) - (λ/2)*sum(gsm.W)
Q = (N*D/2)*log(1/(2* gsm.β⁻¹* π)) + sum(gsm.R .* (LnΠ .- gsm.Δ²)) + (length(gsm.W)/2)*log/(2π)) -/2)*sum(gsm.W .^ 2)

push!(llhs, l)
push!(Qs, Q)
Expand Down
6 changes: 3 additions & 3 deletions src/gsm-nonlinear-mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ end



function GSMNonlinear(; k=10, m=5, s=1.0, Nv=3, λ=0.1, nepochs=100, niters=100, tol=1e-3, nconverged=4, rng=123)
function GSMNonlinear(; k=10, m=5, s=1.0, Nv=3, λ=0.1, nepochs=100, niters=10, tol=1e-3, nconverged=4, rng=123)
model = GSMNonlinear(k, m, s, Nv, λ, nepochs, niters, tol, nconverged, mk_rng(rng))
message = MLJModelInterface.clean!(model)
isempty(message) || @warn message
Expand Down Expand Up @@ -55,8 +55,8 @@ function MLJModelInterface.clean!(m::GSMNonlinear)
end

if m.niters 0
warning *= "Parameter `niters` expected to be positive, resetting to 100\n"
m.niters = 100
warning *= "Parameter `niters` expected to be positive, resetting to 10\n"
m.niters = 10
end

if m.tol 0
Expand Down
21 changes: 21 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,27 @@ end

@test all(rpt[:W] .≥ 0.0)


# fit again with large λw and check sparsity
model = GSMCombo(k=k, m=m, Nv=Nv, λw=0.001, nepochs=100, rng=rng, tol=1e-6)
mach = machine(model, X)
fit!(mach, verbosity=0)

rpt = report(mach)
W = rpt[:W]
μ1 = mean(W[:, Nv+1:end])

model = GSMCombo(k=k, m=m, Nv=Nv, λw=1.0, nepochs=100, rng=rng, tol=1e-6)
mach = machine(model, X)
fit!(mach, verbosity=0)

rpt = report(mach)
W = rpt[:W]
μ2 = mean(W[:, Nv+1:end])

println(μ1, "\t", μ2)
@test μ2 < μ1
@test isapprox(μ2, 0.0, atol=1e-8)
end


Expand Down

0 comments on commit 8fa283c

Please sign in to comment.