diff --git a/HISTORY.md b/HISTORY.md index 54b40b7e9..4b8f5980e 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # DynamicPPL Changelog +## 0.38.4 + +Improve performance of VarNamedVector. It should now be very nearly on par with Metadata for all models we've benchmarked on. + ## 0.38.3 Add an implementation of `returned(::Model, ::AbstractDict{<:VarName})`. diff --git a/Project.toml b/Project.toml index d54f9d1da..0773bbe04 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.38.3" +version = "0.38.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 22fb89267..0d4e9a654 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -30,4 +30,4 @@ LogDensityProblems = "2.1.2" Mooncake = "0.4" PrettyTables = "3" ReverseDiff = "1.15.3" -StableRNGs = "1" \ No newline at end of file +StableRNGs = "1" diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 0cf958cc1..035d8ff49 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -62,6 +62,8 @@ chosen_combinations = [ ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index dcb6ac9be..225e40cd8 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -80,6 +80,10 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: retvals = model(rng) vns = [VarName{k}() for k in keys(retvals)] SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals)))) + elseif varinfo_choice == :typed_vector + DynamicPPL.typed_vector_varinfo(rng, model) + elseif varinfo_choice == :untyped_vector + DynamicPPL.untyped_vector_varinfo(rng, model) else error("Unknown varinfo choice: $varinfo_choice") end diff --git a/docs/src/api.md b/docs/src/api.md index 31b7d07da..b04bd445d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -414,7 +414,7 @@ DynamicPPL.reset! DynamicPPL.update! DynamicPPL.insert! DynamicPPL.loosen_types!! -DynamicPPL.tighten_types +DynamicPPL.tighten_types!! ``` ```@docs diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 396e1463f..44dbc5508 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -180,7 +180,9 @@ function tilde_assume!!( end # Neither of these set the `trans` flag so we have to do it manually if # necessary. - insert_transformed_value && set_transformed!!(vi, true, vn) + if insert_transformed_value + vi = set_transformed!!(vi, true, vn) + end # `accumulate_assume!!` wants untransformed values as the second argument. vi = accumulate_assume!!(vi, x, logjac, vn, dist) # We always return the untransformed value here, as that will determine diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 13124e3a7..e8b50a0b7 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -27,7 +27,7 @@ add_io_context(io::IO) = IOContext(io, :compact => true, :limit => true) show_varname(io::IO, varname::VarName) = print(io, varname) function show_varname(io::IO, varname::Array{<:VarName,N}) where {N} # Attempt to make the type concrete in case the symbol is shared. - return _show_varname(io, map(identity, varname)) + return _show_varname(io, [vn for vn in varname]) end function _show_varname(io::IO, varname::Array{<:VarName,N}) where {N} # Print the first and last element of the array. @@ -407,7 +407,7 @@ julia> @model function demo_incorrect() end demo_incorrect (generic function with 2 methods) -julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually +julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually # alert us to the issue of `x` being sampled twice. model = demo_incorrect(); varinfo = VarInfo(model); diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index e5e6a6dae..3b7b84953 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -49,7 +49,7 @@ box: - [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of linking - [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected - by linking, since transforms are only applied to random variables) + by linking, since transforms are only applied to random variables) !!! note By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the @@ -146,7 +146,7 @@ struct LogDensityFunction{ is_supported(adtype) || @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." # Get a set of dummy params to use for prep - x = map(identity, varinfo[:]) + x = [val for val in varinfo[:]] if use_closure(adtype) prep = DI.prepare_gradient( LogDensityAt(model, getlogdensity, varinfo), adtype, x @@ -282,7 +282,7 @@ function LogDensityProblems.logdensity_and_gradient( ) where {M,F,V,AD<:ADTypes.AbstractADType} f.prep === nothing && error("Gradient preparation not available; this should not happen") - x = map(identity, x) # Concretise type + x = [val for val in x] # Concretise type # Make branching statically inferrable, i.e. type-stable (even if the two # branches happen to return different types) return if use_closure(f.adtype) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 2ba25f142..434480be6 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -484,6 +484,7 @@ function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) "Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.", ) end + return vi end is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 79442fccf..a49ffd18b 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -197,7 +197,7 @@ Everything else is optional, and can be categorised into several groups: 1. _How to specify the results to compare against._ Once logp and its gradient has been calculated with the specified `adtype`, - it can optionally be tested for correctness. The exact way this is tested + it can optionally be tested for correctness. The exact way this is tested is specified in the `test` parameter. There are several options for this: @@ -260,7 +260,7 @@ function run_ad( if isnothing(params) params = varinfo[:] end - params = map(identity, params) # Concretise + params = [p for p in params] # Concretise # Calculate log-density and gradient with the backend of interest verbose && @info "Running AD on $(model.f) with $(adtype)\n" diff --git a/src/varinfo.jl b/src/varinfo.jl index 734bf3db5..a90b81488 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -315,7 +315,7 @@ function untyped_vector_varinfo( model::Model, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy)) + return last(init!!(rng, model, VarInfo(VarNamedVector()), init_strategy)) end function untyped_vector_varinfo( model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() @@ -789,10 +789,16 @@ function setval!(md::Metadata, val, vn::VarName) return md.vals[getrange(md, vn)] = tovec(val) end +function set_transformed!!(vi::NTVarInfo, val::Bool, vn::VarName) + md = set_transformed!!(getmetadata(vi, vn), val, vn) + return Accessors.@set vi.metadata[getsym(vn)] = md +end + function set_transformed!!(vi::VarInfo, val::Bool, vn::VarName) - set_transformed!!(getmetadata(vi, vn), val, vn) - return vi + md = set_transformed!!(getmetadata(vi, vn), val, vn) + return VarInfo(md, vi.accs) end + function set_transformed!!(metadata::Metadata, val::Bool, vn::VarName) metadata.is_transformed[getidx(metadata, vn)] = val return metadata @@ -800,7 +806,7 @@ end function set_transformed!!(vi::VarInfo, val::Bool) for vn in keys(vi) - set_transformed!!(vi, val, vn) + vi = set_transformed!!(vi, val, vn) end return vi @@ -977,7 +983,7 @@ function filter_subsumed(filter_vns, filtered_vns) end @generated function _link!!( - ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} + ::NamedTuple{metadata_names}, vi, varnames::NamedTuple{vns_names} ) where {metadata_names,vns_names} expr = Expr(:block) for f in metadata_names @@ -988,7 +994,7 @@ end expr.args, quote f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(vns.$f, f_vns) + f_vns = filter_subsumed(varnames.$f, f_vns) if !isempty(f_vns) if !is_transformed(vi, f_vns[1]) # Iterate over all `f_vns` and transform @@ -1652,30 +1658,47 @@ end Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to the `VarInfo` `vi`, mutating if it makes sense. """ -function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) - if vi isa UntypedVarInfo - @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" - elseif vi isa NTVarInfo - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" - end +function BangBang.push!!(vi::VarInfo, vn::VarName, val, dist::Distribution) + @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" + md = push!!(getmetadata(vi, vn), vn, val, dist) + return VarInfo(md, vi.accs) +end +function BangBang.push!!(vi::NTVarInfo, vn::VarName, val, dist::Distribution) + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" sym = getsym(vn) - if vi isa NTVarInfo && ~haskey(vi.metadata, sym) + meta = if ~haskey(vi.metadata, sym) # The NamedTuple doesn't have an entry for this variable, let's add one. - val = tovec(r) - md = Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false])) - vi = Accessors.@set vi.metadata[sym] = md + _new_submetadata(vi, vn, val, dist) else - meta = getmetadata(vi, vn) - push!(meta, vn, r, dist) + push!!(getmetadata(vi, vn), vn, val, dist) end - + vi = Accessors.@set vi.metadata[sym] = meta return vi end -function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...) - push!(getmetadata(vi, vn), vn, val, args...) - return vi +""" + _new_submetadata(vi::VarInfo{NamedTuple{Names,SubMetas}}, args...) where {Names,SubMetas} + +Create a new sub-metadata for an NTVarInfo. The type is chosen by the types of existing +SubMetas. +""" +@generated function _new_submetadata( + vi::VarInfo{NamedTuple{Names,SubMetas}}, vn, r, dist +) where {Names,SubMetas} + has_vnv = any(s -> s <: VarNamedVector, SubMetas.parameters) + return if has_vnv + :(return _new_vnv_submetadata(vn, r, dist)) + else + :(return _new_metadata_submetadata(vn, r, dist)) + end +end + +_new_vnv_submetadata(vn, r, _) = VarNamedVector([vn], [r]) + +function _new_metadata_submetadata(vn, r, dist) + val = tovec(r) + return Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false])) end function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...) @@ -1700,6 +1723,11 @@ function Base.push!(meta::Metadata, vn, r, dist) return meta end +function BangBang.push!!(meta::Metadata, vn, r, dist) + push!(meta, vn, r, dist) + return meta +end + function Base.delete!(vi::VarInfo, vn::VarName) delete!(getmetadata(vi, vn), vn) return vi diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 4b2791d19..2c66e1245 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1,3 +1,5 @@ +const CHECK_CONSISTENCY_DEFAULT = true + """ VarNamedVector @@ -40,6 +42,11 @@ contents of the internal storage quickly with `getindex_internal(vnv, :)`. The o of `VarNamedVector` are mostly used to keep track of which part of the internal storage belongs to which `VarName`. +All constructors accept a keyword argument `check_consistency::Bool=true` that controls +whether to run checks like the number of values matching the number of variables. Some of +these checks can be expensive, so if you are confident in the input, you may want to turn +`check_consistency` off for performance. + # Fields $(FIELDS) @@ -49,13 +56,13 @@ $(FIELDS) The values for different variables are internally all stored in a single vector. For instance, ```jldoctest varnamedvector-struct -julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!, update!, getindex_internal +julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!!, update!!, getindex_internal julia> vnv = VarNamedVector(); -julia> setindex!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x)); +julia> vnv = setindex!!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x)); -julia> setindex!(vnv, reshape(1:6, (2,3)), @varname(y)); +julia> vnv = setindex!!(vnv, reshape(1:6, (2,3)), @varname(y)); julia> vnv.vals 10-element Vector{Real}: @@ -84,7 +91,7 @@ If a variable is updated with a new value that is of a smaller dimension than th value, rather than resizing `vnv.vals`, some elements in `vnv.vals` are marked as inactive. ```jldoctest varnamedvector-struct -julia> update!(vnv, [46.0, 48.0], @varname(x)) +julia> vnv = update!!(vnv, [46.0, 48.0], @varname(x)); julia> vnv.vals 10-element Vector{Real}: @@ -100,7 +107,7 @@ julia> vnv.vals 6 julia> println(vnv.num_inactive); -OrderedDict(1 => 2) +Dict(1 => 2) ``` This helps avoid unnecessary memory allocations for values that repeatedly change dimension. @@ -126,17 +133,17 @@ julia> getindex_internal(vnv, :) ``` """ struct VarNamedVector{ - K<:VarName,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector + K<:VarName,V,T,KVec<:AbstractVector{K},VVec<:AbstractVector{V},TVec<:AbstractVector{T} } """ mapping from a `VarName` to its integer index in `varnames`, `ranges` and `transforms` """ - varname_to_index::OrderedDict{K,Int} + varname_to_index::Dict{K,Int} """ vector of `VarNames` for the variables, where `varnames[varname_to_index[vn]] == vn` """ - varnames::TVN # AbstractVector{<:VarName} + varnames::KVec """ vector of index ranges in `vals` corresponding to `varnames`; each `VarName` `vn` has @@ -149,14 +156,14 @@ struct VarNamedVector{ vector of values of all variables; the value(s) of `vn` is/are `vals[ranges[varname_to_index[vn]]]` """ - vals::TVal # AbstractVector{<:Real} + vals::VVec """ vector of transformations, so that `transforms[varname_to_index[vn]]` is a callable that transforms the value of `vn` back to its original space, undoing any linking and vectorisation """ - transforms::TTrans + transforms::TVec """ vector of booleans indicating whether a variable has been explicitly transformed to @@ -175,79 +182,82 @@ struct VarNamedVector{ Inactive entries always come after the last active entry for the given variable. See the extended help with `??VarNamedVector` for more details. """ - num_inactive::OrderedDict{Int,Int} + num_inactive::Dict{Int,Int} function VarNamedVector( varname_to_index, - varnames::TVN, + varnames::KVec, ranges, - vals::TVal, - transforms::TTrans, + vals::VVec, + transforms::TVec, is_unconstrained=fill!(BitVector(undef, length(varnames)), 0), - num_inactive=OrderedDict{Int,Int}(), - ) where {K,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector} - if length(varnames) != length(ranges) || - length(varnames) != length(transforms) || - length(varnames) != length(is_unconstrained) || - length(varnames) != length(varname_to_index) - msg = ( - "Inputs to VarNamedVector have inconsistent lengths. Got lengths varnames: " * - "$(length(varnames)), ranges: " * - "$(length(ranges)), " * - "transforms: $(length(transforms)), " * - "is_unconstrained: $(length(is_unconstrained)), " * - "varname_to_index: $(length(varname_to_index))." - ) - throw(ArgumentError(msg)) - end + num_inactive=Dict{Int,Int}(); + check_consistency::Bool=CHECK_CONSISTENCY_DEFAULT, + ) where {K,V,T,KVec<:AbstractVector{K},VVec<:AbstractVector{V},TVec<:AbstractVector{T}} + if check_consistency + if length(varnames) != length(ranges) || + length(varnames) != length(transforms) || + length(varnames) != length(is_unconstrained) || + length(varnames) != length(varname_to_index) + msg = ( + "Inputs to VarNamedVector have inconsistent lengths. " * + "Got lengths varnames: $(length(varnames)), " * + "ranges: $(length(ranges)), " * + "transforms: $(length(transforms)), " * + "is_unconstrained: $(length(is_unconstrained)), " * + "varname_to_index: $(length(varname_to_index))." + ) + throw(ArgumentError(msg)) + end - num_vals = mapreduce(length, (+), ranges; init=0) + sum(values(num_inactive)) - if num_vals != length(vals) - msg = ( - "The total number of elements in `vals` ($(length(vals))) does not match " * - "the sum of the lengths of the ranges and the number of inactive entries " * - "($(num_vals))." - ) - throw(ArgumentError(msg)) - end + num_vals = mapreduce(length, (+), ranges; init=0) + sum(values(num_inactive)) + if num_vals != length(vals) + msg = ( + "The total number of elements in `vals` ($(length(vals))) does not " * + "match the sum of the lengths of the ranges and the number of " * + "inactive entries ($(num_vals))." + ) + throw(ArgumentError(msg)) + end - if Set(values(varname_to_index)) != Set(axes(varnames, 1)) - msg = ( - "The set of values of `varname_to_index` is not the set of valid indices " * - "for `varnames`." - ) - throw(ArgumentError(msg)) - end + if Set(values(varname_to_index)) != Set(axes(varnames, 1)) + msg = ( + "The set of values of `varname_to_index` is not the set of valid " * + "indices for `varnames`." + ) + throw(ArgumentError(msg)) + end - if !issubset(Set(keys(num_inactive)), Set(values(varname_to_index))) - msg = ( - "The keys of `num_inactive` are not a subset of the values of " * - "`varname_to_index`." - ) - throw(ArgumentError(msg)) - end + if !issubset(Set(keys(num_inactive)), Set(values(varname_to_index))) + msg = ( + "The keys of `num_inactive` are not a subset of the values of " * + "`varname_to_index`." + ) + throw(ArgumentError(msg)) + end - # Check that the varnames don't overlap. The time cost is quadratic in number of - # variables. If this ever becomes an issue, we should be able to go down to at least - # N log N by sorting based on subsumes-order. - for vn1 in keys(varname_to_index) - for vn2 in keys(varname_to_index) - vn1 === vn2 && continue - if subsumes(vn1, vn2) - msg = ( - "Variables in a VarNamedVector should not subsume each other, " * - "but $vn1 subsumes $vn2, i.e. $vn2 describes a subrange of $vn1." - ) - throw(ArgumentError(msg)) + # Check that the varnames don't overlap. The time cost is quadratic in number of + # variables. If this ever becomes an issue, we should be able to go down to at + # least N log N by sorting based on subsumes-order. + for vn1 in keys(varname_to_index) + for vn2 in keys(varname_to_index) + vn1 === vn2 && continue + if subsumes(vn1, vn2) + msg = ( + "Variables in a VarNamedVector should not subsume each " * + "other, but $vn1 subsumes $vn2." + ) + throw(ArgumentError(msg)) + end end end - end - # We could also have a test to check that the ranges don't overlap, but that sounds - # unlikely to occur, and implementing it in linear time would require a tiny bit of - # thought. + # We could also have a test to check that the ranges don't overlap, but that + # sounds unlikely to occur, and implementing it in linear time would require a + # tiny bit of thought. + end - return new{K,V,TVN,TVal,TTrans}( + return new{K,V,T,KVec,VVec,TVec}( varname_to_index, varnames, ranges, @@ -259,36 +269,41 @@ struct VarNamedVector{ end end -function VarNamedVector{K,V}() where {K,V} - return VarNamedVector(OrderedDict{K,Int}(), K[], UnitRange{Int}[], V[], Any[]) +function VarNamedVector{K,V,T}() where {K,V,T} + return VarNamedVector( + Dict{K,Int}(), K[], UnitRange{Int}[], V[], T[]; check_consistency=false + ) end -# TODO(mhauru) I would like for this to be VarNamedVector(Union{}, Union{}). Simlarly the -# transform vector type above could then be Union{}[]. This would allow expanding the -# VarName and element types only as necessary, which would help keep them concrete. However, -# making that change here opens some other cans of worms related to how VarInfo uses -# BangBang, that I don't want to deal with right now. -VarNamedVector() = VarNamedVector{VarName,Real}() -VarNamedVector(xs::Pair...) = VarNamedVector(OrderedDict(xs...)) -VarNamedVector(x::AbstractDict) = VarNamedVector(keys(x), values(x)) -function VarNamedVector(varnames, vals) - return VarNamedVector(collect_maybe(varnames), collect_maybe(vals)) +VarNamedVector() = VarNamedVector{Union{},Union{},Union{}}() +function VarNamedVector(xs::Pair...; check_consistency=CHECK_CONSISTENCY_DEFAULT) + return VarNamedVector(OrderedDict(xs...); check_consistency=check_consistency) +end +function VarNamedVector(x::AbstractDict; check_consistency=CHECK_CONSISTENCY_DEFAULT) + return VarNamedVector(keys(x), values(x); check_consistency=check_consistency) +end +function VarNamedVector(varnames, vals; check_consistency=CHECK_CONSISTENCY_DEFAULT) + return VarNamedVector( + collect_maybe(varnames), collect_maybe(vals); check_consistency=check_consistency + ) end function VarNamedVector( varnames::AbstractVector, orig_vals::AbstractVector, - transforms=fill(identity, length(varnames)), + transforms=fill(identity, length(varnames)); + check_consistency=CHECK_CONSISTENCY_DEFAULT, ) + if isempty(varnames) && isempty(orig_vals) && isempty(transforms) + return VarNamedVector{eltype(varnames),eltype(orig_vals),eltype(transforms)}() + end # Convert `vals` into a vector of vectors. vals_vecs = map(tovec, orig_vals) transforms = map( (t, val) -> _compose_no_identity(t, from_vec_transform(val)), transforms, orig_vals ) - # TODO: Is this really the way to do this? - if !(eltype(varnames) <: VarName) - varnames = convert(Vector{VarName}, varnames) - end - varname_to_index = OrderedDict{eltype(varnames),Int}( + # Make `varnames` have as concrete an element type as possible. + varnames = [v for v in varnames] + varname_to_index = Dict{eltype(varnames),Int}( vn => i for (i, vn) in enumerate(varnames) ) vals = reduce(vcat, vals_vecs) @@ -301,7 +316,19 @@ function VarNamedVector( offset = r[end] end - return VarNamedVector(varname_to_index, varnames, ranges, vals, transforms) + # Passing on check_consistency here seems wasteful. Wouldn't it be faster to do a + # lightweight check of the arguments of this function, and rely on the correctness + # of what this function does? However, the expensive check is whether any variable + # subsumes another, and that's the same regardless of where it's done, so the + # optimisation would be quite pointless. + return VarNamedVector( + varname_to_index, + varnames, + ranges, + vals, + transforms; + check_consistency=check_consistency, + ) end function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector) @@ -314,6 +341,12 @@ function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector) vnv_left.num_inactive == vnv_right.num_inactive end +function is_concretely_typed(vnv::VarNamedVector) + return isconcretetype(eltype(vnv.varnames)) && + isconcretetype(eltype(vnv.vals)) && + isconcretetype(eltype(vnv.transforms)) +end + getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn] getrange(vnv::VarNamedVector, idx::Int) = vnv.ranges[idx] @@ -531,7 +564,7 @@ to be the default vectorisation transform. This undoes any possible linking. ```jldoctest varnamedvector-reset julia> using DynamicPPL: VarNamedVector, @varname, reset! -julia> vnv = VarNamedVector(); +julia> vnv = VarNamedVector{VarName,Any,Any}(); julia> vnv[@varname(x)] = reshape(1:9, (3, 3)); @@ -766,11 +799,16 @@ function update_internal!( return nothing end -function BangBang.push!(vnv::VarNamedVector, vn, val, dist) +function Base.push!(vnv::VarNamedVector, vn, val, dist) f = from_vec_transform(dist) return setindex_internal!(vnv, tovec(val), vn, f) end +function BangBang.push!!(vnv::VarNamedVector, vn, val, dist) + f = from_vec_transform(dist) + return setindex_internal!!(vnv, tovec(val), vn, f) +end + # BangBang versions of the above functions. # The only difference is that update_internal!! and insert_internal!! check whether the # container types of the VarNamedVector vector need to be expanded to accommodate the new @@ -779,7 +817,7 @@ end # with every ! call replaced with a !! call. """ - loosen_types!!(vnv::VarNamedVector{K,V,TVN,TVal,TTrans}, ::Type{KNew}, ::Type{TransNew}) + loosen_types!!(vnv::VarNamedVector, ::Type{KNew}, ::Type{VNew}, ::Type{TNew}) Loosen the types of `vnv` to allow varname type `KNew` and transformation type `TransNew`. @@ -790,7 +828,7 @@ transformations of type `TransNew` can be pushed to it. Some of the underlying s shared between `vnv` and the return value, and thus mutating one may affect the other. # See also -[`tighten_types`](@ref) +[`tighten_types!!`](@ref) # Examples @@ -805,7 +843,9 @@ julia> setindex_internal!(vnv, collect(1:4), @varname(y), y_trans) ERROR: MethodError: Cannot `convert` an object of type [...] -julia> vnv_loose = DynamicPPL.loosen_types!!(vnv, typeof(@varname(y)), typeof(y_trans)); +julia> vnv_loose = DynamicPPL.loosen_types!!( + vnv, typeof(@varname(y)), Float64, typeof(y_trans) + ); julia> setindex_internal!(vnv_loose, collect(1:4), @varname(y), y_trans) @@ -816,39 +856,63 @@ julia> vnv_loose[@varname(y)] ``` """ function loosen_types!!( - vnv::VarNamedVector, ::Type{KNew}, ::Type{TransNew} -) where {KNew,TransNew} + vnv::VarNamedVector, ::Type{KNew}, ::Type{VNew}, ::Type{TNew} +) where {KNew,VNew,TNew} K = eltype(vnv.varnames) - Trans = eltype(vnv.transforms) - if KNew <: K && TransNew <: Trans + V = eltype(vnv.vals) + T = eltype(vnv.transforms) + if KNew <: K && VNew <: V && TNew <: T return vnv else - vn_type = promote_type(K, KNew) - transform_type = promote_type(Trans, TransNew) - return VarNamedVector( - OrderedDict{vn_type,Int}(vnv.varname_to_index), - Vector{vn_type}(vnv.varnames), - vnv.ranges, - vnv.vals, - Vector{transform_type}(vnv.transforms), - vnv.is_unconstrained, - vnv.num_inactive, - ) + # We could use promote_type here, instead of typejoin. However, that would e.g. + # cause Ints to be converted to Float64s, since + # promote_type(Int, Float64) == Float64, which can cause problems. See + # https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188. + # Base.promote_typejoin would be like typejoin, but creates Unions out of Nothing + # and Missing, rather than falling back on Any. However, it's not exported. + vn_type = typejoin(K, KNew) + val_type = typejoin(V, VNew) + transform_type = typejoin(T, TNew) + # This function would work the same way if the first if statement a few lines above + # was skipped, and we only checked for the below condition. However, the first one + # is constant propagated away at compile time (at least on Julia v1.11.7), whereas + # this one isn't. Hence we keep both for performance. + return if vn_type == K && val_type == V && transform_type == T + vnv + elseif isempty(vnv) + VarNamedVector(vn_type[], val_type[], transform_type[]) + else + # TODO(mhauru) We allow a `vnv` to have any AbstractVector type as its vals, but + # then here always revert to Vector. + VarNamedVector( + Dict{vn_type,Int}(vnv.varname_to_index), + Vector{vn_type}(vnv.varnames), + vnv.ranges, + Vector{val_type}(vnv.vals), + Vector{transform_type}(vnv.transforms), + vnv.is_unconstrained, + vnv.num_inactive; + check_consistency=false, + ) + end end end """ - tighten_types(vnv::VarNamedVector) + tighten_types!!(vnv::VarNamedVector) -Return a copy of `vnv` with the most concrete types possible. +Return a `VarNamedVector` like `vnv` with the most concrete types possible. + +This function either returns `vnv` itself or new `VarNamedVector` with the same values in +it, but with the element types of various containers made as concrete as possible. For instance, if `vnv` has its vector of transforms have eltype `Any`, but all the transforms are actually identity transformations, this function will return a new `VarNamedVector` with the transforms vector having eltype `typeof(identity)`. -This is a lot like the reverse of [`loosen_types!!`](@ref), but with two notable -differences: Unlike `loosen_types!!`, this function does not mutate `vnv`; it also changes -not only the key and transform eltypes, but also the values eltype. +This is a lot like the reverse of [`loosen_types!!`](@ref). Like with `loosen_types!!`, the +return value may share some of its underlying storage with `vnv`, and thus mutating one may +affect the other. # See also [`loosen_types!!`](@ref) @@ -858,9 +922,9 @@ not only the key and transform eltypes, but also the values eltype. ```jldoctest varnamedvector-tighten-types julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!!, setindex_internal! -julia> vnv = VarNamedVector(); +julia> vnv = VarNamedVector(@varname(x) => Real[23], @varname(y) => randn(2,2)); -julia> setindex!(vnv, [23], @varname(x)) +julia> vnv = delete!(vnv, @varname(y)); julia> eltype(vnv) Real @@ -869,7 +933,7 @@ julia> vnv.transforms 1-element Vector{Any}: identity (generic function with 1 method) -julia> vnv_tight = DynamicPPL.tighten_types(vnv); +julia> vnv_tight = DynamicPPL.tighten_types!!(vnv); julia> eltype(vnv_tight) == Int true @@ -879,16 +943,24 @@ julia> vnv_tight.transforms identity (generic function with 1 method) ``` """ -function tighten_types(vnv::VarNamedVector) - return VarNamedVector( - OrderedDict(vnv.varname_to_index...), - map(identity, vnv.varnames), - copy(vnv.ranges), - map(identity, vnv.vals), - map(identity, vnv.transforms), - copy(vnv.is_unconstrained), - copy(vnv.num_inactive), - ) +function tighten_types!!(vnv::VarNamedVector) + return if is_concretely_typed(vnv) + # There can not be anything to tighten, so short-circuit. + vnv + elseif isempty(vnv) + VarNamedVector() + else + VarNamedVector( + Dict(vnv.varname_to_index...), + [x for x in vnv.varnames], + vnv.ranges, + [x for x in vnv.vals], + [x for x in vnv.transforms], + vnv.is_unconstrained, + vnv.num_inactive; + check_consistency=false, + ) + end end function BangBang.setindex!!(vnv::VarNamedVector, val, vn::VarName) @@ -940,18 +1012,22 @@ function setindex_internal!!( end end -function insert_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=nothing) +function insert_internal!!( + vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing +) if transform === nothing transform = identity end - vnv = loosen_types!!(vnv, typeof(vn), typeof(transform)) + vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform)) insert_internal!(vnv, val, vn, transform) return vnv end -function update_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=nothing) +function update_internal!!( + vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing +) transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform - vnv = loosen_types!!(vnv, typeof(vn), typeof(transform_resolved)) + vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved)) update_internal!(vnv, val, vn, transform) return vnv end @@ -1041,6 +1117,14 @@ julia> unflatten(vnv, vnv[:]) == vnv true """ function unflatten(vnv::VarNamedVector, vals::AbstractVector) + if length(vals) != vector_length(vnv) + throw( + ArgumentError( + "Length of `vals` ($(length(vals))) does not match the length of " * + "`vnv` ($(vector_length(vnv))).", + ), + ) + end new_ranges = deepcopy(vnv.ranges) recontiguify_ranges!(new_ranges) return VarNamedVector( @@ -1049,7 +1133,8 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector) new_ranges, vals, vnv.transforms, - vnv.is_unconstrained, + vnv.is_unconstrained; + check_consistency=false, ) end @@ -1063,15 +1148,41 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) vns_right = right_vnv.varnames vns_both = union(vns_left, vns_right) + # Check that varnames do not subsume each other. + for vn_left in vns_left + for vn_right in vns_right + vn_left == vn_right && continue + # TODO(mhauru) Subsumation doesn't actually need to be a showstopper. For + # instance, if right has a value for `x` and left has a value for `x[1]`, then + # right will take precedence anyway, and we could merge. However, that requires + # some extra logic that hasn't been done yet. + if subsumes(vn_left, vn_right) + throw( + ArgumentError( + "Cannot merge VarNamedVectors: variable name $vn_left " * + "subsumes $vn_right.", + ), + ) + elseif subsumes(vn_right, vn_left) + throw( + ArgumentError( + "Cannot merge VarNamedVectors: variable name $vn_right " * + "subsumes $vn_left.", + ), + ) + end + end + end + # Determine `eltype` of `vals`. T_left = eltype(left_vnv.vals) T_right = eltype(right_vnv.vals) - T = promote_type(T_left, T_right) + T = typejoin(T_left, T_right) # Determine `eltype` of `varnames`. V_left = eltype(left_vnv.varnames) V_right = eltype(right_vnv.varnames) - V = promote_type(V_left, V_right) + V = typejoin(V_left, V_right) if !(V <: VarName) V = VarName end @@ -1079,10 +1190,10 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) # Determine `eltype` of `transforms`. F_left = eltype(left_vnv.transforms) F_right = eltype(right_vnv.transforms) - F = promote_type(F_left, F_right) + F = typejoin(F_left, F_right) # Allocate. - varname_to_index = OrderedDict{V,Int}() + varname_to_index = Dict{V,Int}() ranges = UnitRange{Int}[] vals = T[] transforms = F[] @@ -1117,7 +1228,13 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) end return VarNamedVector( - varname_to_index, vns_both, ranges, vals, transforms, is_unconstrained + varname_to_index, + vns_both, + ranges, + vals, + transforms, + is_unconstrained; + check_consistency=false, ) end @@ -1145,7 +1262,6 @@ julia> subset(vnv, [@varname(x[2])]) == VarNamedVector(@varname(x[2]) => [2.0]) true """ function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName}) - # NOTE: This does not specialize types when possible. vnv_new = similar(vnv) # Return early if possible. isempty(vnv) && return vnv_new @@ -1157,7 +1273,7 @@ function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName}) end end - return vnv_new + return tighten_types!!(vnv_new) end """ @@ -1193,7 +1309,8 @@ function Base.similar(vnv::VarNamedVector) similar(vnv.vals, 0), similar(vnv.transforms, 0), BitVector(), - empty(vnv.num_inactive), + empty(vnv.num_inactive); + check_consistency=false, ) end @@ -1355,7 +1472,7 @@ true """ function group_by_symbol(vnv::VarNamedVector) symbols = unique(map(getsym, vnv.varnames)) - nt_vals = map(s -> tighten_types(subset(vnv, [VarName{s}()])), symbols) + nt_vals = map(s -> tighten_types!!(subset(vnv, [VarName{s}()])), symbols) return OrderedDict(zip(symbols, nt_vals)) end @@ -1433,6 +1550,16 @@ function Base.delete!(vnv::VarNamedVector, vn::VarName) return vnv end +""" + delete!!(vnv::VarNamedVector, vn::VarName) + +Like `delete!!`, but tightens the element types of the returned `VarNamedVector`. + +# See also: +[`tighten_types!!`](@ref) +""" +BangBang.delete!!(vnv::VarNamedVector, vn::VarName) = tighten_types!!(delete!(vnv, vn)) + """ values_as(vnv::VarNamedVector[, T]) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 3fd76ffe2..b764d517b 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -148,10 +148,10 @@ end # Empty. vnv = DynamicPPL.VarNamedVector() @test isempty(vnv) - @test eltype(vnv) == Real + @test eltype(vnv) == Union{} # Empty with types. - vnv = DynamicPPL.VarNamedVector{VarName,Float64}() + vnv = DynamicPPL.VarNamedVector{VarName,Float64,typeof(identity)}() @test isempty(vnv) @test eltype(vnv) == Float64 end @@ -369,13 +369,17 @@ end # Explicitly setting the transformation. increment(x) = x .+ 10 vnv = deepcopy(vnv_base) - vnv = DynamicPPL.loosen_types!!(vnv, typeof(vn_left), typeof(increment)) + vnv = DynamicPPL.loosen_types!!( + vnv, typeof(vn_left), eltype(vnv), typeof(increment) + ) DynamicPPL.setindex_internal!( vnv, to_vec_left(val_left .+ 100), vn_left, increment ) @test vnv[vn_left] == to_vec_left(val_left .+ 110) - vnv = DynamicPPL.loosen_types!!(vnv, typeof(vn_right), typeof(increment)) + vnv = DynamicPPL.loosen_types!!( + vnv, typeof(vn_right), eltype(vnv), typeof(increment) + ) DynamicPPL.setindex_internal!( vnv, to_vec_right(val_right .+ 100), vn_right, increment )