Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions docs/src/internals/varnamedtuple.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# VarNamedTuple as the basis of VarInfo

This document collects thoughts and ideas for how to unify our multitude of AbstractVarInfo types using a VarNamedTuple type. It may eventually turn into a draft design document, but for now it is more raw than that.

## The current situation

We currently have the following AbstractVarInfo types:

- A: VarInfo with Metadata
- B: VarInfo with VarNamedVector
- C: VarInfo with NamedTuple, with values being Metadata
- D: VarInfo with NamedTuple, with values being VarNamedVector
- E: SimpleVarInfo with NamedTuples
- F: SimpleVarInfo with OrderedDict

A and C are the classic ones, and the defaults. C wraps groups the Metadata objects by the lead Symbol of the VarName of a variable, e.g. `x` in `@varname(x.y[1].z)`, which allows different lead Symbols to have different element types and for the VarInfo to still be type stable. B and D were created to simplify A and C, give them a nicer interface, and make them deal better with changing variable sizes, but according to recent (Oct 2025) benchmarks are quite a lot slower, which needs work.

E and F are entirely distinct in implementation from the others. E is simply a mapping from Symbols to values, with each VarName being converted to a single symbol, e.g. `Symbol("a[1]")`. F is a mapping from VarNames to values as an OrderedDict, with VarName as the key type.

A-D carry within them values for variables, but also their bijectors/distributions, and store all values vectorised, using the bijectors to map to the original values. They also store for each variable a flag for whether the variable has been linked. E-F store only the raw values, and a global flag for the whole SimpleVarInfo for whether it's linked. The link transform itself is implicit.

TODO: Write a better summary of pros and cons of each approach.

## VarNamedTuple

VarNamedTuple has been discussed as a possible data structure to generalise the structure used in VarInfo to achieve type stability, i.e. grouping VarNames by their lead Symbol. The same NamedTuple structure has been used elsewhere, too, e.g. in Turing.GibbsContext. The idea was to encapsulate this structure into its own type, reducing code duplication and making the design more robust and powerful. See https://github.com/TuringLang/DynamicPPL.jl/issues/900 for the discussion.

An AbstractVarInfo type could be only one application of VarNamedTuple, but here I'll focus on it exclusively. If we can make VarNamedTuple work for an AbstractVarInfo, I bet we can make it work for other purposes (condition, fix, Gibbs) as well.

Without going into full detail, here's @mhauru's current proposal for what it would look like. This proposal remains in constant flux as I develop the code.

A VarNamedTuple is a mapping of VarNames to values. Values can be anything. In the case of using VarNamedTuple to implement an AbstractVarInfo, the values would be random samples for random variables. However, they could hold with them extra information. For instance, we might use a value that is a tuple of a vectorised value, a bijector, and a flag for whether the variable is linked.

I sometimes shorten VarNamedTuple to VNT.

Internally, a VarNamedTuple consists of nested NamedTuples. For instance, the mapping `@varname(x) => 1, @varname(y.z) => 2` would be stored as

```
(; x=1, y=(; z=2))
```

(This is a slight simplification, really it would be nested VarNamedTuples rather than NamedTuples, but I omit this detail.)
This forms a tree, with each node being a NamedTuple, like so:

```
NT
x / \ y
1 NT
\ z
2
```

Each `NT` marks a NamedTuple, and the labels on the edges its keys. Here the root node has the keys `x` and `y`. This is like with the type stable VarInfo in our current design, except with possibly more levels (our current one only has the root node). Each nested `PropertyLens`, i.e. each `.` in a VarName like `@varname(a.b.c.e)`, creates a new layer of the tree.

For simplicity, at least for now, we ban any VarNames where an `IndexLens` precedes a `PropertyLens`. That is, we ban any VarNames like `@varname(a.b[1].c)`. Recall that VarNames allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). Thus the only allowed VarName types are `@varname(a.b.c.d)` and `@varname(a.b.c.d[i,j,k])`.

This means that we can add levels to the NamedTuple tree until all `PropertyLenses` have been covered. The leaves of the tree are then of two kinds: They are either the raw value itself if the last lens of the VarName is an `identity`, or otherwise they are something that can be indexed with an `IndexLens`, such as an `Array`.

To get a value from a VarNamedTuple is very simple: For `getindex(vnt::VNT, vn::VarName{S})` (`S` being the lead Symbol) you recurse into `getindex(vnt[S], unprefix(vn, S))`. If the last lens of `vn` is an `IndexLens`, we assume that the leaf of the NamedTuple tree we've reached contains something that can be indexed with it.

Setting values in a VNT is equally simple if there are no `IndexLenses`: For `setindex!!(vnt::VNT, value::Any, vn::VarName)` one simply finds the leaf of the `vnt` tree corresponding to `vn` and sets its value to `value`.

The tricky part is what to do when setting values with `IndexLenses`. There are three possible situations. Say one calls `setindex!!(vnt, 3.0, @varname(a.b[3]))`.

1. If `getindex(vnt, @varname(a.b))` is already a vector of length at least 3, this is easy: Just set the third element.
2. If `getindex(vnt, @varname(a.b))` is a vector of length less than 3, what should we do? Do we error? Do we extend that vector?
3. If `getindex(vnt, @varname(a.b))` isn't even set, what do we do? Say for instance that `vnt` is currently empty. We should set `vnt` to be something like `(; a=(; b=x))`, where `x` is such that `x[3] = 3.0`, but what exactly should `x` be? Is it a dictionary? A vector of length 3? If the latter, what are `x[2]` and `x[1]`? Or should this `setindex!!` call simply error?

A note at this point: VarNamedTuples must always use `setindex!!`, the `!!` version that may or may not operate in place. The NamedTuples can't be modified in place, but the values at the leaves may be. Always using a `!!` function makes type stability easier, and makes structures like the type unstable old VarInfo with Metadata unnecessary: Any value can be set into any VarNamedTuple. The type parameters of the VNT will simply expand as necessary.

To solve the problem of points 2. and 3. above I propose expanding the definition of VNT a bit. This will also help make VNT more flexible, which may help performance or allow more use cases. The modification is this:

Unlike I said above, let's say that VNT isn't just nested NamedTuples with some values at the leaves. Let's say it also has a field called `make_leaf`. `make_leaf(value, lens)` is a function that takes any value, and a lens that is either `identity` or an `IndexLens`, and returns the value wrapped in some suitable struct that can be stored in the leaf of the NamedTuple tree. The values should always be such that `make_leaf(value, lens)[lens] == value`.

Our earlier example of `VarNamedTuple(@varname(x) => 1, @varname(y.z) => 2; make_leaf=f)` would be stored as a tree like

```
--NT--
x / \ y
f(1, identity) NT
\ z
f(2, identity)
```

The above, first draft of VNT which did not include `make_leaf` is equivalent to the trivial choice `make_leaf(value, lens) = lens === identity ? value : error("Don't know how to deal IndexLenses")`. The problems 2. and 3. above are "solved" by making it `make_leaf`'s problem to figure out what to do. For instance, `make_leaf` can always return a `Dict` that maps lenses to values. This is probably slow, but works for any lens. Or it can initialise a vector type, that can grow as needed when indexed into.

The idea would be to use `make_leaf` to try out different ways of implementing a VarInfo, find a good default, and ,if necessary, leave the option for power users to customise behaviour. The first ones to implement would be

- `make_leaf` that returns a Metadata object. This would be a direct replacement for type stable VarInfo that uses Metadata, except now with more nested levels of NamedTuple.
- `make_leaf` that returns an `OrderedDict`. This would be a direct replacement for SimpleVarInfo with OrderedDict.

You may ask, have we simple gone from too many VarInfo types to too many `make_leaf` functions. Yes we have. But hopefully we have gained something in the process:

- The leaf types can be simpler. They do not need to deal with VarNames any more, they only need to deal with `identity` lenses and `IndexLenses`.
- All AbstactVarInfos are as type stable as their leaf types allow. There is no more notion of an untyped VarInfo being converted to a typed one.
- Type stability is maintained even with nested `PropertyLenses` like `@varname(a.b)`, which happens a lot with submodels.
- Many functions that are currently implemented individually for each AbstactVarInfo type would now have a single implementation for the VarNamedTuple-based AbstactVarInfo type, reducing code duplication. I would also hope to get ride of most of the generated functions for in `varinfo.jl`.

My guess is that the eventual One AbstractVarInfo To Rule Them All would have a `make_leaf` function that stores the raw values when the lens is an `identity`, and uses a flexible Vector, a lot like VarNamedVector, when the lens is an IndexLens. However, I could be wrong on that being the best option. Implementing and benchmarking is the only way to know.

I think the two big questions are:

- Will we run into some big, unanticipated blockers when we start to implement this.
- Will the nesting of NamedTuples cause performance regressions, if the compiler either chokes or gives up.

I'll try to derisk these early on in this PR.

## Questions / issues

* People might really need IndexLenses in the middle of VarNames. The one place this comes up is submodels within a loop. I'm still inclined to keep designing without allowing for that, for now, but should keep in mind that that needs to be relaxed eventually. If it makes it easier, we can require that users explicitly tell us the size of any arrays for which this is done.
* When storing values for nested NamedTuples, the actual variable may be a struct. Do we need to be able to reconstruct the struct from the NamedTuple? If so, how do we do that?
* Do `Colon` indices cause any extra trouble for the leafnodes?
Comment on lines +110 to +112
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
* People might really need IndexLenses in the middle of VarNames. The one place this comes up is submodels within a loop. I'm still inclined to keep designing without allowing for that, for now, but should keep in mind that that needs to be relaxed eventually. If it makes it easier, we can require that users explicitly tell us the size of any arrays for which this is done.
* When storing values for nested NamedTuples, the actual variable may be a struct. Do we need to be able to reconstruct the struct from the NamedTuple? If so, how do we do that?
* Do `Colon` indices cause any extra trouble for the leafnodes?
- People might really need IndexLenses in the middle of VarNames. The one place this comes up is submodels within a loop. I'm still inclined to keep designing without allowing for that, for now, but should keep in mind that that needs to be relaxed eventually. If it makes it easier, we can require that users explicitly tell us the size of any arrays for which this is done.
- When storing values for nested NamedTuples, the actual variable may be a struct. Do we need to be able to reconstruct the struct from the NamedTuple? If so, how do we do that?
- Do `Colon` indices cause any extra trouble for the leafnodes?

2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ include("contexts/prefix.jl")
include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl
include("model.jl")
include("varname.jl")
include("varnamedtuple.jl")
using .VarNamedTuples: VarNamedTuple
include("distribution_wrappers.jl")
include("submodel.jl")
include("varnamedvector.jl")
Expand Down
11 changes: 6 additions & 5 deletions src/contexts/transformation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ function tilde_assume!!(
# vi[vn, right] always provides the value in unlinked space.
x = vi[vn, right]

if is_transformed(vi, vn)
isinverse || @warn "Trying to link an already transformed variable ($vn)"
else
isinverse && @warn "Trying to invlink a non-transformed variable ($vn)"
end
# TODO(mhauru) Warnings disabled for benchmarking purposes
# if is_transformed(vi, vn)
# isinverse || @warn "Trying to link an already transformed variable ($vn)"
# else
# isinverse && @warn "Trying to invlink a non-transformed variable ($vn)"
# end

transform = isinverse ? identity : link_transform(right)
y, logjac = with_logabsdet_jacobian(transform, x)
Expand Down
144 changes: 140 additions & 4 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ const NTVarInfo = VarInfo{<:NamedTuple}
const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{
VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}}
}
const TupleVarInfo = VarInfo{<:VarNamedTuple}

function Base.:(==)(vi1::VarInfo, vi2::VarInfo)
return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs)
Expand Down Expand Up @@ -356,6 +357,28 @@ function typed_vector_varinfo(
return typed_vector_varinfo(Random.default_rng(), model, init_strategy)
end

function make_leaf_metadata((r, dist), optic)
md = Metadata(Float64, VarName{:_})
vn = VarName{:_}(optic)
push!(md, vn, r, dist)
return md
end

function tuple_varinfo()
metadata = VarNamedTuple((;), make_leaf_metadata)
return VarInfo(metadata, copy(default_accumulators()))
end
function tuple_varinfo(
rng::Random.AbstractRNG,
model::Model,
init_strategy::AbstractInitStrategy=InitFromPrior(),
)
return last(init!!(rng, model, tuple_varinfo(), init_strategy))
end
function tuple_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior())
return tuple_varinfo(Random.default_rng(), model, init_strategy)
end

"""
vector_length(varinfo::VarInfo)

Expand Down Expand Up @@ -416,13 +439,13 @@ unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x)

Construct an empty type unstable instance of `Metadata`.
"""
function Metadata()
vals = Vector{Real}()
function Metadata(eltype=Real, vntype=VarName)
vals = Vector{eltype}()
is_transformed = BitVector()

return Metadata(
Dict{VarName,Int}(),
Vector{VarName}(),
Dict{vntype,Int}(),
Vector{vntype}(),
Vector{UnitRange{Int}}(),
vals,
Vector{Distribution}(),
Expand Down Expand Up @@ -639,6 +662,9 @@ Return the metadata in `vi` that belongs to `vn`.
"""
getmetadata(vi::VarInfo, vn::VarName) = vi.metadata
getmetadata(vi::NTVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn))
function getmetadata(vi::TupleVarInfo, vn::VarName)
return getindex(vi.metadata, remove_trailing_index(vn))
end

"""
getidx(vi::VarInfo, vn::VarName)
Expand Down Expand Up @@ -744,6 +770,10 @@ end
Return the distribution from which `vn` was sampled in `vi`.
"""
getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn)
function getdist(vi::TupleVarInfo, vn::VarName)
main_vn, optic = split_trailing_index(vn)
return getdist(getindex(vi.metadata, main_vn), VarName{:_}(optic))
end
getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)]
# TODO(mhauru) Remove this once the old Gibbs sampler stuff is gone.
function getdist(::VarNamedVector, ::VarName)
Expand Down Expand Up @@ -782,6 +812,10 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`.
The values may or may not be transformed to Euclidean space.
"""
setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn)
function setval!(vi::TupleVarInfo, val, vn::VarName)
main_vn, optic = split_trailing_index(vn)
return setval!(getindex(vi.metadata, main_vn), val, VarName{:_}(optic))
end
function setval!(md::Metadata, val::AbstractVector, vn::VarName)
return md.vals[getrange(md, vn)] = val
end
Expand Down Expand Up @@ -1579,6 +1613,7 @@ function Base.haskey(vi::NTVarInfo, vn::VarName)
end
return any(md_haskey)
end
Base.haskey(vi::TupleVarInfo, vn::VarName) = haskey(vi.metadata, vn)

function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo)
lines = Tuple{String,Any}[
Expand Down Expand Up @@ -1673,6 +1708,25 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution)
return vi
end

function BangBang.push!!(vi::TupleVarInfo, vn::VarName, r, dist::Distribution)
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TupleVarInfo with dist=$dist"
return VarInfo(setindex!!(vi.metadata, (r, dist), vn), vi.accs)
end

# TODO(mhauru) Implement properly
function is_transformed(vi::TupleVarInfo, vn::VarName)
return false
end

function getindex(vi::TupleVarInfo, vn::VarName)
main_vn, optic = split_trailing_index(vn)
return getindex(getindex(vi.metadata, main_vn), VarName{:_}(optic))
end
function getindex_internal(vi::TupleVarInfo, vn::VarName)
main_vn, optic = split_trailing_index(vn)
return getindex_internal(getindex(vi.metadata, main_vn), VarName{:_}(optic))
end

function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...)
push!(getmetadata(vi, vn), vn, val, args...)
return vi
Expand Down Expand Up @@ -1860,3 +1914,85 @@ end
function from_linked_internal_transform(::VarNamedVector, ::VarName, dist)
return from_linked_vec_transform(dist)
end

function link(vi::TupleVarInfo, model::Model)
metadata = link(vi.metadata, model)
return VarInfo(metadata, vi.accs)
end

function link(vnt::VarNamedTuple, model::Model)
new_vnt = map(value -> link(value, model), vnt)
return new_vnt
end

function link(metadata::Metadata, model::Model)
vns = metadata.vns
cumulative_logjac = zero(LogProbType)

# Construct the new transformed values, and keep track of their lengths.
vals_new = map(vns) do vn
# Return early if we're already in unconstrained space.
# HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
if is_transformed(metadata, vn)
return metadata.vals[getrange(metadata, vn)]
end

# Transform to constrained space.
x = getindex_internal(metadata, vn)
dist = getdist(metadata, vn)
f_from_internal = from_internal_transform(metadata, vn, dist)
f_to_linked_internal = inverse(from_linked_internal_transform(metadata, vn, dist))
f = f_to_linked_internal ∘ f_from_internal
y, logjac = with_logabsdet_jacobian(f, x)
# Vectorize value.
yvec = tovec(y)
# Accumulate the log-abs-det jacobian correction.
cumulative_logjac += logjac
# Return the vectorized transformed value.
return yvec
end

# Determine new ranges.
ranges_new = similar(metadata.ranges)
offset = 0
for (i, v) in enumerate(vals_new)
r_start, r_end = offset + 1, length(v) + offset
offset = r_end
ranges_new[i] = r_start:r_end
end

# Now we just create a new metadata with the new `vals` and `ranges`.
return Metadata(
metadata.idcs,
metadata.vns,
ranges_new,
reduce(vcat, vals_new),
metadata.dists,
BitVector(fill(true, length(metadata.vns))),
)
end

function Base.haskey(vi::TupleVarInfo, vn::VarName)
# TODO(mhauru) Fix this to account for the index.
main_vn, optic = split_trailing_index(vn)
haskey(vi.metadata, main_vn) || return false
value = getindex(vi.metadata, main_vn)
if value isa Metadata
return haskey(value, VarName{:_}(optic))
else
error("TODO(mhauru) Implement me")
end
end

function BangBang.setindex!!(metadata::Metadata, val, optic)
return setindex!!(metadata, val, VarName{:_}(optic))
end

function BangBang.setindex!!(metadata::Metadata, (r, dist), vn::VarName)
if haskey(metadata, vn)
setval!(metadata, r, vn)
else
push!(metadata, vn, r, dist)
end
return metadata
end
23 changes: 23 additions & 0 deletions src/varname.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,26 @@ Possibly existing indices of `varname` are neglected.
) where {s,missings,_F,_a,_T}
return s in missings
end

function remove_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic}
return if Optic === typeof(identity)
vn
elseif Optic <: Accessors.IndexLens
VarName{sym}()
else
AbstractPPL.prefix(
remove_trailing_index(AbstractPPL.unprefix(vn, VarName{sym}())), VarName{sym}()
)
end
end

function split_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic}
return if Optic === typeof(identity)
(vn, identity)
elseif Optic <: Accessors.IndexLens
(VarName{sym}(), getoptic(vn))
else
(pref, index) = split_trailing_index(AbstractPPL.unprefix(vn, VarName{sym}()))
(AbstractPPL.prefix(pref, VarName{sym}()), index)
end
end
Loading