Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.1"
version = "0.5.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
21 changes: 21 additions & 0 deletions src/BlockArraysExtensions/blockedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,32 @@ using BlockArrays:
BlockSlice,
BlockVector,
block,
blockedrange,
blockindex,
blocklengths,
findblock,
findblockindex,
mortar

# Get the axes of each block of a block array.
function eachblockaxes(a::AbstractArray)
return map(axes, blocks(a))
end

axis(a::AbstractVector) = axes(a, 1)

# Get the axis of each block of a blocked unit
# range.
function eachblockaxis(a::AbstractVector)
return map(axis, blocks(a))
end

# Take a collection of axes and mortar them
# into a single blocked axis.
function mortar_axis(axs)
return blockedrange(length.(axs))
end

# Custom `BlockedUnitRange` constructor that takes a unit range
# and a set of block lengths, similar to `BlockArray(::AbstractArray, blocklengths...)`.
function blockedunitrange(a::AbstractUnitRange, blocklengths)
Expand Down
59 changes: 36 additions & 23 deletions src/factorizations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,19 @@ function MatrixAlgebraKit.default_svd_algorithm(A::AbstractBlockSparseMatrix; kw
return BlockPermutedDiagonalAlgorithm(alg)
end

# TODO: this should be replaced with a more general similar function that can handle setting
# the blocktype and element type - something like S = similar(A, BlockType(...))
function _similar_S(A::AbstractBlockSparseMatrix, s_axis)
function similar_output(
::typeof(svd_compact!),
A,
s_axis::AbstractUnitRange,
alg::MatrixAlgebraKit.AbstractAlgorithm,
)
U = similar(A, axes(A, 1), s_axis)
T = real(eltype(A))
return BlockSparseArray{T,2,Diagonal{T,Vector{T}}}(undef, (s_axis, s_axis))
# TODO: this should be replaced with a more general similar function that can handle setting
# the blocktype and element type - something like S = similar(A, BlockType(...))
S = BlockSparseMatrix{T,Diagonal{T,Vector{T}}}(undef, (s_axis, s_axis))
Vt = similar(A, s_axis, axes(A, 2))
return U, S, Vt
end

function MatrixAlgebraKit.initialize_output(
Expand All @@ -34,9 +42,9 @@ function MatrixAlgebraKit.initialize_output(
bm, bn = blocksize(A)
bmn = min(bm, bn)

brows = blocklengths(axes(A, 1))
bcols = blocklengths(axes(A, 2))
slengths = Vector{Int}(undef, bmn)
brows = eachblockaxis(axes(A, 1))
bcols = eachblockaxis(axes(A, 2))
s_axeses = Vector{eltype(brows)}(undef, bmn)

# fill in values for blocks that are present
bIs = collect(eachblockstoredindex(A))
Expand All @@ -46,21 +54,19 @@ function MatrixAlgebraKit.initialize_output(
row, col = Int.(Tuple(bI))
nrows = brows[row]
ncols = bcols[col]
slengths[col] = min(nrows, ncols)
s_axeses[col] = min(nrows, ncols)
end

# fill in values for blocks that aren't present, pairing them in order of occurence
# this is a convention, which at least gives the expected results for blockdiagonal
emptyrows = setdiff(1:bm, browIs)
emptycols = setdiff(1:bn, bcolIs)
for (row, col) in zip(emptyrows, emptycols)
slengths[col] = min(brows[row], bcols[col])
s_axeses[col] = min(brows[row], bcols[col])
end

s_axis = blockedrange(slengths)
U = similar(A, axes(A, 1), s_axis)
S = _similar_S(A, s_axis)
Vt = similar(A, s_axis, axes(A, 2))
s_axis = mortar_axis(s_axeses)
U, S, Vt = similar_output(svd_compact!, A, s_axis, alg)

# allocate output
for bI in eachblockstoredindex(A)
Expand All @@ -79,13 +85,23 @@ function MatrixAlgebraKit.initialize_output(
return U, S, Vt
end

function similar_output(
::typeof(svd_full!), A, s_axis::AbstractUnitRange, alg::MatrixAlgebraKit.AbstractAlgorithm
)
U = similar(A, axes(A, 1), s_axis)
T = real(eltype(A))
S = similar(A, T, (s_axis, axes(A, 2)))
Vt = similar(A, axes(A, 2), axes(A, 2))
return U, S, Vt
end

function MatrixAlgebraKit.initialize_output(
::typeof(svd_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
)
bm, bn = blocksize(A)

brows = blocklengths(axes(A, 1))
slengths = copy(brows)
brows = eachblockaxis(axes(A, 1))
s_axes = copy(brows)

# fill in values for blocks that are present
bIs = collect(eachblockstoredindex(A))
Expand All @@ -94,25 +110,22 @@ function MatrixAlgebraKit.initialize_output(
for bI in eachblockstoredindex(A)
row, col = Int.(Tuple(bI))
nrows = brows[row]
slengths[col] = nrows
s_axes[col] = nrows
end

# fill in values for blocks that aren't present, pairing them in order of occurence
# this is a convention, which at least gives the expected results for blockdiagonal
emptyrows = setdiff(1:bm, browIs)
emptycols = setdiff(1:bn, bcolIs)
for (row, col) in zip(emptyrows, emptycols)
slengths[col] = brows[row]
s_axes[col] = brows[row]
end
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
slengths[bn + i] = brows[emptyrows[k]]
s_axes[bn + i] = brows[emptyrows[k]]
end

s_axis = blockedrange(slengths)
U = similar(A, axes(A, 1), s_axis)
Tr = real(eltype(A))
S = similar(A, Tr, (s_axis, axes(A, 2)))
Vt = similar(A, axes(A, 2), axes(A, 2))
s_axis = mortar_axis(s_axes)
U, S, Vt = similar_output(svd_full!, A, s_axis, alg)

# allocate output
for bI in eachblockstoredindex(A)
Expand Down
25 changes: 18 additions & 7 deletions src/factorizations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ function MatrixAlgebraKit.findtruncated(
return indexmask
end

function similar_truncate(
::typeof(svd_trunc!),
(U, S, Vᴴ)::TBlockUSVᴴ,
strategy::BlockPermutedDiagonalTruncationStrategy,
indexmask=MatrixAlgebraKit.findtruncated(diagview(S), strategy),
)
ax = axes(S, 1)
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
s_lengths = filter!(>(0), map(counter, blocks(ax)))
s_axis = blockedrange(s_lengths)
Ũ = similar(U, axes(U, 1), s_axis)
S̃ = similar(S, s_axis, s_axis)
Ṽᴴ = similar(Vᴴ, s_axis, axes(Vᴴ, 2))
return Ũ, S̃, Ṽᴴ
end

function MatrixAlgebraKit.truncate!(
::typeof(svd_trunc!),
(U, S, Vᴴ)::TBlockUSVᴴ,
Expand All @@ -54,13 +70,7 @@ function MatrixAlgebraKit.truncate!(

# first determine the block structure of the output to avoid having assumptions on the
# data structures
ax = axes(S, 1)
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
Slengths = filter!(>(0), map(counter, blocks(ax)))
Sax = blockedrange(Slengths)
Ũ = similar(U, axes(U, 1), Sax)
S̃ = similar(S, Sax, Sax)
Ṽᴴ = similar(Vᴴ, Sax, axes(Vᴴ, 2))
Ũ, S̃, Ṽᴴ = similar_truncate(svd_trunc!, (U, S, Vᴴ), strategy, indexmask)

# then loop over the blocks and assign the data
# TODO: figure out if we can presort and loop over the blocks -
Expand All @@ -70,6 +80,7 @@ function MatrixAlgebraKit.truncate!(
bI_Vᴴs = collect(eachblockstoredindex(Vᴴ))

I′ = 0 # number of skipped blocks that got fully truncated
ax = axes(S, 1)
for I in 1:blocksize(ax, 1)
b = ax[Block(I)]
mask = indexmask[b]
Expand Down
Loading