Skip to content

Use NoCache to improve set_to_zero!! performance with Mooncake #975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# DynamicPPL Changelog

## 0.36.16

Improved performance for some models with Mooncake.jl by using `NoCache` with `Mooncake.set_to_zero!!` for DynamicPPL types.

## 0.36.15

Bumped minimum Julia version to 1.10.8 to avoid potential crashes with `Core.Compiler.widenconst` (which Mooncake uses).
Expand Down
26 changes: 26 additions & 0 deletions ext/DynamicPPLMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,30 @@ using Mooncake: Mooncake
# This is purely an optimisation.
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}

@static if isdefined(Mooncake, :requires_cache)
import Mooncake: requires_cache

function Mooncake.requires_cache(::Type{<:DynamicPPL.Metadata})
return Val(false)
end

function Mooncake.requires_cache(::Type{<:DynamicPPL.TypedVarInfo})
return Val(false)
end

function Mooncake.requires_cache(::Type{<:DynamicPPL.Model})
# Model has f (function/closure), args, defaults, context
# Closures can have circular references
return Val(false)
end

function Mooncake.requires_cache(::Type{<:DynamicPPL.LogDensityFunction})
return Val(false)
end

function Mooncake.requires_cache(::Type{<:DynamicPPL.AbstractContext})
return Val(false)
end
end

end # module
168 changes: 165 additions & 3 deletions test/ext/DynamicPPLMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,167 @@
using DynamicPPL
using Distributions
using Random
using Test
using StableRNGs
using Mooncake: Mooncake, NoCache, set_to_zero!!, set_to_zero_internal!!, zero_tangent
using DynamicPPL.TestUtils.AD: @be, median

# Define models globally to avoid closure issues
@model function test_model1(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
return x .~ Normal(m, sqrt(s))
end

@model function test_model2(x, y)
τ ~ Gamma(1, 1)
σ ~ InverseGamma(2, 3)
μ ~ Normal(0, τ)
x .~ Normal(μ, σ)
return y .~ Normal(μ, σ)
end

@testset "DynamicPPLMooncakeExt" begin
Mooncake.TestUtils.test_rule(
StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true
)
@testset "istrans rule" begin
Mooncake.TestUtils.test_rule(
StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true
)
end

@testset "set_to_zero!! correctness" begin
# Test that set_to_zero!! works correctly for DynamicPPL types
model = test_model1([1.0, 2.0, 3.0])
vi = VarInfo(Random.default_rng(), model)
ldf = LogDensityFunction(model, vi, DefaultContext())
tangent = zero_tangent(ldf)

# Modify some values
if hasfield(typeof(tangent.fields.model.fields), :args) &&
hasfield(typeof(tangent.fields.model.fields.args), :x)
x_tangent = tangent.fields.model.fields.args.x
if !isempty(x_tangent)
x_tangent[1] = 5.0
end
end

# Call set_to_zero!! and verify it works
result = set_to_zero!!(tangent)
@test result isa typeof(tangent)

# Check that values are zeroed
if hasfield(typeof(tangent.fields.model.fields), :args) &&
hasfield(typeof(tangent.fields.model.fields.args), :x)
x_tangent = tangent.fields.model.fields.args.x
if !isempty(x_tangent)
@test x_tangent[1] == 0.0
end
end
end

@testset "Performance improvement" begin
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
vi = VarInfo(Random.default_rng(), model)
ldf = LogDensityFunction(model, vi, DefaultContext())
tangent = zero_tangent(ldf)

# Run benchmarks
result_iddict = @be begin
cache = IdDict{Any,Bool}()
set_to_zero_internal!!(cache, tangent)
end

result_nocache = @be set_to_zero!!(tangent)

# Extract median times
time_iddict = median(result_iddict).time
time_nocache = median(result_nocache).time

# We expect NoCache to be faster
speedup = time_iddict / time_nocache
@test speedup > 1.5 # Conservative expectation - should be ~4x

# Sanity check
@info "Performance improvement" speedup time_iddict_μs = time_iddict / 1000 time_nocache_μs =
time_nocache / 1000
end

@testset "Aliasing safety" begin
# Test with aliased data
shared_data = [1.0, 2.0, 3.0]
model = test_model2(shared_data, shared_data) # x and y are the same array
vi = VarInfo(Random.default_rng(), model)
ldf = LogDensityFunction(model, vi, DefaultContext())
tangent = zero_tangent(ldf)

# Check that aliasing is preserved in tangent
if hasfield(typeof(tangent.fields.model.fields), :args)
args = tangent.fields.model.fields.args
if hasfield(typeof(args), :x) && hasfield(typeof(args), :y)
@test args.x === args.y # Aliasing should be preserved

# Modify via x
if !isempty(args.x)
args.x[1] = 10.0
@test args.y[1] == 10.0 # Should also change y
end

# Zero and check both are zeroed
# Since x and y are aliased, zeroing one zeros both
set_to_zero!!(tangent)
if !isempty(args.x)
@test args.x[1] == 0.0
@test args.y[1] == 0.0
end
end
end
end

@testset "Closure handling" begin
# Test that closure models are correctly handled

# Create closure model (captures environment, has circular references)
function create_closure_model()
local_var = 42
@model function closure_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
return x .~ Normal(m, sqrt(s))
end
return closure_model
end

closure_fn = create_closure_model()
model_closure = closure_fn([1.0, 2.0, 3.0])
vi_closure = VarInfo(Random.default_rng(), model_closure)
ldf_closure = LogDensityFunction(model_closure, vi_closure, DefaultContext())
tangent_closure = zero_tangent(ldf_closure)

# Test that it works without stack overflow
@test_nowarn set_to_zero!!(deepcopy(tangent_closure))

# Compare with global model (no closure)
model_global = test_model1([1.0, 2.0, 3.0])
vi_global = VarInfo(Random.default_rng(), model_global)
ldf_global = LogDensityFunction(model_global, vi_global, DefaultContext())
tangent_global = zero_tangent(ldf_global)

# Verify model.f tangent types differ
f_tangent_closure = tangent_closure.fields.model.fields.f
f_tangent_global = tangent_global.fields.model.fields.f

@test f_tangent_global isa Mooncake.NoTangent # Global function
@test f_tangent_closure isa Mooncake.Tangent # Closure function

# Performance comparison
time_global = @elapsed for _ in 1:100
set_to_zero!!(tangent_global)
end

time_closure = @elapsed for _ in 1:100
set_to_zero!!(tangent_closure)
end

# Global should be faster (uses NoCache)
@test time_global < time_closure
end
end
Loading