Skip to content

Commit 28c9e2e

Browse files
committed
Add a different struct that can pass kwargs on to Libtask
1 parent e100352 commit 28c9e2e

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AdvancedPS"
22
uuid = "576499cb-2369-40b2-a588-c64705576edc"
33
authors = ["TuringLang"]
4-
version = "0.7.1"
4+
version = "0.7.2"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

ext/AdvancedPSLibtaskExt.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,30 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
3535
function AdvancedPS.LibtaskModel(
3636
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
3737
) # Changed the API, need to take care of the RNG properly
38-
return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...))
38+
return AdvancedPS.LibtaskModel(
39+
f, Libtask.TapedTask(TapedGlobals(rng), f, args...))
40+
)
41+
end
42+
# TODO: Upstream this to Turing
43+
function AdvancedPS.LibtaskModel(
44+
f::AdvancedPS.AbstractTuringLibtaskModel, rng::Random.AbstractRNG
45+
)
46+
return AdvancedPS.LibtaskModel(
47+
f, Libtask.TapedTask(TapedGlobals(rng), f.fargs...; f.kwargs...)
48+
)
49+
end
50+
51+
function to_tapedtask(
52+
newf::AdvancedPS.AbstractGenericModel, trace::LibtaskTrace, rng::Random.AbstractRNG
53+
)
54+
return Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf)
55+
end
56+
function to_tapedtask(
57+
newf::AdvancedPS.AbstractTuringLibtaskModel, trace::LibtaskTrace, rng::Random.AbstractRNG
58+
)
59+
return Libtask.TapedTask(
60+
TapedGlobals(rng, get_other_global(trace)), newf.fargs...; newf.kwargs...
61+
)
3962
end
4063

4164
"""
@@ -114,7 +137,7 @@ function AdvancedPS.forkr(trace::LibtaskTrace)
114137
newf = AdvancedPS.reset_model(trace.model.f)
115138
Random123.set_counter!(rng, 1)
116139

117-
ctask = Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf)
140+
ctask = to_tapedtask(newf, trace, rng)
118141
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)
119142

120143
# add backward reference

src/AdvancedPS.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ abstract type AbstractParticleSampler <: AbstractMCMC.AbstractSampler end
1616
abstract type AbstractStateSpaceModel <: AbstractParticleModel end
1717
abstract type AbstractGenericModel <: AbstractParticleModel end
1818

19+
# TODO(penelopeysm): This should be upstreamed to Turing together with anything that is
20+
# Turing-specific in LibtaskExt.
21+
abstract type AbstractTuringLibtaskModel <: AbstractGenericModel end
22+
1923
include("resampling.jl")
2024
include("rng.jl")
2125
include("model.jl")

0 commit comments

Comments
 (0)