Skip to content

Commit b4fd1ce

Browse files
Merge pull request #680 from Abhishek-1Bhatt/sarray
Out of place QuadratureAdjoint for Working with StaticArrays
2 parents 885552a + 7340ca4 commit b4fd1ce

10 files changed

+481
-69
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
3434
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3535
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
3636
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
37+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3738
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
3839
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
3940
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

src/SciMLSensitivity.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import ZygoteRules, Zygote, ReverseDiff
1414
import ArrayInterfaceCore, ArrayInterfaceTracker
1515
import Enzyme
1616
import GPUArraysCore
17+
using StaticArrays
1718

1819
import PreallocationTools: dualcache, get_tmp, DiffCache
1920

@@ -24,7 +25,8 @@ using EllipsisNotation
2425
using Markdown
2526

2627
using Reexport
27-
import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented
28+
import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ProjectTo,
29+
project_type, _eltype_projectto, rrule
2830
abstract type SensitivityFunction end
2931
abstract type TransformedFunction end
3032

@@ -45,6 +47,7 @@ include("concrete_solve.jl")
4547
include("second_order.jl")
4648
include("steadystate_adjoint.jl")
4749
include("sde_tools.jl")
50+
include("staticarrays.jl")
4851

4952
# AD Extensions
5053
include("reversediff.jl")

src/adjoint_common.jl

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ inplace_sensitivity(S::SensitivityFunction) = isinplace(getprob(S))
401401

402402
struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, dg1Type,
403403
dg2Type,
404-
cacheType}
404+
cacheType, solType}
405405
isq::Bool
406406
λ::λType
407407
t::timeType
@@ -413,6 +413,7 @@ struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, dg
413413
dgdu::dg1Type
414414
dgdp::dg2Type
415415
diffcache::cacheType
416+
sol::solType
416417
end
417418

418419
function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time)
@@ -422,13 +423,17 @@ function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time)
422423
@unpack factorized_mass_matrix = sensefun.diffcache
423424
prob = getprob(sensefun)
424425
idx = length(prob.u0)
425-
426-
return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix,
427-
sensealg, dgdu, dgdp, sensefun.diffcache)
426+
if ArrayInterfaceCore.ismutable(y)
427+
return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix,
428+
sensealg, dgdu, dgdp, sensefun.diffcache, nothing)
429+
else
430+
return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix,
431+
sensealg, dgdu, dgdp, sensefun.diffcache, sensefun.sol)
432+
end
428433
end
429434

430435
function (f::ReverseLossCallback)(integrator)
431-
@unpack isq, λ, t, y, cur_time, idx, F, sensealg, dgdu, dgdp = f
436+
@unpack isq, λ, t, y, cur_time, idx, F, sensealg, dgdu, dgdp, sol = f
432437
@unpack diffvar_idxs, algevar_idxs, issemiexplicitdae, J, uf, f_cache, jac_config = f.diffcache
433438

434439
p, u = integrator.p, integrator.u
@@ -437,16 +442,23 @@ function (f::ReverseLossCallback)(integrator)
437442
copyto!(y, integrator.u[(end - idx + 1):end])
438443
end
439444

440-
# Warning: alias here! Be careful with λ
441-
gᵤ = isq ? λ : @view(λ[1:idx])
442-
if dgdu !== nothing
443-
dgdu(gᵤ, y, p, t[cur_time[]], cur_time[])
444-
# add discrete dgdp contribution
445-
if dgdp !== nothing && !isq
446-
gp = @view(λ[(idx + 1):end])
447-
dgdp(gp, y, p, t[cur_time[]], cur_time[])
448-
u[(idx + 1):length(λ)] .+= gp
445+
if ArrayInterfaceCore.ismutable(u)
446+
# Warning: alias here! Be careful with λ
447+
gᵤ = isq ? λ : @view(λ[1:idx])
448+
if dgdu !== nothing
449+
dgdu(gᵤ, y, p, t[cur_time[]], cur_time[])
450+
# add discrete dgdp contribution
451+
if dgdp !== nothing && !isq
452+
gp = @view(λ[(idx + 1):end])
453+
dgdp(gp, y, p, t[cur_time[]], cur_time[])
454+
u[(idx + 1):length(λ)] .+= gp
455+
end
449456
end
457+
else
458+
@assert sensealg isa QuadratureAdjoint
459+
outtype = DiffEqBase.parameterless_type(λ)
460+
y = sol(t[cur_time[]])
461+
gᵤ = dgdu(y, p, t[cur_time[]], cur_time[]; outtype = outtype)
450462
end
451463

452464
if issemiexplicitdae
@@ -468,7 +480,12 @@ function (f::ReverseLossCallback)(integrator)
468480
F !== I && F !== (I, I) && ldiv!(F, Δλd)
469481
end
470482

471-
u[diffvar_idxs] .+= Δλd
483+
if ArrayInterfaceCore.ismutable(u)
484+
u[diffvar_idxs] .+= Δλd
485+
else
486+
@assert sensealg isa QuadratureAdjoint
487+
integrator.u += Δλd
488+
end
472489
u_modified!(integrator, true)
473490
cur_time[] -= 1
474491
return nothing

src/concrete_solve.jl

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractODEPro
343343
_save_idxs = save_idxs === nothing ? Colon() : save_idxs
344344

345345
function adjoint_sensitivity_backpass(Δ)
346-
function df(_out, u, p, t, i)
346+
function df_iip(_out, u, p, t, i)
347347
outtype = typeof(_out) <: SubArray ?
348348
DiffEqBase.parameterless_type(_out.parent) :
349349
DiffEqBase.parameterless_type(_out)
@@ -404,16 +404,82 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractODEPro
404404
end
405405
end
406406

407+
function df_oop(u, p, t, i; outtype = nothing)
408+
if only_end
409+
eltype(Δ) <: NoTangent && return
410+
if typeof(Δ) <: AbstractArray{<:AbstractArray} && length(Δ) == 1 && i == 1
411+
# user did sol[end] on only_end
412+
if typeof(_save_idxs) <: Number
413+
x = vec(Δ[1])
414+
_out = adapt(outtype, @view(x[_save_idxs]))
415+
elseif _save_idxs isa Colon
416+
_out = adapt(outtype, vec(Δ[1]))
417+
else
418+
_out = adapt(outtype,
419+
vec(Δ[1])[_save_idxs])
420+
end
421+
else
422+
Δ isa NoTangent && return
423+
if typeof(_save_idxs) <: Number
424+
x = vec(Δ)
425+
_out = adapt(outtype, @view(x[_save_idxs]))
426+
elseif _save_idxs isa Colon
427+
_out = adapt(outtype, vec(Δ))
428+
else
429+
x = vec(Δ)
430+
_out = adapt(outtype, @view(x[_save_idxs]))
431+
end
432+
end
433+
else
434+
!Base.isconcretetype(eltype(Δ)) &&
435+
(Δ[i] isa NoTangent || eltype(Δ) <: NoTangent) && return
436+
if typeof(Δ) <: AbstractArray{<:AbstractArray} || typeof(Δ) <: DESolution
437+
x = Δ[i]
438+
if typeof(_save_idxs) <: Number
439+
_out = @view(x[_save_idxs])
440+
elseif _save_idxs isa Colon
441+
_out = vec(x)
442+
else
443+
_out = vec(@view(x[_save_idxs]))
444+
end
445+
else
446+
if typeof(_save_idxs) <: Number
447+
_out = adapt(outtype,
448+
reshape(Δ, prod(size(Δ)[1:(end - 1)]),
449+
size(Δ)[end])[_save_idxs, i])
450+
elseif _save_idxs isa Colon
451+
_out = vec(adapt(outtype,
452+
reshape(Δ, prod(size(Δ)[1:(end - 1)]),
453+
size(Δ)[end])[:, i]))
454+
else
455+
_out = vec(adapt(outtype,
456+
reshape(Δ,
457+
prod(size(Δ)[1:(end - 1)]),
458+
size(Δ)[end])[:, i]))
459+
end
460+
end
461+
end
462+
return _out
463+
end
464+
407465
if haskey(kwargs_adj, :callback_adj)
408466
cb2 = CallbackSet(cb, kwargs[:callback_adj])
409467
else
410468
cb2 = cb
411469
end
412-
413-
du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, dgdu_discrete = df,
414-
sensealg = sensealg,
415-
callback = cb2,
416-
kwargs_adj...)
470+
if ArrayInterfaceCore.ismutable(eltype(sol.u))
471+
du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts,
472+
dgdu_discrete = df_iip,
473+
sensealg = sensealg,
474+
callback = cb2,
475+
kwargs_adj...)
476+
else
477+
du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts,
478+
dgdu_discrete = df_oop,
479+
sensealg = sensealg,
480+
callback = cb2,
481+
kwargs_adj...)
482+
end
417483

418484
du0 = reshape(du0, size(u0))
419485
dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing :

src/derivative_wrappers.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,12 @@ function vecjacobian!(dλ, y, λ, p, t, S::TS;
222222
return
223223
end
224224

225+
function vecjacobian(y, λ, p, t, S::TS;
226+
dgrad = nothing, dy = nothing,
227+
W = nothing) where {TS <: SensitivityFunction}
228+
return _vecjacobian(y, λ, p, t, S, S.sensealg.autojacvec, dgrad, dy, W)
229+
end
230+
225231
function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::Bool, dgrad, dy,
226232
W) where {TS <: SensitivityFunction}
227233
@unpack sensealg, f = S
@@ -588,6 +594,43 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad,
588594
return
589595
end
590596

597+
function _vecjacobian(y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy,
598+
W) where {TS <: SensitivityFunction}
599+
@unpack sensealg, f = S
600+
prob = getprob(S)
601+
602+
isautojacvec = get_jacvec(sensealg)
603+
604+
if W === nothing
605+
_dy, back = Zygote.pullback(y, p) do u, p
606+
vec(f(u, p, t))
607+
end
608+
else
609+
_dy, back = Zygote.pullback(y, p) do u, p
610+
vec(f(u, p, t, W))
611+
end
612+
end
613+
614+
# Grab values from `_dy` before `back` in case mutated
615+
dy !== nothing && (dy[:] .= vec(_dy))
616+
617+
tmp1, tmp2 = back(λ)
618+
if tmp1 === nothing && !sensealg.autojacvec.allow_nothing
619+
throw(ZygoteVJPNothingError())
620+
elseif tmp1 !== nothing
621+
(dλ = vec(tmp1))
622+
end
623+
624+
if dgrad !== nothing
625+
if tmp2 === nothing && !sensealg.autojacvec.allow_nothing
626+
throw(ZygoteVJPNothingError())
627+
elseif tmp2 !== nothing
628+
(dgrad[:] .= vec(tmp2))
629+
end
630+
end
631+
return dy, dλ, dgrad
632+
end
633+
591634
function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, dy,
592635
W) where {TS <: SensitivityFunction}
593636
@unpack sensealg = S
@@ -923,6 +966,19 @@ function accumulate_cost!(dλ, y, p, t, S::TS,
923966
return nothing
924967
end
925968

969+
function accumulate_cost(dλ, y, p, t, S::TS,
970+
dgrad = nothing) where {TS <: SensitivityFunction}
971+
@unpack dgdu, dgdp = S.diffcache
972+
973+
-= dgdu(y, p, t)
974+
if dgdp !== nothing
975+
if dgrad !== nothing
976+
dgrad -= dgdp(y, p, t)
977+
end
978+
end
979+
return dλ, dgrad
980+
end
981+
926982
function build_jac_config(alg, uf, u)
927983
if alg_autodiff(alg)
928984
jac_config = ForwardDiff.JacobianConfig(uf, u, u,

0 commit comments

Comments
 (0)