Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.8.0"
version = "0.8.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
11 changes: 11 additions & 0 deletions src/BlockArraysExtensions/BlockArraysExtensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ function _blockslice(x, y::AbstractVector)
return BlockIndices(x, y)
end

# TODO: Constrain the type of `BlockIndices` more, this seems
# to assume that `S.blocks` is a list of blocks as opposed to
# a flat list of block indices like the definition below.
function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
# TODO: Check that `i.indices` is consistent with `S.indices`.
# It seems like this isn't handling the case where `i` is a
Expand All @@ -96,6 +99,13 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
return _blockslice(S.blocks[Int(Block(i))], S.indices[Block(i)])
end

function Base.getindex(
S::BlockIndices{<:AbstractBlockVector{<:BlockIndex{1}}}, i::BlockSlice{<:Block{1}}
)
@assert length(S.indices[Block(i)]) == length(i.indices)
return _blockslice(S.blocks[Block(i)], S.indices[Block(i)])
end

# This is used in slicing like:
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
Expand Down Expand Up @@ -185,6 +195,7 @@ const GenericBlockIndexVectorSlices = BlockIndices{
<:BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector}}
}
const SubBlockSliceCollection = Union{
Base.Slice,
BlockIndexRangeSlice,
BlockIndexRangeSlices,
BlockIndexVectorSlices,
Expand Down
9 changes: 9 additions & 0 deletions src/BlockArraysExtensions/blockedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,15 @@ BlockArrays.Block(b::BlockIndexVector) = b.block

Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy.(a.indices))

# Copied from BlockArrays.BlockIndexRange.
function Base.show(io::IO, B::BlockIndexVector)
show(io, Block(B))
print(io, "[")
print_tuple_elements(io, B.indices)
print(io, "]")
end
Base.show(io::IO, ::MIME"text/plain", B::BlockIndexVector) = show(io, B)

function Base.getindex(b::AbstractBlockedUnitRange, Kkr::BlockIndexVector{1})
return b[block(Kkr)][Kkr.indices...]
end
Expand Down
6 changes: 6 additions & 0 deletions src/BlockArraysExtensions/blockrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ function Base.getindex(r::BlockUnitRange, I::Block{1})
return eachblockaxis(r)[Int(I)] .+ (first(r.r[I]) - 1)
end

using BlockArrays: BlockedOneTo
const BlockOneTo{T<:Integer,B,CS,R<:BlockedOneTo{T,CS}} = BlockUnitRange{T,B,CS,R}
Base.axes(S::Base.Slice{<:BlockOneTo}) = (S.indices,)
Base.axes1(S::Base.Slice{<:BlockOneTo}) = S.indices
Base.unsafe_indices(S::Base.Slice{<:BlockOneTo}) = (S.indices,)

function BlockArrays.combine_blockaxes(r1::BlockUnitRange, r2::BlockUnitRange)
if eachblockaxis(r1) ≠ eachblockaxis(r2)
return throw(ArgumentError("BlockUnitRanges must have the same block axes"))
Expand Down
55 changes: 49 additions & 6 deletions src/abstractblocksparsearray/views.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,26 @@ function Base.view(
return viewblock(a, block...)
end

# Disambiguate between block reindexing of blockwise views
# (`BlockSliceCollection`) and subblockwise views (`SubBlockSliceCollection`),
# which both include `Base.Slice`.
function Base.view(
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{Base.Slice,N}}},
block::Block{N},
) where {T,N}
return viewblock(a, block)
end

# Block reindexing of blockwise views (`BlockSliceCollection`).
function viewblock_blockslice(a::SubArray{<:Any,N}, block::Vararg{Block{1},N}) where {N}
I = CartesianIndex(Int.(block))
# TODO: Use `eachblockstoredindex`.
if I ∈ eachstoredindex(blocks(a))
return blocks(a)[I]
end
return BlockView(parent(a), Block.(Base.reindex(parentindices(blocks(a)), Tuple(I))))
end

# XXX: TODO: Distinguish if a sub-view of the block needs to be taken!
# Define a new `SubBlockSlice` which is used in:
# `@interface interface(a) to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}})`
Expand All @@ -199,12 +219,17 @@ function BlockArrays.viewblock(
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSliceCollection,N}}},
block::Vararg{Block{1},N},
) where {T,N}
I = CartesianIndex(Int.(block))
# TODO: Use `eachblockstoredindex`.
if I ∈ eachstoredindex(blocks(a))
return blocks(a)[I]
end
return BlockView(parent(a), Block.(Base.reindex(parentindices(blocks(a)), Tuple(I))))
return viewblock_blockslice(a, block...)
end

# Disambiguate between block reindexing of blockwise views
# (`BlockSliceCollection`) and subblockwise views (`SubBlockSliceCollection`),
# which both include `Base.Slice`.
function BlockArrays.viewblock(
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{Base.Slice,N}}},
block::Vararg{Block{1},N},
) where {T,N}
return viewblock_blockslice(a, block...)
end

function to_blockindexrange(
Expand Down Expand Up @@ -291,6 +316,15 @@ function Base.view(
) where {T,N}
return viewblock(a, block...)
end
# Fix ambiguity error.
function Base.view(
a::SubArray{
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
},
block::Vararg{Block{1},N},
) where {T,N}
return viewblock(a, block...)
end
function BlockArrays.viewblock(
a::SubArray{
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
Expand All @@ -302,6 +336,14 @@ end

blockedslice_blocks(x::BlockSlice) = x.block
blockedslice_blocks(x::BlockIndices) = x.blocks
# Reinterpret the slice blockwise.
function blockedslice_blocks(x::Base.Slice)
return mortar(
map(BlockRange(x.indices)) do b
return BlockIndexRange(b, Base.Slice(Base.axes1(x.indices[b])))
end,
)
end

# TODO: Define `@interface interface(a) viewblock`.
function BlockArrays.viewblock(
Expand All @@ -319,6 +361,7 @@ function BlockArrays.viewblock(
end
return @view parent(a)[brs...]
end

# TODO: Define `@interface interface(a) viewblock`.
function BlockArrays.viewblock(
a::SubArray{
Expand Down
24 changes: 15 additions & 9 deletions src/blocksparsearrayinterface/blocksparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,19 +222,25 @@ end

# a[mortar([Block(1)[1:2], Block(2)[1:3]]), mortar([Block(1)[1:2], Block(2)[1:3]])]
# a[[Block(1)[1:2], Block(2)[1:3]], [Block(1)[1:2], Block(2)[1:3]]]
@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
a, inds, I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},Vararg{Any}}
)
I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1]))
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
end

@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
a,
inds,
I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector{1}}},Vararg{Any}},
I::Tuple{
BlockVector{<:BlockIndex{1},<:Vector{<:Union{BlockIndexRange{1},BlockIndexVector{1}}}},
Vararg{Any},
},
)
I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1]))
# Index the `inds` by the `BlockIndexRange`/`BlockIndexVector`s on each block
# in order to canonicalize the indices and preserve metadata,
# such as sector data for symmetric tensors.
bs = mortar(
map(blocks(I[1])) do bi
b = Block(bi)
binds = only(bi.indices)
return BlockIndexVector(b, Base.axes1(inds[1][b])[binds])
end,
)
I1 = BlockIndices(bs, blockedunitrange_getindices(inds[1], I[1]))
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
end
@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
Expand Down
Loading