@@ -41,52 +41,102 @@ LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T)
4141LogLikelihoodAccumulator () = LogLikelihoodAccumulator {LogProbType} ()
4242
4343"""
44- NumProduceAccumulator {T} <: AbstractAccumulator
44+ VariableOrderAccumulator {T} <: AbstractAccumulator
4545
46- An accumulator that tracks the number of observations during model execution.
46+ An accumulator that tracks the order of variables in a `VarInfo`.
47+
48+ This doesn't track the full ordering, but rather how many observations have taken place
49+ before the assume statement for each variable. This is needed for particle methods, where
50+ the model is segmented into parts by each observation, and we need to know which part each
51+ assume statement is in.
4752
4853# Fields
4954$(TYPEDFIELDS)
5055"""
51- struct NumProduceAccumulator{T <: Integer } <: AbstractAccumulator
56+ struct VariableOrderAccumulator{Eltype <: Integer ,VNType <: VarName } <: AbstractAccumulator
5257 " the number of observations"
53- num:: T
58+ num_produce:: Eltype
59+ " mapping of variable names to their order in the model"
60+ order:: Dict{VNType,Eltype}
5461end
5562
5663"""
57- NumProduceAccumulator {T<:Integer}()
64+ VariableOrderAccumulator {T<:Integer}(n=zero(T) )
5865
59- Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero .
66+ Create a new `VariableOrderAccumulator` with the number of observations set to `n` .
6067"""
61- NumProduceAccumulator {T} () where {T<: Integer } = NumProduceAccumulator (zero (T))
62- NumProduceAccumulator () = NumProduceAccumulator {Int} ()
68+ VariableOrderAccumulator {T} (n= zero (T)) where {T<: Integer } =
69+ VariableOrderAccumulator (convert (T, n), Dict {VarName,T} ())
70+ VariableOrderAccumulator (n) = VariableOrderAccumulator {typeof(n)} (n)
71+ VariableOrderAccumulator () = VariableOrderAccumulator {Int} ()
72+
73+ Base. copy (acc:: LogPriorAccumulator ) = acc
74+ Base. copy (acc:: LogLikelihoodAccumulator ) = acc
75+ function Base. copy (acc:: VariableOrderAccumulator )
76+ return VariableOrderAccumulator (acc. num_produce, copy (acc. order))
77+ end
6378
6479function Base. show (io:: IO , acc:: LogPriorAccumulator )
6580 return print (io, " LogPriorAccumulator($(repr (acc. logp)) )" )
6681end
6782function Base. show (io:: IO , acc:: LogLikelihoodAccumulator )
6883 return print (io, " LogLikelihoodAccumulator($(repr (acc. logp)) )" )
6984end
70- function Base. show (io:: IO , acc:: NumProduceAccumulator )
71- return print (io, " NumProduceAccumulator($(repr (acc. num)) )" )
85+ function Base. show (io:: IO , acc:: VariableOrderAccumulator )
86+ return print (
87+ io, " VariableOrderAccumulator($(repr (acc. num_produce)) , $(repr (acc. order)) )"
88+ )
89+ end
90+
91+ # Note that == and isequal are different, and equality under the latter should imply
92+ # equality of hashes. Both of the below implementations are also different from the default
93+ # implementation for structs.
94+ Base.:(== )(acc1:: LogPriorAccumulator , acc2:: LogPriorAccumulator ) = acc1. logp == acc2. logp
95+ function Base.:(== )(acc1:: LogLikelihoodAccumulator , acc2:: LogLikelihoodAccumulator )
96+ return acc1. logp == acc2. logp
97+ end
98+ function Base.:(== )(acc1:: VariableOrderAccumulator , acc2:: VariableOrderAccumulator )
99+ return acc1. num_produce == acc2. num_produce && acc1. order == acc2. order
100+ end
101+
102+ function Base. isequal (acc1:: LogPriorAccumulator , acc2:: LogPriorAccumulator )
103+ return isequal (acc1. logp, acc2. logp)
104+ end
105+ function Base. isequal (acc1:: LogLikelihoodAccumulator , acc2:: LogLikelihoodAccumulator )
106+ return isequal (acc1. logp, acc2. logp)
107+ end
108+ function Base. isequal (acc1:: VariableOrderAccumulator , acc2:: VariableOrderAccumulator )
109+ return isequal (acc1. num_produce, acc2. num_produce) && isequal (acc1. order, acc2. order)
110+ end
111+
112+ Base. hash (acc:: LogPriorAccumulator , h:: UInt ) = hash ((LogPriorAccumulator, acc. logp), h)
113+ function Base. hash (acc:: LogLikelihoodAccumulator , h:: UInt )
114+ return hash ((LogLikelihoodAccumulator, acc. logp), h)
115+ end
116+ function Base. hash (acc:: VariableOrderAccumulator , h:: UInt )
117+ return hash ((VariableOrderAccumulator, acc. num_produce, acc. order), h)
72118end
73119
74120accumulator_name (:: Type{<:LogPriorAccumulator} ) = :LogPrior
75121accumulator_name (:: Type{<:LogLikelihoodAccumulator} ) = :LogLikelihood
76- accumulator_name (:: Type{<:NumProduceAccumulator } ) = :NumProduce
122+ accumulator_name (:: Type{<:VariableOrderAccumulator } ) = :VariableOrder
77123
78124split (:: LogPriorAccumulator{T} ) where {T} = LogPriorAccumulator (zero (T))
79125split (:: LogLikelihoodAccumulator{T} ) where {T} = LogLikelihoodAccumulator (zero (T))
80- split (acc:: NumProduceAccumulator ) = acc
126+ split (acc:: VariableOrderAccumulator ) = copy ( acc)
81127
82128function combine (acc:: LogPriorAccumulator , acc2:: LogPriorAccumulator )
83129 return LogPriorAccumulator (acc. logp + acc2. logp)
84130end
85131function combine (acc:: LogLikelihoodAccumulator , acc2:: LogLikelihoodAccumulator )
86132 return LogLikelihoodAccumulator (acc. logp + acc2. logp)
87133end
88- function combine (acc:: NumProduceAccumulator , acc2:: NumProduceAccumulator )
89- return NumProduceAccumulator (max (acc. num, acc2. num))
134+ function combine (acc:: VariableOrderAccumulator , acc2:: VariableOrderAccumulator )
135+ # Note that assumptions are not allowed in parallelised blocks, and thus the
136+ # dictionaries should be identical.
137+ return VariableOrderAccumulator (
138+ max (acc. num_produce, acc2. num_produce), merge (acc. order, acc2. order)
139+ )
90140end
91141
92142function Base.:+ (acc1:: LogPriorAccumulator , acc2:: LogPriorAccumulator )
95145function Base.:+ (acc1:: LogLikelihoodAccumulator , acc2:: LogLikelihoodAccumulator )
96146 return LogLikelihoodAccumulator (acc1. logp + acc2. logp)
97147end
98- increment (acc:: NumProduceAccumulator ) = NumProduceAccumulator (acc. num + oneunit (acc. num))
148+ function increment (acc:: VariableOrderAccumulator )
149+ return VariableOrderAccumulator (acc. num_produce + oneunit (acc. num_produce), acc. order)
150+ end
99151
100152Base. zero (acc:: LogPriorAccumulator ) = LogPriorAccumulator (zero (acc. logp))
101153Base. zero (acc:: LogLikelihoodAccumulator ) = LogLikelihoodAccumulator (zero (acc. logp))
102- Base. zero (acc:: NumProduceAccumulator ) = NumProduceAccumulator (zero (acc. num))
103154
104155function accumulate_assume!! (acc:: LogPriorAccumulator , val, logjac, vn, right)
105156 return acc + LogPriorAccumulator (logpdf (right, val) + logjac)
@@ -114,8 +165,11 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
114165 return acc + LogLikelihoodAccumulator (Distributions. loglikelihood (right, left))
115166end
116167
117- accumulate_assume!! (acc:: NumProduceAccumulator , val, logjac, vn, right) = acc
118- accumulate_observe!! (acc:: NumProduceAccumulator , right, left, vn) = increment (acc)
168+ function accumulate_assume!! (acc:: VariableOrderAccumulator , val, logjac, vn, right)
169+ acc. order[vn] = acc. num_produce
170+ return acc
171+ end
172+ accumulate_observe!! (acc:: VariableOrderAccumulator , right, left, vn) = increment (acc)
119173
120174function Base. convert (:: Type{LogPriorAccumulator{T}} , acc:: LogPriorAccumulator ) where {T}
121175 return LogPriorAccumulator (convert (T, acc. logp))
@@ -126,15 +180,19 @@ function Base.convert(
126180 return LogLikelihoodAccumulator (convert (T, acc. logp))
127181end
128182function Base. convert (
129- :: Type{NumProduceAccumulator{T}} , acc:: NumProduceAccumulator
130- ) where {T}
131- return NumProduceAccumulator (convert (T, acc. num))
183+ :: Type{VariableOrderAccumulator{ElType,VnType}} , acc:: VariableOrderAccumulator
184+ ) where {ElType,VnType}
185+ order = Dict {VnType,ElType} ()
186+ for (k, v) in acc. order
187+ order[convert (VnType, k)] = convert (ElType, v)
188+ end
189+ return VariableOrderAccumulator (convert (ElType, acc. num_produce), order)
132190end
133191
134192# TODO (mhauru)
135- # We ignore the convert_eltype calls for NumProduceAccumulator , by letting them fallback on
193+ # We ignore the convert_eltype calls for VariableOrderAccumulator , by letting them fallback on
136194# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
137- # deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator . This is
195+ # deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator . This is
138196# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
139197function convert_eltype (:: Type{T} , acc:: LogPriorAccumulator ) where {T}
140198 return LogPriorAccumulator (convert (T, acc. logp))
@@ -149,6 +207,6 @@ function default_accumulators(
149207 return AccumulatorTuple (
150208 LogPriorAccumulator {FloatT} (),
151209 LogLikelihoodAccumulator {FloatT} (),
152- NumProduceAccumulator {IntT} (),
210+ VariableOrderAccumulator {IntT} (),
153211 )
154212end
0 commit comments