diff --git a/Project.toml b/Project.toml index 63674ffa3..84376470d 100644 --- a/Project.toml +++ b/Project.toml @@ -34,6 +34,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 60f61416f..cf2769680 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -14,6 +14,7 @@ import ZygoteRules, Zygote, ReverseDiff import ArrayInterfaceCore, ArrayInterfaceTracker import Enzyme import GPUArraysCore +using StaticArrays import PreallocationTools: dualcache, get_tmp, DiffCache @@ -24,7 +25,8 @@ using EllipsisNotation using Markdown using Reexport -import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented +import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ProjectTo, + project_type, _eltype_projectto, rrule abstract type SensitivityFunction end abstract type TransformedFunction end @@ -45,6 +47,7 @@ include("concrete_solve.jl") include("second_order.jl") include("steadystate_adjoint.jl") include("sde_tools.jl") +include("staticarrays.jl") # AD Extensions include("reversediff.jl") diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 6f4cde70b..9210895f0 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -401,7 +401,7 @@ inplace_sensitivity(S::SensitivityFunction) = isinplace(getprob(S)) struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, dg1Type, dg2Type, - cacheType} + cacheType, solType} isq::Bool λ::λType t::timeType @@ -413,6 +413,7 @@ struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, dg dgdu::dg1Type dgdp::dg2Type diffcache::cacheType + sol::solType end function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time) @@ -422,13 +423,17 @@ function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time) @unpack factorized_mass_matrix = sensefun.diffcache prob = getprob(sensefun) idx = length(prob.u0) - - return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix, - sensealg, dgdu, dgdp, sensefun.diffcache) + if ArrayInterfaceCore.ismutable(y) + return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix, + sensealg, dgdu, dgdp, sensefun.diffcache, nothing) + else + return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix, + sensealg, dgdu, dgdp, sensefun.diffcache, sensefun.sol) + end end function (f::ReverseLossCallback)(integrator) - @unpack isq, λ, t, y, cur_time, idx, F, sensealg, dgdu, dgdp = f + @unpack isq, λ, t, y, cur_time, idx, F, sensealg, dgdu, dgdp, sol = f @unpack diffvar_idxs, algevar_idxs, issemiexplicitdae, J, uf, f_cache, jac_config = f.diffcache p, u = integrator.p, integrator.u @@ -437,16 +442,23 @@ function (f::ReverseLossCallback)(integrator) copyto!(y, integrator.u[(end - idx + 1):end]) end - # Warning: alias here! Be careful with λ - gᵤ = isq ? λ : @view(λ[1:idx]) - if dgdu !== nothing - dgdu(gᵤ, y, p, t[cur_time[]], cur_time[]) - # add discrete dgdp contribution - if dgdp !== nothing && !isq - gp = @view(λ[(idx + 1):end]) - dgdp(gp, y, p, t[cur_time[]], cur_time[]) - u[(idx + 1):length(λ)] .+= gp + if ArrayInterfaceCore.ismutable(u) + # Warning: alias here! Be careful with λ + gᵤ = isq ? λ : @view(λ[1:idx]) + if dgdu !== nothing + dgdu(gᵤ, y, p, t[cur_time[]], cur_time[]) + # add discrete dgdp contribution + if dgdp !== nothing && !isq + gp = @view(λ[(idx + 1):end]) + dgdp(gp, y, p, t[cur_time[]], cur_time[]) + u[(idx + 1):length(λ)] .+= gp + end end + else + @assert sensealg isa QuadratureAdjoint + outtype = DiffEqBase.parameterless_type(λ) + y = sol(t[cur_time[]]) + gᵤ = dgdu(y, p, t[cur_time[]], cur_time[]; outtype = outtype) end if issemiexplicitdae @@ -468,7 +480,12 @@ function (f::ReverseLossCallback)(integrator) F !== I && F !== (I, I) && ldiv!(F, Δλd) end - u[diffvar_idxs] .+= Δλd + if ArrayInterfaceCore.ismutable(u) + u[diffvar_idxs] .+= Δλd + else + @assert sensealg isa QuadratureAdjoint + integrator.u += Δλd + end u_modified!(integrator, true) cur_time[] -= 1 return nothing diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 44faa84be..784f2902d 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -343,7 +343,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractODEPro _save_idxs = save_idxs === nothing ? Colon() : save_idxs function adjoint_sensitivity_backpass(Δ) - function df(_out, u, p, t, i) + function df_iip(_out, u, p, t, i) outtype = typeof(_out) <: SubArray ? DiffEqBase.parameterless_type(_out.parent) : DiffEqBase.parameterless_type(_out) @@ -404,16 +404,82 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractODEPro end end + function df_oop(u, p, t, i; outtype = nothing) + if only_end + eltype(Δ) <: NoTangent && return + if typeof(Δ) <: AbstractArray{<:AbstractArray} && length(Δ) == 1 && i == 1 + # user did sol[end] on only_end + if typeof(_save_idxs) <: Number + x = vec(Δ[1]) + _out = adapt(outtype, @view(x[_save_idxs])) + elseif _save_idxs isa Colon + _out = adapt(outtype, vec(Δ[1])) + else + _out = adapt(outtype, + vec(Δ[1])[_save_idxs]) + end + else + Δ isa NoTangent && return + if typeof(_save_idxs) <: Number + x = vec(Δ) + _out = adapt(outtype, @view(x[_save_idxs])) + elseif _save_idxs isa Colon + _out = adapt(outtype, vec(Δ)) + else + x = vec(Δ) + _out = adapt(outtype, @view(x[_save_idxs])) + end + end + else + !Base.isconcretetype(eltype(Δ)) && + (Δ[i] isa NoTangent || eltype(Δ) <: NoTangent) && return + if typeof(Δ) <: AbstractArray{<:AbstractArray} || typeof(Δ) <: DESolution + x = Δ[i] + if typeof(_save_idxs) <: Number + _out = @view(x[_save_idxs]) + elseif _save_idxs isa Colon + _out = vec(x) + else + _out = vec(@view(x[_save_idxs])) + end + else + if typeof(_save_idxs) <: Number + _out = adapt(outtype, + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[_save_idxs, i]) + elseif _save_idxs isa Colon + _out = vec(adapt(outtype, + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, i])) + else + _out = vec(adapt(outtype, + reshape(Δ, + prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, i])) + end + end + end + return _out + end + if haskey(kwargs_adj, :callback_adj) cb2 = CallbackSet(cb, kwargs[:callback_adj]) else cb2 = cb end - - du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, dgdu_discrete = df, - sensealg = sensealg, - callback = cb2, - kwargs_adj...) + if ArrayInterfaceCore.ismutable(eltype(sol.u)) + du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, + dgdu_discrete = df_iip, + sensealg = sensealg, + callback = cb2, + kwargs_adj...) + else + du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, + dgdu_discrete = df_oop, + sensealg = sensealg, + callback = cb2, + kwargs_adj...) + end du0 = reshape(du0, size(u0)) dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 4c8305e96..f75524555 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -222,6 +222,12 @@ function vecjacobian!(dλ, y, λ, p, t, S::TS; return end +function vecjacobian(y, λ, p, t, S::TS; + dgrad = nothing, dy = nothing, + W = nothing) where {TS <: SensitivityFunction} + return _vecjacobian(y, λ, p, t, S, S.sensealg.autojacvec, dgrad, dy, W) +end + function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::Bool, dgrad, dy, W) where {TS <: SensitivityFunction} @unpack sensealg, f = S @@ -588,6 +594,43 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, return end +function _vecjacobian(y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, + W) where {TS <: SensitivityFunction} + @unpack sensealg, f = S + prob = getprob(S) + + isautojacvec = get_jacvec(sensealg) + + if W === nothing + _dy, back = Zygote.pullback(y, p) do u, p + vec(f(u, p, t)) + end + else + _dy, back = Zygote.pullback(y, p) do u, p + vec(f(u, p, t, W)) + end + end + + # Grab values from `_dy` before `back` in case mutated + dy !== nothing && (dy[:] .= vec(_dy)) + + tmp1, tmp2 = back(λ) + if tmp1 === nothing && !sensealg.autojacvec.allow_nothing + throw(ZygoteVJPNothingError()) + elseif tmp1 !== nothing + (dλ = vec(tmp1)) + end + + if dgrad !== nothing + if tmp2 === nothing && !sensealg.autojacvec.allow_nothing + throw(ZygoteVJPNothingError()) + elseif tmp2 !== nothing + (dgrad[:] .= vec(tmp2)) + end + end + return dy, dλ, dgrad +end + function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, dy, W) where {TS <: SensitivityFunction} @unpack sensealg = S @@ -923,6 +966,19 @@ function accumulate_cost!(dλ, y, p, t, S::TS, return nothing end +function accumulate_cost(dλ, y, p, t, S::TS, + dgrad = nothing) where {TS <: SensitivityFunction} + @unpack dgdu, dgdp = S.diffcache + + dλ -= dgdu(y, p, t) + if dgdp !== nothing + if dgrad !== nothing + dgrad -= dgdp(y, p, t) + end + end + return dλ, dgrad +end + function build_jac_config(alg, uf, u) if alg_autodiff(alg) jac_config = ForwardDiff.JacobianConfig(uf, u, u, diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index 3707f6ed4..4391b8943 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -33,6 +33,21 @@ function (S::ODEQuadratureAdjointSensitivityFunction)(du, u, p, t) return nothing end +function (S::ODEQuadratureAdjointSensitivityFunction)(u, p, t) + @unpack sol, discrete = S + f = sol.prob.f + + λ, grad, y, dgrad, dy = split_states(u, t, S) + + dy, dλ, dgrad = vecjacobian(y, λ, p, t, S; dgrad = dgrad, dy = dy) + dλ *= (-one(eltype(λ))) + + if !discrete + dλ, dgrad = accumulate_cost(dλ, y, p, t, S, dgrad) + end + return dλ +end + function split_states(du, u, t, S::ODEQuadratureAdjointSensitivityFunction; update = true) @unpack y, sol = S @@ -50,6 +65,18 @@ function split_states(du, u, t, S::ODEQuadratureAdjointSensitivityFunction; upda λ, nothing, y, dλ, nothing, nothing end +function split_states(u, t, S::ODEQuadratureAdjointSensitivityFunction; update = true) + @unpack y, sol = S + + if update + y = sol(t, continuity = :right) + end + + λ = u + + λ, nothing, y, nothing, nothing +end + # g is either g(t,u,p) or discrete g(t,u,i) @noinline function ODEAdjointProblem(sol, sensealg::QuadratureAdjoint, alg, t = nothing, @@ -80,9 +107,13 @@ end (dgdu_continuous === nothing && dgdp_continuous === nothing || g !== nothing)) - len = length(u0) - λ = similar(u0, len) - λ .= false + if ArrayInterfaceCore.ismutable(u0) + len = length(u0) + λ = similar(u0, len) + λ .= false + else + λ = zero(u0) + end sense = ODEQuadratureAdjointSensitivityFunction(g, sensealg, discrete, sol, dgdu_continuous, dgdp_continuous, alg) @@ -103,7 +134,7 @@ end odefun = ODEFunction(sense, mass_matrix = sol.prob.f.mass_matrix', jac_prototype = adjoint_jac_prototype) end - return ODEProblem(odefun, z0, tspan, p, callback = cb) + return ODEProblem{ArrayInterfaceCore.ismutable(z0)}(odefun, z0, tspan, p, callback = cb) end struct AdjointSensitivityIntegrand{pType, uType, lType, rateType, S, AS, PF, PJC, PJT, DGP, @@ -130,7 +161,6 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing) λ = zero(adj_sol.prob.u0) # we need to alias `y` f_cache = zero(y) - f_cache .= false isautojacvec = get_jacvec(sensealg) dgdp_cache = dgdp === nothing ? nothing : zero(p) @@ -198,8 +228,13 @@ end function (S::AdjointSensitivityIntegrand)(out, t) @unpack y, λ, pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol, adj_sol = S f = sol.prob.f - sol(y, t) - adj_sol(λ, t) + if ArrayInterfaceCore.ismutable(y) + sol(y, t) + adj_sol(λ, t) + else + y = sol(t) + λ = adj_sol(t) + end isautojacvec = get_jacvec(sensealg) # y is aliased @@ -310,8 +345,13 @@ function _adjoint_sensitivities(sol, sensealg::QuadratureAdjoint, alg; t = nothi end for i in (length(t) - 1):-1:1 - res .+= quadgk(integrand, t[i], t[i + 1], - atol = abstol, rtol = reltol)[1] + if ArrayInterfaceCore.ismutable(res) + res .+= quadgk(integrand, t[i], t[i + 1], + atol = abstol, rtol = reltol)[1] + else + res += quadgk(integrand, t[i], t[i + 1], + atol = abstol, rtol = reltol)[1] + end if t[i] == t[i + 1] integrand = update_integrand_and_dgrad(res, sensealg, callback, integrand, diff --git a/src/staticarrays.jl b/src/staticarrays.jl new file mode 100644 index 000000000..11cc1aaad --- /dev/null +++ b/src/staticarrays.jl @@ -0,0 +1,23 @@ +### Projecting a tuple to SMatrix leads to ChainRulesCore._projection_mismatch by default, so overloaded here +function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::StaticArrays.SArray) + dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray + dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) + return project_type(project)(dz...) +end + +### Project SArray to SArray +function ProjectTo(x::StaticArrays.SArray{S, T}) where {S, T} + return ProjectTo{StaticArrays.SArray}(; element = _eltype_projectto(T), axes = S) +end + +function (project::ProjectTo{StaticArrays.SArray})(dx::AbstractArray{S, M}) where {S, M} + return StaticArrays.SArray{project.axes}(dx) +end + +### Adjoint for SArray constructor + +function rrule(::Type{T}, x::Tuple) where {T <: StaticArrays.SArray} + project_x = ProjectTo(x) + Array_pullback(ȳ) = (NoTangent(), project_x(ȳ)) + return T(x), Array_pullback +end diff --git a/test/adjoint_oop.jl b/test/adjoint_oop.jl new file mode 100644 index 000000000..c8e5883e6 --- /dev/null +++ b/test/adjoint_oop.jl @@ -0,0 +1,202 @@ +using SciMLSensitivity, OrdinaryDiffEq, StaticArrays, QuadGK, ForwardDiff, + Zygote +using Test + +##StaticArrays rrule +u0 = @SVector rand(2) +p = @SVector rand(4) + +function lotka(u, p, svec = true) + du1 = p[1] * u[1] - p[2] * u[1] * u[2] + du2 = -p[3] * u[2] + p[4] * u[1] * u[2] + if svec + @SVector [du1, du2] + else + @SMatrix [du1 du2 du1; du2 du1 du1] + end +end + +#SVector constructor adjoint +function loss(p) + u = lotka(u0, p) + sum(1 .- u) +end + +grad = Zygote.gradient(loss, p) +@test typeof(grad[1]) <: SArray +grad2 = ForwardDiff.gradient(loss, p) +@test grad[1]≈grad2 rtol=1e-12 + +#SMatrix constructor adjoint +function loss_mat(p) + u = lotka(u0, p, false) + sum(1 .- u) +end + +grad = Zygote.gradient(loss_mat, p) +@test typeof(grad[1]) <: SArray +grad2 = ForwardDiff.gradient(loss_mat, p) +@test grad[1]≈grad2 rtol=1e-12 + +##Adjoints of StaticArrays ODE + +u0 = @SVector [1.0, 1.0] +p = @SVector [1.5, 1.0, 3.0, 1.0] +tspan = (0.0, 5.0) +datasize = 15 +tsteps = range(tspan[1], tspan[2], length = datasize) + +function lotka(u, p, t) + du1 = p[1] * u[1] - p[2] * u[1] * u[2] + du2 = -p[3] * u[2] + p[4] * u[1] * u[2] + @SVector [du1, du2] +end + +prob = ODEProblem(lotka, u0, tspan, p) +sol = solve(prob, Tsit5(), saveat = tsteps, abstol = 1e-14, reltol = 1e-14) + +## Discrete Case +dg_disc(u, p, t, i; outtype = nothing) = u + +du0, dp = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc, + sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = ZygoteVJP())) + +@test !iszero(du0) +@test !iszero(dp) +# +adj_prob = ODEAdjointProblem(sol, + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = SciMLSensitivity.ZygoteVJP()), + Tsit5(), tsteps, dg_disc) +adj_sol = solve(adj_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) +integrand = AdjointSensitivityIntegrand(sol, adj_sol, + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = SciMLSensitivity.ZygoteVJP())) +res, err = quadgk(integrand, 0.0, 5.0, atol = 1e-14, rtol = 1e-14) + +@test adj_sol[end]≈du0 rtol=1e-12 +@test res≈dp rtol=1e-12 + +###Comparing with gradients of lotka volterra with normal arrays +u2 = [1.0, 1.0] +p2 = [1.5, 1.0, 3.0, 1.0] + +function f(u, p, t) + du1 = p[1] * u[1] - p[2] * u[1] * u[2] + du2 = -p[3] * u[2] + p[4] * u[1] * u[2] + [du1, du2] +end + +prob2 = ODEProblem(f, u2, tspan, p2) +sol2 = solve(prob, Tsit5(), saveat = tsteps, abstol = 1e-14, reltol = 1e-14) + +function dg_disc(du, u, p, t, i) + du .= u +end + +du1, dp1 = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc, + sensealg = QuadratureAdjoint(abstol = 1e-14, + reltol = 1e-14, + autojacvec = ZygoteVJP())) + +@test du0≈du1 rtol=1e-12 +@test dp≈dp1 rtol=1e-12 + +## with ForwardDiff and Zygote + +function G_p(p) + tmp_prob = remake(prob, u0 = convert.(eltype(p), prob.u0), p = p) + sol = solve(tmp_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14, + sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = ZygoteVJP()), saveat = tsteps) + u = Array(sol) + return sum(((1 .- u) .^ 2) ./ 2) +end + +function G_u(u0) + tmp_prob = remake(prob, u0 = u0, p = prob.p) + sol = solve(tmp_prob, Tsit5(), saveat = tsteps, + sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = ZygoteVJP()), abstol = 1e-14, + reltol = 1e-14) + u = Array(sol) + + return sum(((1 .- u) .^ 2) ./ 2) +end + +G_p(p) +G_u(u0) +f_dp = ForwardDiff.gradient(G_p, p) +f_du0 = ForwardDiff.gradient(G_u, u0) + +z_dp = Zygote.gradient(G_p, p) +z_du0 = Zygote.gradient(G_u, u0) + +@test z_du0[1]≈f_du0 rtol=1e-12 +@test z_dp[1]≈f_dp rtol=1e-12 + +## Continuous Case + +g(u, p, t) = sum((u .^ 2) ./ 2) + +function dg(u, p, t) + u +end + +du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = g, + sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = ZygoteVJP())) + +@test !iszero(du0) +@test !iszero(dp) + +adj_prob = ODEAdjointProblem(sol, + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = SciMLSensitivity.ZygoteVJP()), + Tsit5(), nothing, nothing, nothing, dg, nothing, g) +adj_sol = solve(adj_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) +integrand = AdjointSensitivityIntegrand(sol, adj_sol, + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = SciMLSensitivity.ZygoteVJP())) +res, err = quadgk(integrand, 0.0, 5.0, atol = 1e-14, rtol = 1e-14) + +@test adj_sol[end]≈du0 rtol=1e-12 +@test res≈dp rtol=1e-12 + +##ForwardDiff + +function G_p(p) + tmp_prob = remake(prob, p = p) + sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12) + res, err = quadgk((t) -> (sum((sol(t) .^ 2) ./ 2)), 0.0, 5.0, atol = 1e-12, + rtol = 1e-12) + res +end + +function G_u(u0) + tmp_prob = remake(prob, u0 = u0) + sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12) + res, err = quadgk((t) -> (sum((sol(t) .^ 2) ./ 2)), 0.0, 5.0, atol = 1e-12, + rtol = 1e-12) + res +end + +f_du0 = ForwardDiff.gradient(G_u, u0) +f_dp = ForwardDiff.gradient(G_p, p) + +@test !iszero(f_du0) +@test !iszero(f_dp) + +## concrete solve + +du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob, Tsit5(), u0, p, + abstol = 1e-10, reltol = 1e-10, + saveat = tsteps, + sensealg = QuadratureAdjoint(abstol = 1e-14, + reltol = 1e-14, + autojacvec = ZygoteVJP()))), + u0, p) + +@test !iszero(du0) +@test !iszero(dp) diff --git a/test/runtests.jl b/test/runtests.jl index 56227fb5b..ee77951da 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,6 +45,7 @@ end @time @safetestset "Adjoint Sensitivity" begin include("adjoint.jl") end @time @safetestset "Continuous adjoint params" begin include("adjoint_param.jl") end @time @safetestset "Continuous and discrete costs" begin include("mixed_costs.jl") end + @time @safetestset "Fully Out of Place adjoint sensitivity" begin include("adjoint_oop.jl") end end if GROUP == "All" || GROUP == "Core4" diff --git a/test/stiff_adjoints.jl b/test/stiff_adjoints.jl index 174bd903f..a866cdc53 100644 --- a/test/stiff_adjoints.jl +++ b/test/stiff_adjoints.jl @@ -175,7 +175,7 @@ if VERSION >= v"1.7-" ROCK4(), RKC(), # SERK2v2(), not defined? - ESERK5()]; + ESERK5()] p = rand(3) @@ -214,45 +214,48 @@ if VERSION >= v"1.7-" dp1 = @test_broken Zygote.gradient(p -> loss(p, ReverseDiffAdjoint()), p)[1] @test_broken dp≈dp1 rtol=1e-2 end -end -using SciMLSensitivity, OrdinaryDiffEq, ForwardDiff, Zygote, Test + # using SciMLSensitivity, OrdinaryDiffEq, ForwardDiff, Zygote, Test -function rober(du, u, p, t) - y₁, y₂, y₃ = u - k₁, k₂, k₃ = p[1], p[2], p[3] - du[1] = -k₁ * y₁ + k₃ * y₂ * y₃ - du[2] = k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃ - du[3] = k₂ * y₂^2 + sum(p) - nothing -end + function rober(du, u, p, t) + y₁, y₂, y₃ = u + k₁, k₂, k₃ = p[1], p[2], p[3] + du[1] = -k₁ * y₁ + k₃ * y₂ * y₃ + du[2] = k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃ + du[3] = k₂ * y₂^2 + sum(p) + nothing + end -function sum_of_solution_fwd(x) - _prob = ODEProblem(rober, x[1:3], (0.0, 1e4), x[4:end]) - sum(solve(_prob, Rodas5(), saveat = 1, reltol = 1e-12, abstol = 1e-12)) -end + function sum_of_solution_fwd(x) + _prob = ODEProblem(rober, x[1:3], (0.0, 1e4), x[4:end]) + sum(solve(_prob, Rodas5(), saveat = 1, reltol = 1e-12, abstol = 1e-12)) + end -function sum_of_solution_CASA(x; vjp = EnzymeVJP()) - sensealg = QuadratureAdjoint(autodiff = false, autojacvec = vjp) - _prob = ODEProblem(rober, x[1:3], (0.0, 1e4), x[4:end]) - sum(solve(_prob, Rodas5(), reltol = 1e-8, abstol = 1e-8, saveat = 1, - sensealg = sensealg)) -end + function sum_of_solution_CASA(x; vjp = EnzymeVJP()) + sensealg = QuadratureAdjoint(autodiff = false, autojacvec = vjp) + _prob = ODEProblem(rober, x[1:3], (0.0, 1e4), x[4:end]) + sum(solve(_prob, Rodas5(), reltol = 1e-8, abstol = 1e-8, saveat = 1, + sensealg = sensealg)) + end -u0 = [1.0, 0.0, 0.0] -p = ones(8) # change me, the number of parameters - -grad1 = ForwardDiff.gradient(sum_of_solution_fwd, [u0; p]) -grad2 = Zygote.gradient(sum_of_solution_CASA, [u0; p])[1] -grad3 = Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = ReverseDiffVJP()), [u0; p])[1] -grad4 = Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = ReverseDiffVJP(true)), [u0; p])[1] -@test_throws Any Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = true), [u0; p])[1] -grad6 = Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = false), [u0; p])[1] -@test_throws Any Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = ZygoteVJP()), [u0; p])[1] -@test_throws Any Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = TrackerVJP()), [u0; p])[1] - -@test grad1 ≈ grad2 -@test grad1 ≈ grad3 -@test grad1 ≈ grad4 -#@test grad1 ≈ grad5 -@test grad1 ≈ grad6 + u0 = [1.0, 0.0, 0.0] + p = ones(8) # change me, the number of parameters + + grad1 = ForwardDiff.gradient(sum_of_solution_fwd, [u0; p]) + grad2 = Zygote.gradient(sum_of_solution_CASA, [u0; p])[1] + grad3 = Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = ReverseDiffVJP()), [u0; p])[1] + grad4 = Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = ReverseDiffVJP(true)), + [u0; p])[1] + @test_throws Any Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = true), [u0; p])[1] + grad6 = Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = false), [u0; p])[1] + @test_throws Any Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = ZygoteVJP()), + [u0; p])[1] + @test_throws Any Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = TrackerVJP()), + [u0; p])[1] + + @test grad1 ≈ grad2 + @test grad1 ≈ grad3 + @test grad1 ≈ grad4 + #@test grad1 ≈ grad5 + @test grad1 ≈ grad6 +end