Skip to content
13 changes: 13 additions & 0 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,16 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular})
end
return y, logdet_pullback
end

#####
##### Tridiagonal
#####

function rrule(::Type{Tridiagonal}, dl, d, du)
y = Tridiagonal(dl, d, du)
@views function ∇Tridiagonal(∂y)
return (NoTangent(), diag(∂y[2:end, 1:(end - 1)]), diag(∂y),
diag(∂y[1:(end - 1), 2:end]))
end
return y, ∇Tridiagonal
end