Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ RecursiveArrayTools = "3.27.2"
Reexport = "1.0"
ReverseDiff = "1.15.1"
SafeTestsets = "0.1.0"
SciMLBase = "2.103.1"
SciMLBase = "2.117.1"
SciMLJacobianOperators = "0.1"
SciMLStructures = "1.3"
SparseArrays = "1.10"
Expand Down
1 change: 1 addition & 0 deletions src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ include("concrete_solve.jl")
include("second_order.jl")
include("steadystate_adjoint.jl")
include("sde_tools.jl")
include("enzyme_rules.jl")

export extract_local_sensitivities

Expand Down
14 changes: 14 additions & 0 deletions src/enzyme_rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Enzyme rules for VJP choice types defined in SciMLSensitivity
#
# VJP choice types configure how jacobian-vector products are computed within
# sensitivity algorithms. They should be treated as inactive (constant) during
# Enzyme differentiation to prevent errors when they are stored in problem
# structures or other data that Enzyme differentiates through.
#
# Note: AbstractSensitivityAlgorithm inactive rule is handled in SciMLBase
# to avoid type piracy.

import Enzyme: EnzymeRules

# VJP choice types should be inactive since they configure computation methods
EnzymeRules.inactive_type(::Type{<:VJPChoice}) = true
72 changes: 72 additions & 0 deletions test/enzyme_vjp_inactive.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using Test, SciMLSensitivity, Enzyme, OrdinaryDiffEq

# Test that VJP choice types are treated as inactive by Enzyme
# The AbstractSensitivityAlgorithm inactive rule is handled in SciMLBase
# This addresses issue #1225 where sensealg in ODEProblem constructor would fail

@testset "Enzyme VJP Choice Inactive Types" begin

# Test 1: Basic test that VJP objects can be stored in data structures during Enzyme differentiation
@testset "VJP types in data structures" begin
vjp = EnzymeVJP()

function test_func(x)
# Store the VJP in a data structure (this would fail without inactive rules)
data = (value = x[1] + x[2], vjp = vjp)
return data.value * 2.0
end

x = [1.0, 2.0]
dx = Enzyme.make_zero(x)

# This should not throw an error
@test_nowarn Enzyme.autodiff(Enzyme.Reverse, test_func, Enzyme.Active, Enzyme.Duplicated(x, dx))
@test dx ≈ [2.0, 2.0]
end

# Test 2: Test different VJP choice types are inactive
@testset "Different VJP types inactive" begin
vjp_types = [EnzymeVJP(), ZygoteVJP(), ReverseDiffVJP(), TrackerVJP()]

for vjp in vjp_types
function test_func(x)
data = (value = x[1] * x[2], vjp = vjp)
return data.value + 1.0
end

x = [2.0, 3.0]
dx = Enzyme.make_zero(x)

@test_nowarn Enzyme.autodiff(Enzyme.Reverse, test_func, Enzyme.Active, Enzyme.Duplicated(x, dx))
end
end

# Test 3: Test sensitivity algorithms with VJP choices (integration test)
# Note: This test also depends on SciMLBase having AbstractSensitivityAlgorithm as inactive
@testset "Sensitivity algorithms with VJP choices" begin
function f(du, u, p, t)
du[1] = -p[1] * u[1]
du[2] = p[2] * u[2]
end

function loss_func(p)
u0 = [1.0, 2.0]
# Both VJP choice and sensitivity algorithm should be inactive
prob = ODEProblem(
f, u0, (0.0, 0.1), p, sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()))
sol = solve(prob, Tsit5())
return sol.u[end][1] + sol.u[end][2]
end

p = [0.5, 1.5]
dp = Enzyme.make_zero(p)

# This should not throw the "Error handling recursive stores for String" error
# This is the original failing case from issue #1225
@test_nowarn Enzyme.autodiff(Enzyme.Reverse, loss_func, Enzyme.Active, Enzyme.Duplicated(p, dp))

# Verify the gradient is computed (non-zero and finite)
@test all(isfinite, dp)
@test any(x -> abs(x) > 1e-10, dp) # At least one component should be non-trivial
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ end
@time @safetestset "Scalar u0" include("scalar_u.jl")
@time @safetestset "Error Messages" include("error_messages.jl")
@time @safetestset "Autodiff Events" include("autodiff_events.jl")
@time @safetestset "Enzyme VJP Inactive" include("enzyme_vjp_inactive.jl")
end
end

Expand Down
Loading