Skip to content

Commit 0374568

Browse files
committed
BlockDiagIEB AD fixes
1 parent a4f959b commit 0374568

File tree

5 files changed

+16
-10
lines changed

5 files changed

+16
-10
lines changed

src/field_vectors.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,3 @@ function pinv!(dst::FieldOrOpMatrix{<:Diagonal}, src::FieldOrOpMatrix{<:Diagonal
6565
end
6666

6767
promote_rule(::Type{F}, ::Type{<:Scalar}) where {F<:Field} = F
68-
arithmetic_closure(::F) where {F<:Field} = F

src/generic.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ precompute!!(L::Adjoint, f) = precompute!!(parent(L),f)'
322322
# splatted into a giant matrix when doing [f f; f f] (which they would othewise
323323
# be since they're Arrays)
324324
hvcat(rows::Tuple{Vararg{Int}}, values::Field...) = hvcat(rows, ([x] for x in values)...)
325+
hvcat(rows::Tuple{Vararg{Int}}, values::DiagOp...) = hvcat(rows, ([x] for x in values)...)
325326
hcat(values::Field...) = hcat(([x] for x in values)...)
326327

327328
### printing

src/proj_lambert.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ function Cℓ_to_Cov(::Val{:P}, proj::ProjLambert, CℓEE::Cℓs, CℓBB::Cℓs;
366366
end
367367
function Cℓ_to_Cov(::Val{:IP}, proj::ProjLambert, CℓTT, CℓEE, CℓBB, CℓTE; kwargs...)
368368
ΣTT, ΣEE, ΣBB, ΣTE = [Cℓ_to_Cov(:I,proj,Cℓ; kwargs...) for Cℓ in (CℓTT,CℓEE,CℓBB,CℓTE)]
369-
BlockDiagIEB(@SMatrix([ΣTT ΣTE; ΣTE ΣEE]), ΣBB)
369+
BlockDiagIEB([ΣTT ΣTE; ΣTE ΣEE], ΣBB)
370370
end
371371

372372
## ParamDependentOp covariances scaled by amplitudes in different ℓ-bins

src/specialops.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,10 @@ end
5757
# call sqrt/inv on it, and the ΣBB block separately as ΣB. This type
5858
# is generic with regards to the field type, F.
5959
struct BlockDiagIEB{T,F} <: ImplicitOp{T}
60-
ΣTE :: SMatrix{2,2,Diagonal{T,F},4}
60+
ΣTE :: SizedMatrix{2,2,Diagonal{T,F},2,Matrix{Diagonal{T,F}}}
6161
ΣB :: Diagonal{T,F}
6262
end
63+
BlockDiagIEB(ΣTE::AbstractMatrix{Diagonal{T,F}}, ΣB::Diagonal{T,F}) where {T,F} = BlockDiagIEB{T,F}(ΣTE, ΣB)
6364
# applying
6465
*(L::BlockDiagIEB, f::BaseS02) = L * IEBFourier(f)
6566
\(L::BlockDiagIEB, f::BaseS02) = pinv(L) * IEBFourier(f)
@@ -77,13 +78,13 @@ similar(L::BlockDiagIEB) = BlockDiagIEB(similar.(L.ΣTE), similar(L.ΣB))
7778
get_storage(L::BlockDiagIEB) = get_storage(L.ΣB)
7879
adapt_structure(storage, L::BlockDiagIEB) = BlockDiagIEB(adapt.(Ref(storage), L.ΣTE), adapt(storage, L.ΣB))
7980
simulate(rng::AbstractRNG, L::BlockDiagIEB; Nbatch=()) = sqrt(L) * randn!(rng, similar(diag(L), Nbatch...))
80-
logdet(L::BlockDiagIEB) = logdet(L.ΣTE[1,1]*L.ΣTE[2,2]-L.ΣTE[1,2]*L.ΣTE[2,1]) + logdet(L.ΣB)
81+
logdet(L::BlockDiagIEB) = logdet(det(L.ΣTE)) + logdet(L.ΣB)
8182
# arithmetic
82-
*(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB(SMatrix{2,2}(L.ΣTE * [[D[:I]] [0]; [0] [D[:E]]]), L.ΣB * D[:B])
83-
+(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB(@SMatrix[L.ΣTE[1,1]+D[:I] L.ΣTE[1,2]; L.ΣTE[2,1] L.ΣTE[2,2]+D[:E]], L.ΣB + D[:B])
83+
*(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB(L.ΣTE * [[D[:I]] [0]; [0] [D[:E]]], L.ΣB * D[:B])
84+
+(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB([L.ΣTE[1,1]+D[:I] L.ΣTE[1,2]; L.ΣTE[2,1] L.ΣTE[2,2]+D[:E]], L.ΣB + D[:B])
8485
*(La::F, Lb::F) where {F<:BlockDiagIEB} = F(La.ΣTE * Lb.ΣTE, La.ΣB * Lb.ΣB)
8586
+(La::F, Lb::F) where {F<:BlockDiagIEB} = F(La.ΣTE + Lb.ΣTE, La.ΣB + Lb.ΣB)
86-
+(L::BlockDiagIEB, U::UniformScaling{<:Scalar}) = BlockDiagIEB(@SMatrix[(L.ΣTE[1,1]+U) L.ΣTE[1,2]; L.ΣTE[2,1] (L.ΣTE[2,2]+U)], L.ΣB+U)
87+
+(L::BlockDiagIEB, U::UniformScaling{<:Scalar}) = BlockDiagIEB([(L.ΣTE[1,1]+U) L.ΣTE[1,2]; L.ΣTE[2,1] (L.ΣTE[2,2]+U)], L.ΣB+U)
8788
*(L::BlockDiagIEB, λ::Scalar) = BlockDiagIEB(L.ΣTE * λ, L.ΣB * λ)
8889
*(D::DiagOp{<:BaseIEBFourier}, L::BlockDiagIEB) = L * D
8990
+(U::UniformScaling{<:Scalar}, L::BlockDiagIEB) = L + U

src/util.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,25 @@ end
100100
# these allow pinv and sqrt of SMatrices of Diagonals to work correctly, which
101101
# we use for the T-E block of the covariance. hopefully some of this can be cut
102102
# down on in the futue with some PRs into StaticArrays.
103-
permutedims(A::SMatrix{2,2}) = @SMatrix[A[1] A[3]; A[2] A[4]]
104-
@auto_adjoint function sqrt(A::SMatrix{2,2,<:Diagonal})
103+
# permutedims(A::SMatrix{2,2}) = @SMatrix[A[1] A[3]; A[2] A[4]]
104+
@auto_adjoint function sqrt(A::SizedMatrix{2,2,<:Diagonal})
105105
# A = [a b; c d]
106106
a,c,b,d = A
107107
s = sqrt(a*d-b*c)
108108
t = pinv(sqrt(a+(d+2s)))
109109
@SMatrix[t*(a+s) t*b; t*c t*(d+s)]
110110
end
111-
@auto_adjoint function pinv(A::SMatrix{2,2,<:Diagonal})
111+
@auto_adjoint function pinv(A::SizedMatrix{2,2,<:Diagonal})
112112
# A = [a b; c d]
113113
a,c,b,d = A
114114
idet = pinv(a*d-b*c)
115115
@SMatrix[d*idet -(b*idet); -(c*idet) a*idet]
116116
end
117+
@auto_adjoint function det(A::SizedMatrix{2,2,<:Diagonal})
118+
# A = [a b; c d]
119+
a,c,b,d = A
120+
a*d-b*c
121+
end
117122

118123

119124
# some usefule tuple manipulation functions:

0 commit comments

Comments
 (0)