Skip to content

Commit

Permalink
Bump DynamicPPL to v0.25 (TuringLang#2197)
Browse files Browse the repository at this point in the history
* reeclpad Setfield with Accessors to bump up to DPPL v0.25

* bump DPPL version

* use Accessors

* replaced usages of `@set!` with `BangBang.@set!!`

* fixed Project.toml

* reverted accidental change

* import BangBang in Turing

* replace `BangBang.@set!!` with `Accessors.@set`

* bump minor version since this is a breaking change

* makke failing test conditional on Julia version >1.7

* fixed references to Setfield.jl in Experimental module

* disabled another test due to the same issue

---------

Co-authored-by: Xianda Sun <sunxdt@gmail.com>
Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 23, 2024
1 parent a022dc6 commit 9be6b79
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 50 deletions.
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.30.9"
version = "0.31.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc"
Expand All @@ -29,7 +30,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
Expand All @@ -47,18 +47,19 @@ TuringOptimExt = "Optim"
[compat]
ADTypes = "0.2"
AbstractMCMC = "5.2"
Accessors = "0.1"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
AdvancedMH = "0.8"
AdvancedPS = "0.5.4"
AdvancedVI = "0.2"
BangBang = "0.3"
BangBang = "0.4"
Bijectors = "0.13.6"
DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.24.10"
DynamicPPL = "0.25.1"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
Expand All @@ -70,7 +71,6 @@ Optim = "1"
Reexport = "0.2, 1"
Requires = "0.5, 1.0"
SciMLBase = "1.37.1, 2"
Setfield = "0.8, 1"
SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2"
Statistics = "1.6"
StatsAPI = "1.6"
Expand Down
18 changes: 9 additions & 9 deletions ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ module TuringOptimExt

if isdefined(Base, :get_extension)
import Turing
import Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Setfield, Statistics, StatsAPI, StatsBase
import Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Accessors, Statistics, StatsAPI, StatsBase
import Optim
else
import ..Turing
import ..Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Setfield, Statistics, StatsAPI, StatsBase
import ..Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Accessors, Statistics, StatsAPI, StatsBase
import ..Optim
end

Expand Down Expand Up @@ -80,7 +80,7 @@ function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff
# Hessian is computed with respect to the untransformed parameters.
linked = DynamicPPL.istrans(m.f.varinfo)
if linked
Setfield.@set! m.f.varinfo = DynamicPPL.invlink!!(m.f.varinfo, m.f.model)
m = Accessors.@set m.f.varinfo = DynamicPPL.invlink!!(m.f.varinfo, m.f.model)
end

# Calculate the Hessian, which is the information matrix because the negative of the log likelihood was optimized
Expand All @@ -89,7 +89,7 @@ function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff

# Link it back if we invlinked it.
if linked
Setfield.@set! m.f.varinfo = DynamicPPL.link!!(m.f.varinfo, m.f.model)
m = Accessors.@set m.f.varinfo = DynamicPPL.link!!(m.f.varinfo, m.f.model)
end

return NamedArrays.NamedArray(info, (varnames, varnames))
Expand Down Expand Up @@ -227,8 +227,8 @@ function _optimize(
)
# Convert the initial values, since it is assumed that users provide them
# in the constrained space.
Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
Setfield.@set! f.varinfo = DynamicPPL.link(f.varinfo, model)
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
f = Accessors.@set f.varinfo = DynamicPPL.link(f.varinfo, model)
init_vals = DynamicPPL.getparams(f)

# Optimize!
Expand All @@ -241,10 +241,10 @@ function _optimize(

# Get the VarInfo at the MLE/MAP point, and run the model to ensure
# correct dimensionality.
Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
Setfield.@set! f.varinfo = DynamicPPL.invlink(f.varinfo, model)
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
f = Accessors.@set f.varinfo = DynamicPPL.invlink(f.varinfo, model)
vals = DynamicPPL.getparams(f)
Setfield.@set! f.varinfo = DynamicPPL.link(f.varinfo, model)
f = Accessors.@set f.varinfo = DynamicPPL.link(f.varinfo, model)

# Make one transition to get the parameter names.
ts = [Turing.Inference.Transition(
Expand Down
4 changes: 3 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ using DynamicPPL: DynamicPPL, LogDensityFunction
import DynamicPPL: getspace, NoDist, NamedDist
import LogDensityProblems
import NamedArrays
import Setfield
import Accessors
import StatsAPI
import StatsBase

using Accessors: Accessors

import Printf
import Random

Expand Down
2 changes: 1 addition & 1 deletion src/experimental/Experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Experimental
using Random: Random
using AbstractMCMC: AbstractMCMC
using DynamicPPL: DynamicPPL, VarName
using Setfield: Setfield
using Accessors: Accessors

using DocStringExtensions: TYPEDFIELDS
using Distributions
Expand Down
10 changes: 5 additions & 5 deletions src/experimental/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ Returns the preferred value type for a variable with the given `varinfo`.
preferred_value_type(::DynamicPPL.AbstractVarInfo) = DynamicPPL.OrderedDict
preferred_value_type(::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = NamedTuple
function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo)
# We can only do this in the scenario where all the varnames are `Setfield.IdentityLens`.
# We can only do this in the scenario where all the varnames are `Accessors.IdentityLens`.
namedtuple_compatible = all(varinfo.metadata) do md
eltype(md.vns) <: VarName{<:Any,Setfield.IdentityLens}
eltype(md.vns) <: VarName{<:Any,typeof(identity)}
end
return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict
end
Expand Down Expand Up @@ -321,8 +321,8 @@ function AbstractMCMC.step(
)

# Update the `states` and `varinfos`.
states = Setfield.setindex(states, new_state_local, index)
varinfos = Setfield.setindex(varinfos, new_varinfo_local, index)
states = Accessors.setindex(states, new_state_local, index)
varinfos = Accessors.setindex(varinfos, new_varinfo_local, index)
end

# Combine the resulting varinfo objects.
Expand All @@ -349,7 +349,7 @@ function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler
# NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide
# a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact
# same `selector` as before but now with `rerun` set to `true` if needed.
return Setfield.@set sampler.selector.rerun = true
return Accessors.@set sampler.selector.rerun = true
end

# Interface we need a sampler to implement to work as a component in a Gibbs sampler.
Expand Down
4 changes: 2 additions & 2 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ using DynamicPPL
using AbstractMCMC: AbstractModel, AbstractSampler
using DocStringExtensions: TYPEDEF, TYPEDFIELDS
using DataStructures: OrderedSet
using Setfield: Setfield
using Accessors: Accessors

import ADTypes
import AbstractMCMC
import AdvancedHMC; const AHMC = AdvancedHMC
import AdvancedMH; const AMH = AdvancedMH
import AdvancedPS
import BangBang
import Accessors
import EllipticalSliceSampling
import LogDensityProblems
import LogDensityProblemsAD
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.pa
getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo
getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper) = getvarinfo(parent(f))

setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo) = Setfield.@set f.varinfo = varinfo
setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo) = Accessors.@set f.varinfo = varinfo
setvarinfo(f::LogDensityProblemsAD.ADGradientWrapper, varinfo) = setvarinfo(parent(f), varinfo)

# TODO: Do we also support `resume`, etc?
Expand Down
4 changes: 2 additions & 2 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Bijectors
using Random
using SciMLBase: OptimizationFunction, OptimizationProblem, AbstractADType, NoAD

using Setfield
using Accessors: Accessors
using DynamicPPL
using DynamicPPL: Model, AbstractContext, VarInfo, VarName,
_getindex, getsym, getfield, setorder!,
Expand Down Expand Up @@ -150,7 +150,7 @@ function transform!!(f::OptimLogDensity)
linked = DynamicPPL.istrans(f.varinfo)

## transform into constrained or unconstrained space depending on current state of vi
@set! f.varinfo = if !linked
f = Accessors.@set f.varinfo = if !linked
DynamicPPL.link!!(f.varinfo, f.model)
else
DynamicPPL.invlink!!(f.varinfo, f.model)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Clustering = "0.14, 0.15"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.24"
DynamicPPL = "0.25.1"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
LogDensityProblems = "2"
Expand Down
20 changes: 12 additions & 8 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,18 @@
end
end

@turing_testset "(partially) issue: #2095" begin
@model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV}
xs = Vector{TV}(undef, 2)
xs[1] ~ Dirichlet(ones(5))
xs[2] ~ Dirichlet(ones(5))
# Disable on Julia <1.8 due to https://github.com/TuringLang/Turing.jl/pull/2197.
# TODO: Remove this block once https://github.com/JuliaFolds2/BangBang.jl/pull/22 has been released.
if VERSION v"1.8"
@turing_testset "(partially) issue: #2095" begin
@model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV}
xs = Vector{TV}(undef, 2)
xs[1] ~ Dirichlet(ones(5))
xs[2] ~ Dirichlet(ones(5))
end
model = vector_of_dirichlet()
chain = sample(model, NUTS(), 1000)
@test mean(Array(chain)) 0.2
end
model = vector_of_dirichlet()
chain = sample(model, NUTS(), 1000)
@test mean(Array(chain)) 0.2
end
end
34 changes: 19 additions & 15 deletions test/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,24 +162,28 @@
# @test v1 < v2
end

@turing_testset "vector of multivariate distributions" begin
@model function test(k)
T = Vector{Vector{Float64}}(undef, k)
for i in 1:k
T[i] ~ Dirichlet(5, 1.0)
# Disable on Julia <1.8 due to https://github.com/TuringLang/Turing.jl/pull/2197.
# TODO: Remove this block once https://github.com/JuliaFolds2/BangBang.jl/pull/22 has been released.
if VERSION v"1.8"
@turing_testset "vector of multivariate distributions" begin
@model function test(k)
T = Vector{Vector{Float64}}(undef, k)
for i in 1:k
T[i] ~ Dirichlet(5, 1.0)
end
end
end

Random.seed!(100)
chain = sample(test(1), MH(), 5_000)
for i in 1:5
@test mean(chain, "T[1][$i]") 0.2 atol=0.01
end
Random.seed!(100)
chain = sample(test(1), MH(), 5_000)
for i in 1:5
@test mean(chain, "T[1][$i]") 0.2 atol = 0.01
end

Random.seed!(100)
chain = sample(test(10), MH(), 5_000)
for j in 1:10, i in 1:5
@test mean(chain, "T[$j][$i]") 0.2 atol=0.01
Random.seed!(100)
chain = sample(test(10), MH(), 5_000)
for j in 1:10, i in 1:5
@test mean(chain, "T[$j][$i]") 0.2 atol = 0.01
end
end
end

Expand Down

0 comments on commit 9be6b79

Please sign in to comment.