@@ -508,11 +508,33 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
508508 end
509509 end
510510 end
511+ function test_link_status_respected (strategy:: AbstractInitStrategy )
512+ @testset " check that varinfo linking is preserved: $(typeof (strategy)) " begin
513+ @model logn () = a ~ LogNormal ()
514+ model = logn ()
515+ vi = VarInfo (model)
516+ linked_vi = DynamicPPL. link!! (vi, model)
517+ _, new_vi = DynamicPPL. init!! (model, linked_vi, strategy)
518+ @test DynamicPPL. istrans (new_vi)
519+ # this is the unlinked value, since it uses `getindex`
520+ a = new_vi[@varname (a)]
521+ # internal logjoint should correspond to the transformed value
522+ @test isapprox (
523+ DynamicPPL. getlogjoint_internal (new_vi), logpdf (Normal (), log (a))
524+ )
525+ # user logjoint should correspond to the transformed value
526+ @test isapprox (DynamicPPL. getlogjoint (new_vi), logpdf (LogNormal (), a))
527+ @test isapprox (
528+ only (DynamicPPL. getindex_internal (new_vi, @varname (a))), log (a)
529+ )
530+ end
531+ end
511532
512533 @testset " PriorInit" begin
513534 test_generating_new_values (PriorInit ())
514535 test_replacing_values (PriorInit ())
515536 test_rng_respected (PriorInit ())
537+ test_link_status_respected (PriorInit ())
516538
517539 @testset " check that values are within support" begin
518540 # Not many other sensible checks we can do for priors.
@@ -529,6 +551,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
529551 test_generating_new_values (UniformInit ())
530552 test_replacing_values (UniformInit ())
531553 test_rng_respected (UniformInit ())
554+ test_link_status_respected (UniformInit ())
532555
533556 @testset " check that bounds are respected" begin
534557 @testset " unconstrained" begin
@@ -559,6 +582,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
559582 end
560583
561584 @testset " ParamsInit" begin
585+ test_link_status_respected (ParamsInit ((; a= 1.0 )))
586+ test_link_status_respected (ParamsInit (Dict (@varname (a) => 1.0 )))
587+
562588 @testset " given full set of parameters" begin
563589 # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I)
564590 my_x, my_y = 1.0 , [2.0 , 3.0 ]
0 commit comments