Skip to content

Commit 487d216

Browse files
committed
allow 3d arrays for Diagonal
1 parent c20360a commit 487d216

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

src/projection.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,10 @@ end
435435
ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag))
436436
(project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx)))
437437
(project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag))
438+
function (project::ProjectTo{Diagonal})(dx::AbstractArray)
439+
ind = diagind(size(dx,1), size(dx,2), 0)
440+
return Diagonal(project.diag(dx[ind]))
441+
end
438442
function (project::ProjectTo{Diagonal})(dx::Tangent{<:Diagonal}) # structural => natural
439443
return dx.diag isa ArrayOrZero ? Diagonal(project.diag(dx.diag)) : dx
440444
end

test/projection.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ struct NoSuperType end
307307
@testset "LinearAlgebra: sparse structured matrices" begin
308308
pdiag = ProjectTo(Diagonal(1:3))
309309
@test pdiag(reshape(1:9, 3, 3)) == Diagonal([1, 5, 9])
310+
@test pdiag(reshape(1:9, 3, 3, 1)) == Diagonal([1, 5, 9])
310311
@test pdiag(pdiag(reshape(1:9, 3, 3))) == pdiag(reshape(1:9, 3, 3))
311312
@test pdiag(rand(ComplexF32, 3, 3)) isa Diagonal{Float64}
312313
@test pdiag(Diagonal(1.0:3.0)) === Diagonal(1.0:3.0)

0 commit comments

Comments
 (0)