Skip to content

Commit 051012c

Browse files
committed
Add a different struct that can pass kwargs on to Libtask
1 parent e100352 commit 051012c

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AdvancedPS"
22
uuid = "576499cb-2369-40b2-a588-c64705576edc"
33
authors = ["TuringLang"]
4-
version = "0.7.1"
4+
version = "0.7.2"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

ext/AdvancedPSLibtaskExt.jl

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,32 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
3535
function AdvancedPS.LibtaskModel(
3636
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
3737
) # Changed the API, need to take care of the RNG properly
38-
return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...))
38+
return AdvancedPS.LibtaskModel(
39+
f, Libtask.TapedTask(TapedGlobals(rng), f, args...)
40+
)
41+
end
42+
# TODO: Upstream this to Turing
43+
function AdvancedPS.LibtaskModel(
44+
f::AdvancedPS.AbstractTuringLibtaskModel, rng::Random.AbstractRNG
45+
)
46+
return AdvancedPS.LibtaskModel(
47+
f, Libtask.TapedTask(TapedGlobals(rng), f.fargs...; f.kwargs...)
48+
)
49+
end
50+
51+
const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R}
52+
53+
function to_tapedtask(
54+
newf::AdvancedPS.AbstractGenericModel, trace::LibtaskTrace, rng::Random.AbstractRNG
55+
)
56+
return Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf)
57+
end
58+
function to_tapedtask(
59+
newf::AdvancedPS.AbstractTuringLibtaskModel, trace::LibtaskTrace, rng::Random.AbstractRNG
60+
)
61+
return Libtask.TapedTask(
62+
TapedGlobals(rng, get_other_global(trace)), newf.fargs...; newf.kwargs...
63+
)
3964
end
4065

4166
"""
@@ -47,8 +72,6 @@ function Base.copy(model::AdvancedPS.LibtaskModel)
4772
return AdvancedPS.LibtaskModel(deepcopy(model.f), copy(model.ctask))
4873
end
4974

50-
const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R}
51-
5275
function Base.copy(trace::LibtaskTrace)
5376
newtrace = AdvancedPS.Trace(copy(trace.model), deepcopy(trace.rng))
5477
set_other_global!(newtrace, newtrace)
@@ -114,7 +137,7 @@ function AdvancedPS.forkr(trace::LibtaskTrace)
114137
newf = AdvancedPS.reset_model(trace.model.f)
115138
Random123.set_counter!(rng, 1)
116139

117-
ctask = Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf)
140+
ctask = to_tapedtask(newf, trace, rng)
118141
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)
119142

120143
# add backward reference

src/AdvancedPS.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ abstract type AbstractParticleSampler <: AbstractMCMC.AbstractSampler end
1616
abstract type AbstractStateSpaceModel <: AbstractParticleModel end
1717
abstract type AbstractGenericModel <: AbstractParticleModel end
1818

19+
# TODO(penelopeysm): This should be upstreamed to Turing together with anything that is
20+
# Turing-specific in LibtaskExt.
21+
abstract type AbstractTuringLibtaskModel <: AbstractGenericModel end
22+
1923
include("resampling.jl")
2024
include("rng.jl")
2125
include("model.jl")

0 commit comments

Comments
 (0)