diff --git a/Project.toml b/Project.toml index 491eebe..28a5155 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AdvancedPS" uuid = "576499cb-2369-40b2-a588-c64705576edc" authors = ["TuringLang"] -version = "0.7.1" +version = "0.7.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/ext/AdvancedPSLibtaskExt.jl b/ext/AdvancedPSLibtaskExt.jl index d3ba9e0..e85f583 100644 --- a/ext/AdvancedPSLibtaskExt.jl +++ b/ext/AdvancedPSLibtaskExt.jl @@ -37,6 +37,31 @@ function AdvancedPS.LibtaskModel( ) # Changed the API, need to take care of the RNG properly return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...)) end +# TODO: Upstream this to Turing +function AdvancedPS.LibtaskModel( + f::AdvancedPS.AbstractTuringLibtaskModel, rng::Random.AbstractRNG +) + return AdvancedPS.LibtaskModel( + f, Libtask.TapedTask(TapedGlobals(rng), f.fargs...; f.kwargs...) + ) +end + +const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R} + +function to_tapedtask( + newf::AdvancedPS.AbstractGenericModel, trace::LibtaskTrace, rng::Random.AbstractRNG +) + return Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf) +end +function to_tapedtask( + newf::AdvancedPS.AbstractTuringLibtaskModel, + trace::LibtaskTrace, + rng::Random.AbstractRNG, +) + return Libtask.TapedTask( + TapedGlobals(rng, get_other_global(trace)), newf.fargs...; newf.kwargs... + ) +end """ copy(model::AdvancedPS.LibtaskModel) @@ -47,8 +72,6 @@ function Base.copy(model::AdvancedPS.LibtaskModel) return AdvancedPS.LibtaskModel(deepcopy(model.f), copy(model.ctask)) end -const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R} - function Base.copy(trace::LibtaskTrace) newtrace = AdvancedPS.Trace(copy(trace.model), deepcopy(trace.rng)) set_other_global!(newtrace, newtrace) @@ -114,7 +137,7 @@ function AdvancedPS.forkr(trace::LibtaskTrace) newf = AdvancedPS.reset_model(trace.model.f) Random123.set_counter!(rng, 1) - ctask = Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf) + ctask = to_tapedtask(newf, trace, rng) new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask) # add backward reference diff --git a/src/AdvancedPS.jl b/src/AdvancedPS.jl index faa673e..93254f9 100644 --- a/src/AdvancedPS.jl +++ b/src/AdvancedPS.jl @@ -16,6 +16,10 @@ abstract type AbstractParticleSampler <: AbstractMCMC.AbstractSampler end abstract type AbstractStateSpaceModel <: AbstractParticleModel end abstract type AbstractGenericModel <: AbstractParticleModel end +# TODO(penelopeysm): This should be upstreamed to Turing together with anything that is +# Turing-specific in LibtaskExt. +abstract type AbstractTuringLibtaskModel <: AbstractGenericModel end + include("resampling.jl") include("rng.jl") include("model.jl")