@@ -35,7 +35,32 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
3535function AdvancedPS. LibtaskModel (
3636 f:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
3737) # Changed the API, need to take care of the RNG properly
38- return AdvancedPS. LibtaskModel (f, Libtask. TapedTask (TapedGlobals (rng), f, args... ))
38+ return AdvancedPS. LibtaskModel (
39+ f, Libtask. TapedTask (TapedGlobals (rng), f, args... )
40+ )
41+ end
42+ # TODO : Upstream this to Turing
43+ function AdvancedPS. LibtaskModel (
44+ f:: AdvancedPS.AbstractTuringLibtaskModel , rng:: Random.AbstractRNG
45+ )
46+ return AdvancedPS. LibtaskModel (
47+ f, Libtask. TapedTask (TapedGlobals (rng), f. fargs... ; f. kwargs... )
48+ )
49+ end
50+
51+ const LibtaskTrace{R} = AdvancedPS. Trace{<: AdvancedPS.LibtaskModel ,R}
52+
53+ function to_tapedtask (
54+ newf:: AdvancedPS.AbstractGenericModel , trace:: LibtaskTrace , rng:: Random.AbstractRNG
55+ )
56+ return Libtask. TapedTask (TapedGlobals (rng, get_other_global (trace)), newf)
57+ end
58+ function to_tapedtask (
59+ newf:: AdvancedPS.AbstractTuringLibtaskModel , trace:: LibtaskTrace , rng:: Random.AbstractRNG
60+ )
61+ return Libtask. TapedTask (
62+ TapedGlobals (rng, get_other_global (trace)), newf. fargs... ; newf. kwargs...
63+ )
3964end
4065
4166"""
@@ -47,8 +72,6 @@ function Base.copy(model::AdvancedPS.LibtaskModel)
4772 return AdvancedPS. LibtaskModel (deepcopy (model. f), copy (model. ctask))
4873end
4974
50- const LibtaskTrace{R} = AdvancedPS. Trace{<: AdvancedPS.LibtaskModel ,R}
51-
5275function Base. copy (trace:: LibtaskTrace )
5376 newtrace = AdvancedPS. Trace (copy (trace. model), deepcopy (trace. rng))
5477 set_other_global! (newtrace, newtrace)
@@ -114,7 +137,7 @@ function AdvancedPS.forkr(trace::LibtaskTrace)
114137 newf = AdvancedPS. reset_model (trace. model. f)
115138 Random123. set_counter! (rng, 1 )
116139
117- ctask = Libtask . TapedTask ( TapedGlobals (rng, get_other_global ( trace)), newf )
140+ ctask = to_tapedtask (newf, trace, rng )
118141 new_tapedmodel = AdvancedPS. LibtaskModel (newf, ctask)
119142
120143 # add backward reference
0 commit comments