Skip to content

Commit

Permalink
Bugfix for Optim.jl on models with different linked dimensionality (T…
Browse files Browse the repository at this point in the history
…uringLang#2196)

* fixed bug with optim interface

* bump patch version

* fixed test

* dirichlet onles has a unique mode for alpha > 1...
  • Loading branch information
torfjelde committed Apr 19, 2024
1 parent c29d36e commit fa6f30a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.30.7"
version = "0.30.8"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
6 changes: 3 additions & 3 deletions ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ function _optimize(
# Convert the initial values, since it is assumed that users provide them
# in the constrained space.
Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
Setfield.@set! f.varinfo = DynamicPPL.link!!(f.varinfo, model)
Setfield.@set! f.varinfo = DynamicPPL.link(f.varinfo, model)
init_vals = DynamicPPL.getparams(f)

# Optimize!
Expand All @@ -242,9 +242,9 @@ function _optimize(
# Get the VarInfo at the MLE/MAP point, and run the model to ensure
# correct dimensionality.
Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
Setfield.@set! f.varinfo = DynamicPPL.invlink!!(f.varinfo, model)
Setfield.@set! f.varinfo = DynamicPPL.invlink(f.varinfo, model)
vals = DynamicPPL.getparams(f)
Setfield.@set! f.varinfo = DynamicPPL.link!!(f.varinfo, model)
Setfield.@set! f.varinfo = DynamicPPL.link(f.varinfo, model)

# Make one transition to get the parameter names.
ts = [Turing.Inference.Transition(
Expand Down
8 changes: 8 additions & 0 deletions test/optimisation/OptimInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,12 @@ end
@test Turing.OptimLogDensity(m1, ctx)(w) == Turing.OptimLogDensity(m2, ctx)(w)
end
end

# Issue: https://discourse.julialang.org/t/turing-mixture-models-with-dirichlet-weightings/112910
@testset "with different linked dimensionality" begin
@model demo_dirichlet() = x ~ Dirichlet(2 * ones(3))
model = demo_dirichlet()
result = optimize(model, MAP())
@test result.values mode(Dirichlet(2 * ones(3))) atol=0.2
end
end

0 comments on commit fa6f30a

Please sign in to comment.