|
| 1 | +module DynamicPPLUtilsTests |
| 2 | + |
| 3 | +using Bijectors: Bijectors |
| 4 | +using Distributions |
| 5 | +using DynamicPPL |
| 6 | +using LinearAlgebra: LinearAlgebra |
| 7 | +using Test |
| 8 | + |
| 9 | +isapprox_nested(a::Number, b::Number; kwargs...) = isapprox(a, b; kwargs...) |
| 10 | +isapprox_nested(a::AbstractArray, b::AbstractArray; kwargs...) = isapprox(a, b; kwargs...) |
| 11 | +function isapprox_nested(a::LinearAlgebra.Cholesky, b::LinearAlgebra.Cholesky; kwargs...) |
| 12 | + return isapprox(a.U, b.U; kwargs...) && isapprox(a.L, b.L; kwargs...) |
| 13 | +end |
| 14 | +function isapprox_nested(a::NamedTuple, b::NamedTuple; kwargs...) |
| 15 | + keys(a) == keys(b) || return false |
| 16 | + return all(k -> isapprox_nested(a[k], b[k]; kwargs...), keys(a)) |
| 17 | +end |
| 18 | + |
1 | 19 | @testset "utils.jl" begin |
2 | 20 | @testset "addlogprob!" begin |
3 | 21 | @model function testmodel() |
|
31 | 49 | end |
32 | 50 | end |
33 | 51 |
|
| 52 | + @testset "transformations" begin |
| 53 | + function test_transformation( |
| 54 | + dist::Distribution; test_bijector_type_stability::Bool=true |
| 55 | + ) |
| 56 | + unlinked = rand(dist) |
| 57 | + unlinked_vec = DynamicPPL.tovec(unlinked) |
| 58 | + @test unlinked_vec isa AbstractVector |
| 59 | + |
| 60 | + from_vec_trfm = DynamicPPL.from_vec_transform(dist) |
| 61 | + unlinked_again, logjac = Bijectors.with_logabsdet_jacobian( |
| 62 | + from_vec_trfm, unlinked_vec |
| 63 | + ) |
| 64 | + @test isapprox_nested(unlinked, unlinked_again) |
| 65 | + @test iszero(logjac) |
| 66 | + # Type stability |
| 67 | + @inferred DynamicPPL.from_vec_transform(dist) |
| 68 | + @inferred Bijectors.with_logabsdet_jacobian(from_vec_trfm, unlinked_vec) |
| 69 | + |
| 70 | + # Typically the same as `bijector(dist)`, but technically a different |
| 71 | + # function |
| 72 | + b = DynamicPPL.link_transform(dist) |
| 73 | + @test (b(unlinked); true) |
| 74 | + linked, logjac = Bijectors.with_logabsdet_jacobian(b, unlinked) |
| 75 | + @test logjac isa Real |
| 76 | + |
| 77 | + binv = DynamicPPL.invlink_transform(dist) |
| 78 | + unlinked_again, logjac_inv = Bijectors.with_logabsdet_jacobian(binv, linked) |
| 79 | + @test isapprox_nested(unlinked, unlinked_again) |
| 80 | + @test isapprox(logjac, -logjac_inv) |
| 81 | + |
| 82 | + linked_vec = DynamicPPL.tovec(linked) |
| 83 | + @test linked_vec isa AbstractVector |
| 84 | + from_linked_vec_trfm = DynamicPPL.from_linked_vec_transform(dist) |
| 85 | + unlinked_again_again = from_linked_vec_trfm(linked_vec) |
| 86 | + @test isapprox_nested(unlinked, unlinked_again_again) |
| 87 | + |
| 88 | + # Sometimes the bijector itself is not type stable. In this case there is not |
| 89 | + # much we can do in DynamicPPL except skip these tests (it has to be fixed |
| 90 | + # upstream in Bijectors.) |
| 91 | + if test_bijector_type_stability |
| 92 | + @inferred DynamicPPL.from_linked_vec_transform(dist) |
| 93 | + @inferred Bijectors.with_logabsdet_jacobian( |
| 94 | + from_linked_vec_trfm, linked_vec |
| 95 | + ) |
| 96 | + end |
| 97 | + |
| 98 | + # Create a model and check that we can evaluate it with both unlinked and linked |
| 99 | + # VarInfo. This relies on the transformations working correctly so is more of an |
| 100 | + # 'end to end' test |
| 101 | + @model test() = x ~ dist |
| 102 | + model = test() |
| 103 | + vi_unlinked = VarInfo(model) |
| 104 | + vi_linked = DynamicPPL.link!!(VarInfo(model), model) |
| 105 | + @test (DynamicPPL.evaluate!!(model, vi_unlinked); true) |
| 106 | + @test (DynamicPPL.evaluate!!(model, vi_linked); true) |
| 107 | + model_init = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) |
| 108 | + @test (DynamicPPL.evaluate!!(model_init, vi_unlinked); true) |
| 109 | + @test (DynamicPPL.evaluate!!(model_init, vi_linked); true) |
| 110 | + end |
| 111 | + |
| 112 | + # Unconstrained univariate |
| 113 | + test_transformation(Normal()) |
| 114 | + # Constrained univariate |
| 115 | + test_transformation(LogNormal()) |
| 116 | + test_transformation(truncated(Normal(); lower=0)) |
| 117 | + test_transformation(Exponential(1.0)) |
| 118 | + test_transformation(Uniform(-2, 2)) |
| 119 | + test_transformation(Beta(2, 2)) |
| 120 | + test_transformation(InverseGamma(2, 3)) |
| 121 | + # Discrete univariate |
| 122 | + test_transformation(Poisson(3)) |
| 123 | + test_transformation(Binomial(10, 0.5)) |
| 124 | + # Multivariate |
| 125 | + test_transformation(MvNormal(zeros(3), LinearAlgebra.I)) |
| 126 | + test_transformation( |
| 127 | + product_distribution([Normal(), LogNormal()]); |
| 128 | + test_bijector_type_stability=false, |
| 129 | + ) |
| 130 | + test_transformation(product_distribution([LogNormal(), LogNormal()])) |
| 131 | + # Matrixvariate |
| 132 | + test_transformation(LKJ(3, 0.5)) |
| 133 | + test_transformation(Wishart(7, [1.0 0.0; 0.0 1.0])) |
| 134 | + # This is a pathological example: the linked representation is a matrix |
| 135 | + test_transformation(product_distribution(fill(Dirichlet(ones(4)), 2, 3))) |
| 136 | + # Cholesky |
| 137 | + test_transformation(LKJCholesky(3, 0.5)) |
| 138 | + # ProductNamedTupleDistribution |
| 139 | + d = product_distribution((a=Normal(), b=LogNormal())) |
| 140 | + test_transformation(d) |
| 141 | + d_nested = product_distribution((x=LKJCholesky(2, 0.5), y=d)) |
| 142 | + test_transformation(d_nested) |
| 143 | + end |
| 144 | + |
34 | 145 | @testset "getargs_dottilde" begin |
35 | 146 | # Some things that are not expressions. |
36 | | - @test getargs_dottilde(:x) === nothing |
37 | | - @test getargs_dottilde(1.0) === nothing |
38 | | - @test getargs_dottilde([1.0, 2.0, 4.0]) === nothing |
| 147 | + @test DynamicPPL.getargs_dottilde(:x) === nothing |
| 148 | + @test DynamicPPL.getargs_dottilde(1.0) === nothing |
| 149 | + @test DynamicPPL.getargs_dottilde([1.0, 2.0, 4.0]) === nothing |
39 | 150 |
|
40 | 151 | # Some expressions. |
41 | | - @test getargs_dottilde(:(x ~ Normal(μ, σ))) === nothing |
42 | | - @test getargs_dottilde(:((.~)(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ))) |
43 | | - @test getargs_dottilde(:((~).(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ))) |
44 | | - @test getargs_dottilde(:(x .~ Normal(μ, σ))) == (:x, :(Normal(μ, σ))) |
45 | | - @test getargs_dottilde(:(@. x ~ Normal(μ, σ))) === nothing |
46 | | - @test getargs_dottilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing |
47 | | - @test getargs_dottilde(:(@~ Normal.(μ, σ))) === nothing |
| 152 | + @test DynamicPPL.getargs_dottilde(:(x ~ Normal(μ, σ))) === nothing |
| 153 | + @test DynamicPPL.getargs_dottilde(:((.~)(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ))) |
| 154 | + @test DynamicPPL.getargs_dottilde(:((~).(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ))) |
| 155 | + @test DynamicPPL.getargs_dottilde(:(x .~ Normal(μ, σ))) == (:x, :(Normal(μ, σ))) |
| 156 | + @test DynamicPPL.getargs_dottilde(:(@. x ~ Normal(μ, σ))) === nothing |
| 157 | + @test DynamicPPL.getargs_dottilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === |
| 158 | + nothing |
| 159 | + @test DynamicPPL.getargs_dottilde(:(@~ Normal.(μ, σ))) === nothing |
48 | 160 | end |
49 | 161 |
|
50 | 162 | @testset "getargs_tilde" begin |
51 | 163 | # Some things that are not expressions. |
52 | | - @test getargs_tilde(:x) === nothing |
53 | | - @test getargs_tilde(1.0) === nothing |
54 | | - @test getargs_tilde([1.0, 2.0, 4.0]) === nothing |
| 164 | + @test DynamicPPL.getargs_tilde(:x) === nothing |
| 165 | + @test DynamicPPL.getargs_tilde(1.0) === nothing |
| 166 | + @test DynamicPPL.getargs_tilde([1.0, 2.0, 4.0]) === nothing |
55 | 167 |
|
56 | 168 | # Some expressions. |
57 | | - @test getargs_tilde(:(x ~ Normal(μ, σ))) == (:x, :(Normal(μ, σ))) |
58 | | - @test getargs_tilde(:((.~)(x, Normal(μ, σ)))) === nothing |
59 | | - @test getargs_tilde(:((~).(x, Normal(μ, σ)))) === nothing |
60 | | - @test getargs_tilde(:(@. x ~ Normal(μ, σ))) === nothing |
61 | | - @test getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing |
62 | | - @test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing |
| 169 | + @test DynamicPPL.getargs_tilde(:(x ~ Normal(μ, σ))) == (:x, :(Normal(μ, σ))) |
| 170 | + @test DynamicPPL.getargs_tilde(:((.~)(x, Normal(μ, σ)))) === nothing |
| 171 | + @test DynamicPPL.getargs_tilde(:((~).(x, Normal(μ, σ)))) === nothing |
| 172 | + @test DynamicPPL.getargs_tilde(:(@. x ~ Normal(μ, σ))) === nothing |
| 173 | + @test DynamicPPL.getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === |
| 174 | + nothing |
| 175 | + @test DynamicPPL.getargs_tilde(:(@~ Normal.(μ, σ))) === nothing |
63 | 176 | end |
64 | 177 |
|
65 | 178 | @testset "tovec" begin |
|
97 | 210 | @test DynamicPPL.group_varnames_by_symbol(vns_tuple) == vns_nt |
98 | 211 | end |
99 | 212 | end |
| 213 | + |
| 214 | +end |
0 commit comments