diff --git a/Project.toml b/Project.toml index 74342aa99..3241a9533 100644 --- a/Project.toml +++ b/Project.toml @@ -95,7 +95,7 @@ RecursiveArrayTools = "3.27.2" Reexport = "1.0" ReverseDiff = "1.15.1" SafeTestsets = "0.1.0" -SciMLBase = "2.103.1" +SciMLBase = "2.117.1" SciMLJacobianOperators = "0.1" SciMLStructures = "1.3" SparseArrays = "1.10" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 7a003b4d3..619425ad4 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -83,6 +83,7 @@ include("concrete_solve.jl") include("second_order.jl") include("steadystate_adjoint.jl") include("sde_tools.jl") +include("enzyme_rules.jl") export extract_local_sensitivities diff --git a/src/enzyme_rules.jl b/src/enzyme_rules.jl new file mode 100644 index 000000000..e11b6acd4 --- /dev/null +++ b/src/enzyme_rules.jl @@ -0,0 +1,14 @@ +# Enzyme rules for VJP choice types defined in SciMLSensitivity +# +# VJP choice types configure how jacobian-vector products are computed within +# sensitivity algorithms. They should be treated as inactive (constant) during +# Enzyme differentiation to prevent errors when they are stored in problem +# structures or other data that Enzyme differentiates through. +# +# Note: AbstractSensitivityAlgorithm inactive rule is handled in SciMLBase +# to avoid type piracy. + +import Enzyme: EnzymeRules + +# VJP choice types should be inactive since they configure computation methods +EnzymeRules.inactive_type(::Type{<:VJPChoice}) = true diff --git a/test/enzyme_vjp_inactive.jl b/test/enzyme_vjp_inactive.jl new file mode 100644 index 000000000..d99025896 --- /dev/null +++ b/test/enzyme_vjp_inactive.jl @@ -0,0 +1,72 @@ +using Test, SciMLSensitivity, Enzyme, OrdinaryDiffEq + +# Test that VJP choice types are treated as inactive by Enzyme +# The AbstractSensitivityAlgorithm inactive rule is handled in SciMLBase +# This addresses issue #1225 where sensealg in ODEProblem constructor would fail + +@testset "Enzyme VJP Choice Inactive Types" begin + + # Test 1: Basic test that VJP objects can be stored in data structures during Enzyme differentiation + @testset "VJP types in data structures" begin + vjp = EnzymeVJP() + + function test_func(x) + # Store the VJP in a data structure (this would fail without inactive rules) + data = (value = x[1] + x[2], vjp = vjp) + return data.value * 2.0 + end + + x = [1.0, 2.0] + dx = Enzyme.make_zero(x) + + # This should not throw an error + @test_nowarn Enzyme.autodiff(Enzyme.Reverse, test_func, Enzyme.Active, Enzyme.Duplicated(x, dx)) + @test dx ≈ [2.0, 2.0] + end + + # Test 2: Test different VJP choice types are inactive + @testset "Different VJP types inactive" begin + vjp_types = [EnzymeVJP(), ZygoteVJP(), ReverseDiffVJP(), TrackerVJP()] + + for vjp in vjp_types + function test_func(x) + data = (value = x[1] * x[2], vjp = vjp) + return data.value + 1.0 + end + + x = [2.0, 3.0] + dx = Enzyme.make_zero(x) + + @test_nowarn Enzyme.autodiff(Enzyme.Reverse, test_func, Enzyme.Active, Enzyme.Duplicated(x, dx)) + end + end + + # Test 3: Test sensitivity algorithms with VJP choices (integration test) + # Note: This test also depends on SciMLBase having AbstractSensitivityAlgorithm as inactive + @testset "Sensitivity algorithms with VJP choices" begin + function f(du, u, p, t) + du[1] = -p[1] * u[1] + du[2] = p[2] * u[2] + end + + function loss_func(p) + u0 = [1.0, 2.0] + # Both VJP choice and sensitivity algorithm should be inactive + prob = ODEProblem( + f, u0, (0.0, 0.1), p, sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP())) + sol = solve(prob, Tsit5()) + return sol.u[end][1] + sol.u[end][2] + end + + p = [0.5, 1.5] + dp = Enzyme.make_zero(p) + + # This should not throw the "Error handling recursive stores for String" error + # This is the original failing case from issue #1225 + @test_nowarn Enzyme.autodiff(Enzyme.Reverse, loss_func, Enzyme.Active, Enzyme.Duplicated(p, dp)) + + # Verify the gradient is computed (non-zero and finite) + @test all(isfinite, dp) + @test any(x -> abs(x) > 1e-10, dp) # At least one component should be non-trivial + end +end diff --git a/test/runtests.jl b/test/runtests.jl index cfb90641e..5ddde73c7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,6 +40,7 @@ end @time @safetestset "Scalar u0" include("scalar_u.jl") @time @safetestset "Error Messages" include("error_messages.jl") @time @safetestset "Autodiff Events" include("autodiff_events.jl") + @time @safetestset "Enzyme VJP Inactive" include("enzyme_vjp_inactive.jl") end end