Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,18 @@ check_ptr_indexable(a::AbstractArray, sz) = false
@propagate_inbounds isassigned(a::ReinterpretArray, inds::SCartesianIndex2) = isassigned(a.parent, inds.j)
@propagate_inbounds _isassigned_ra(a::ReinterpretArray, inds...) = true # that is not entirely true, but computing exactly which indexes will be accessed in the parent requires a lot of duplication from the _getindex_ra code

@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,1,S}, i::Int) where {T,S}
check_readable(a)
if check_ptr_indexable(a)
@boundscheck checkbounds(a, i)
li = _to_linear_index(a, i)
GC.@preserve a begin
return unsafe_load(pointer(a), li)
end
end
_getindex_ra(a, i, ())
end

@propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S}
check_readable(a)
check_ptr_indexable(a) && return _getindex_ptr(a, inds...)
Expand Down Expand Up @@ -440,9 +452,9 @@ end
@inline function _getindex_ptr(a::ReinterpretArray{T}, inds...) where {T}
@boundscheck checkbounds(a, inds...)
li = _to_linear_index(a, inds...)
ap = cconvert(Ptr{T}, a)
p = unsafe_convert(Ptr{T}, ap) + sizeof(T) * (li - 1)
GC.@preserve ap return unsafe_load(p)
GC.@preserve a begin
return unsafe_load(pointer(a), li)
end
end

@propagate_inbounds function _getindex_ra(a::NonReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
Expand Down Expand Up @@ -558,6 +570,21 @@ end

@propagate_inbounds setindex!(a::ReshapedReinterpretArray{T,0}, v) where {T} = setindex!(a, v, firstindex(a))

# Specialized 1D version for SIMD optimization
# This avoids Vararg dispatch and enables better vectorization
@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,1,S}, v, i::Int) where {T,S}
check_writable(a)
if check_ptr_indexable(a)
@boundscheck checkbounds(a, i)
li = _to_linear_index(a, i)
GC.@preserve a begin
unsafe_store!(pointer(a), v, li)
end
return a
end
_setindex_ra!(a, v, i, ())
end

@propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S}
check_writable(a)
check_ptr_indexable(a) && return _setindex_ptr!(a, v, inds...)
Expand Down Expand Up @@ -589,9 +616,9 @@ end
@inline function _setindex_ptr!(a::ReinterpretArray{T}, v, inds...) where {T}
@boundscheck checkbounds(a, inds...)
li = _to_linear_index(a, inds...)
ap = cconvert(Ptr{T}, a)
p = unsafe_convert(Ptr{T}, ap) + sizeof(T) * (li - 1)
GC.@preserve ap unsafe_store!(p, v)
GC.@preserve a begin
unsafe_store!(pointer(a), v, li)
end
return a
end

Expand Down