Skip to content

Commit 262d732

Browse files
committed
Merge remote-tracking branch 'origin/main' into breaking
2 parents 8c3d30f + 1b159a6 commit 262d732

File tree

5 files changed

+201
-27
lines changed

5 files changed

+201
-27
lines changed

HISTORY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
## 0.39.0
44

5+
## 0.38.1
6+
7+
Added `from_linked_vec_transform` and `from_vec_transform` methods for `ProductNamedTupleDistribution`.
8+
This patch allows sampling from `ProductNamedTupleDistribution` in DynamicPPL models.
9+
510
## 0.38.0
611

712
### Breaking changes

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ AbstractMCMC = "5"
5151
AbstractPPL = "0.13.1"
5252
Accessors = "0.1"
5353
BangBang = "0.4.1"
54-
Bijectors = "0.13.18, 0.14, 0.15"
54+
Bijectors = "0.15.11"
5555
ChainRulesCore = "1"
5656
Chairmarks = "1.3.1"
5757
Compat = "4"

src/contexts/init.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ struct InitFromUniform{T<:AbstractFloat} <: AbstractInitStrategy
6161
end
6262
function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFromUniform)
6363
b = Bijectors.bijector(dist)
64-
sz = Bijectors.output_size(b, size(dist))
64+
sz = Bijectors.output_size(b, dist)
6565
y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...))
6666
b_inv = Bijectors.inverse(b)
6767
x = b_inv(y)
@@ -166,12 +166,11 @@ function tilde_assume!!(
166166
# is_transformed(vi) returns true if vi is nonempty and all variables in vi
167167
# are linked.
168168
insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi)
169-
f = if insert_transformed_value
170-
link_transform(dist)
169+
y, logjac = if insert_transformed_value
170+
with_logabsdet_jacobian(link_transform(dist), x)
171171
else
172-
identity
172+
x, zero(LogProbType)
173173
end
174-
y, logjac = with_logabsdet_jacobian(f, x)
175174
# Add the new value to the VarInfo. `push!!` errors if the value already
176175
# exists, hence the need for setindex!!.
177176
if in_varinfo

src/utils.jl

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,6 @@ Return the transformation from the vector representation of a realization of siz
354354
original representation.
355355
"""
356356
from_vec_transform_for_size(sz::Tuple) = ReshapeTransform(sz)
357-
# TODO(mhauru) Is the below used? If not, this function can be removed.
358357
from_vec_transform_for_size(::Tuple{<:Any}) = identity
359358

360359
"""
@@ -367,6 +366,60 @@ from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist))
367366
from_vec_transform(::UnivariateDistribution) = UnwrapSingletonTransform()
368367
from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ReshapeTransform(size(dist))
369368

369+
struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}}
370+
dists::T
371+
# The `i`-th input range corresponds to the segment of the input vector
372+
# that belongs to the `i`-th distribution.
373+
input_ranges::Vector{UnitRange}
374+
function ProductNamedTupleUnvecTransform(
375+
d::Distributions.ProductNamedTupleDistribution{names}
376+
) where {names}
377+
offset = 1
378+
input_ranges = UnitRange[]
379+
for name in names
380+
this_dist = d.dists[name]
381+
this_name_size = _input_length(from_vec_transform(this_dist))
382+
push!(input_ranges, offset:(offset + this_name_size - 1))
383+
offset += this_name_size
384+
end
385+
return new{names,typeof(d.dists)}(d.dists, input_ranges)
386+
end
387+
end
388+
389+
@generated function (trf::ProductNamedTupleUnvecTransform{names})(
390+
x::AbstractVector
391+
) where {names}
392+
expr = Expr(:tuple)
393+
for (i, name) in enumerate(names)
394+
push!(
395+
expr.args,
396+
:($name = from_vec_transform(trf.dists.$name)(x[trf.input_ranges[$i]])),
397+
)
398+
end
399+
return expr
400+
end
401+
402+
function from_vec_transform(dist::Distributions.ProductNamedTupleDistribution)
403+
return ProductNamedTupleUnvecTransform(dist)
404+
end
405+
function Bijectors.with_logabsdet_jacobian(f::ProductNamedTupleUnvecTransform, x)
406+
return f(x), zero(LogProbType)
407+
end
408+
409+
# This function returns the length of the vector that the function from_vec_transform
410+
# expects. This helps us determine which segment of a concatenated vector belongs to which
411+
# variable.
412+
_input_length(from_vec_trfm::UnwrapSingletonTransform) = 1
413+
_input_length(from_vec_trfm::ReshapeTransform) = prod(from_vec_trfm.output_size)
414+
function _input_length(trfm::ProductNamedTupleUnvecTransform)
415+
return sum(_input_length from_vec_transform, values(trfm.dists))
416+
end
417+
function _input_length(
418+
c::ComposedFunction{<:DynamicPPL.ToChol,<:DynamicPPL.ReshapeTransform}
419+
)
420+
return _input_length(c.inner)
421+
end
422+
370423
"""
371424
from_vec_transform(f, size::Tuple)
372425
@@ -405,7 +458,9 @@ function from_linked_vec_transform(dist::UnivariateDistribution)
405458
sz = Bijectors.output_size(f_combined, size(dist))
406459
return UnwrapSingletonTransform(sz) f_combined
407460
end
408-
461+
function from_linked_vec_transform(dist::Distributions.ProductNamedTupleDistribution)
462+
return invlink_transform(dist)
463+
end
409464
# Specializations that circumvent the `from_vec_transform` machinery.
410465
function from_linked_vec_transform(dist::LKJCholesky)
411466
return inverse(Bijectors.VecCholeskyBijector(dist.uplo))

test/utils.jl

Lines changed: 134 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
module DynamicPPLUtilsTests
2+
3+
using Bijectors: Bijectors
4+
using Distributions
5+
using DynamicPPL
6+
using LinearAlgebra: LinearAlgebra
7+
using Test
8+
9+
isapprox_nested(a::Number, b::Number; kwargs...) = isapprox(a, b; kwargs...)
10+
isapprox_nested(a::AbstractArray, b::AbstractArray; kwargs...) = isapprox(a, b; kwargs...)
11+
function isapprox_nested(a::LinearAlgebra.Cholesky, b::LinearAlgebra.Cholesky; kwargs...)
12+
return isapprox(a.U, b.U; kwargs...) && isapprox(a.L, b.L; kwargs...)
13+
end
14+
function isapprox_nested(a::NamedTuple, b::NamedTuple; kwargs...)
15+
keys(a) == keys(b) || return false
16+
return all(k -> isapprox_nested(a[k], b[k]; kwargs...), keys(a))
17+
end
18+
119
@testset "utils.jl" begin
220
@testset "addlogprob!" begin
321
@model function testmodel()
@@ -31,35 +49,130 @@
3149
end
3250
end
3351

52+
@testset "transformations" begin
53+
function test_transformation(
54+
dist::Distribution; test_bijector_type_stability::Bool=true
55+
)
56+
unlinked = rand(dist)
57+
unlinked_vec = DynamicPPL.tovec(unlinked)
58+
@test unlinked_vec isa AbstractVector
59+
60+
from_vec_trfm = DynamicPPL.from_vec_transform(dist)
61+
unlinked_again, logjac = Bijectors.with_logabsdet_jacobian(
62+
from_vec_trfm, unlinked_vec
63+
)
64+
@test isapprox_nested(unlinked, unlinked_again)
65+
@test iszero(logjac)
66+
# Type stability
67+
@inferred DynamicPPL.from_vec_transform(dist)
68+
@inferred Bijectors.with_logabsdet_jacobian(from_vec_trfm, unlinked_vec)
69+
70+
# Typically the same as `bijector(dist)`, but technically a different
71+
# function
72+
b = DynamicPPL.link_transform(dist)
73+
@test (b(unlinked); true)
74+
linked, logjac = Bijectors.with_logabsdet_jacobian(b, unlinked)
75+
@test logjac isa Real
76+
77+
binv = DynamicPPL.invlink_transform(dist)
78+
unlinked_again, logjac_inv = Bijectors.with_logabsdet_jacobian(binv, linked)
79+
@test isapprox_nested(unlinked, unlinked_again)
80+
@test isapprox(logjac, -logjac_inv)
81+
82+
linked_vec = DynamicPPL.tovec(linked)
83+
@test linked_vec isa AbstractVector
84+
from_linked_vec_trfm = DynamicPPL.from_linked_vec_transform(dist)
85+
unlinked_again_again = from_linked_vec_trfm(linked_vec)
86+
@test isapprox_nested(unlinked, unlinked_again_again)
87+
88+
# Sometimes the bijector itself is not type stable. In this case there is not
89+
# much we can do in DynamicPPL except skip these tests (it has to be fixed
90+
# upstream in Bijectors.)
91+
if test_bijector_type_stability
92+
@inferred DynamicPPL.from_linked_vec_transform(dist)
93+
@inferred Bijectors.with_logabsdet_jacobian(
94+
from_linked_vec_trfm, linked_vec
95+
)
96+
end
97+
98+
# Create a model and check that we can evaluate it with both unlinked and linked
99+
# VarInfo. This relies on the transformations working correctly so is more of an
100+
# 'end to end' test
101+
@model test() = x ~ dist
102+
model = test()
103+
vi_unlinked = VarInfo(model)
104+
vi_linked = DynamicPPL.link!!(VarInfo(model), model)
105+
@test (DynamicPPL.evaluate!!(model, vi_unlinked); true)
106+
@test (DynamicPPL.evaluate!!(model, vi_linked); true)
107+
model_init = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext())
108+
@test (DynamicPPL.evaluate!!(model_init, vi_unlinked); true)
109+
@test (DynamicPPL.evaluate!!(model_init, vi_linked); true)
110+
end
111+
112+
# Unconstrained univariate
113+
test_transformation(Normal())
114+
# Constrained univariate
115+
test_transformation(LogNormal())
116+
test_transformation(truncated(Normal(); lower=0))
117+
test_transformation(Exponential(1.0))
118+
test_transformation(Uniform(-2, 2))
119+
test_transformation(Beta(2, 2))
120+
test_transformation(InverseGamma(2, 3))
121+
# Discrete univariate
122+
test_transformation(Poisson(3))
123+
test_transformation(Binomial(10, 0.5))
124+
# Multivariate
125+
test_transformation(MvNormal(zeros(3), LinearAlgebra.I))
126+
test_transformation(
127+
product_distribution([Normal(), LogNormal()]);
128+
test_bijector_type_stability=false,
129+
)
130+
test_transformation(product_distribution([LogNormal(), LogNormal()]))
131+
# Matrixvariate
132+
test_transformation(LKJ(3, 0.5))
133+
test_transformation(Wishart(7, [1.0 0.0; 0.0 1.0]))
134+
# This is a pathological example: the linked representation is a matrix
135+
test_transformation(product_distribution(fill(Dirichlet(ones(4)), 2, 3)))
136+
# Cholesky
137+
test_transformation(LKJCholesky(3, 0.5))
138+
# ProductNamedTupleDistribution
139+
d = product_distribution((a=Normal(), b=LogNormal()))
140+
test_transformation(d)
141+
d_nested = product_distribution((x=LKJCholesky(2, 0.5), y=d))
142+
test_transformation(d_nested)
143+
end
144+
34145
@testset "getargs_dottilde" begin
35146
# Some things that are not expressions.
36-
@test getargs_dottilde(:x) === nothing
37-
@test getargs_dottilde(1.0) === nothing
38-
@test getargs_dottilde([1.0, 2.0, 4.0]) === nothing
147+
@test DynamicPPL.getargs_dottilde(:x) === nothing
148+
@test DynamicPPL.getargs_dottilde(1.0) === nothing
149+
@test DynamicPPL.getargs_dottilde([1.0, 2.0, 4.0]) === nothing
39150

40151
# Some expressions.
41-
@test getargs_dottilde(:(x ~ Normal(μ, σ))) === nothing
42-
@test getargs_dottilde(:((.~)(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ)))
43-
@test getargs_dottilde(:((~).(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ)))
44-
@test getargs_dottilde(:(x .~ Normal(μ, σ))) == (:x, :(Normal(μ, σ)))
45-
@test getargs_dottilde(:(@. x ~ Normal(μ, σ))) === nothing
46-
@test getargs_dottilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing
47-
@test getargs_dottilde(:(@~ Normal.(μ, σ))) === nothing
152+
@test DynamicPPL.getargs_dottilde(:(x ~ Normal(μ, σ))) === nothing
153+
@test DynamicPPL.getargs_dottilde(:((.~)(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ)))
154+
@test DynamicPPL.getargs_dottilde(:((~).(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ)))
155+
@test DynamicPPL.getargs_dottilde(:(x .~ Normal(μ, σ))) == (:x, :(Normal(μ, σ)))
156+
@test DynamicPPL.getargs_dottilde(:(@. x ~ Normal(μ, σ))) === nothing
157+
@test DynamicPPL.getargs_dottilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) ===
158+
nothing
159+
@test DynamicPPL.getargs_dottilde(:(@~ Normal.(μ, σ))) === nothing
48160
end
49161

50162
@testset "getargs_tilde" begin
51163
# Some things that are not expressions.
52-
@test getargs_tilde(:x) === nothing
53-
@test getargs_tilde(1.0) === nothing
54-
@test getargs_tilde([1.0, 2.0, 4.0]) === nothing
164+
@test DynamicPPL.getargs_tilde(:x) === nothing
165+
@test DynamicPPL.getargs_tilde(1.0) === nothing
166+
@test DynamicPPL.getargs_tilde([1.0, 2.0, 4.0]) === nothing
55167

56168
# Some expressions.
57-
@test getargs_tilde(:(x ~ Normal(μ, σ))) == (:x, :(Normal(μ, σ)))
58-
@test getargs_tilde(:((.~)(x, Normal(μ, σ)))) === nothing
59-
@test getargs_tilde(:((~).(x, Normal(μ, σ)))) === nothing
60-
@test getargs_tilde(:(@. x ~ Normal(μ, σ))) === nothing
61-
@test getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing
62-
@test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing
169+
@test DynamicPPL.getargs_tilde(:(x ~ Normal(μ, σ))) == (:x, :(Normal(μ, σ)))
170+
@test DynamicPPL.getargs_tilde(:((.~)(x, Normal(μ, σ)))) === nothing
171+
@test DynamicPPL.getargs_tilde(:((~).(x, Normal(μ, σ)))) === nothing
172+
@test DynamicPPL.getargs_tilde(:(@. x ~ Normal(μ, σ))) === nothing
173+
@test DynamicPPL.getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) ===
174+
nothing
175+
@test DynamicPPL.getargs_tilde(:(@~ Normal.(μ, σ))) === nothing
63176
end
64177

65178
@testset "tovec" begin
@@ -97,3 +210,5 @@
97210
@test DynamicPPL.group_varnames_by_symbol(vns_tuple) == vns_nt
98211
end
99212
end
213+
214+
end

0 commit comments

Comments
 (0)