From e93458c98f6b6da061e635b0360b0c14f733bf49 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 8 Jul 2025 08:25:35 +0100 Subject: [PATCH 1/9] use `NoCache` to improve `set_to_zero!!` performance with Mooncake --- HISTORY.md | 4 + ext/DynamicPPLMooncakeExt.jl | 101 +++++++++++++++ test/ext/DynamicPPLMooncakeExt.jl | 196 +++++++++++++++++++++++++++++- 3 files changed, 298 insertions(+), 3 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 0d2a56606..9aa9d0303 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # DynamicPPL Changelog +## Unreleased + +Improved performance for some models with Mooncake.jl by using `NoCache` with `Mooncake.set_to_zero!!` for DynamicPPL types. + ## 0.36.14 Added compatibility with AbstractPPL@0.12. diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index b86d807bc..e4f85c3eb 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -1,9 +1,110 @@ module DynamicPPLMooncakeExt +__precompile__(false) + using DynamicPPL: DynamicPPL, istrans using Mooncake: Mooncake +import Mooncake: set_to_zero!! +using Mooncake: NoTangent, Tangent, MutableTangent, NoCache, set_to_zero_internal!! # This is purely an optimisation. Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg} +""" +Check if a tangent has the expected structure for a given type. +""" +function has_expected_structure( + x, expected_type::Type{<:Union{Tangent,MutableTangent}}, expected_fields +) + x isa expected_type || return false + hasfield(typeof(x), :fields) || return false + + fields = x.fields + if expected_fields isa Tuple + # Exact match required + propertynames(fields) == expected_fields || return false + else + # All expected fields must be present + all(f in propertynames(fields) for f in expected_fields) || return false + end + + return true +end + +""" +Check if a tangent corresponds to a DynamicPPL.LogDensityFunction +""" +function is_dppl_ldf_tangent(x) + has_expected_structure(x, Tangent, (:model, :varinfo, :context, :adtype, :prep)) || + return false + + fields = x.fields + is_dppl_varinfo_tangent(fields.varinfo) || return false + is_dppl_model_tangent(fields.model) || return false + + return true +end + +""" +Check if a tangent corresponds to a DynamicPPL.VarInfo +""" +function is_dppl_varinfo_tangent(x) + return has_expected_structure(x, Tangent, (:metadata, :logp, :num_produce)) +end + +""" +Check if a tangent corresponds to a DynamicPPL.Model +""" +function is_dppl_model_tangent(x) + return has_expected_structure(x, Tangent, (:f, :args, :defaults, :context)) +end + +""" +Check if a MutableTangent corresponds to DynamicPPL.Metadata +""" +function is_dppl_metadata_tangent(x) + return has_expected_structure( + x, MutableTangent, (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags) + ) +end + +""" +Check if a model function tangent represents a closure. +""" +function is_closure_model(model_f_tangent) + model_f_tangent isa MutableTangent && return true + + if model_f_tangent isa Tangent && hasfield(typeof(model_f_tangent), :fields) + # Check if any field is a MutableTangent with PossiblyUninitTangent{Any} + for (_, fval) in pairs(model_f_tangent.fields) + if fval isa MutableTangent && + hasfield(typeof(fval), :fields) && + hasfield(typeof(fval.fields), :contents) && + fval.fields.contents isa Mooncake.PossiblyUninitTangent{Any} + return true + end + end + end + + return false +end + +function Mooncake.set_to_zero!!(x) + # Check for DynamicPPL types and use NoCache for better performance + if is_dppl_ldf_tangent(x) + # Special handling for LogDensityFunction to detect closures + model_f_tangent = x.fields.model.fields.f + cache = is_closure_model(model_f_tangent) ? IdDict{Any,Bool}() : NoCache() + return set_to_zero_internal!!(cache, x) + elseif is_dppl_varinfo_tangent(x) || + is_dppl_model_tangent(x) || + is_dppl_metadata_tangent(x) + # These types can always use NoCache + return set_to_zero_internal!!(NoCache(), x) + else + # Use the original implementation with IdDict for all other types + return set_to_zero_internal!!(IdDict{Any,Bool}(), x) + end +end + end # module diff --git a/test/ext/DynamicPPLMooncakeExt.jl b/test/ext/DynamicPPLMooncakeExt.jl index 986057da0..4c8fe2c7e 100644 --- a/test/ext/DynamicPPLMooncakeExt.jl +++ b/test/ext/DynamicPPLMooncakeExt.jl @@ -1,5 +1,195 @@ +using DynamicPPL +using Distributions +using Random +using Test +using StableRNGs +using Mooncake: NoCache, set_to_zero!!, set_to_zero_internal!!, zero_tangent +using DynamicPPL.TestUtils.AD: @be +using Statistics: median + +# Define models globally to avoid closure issues +@model function test_model1(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + return x .~ Normal(m, sqrt(s)) +end + +@model function test_model2(x, y) + τ ~ Gamma(1, 1) + σ ~ InverseGamma(2, 3) + μ ~ Normal(0, τ) + x .~ Normal(μ, σ) + return y .~ Normal(μ, σ) +end + @testset "DynamicPPLMooncakeExt" begin - Mooncake.TestUtils.test_rule( - StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true - ) + @testset "istrans rule" begin + Mooncake.TestUtils.test_rule( + StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true + ) + end + + @testset "set_to_zero!! optimization" begin + # Test with a real DynamicPPL model + model = test_model1([1.0, 2.0, 3.0]) + vi = VarInfo(Random.default_rng(), model) + ldf = LogDensityFunction(model, vi, DefaultContext()) + tangent = zero_tangent(ldf) + + # Test that set_to_zero!! works correctly + result = set_to_zero!!(deepcopy(tangent)) + @test result isa typeof(tangent) + + # Test with metadata - verify structure exists + if hasfield(typeof(tangent.fields.varinfo.fields), :metadata) + metadata = tangent.fields.varinfo.fields.metadata + @test !isnothing(metadata) + end + end + + @testset "NoCache optimization correctness" begin + # Test that set_to_zero!! uses NoCache for DynamicPPL types + model = test_model1([1.0, 2.0, 3.0]) + vi = VarInfo(Random.default_rng(), model) + ldf = LogDensityFunction(model, vi, DefaultContext()) + tangent = zero_tangent(ldf) + + # Modify some values + if hasfield(typeof(tangent.fields.model.fields), :args) && + hasfield(typeof(tangent.fields.model.fields.args), :x) + x_tangent = tangent.fields.model.fields.args.x + if !isempty(x_tangent) + x_tangent[1] = 5.0 + end + end + + # Call set_to_zero!! and verify it works + set_to_zero!!(tangent) + + # Check that values are zeroed + if hasfield(typeof(tangent.fields.model.fields), :args) && + hasfield(typeof(tangent.fields.model.fields.args), :x) + x_tangent = tangent.fields.model.fields.args.x + if !isempty(x_tangent) + @test x_tangent[1] == 0.0 + end + end + end + + @testset "Performance improvement" begin + # Test with DEMO_MODELS if available + if isdefined(DynamicPPL.TestUtils, :DEMO_MODELS) && + !isempty(DynamicPPL.TestUtils.DEMO_MODELS) + model = DynamicPPL.TestUtils.DEMO_MODELS[1] + else + # Fallback to our test model + model = test_model1([1.0, 2.0, 3.0, 4.0]) + end + + vi = VarInfo(Random.default_rng(), model) + ldf = LogDensityFunction(model, vi, DefaultContext()) + tangent = zero_tangent(ldf) + + # Run benchmarks + result_iddict = @be begin + cache = IdDict{Any,Bool}() + set_to_zero_internal!!(cache, tangent) + end + + result_nocache = @be set_to_zero!!(tangent) + + # Extract median times + time_iddict = median(result_iddict).time + time_nocache = median(result_nocache).time + + # We expect NoCache to be faster + speedup = time_iddict / time_nocache + @test speedup > 1.5 # Conservative expectation - should be ~4x + + @info "Performance improvement" speedup time_iddict_μs = time_iddict / 1000 time_nocache_μs = + time_nocache / 1000 + end + + @testset "Aliasing safety" begin + # Test with aliased data + shared_data = [1.0, 2.0, 3.0] + model = test_model2(shared_data, shared_data) # x and y are the same array + vi = VarInfo(Random.default_rng(), model) + ldf = LogDensityFunction(model, vi, DefaultContext()) + tangent = zero_tangent(ldf) + + # Check that aliasing is preserved in tangent + if hasfield(typeof(tangent.fields.model.fields), :args) + args = tangent.fields.model.fields.args + if hasfield(typeof(args), :x) && hasfield(typeof(args), :y) + @test args.x === args.y # Aliasing should be preserved + + # Modify via x + if !isempty(args.x) + args.x[1] = 10.0 + @test args.y[1] == 10.0 # Should also change y + end + + # Zero and check both are zeroed + # Since x and y are aliased, zeroing one zeros both + set_to_zero!!(tangent) + if !isempty(args.x) + @test args.x[1] == 0.0 + @test args.y[1] == 0.0 + end + end + end + end + + @testset "Closure handling" begin + # Test that closure models are correctly handled + + # Create closure model (captures environment, has circular references) + function create_closure_model() + local_var = 42 + @model function closure_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + return x .~ Normal(m, sqrt(s)) + end + return closure_model + end + + closure_fn = create_closure_model() + model_closure = closure_fn([1.0, 2.0, 3.0]) + vi_closure = VarInfo(Random.default_rng(), model_closure) + ldf_closure = LogDensityFunction(model_closure, vi_closure, DefaultContext()) + tangent_closure = zero_tangent(ldf_closure) + + # Test that it works without stack overflow + @test_nowarn set_to_zero!!(deepcopy(tangent_closure)) + + # Compare with global model (no closure) + model_global = test_model1([1.0, 2.0, 3.0]) + vi_global = VarInfo(Random.default_rng(), model_global) + ldf_global = LogDensityFunction(model_global, vi_global, DefaultContext()) + tangent_global = zero_tangent(ldf_global) + + # Verify model.f tangent types differ + f_tangent_closure = tangent_closure.fields.model.fields.f + f_tangent_global = tangent_global.fields.model.fields.f + + @test f_tangent_global isa Mooncake.NoTangent # Global function + @test f_tangent_closure isa Mooncake.Tangent # Closure function + + # Performance comparison + time_global = @elapsed for _ in 1:100 + set_to_zero!!(tangent_global) + end + + time_closure = @elapsed for _ in 1:100 + set_to_zero!!(tangent_closure) + end + + # Global should be faster (uses NoCache) + @test time_global < time_closure + + @info "Closure handling" time_global_ms = time_global * 1000 time_closure_ms = + time_closure * 1000 + end end From 18f4c731685edf668a1b300b80ccef08cae4cbf3 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 8 Jul 2025 08:26:29 +0100 Subject: [PATCH 2/9] use concrete version number for history note --- HISTORY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 9aa9d0303..aea3f4262 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,6 @@ # DynamicPPL Changelog -## Unreleased +## 0.36.15 Improved performance for some models with Mooncake.jl by using `NoCache` with `Mooncake.set_to_zero!!` for DynamicPPL types. From 5c79686da93fa1ed342b726df5f63bc2596fe774 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 9 Jul 2025 07:28:24 +0100 Subject: [PATCH 3/9] fix test errors --- test/Project.toml | 2 ++ test/ext/DynamicPPLMooncakeExt.jl | 9 +++------ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index afecba1c4..0ed45a54e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -45,6 +46,7 @@ JET = "0.9, 0.10" LogDensityProblems = "2" MCMCChains = "6.0.4, 7" MacroTools = "0.5.6" +Mooncake = "0.4.137" OrderedCollections = "1" ReverseDiff = "1" StableRNGs = "1" diff --git a/test/ext/DynamicPPLMooncakeExt.jl b/test/ext/DynamicPPLMooncakeExt.jl index 4c8fe2c7e..65ae864b3 100644 --- a/test/ext/DynamicPPLMooncakeExt.jl +++ b/test/ext/DynamicPPLMooncakeExt.jl @@ -3,9 +3,8 @@ using Distributions using Random using Test using StableRNGs -using Mooncake: NoCache, set_to_zero!!, set_to_zero_internal!!, zero_tangent -using DynamicPPL.TestUtils.AD: @be -using Statistics: median +using Mooncake: Mooncake, NoCache, set_to_zero!!, set_to_zero_internal!!, zero_tangent +using DynamicPPL.TestUtils.AD: @be, median # Define models globally to avoid closure issues @model function test_model1(x) @@ -106,6 +105,7 @@ end speedup = time_iddict / time_nocache @test speedup > 1.5 # Conservative expectation - should be ~4x + # Sanity check @info "Performance improvement" speedup time_iddict_μs = time_iddict / 1000 time_nocache_μs = time_nocache / 1000 end @@ -188,8 +188,5 @@ end # Global should be faster (uses NoCache) @test time_global < time_closure - - @info "Closure handling" time_global_ms = time_global * 1000 time_closure_ms = - time_closure * 1000 end end From 66f453c47fefe1cb5398b58ceb65625ad782620a Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 21 Jul 2025 10:08:16 +0100 Subject: [PATCH 4/9] resolve CI error --- ext/DynamicPPLMooncakeExt.jl | 42 ++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index e4f85c3eb..8c166b8c7 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -89,13 +89,55 @@ function is_closure_model(model_f_tangent) return false end +""" +Check if a VarInfo tangent needs caching due to circular references (e.g., Ref fields). +""" +function needs_caching_for_varinfo(x) + # Check if it's a VarInfo tangent + is_dppl_varinfo_tangent(x) || return false + + # Check if the logp field contains a Ref-like tangent structure + hasfield(typeof(x.fields), :logp) || return false + logp_tangent = x.fields.logp + + # Ref types in tangents often appear as MutableTangent with circular references + return logp_tangent isa MutableTangent +end + +""" +Check if a tangent contains PossiblyUninitTangent{Any} which can cause infinite recursion. +""" +function contains_possibly_uninit_any(x) + x isa Mooncake.PossiblyUninitTangent{Any} && return true + + if x isa Tangent && hasfield(typeof(x), :fields) + for (_, fval) in pairs(x.fields) + contains_possibly_uninit_any(fval) && return true + end + elseif x isa MutableTangent && hasfield(typeof(x), :fields) + hasfield(typeof(x.fields), :contents) && + x.fields.contents isa Mooncake.PossiblyUninitTangent{Any} && + return true + end + + return false +end + function Mooncake.set_to_zero!!(x) + # Always use caching if we detect PossiblyUninitTangent{Any} anywhere + if contains_possibly_uninit_any(x) + return set_to_zero_internal!!(IdDict{Any,Bool}(), x) + end + # Check for DynamicPPL types and use NoCache for better performance if is_dppl_ldf_tangent(x) # Special handling for LogDensityFunction to detect closures model_f_tangent = x.fields.model.fields.f cache = is_closure_model(model_f_tangent) ? IdDict{Any,Bool}() : NoCache() return set_to_zero_internal!!(cache, x) + elseif is_dppl_varinfo_tangent(x) && needs_caching_for_varinfo(x) + # Use IdDict for SimpleVarInfo with Ref fields to avoid circular references + return set_to_zero_internal!!(IdDict{Any,Bool}(), x) elseif is_dppl_varinfo_tangent(x) || is_dppl_model_tangent(x) || is_dppl_metadata_tangent(x) From aabc8441f2ef52d82f6b2107155b7960c8e96b7f Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 21 Jul 2025 10:28:05 +0100 Subject: [PATCH 5/9] refactor --- ext/DynamicPPLMooncakeExt.jl | 129 +++++++++++++++++++++++------------ 1 file changed, 87 insertions(+), 42 deletions(-) diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 8c166b8c7..76a06099d 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -10,6 +10,36 @@ using Mooncake: NoTangent, Tangent, MutableTangent, NoCache, set_to_zero_interna # This is purely an optimisation. Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg} +# ======================= +# Cache Strategy System +# ======================= + +""" + determine_cache_strategy(x) + +Determines the appropriate caching strategy for a given tangent. +Returns either `NoCache()` for safe types or `IdDict{Any,Bool}()` for types with circular reference risk. +""" +function determine_cache_strategy(x) + # Fast path: check for known circular reference patterns + has_circular_reference_risk(x) && return IdDict{Any,Bool}() + + # Check for DynamicPPL types that can safely use NoCache + is_safe_dppl_type(x) && return NoCache() + + # Special case: LogDensityFunction without problematic patterns can use NoCache + if is_dppl_ldf_tangent(x) + return NoCache() + end + + # Default to safe caching for unknown types + return IdDict{Any,Bool}() +end + +# ======================= +# Type Recognition +# ======================= + """ Check if a tangent has the expected structure for a given type. """ @@ -68,15 +98,46 @@ function is_dppl_metadata_tangent(x) ) end +# ======================= +# Circular Reference Detection +# ======================= + """ -Check if a model function tangent represents a closure. + has_circular_reference_risk(x) + +Main entry point for detecting circular reference patterns that require caching. +Optimized for performance with targeted checks instead of recursive traversal. """ -function is_closure_model(model_f_tangent) - model_f_tangent isa MutableTangent && return true +function has_circular_reference_risk(x) + # Type-specific targeted checks only + if is_dppl_ldf_tangent(x) + # Check model function for closure patterns with circular refs + model_f = x.fields.model.fields.f + return is_closure_with_circular_refs(model_f) + elseif is_dppl_varinfo_tangent(x) + # Check for Ref fields in VarInfo + return check_for_ref_fields(x) + end - if model_f_tangent isa Tangent && hasfield(typeof(model_f_tangent), :fields) - # Check if any field is a MutableTangent with PossiblyUninitTangent{Any} - for (_, fval) in pairs(model_f_tangent.fields) + # For unknown types, do a shallow check for PossiblyUninitTangent{Any} + return x isa Mooncake.PossiblyUninitTangent{Any} +end + +""" +Check if a tangent represents a closure with circular reference patterns. +Only returns true for actual problematic patterns, not all MutableTangents. +""" +function is_closure_with_circular_refs(x) + # Check if MutableTangent contains PossiblyUninitTangent{Any} + if x isa MutableTangent && hasfield(typeof(x), :fields) + hasfield(typeof(x.fields), :contents) && + x.fields.contents isa Mooncake.PossiblyUninitTangent{Any} && + return true + end + + # For Tangent, only check immediate fields (no deep recursion) + if x isa Tangent && hasfield(typeof(x), :fields) + for (_, fval) in pairs(x.fields) if fval isa MutableTangent && hasfield(typeof(fval), :fields) && hasfield(typeof(fval.fields), :contents) && @@ -90,9 +151,9 @@ function is_closure_model(model_f_tangent) end """ -Check if a VarInfo tangent needs caching due to circular references (e.g., Ref fields). +Check if a VarInfo tangent has Ref fields that need caching. """ -function needs_caching_for_varinfo(x) +function check_for_ref_fields(x) # Check if it's a VarInfo tangent is_dppl_varinfo_tangent(x) || return false @@ -105,48 +166,32 @@ function needs_caching_for_varinfo(x) end """ -Check if a tangent contains PossiblyUninitTangent{Any} which can cause infinite recursion. +Check if a tangent is a safe DynamicPPL type that can use NoCache. """ -function contains_possibly_uninit_any(x) - x isa Mooncake.PossiblyUninitTangent{Any} && return true +function is_safe_dppl_type(x) + # Metadata is always safe + is_dppl_metadata_tangent(x) && return true - if x isa Tangent && hasfield(typeof(x), :fields) - for (_, fval) in pairs(x.fields) - contains_possibly_uninit_any(fval) && return true - end - elseif x isa MutableTangent && hasfield(typeof(x), :fields) - hasfield(typeof(x.fields), :contents) && - x.fields.contents isa Mooncake.PossiblyUninitTangent{Any} && - return true + # Model tangents without closures are safe + if is_dppl_model_tangent(x) + !is_closure_with_circular_refs(x.fields.f) && return true + end + + # VarInfo without Ref fields is safe + if is_dppl_varinfo_tangent(x) + !check_for_ref_fields(x) && return true end return false end -function Mooncake.set_to_zero!!(x) - # Always use caching if we detect PossiblyUninitTangent{Any} anywhere - if contains_possibly_uninit_any(x) - return set_to_zero_internal!!(IdDict{Any,Bool}(), x) - end +# ======================= +# Main Entry Point +# ======================= - # Check for DynamicPPL types and use NoCache for better performance - if is_dppl_ldf_tangent(x) - # Special handling for LogDensityFunction to detect closures - model_f_tangent = x.fields.model.fields.f - cache = is_closure_model(model_f_tangent) ? IdDict{Any,Bool}() : NoCache() - return set_to_zero_internal!!(cache, x) - elseif is_dppl_varinfo_tangent(x) && needs_caching_for_varinfo(x) - # Use IdDict for SimpleVarInfo with Ref fields to avoid circular references - return set_to_zero_internal!!(IdDict{Any,Bool}(), x) - elseif is_dppl_varinfo_tangent(x) || - is_dppl_model_tangent(x) || - is_dppl_metadata_tangent(x) - # These types can always use NoCache - return set_to_zero_internal!!(NoCache(), x) - else - # Use the original implementation with IdDict for all other types - return set_to_zero_internal!!(IdDict{Any,Bool}(), x) - end +function Mooncake.set_to_zero!!(x) + cache = determine_cache_strategy(x) + return set_to_zero_internal!!(cache, x) end end # module From 92f935d2a193ced417d98c08f2ecf5a54aab69d3 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 21 Jul 2025 10:59:57 +0100 Subject: [PATCH 6/9] refactor more; add additional test --- ext/DynamicPPLMooncakeExt.jl | 107 ++++++++++++--------------- test/ext/DynamicPPLMooncakeExt.jl | 118 ++++++++++++++++++++++-------- 2 files changed, 135 insertions(+), 90 deletions(-) diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 76a06099d..51a149e75 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -11,33 +11,7 @@ using Mooncake: NoTangent, Tangent, MutableTangent, NoCache, set_to_zero_interna Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg} # ======================= -# Cache Strategy System -# ======================= - -""" - determine_cache_strategy(x) - -Determines the appropriate caching strategy for a given tangent. -Returns either `NoCache()` for safe types or `IdDict{Any,Bool}()` for types with circular reference risk. -""" -function determine_cache_strategy(x) - # Fast path: check for known circular reference patterns - has_circular_reference_risk(x) && return IdDict{Any,Bool}() - - # Check for DynamicPPL types that can safely use NoCache - is_safe_dppl_type(x) && return NoCache() - - # Special case: LogDensityFunction without problematic patterns can use NoCache - if is_dppl_ldf_tangent(x) - return NoCache() - end - - # Default to safe caching for unknown types - return IdDict{Any,Bool}() -end - -# ======================= -# Type Recognition +# `Mooncake.set_to_zero!!` optimization with `NoCache` # ======================= """ @@ -61,9 +35,6 @@ function has_expected_structure( return true end -""" -Check if a tangent corresponds to a DynamicPPL.LogDensityFunction -""" function is_dppl_ldf_tangent(x) has_expected_structure(x, Tangent, (:model, :varinfo, :context, :adtype, :prep)) || return false @@ -75,47 +46,55 @@ function is_dppl_ldf_tangent(x) return true end -""" -Check if a tangent corresponds to a DynamicPPL.VarInfo -""" function is_dppl_varinfo_tangent(x) return has_expected_structure(x, Tangent, (:metadata, :logp, :num_produce)) end -""" -Check if a tangent corresponds to a DynamicPPL.Model -""" function is_dppl_model_tangent(x) return has_expected_structure(x, Tangent, (:f, :args, :defaults, :context)) end -""" -Check if a MutableTangent corresponds to DynamicPPL.Metadata -""" function is_dppl_metadata_tangent(x) - return has_expected_structure( + # Metadata can be either: + # 1. A MutableTangent with the expected fields (for single metadata) + # 2. A NamedTuple where each value is a Tangent with the expected fields + + # Check for MutableTangent case + if has_expected_structure( x, MutableTangent, (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags) ) -end + return true + end -# ======================= -# Circular Reference Detection -# ======================= + # Check for NamedTuple case (multiple metadata) + if x isa NamedTuple + # Each value should be a Tangent with metadata fields + for var_metadata in values(x) + if !has_expected_structure( + var_metadata, + Tangent, + (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags), + ) + return false + end + end + return true + end + + return false +end """ has_circular_reference_risk(x) Main entry point for detecting circular reference patterns that require caching. -Optimized for performance with targeted checks instead of recursive traversal. """ function has_circular_reference_risk(x) - # Type-specific targeted checks only if is_dppl_ldf_tangent(x) # Check model function for closure patterns with circular refs model_f = x.fields.model.fields.f return is_closure_with_circular_refs(model_f) elseif is_dppl_varinfo_tangent(x) - # Check for Ref fields in VarInfo return check_for_ref_fields(x) end @@ -123,10 +102,6 @@ function has_circular_reference_risk(x) return x isa Mooncake.PossiblyUninitTangent{Any} end -""" -Check if a tangent represents a closure with circular reference patterns. -Only returns true for actual problematic patterns, not all MutableTangents. -""" function is_closure_with_circular_refs(x) # Check if MutableTangent contains PossiblyUninitTangent{Any} if x isa MutableTangent && hasfield(typeof(x), :fields) @@ -150,9 +125,6 @@ function is_closure_with_circular_refs(x) return false end -""" -Check if a VarInfo tangent has Ref fields that need caching. -""" function check_for_ref_fields(x) # Check if it's a VarInfo tangent is_dppl_varinfo_tangent(x) || return false @@ -165,9 +137,6 @@ function check_for_ref_fields(x) return logp_tangent isa MutableTangent end -""" -Check if a tangent is a safe DynamicPPL type that can use NoCache. -""" function is_safe_dppl_type(x) # Metadata is always safe is_dppl_metadata_tangent(x) && return true @@ -185,9 +154,27 @@ function is_safe_dppl_type(x) return false end -# ======================= -# Main Entry Point -# ======================= +""" + determine_cache_strategy(x) + +Determines the appropriate caching strategy for a given tangent. +Returns either `NoCache()` for safe types or `IdDict{Any,Bool}()` for types with circular reference risk. +""" +function determine_cache_strategy(x) + # Fast path: check for known circular reference patterns + has_circular_reference_risk(x) && return IdDict{Any,Bool}() + + # Check for DynamicPPL types that can safely use NoCache + is_safe_dppl_type(x) && return NoCache() + + # Special case: LogDensityFunction without problematic patterns can use NoCache + if is_dppl_ldf_tangent(x) + return NoCache() + end + + # Default to safe caching for unknown types + return IdDict{Any,Bool}() +end function Mooncake.set_to_zero!!(x) cache = determine_cache_strategy(x) diff --git a/test/ext/DynamicPPLMooncakeExt.jl b/test/ext/DynamicPPLMooncakeExt.jl index 65ae864b3..f16bd458a 100644 --- a/test/ext/DynamicPPLMooncakeExt.jl +++ b/test/ext/DynamicPPLMooncakeExt.jl @@ -28,26 +28,8 @@ end ) end - @testset "set_to_zero!! optimization" begin - # Test with a real DynamicPPL model - model = test_model1([1.0, 2.0, 3.0]) - vi = VarInfo(Random.default_rng(), model) - ldf = LogDensityFunction(model, vi, DefaultContext()) - tangent = zero_tangent(ldf) - - # Test that set_to_zero!! works correctly - result = set_to_zero!!(deepcopy(tangent)) - @test result isa typeof(tangent) - - # Test with metadata - verify structure exists - if hasfield(typeof(tangent.fields.varinfo.fields), :metadata) - metadata = tangent.fields.varinfo.fields.metadata - @test !isnothing(metadata) - end - end - - @testset "NoCache optimization correctness" begin - # Test that set_to_zero!! uses NoCache for DynamicPPL types + @testset "set_to_zero!! correctness" begin + # Test that set_to_zero!! works correctly for DynamicPPL types model = test_model1([1.0, 2.0, 3.0]) vi = VarInfo(Random.default_rng(), model) ldf = LogDensityFunction(model, vi, DefaultContext()) @@ -63,7 +45,8 @@ end end # Call set_to_zero!! and verify it works - set_to_zero!!(tangent) + result = set_to_zero!!(tangent) + @test result isa typeof(tangent) # Check that values are zeroed if hasfield(typeof(tangent.fields.model.fields), :args) && @@ -76,15 +59,7 @@ end end @testset "Performance improvement" begin - # Test with DEMO_MODELS if available - if isdefined(DynamicPPL.TestUtils, :DEMO_MODELS) && - !isempty(DynamicPPL.TestUtils.DEMO_MODELS) - model = DynamicPPL.TestUtils.DEMO_MODELS[1] - else - # Fallback to our test model - model = test_model1([1.0, 2.0, 3.0, 4.0]) - end - + model = DynamicPPL.TestUtils.DEMO_MODELS[1] vi = VarInfo(Random.default_rng(), model) ldf = LogDensityFunction(model, vi, DefaultContext()) tangent = zero_tangent(ldf) @@ -189,4 +164,87 @@ end # Global should be faster (uses NoCache) @test time_global < time_closure end + + @testset "Struct field assumptions" begin + # Test that our assumptions about DynamicPPL struct fields are correct + # These tests will fail if DynamicPPL changes its internal structure + + @testset "LogDensityFunction tangent structure" begin + model = test_model1([1.0, 2.0, 3.0]) + vi = VarInfo(Random.default_rng(), model) + ldf = LogDensityFunction(model, vi, DefaultContext()) + tangent = zero_tangent(ldf) + + # Test expected fields exist + @test hasfield(typeof(tangent), :fields) + @test hasfield(typeof(tangent.fields), :model) + @test hasfield(typeof(tangent.fields), :varinfo) + @test hasfield(typeof(tangent.fields), :context) + @test hasfield(typeof(tangent.fields), :adtype) + @test hasfield(typeof(tangent.fields), :prep) + + # Test exact field names match + @test propertynames(tangent.fields) == + (:model, :varinfo, :context, :adtype, :prep) + end + + @testset "VarInfo tangent structure" begin + model = test_model1([1.0, 2.0, 3.0]) + vi = VarInfo(Random.default_rng(), model) + tangent_vi = zero_tangent(vi) + + # Test expected fields exist + @test hasfield(typeof(tangent_vi), :fields) + @test hasfield(typeof(tangent_vi.fields), :metadata) + @test hasfield(typeof(tangent_vi.fields), :logp) + @test hasfield(typeof(tangent_vi.fields), :num_produce) + + # Test exact field names match + @test propertynames(tangent_vi.fields) == (:metadata, :logp, :num_produce) + end + + @testset "Model tangent structure" begin + model = test_model1([1.0, 2.0, 3.0]) + tangent_model = zero_tangent(model) + + # Test expected fields exist + @test hasfield(typeof(tangent_model), :fields) + @test hasfield(typeof(tangent_model.fields), :f) + @test hasfield(typeof(tangent_model.fields), :args) + @test hasfield(typeof(tangent_model.fields), :defaults) + @test hasfield(typeof(tangent_model.fields), :context) + + # Test exact field names match + @test propertynames(tangent_model.fields) == (:f, :args, :defaults, :context) + end + + @testset "Metadata tangent structure" begin + model = test_model1([1.0, 2.0, 3.0]) + vi = VarInfo(Random.default_rng(), model) + tangent_vi = zero_tangent(vi) + metadata = tangent_vi.fields.metadata + + # Metadata is a NamedTuple with variable names as keys + @test metadata isa NamedTuple + + # Each variable's metadata should be a Tangent with the expected fields + for (varname, var_metadata) in pairs(metadata) + @test var_metadata isa Mooncake.Tangent + @test hasfield(typeof(var_metadata), :fields) + + # Test expected fields exist + @test hasfield(typeof(var_metadata.fields), :idcs) + @test hasfield(typeof(var_metadata.fields), :vns) + @test hasfield(typeof(var_metadata.fields), :ranges) + @test hasfield(typeof(var_metadata.fields), :vals) + @test hasfield(typeof(var_metadata.fields), :dists) + @test hasfield(typeof(var_metadata.fields), :orders) + @test hasfield(typeof(var_metadata.fields), :flags) + + # Test exact field names match + @test propertynames(var_metadata.fields) == + (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags) + end + end + end end From de57eddead375c1f4c89aba48e542f834140a278 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 21 Jul 2025 11:26:38 +0100 Subject: [PATCH 7/9] remove Mooncake from test project --- test/Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 0ed45a54e..afecba1c4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -18,7 +18,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -46,7 +45,6 @@ JET = "0.9, 0.10" LogDensityProblems = "2" MCMCChains = "6.0.4, 7" MacroTools = "0.5.6" -Mooncake = "0.4.137" OrderedCollections = "1" ReverseDiff = "1" StableRNGs = "1" From 76622ae007fef77446c2ee469f852b8f000404ca Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 6 Aug 2025 18:26:45 +0100 Subject: [PATCH 8/9] use `Mooncake.requires_cache` function --- ext/DynamicPPLMooncakeExt.jl | 185 +++--------------------------- test/ext/DynamicPPLMooncakeExt.jl | 83 -------------- 2 files changed, 18 insertions(+), 250 deletions(-) diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 51a149e75..649db792a 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -1,184 +1,35 @@ module DynamicPPLMooncakeExt -__precompile__(false) - using DynamicPPL: DynamicPPL, istrans using Mooncake: Mooncake -import Mooncake: set_to_zero!! -using Mooncake: NoTangent, Tangent, MutableTangent, NoCache, set_to_zero_internal!! # This is purely an optimisation. Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg} -# ======================= -# `Mooncake.set_to_zero!!` optimization with `NoCache` -# ======================= - -""" -Check if a tangent has the expected structure for a given type. -""" -function has_expected_structure( - x, expected_type::Type{<:Union{Tangent,MutableTangent}}, expected_fields -) - x isa expected_type || return false - hasfield(typeof(x), :fields) || return false +@static if isdefined(Mooncake, :requires_cache) + import Mooncake: requires_cache - fields = x.fields - if expected_fields isa Tuple - # Exact match required - propertynames(fields) == expected_fields || return false - else - # All expected fields must be present - all(f in propertynames(fields) for f in expected_fields) || return false + function Mooncake.requires_cache(::Type{<:DynamicPPL.Metadata}) + return Val(false) end - - return true -end - -function is_dppl_ldf_tangent(x) - has_expected_structure(x, Tangent, (:model, :varinfo, :context, :adtype, :prep)) || - return false - - fields = x.fields - is_dppl_varinfo_tangent(fields.varinfo) || return false - is_dppl_model_tangent(fields.model) || return false - - return true -end - -function is_dppl_varinfo_tangent(x) - return has_expected_structure(x, Tangent, (:metadata, :logp, :num_produce)) -end - -function is_dppl_model_tangent(x) - return has_expected_structure(x, Tangent, (:f, :args, :defaults, :context)) -end - -function is_dppl_metadata_tangent(x) - # Metadata can be either: - # 1. A MutableTangent with the expected fields (for single metadata) - # 2. A NamedTuple where each value is a Tangent with the expected fields - - # Check for MutableTangent case - if has_expected_structure( - x, MutableTangent, (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags) - ) - return true + + function Mooncake.requires_cache(::Type{<:DynamicPPL.TypedVarInfo}) + return Val(false) end - - # Check for NamedTuple case (multiple metadata) - if x isa NamedTuple - # Each value should be a Tangent with metadata fields - for var_metadata in values(x) - if !has_expected_structure( - var_metadata, - Tangent, - (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags), - ) - return false - end - end - return true + + function Mooncake.requires_cache(::Type{<:DynamicPPL.Model}) + # Model has f (function/closure), args, defaults, context + # Closures can have circular references + return Val(false) end - - return false -end - -""" - has_circular_reference_risk(x) - -Main entry point for detecting circular reference patterns that require caching. -""" -function has_circular_reference_risk(x) - if is_dppl_ldf_tangent(x) - # Check model function for closure patterns with circular refs - model_f = x.fields.model.fields.f - return is_closure_with_circular_refs(model_f) - elseif is_dppl_varinfo_tangent(x) - return check_for_ref_fields(x) + + function Mooncake.requires_cache(::Type{<:DynamicPPL.LogDensityFunction}) + return Val(false) end - - # For unknown types, do a shallow check for PossiblyUninitTangent{Any} - return x isa Mooncake.PossiblyUninitTangent{Any} -end - -function is_closure_with_circular_refs(x) - # Check if MutableTangent contains PossiblyUninitTangent{Any} - if x isa MutableTangent && hasfield(typeof(x), :fields) - hasfield(typeof(x.fields), :contents) && - x.fields.contents isa Mooncake.PossiblyUninitTangent{Any} && - return true + + function Mooncake.requires_cache(::Type{<:DynamicPPL.AbstractContext}) + return Val(false) end - - # For Tangent, only check immediate fields (no deep recursion) - if x isa Tangent && hasfield(typeof(x), :fields) - for (_, fval) in pairs(x.fields) - if fval isa MutableTangent && - hasfield(typeof(fval), :fields) && - hasfield(typeof(fval.fields), :contents) && - fval.fields.contents isa Mooncake.PossiblyUninitTangent{Any} - return true - end - end - end - - return false -end - -function check_for_ref_fields(x) - # Check if it's a VarInfo tangent - is_dppl_varinfo_tangent(x) || return false - - # Check if the logp field contains a Ref-like tangent structure - hasfield(typeof(x.fields), :logp) || return false - logp_tangent = x.fields.logp - - # Ref types in tangents often appear as MutableTangent with circular references - return logp_tangent isa MutableTangent -end - -function is_safe_dppl_type(x) - # Metadata is always safe - is_dppl_metadata_tangent(x) && return true - - # Model tangents without closures are safe - if is_dppl_model_tangent(x) - !is_closure_with_circular_refs(x.fields.f) && return true - end - - # VarInfo without Ref fields is safe - if is_dppl_varinfo_tangent(x) - !check_for_ref_fields(x) && return true - end - - return false -end - -""" - determine_cache_strategy(x) - -Determines the appropriate caching strategy for a given tangent. -Returns either `NoCache()` for safe types or `IdDict{Any,Bool}()` for types with circular reference risk. -""" -function determine_cache_strategy(x) - # Fast path: check for known circular reference patterns - has_circular_reference_risk(x) && return IdDict{Any,Bool}() - - # Check for DynamicPPL types that can safely use NoCache - is_safe_dppl_type(x) && return NoCache() - - # Special case: LogDensityFunction without problematic patterns can use NoCache - if is_dppl_ldf_tangent(x) - return NoCache() - end - - # Default to safe caching for unknown types - return IdDict{Any,Bool}() -end - -function Mooncake.set_to_zero!!(x) - cache = determine_cache_strategy(x) - return set_to_zero_internal!!(cache, x) end end # module diff --git a/test/ext/DynamicPPLMooncakeExt.jl b/test/ext/DynamicPPLMooncakeExt.jl index f16bd458a..06804c277 100644 --- a/test/ext/DynamicPPLMooncakeExt.jl +++ b/test/ext/DynamicPPLMooncakeExt.jl @@ -164,87 +164,4 @@ end # Global should be faster (uses NoCache) @test time_global < time_closure end - - @testset "Struct field assumptions" begin - # Test that our assumptions about DynamicPPL struct fields are correct - # These tests will fail if DynamicPPL changes its internal structure - - @testset "LogDensityFunction tangent structure" begin - model = test_model1([1.0, 2.0, 3.0]) - vi = VarInfo(Random.default_rng(), model) - ldf = LogDensityFunction(model, vi, DefaultContext()) - tangent = zero_tangent(ldf) - - # Test expected fields exist - @test hasfield(typeof(tangent), :fields) - @test hasfield(typeof(tangent.fields), :model) - @test hasfield(typeof(tangent.fields), :varinfo) - @test hasfield(typeof(tangent.fields), :context) - @test hasfield(typeof(tangent.fields), :adtype) - @test hasfield(typeof(tangent.fields), :prep) - - # Test exact field names match - @test propertynames(tangent.fields) == - (:model, :varinfo, :context, :adtype, :prep) - end - - @testset "VarInfo tangent structure" begin - model = test_model1([1.0, 2.0, 3.0]) - vi = VarInfo(Random.default_rng(), model) - tangent_vi = zero_tangent(vi) - - # Test expected fields exist - @test hasfield(typeof(tangent_vi), :fields) - @test hasfield(typeof(tangent_vi.fields), :metadata) - @test hasfield(typeof(tangent_vi.fields), :logp) - @test hasfield(typeof(tangent_vi.fields), :num_produce) - - # Test exact field names match - @test propertynames(tangent_vi.fields) == (:metadata, :logp, :num_produce) - end - - @testset "Model tangent structure" begin - model = test_model1([1.0, 2.0, 3.0]) - tangent_model = zero_tangent(model) - - # Test expected fields exist - @test hasfield(typeof(tangent_model), :fields) - @test hasfield(typeof(tangent_model.fields), :f) - @test hasfield(typeof(tangent_model.fields), :args) - @test hasfield(typeof(tangent_model.fields), :defaults) - @test hasfield(typeof(tangent_model.fields), :context) - - # Test exact field names match - @test propertynames(tangent_model.fields) == (:f, :args, :defaults, :context) - end - - @testset "Metadata tangent structure" begin - model = test_model1([1.0, 2.0, 3.0]) - vi = VarInfo(Random.default_rng(), model) - tangent_vi = zero_tangent(vi) - metadata = tangent_vi.fields.metadata - - # Metadata is a NamedTuple with variable names as keys - @test metadata isa NamedTuple - - # Each variable's metadata should be a Tangent with the expected fields - for (varname, var_metadata) in pairs(metadata) - @test var_metadata isa Mooncake.Tangent - @test hasfield(typeof(var_metadata), :fields) - - # Test expected fields exist - @test hasfield(typeof(var_metadata.fields), :idcs) - @test hasfield(typeof(var_metadata.fields), :vns) - @test hasfield(typeof(var_metadata.fields), :ranges) - @test hasfield(typeof(var_metadata.fields), :vals) - @test hasfield(typeof(var_metadata.fields), :dists) - @test hasfield(typeof(var_metadata.fields), :orders) - @test hasfield(typeof(var_metadata.fields), :flags) - - # Test exact field names match - @test propertynames(var_metadata.fields) == - (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags) - end - end - end end From 7532b62369a48e38c730ba3ccb98fa680e8f7b61 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 6 Aug 2025 19:03:48 +0100 Subject: [PATCH 9/9] formatting --- ext/DynamicPPLMooncakeExt.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 649db792a..3eafac00c 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -12,21 +12,21 @@ Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg} function Mooncake.requires_cache(::Type{<:DynamicPPL.Metadata}) return Val(false) end - + function Mooncake.requires_cache(::Type{<:DynamicPPL.TypedVarInfo}) return Val(false) end - + function Mooncake.requires_cache(::Type{<:DynamicPPL.Model}) # Model has f (function/closure), args, defaults, context # Closures can have circular references return Val(false) end - + function Mooncake.requires_cache(::Type{<:DynamicPPL.LogDensityFunction}) return Val(false) end - + function Mooncake.requires_cache(::Type{<:DynamicPPL.AbstractContext}) return Val(false) end