diff --git a/HISTORY.md b/HISTORY.md index 374946f4c..ff0af9377 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,19 @@ +# 0.42.1 + +SMC and PG can now be used for models with keyword arguments, albeit with one requirement: the user must mark the model function as being able to produce. +For example, if the model is + +```julia +@model foo(x; y) = a ~ Normal(x, y) +``` + +then before samping from this with SMC or PG, you will have to run + +```julia +using Libtask; +Libtask.@might_produce(foo); +``` + # 0.42.0 ## DynamicPPL 0.39 diff --git a/Project.toml b/Project.toml index 2831cd331..26f870b17 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.42.0" +version = "0.42.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -52,7 +52,7 @@ AbstractPPL = "0.11, 0.12, 0.13" Accessors = "0.1" AdvancedHMC = "0.8.3" AdvancedMH = "0.8.9" -AdvancedPS = "0.7" +AdvancedPS = "0.7.2" AdvancedVI = "0.6" BangBang = "0.4.2" Bijectors = "0.14, 0.15" @@ -65,7 +65,7 @@ DynamicHMC = "3.4" DynamicPPL = "0.39.1" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" -Libtask = "0.9.3" +Libtask = "0.9.5" LinearAlgebra = "1" LogDensityProblems = "2" MCMCChains = "5, 6, 7" @@ -86,3 +86,6 @@ julia = "1.10.8" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" + +[sources] +AdvancedPS = {url = "https://github.com/TuringLang/AdvancedPS.jl", rev = "py/kwargs"} diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 585d906cb..aff64a707 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -19,11 +19,13 @@ struct ParticleMCMCContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext rng::R end -struct TracedModel{V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel +struct TracedModel{V<:AbstractVarInfo,M<:Model,T<:Tuple,NT<:NamedTuple} <: + AdvancedPS.AbstractTuringLibtaskModel model::M varinfo::V - evaluator::E resample::Bool + fargs::T + kwargs::NT end function TracedModel( @@ -31,11 +33,8 @@ function TracedModel( ) model = DynamicPPL.setleafcontext(model, ParticleMCMCContext(rng)) args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo) - isempty(kwargs) || error( - "Particle sampling methods do not currently support models with keyword arguments.", - ) - evaluator = (model.f, args...) - return TracedModel(model, varinfo, evaluator, resample) + fargs = (model.f, args...) + return TracedModel(model, varinfo, resample, fargs, kwargs) end function AdvancedPS.advance!( @@ -53,16 +52,16 @@ function AdvancedPS.delete_retained!(trace::TracedModel) # In such a case, we need to ensure that when we continue sampling (i.e. # the next time we hit tilde_assume!!), we don't use the values in the # reference particle but rather sample new values. - return TracedModel(trace.model, trace.varinfo, trace.evaluator, true) + return TracedModel(trace.model, trace.varinfo, true, trace.fargs, trace.kwargs) end function AdvancedPS.reset_model(trace::TracedModel) return trace end -function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...) +function Libtask.TapedTask(taped_globals, model::TracedModel) return Libtask.TapedTask( - taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs... + taped_globals, model.fargs[1], model.fargs[2:end]...; model.kwargs... ) end @@ -124,6 +123,7 @@ function AbstractMCMC.sample( ) check_model && _check_model(model, sampler) error_if_threadsafe_eval(model) + check_model_kwargs(model) # need to add on the `nparticles` keyword argument for `initialstep` to make use of return AbstractMCMC.mcmcsample( rng, @@ -138,6 +138,31 @@ function AbstractMCMC.sample( ) end +function check_model_kwargs(model::DynamicPPL.Model) + if !isempty(model.defaults) + # If there are keyword arguments, we need to check that the user has + # accounted for this by overloading `might_produce`. + might_produce = Libtask.might_produce(typeof((Core.kwcall, NamedTuple(), model.f))) + if !might_produce + io = IOBuffer() + ctx = IOContext(io, :color => true) + print( + ctx, + "Models with keyword arguments need special treatment to be used" * + " with particle methods. Please run:\n\n", + ) + printstyled( + ctx, + " using Libtask; Libtask.@might_produce($(model.f))"; + bold=true, + color=:blue, + ) + print(ctx, "\n\nbefore sampling from this model with particle methods.\n") + error(String(take!(io))) + end + end +end + function Turing.Inference.initialstep( rng::AbstractRNG, model::DynamicPPL.Model, @@ -146,6 +171,7 @@ function Turing.Inference.initialstep( nparticles::Int, kwargs..., ) + check_model_kwargs(model) # Reset the VarInfo. vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) vi = DynamicPPL.empty!!(vi) @@ -254,6 +280,7 @@ function Turing.Inference.initialstep( rng::AbstractRNG, model::DynamicPPL.Model, spl::PG, vi::AbstractVarInfo; kwargs... ) error_if_threadsafe_eval(model) + check_model_kwargs(model) vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Create a new set of particles @@ -495,7 +522,7 @@ end # details of the compiler, we set a bunch of methods as might_produce = true. We start with # adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the # call stack. -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true +Libtask.@might_produce(DynamicPPL.accloglikelihood!!) function Libtask.might_produce( ::Type{ <:Tuple{ @@ -507,15 +534,11 @@ function Libtask.might_produce( ) return true end -function Libtask.might_produce( - ::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}} -) - return true -end -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true -# Could the next two could have tighter type bounds on the arguments, namely a GibbsContext? +Libtask.@might_produce(DynamicPPL.accumulate_observe!!) +Libtask.@might_produce(DynamicPPL.tilde_observe!!) +# Could tilde_assume!! have tighter type bounds on the arguments, namely a GibbsContext? # That's the only thing that makes tilde_assume calls result in tilde_observe calls. -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.init!!),Vararg}}) = true +Libtask.@might_produce(DynamicPPL.tilde_assume!!) +Libtask.@might_produce(DynamicPPL.evaluate!!) +Libtask.@might_produce(DynamicPPL.init!!) Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true diff --git a/test/Project.toml b/test/Project.toml index 70bf16fcf..96bda0824 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ ADTypes = "1" AbstractMCMC = "5.9" AbstractPPL = "0.11, 0.12, 0.13" AdvancedMH = "0.8.9" -AdvancedPS = "0.7" +AdvancedPS = "0.7.2" AdvancedVI = "0.6" Aqua = "0.8" BangBang = "0.4" @@ -77,3 +77,6 @@ StatsBase = "0.33, 0.34" StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" julia = "1.10" + +[sources] +AdvancedPS = {url = "https://github.com/TuringLang/AdvancedPS.jl", rev = "py/kwargs"} diff --git a/test/mcmc/particle_mcmc.jl b/test/mcmc/particle_mcmc.jl index f816fc43a..0f3353319 100644 --- a/test/mcmc/particle_mcmc.jl +++ b/test/mcmc/particle_mcmc.jl @@ -4,6 +4,7 @@ using ..Models: gdemo_default using ..SamplerTestUtils: test_chain_logp_metadata using AdvancedPS: ResampleWithESSThreshold, resample_systematic, resample_multinomial using Distributions: Bernoulli, Beta, Gamma, Normal, sample +using Libtask: @might_produce using Random: Random using StableRNGs: StableRNG using Test: @test, @test_throws, @testset @@ -162,9 +163,23 @@ end end # https://github.com/TuringLang/Turing.jl/issues/2007 - @testset "keyword arguments not supported" begin - @model kwarg_demo(; x=2) = return x - @test_throws ErrorException sample(kwarg_demo(), PG(1), 10) + @testset "keyword argument handling" begin + @model function kwarg_demo(y; n=0.0) + x ~ Normal(n) + return y ~ Normal(x) + end + @test_throws "Models with keyword arguments" sample(kwarg_demo(5.0), PG(20), 10) + + # Check that enabling `might_produce` does allow sampling + @might_produce kwarg_demo + chain = sample(StableRNG(468), kwarg_demo(5.0), PG(20), 1000) + @test chain isa MCMCChains.Chains + @test mean(chain[:x]) ≈ 2.5 atol = 0.2 + + # Check that the keyword argument's value is respected + chain2 = sample(StableRNG(468), kwarg_demo(5.0; n=10.0), PG(20), 1000) + @test chain2 isa MCMCChains.Chains + @test mean(chain2[:x]) ≈ 7.5 atol = 0.2 end @testset "refuses to run threadsafe eval" begin