Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to externalsampler #2204

Merged
merged 23 commits into from
Apr 28, 2024
Merged

Conversation

torfjelde
Copy link
Member

Currently there's no way of choosing an ad backend for externalsampler, nor to specify whether it requires constrained or unconstrained parameters. This PR adds those options.

In addition, the PR adds a default_adtype method to be more consistent across the codebase.

src/Turing.jl Outdated Show resolved Hide resolved
src/mcmc/Inference.jl Outdated Show resolved Hide resolved
@yebai
Copy link
Member

yebai commented Apr 20, 2024

Nice improvements to externalsampler!

@coveralls
Copy link

coveralls commented Apr 20, 2024

Pull Request Test Coverage Report for Build 8808209405

Details

  • 0 of 29 (0.0%) changed or added relevant lines in 4 files are covered.
  • 1 unchanged line in 1 file lost coverage.
  • Overall coverage remained the same at 0.0%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/deprecated.jl 0 2 0.0%
src/mcmc/hmc.jl 0 2 0.0%
src/mcmc/Inference.jl 0 9 0.0%
src/mcmc/abstractmcmc.jl 0 16 0.0%
Files with Coverage Reduction New Missed Lines %
src/mcmc/abstractmcmc.jl 1 0.0%
Totals Coverage Status
Change from base Build 8803479078: 0.0%
Covered Lines: 0
Relevant Lines: 1529

💛 - Coveralls

src/Turing.jl Outdated
@@ -18,6 +18,10 @@ import StatsBase
import Printf
import Random

using ADTypes: ADTypes

default_adtype() = ADTypes.AutoForwardDiff()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't a constant be simpler?

Suggested change
default_adtype() = ADTypes.AutoForwardDiff()
const DEFAULT_ADTYPE = ADTypes.AutoForwardDiff()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually did that first 😅

Buuuut using a method does allow an advanced user to actuallly overload the default AD backend that is used 🤷

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not completely convinced that this should be supported since it has such widereaching consequences. Some limited flexibility could be allowed by making the default a Preferences-constant that could be modified by users.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made it a const as desired:)

Comment on lines 102 to 104
function ExternalSampler(sampler::S, adtype::AD, unconstrained::Bool=true) where {S<:AbstractSampler,AD<:ADTypes.AbstractADType}
return new{S,AD,unconstrained}(sampler, adtype)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, the return type can't be inferred here. Maybe make unconstrained the first type parameter and only support

ExternalSampler{unconstrained}(sampler, adtype) where unconstrained = ExternalSampler{unconstrained,typeof(sampler),typeof(adtype)}(sampler, adtype)
ExternalSampler(sampler, adtype) = ExternalSampler{true}(sampler, adtype)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am aware, and am happy to make this change 👍 Buuuut the user will be using externalsampler to construct it anyways, which will never be inferrable either 😕

Buuut I agree with you;; we should at least have the option of using a type-inferrable constructor so will make the change 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use Val(true)/Val(false) in externalsampler to make it fully inferrable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use Val(true)/Val(false) in externalsampler to make it fully inferrable?

Don't like this as it requires users to be familiar with what Val is, which doesn't seem worth it given that externalsampler will never be used in "hot" code (it's really only going to be used before sample(...), which is not type-stable anyways) 😕

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the changes you wanted to the constructor (though leeft externalsampler as is)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making use of Val for constructor ExternalSampler, but leaving externalsampler as is I think

f = DynamicPPL.LogDensityFunction(model)
d = LogDensityProblems.dimension(f)
return AdvancedMH.RWMH(MvNormal(Zeros(d), 0.1 * I))
end

# TODO: Should this go somewhere else?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably DynamicPPL I'd say?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was thinking the same. Only thing holding me back is that it's unclear to me if this is indeed how we want to implement something like this 😕 It was a quick--and-dirty way of doing it, hence why I figured I'd just leave it here for now.

@torfjelde
Copy link
Member Author

You happy with the current version @devmotion ?

@torfjelde
Copy link
Member Author

I'll add one final thing to this PR: fix issues with initial params when working with unconstrained models.

@torfjelde
Copy link
Member Author

Added initial params handling + testing. Should pass as usual, at which point we can merge:)

@torfjelde
Copy link
Member Author

torfjelde commented Apr 23, 2024

I'll merge this once tests have passed unless anyone has any immediate concerns 👍 Should go in the same release as #2197 as this is also technically breaking

@torfjelde
Copy link
Member Author

Actually, nvm, this is not breaking 🤦 Will check that tests run after merge with master, and then will merge (unless there are objections)

src/Turing.jl Outdated Show resolved Hide resolved
src/mcmc/Inference.jl Outdated Show resolved Hide resolved
src/mcmc/Inference.jl Outdated Show resolved Hide resolved
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
@torfjelde
Copy link
Member Author

Accepted your changes @devmotion ; thanks!

@yebai yebai merged commit 17ab13d into master Apr 28, 2024
11 checks passed
@yebai yebai deleted the torfjelde/external-sampler-improvements branch April 28, 2024 10:36
@yebai
Copy link
Member

yebai commented Apr 28, 2024

@torfjelde, the PR passes all CI tests but seems to have introduced several CI errors after merging into master.

@devmotion
Copy link
Member

I assume the errors are unrelated to this PR and caused by the latest release of AliasTables: LilithHafner/AliasTables.jl#48

@devmotion
Copy link
Member

Should be fixed by AliasTables 1.1.1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants