@@ -315,7 +315,7 @@ function untyped_vector_varinfo(
315315 model:: Model ,
316316 init_strategy:: AbstractInitStrategy = InitFromPrior (),
317317)
318- return untyped_vector_varinfo ( untyped_varinfo (rng, model, init_strategy))
318+ return last ( init!! (rng, model, VarInfo ( VarNamedVector ()) , init_strategy))
319319end
320320function untyped_vector_varinfo (
321321 model:: Model , init_strategy:: AbstractInitStrategy = InitFromPrior ()
@@ -789,18 +789,24 @@ function setval!(md::Metadata, val, vn::VarName)
789789 return md. vals[getrange (md, vn)] = tovec (val)
790790end
791791
792+ function set_transformed!! (vi:: NTVarInfo , val:: Bool , vn:: VarName )
793+ md = set_transformed!! (getmetadata (vi, vn), val, vn)
794+ return Accessors. @set vi. metadata[getsym (vn)] = md
795+ end
796+
792797function set_transformed!! (vi:: VarInfo , val:: Bool , vn:: VarName )
793- set_transformed!! (getmetadata (vi, vn), val, vn)
794- return vi
798+ md = set_transformed!! (getmetadata (vi, vn), val, vn)
799+ return VarInfo (md, vi . accs)
795800end
801+
796802function set_transformed!! (metadata:: Metadata , val:: Bool , vn:: VarName )
797803 metadata. is_transformed[getidx (metadata, vn)] = val
798804 return metadata
799805end
800806
801807function set_transformed!! (vi:: VarInfo , val:: Bool )
802808 for vn in keys (vi)
803- set_transformed!! (vi, val, vn)
809+ vi = set_transformed!! (vi, val, vn)
804810 end
805811
806812 return vi
@@ -977,7 +983,7 @@ function filter_subsumed(filter_vns, filtered_vns)
977983end
978984
979985@generated function _link!! (
980- :: NamedTuple{metadata_names} , vi, vns :: NamedTuple{vns_names}
986+ :: NamedTuple{metadata_names} , vi, varnames :: NamedTuple{vns_names}
981987) where {metadata_names,vns_names}
982988 expr = Expr (:block )
983989 for f in metadata_names
988994 expr. args,
989995 quote
990996 f_vns = vi. metadata.$ f. vns
991- f_vns = filter_subsumed (vns .$ f, f_vns)
997+ f_vns = filter_subsumed (varnames .$ f, f_vns)
992998 if ! isempty (f_vns)
993999 if ! is_transformed (vi, f_vns[1 ])
9941000 # Iterate over all `f_vns` and transform
@@ -1652,30 +1658,47 @@ end
16521658Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to
16531659the `VarInfo` `vi`, mutating if it makes sense.
16541660"""
1655- function BangBang. push!! (vi:: VarInfo , vn:: VarName , r, dist:: Distribution )
1656- if vi isa UntypedVarInfo
1657- @assert ~ (vn in keys (vi)) " [push!!] attempt to add an existing variable $(getsym (vn)) ($(vn) ) to VarInfo (keys=$(keys (vi)) ) with dist=$dist "
1658- elseif vi isa NTVarInfo
1659- @assert ~ (haskey (vi, vn)) " [push!!] attempt to add an existing variable $(getsym (vn)) ($(vn) ) to NTVarInfo of syms $(syms (vi)) with dist=$dist "
1660- end
1661+ function BangBang. push!! (vi:: VarInfo , vn:: VarName , val, dist:: Distribution )
1662+ @assert ~ (vn in keys (vi)) " [push!!] attempt to add an existing variable $(getsym (vn)) ($(vn) ) to VarInfo (keys=$(keys (vi)) ) with dist=$dist "
1663+ md = push!! (getmetadata (vi, vn), vn, val, dist)
1664+ return VarInfo (md, vi. accs)
1665+ end
16611666
1667+ function BangBang. push!! (vi:: NTVarInfo , vn:: VarName , val, dist:: Distribution )
1668+ @assert ~ (haskey (vi, vn)) " [push!!] attempt to add an existing variable $(getsym (vn)) ($(vn) ) to NTVarInfo of syms $(syms (vi)) with dist=$dist "
16621669 sym = getsym (vn)
1663- if vi isa NTVarInfo && ~ haskey (vi. metadata, sym)
1670+ meta = if ~ haskey (vi. metadata, sym)
16641671 # The NamedTuple doesn't have an entry for this variable, let's add one.
1665- val = tovec (r)
1666- md = Metadata (Dict (vn => 1 ), [vn], [1 : length (val)], val, [dist], BitVector ([false ]))
1667- vi = Accessors. @set vi. metadata[sym] = md
1672+ _new_submetadata (vi, vn, val, dist)
16681673 else
1669- meta = getmetadata (vi, vn)
1670- push! (meta, vn, r, dist)
1674+ push!! (getmetadata (vi, vn), vn, val, dist)
16711675 end
1672-
1676+ vi = Accessors . @set vi . metadata[sym] = meta
16731677 return vi
16741678end
16751679
1676- function Base. push! (vi:: UntypedVectorVarInfo , vn:: VarName , val, args... )
1677- push! (getmetadata (vi, vn), vn, val, args... )
1678- return vi
1680+ """
1681+ _new_submetadata(vi::VarInfo{NamedTuple{Names,SubMetas}}, args...) where {Names,SubMetas}
1682+
1683+ Create a new sub-metadata for an NTVarInfo. The type is chosen by the types of existing
1684+ SubMetas.
1685+ """
1686+ @generated function _new_submetadata (
1687+ vi:: VarInfo{NamedTuple{Names,SubMetas}} , vn, r, dist
1688+ ) where {Names,SubMetas}
1689+ has_vnv = any (s -> s <: VarNamedVector , SubMetas. parameters)
1690+ return if has_vnv
1691+ :(return _new_vnv_submetadata (vn, r, dist))
1692+ else
1693+ :(return _new_metadata_submetadata (vn, r, dist))
1694+ end
1695+ end
1696+
1697+ _new_vnv_submetadata (vn, r, _) = VarNamedVector ([vn], [r])
1698+
1699+ function _new_metadata_submetadata (vn, r, dist)
1700+ val = tovec (r)
1701+ return Metadata (Dict (vn => 1 ), [vn], [1 : length (val)], val, [dist], BitVector ([false ]))
16791702end
16801703
16811704function Base. push! (vi:: UntypedVectorVarInfo , pair:: Pair , args... )
@@ -1700,6 +1723,11 @@ function Base.push!(meta::Metadata, vn, r, dist)
17001723 return meta
17011724end
17021725
1726+ function BangBang. push!! (meta:: Metadata , vn, r, dist)
1727+ push! (meta, vn, r, dist)
1728+ return meta
1729+ end
1730+
17031731function Base. delete! (vi:: VarInfo , vn:: VarName )
17041732 delete! (getmetadata (vi, vn), vn)
17051733 return vi
0 commit comments