11# assume
2- """
3- tilde_assume(context::SamplingContext, right, vn, vi)
4-
5- Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
6- accumulate the log probability, and return the sampled value with a context associated
7- with a sampler.
8-
9- Falls back to
10- ```julia
11- tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
12- ```
13- """
14- function tilde_assume (context:: SamplingContext , right, vn, vi)
15- return tilde_assume (context. rng, context. context, context. sampler, right, vn, vi)
16- end
17-
182function tilde_assume (context:: AbstractContext , args... )
193 return tilde_assume (childcontext (context), args... )
204end
215function tilde_assume (:: DefaultContext , right, vn, vi)
22- return assume (right, vn, vi)
23- end
24-
25- function tilde_assume (rng:: Random.AbstractRNG , context:: AbstractContext , args... )
26- return tilde_assume (rng, childcontext (context), args... )
27- end
28- function tilde_assume (rng:: Random.AbstractRNG , :: DefaultContext , sampler, right, vn, vi)
29- return assume (rng, sampler, right, vn, vi)
30- end
31- function tilde_assume (:: Random.AbstractRNG , :: InitContext , sampler, right, vn, vi)
32- return error (
33- " Encountered SamplingContext->InitContext. This method will be removed in the next PR." ,
34- )
35- end
36- function tilde_assume (:: DefaultContext , sampler, right, vn, vi)
37- # same as above but no rng
38- return assume (Random. default_rng (), sampler, right, vn, vi)
6+ y = getindex_internal (vi, vn)
7+ f = from_maybe_linked_internal_transform (vi, vn, right)
8+ x, logjac = with_logabsdet_jacobian (f, y)
9+ vi = accumulate_assume!! (vi, x, logjac, vn, right)
10+ return x, vi
3911end
40-
4112function tilde_assume (context:: PrefixContext , right, vn, vi)
4213 # Note that we can't use something like this here:
4314 # new_vn = prefix(context, vn)
@@ -51,12 +22,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi)
5122 new_vn, new_context = prefix_and_strip_contexts (context, vn)
5223 return tilde_assume (new_context, right, new_vn, vi)
5324end
54- function tilde_assume (
55- rng:: Random.AbstractRNG , context:: PrefixContext , sampler, right, vn, vi
56- )
57- new_vn, new_context = prefix_and_strip_contexts (context, vn)
58- return tilde_assume (rng, new_context, sampler, right, new_vn, vi)
59- end
6025
6126"""
6227 tilde_assume!!(context, right, vn, vi)
@@ -76,17 +41,6 @@ function tilde_assume!!(context, right, vn, vi)
7641end
7742
7843# observe
79- """
80- tilde_observe!!(context::SamplingContext, right, left, vi)
81-
82- Handle observed constants with a `context` associated with a sampler.
83-
84- Falls back to `tilde_observe!!(context.context, right, left, vi)`.
85- """
86- function tilde_observe!! (context:: SamplingContext , right, left, vn, vi)
87- return tilde_observe!! (context. context, right, left, vn, vi)
88- end
89-
9044function tilde_observe!! (context:: AbstractContext , right, left, vn, vi)
9145 return tilde_observe!! (childcontext (context), right, left, vn, vi)
9246end
@@ -119,59 +73,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi)
11973 vi = accumulate_observe!! (vi, right, left, vn)
12074 return left, vi
12175end
122-
123- function assume (:: Random.AbstractRNG , spl:: Sampler , dist)
124- return error (" DynamicPPL.assume: unmanaged inference algorithm: $(typeof (spl)) " )
125- end
126-
127- # fallback without sampler
128- function assume (dist:: Distribution , vn:: VarName , vi)
129- y = getindex_internal (vi, vn)
130- f = from_maybe_linked_internal_transform (vi, vn, dist)
131- x, logjac = with_logabsdet_jacobian (f, y)
132- vi = accumulate_assume!! (vi, x, logjac, vn, dist)
133- return x, vi
134- end
135-
136- # TODO : Remove this thing.
137- # SampleFromPrior and SampleFromUniform
138- function assume (
139- rng:: Random.AbstractRNG ,
140- sampler:: Union{SampleFromPrior,SampleFromUniform} ,
141- dist:: Distribution ,
142- vn:: VarName ,
143- vi:: VarInfoOrThreadSafeVarInfo ,
144- )
145- if haskey (vi, vn)
146- # Always overwrite the parameters with new ones for `SampleFromUniform`.
147- if sampler isa SampleFromUniform || is_flagged (vi, vn, " del" )
148- # TODO (mhauru) Is it important to unset the flag here? The `true` allows us
149- # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
150- # if that's okay.
151- unset_flag! (vi, vn, " del" , true )
152- r = init (rng, dist, sampler)
153- f = to_maybe_linked_internal_transform (vi, vn, dist)
154- # TODO (mhauru) This should probably be call a function called setindex_internal!
155- vi = BangBang. setindex!! (vi, f (r), vn)
156- setorder! (vi, vn, get_num_produce (vi))
157- else
158- # Otherwise we just extract it.
159- r = vi[vn, dist]
160- end
161- else
162- r = init (rng, dist, sampler)
163- if istrans (vi)
164- f = to_linked_internal_transform (vi, vn, dist)
165- vi = push!! (vi, vn, f (r), dist)
166- # By default `push!!` sets the transformed flag to `false`.
167- vi = settrans!! (vi, true , vn)
168- else
169- vi = push!! (vi, vn, r, dist)
170- end
171- end
172-
173- # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
174- logjac = logabsdetjac (istrans (vi, vn) ? link_transform (dist) : identity, r)
175- vi = accumulate_assume!! (vi, r, - logjac, vn, dist)
176- return r, vi
177- end
0 commit comments