@@ -401,7 +401,7 @@ inplace_sensitivity(S::SensitivityFunction) = isinplace(getprob(S))
401
401
402
402
struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, dg1Type,
403
403
dg2Type,
404
- cacheType}
404
+ cacheType, solType }
405
405
isq:: Bool
406
406
λ:: λType
407
407
t:: timeType
@@ -413,6 +413,7 @@ struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, dg
413
413
dgdu:: dg1Type
414
414
dgdp:: dg2Type
415
415
diffcache:: cacheType
416
+ sol:: solType
416
417
end
417
418
418
419
function ReverseLossCallback (sensefun, λ, t, dgdu, dgdp, cur_time)
@@ -422,13 +423,17 @@ function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time)
422
423
@unpack factorized_mass_matrix = sensefun. diffcache
423
424
prob = getprob (sensefun)
424
425
idx = length (prob. u0)
425
-
426
- return ReverseLossCallback (isq, λ, t, y, cur_time, idx, factorized_mass_matrix,
427
- sensealg, dgdu, dgdp, sensefun. diffcache)
426
+ if ArrayInterfaceCore. ismutable (y)
427
+ return ReverseLossCallback (isq, λ, t, y, cur_time, idx, factorized_mass_matrix,
428
+ sensealg, dgdu, dgdp, sensefun. diffcache, nothing )
429
+ else
430
+ return ReverseLossCallback (isq, λ, t, y, cur_time, idx, factorized_mass_matrix,
431
+ sensealg, dgdu, dgdp, sensefun. diffcache, sensefun. sol)
432
+ end
428
433
end
429
434
430
435
function (f:: ReverseLossCallback )(integrator)
431
- @unpack isq, λ, t, y, cur_time, idx, F, sensealg, dgdu, dgdp = f
436
+ @unpack isq, λ, t, y, cur_time, idx, F, sensealg, dgdu, dgdp, sol = f
432
437
@unpack diffvar_idxs, algevar_idxs, issemiexplicitdae, J, uf, f_cache, jac_config = f. diffcache
433
438
434
439
p, u = integrator. p, integrator. u
@@ -437,16 +442,23 @@ function (f::ReverseLossCallback)(integrator)
437
442
copyto! (y, integrator. u[(end - idx + 1 ): end ])
438
443
end
439
444
440
- # Warning: alias here! Be careful with λ
441
- gᵤ = isq ? λ : @view (λ[1 : idx])
442
- if dgdu != = nothing
443
- dgdu (gᵤ, y, p, t[cur_time[]], cur_time[])
444
- # add discrete dgdp contribution
445
- if dgdp != = nothing && ! isq
446
- gp = @view (λ[(idx + 1 ): end ])
447
- dgdp (gp, y, p, t[cur_time[]], cur_time[])
448
- u[(idx + 1 ): length (λ)] .+ = gp
445
+ if ArrayInterfaceCore. ismutable (u)
446
+ # Warning: alias here! Be careful with λ
447
+ gᵤ = isq ? λ : @view (λ[1 : idx])
448
+ if dgdu != = nothing
449
+ dgdu (gᵤ, y, p, t[cur_time[]], cur_time[])
450
+ # add discrete dgdp contribution
451
+ if dgdp != = nothing && ! isq
452
+ gp = @view (λ[(idx + 1 ): end ])
453
+ dgdp (gp, y, p, t[cur_time[]], cur_time[])
454
+ u[(idx + 1 ): length (λ)] .+ = gp
455
+ end
449
456
end
457
+ else
458
+ @assert sensealg isa QuadratureAdjoint
459
+ outtype = DiffEqBase. parameterless_type (λ)
460
+ y = sol (t[cur_time[]])
461
+ gᵤ = dgdu (y, p, t[cur_time[]], cur_time[]; outtype = outtype)
450
462
end
451
463
452
464
if issemiexplicitdae
@@ -468,7 +480,12 @@ function (f::ReverseLossCallback)(integrator)
468
480
F != = I && F != = (I, I) && ldiv! (F, Δλd)
469
481
end
470
482
471
- u[diffvar_idxs] .+ = Δλd
483
+ if ArrayInterfaceCore. ismutable (u)
484
+ u[diffvar_idxs] .+ = Δλd
485
+ else
486
+ @assert sensealg isa QuadratureAdjoint
487
+ integrator. u += Δλd
488
+ end
472
489
u_modified! (integrator, true )
473
490
cur_time[] -= 1
474
491
return nothing
0 commit comments