diff --git a/docs/src/examples/optimal_control/optimal_control.md b/docs/src/examples/optimal_control/optimal_control.md index 53bca29a7..af4bbad85 100644 --- a/docs/src/examples/optimal_control/optimal_control.md +++ b/docs/src/examples/optimal_control/optimal_control.md @@ -134,8 +134,9 @@ Now let's see what we received: ```@example neuraloptimalcontrol l = loss_adjoint(res3.u) cb(res3, l) -p = Plots.plot(ODE.solve(ODE.remake(prob, p = res3.u), ODE.Tsit5(), saveat = 0.01), ylim = ( - -6, 6), lw = 3) +p = Plots.plot( + ODE.solve(ODE.remake(prob, p = res3.u), ODE.Tsit5(), saveat = 0.01), ylim = ( + -6, 6), lw = 3) Plots.plot!(p, ts, [first(first(ann([t], CA.ComponentArray(res3.u, ax), st))) for t in ts], label = "u(t)", lw = 3) ``` diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index ac76e8727..5bec802d2 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -230,9 +230,20 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f if isinplace && !(p === nothing || p === SciMLBase.NullParameters()) if !isRODE - pf = SciMLBase.ParamJacobianWrapper(unwrappedf, _t, y) + if isscimlstructure(p) + pf = SciMLBase.ParamJacobianWrapper( + ( + du, u, p, t)->unwrappedf(du, u, repack(p), t), _t, y) + else + pf = SciMLBase.ParamJacobianWrapper(unwrappedf, _t, y) + end else - pf = RODEParamJacobianWrapper(unwrappedf, _t, y, _W) + if isscimlstructure(p) + pf = RODEParamJacobianWrapper( + (du, u, p, t, W)->unwrappedf(du, u, repack(p), t, W), _t, y, _W) + else + pf = RODEParamJacobianWrapper(unwrappedf, _t, y, _W) + end end paramjac_config = build_param_jac_config( sensealg, pf, y, SciMLStructures.replace(Tunable(), p, tunables)) @@ -317,7 +328,12 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f elseif autojacvec isa Bool if isinplace if SciMLBase.is_diagonal_noise(prob) - pf = SciMLBase.ParamJacobianWrapper(unwrappedf, _t, y) + if isscimlstructure(p) + pf = SciMLBase.ParamJacobianWrapper( + (du, u, p, t)->unwrappedf(du, u, repack(p), t), _t, y) + else + pf = SciMLBase.ParamJacobianWrapper(unwrappedf, _t, y) + end if isnoisemixing(sensealg) uf = SciMLBase.UJacobianWrapper(unwrappedf, _t, p) jac_noise_config = build_jac_config(sensealg, uf, u0) diff --git a/src/backsolve_adjoint.jl b/src/backsolve_adjoint.jl index bc1d87b4a..b97f6731d 100644 --- a/src/backsolve_adjoint.jl +++ b/src/backsolve_adjoint.jl @@ -139,8 +139,10 @@ end u0 = state_values(sol.prob) if p === nothing || p isa SciMLBase.NullParameters tunables, repack = p, identity - else + elseif isscimlstructure(p) tunables, repack, _ = canonicalize(Tunable(), p) + else + throw(SciMLStructuresCompatibilityError()) end ## Force recompile mode until vjps are specialized to handle this!!! @@ -263,7 +265,13 @@ end (; f, tspan) = sol.prob p = parameter_values(sol) u0 = state_values(sol.prob) - tunables, repack, _ = canonicalize(Tunable(), p) + if p === nothing || p isa SciMLBase.NullParameters + tunables, repack = p, identity + elseif isscimlstructure(p) + tunables, repack, _ = canonicalize(Tunable(), p) + else + throw(SciMLStructuresCompatibilityError()) + end # check if solution was terminated, then use reduced time span terminated = false @@ -283,7 +291,7 @@ end error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!") numstates = length(u0) - numparams = length(tunables) + numparams = p === nothing || p === SciMLBase.NullParameters() ? 0 : length(tunables) len = length(u0) + numparams λ = one(eltype(u0)) .* similar(tunables, len) @@ -386,7 +394,13 @@ end (; f, tspan) = sol.prob p = parameter_values(sol) u0 = state_values(sol.prob) - tunables, repack, _ = canonicalize(Tunable(), p) + if p === nothing || p isa SciMLBase.NullParameters + tunables, repack = p, identity + elseif isscimlstructure(p) + tunables, repack, _ = canonicalize(Tunable(), p) + else + throw(SciMLStructuresCompatibilityError()) + end # check if solution was terminated, then use reduced time span terminated = false if hasfield(typeof(sol), :retcode) @@ -404,7 +418,7 @@ end error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!") numstates = length(u0) - numparams = length(tunables) + numparams = p === nothing || p === SciMLBase.NullParameters() ? 0 : length(tunables) len = length(u0) + numparams λ = one(eltype(u0)) .* similar(tunables, len) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 113220664..0652dd339 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -376,8 +376,12 @@ function DiffEqBase._concrete_solve_adjoint( save_idxs = nothing, initializealg_default = SciMLBase.OverrideInit(; abstol = 1e-6, reltol = 1e-3), kwargs...) - if !(sensealg isa GaussAdjoint) && - !(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || + # Check parameter compatibility for adjoint methods + if !((p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || + (sensealg isa + Union{GaussAdjoint, BacksolveAdjoint, InterpolatingAdjoint, QuadratureAdjoint} && + isscimlstructure(p)) || + (sensealg isa Union{GaussAdjoint, QuadratureAdjoint} && isfunctor(p))) || (p isa AbstractArray && !Base.isconcretetype(eltype(p))) throw(AdjointSensitivityParameterCompatibilityError()) end diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 3e1cd782e..36893a9fe 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -304,9 +304,19 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::Bool, dgrad, dy, pf.t = t pf.u = y if inplace_sensitivity(S) - jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config) + if isscimlstructure(p) + tunables, _, _ = canonicalize(Tunable(), p) + jacobian!(pJ, pf, tunables, f_cache, sensealg, paramjac_config) + else + jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config) + end else - temp = jacobian(pf, p, sensealg) + if isscimlstructure(p) + tunables, _, _ = canonicalize(Tunable(), p) + temp = jacobian(pf, tunables, sensealg) + else + temp = jacobian(pf, p, sensealg) + end pJ .= temp end end @@ -319,9 +329,19 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::Bool, dgrad, dy, pf.u = y pf.W = W if inplace_sensitivity(S) - jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config) + if isscimlstructure(p) + tunables, _, _ = canonicalize(Tunable(), p) + jacobian!(pJ, pf, tunables, f_cache, sensealg, paramjac_config) + else + jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config) + end else - temp = jacobian(pf, p, sensealg) + if isscimlstructure(p) + tunables, _, _ = canonicalize(Tunable(), p) + temp = jacobian(pf, tunables, sensealg) + else + temp = jacobian(pf, p, sensealg) + end pJ .= temp end end @@ -814,10 +834,20 @@ function _jacNoise!(λ, y, p, t, S::TS, isnoise::Bool, dgrad, dλ, pf.t = t pf.u = y if inplace_sensitivity(S) - jacobian!(pJ, pf, p, nothing, sensealg, nothing) + if isscimlstructure(p) + tunables, _, _ = canonicalize(Tunable(), p) + jacobian!(pJ, pf, tunables, nothing, sensealg, nothing) + else + jacobian!(pJ, pf, p, nothing, sensealg, nothing) + end #jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_noise_config) else - temp = jacobian(pf, p, sensealg) + if isscimlstructure(p) + tunables, _, _ = canonicalize(Tunable(), p) + temp = jacobian(pf, tunables, sensealg) + else + temp = jacobian(pf, p, sensealg) + end pJ .= temp end end diff --git a/src/interpolating_adjoint.jl b/src/interpolating_adjoint.jl index 584eef65d..567b38c41 100644 --- a/src/interpolating_adjoint.jl +++ b/src/interpolating_adjoint.jl @@ -288,8 +288,10 @@ end if p === nothing || p isa SciMLBase.NullParameters tunables, repack = p, identity - else + elseif isscimlstructure(p) tunables, repack, _ = canonicalize(Tunable(), p) + else + throw(SciMLStructuresCompatibilityError()) end ## Force recompile mode until vjps are specialized to handle this!!! diff --git a/src/sensitivity_interface.jl b/src/sensitivity_interface.jl index 0e1f44aea..1faef7272 100644 --- a/src/sensitivity_interface.jl +++ b/src/sensitivity_interface.jl @@ -423,7 +423,8 @@ function _adjoint_sensitivities(sol, sensealg, alg; callback = nothing, kwargs...) mtkp = SymbolicIndexingInterface.parameter_values(sol) - if !(mtkp isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || + if !((mtkp isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || + isscimlstructure(mtkp) || isfunctor(mtkp)) || (mtkp isa AbstractArray && !Base.isconcretetype(eltype(mtkp))) throw(AdjointSensitivityParameterCompatibilityError()) end diff --git a/test/scimlstructures_interface.jl b/test/scimlstructures_interface.jl index 303fa7a9c..5b28148f3 100644 --- a/test/scimlstructures_interface.jl +++ b/test/scimlstructures_interface.jl @@ -1,6 +1,7 @@ # taken from https://github.com/SciML/SciMLStructures.jl/pull/28 using OrdinaryDiffEq, SciMLSensitivity, Zygote using LinearAlgebra +using Test import SciMLStructures as SS mutable struct SubproblemParameters{P, Q, R} @@ -87,6 +88,7 @@ import SciMLStructures as SS using Zygote using ADTypes using Test +using Tracker, ReverseDiff mutable struct myparam{M, P, S} model::M @@ -156,6 +158,24 @@ function run_diff(ps, sensealg) return sol.u |> last |> sum end +## Test all adjoints with SciMLStructures + +# Test basic functionality run_diff(initialize()) -@test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint())[1].ps) -@test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint(autojacvec = false))[1].ps) + +@testset "SciMLStructures Support for All Adjoints" begin + # Test GaussAdjoint (already working) + @test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint())[1].ps) + + # Test newly fixed BacksolveAdjoint and InterpolatingAdjoint - these are the main fixes in this PR + @test !iszero(Zygote.gradient(run_diff, initialize(), BacksolveAdjoint())[1].ps) + @test !iszero(Zygote.gradient(run_diff, initialize(), InterpolatingAdjoint())[1].ps) + + # Test QuadratureAdjoint (already working) + @test !iszero(Zygote.gradient(run_diff, initialize(), QuadratureAdjoint())[1].ps) + + # Test with different AD backends + @test !iszero(Zygote.gradient(run_diff, initialize(), ReverseDiffAdjoint())[1].ps) + @test !iszero(Zygote.gradient(run_diff, initialize(), TrackerAdjoint())[1].ps) + @test !iszero(Zygote.gradient(run_diff, initialize(), ZygoteAdjoint())[1].ps) +end