Skip to content

Commit 3b8b4a8

Browse files
authored
Improvements to VarNamedVector (#1098)
* Change VNV to use Dict rather than OrderedDict * Change concretisation from map(identity, x) to a comprehension * Improve tighten_types!! and loosen_types!! * Fix use of set_transformed!! * Fix push!! for VarInfos * Change the default element types in VNV to be Union{} * In untyped_vector_varinfo, don't rely on Metadata * Code style * Run formatter * In VNV, use typejoin rather than promote_type * Bump patch version to 0.38.4
1 parent 1fa3109 commit 3b8b4a8

File tree

11 files changed

+208
-117
lines changed

11 files changed

+208
-117
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.38.4
4+
5+
Improve performance of VarNamedVector. It should now be very nearly on par with Metadata for all models we've benchmarked on.
6+
37
## 0.38.3
48

59
Add an implementation of `returned(::Model, ::AbstractDict{<:VarName})`.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.38.3"
3+
version = "0.38.4"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ DynamicPPL.reset!
414414
DynamicPPL.update!
415415
DynamicPPL.insert!
416416
DynamicPPL.loosen_types!!
417-
DynamicPPL.tighten_types
417+
DynamicPPL.tighten_types!!
418418
```
419419

420420
```@docs

src/contexts/init.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ function tilde_assume!!(
180180
end
181181
# Neither of these set the `trans` flag so we have to do it manually if
182182
# necessary.
183-
insert_transformed_value && set_transformed!!(vi, true, vn)
183+
if insert_transformed_value
184+
vi = set_transformed!!(vi, true, vn)
185+
end
184186
# `accumulate_assume!!` wants untransformed values as the second argument.
185187
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
186188
# We always return the untransformed value here, as that will determine

src/debug_utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ add_io_context(io::IO) = IOContext(io, :compact => true, :limit => true)
2727
show_varname(io::IO, varname::VarName) = print(io, varname)
2828
function show_varname(io::IO, varname::Array{<:VarName,N}) where {N}
2929
# Attempt to make the type concrete in case the symbol is shared.
30-
return _show_varname(io, map(identity, varname))
30+
return _show_varname(io, [vn for vn in varname])
3131
end
3232
function _show_varname(io::IO, varname::Array{<:VarName,N}) where {N}
3333
# Print the first and last element of the array.
@@ -407,7 +407,7 @@ julia> @model function demo_incorrect()
407407
end
408408
demo_incorrect (generic function with 2 methods)
409409
410-
julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually
410+
julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually
411411
# alert us to the issue of `x` being sampled twice.
412412
model = demo_incorrect(); varinfo = VarInfo(model);
413413

src/logdensityfunction.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ box:
4949
- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring
5050
any effects of linking
5151
- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected
52-
by linking, since transforms are only applied to random variables)
52+
by linking, since transforms are only applied to random variables)
5353
5454
!!! note
5555
By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the
@@ -146,7 +146,7 @@ struct LogDensityFunction{
146146
is_supported(adtype) ||
147147
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
148148
# Get a set of dummy params to use for prep
149-
x = map(identity, varinfo[:])
149+
x = [val for val in varinfo[:]]
150150
if use_closure(adtype)
151151
prep = DI.prepare_gradient(
152152
LogDensityAt(model, getlogdensity, varinfo), adtype, x
@@ -282,7 +282,7 @@ function LogDensityProblems.logdensity_and_gradient(
282282
) where {M,F,V,AD<:ADTypes.AbstractADType}
283283
f.prep === nothing &&
284284
error("Gradient preparation not available; this should not happen")
285-
x = map(identity, x) # Concretise type
285+
x = [val for val in x] # Concretise type
286286
# Make branching statically inferrable, i.e. type-stable (even if the two
287287
# branches happen to return different types)
288288
return if use_closure(f.adtype)

src/simple_varinfo.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName)
484484
"Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.",
485485
)
486486
end
487+
return vi
487488
end
488489

489490
is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)

src/test_utils/ad.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ Everything else is optional, and can be categorised into several groups:
197197
1. _How to specify the results to compare against._
198198
199199
Once logp and its gradient has been calculated with the specified `adtype`,
200-
it can optionally be tested for correctness. The exact way this is tested
200+
it can optionally be tested for correctness. The exact way this is tested
201201
is specified in the `test` parameter.
202202
203203
There are several options for this:
@@ -260,7 +260,7 @@ function run_ad(
260260
if isnothing(params)
261261
params = varinfo[:]
262262
end
263-
params = map(identity, params) # Concretise
263+
params = [p for p in params] # Concretise
264264

265265
# Calculate log-density and gradient with the backend of interest
266266
verbose && @info "Running AD on $(model.f) with $(adtype)\n"

src/varinfo.jl

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ function untyped_vector_varinfo(
315315
model::Model,
316316
init_strategy::AbstractInitStrategy=InitFromPrior(),
317317
)
318-
return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy))
318+
return last(init!!(rng, model, VarInfo(VarNamedVector()), init_strategy))
319319
end
320320
function untyped_vector_varinfo(
321321
model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()
@@ -789,18 +789,24 @@ function setval!(md::Metadata, val, vn::VarName)
789789
return md.vals[getrange(md, vn)] = tovec(val)
790790
end
791791

792+
function set_transformed!!(vi::NTVarInfo, val::Bool, vn::VarName)
793+
md = set_transformed!!(getmetadata(vi, vn), val, vn)
794+
return Accessors.@set vi.metadata[getsym(vn)] = md
795+
end
796+
792797
function set_transformed!!(vi::VarInfo, val::Bool, vn::VarName)
793-
set_transformed!!(getmetadata(vi, vn), val, vn)
794-
return vi
798+
md = set_transformed!!(getmetadata(vi, vn), val, vn)
799+
return VarInfo(md, vi.accs)
795800
end
801+
796802
function set_transformed!!(metadata::Metadata, val::Bool, vn::VarName)
797803
metadata.is_transformed[getidx(metadata, vn)] = val
798804
return metadata
799805
end
800806

801807
function set_transformed!!(vi::VarInfo, val::Bool)
802808
for vn in keys(vi)
803-
set_transformed!!(vi, val, vn)
809+
vi = set_transformed!!(vi, val, vn)
804810
end
805811

806812
return vi
@@ -977,7 +983,7 @@ function filter_subsumed(filter_vns, filtered_vns)
977983
end
978984

979985
@generated function _link!!(
980-
::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names}
986+
::NamedTuple{metadata_names}, vi, varnames::NamedTuple{vns_names}
981987
) where {metadata_names,vns_names}
982988
expr = Expr(:block)
983989
for f in metadata_names
@@ -988,7 +994,7 @@ end
988994
expr.args,
989995
quote
990996
f_vns = vi.metadata.$f.vns
991-
f_vns = filter_subsumed(vns.$f, f_vns)
997+
f_vns = filter_subsumed(varnames.$f, f_vns)
992998
if !isempty(f_vns)
993999
if !is_transformed(vi, f_vns[1])
9941000
# Iterate over all `f_vns` and transform
@@ -1652,30 +1658,47 @@ end
16521658
Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to
16531659
the `VarInfo` `vi`, mutating if it makes sense.
16541660
"""
1655-
function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution)
1656-
if vi isa UntypedVarInfo
1657-
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist"
1658-
elseif vi isa NTVarInfo
1659-
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist"
1660-
end
1661+
function BangBang.push!!(vi::VarInfo, vn::VarName, val, dist::Distribution)
1662+
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist"
1663+
md = push!!(getmetadata(vi, vn), vn, val, dist)
1664+
return VarInfo(md, vi.accs)
1665+
end
16611666

1667+
function BangBang.push!!(vi::NTVarInfo, vn::VarName, val, dist::Distribution)
1668+
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist"
16621669
sym = getsym(vn)
1663-
if vi isa NTVarInfo && ~haskey(vi.metadata, sym)
1670+
meta = if ~haskey(vi.metadata, sym)
16641671
# The NamedTuple doesn't have an entry for this variable, let's add one.
1665-
val = tovec(r)
1666-
md = Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false]))
1667-
vi = Accessors.@set vi.metadata[sym] = md
1672+
_new_submetadata(vi, vn, val, dist)
16681673
else
1669-
meta = getmetadata(vi, vn)
1670-
push!(meta, vn, r, dist)
1674+
push!!(getmetadata(vi, vn), vn, val, dist)
16711675
end
1672-
1676+
vi = Accessors.@set vi.metadata[sym] = meta
16731677
return vi
16741678
end
16751679

1676-
function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...)
1677-
push!(getmetadata(vi, vn), vn, val, args...)
1678-
return vi
1680+
"""
1681+
_new_submetadata(vi::VarInfo{NamedTuple{Names,SubMetas}}, args...) where {Names,SubMetas}
1682+
1683+
Create a new sub-metadata for an NTVarInfo. The type is chosen by the types of existing
1684+
SubMetas.
1685+
"""
1686+
@generated function _new_submetadata(
1687+
vi::VarInfo{NamedTuple{Names,SubMetas}}, vn, r, dist
1688+
) where {Names,SubMetas}
1689+
has_vnv = any(s -> s <: VarNamedVector, SubMetas.parameters)
1690+
return if has_vnv
1691+
:(return _new_vnv_submetadata(vn, r, dist))
1692+
else
1693+
:(return _new_metadata_submetadata(vn, r, dist))
1694+
end
1695+
end
1696+
1697+
_new_vnv_submetadata(vn, r, _) = VarNamedVector([vn], [r])
1698+
1699+
function _new_metadata_submetadata(vn, r, dist)
1700+
val = tovec(r)
1701+
return Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false]))
16791702
end
16801703

16811704
function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...)
@@ -1700,6 +1723,11 @@ function Base.push!(meta::Metadata, vn, r, dist)
17001723
return meta
17011724
end
17021725

1726+
function BangBang.push!!(meta::Metadata, vn, r, dist)
1727+
push!(meta, vn, r, dist)
1728+
return meta
1729+
end
1730+
17031731
function Base.delete!(vi::VarInfo, vn::VarName)
17041732
delete!(getmetadata(vi, vn), vn)
17051733
return vi

0 commit comments

Comments
 (0)