diff --git a/Project.toml b/Project.toml index 7ac7a30..7f0208a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.8.0" +version = "0.8.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/BlockArraysExtensions/BlockArraysExtensions.jl b/src/BlockArraysExtensions/BlockArraysExtensions.jl index 4a194d2..5815540 100644 --- a/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -88,6 +88,9 @@ function _blockslice(x, y::AbstractVector) return BlockIndices(x, y) end +# TODO: Constrain the type of `BlockIndices` more, this seems +# to assume that `S.blocks` is a list of blocks as opposed to +# a flat list of block indices like the definition below. function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}}) # TODO: Check that `i.indices` is consistent with `S.indices`. # It seems like this isn't handling the case where `i` is a @@ -96,6 +99,13 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}}) return _blockslice(S.blocks[Int(Block(i))], S.indices[Block(i)]) end +function Base.getindex( + S::BlockIndices{<:AbstractBlockVector{<:BlockIndex{1}}}, i::BlockSlice{<:Block{1}} +) + @assert length(S.indices[Block(i)]) == length(i.indices) + return _blockslice(S.blocks[Block(i)], S.indices[Block(i)]) +end + # This is used in slicing like: # a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2]) # I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2]) @@ -185,6 +195,7 @@ const GenericBlockIndexVectorSlices = BlockIndices{ <:BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector}} } const SubBlockSliceCollection = Union{ + Base.Slice, BlockIndexRangeSlice, BlockIndexRangeSlices, BlockIndexVectorSlices, diff --git a/src/BlockArraysExtensions/blockedunitrange.jl b/src/BlockArraysExtensions/blockedunitrange.jl index 604bef0..b111c83 100644 --- a/src/BlockArraysExtensions/blockedunitrange.jl +++ b/src/BlockArraysExtensions/blockedunitrange.jl @@ -349,6 +349,15 @@ BlockArrays.Block(b::BlockIndexVector) = b.block Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy.(a.indices)) +# Copied from BlockArrays.BlockIndexRange. +function Base.show(io::IO, B::BlockIndexVector) + show(io, Block(B)) + print(io, "[") + print_tuple_elements(io, B.indices) + print(io, "]") +end +Base.show(io::IO, ::MIME"text/plain", B::BlockIndexVector) = show(io, B) + function Base.getindex(b::AbstractBlockedUnitRange, Kkr::BlockIndexVector{1}) return b[block(Kkr)][Kkr.indices...] end diff --git a/src/BlockArraysExtensions/blockrange.jl b/src/BlockArraysExtensions/blockrange.jl index 7edd013..a1f291e 100644 --- a/src/BlockArraysExtensions/blockrange.jl +++ b/src/BlockArraysExtensions/blockrange.jl @@ -16,6 +16,12 @@ function Base.getindex(r::BlockUnitRange, I::Block{1}) return eachblockaxis(r)[Int(I)] .+ (first(r.r[I]) - 1) end +using BlockArrays: BlockedOneTo +const BlockOneTo{T<:Integer,B,CS,R<:BlockedOneTo{T,CS}} = BlockUnitRange{T,B,CS,R} +Base.axes(S::Base.Slice{<:BlockOneTo}) = (S.indices,) +Base.axes1(S::Base.Slice{<:BlockOneTo}) = S.indices +Base.unsafe_indices(S::Base.Slice{<:BlockOneTo}) = (S.indices,) + function BlockArrays.combine_blockaxes(r1::BlockUnitRange, r2::BlockUnitRange) if eachblockaxis(r1) ≠ eachblockaxis(r2) return throw(ArgumentError("BlockUnitRanges must have the same block axes")) diff --git a/src/abstractblocksparsearray/views.jl b/src/abstractblocksparsearray/views.jl index 084cb2c..19de8a3 100644 --- a/src/abstractblocksparsearray/views.jl +++ b/src/abstractblocksparsearray/views.jl @@ -190,6 +190,26 @@ function Base.view( return viewblock(a, block...) end +# Disambiguate between block reindexing of blockwise views +# (`BlockSliceCollection`) and subblockwise views (`SubBlockSliceCollection`), +# which both include `Base.Slice`. +function Base.view( + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{Base.Slice,N}}}, + block::Block{N}, +) where {T,N} + return viewblock(a, block) +end + +# Block reindexing of blockwise views (`BlockSliceCollection`). +function viewblock_blockslice(a::SubArray{<:Any,N}, block::Vararg{Block{1},N}) where {N} + I = CartesianIndex(Int.(block)) + # TODO: Use `eachblockstoredindex`. + if I ∈ eachstoredindex(blocks(a)) + return blocks(a)[I] + end + return BlockView(parent(a), Block.(Base.reindex(parentindices(blocks(a)), Tuple(I)))) +end + # XXX: TODO: Distinguish if a sub-view of the block needs to be taken! # Define a new `SubBlockSlice` which is used in: # `@interface interface(a) to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}})` @@ -199,12 +219,17 @@ function BlockArrays.viewblock( a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSliceCollection,N}}}, block::Vararg{Block{1},N}, ) where {T,N} - I = CartesianIndex(Int.(block)) - # TODO: Use `eachblockstoredindex`. - if I ∈ eachstoredindex(blocks(a)) - return blocks(a)[I] - end - return BlockView(parent(a), Block.(Base.reindex(parentindices(blocks(a)), Tuple(I)))) + return viewblock_blockslice(a, block...) +end + +# Disambiguate between block reindexing of blockwise views +# (`BlockSliceCollection`) and subblockwise views (`SubBlockSliceCollection`), +# which both include `Base.Slice`. +function BlockArrays.viewblock( + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{Base.Slice,N}}}, + block::Vararg{Block{1},N}, +) where {T,N} + return viewblock_blockslice(a, block...) end function to_blockindexrange( @@ -291,6 +316,15 @@ function Base.view( ) where {T,N} return viewblock(a, block...) end +# Fix ambiguity error. +function Base.view( + a::SubArray{ + T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}} + }, + block::Vararg{Block{1},N}, +) where {T,N} + return viewblock(a, block...) +end function BlockArrays.viewblock( a::SubArray{ T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}} @@ -302,6 +336,14 @@ end blockedslice_blocks(x::BlockSlice) = x.block blockedslice_blocks(x::BlockIndices) = x.blocks +# Reinterpret the slice blockwise. +function blockedslice_blocks(x::Base.Slice) + return mortar( + map(BlockRange(x.indices)) do b + return BlockIndexRange(b, Base.Slice(Base.axes1(x.indices[b]))) + end, + ) +end # TODO: Define `@interface interface(a) viewblock`. function BlockArrays.viewblock( @@ -319,6 +361,7 @@ function BlockArrays.viewblock( end return @view parent(a)[brs...] end + # TODO: Define `@interface interface(a) viewblock`. function BlockArrays.viewblock( a::SubArray{ diff --git a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 9ca44e5..30d75ba 100644 --- a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -222,19 +222,25 @@ end # a[mortar([Block(1)[1:2], Block(2)[1:3]]), mortar([Block(1)[1:2], Block(2)[1:3]])] # a[[Block(1)[1:2], Block(2)[1:3]], [Block(1)[1:2], Block(2)[1:3]]] -@interface ::AbstractBlockSparseArrayInterface function Base.to_indices( - a, inds, I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},Vararg{Any}} -) - I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) - return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) -end - @interface ::AbstractBlockSparseArrayInterface function Base.to_indices( a, inds, - I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector{1}}},Vararg{Any}}, + I::Tuple{ + BlockVector{<:BlockIndex{1},<:Vector{<:Union{BlockIndexRange{1},BlockIndexVector{1}}}}, + Vararg{Any}, + }, ) - I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) + # Index the `inds` by the `BlockIndexRange`/`BlockIndexVector`s on each block + # in order to canonicalize the indices and preserve metadata, + # such as sector data for symmetric tensors. + bs = mortar( + map(blocks(I[1])) do bi + b = Block(bi) + binds = only(bi.indices) + return BlockIndexVector(b, Base.axes1(inds[1][b])[binds]) + end, + ) + I1 = BlockIndices(bs, blockedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end @interface ::AbstractBlockSparseArrayInterface function Base.to_indices(