Skip to content

Commit fd8fa3f

Browse files
committed
Pass kwargs on to Libtask
1 parent e100352 commit fd8fa3f

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AdvancedPS"
22
uuid = "576499cb-2369-40b2-a588-c64705576edc"
33
authors = ["TuringLang"]
4-
version = "0.7.1"
4+
version = "0.7.2"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

ext/AdvancedPSLibtaskExt.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ TapedGlobals(rng::Random.AbstractRNG) = TapedGlobals(rng, nothing)
3333
State wrapper to hold `Libtask.CTask` model initiated from `f`.
3434
"""
3535
function AdvancedPS.LibtaskModel(
36-
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
36+
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG
3737
) # Changed the API, need to take care of the RNG properly
38-
return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...))
38+
return AdvancedPS.LibtaskModel(
39+
f, Libtask.TapedTask(TapedGlobals(rng), f.fargs...; f.kwargs...)
40+
)
3941
end
4042

4143
"""
@@ -114,7 +116,9 @@ function AdvancedPS.forkr(trace::LibtaskTrace)
114116
newf = AdvancedPS.reset_model(trace.model.f)
115117
Random123.set_counter!(rng, 1)
116118

117-
ctask = Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf)
119+
ctask = Libtask.TapedTask(
120+
TapedGlobals(rng, get_other_global(trace)), newf.fargs...; newf.kwargs...
121+
)
118122
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)
119123

120124
# add backward reference

0 commit comments

Comments
 (0)