1616 using .. Libtask: Libtask
1717end
1818
19+ # In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG,
20+ # and sometimes both an RNG and other information. In Turing.jl the other information
21+ # is a VarInfo. This struct puts those in a single struct. Note the abstract type of
22+ # the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
23+ struct TapedGlobals{RngType}
24+ rng:: RngType
25+ other:: Any
26+ end
27+
28+ TapedGlobals (rng:: Random.AbstractRNG ) = TapedGlobals (rng, nothing )
29+
1930"""
2031 LibtaskModel{F}
2132
@@ -24,12 +35,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
2435function AdvancedPS. LibtaskModel (
2536 f:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
2637) # Changed the API, need to take care of the RNG properly
27- return AdvancedPS. LibtaskModel (
28- f,
29- Libtask. TapedTask (
30- f, rng, args... ; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof (f)}
31- ),
32- )
38+ return AdvancedPS. LibtaskModel (f, Libtask. TapedTask (TapedGlobals (rng), f, args... ))
3339end
3440
3541"""
4349
4450const LibtaskTrace{R} = AdvancedPS. Trace{<: AdvancedPS.LibtaskModel ,R}
4551
46- function AdvancedPS . Trace (
47- model :: AdvancedPS.AbstractGenericModel , rng :: Random.AbstractRNG , args ...
48- )
49- return AdvancedPS . Trace (AdvancedPS . LibtaskModel (model, rng, args ... ), rng)
52+ function Base . copy (trace :: LibtaskTrace )
53+ newtrace = AdvancedPS. Trace ( copy (trace . model), deepcopy (trace . rng))
54+ set_other_global! (newtrace, newtrace )
55+ return newtrace
5056end
5157
52- # step to the next observe statement and
53- # return the log probability of the transition (or nothing if done)
54- function AdvancedPS. advance! (t:: LibtaskTrace , isref:: Bool = false )
55- isref ? AdvancedPS. load_state! (t. rng) : AdvancedPS. save_state! (t. rng)
56- AdvancedPS. inc_counter! (t. rng)
57-
58- # Move to next step
59- return Libtask. consume (t. model. ctask)
58+ """ Get the RNG from a `LibtaskTrace`."""
59+ function get_rng (trace:: LibtaskTrace )
60+ return trace. model. ctask. taped_globals. rng
6061end
6162
62- # create a backward reference in task_local_storage
63- function AdvancedPS. addreference! (task:: Task , trace:: LibtaskTrace )
64- if task. storage === nothing
65- task. storage = IdDict ()
66- end
67- task. storage[:__trace ] = trace
63+ """ Set the RNG for a `LibtaskTrace`."""
64+ function set_rng! (trace:: LibtaskTrace , rng:: Random.AbstractRNG )
65+ other = get_other_global (trace)
66+ Libtask. set_taped_globals! (trace. model. ctask, TapedGlobals (rng, other))
67+ trace. rng = rng
68+ return trace
69+ end
6870
69- return task
71+ """ Set the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
72+ function set_other_global! (trace:: LibtaskTrace , other)
73+ rng = get_rng (trace)
74+ Libtask. set_taped_globals! (trace. model. ctask, TapedGlobals (rng, other))
75+ return trace
7076end
7177
72- function AdvancedPS. update_rng! (trace:: LibtaskTrace )
73- rng, = trace. model. ctask. args
74- trace. rng = rng
78+ """ Get the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
79+ get_other_global (trace:: LibtaskTrace ) = trace. model. ctask. taped_globals. other
80+
81+ function AdvancedPS. Trace (
82+ model:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
83+ )
84+ trace = AdvancedPS. Trace (AdvancedPS. LibtaskModel (model, rng, args... ), rng)
85+ # Set a backreference so that the TapedTask in `trace` stores the `trace` itself in its
86+ # taped globals.
87+ set_other_global! (trace, trace)
7588 return trace
7689end
7790
91+ # step to the next observe statement and
92+ # return the log probability of the transition (or nothing if done)
93+ function AdvancedPS. advance! (trace:: LibtaskTrace , isref:: Bool = false )
94+ rng = get_rng (trace)
95+ isref ? AdvancedPS. load_state! (rng) : AdvancedPS. save_state! (rng)
96+ AdvancedPS. inc_counter! (rng)
97+ # Move to next step
98+ return Libtask. consume (trace. model. ctask)
99+ end
100+
78101# Task copying version of fork for Trace.
79102function AdvancedPS. fork (trace:: LibtaskTrace , isref:: Bool = false )
80103 newtrace = copy (trace)
81- AdvancedPS . update_rng ! (newtrace)
104+ set_rng ! (newtrace, deepcopy ( get_rng (newtrace)) )
82105 isref && AdvancedPS. delete_retained! (newtrace. model. f)
83106 isref && delete_seeds! (newtrace)
84-
85- # add backward reference
86- AdvancedPS. addreference! (newtrace. model. ctask. task, newtrace)
87107 return newtrace
88108end
89109
90110# PG requires keeping all randomness for the reference particle
91111# Create new task and copy randomness
92112function AdvancedPS. forkr (trace:: LibtaskTrace )
113+ rng = get_rng (trace)
93114 newf = AdvancedPS. reset_model (trace. model. f)
94- Random123. set_counter! (trace . rng, 1 )
115+ Random123. set_counter! (rng, 1 )
95116
96- ctask = Libtask. TapedTask (
97- newf, trace. rng; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof (trace. model. f)}
98- )
117+ ctask = Libtask. TapedTask (TapedGlobals (rng, get_other_global (trace)), newf)
99118 new_tapedmodel = AdvancedPS. LibtaskModel (newf, ctask)
100119
101120 # add backward reference
102- newtrace = AdvancedPS. Trace (new_tapedmodel, trace. rng)
103- AdvancedPS. addreference! (ctask. task, newtrace)
121+ newtrace = AdvancedPS. Trace (new_tapedmodel, rng)
104122 AdvancedPS. gen_refseed! (newtrace)
105123 return newtrace
106124end
@@ -113,7 +131,8 @@ AdvancedPS.update_ref!(::LibtaskTrace) = nothing
113131Observe sample `x` from distribution `dist` and yield its log-likelihood value.
114132"""
115133function AdvancedPS. observe (dist:: Distributions.Distribution , x)
116- return Libtask. produce (Distributions. loglikelihood (dist, x))
134+ Libtask. produce (Distributions. loglikelihood (dist, x))
135+ return nothing
117136end
118137
119138"""
@@ -138,7 +157,6 @@ function AbstractMCMC.step(
138157 else
139158 trng = AdvancedPS. TracedRNG ()
140159 trace = AdvancedPS. Trace (deepcopy (model), trng)
141- AdvancedPS. addreference! (trace. model. ctask. task, trace) # TODO : Do we need it here ?
142160 trace
143161 end
144162 end
@@ -153,8 +171,7 @@ function AbstractMCMC.step(
153171 newtrajectory = rand (rng, particles)
154172
155173 replayed = AdvancedPS. replay (newtrajectory)
156- return AdvancedPS. PGSample (replayed. model. f, logevidence),
157- AdvancedPS. PGState (newtrajectory)
174+ return AdvancedPS. PGSample (replayed. model. f, logevidence), AdvancedPS. PGState (replayed)
158175end
159176
160177function AbstractMCMC. sample (
@@ -176,7 +193,6 @@ function AbstractMCMC.sample(
176193 traces = map (1 : (sampler. nparticles)) do i
177194 trng = AdvancedPS. TracedRNG ()
178195 trace = AdvancedPS. Trace (deepcopy (model), trng)
179- AdvancedPS. addreference! (trace. model. ctask. task, trace) # Do we need it here ?
180196 trace
181197 end
182198
0 commit comments