-
-
Notifications
You must be signed in to change notification settings - Fork 163
Closed
JuliaDiff/ChainRules.jl
#758Description
DiffEqFlux.jl/src/DiffEqFlux.jl
Lines 24 to 40 in e32422d
# ForwardDiff integration | |
ZygoteRules.@adjoint function ForwardDiff.Dual{T}(x, ẋ::Tuple) where T | |
@assert length(ẋ) == 1 | |
ForwardDiff.Dual{T}(x, ẋ), ḋ -> (ḋ.partials[1], (ḋ.value,)) | |
end | |
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:partials}) where T = | |
d.partials, ṗ -> (ForwardDiff.Dual{T}(ṗ[1], 0),) | |
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:value}) where T = | |
d.value, ẋ -> (ForwardDiff.Dual{T}(0, ẋ),) | |
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:dl}) = A.dl,y -> Tridiagonal(dl,zeros(length(d)),zeros(length(du)),) | |
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:d}) = A.d,y -> Tridiagonal(zeros(length(dl)),d,zeros(length(du)),) | |
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:du}) = A.dl,y -> Tridiagonal(zeros(length(dl)),zeros(length(d),du),) | |
ZygoteRules.@adjoint Tridiagonal(dl, d, du) = Tridiagonal(dl, d, du), p̄ -> (diag(p̄[2:end,1:end-1]),diag(p̄),diag(p̄[1:end-1,2:end])) |
Related to SciML/SciMLSensitivity.jl#582
Metadata
Metadata
Assignees
Labels
No labels