Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.7.21"
version = "0.7.23"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
7 changes: 7 additions & 0 deletions src/BlockArraysExtensions/BlockArraysExtensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe
end
Base.getindex(S::BlockIndices, i::Integer) = getindex(S.indices, i)

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices::BlockIndices)
# TODO: Is this a good definition? It ignores `indices.indices`.
return a[indices.blocks]
end

# Generalization of to `BlockArrays._blockslice`:
# https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.6.3/src/views.jl#L13-L14
# Used by `BlockArrays.unblock`, which is used in `to_indices`
Expand Down Expand Up @@ -179,6 +185,7 @@ const GenericBlockIndexVectorSlices = BlockIndices{
<:BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector}}
}
const SubBlockSliceCollection = Union{
Base.Slice,
BlockIndexRangeSlice,
BlockIndexRangeSlices,
BlockIndexVectorSlices,
Expand Down
23 changes: 21 additions & 2 deletions src/BlockArraysExtensions/blockedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ end
# `Base.getindex(a::Block, b...)`.
_getindex(a::Block{N}, b::Vararg{Any,N}) where {N} = GenericBlockIndex(a, b)
_getindex(a::Block{N}, b::Vararg{Integer,N}) where {N} = a[b...]
_getindex(a::Block{N}, b::Vararg{AbstractUnitRange{<:Integer},N}) where {N} = a[b...]
_getindex(a::Block{N}, b::Vararg{AbstractVector,N}) where {N} = BlockIndexVector(a, b)
# Fix ambiguity.
_getindex(a::Block{0}) = a[]

Expand Down Expand Up @@ -347,6 +349,15 @@ BlockArrays.Block(b::BlockIndexVector) = b.block

Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy.(a.indices))

# Copied from BlockArrays.BlockIndexRange.
function Base.show(io::IO, B::BlockIndexVector)
show(io, Block(B))
print(io, "[")
print_tuple_elements(io, B.indices)
print(io, "]")
end
Base.show(io::IO, ::MIME"text/plain", B::BlockIndexVector) = show(io, B)

function Base.getindex(b::AbstractBlockedUnitRange, Kkr::BlockIndexVector{1})
return b[block(Kkr)][Kkr.indices...]
end
Expand All @@ -366,13 +377,21 @@ function blockedunitrange_getindices(
a::AbstractBlockedUnitRange,
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector{1}}},
)
return mortar(map(b -> a[b], blocks(indices)))
blks = map(b -> a[b], blocks(indices))
# Preserve any extra structure in the axes, like a
# Kronecker structure, symmetry sectors, etc.
ax = mortar_axis(map(b -> axis(a[b]), blocks(indices)))
return mortar(blks, (ax,))
end
function blockedunitrange_getindices(
a::AbstractBlockedUnitRange,
indices::BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector{1}}},
)
return mortar(map(b -> a[b], blocks(indices)))
blks = map(b -> a[b], blocks(indices))
# Preserve any extra structure in the axes, like a
# Kronecker structure, symmetry sectors, etc.
ax = mortar_axis(map(b -> axis(a[b]), blocks(indices)))
return mortar(blks, (ax,))
end

# This is a specialization of `BlockArrays.unblock`:
Expand Down
5 changes: 5 additions & 0 deletions src/BlockArraysExtensions/blockrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ function Base.getindex(r::BlockUnitRange, I::Block{1})
return eachblockaxis(r)[Int(I)] .+ (first(r.r[I]) - 1)
end

const BlockOneTo{T<:Integer,B,CS,R<:BlockedOneTo{T,CS}} = BlockUnitRange{T,B,CS,R}
Base.axes(S::Base.Slice{<:BlockOneTo}) = (S.indices,)
Base.axes1(S::Base.Slice{<:BlockOneTo}) = S.indices
Base.unsafe_indices(S::Base.Slice{<:BlockOneTo}) = (S.indices,)

function BlockArrays.combine_blockaxes(r1::BlockUnitRange, r2::BlockUnitRange)
if eachblockaxis(r1) ≠ eachblockaxis(r2)
return throw(ArgumentError("BlockUnitRanges must have the same block axes"))
Expand Down
20 changes: 19 additions & 1 deletion src/abstractblocksparsearray/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearAlgebra: LinearAlgebra, Adjoint, Transpose, norm, tr
using LinearAlgebra: LinearAlgebra, Adjoint, Transpose, diag, norm, tr

# Like: https://github.com/JuliaLang/julia/blob/v1.11.1/stdlib/LinearAlgebra/src/transpose.jl#L184
# but also takes the dual of the axes.
Expand Down Expand Up @@ -33,6 +33,24 @@ function LinearAlgebra.tr(a::AnyAbstractBlockSparseMatrix)
return tr_a
end

# TODO: Define in DiagonalArrays.jl.
function diagaxis(a::AbstractArray)
LinearAlgebra.checksquare(a)
return axes(a, 1)
end
function LinearAlgebra.diag(a::AnyAbstractBlockSparseMatrix)
# TODO: Add `checkblocksquare` to also check it is square blockwise.
LinearAlgebra.checksquare(a)
diagaxes = map(blockdiagindices(a)) do I
return diagaxis(@view(a[I]))
end
r = blockrange(diagaxes)
stored_blocks = Dict((
Tuple(I)[1] => diag(@view!(a[I])) for I in eachstoredblockdiagindex(a)
))
return blocksparse(stored_blocks, (r,))
end

# TODO: Define `SparseArraysBase.isdiag`, define as
# `isdiag(blocks(a))`.
function blockisdiag(a::AbstractArray)
Expand Down
8 changes: 8 additions & 0 deletions src/abstractblocksparsearray/views.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,14 @@ end

blockedslice_blocks(x::BlockSlice) = x.block
blockedslice_blocks(x::BlockIndices) = x.blocks
# Reinterpret the slice blockwise.
function blockedslice_blocks(x::Base.Slice)
return mortar(
map(BlockRange(x.indices)) do b
return BlockIndexRange(b, Base.Slice(Base.axes1(x.indices[b])))
end,
)
end

# TODO: Define `@interface interface(a) viewblock`.
function BlockArrays.viewblock(
Expand Down
45 changes: 26 additions & 19 deletions src/factorizations/truncation.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
using MatrixAlgebraKit: TruncationStrategy, diagview, eig_trunc!, eigh_trunc!, svd_trunc!

function MatrixAlgebraKit.diagview(A::BlockSparseMatrix{T,Diagonal{T,Vector{T}}}) where {T}
D = BlockSparseVector{T}(undef, axes(A, 1))
for I in eachblockstoredindex(A)
if ==(Int.(Tuple(I))...)
D[Tuple(I)[1]] = diagview(A[I])
end
end
return D
end
using MatrixAlgebraKit:
TruncationStrategy,
diagview,
eig_trunc!,
eigh_trunc!,
findtruncated,
svd_trunc!,
truncate!

"""
BlockPermutedDiagonalTruncationStrategy(strategy::TruncationStrategy)
Expand All @@ -27,7 +24,7 @@ function MatrixAlgebraKit.truncate!(
strategy::TruncationStrategy,
)
# TODO assert blockdiagonal
return MatrixAlgebraKit.truncate!(
return truncate!(
svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy)
)
end
Expand All @@ -38,9 +35,7 @@ for f in [:eig_trunc!, :eigh_trunc!]
(D, V)::NTuple{2,AbstractBlockSparseMatrix},
strategy::TruncationStrategy,
)
return MatrixAlgebraKit.truncate!(
$f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy)
)
return truncate!($f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy))
end
end
end
Expand All @@ -50,18 +45,30 @@ end
function MatrixAlgebraKit.findtruncated(
values::AbstractVector, strategy::BlockPermutedDiagonalTruncationStrategy
)
ind = MatrixAlgebraKit.findtruncated(values, strategy.strategy)
ind = findtruncated(Vector(values), strategy.strategy)
indexmask = falses(length(values))
indexmask[ind] .= true
return indexmask
return to_truncated_indices(values, indexmask)
end

# Allow customizing the indices output by `findtruncated`
# based on the type of `values`, for example to preserve
# a block or Kronecker structure.
to_truncated_indices(values, I) = I
function to_truncated_indices(values::AbstractBlockVector, I::AbstractVector{Bool})
I′ = BlockedVector(I, blocklengths(axis(values)))
blocks = map(BlockRange(values)) do b
return _getindex(b, to_truncated_indices(values[b], I′[b]))
end
return blocks
end

function MatrixAlgebraKit.truncate!(
::typeof(svd_trunc!),
(U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix},
strategy::BlockPermutedDiagonalTruncationStrategy,
)
I = MatrixAlgebraKit.findtruncated(diagview(S), strategy)
I = findtruncated(diag(S), strategy)
return (U[:, I], S[I, I], Vᴴ[I, :])
end
for f in [:eig_trunc!, :eigh_trunc!]
Expand All @@ -71,7 +78,7 @@ for f in [:eig_trunc!, :eigh_trunc!]
(D, V)::NTuple{2,AbstractBlockSparseMatrix},
strategy::BlockPermutedDiagonalTruncationStrategy,
)
I = MatrixAlgebraKit.findtruncated(diagview(D), strategy)
I = findtruncated(diag(D), strategy)
return (D[I, I], V[:, I])
end
end
Expand Down
64 changes: 40 additions & 24 deletions test/test_factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,20 +146,23 @@ test_params = Iterators.product(blockszs, eltypes)
@test test_svd(a, usv_empty)

# test blockdiagonal
rng = StableRNG(123)
for i in LinearAlgebra.diagind(blocks(a))
I = CartesianIndices(blocks(a))[i]
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
a[Block(I.I...)] = rand(rng, T, size(blocks(a)[i]))
end
usv = svd_compact(a)
@test test_svd(a, usv)

perm = Random.randperm(length(m))
rng = StableRNG(123)
perm = Random.randperm(rng, length(m))
b = a[Block.(perm), Block.(1:length(n))]
usv = svd_compact(b)
@test test_svd(b, usv)

# test permuted blockdiagonal with missing row/col
I_removed = rand(eachblockstoredindex(b))
rng = StableRNG(123)
I_removed = rand(rng, eachblockstoredindex(b))
c = copy(b)
delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed))))
usv = svd_compact(c)
Expand All @@ -176,20 +179,23 @@ end
@test test_svd(a, usv_empty; full=true)

# test blockdiagonal
rng = StableRNG(123)
for i in LinearAlgebra.diagind(blocks(a))
I = CartesianIndices(blocks(a))[i]
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
a[Block(I.I...)] = rand(rng, T, size(blocks(a)[i]))
end
usv = svd_full(a)
@test test_svd(a, usv; full=true)

perm = Random.randperm(length(m))
rng = StableRNG(123)
perm = Random.randperm(rng, length(m))
b = a[Block.(perm), Block.(1:length(n))]
usv = svd_full(b)
@test test_svd(b, usv; full=true)

# test permuted blockdiagonal with missing row/col
I_removed = rand(eachblockstoredindex(b))
rng = StableRNG(123)
I_removed = rand(rng, eachblockstoredindex(b))
c = copy(b)
delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed))))
usv = svd_full(c)
Expand All @@ -203,9 +209,10 @@ end
a = BlockSparseArray{T}(undef, m, n)

# test blockdiagonal
rng = StableRNG(123)
for i in LinearAlgebra.diagind(blocks(a))
I = CartesianIndices(blocks(a))[i]
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
a[Block(I.I...)] = rand(rng, T, size(blocks(a)[i]))
end

minmn = min(size(a)...)
Expand Down Expand Up @@ -236,7 +243,8 @@ end
@test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I)

# test permuted blockdiagonal
perm = Random.randperm(length(m))
rng = StableRNG(123)
perm = Random.randperm(rng, length(m))
b = a[Block.(perm), Block.(1:length(n))]
for trunc in (truncrank(r), trunctol(atol))
U1, S1, V1ᴴ = svd_trunc(b; trunc)
Expand Down Expand Up @@ -270,8 +278,9 @@ end
@testset "qr_compact (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
A[Block(1, 1)] = randn(T, i, k)
A[Block(2, 2)] = randn(T, j, l)
rng = StableRNG(123)
A[Block(1, 1)] = randn(rng, T, i, k)
A[Block(2, 2)] = randn(rng, T, j, l)
Q, R = qr_compact(A)
@test Matrix(Q'Q) ≈ LinearAlgebra.I
@test A ≈ Q * R
Expand All @@ -281,8 +290,9 @@ end
@testset "qr_full (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
A[Block(1, 1)] = randn(T, i, k)
A[Block(2, 2)] = randn(T, j, l)
rng = StableRNG(123)
A[Block(1, 1)] = randn(rng, T, i, k)
A[Block(2, 2)] = randn(rng, T, j, l)
Q, R = qr_full(A)
Q′, R′ = qr_full(Matrix(A))
@test size(Q) == size(Q′)
Expand All @@ -296,8 +306,9 @@ end
@testset "lq_compact" for T in (Float32, Float64, ComplexF32, ComplexF64)
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
A[Block(1, 1)] = randn(T, i, k)
A[Block(2, 2)] = randn(T, j, l)
rng = StableRNG(123)
A[Block(1, 1)] = randn(rng, T, i, k)
A[Block(2, 2)] = randn(rng, T, j, l)
L, Q = lq_compact(A)
@test Matrix(Q * Q') ≈ LinearAlgebra.I
@test A ≈ L * Q
Expand All @@ -307,8 +318,9 @@ end
@testset "lq_full" for T in (Float32, Float64, ComplexF32, ComplexF64)
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
A[Block(1, 1)] = randn(T, i, k)
A[Block(2, 2)] = randn(T, j, l)
rng = StableRNG(123)
A[Block(1, 1)] = randn(rng, T, i, k)
A[Block(2, 2)] = randn(rng, T, j, l)
L, Q = lq_full(A)
L′, Q′ = lq_full(Matrix(A))
@test size(L) == size(L′)
Expand All @@ -321,8 +333,9 @@ end

@testset "left_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
A = BlockSparseArray{T}(undef, ([3, 4], [2, 3]))
A[Block(1, 1)] = randn(T, 3, 2)
A[Block(2, 2)] = randn(T, 4, 3)
rng = StableRNG(123)
A[Block(1, 1)] = randn(rng, T, 3, 2)
A[Block(2, 2)] = randn(rng, T, 4, 3)

U, C = left_polar(A)
@test U * C ≈ A
Expand All @@ -331,8 +344,9 @@ end

@testset "right_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
A = BlockSparseArray{T}(undef, ([2, 3], [3, 4]))
A[Block(1, 1)] = randn(T, 2, 3)
A[Block(2, 2)] = randn(T, 3, 4)
rng = StableRNG(123)
A[Block(1, 1)] = randn(rng, T, 2, 3)
A[Block(2, 2)] = randn(rng, T, 3, 4)

C, U = right_polar(A)
@test C * U ≈ A
Expand All @@ -341,8 +355,9 @@ end

@testset "left_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
A = BlockSparseArray{T}(undef, ([3, 4], [2, 3]))
A[Block(1, 1)] = randn(T, 3, 2)
A[Block(2, 2)] = randn(T, 4, 3)
rng = StableRNG(123)
A[Block(1, 1)] = randn(rng, T, 3, 2)
A[Block(2, 2)] = randn(rng, T, 4, 3)

for kind in (:polar, :qr, :svd)
U, C = left_orth(A; kind)
Expand All @@ -358,8 +373,9 @@ end

@testset "right_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
A = BlockSparseArray{T}(undef, ([2, 3], [3, 4]))
A[Block(1, 1)] = randn(T, 2, 3)
A[Block(2, 2)] = randn(T, 3, 4)
rng = StableRNG(123)
A[Block(1, 1)] = randn(rng, T, 2, 3)
A[Block(2, 2)] = randn(rng, T, 3, 4)

for kind in (:lq, :polar, :svd)
C, U = right_orth(A; kind)
Expand Down