Skip to content

VariableOrderAccumulator #940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 18, 2025
4 changes: 2 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ get_num_produce
set_num_produce!!
increment_num_produce!!
reset_num_produce!!
setorder!
setorder!!
set_retained_vns_del!
```

Expand All @@ -368,7 +368,7 @@ DynamicPPL provides the following default accumulators.
```@docs
LogPriorAccumulator
LogLikelihoodAccumulator
NumProduceAccumulator
VariableOrderAccumulator
```

### Common API
Expand Down
4 changes: 2 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export AbstractVarInfo,
AbstractAccumulator,
LogLikelihoodAccumulator,
LogPriorAccumulator,
NumProduceAccumulator,
VariableOrderAccumulator,
push!!,
empty!!,
subset,
Expand All @@ -73,7 +73,7 @@ export AbstractVarInfo,
is_flagged,
set_flag!,
unset_flag!,
setorder!,
setorder!!,
istrans,
link,
link!!,
Expand Down
34 changes: 30 additions & 4 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,24 @@ function resetlogp!!(vi::AbstractVarInfo)
return vi
end

"""
setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)

Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe
statements run before sampling `vn`.
"""
function setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)
return map_accumulator!!(acc -> (acc.order[vn] = index; acc), vi, Val(:VariableOrder))
end

"""
getorder(vi::VarInfo, vn::VarName)

Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements
run before sampling `vn`.
"""
getorder(vi::AbstractVarInfo, vn::VarName) = getacc(vi, Val(:VariableOrder)).order[vn]

# Variables and their realizations.
@doc """
keys(vi::AbstractVarInfo)
Expand Down Expand Up @@ -980,29 +998,37 @@ end

Return the `num_produce` of `vi`.
"""
get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:NumProduce)).num
get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:VariableOrder)).num_produce

"""
set_num_produce!!(vi::AbstractVarInfo, n::Int)

Set the `num_produce` field of `vi` to `n`.
"""
set_num_produce!!(vi::AbstractVarInfo, n::Int) = setacc!!(vi, NumProduceAccumulator(n))
function set_num_produce!!(vi::AbstractVarInfo, n::Integer)
if hasacc(vi, Val(:VariableOrder))
acc = getacc(vi, Val(:VariableOrder))
acc = VariableOrderAccumulator(n, acc.order)
else
acc = VariableOrderAccumulator(n)
end
return setacc!!(vi, acc)
end

"""
increment_num_produce!!(vi::AbstractVarInfo)

Add 1 to `num_produce` in `vi`.
"""
increment_num_produce!!(vi::AbstractVarInfo) =
map_accumulator!!(increment, vi, Val(:NumProduce))
map_accumulator!!(increment, vi, Val(:VariableOrder))

"""
reset_num_produce!!(vi::AbstractVarInfo)

Reset the value of `num_produce` in `vi` to 0.
"""
reset_num_produce!!(vi::AbstractVarInfo) = map_accumulator!!(zero, vi, Val(:NumProduce))
reset_num_produce!!(vi::AbstractVarInfo) = set_num_produce!!(vi, zero(get_num_produce(vi)))

"""
from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])
Expand Down
4 changes: 4 additions & 0 deletions src/accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ An accumulator type `T <: AbstractAccumulator` must implement the following meth
- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})`
- `accumulate_observe!!(acc::T, right, left, vn)`
- `accumulate_assume!!(acc::T, val, logjac, vn, right)`
- `Base.copy(acc::T)`

To be able to work with multi-threading, it should also implement:
- `split(acc::T)`
Expand Down Expand Up @@ -138,6 +139,9 @@ function Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname}
@inline return haskey(at.nt, accname)
end
Base.keys(at::AccumulatorTuple) = keys(at.nt)
Base.:(==)(at1::AccumulatorTuple, at2::AccumulatorTuple) = at1.nt == at2.nt
Base.hash(at::AccumulatorTuple, h::UInt) = Base.hash((AccumulatorTuple, at.nt), h)
Base.copy(at::AccumulatorTuple) = AccumulatorTuple(map(copy, at.nt))

function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T}
return AccumulatorTuple(convert(T, accs.nt))
Expand Down
1 change: 0 additions & 1 deletion src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ function assume(
f = to_maybe_linked_internal_transform(vi, vn, dist)
# TODO(mhauru) This should probably be call a function called setindex_internal!
vi = BangBang.setindex!!(vi, f(r), vn)
setorder!(vi, vn, get_num_produce(vi))
else
# Otherwise we just extract it.
r = vi[vn, dist]
Expand Down
6 changes: 1 addition & 5 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,7 @@ function check_model_and_trace(
model::Model, varinfo::AbstractVarInfo; error_on_failure=false
)
# Add debug accumulator to the VarInfo.
# Need a NumProduceAccumulator as well or else get_num_produce may throw
# TODO(mhauru) Remove this once VariableOrderAccumulator stuff is done.
varinfo = DynamicPPL.setaccs!!(
deepcopy(varinfo), (DebugAccumulator(error_on_failure), NumProduceAccumulator())
)
varinfo = DynamicPPL.setaccs!!(deepcopy(varinfo), (DebugAccumulator(error_on_failure),))

# Perform checks before evaluating the model.
issuccess = check_model_pre_evaluation(model)
Expand Down
106 changes: 82 additions & 24 deletions src/default_accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,52 +41,102 @@ LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T)
LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}()

"""
NumProduceAccumulator{T} <: AbstractAccumulator
VariableOrderAccumulator{T} <: AbstractAccumulator

An accumulator that tracks the number of observations during model execution.
An accumulator that tracks the order of variables in a `VarInfo`.

This doesn't track the full ordering, but rather how many observations have taken place
before the assume statement for each variable. This is needed for particle methods, where
the model is segmented into parts by each observation, and we need to know which part each
assume statement is in.

# Fields
$(TYPEDFIELDS)
"""
struct NumProduceAccumulator{T<:Integer} <: AbstractAccumulator
struct VariableOrderAccumulator{Eltype<:Integer,VNType<:VarName} <: AbstractAccumulator
"the number of observations"
num::T
num_produce::Eltype
"mapping of variable names to their order in the model"
order::Dict{VNType,Eltype}
end

"""
NumProduceAccumulator{T<:Integer}()
VariableOrderAccumulator{T<:Integer}(n=zero(T))

Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero.
Create a new `VariableOrderAccumulator` with the number of observations set to `n`.
"""
NumProduceAccumulator{T}() where {T<:Integer} = NumProduceAccumulator(zero(T))
NumProduceAccumulator() = NumProduceAccumulator{Int}()
VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} =
VariableOrderAccumulator(convert(T, n), Dict{VarName,T}())
VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n)
VariableOrderAccumulator() = VariableOrderAccumulator{Int}()

Base.copy(acc::LogPriorAccumulator) = acc
Base.copy(acc::LogLikelihoodAccumulator) = acc
function Base.copy(acc::VariableOrderAccumulator)
return VariableOrderAccumulator(acc.num_produce, copy(acc.order))
end

function Base.show(io::IO, acc::LogPriorAccumulator)
return print(io, "LogPriorAccumulator($(repr(acc.logp)))")
end
function Base.show(io::IO, acc::LogLikelihoodAccumulator)
return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))")
end
function Base.show(io::IO, acc::NumProduceAccumulator)
return print(io, "NumProduceAccumulator($(repr(acc.num)))")
function Base.show(io::IO, acc::VariableOrderAccumulator)
return print(
io, "VariableOrderAccumulator($(repr(acc.num_produce)), $(repr(acc.order)))"
)
end

# Note that == and isequal are different, and equality under the latter should imply
# equality of hashes. Both of the below implementations are also different from the default
# implementation for structs.
Base.:(==)(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) = acc1.logp == acc2.logp
function Base.:(==)(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
return acc1.logp == acc2.logp
end
function Base.:(==)(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
return acc1.num_produce == acc2.num_produce && acc1.order == acc2.order
end

function Base.isequal(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
return isequal(acc1.logp, acc2.logp)
end
function Base.isequal(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
return isequal(acc1.logp, acc2.logp)
end
function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
return isequal(acc1.num_produce, acc2.num_produce) && isequal(acc1.order, acc2.order)
end

Base.hash(acc::LogPriorAccumulator, h::UInt) = hash((LogPriorAccumulator, acc.logp), h)
function Base.hash(acc::LogLikelihoodAccumulator, h::UInt)
return hash((LogLikelihoodAccumulator, acc.logp), h)
end
function Base.hash(acc::VariableOrderAccumulator, h::UInt)
return hash((VariableOrderAccumulator, acc.num_produce, acc.order), h)
end

accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior
accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood
accumulator_name(::Type{<:NumProduceAccumulator}) = :NumProduce
accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder

split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T))
split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T))
split(acc::NumProduceAccumulator) = acc
split(acc::VariableOrderAccumulator) = copy(acc)

function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator)
return LogPriorAccumulator(acc.logp + acc2.logp)
end
function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
return LogLikelihoodAccumulator(acc.logp + acc2.logp)
end
function combine(acc::NumProduceAccumulator, acc2::NumProduceAccumulator)
return NumProduceAccumulator(max(acc.num, acc2.num))
function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
# Note that assumptions are not allowed in parallelised blocks, and thus the
# dictionaries should be identical.
return VariableOrderAccumulator(
max(acc.num_produce, acc2.num_produce), merge(acc.order, acc2.order)
)
end

function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
Expand All @@ -95,11 +145,12 @@ end
function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
return LogLikelihoodAccumulator(acc1.logp + acc2.logp)
end
increment(acc::NumProduceAccumulator) = NumProduceAccumulator(acc.num + oneunit(acc.num))
function increment(acc::VariableOrderAccumulator)
return VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order)
end

Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp))
Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp))
Base.zero(acc::NumProduceAccumulator) = NumProduceAccumulator(zero(acc.num))

function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right)
return acc + LogPriorAccumulator(logpdf(right, val) + logjac)
Expand All @@ -114,8 +165,11 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left))
end

accumulate_assume!!(acc::NumProduceAccumulator, val, logjac, vn, right) = acc
accumulate_observe!!(acc::NumProduceAccumulator, right, left, vn) = increment(acc)
function accumulate_assume!!(acc::VariableOrderAccumulator, val, logjac, vn, right)
acc.order[vn] = acc.num_produce
return acc
end
accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment(acc)

function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T}
return LogPriorAccumulator(convert(T, acc.logp))
Expand All @@ -126,15 +180,19 @@ function Base.convert(
return LogLikelihoodAccumulator(convert(T, acc.logp))
end
function Base.convert(
::Type{NumProduceAccumulator{T}}, acc::NumProduceAccumulator
) where {T}
return NumProduceAccumulator(convert(T, acc.num))
::Type{VariableOrderAccumulator{ElType,VnType}}, acc::VariableOrderAccumulator
) where {ElType,VnType}
order = Dict{VnType,ElType}()
for (k, v) in acc.order
order[convert(VnType, k)] = convert(ElType, v)
end
return VariableOrderAccumulator(convert(ElType, acc.num_produce), order)
end

# TODO(mhauru)
# We ignore the convert_eltype calls for NumProduceAccumulator, by letting them fallback on
# We ignore the convert_eltype calls for VariableOrderAccumulator, by letting them fallback on
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
# deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator. This is
# deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T}
return LogPriorAccumulator(convert(T, acc.logp))
Expand All @@ -149,6 +207,6 @@ function default_accumulators(
return AccumulatorTuple(
LogPriorAccumulator{FloatT}(),
LogLikelihoodAccumulator{FloatT}(),
NumProduceAccumulator{IntT}(),
VariableOrderAccumulator{IntT}(),
)
end
16 changes: 6 additions & 10 deletions src/extract_priors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ end

PriorDistributionAccumulator() = PriorDistributionAccumulator(OrderedDict{VarName,Any}())

function Base.copy(acc::PriorDistributionAccumulator)
return PriorDistributionAccumulator(copy(acc.priors))
end

accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator

split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors))
Expand Down Expand Up @@ -112,10 +116,7 @@ extract_priors(args::Union{Model,AbstractVarInfo}...) =
extract_priors(Random.default_rng(), args...)
function extract_priors(rng::Random.AbstractRNG, model::Model)
varinfo = VarInfo()
# TODO(mhauru) This doesn't actually need the NumProduceAccumulator, it's only a
# workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
# can't push new variables without knowing the num_produce. Remove this when possible.
varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator()))
varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),))
varinfo = last(evaluate_and_sample!!(rng, model, varinfo))
return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors
end
Expand All @@ -129,12 +130,7 @@ This is done by evaluating the model at the values present in `varinfo`
and recording the distributions that are present at each tilde statement.
"""
function extract_priors(model::Model, varinfo::AbstractVarInfo)
# TODO(mhauru) This doesn't actually need the NumProduceAccumulator, it's only a
# workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
# can't push new variables without knowing the num_produce. Remove this when possible.
varinfo = setaccs!!(
deepcopy(varinfo), (PriorDistributionAccumulator(), NumProduceAccumulator())
)
varinfo = setaccs!!(deepcopy(varinfo), (PriorDistributionAccumulator(),))
varinfo = last(evaluate!!(model, varinfo))
return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors
end
4 changes: 4 additions & 0 deletions src/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob
return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps)
end

function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps))
end

function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp)
logps = acc.logps
# The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys.
Expand Down
Loading
Loading