Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# 0.42.1

SMC and PG can now be used for models with keyword arguments, albeit with one requirement: the user must mark the model function as being able to produce.
For example, if the model is

```julia
@model foo(x; y) = a ~ Normal(x, y)
```

then before samping from this with SMC or PG, you will have to run

```julia
using Libtask;
Libtask.@might_produce(foo);
```

# 0.42.0

## DynamicPPL 0.39
Expand Down
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.42.0"
version = "0.42.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -52,7 +52,7 @@ AbstractPPL = "0.11, 0.12, 0.13"
Accessors = "0.1"
AdvancedHMC = "0.8.3"
AdvancedMH = "0.8.9"
AdvancedPS = "0.7"
AdvancedPS = "0.7.2"
AdvancedVI = "0.6"
BangBang = "0.4.2"
Bijectors = "0.14, 0.15"
Expand All @@ -65,7 +65,7 @@ DynamicHMC = "3.4"
DynamicPPL = "0.39.1"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3, 1"
Libtask = "0.9.3"
Libtask = "0.9.5"
LinearAlgebra = "1"
LogDensityProblems = "2"
MCMCChains = "5, 6, 7"
Expand All @@ -86,3 +86,6 @@ julia = "1.10.8"

[extras]
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"

[sources]
AdvancedPS = {url = "https://github.com/TuringLang/AdvancedPS.jl", rev = "py/kwargs"}
Comment on lines +90 to +91
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will have to be removed pending release of TuringLang/AdvancedPS.jl#118

65 changes: 44 additions & 21 deletions src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,22 @@ struct ParticleMCMCContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext
rng::R
end

struct TracedModel{V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel
struct TracedModel{V<:AbstractVarInfo,M<:Model,T<:Tuple,NT<:NamedTuple} <:
AdvancedPS.AbstractTuringLibtaskModel
model::M
varinfo::V
evaluator::E
resample::Bool
fargs::T
kwargs::NT
end

function TracedModel(
model::Model, varinfo::AbstractVarInfo, rng::Random.AbstractRNG, resample::Bool
)
model = DynamicPPL.setleafcontext(model, ParticleMCMCContext(rng))
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo)
isempty(kwargs) || error(
"Particle sampling methods do not currently support models with keyword arguments.",
)
evaluator = (model.f, args...)
return TracedModel(model, varinfo, evaluator, resample)
fargs = (model.f, args...)
return TracedModel(model, varinfo, resample, fargs, kwargs)
end

function AdvancedPS.advance!(
Expand All @@ -53,16 +52,16 @@ function AdvancedPS.delete_retained!(trace::TracedModel)
# In such a case, we need to ensure that when we continue sampling (i.e.
# the next time we hit tilde_assume!!), we don't use the values in the
# reference particle but rather sample new values.
return TracedModel(trace.model, trace.varinfo, trace.evaluator, true)
return TracedModel(trace.model, trace.varinfo, true, trace.fargs, trace.kwargs)
end

function AdvancedPS.reset_model(trace::TracedModel)
return trace
end

function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...)
function Libtask.TapedTask(taped_globals, model::TracedModel)
return Libtask.TapedTask(
taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs...
taped_globals, model.fargs[1], model.fargs[2:end]...; model.kwargs...
)
end

Expand Down Expand Up @@ -124,6 +123,7 @@ function AbstractMCMC.sample(
)
check_model && _check_model(model, sampler)
error_if_threadsafe_eval(model)
check_model_kwargs(model)
# need to add on the `nparticles` keyword argument for `initialstep` to make use of
return AbstractMCMC.mcmcsample(
rng,
Expand All @@ -138,6 +138,31 @@ function AbstractMCMC.sample(
)
end

function check_model_kwargs(model::DynamicPPL.Model)
if !isempty(model.defaults)
# If there are keyword arguments, we need to check that the user has
# accounted for this by overloading `might_produce`.
might_produce = Libtask.might_produce(typeof((Core.kwcall, NamedTuple(), model.f)))
if !might_produce
io = IOBuffer()
ctx = IOContext(io, :color => true)
print(
ctx,
"Models with keyword arguments need special treatment to be used" *
" with particle methods. Please run:\n\n",
)
printstyled(
ctx,
" using Libtask; Libtask.@might_produce($(model.f))";
Copy link
Member Author

@penelopeysm penelopeysm Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could/should reexport Libtask.@might_produce or something similar from Turing, to make it easier?

bold=true,
color=:blue,
)
print(ctx, "\n\nbefore sampling from this model with particle methods.\n")
error(String(take!(io)))
end
end
end

function Turing.Inference.initialstep(
rng::AbstractRNG,
model::DynamicPPL.Model,
Expand All @@ -146,6 +171,7 @@ function Turing.Inference.initialstep(
nparticles::Int,
kwargs...,
)
check_model_kwargs(model)
# Reset the VarInfo.
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
vi = DynamicPPL.empty!!(vi)
Expand Down Expand Up @@ -254,6 +280,7 @@ function Turing.Inference.initialstep(
rng::AbstractRNG, model::DynamicPPL.Model, spl::PG, vi::AbstractVarInfo; kwargs...
)
error_if_threadsafe_eval(model)
check_model_kwargs(model)
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())

# Create a new set of particles
Expand Down Expand Up @@ -495,7 +522,7 @@ end
# details of the compiler, we set a bunch of methods as might_produce = true. We start with
# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the
# call stack.
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true
Libtask.@might_produce(DynamicPPL.accloglikelihood!!)
function Libtask.might_produce(
::Type{
<:Tuple{
Expand All @@ -507,15 +534,11 @@ function Libtask.might_produce(
)
return true
end
function Libtask.might_produce(
::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}}
)
return true
end
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true
# Could the next two could have tighter type bounds on the arguments, namely a GibbsContext?
Libtask.@might_produce(DynamicPPL.accumulate_observe!!)
Libtask.@might_produce(DynamicPPL.tilde_observe!!)
# Could tilde_assume!! have tighter type bounds on the arguments, namely a GibbsContext?
# That's the only thing that makes tilde_assume calls result in tilde_observe calls.
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.init!!),Vararg}}) = true
Libtask.@might_produce(DynamicPPL.tilde_assume!!)
Libtask.@might_produce(DynamicPPL.evaluate!!)
Libtask.@might_produce(DynamicPPL.init!!)
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true
5 changes: 4 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ ADTypes = "1"
AbstractMCMC = "5.9"
AbstractPPL = "0.11, 0.12, 0.13"
AdvancedMH = "0.8.9"
AdvancedPS = "0.7"
AdvancedPS = "0.7.2"
AdvancedVI = "0.6"
Aqua = "0.8"
BangBang = "0.4"
Expand Down Expand Up @@ -77,3 +77,6 @@ StatsBase = "0.33, 0.34"
StatsFuns = "0.9.5, 1"
TimerOutputs = "0.5"
julia = "1.10"

[sources]
AdvancedPS = {url = "https://github.com/TuringLang/AdvancedPS.jl", rev = "py/kwargs"}
Comment on lines +81 to +82
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise

21 changes: 18 additions & 3 deletions test/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ..Models: gdemo_default
using ..SamplerTestUtils: test_chain_logp_metadata
using AdvancedPS: ResampleWithESSThreshold, resample_systematic, resample_multinomial
using Distributions: Bernoulli, Beta, Gamma, Normal, sample
using Libtask: @might_produce
using Random: Random
using StableRNGs: StableRNG
using Test: @test, @test_throws, @testset
Expand Down Expand Up @@ -162,9 +163,23 @@ end
end

# https://github.com/TuringLang/Turing.jl/issues/2007
@testset "keyword arguments not supported" begin
@model kwarg_demo(; x=2) = return x
@test_throws ErrorException sample(kwarg_demo(), PG(1), 10)
@testset "keyword argument handling" begin
@model function kwarg_demo(y; n=0.0)
x ~ Normal(n)
return y ~ Normal(x)
end
@test_throws "Models with keyword arguments" sample(kwarg_demo(5.0), PG(20), 10)

# Check that enabling `might_produce` does allow sampling
@might_produce kwarg_demo
chain = sample(StableRNG(468), kwarg_demo(5.0), PG(20), 1000)
@test chain isa MCMCChains.Chains
@test mean(chain[:x]) ≈ 2.5 atol = 0.2

# Check that the keyword argument's value is respected
chain2 = sample(StableRNG(468), kwarg_demo(5.0; n=10.0), PG(20), 1000)
@test chain2 isa MCMCChains.Chains
@test mean(chain2[:x]) ≈ 7.5 atol = 0.2
end

@testset "refuses to run threadsafe eval" begin
Expand Down
Loading