Skip to content

Commit 83ed436

Browse files
authored
More general truncation and slicing (#159)
1 parent 97a3c14 commit 83ed436

File tree

6 files changed

+91
-16
lines changed

6 files changed

+91
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.8.0"
4+
version = "0.8.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ function _blockslice(x, y::AbstractVector)
8888
return BlockIndices(x, y)
8989
end
9090

91+
# TODO: Constrain the type of `BlockIndices` more, this seems
92+
# to assume that `S.blocks` is a list of blocks as opposed to
93+
# a flat list of block indices like the definition below.
9194
function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
9295
# TODO: Check that `i.indices` is consistent with `S.indices`.
9396
# It seems like this isn't handling the case where `i` is a
@@ -96,6 +99,13 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
9699
return _blockslice(S.blocks[Int(Block(i))], S.indices[Block(i)])
97100
end
98101

102+
function Base.getindex(
103+
S::BlockIndices{<:AbstractBlockVector{<:BlockIndex{1}}}, i::BlockSlice{<:Block{1}}
104+
)
105+
@assert length(S.indices[Block(i)]) == length(i.indices)
106+
return _blockslice(S.blocks[Block(i)], S.indices[Block(i)])
107+
end
108+
99109
# This is used in slicing like:
100110
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
101111
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
@@ -185,6 +195,7 @@ const GenericBlockIndexVectorSlices = BlockIndices{
185195
<:BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector}}
186196
}
187197
const SubBlockSliceCollection = Union{
198+
Base.Slice,
188199
BlockIndexRangeSlice,
189200
BlockIndexRangeSlices,
190201
BlockIndexVectorSlices,

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,15 @@ BlockArrays.Block(b::BlockIndexVector) = b.block
349349

350350
Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy.(a.indices))
351351

352+
# Copied from BlockArrays.BlockIndexRange.
353+
function Base.show(io::IO, B::BlockIndexVector)
354+
show(io, Block(B))
355+
print(io, "[")
356+
print_tuple_elements(io, B.indices)
357+
print(io, "]")
358+
end
359+
Base.show(io::IO, ::MIME"text/plain", B::BlockIndexVector) = show(io, B)
360+
352361
function Base.getindex(b::AbstractBlockedUnitRange, Kkr::BlockIndexVector{1})
353362
return b[block(Kkr)][Kkr.indices...]
354363
end

src/BlockArraysExtensions/blockrange.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ function Base.getindex(r::BlockUnitRange, I::Block{1})
1616
return eachblockaxis(r)[Int(I)] .+ (first(r.r[I]) - 1)
1717
end
1818

19+
using BlockArrays: BlockedOneTo
20+
const BlockOneTo{T<:Integer,B,CS,R<:BlockedOneTo{T,CS}} = BlockUnitRange{T,B,CS,R}
21+
Base.axes(S::Base.Slice{<:BlockOneTo}) = (S.indices,)
22+
Base.axes1(S::Base.Slice{<:BlockOneTo}) = S.indices
23+
Base.unsafe_indices(S::Base.Slice{<:BlockOneTo}) = (S.indices,)
24+
1925
function BlockArrays.combine_blockaxes(r1::BlockUnitRange, r2::BlockUnitRange)
2026
if eachblockaxis(r1) eachblockaxis(r2)
2127
return throw(ArgumentError("BlockUnitRanges must have the same block axes"))

src/abstractblocksparsearray/views.jl

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,26 @@ function Base.view(
190190
return viewblock(a, block...)
191191
end
192192

193+
# Disambiguate between block reindexing of blockwise views
194+
# (`BlockSliceCollection`) and subblockwise views (`SubBlockSliceCollection`),
195+
# which both include `Base.Slice`.
196+
function Base.view(
197+
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{Base.Slice,N}}},
198+
block::Block{N},
199+
) where {T,N}
200+
return viewblock(a, block)
201+
end
202+
203+
# Block reindexing of blockwise views (`BlockSliceCollection`).
204+
function viewblock_blockslice(a::SubArray{<:Any,N}, block::Vararg{Block{1},N}) where {N}
205+
I = CartesianIndex(Int.(block))
206+
# TODO: Use `eachblockstoredindex`.
207+
if I eachstoredindex(blocks(a))
208+
return blocks(a)[I]
209+
end
210+
return BlockView(parent(a), Block.(Base.reindex(parentindices(blocks(a)), Tuple(I))))
211+
end
212+
193213
# XXX: TODO: Distinguish if a sub-view of the block needs to be taken!
194214
# Define a new `SubBlockSlice` which is used in:
195215
# `@interface interface(a) to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}})`
@@ -199,12 +219,17 @@ function BlockArrays.viewblock(
199219
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockSliceCollection,N}}},
200220
block::Vararg{Block{1},N},
201221
) where {T,N}
202-
I = CartesianIndex(Int.(block))
203-
# TODO: Use `eachblockstoredindex`.
204-
if I eachstoredindex(blocks(a))
205-
return blocks(a)[I]
206-
end
207-
return BlockView(parent(a), Block.(Base.reindex(parentindices(blocks(a)), Tuple(I))))
222+
return viewblock_blockslice(a, block...)
223+
end
224+
225+
# Disambiguate between block reindexing of blockwise views
226+
# (`BlockSliceCollection`) and subblockwise views (`SubBlockSliceCollection`),
227+
# which both include `Base.Slice`.
228+
function BlockArrays.viewblock(
229+
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{Base.Slice,N}}},
230+
block::Vararg{Block{1},N},
231+
) where {T,N}
232+
return viewblock_blockslice(a, block...)
208233
end
209234

210235
function to_blockindexrange(
@@ -291,6 +316,15 @@ function Base.view(
291316
) where {T,N}
292317
return viewblock(a, block...)
293318
end
319+
# Fix ambiguity error.
320+
function Base.view(
321+
a::SubArray{
322+
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
323+
},
324+
block::Vararg{Block{1},N},
325+
) where {T,N}
326+
return viewblock(a, block...)
327+
end
294328
function BlockArrays.viewblock(
295329
a::SubArray{
296330
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
@@ -302,6 +336,14 @@ end
302336

303337
blockedslice_blocks(x::BlockSlice) = x.block
304338
blockedslice_blocks(x::BlockIndices) = x.blocks
339+
# Reinterpret the slice blockwise.
340+
function blockedslice_blocks(x::Base.Slice)
341+
return mortar(
342+
map(BlockRange(x.indices)) do b
343+
return BlockIndexRange(b, Base.Slice(Base.axes1(x.indices[b])))
344+
end,
345+
)
346+
end
305347

306348
# TODO: Define `@interface interface(a) viewblock`.
307349
function BlockArrays.viewblock(
@@ -319,6 +361,7 @@ function BlockArrays.viewblock(
319361
end
320362
return @view parent(a)[brs...]
321363
end
364+
322365
# TODO: Define `@interface interface(a) viewblock`.
323366
function BlockArrays.viewblock(
324367
a::SubArray{

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,19 +222,25 @@ end
222222

223223
# a[mortar([Block(1)[1:2], Block(2)[1:3]]), mortar([Block(1)[1:2], Block(2)[1:3]])]
224224
# a[[Block(1)[1:2], Block(2)[1:3]], [Block(1)[1:2], Block(2)[1:3]]]
225-
@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
226-
a, inds, I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},Vararg{Any}}
227-
)
228-
I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1]))
229-
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
230-
end
231-
232225
@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
233226
a,
234227
inds,
235-
I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector{1}}},Vararg{Any}},
228+
I::Tuple{
229+
BlockVector{<:BlockIndex{1},<:Vector{<:Union{BlockIndexRange{1},BlockIndexVector{1}}}},
230+
Vararg{Any},
231+
},
236232
)
237-
I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1]))
233+
# Index the `inds` by the `BlockIndexRange`/`BlockIndexVector`s on each block
234+
# in order to canonicalize the indices and preserve metadata,
235+
# such as sector data for symmetric tensors.
236+
bs = mortar(
237+
map(blocks(I[1])) do bi
238+
b = Block(bi)
239+
binds = only(bi.indices)
240+
return BlockIndexVector(b, Base.axes1(inds[1][b])[binds])
241+
end,
242+
)
243+
I1 = BlockIndices(bs, blockedunitrange_getindices(inds[1], I[1]))
238244
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
239245
end
240246
@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(

0 commit comments

Comments
 (0)