Skip to content

Commit 8ddcb4a

Browse files
Merge pull request #327 from AayushSabharwal/as/checkbounds
fix: fix checkbounds, view methods, indexing, and add tests
2 parents 33dd4ea + 9afae5a commit 8ddcb4a

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

src/vector_of_array.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ Base.IndexStyle(::Type{<:AbstractVectorOfArray}) = IndexCartesian()
220220
@inline function Base.eachindex(VA::AbstractVectorOfArray)
221221
return eachindex(VA.u)
222222
end
223+
@inline function Base.eachindex(::IndexLinear, VA::AbstractVectorOfArray{T,N,<:AbstractVector{T}}) where {T, N}
224+
return eachindex(IndexLinear(), VA.u)
225+
end
223226
@inline Base.IteratorSize(::Type{<:AbstractVectorOfArray}) = Base.HasLength()
224227
@inline Base.first(VA::AbstractVectorOfArray) = first(VA.u)
225228
@inline Base.last(VA::AbstractVectorOfArray) = last(VA.u)
@@ -245,7 +248,11 @@ __parameterless_type(T) = Base.typename(T).wrapper
245248
Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N},
246249
::NotSymbolic, I::Colon...) where {T, N}
247250
@assert length(I) == ndims(A.u[1]) + 1
248-
vecs = vec.(A.u)
251+
vecs = if N == 1
252+
A.u
253+
else
254+
vec.(A.u)
255+
end
249256
return Adapt.adapt(__parameterless_type(T),
250257
reshape(reduce(hcat, vecs), size(A.u[1])..., length(A.u)))
251258
end
@@ -496,6 +503,16 @@ function Base.stack(VA::AbstractVectorOfArray; dims = :)
496503
end
497504

498505
# AbstractArray methods
506+
function Base.view(A::AbstractVectorOfArray{T,N,<:AbstractVector{T}}, I::Vararg{Any, M}) where {T,N,M}
507+
@inline
508+
if length(I) == 1
509+
J = map(i->Base.unalias(A,i), to_indices(A, I))
510+
elseif length(I) == 2 && (I[1] == Colon() || I[1] == 1)
511+
J = map(i->Base.unalias(A,i), to_indices(A, Base.tail(I)))
512+
end
513+
@boundscheck checkbounds(A, J...)
514+
SubArray(IndexStyle(A), A, J, Base.index_dimsum(J...))
515+
end
499516
function Base.view(A::AbstractVectorOfArray, I::Vararg{Any,M}) where {M}
500517
@inline
501518
J = map(i->Base.unalias(A,i), to_indices(A, I))
@@ -509,6 +526,13 @@ end
509526
Base.isassigned(VA::AbstractVectorOfArray, idxs...) = checkbounds(Bool, VA, idxs...)
510527
Base.check_parent_index_match(::RecursiveArrayTools.AbstractVectorOfArray{T,N}, ::NTuple{N,Bool}) where {T,N} = nothing
511528
Base.ndims(::AbstractVectorOfArray{T, N}) where {T, N} = N
529+
530+
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, idxs...) where {T, N}
531+
if length(idxs) == 2 && (idxs[1] == Colon() || idxs[1] == 1)
532+
return checkbounds(Bool, VA.u, idxs[2])
533+
end
534+
return checkbounds(Bool, VA.u, idxs...)
535+
end
512536
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
513537
if checkbounds(Bool, VA.u, last(idx))
514538
if last(idx) isa Integer

test/interface_tests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,20 @@ for idxs in [(2, 2, :), (2, :, 2), (:, 2, 2), (:, :, 2), (:, 2, :), (2, : ,:), (
9191
@test all(arr_view .== voa_view)
9292
end
9393

94+
testvc = VectorOfArray(collect(1:10))
95+
arrvc = Array(testvc)
96+
for (voaidx, arridx) in [
97+
((:,), (:,)),
98+
((3:5,), (3:5,)),
99+
((:, 3:5), (3:5,)),
100+
((1, 3:5), (3:5,)),
101+
]
102+
arr_view = view(arrvc, arridx...)
103+
voa_view = view(testvc, voaidx...)
104+
@test size(arr_view) == size(voa_view)
105+
@test all(arr_view .== voa_view)
106+
end
107+
94108
# test stack
95109
@test stack(testva) == [1 4 7; 2 5 8; 3 6 9]
96110
@test stack(testva; dims = 1) == [1 2 3; 4 5 6; 7 8 9]

0 commit comments

Comments
 (0)