Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.0"
version = "0.5.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
1 change: 1 addition & 0 deletions src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,6 @@ include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl")

# factorizations
include("factorizations/svd.jl")
include("factorizations/truncation.jl")

end
3 changes: 2 additions & 1 deletion src/factorizations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ using MatrixAlgebraKit: MatrixAlgebraKit, svd_compact!, svd_full!
BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm)

A wrapper for `MatrixAlgebraKit.AbstractAlgorithm` that implements the wrapped algorithm on
a block-by-block basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted block-diagonal matrix.
a block-by-block basis, which is possible if the input matrix is a block-diagonal matrix or
a block permuted block-diagonal matrix.
"""
struct BlockPermutedDiagonalAlgorithm{A<:MatrixAlgebraKit.AbstractAlgorithm} <:
MatrixAlgebraKit.AbstractAlgorithm
Expand Down
102 changes: 102 additions & 0 deletions src/factorizations/truncation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using MatrixAlgebraKit: TruncationStrategy, diagview, svd_trunc!

function MatrixAlgebraKit.diagview(A::BlockSparseMatrix{T,Diagonal{T,Vector{T}}}) where {T}
D = BlockSparseVector{T}(undef, axes(A, 1))
for I in eachblockstoredindex(A)
if ==(Int.(Tuple(I))...)
D[Tuple(I)[1]] = diagview(A[I])
end
end
return D
end

"""
BlockPermutedDiagonalTruncationStrategy(strategy::TruncationStrategy)

A wrapper for `TruncationStrategy` that implements the wrapped strategy on a block-by-block
basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted
block-diagonal matrix.
"""
struct BlockPermutedDiagonalTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy
strategy::T
end

const TBlockUSVᴴ = Tuple{
<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix
}

function MatrixAlgebraKit.truncate!(
::typeof(svd_trunc!), (U, S, Vᴴ)::TBlockUSVᴴ, strategy::TruncationStrategy
)
# TODO assert blockdiagonal
return MatrixAlgebraKit.truncate!(
svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy)
)
end

# cannot use regular slicing here: I want to slice without altering blockstructure
# solution: use boolean indexing and slice the mask, effectively cheaply inverting the map
function MatrixAlgebraKit.findtruncated(
values::AbstractVector, strategy::BlockPermutedDiagonalTruncationStrategy
)
ind = MatrixAlgebraKit.findtruncated(values, strategy.strategy)
indexmask = falses(length(values))
indexmask[ind] .= true
return indexmask
end

function MatrixAlgebraKit.truncate!(
::typeof(svd_trunc!),
(U, S, Vᴴ)::TBlockUSVᴴ,
strategy::BlockPermutedDiagonalTruncationStrategy,
)
indexmask = MatrixAlgebraKit.findtruncated(diagview(S), strategy)

# first determine the block structure of the output to avoid having assumptions on the
# data structures
ax = axes(S, 1)
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
Slengths = filter!(>(0), map(counter, blocks(ax)))
Sax = blockedrange(Slengths)
Ũ = similar(U, axes(U, 1), Sax)
S̃ = similar(S, Sax, Sax)
Ṽᴴ = similar(Vᴴ, Sax, axes(Vᴴ, 2))

# then loop over the blocks and assign the data
# TODO: figure out if we can presort and loop over the blocks -
# for now this has issues with missing blocks
bI_Us = collect(eachblockstoredindex(U))
bI_Ss = collect(eachblockstoredindex(S))
bI_Vᴴs = collect(eachblockstoredindex(Vᴴ))

I′ = 0 # number of skipped blocks that got fully truncated
for I in 1:blocksize(ax, 1)
b = ax[Block(I)]
mask = indexmask[b]

if !any(mask)
I′ += 1
continue
end

bU_id = @something findfirst(x -> last(Tuple(x)) == Block(I), bI_Us) error(
"No U-block found for $I"
)
bU = Tuple(bI_Us[bU_id])
Ũ[bU[1], bU[2] - Block(I′)] = view(U, bU...)[:, mask]

bVᴴ_id = @something findfirst(x -> first(Tuple(x)) == Block(I), bI_Vᴴs) error(
"No Vᴴ-block found for $I"
)
bVᴴ = Tuple(bI_Vᴴs[bVᴴ_id])
Ṽᴴ[bVᴴ[1] - Block(I′), bVᴴ[2]] = view(Vᴴ, bVᴴ...)[mask, :]

bS_id = findfirst(x -> last(Tuple(x)) == Block(I), bI_Ss)
if !isnothing(bS_id)
bS = Tuple(bI_Ss[bS_id])
S̃[(bS .- Block(I′))...] = Diagonal(diagview(view(S, bS...))[mask])
end
end

return Ũ, S̃, Ṽᴴ
end
73 changes: 72 additions & 1 deletion test/test_factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex
using MatrixAlgebraKit: svd_compact, svd_full
using MatrixAlgebraKit: svd_compact, svd_full, svd_trunc, truncrank, trunctol
using LinearAlgebra: LinearAlgebra
using Random: Random
using Test: @inferred, @testset, @test
Expand Down Expand Up @@ -83,3 +83,74 @@ end
usv = svd_full(c)
@test test_svd(c, usv; full=true)
end

# svd_trunc!
# ----------

@testset "svd_trunc ($m, $n) BlockSparseMatri{$T}" for ((m, n), T) in test_params
a = BlockSparseArray{T}(undef, m, n)

# test blockdiagonal
for i in LinearAlgebra.diagind(blocks(a))
I = CartesianIndices(blocks(a))[i]
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
end

minmn = min(size(a)...)
r = max(1, minmn - 2)
trunc = truncrank(r)

U1, S1, V1ᴴ = svd_trunc(a; trunc)
U2, S2, V2ᴴ = svd_trunc(Matrix(a); trunc)
@test size(U1) == size(U2)
@test size(S1) == size(S2)
@test size(V1ᴴ) == size(V2ᴴ)
@test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ

@test (U1' * U1 ≈ LinearAlgebra.I)
@test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I)

atol = minimum(LinearAlgebra.diag(S1)) + 10 * eps(real(T))
trunc = trunctol(atol)

U1, S1, V1ᴴ = svd_trunc(a; trunc)
U2, S2, V2ᴴ = svd_trunc(Matrix(a); trunc)
@test size(U1) == size(U2)
@test size(S1) == size(S2)
@test size(V1ᴴ) == size(V2ᴴ)
@test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ

@test (U1' * U1 ≈ LinearAlgebra.I)
@test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I)

# test permuted blockdiagonal
perm = Random.randperm(length(m))
b = a[Block.(perm), Block.(1:length(n))]
for trunc in (truncrank(r), trunctol(atol))
U1, S1, V1ᴴ = svd_trunc(b; trunc)
U2, S2, V2ᴴ = svd_trunc(Matrix(b); trunc)
@test size(U1) == size(U2)
@test size(S1) == size(S2)
@test size(V1ᴴ) == size(V2ᴴ)
@test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ

@test (U1' * U1 ≈ LinearAlgebra.I)
@test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I)
end

# test permuted blockdiagonal with missing row/col
I_removed = rand(eachblockstoredindex(b))
c = copy(b)
delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed))))
for trunc in (truncrank(r), trunctol(atol))
U1, S1, V1ᴴ = svd_trunc(c; trunc)
U2, S2, V2ᴴ = svd_trunc(Matrix(c); trunc)
@test size(U1) == size(U2)
@test size(S1) == size(S2)
@test size(V1ᴴ) == size(V2ᴴ)
@test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ

@test (U1' * U1 ≈ LinearAlgebra.I)
@test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I)
end
end
Loading