Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 91 additions & 96 deletions dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <type_traits>

#include "dpctl_tensor_types.hpp"
#include "utils/indexing_utils.hpp"
#include "utils/offset_utils.hpp"
#include "utils/type_utils.hpp"

Expand All @@ -42,54 +43,10 @@ namespace kernels
namespace indexing
{

using namespace dpctl::tensor::offset_utils;

template <typename ProjectorT,
typename OrthogStrider,
typename IndicesStrider,
typename AxesStrider,
typename T,
typename indT>
class take_kernel;
template <typename ProjectorT,
typename OrthogStrider,
typename IndicesStrider,
typename AxesStrider,
typename T,
typename indT>
class put_kernel;

class WrapIndex
{
public:
WrapIndex() = default;

void operator()(ssize_t max_item, ssize_t &ind) const
{
max_item = std::max<ssize_t>(max_item, 1);
ind = sycl::clamp<ssize_t>(ind, -max_item, max_item - 1);
ind = (ind < 0) ? ind + max_item : ind;
return;
}
};

class ClipIndex
{
public:
ClipIndex() = default;

void operator()(ssize_t max_item, ssize_t &ind) const
{
max_item = std::max<ssize_t>(max_item, 1);
ind = sycl::clamp<ssize_t>(ind, ssize_t(0), max_item - 1);
return;
}
};

template <typename ProjectorT,
typename OrthogStrider,
typename IndicesStrider,
typename AxesStrider,
typename OrthogIndexer,
typename IndicesIndexer,
typename AxesIndexer,
typename T,
typename indT>
class TakeFunctor
Expand All @@ -101,9 +58,9 @@ class TakeFunctor
int k_ = 0;
size_t ind_nelems_ = 0;
const ssize_t *axes_shape_and_strides_ = nullptr;
const OrthogStrider orthog_strider;
const IndicesStrider ind_strider;
const AxesStrider axes_strider;
const OrthogIndexer orthog_strider;
const IndicesIndexer ind_strider;
const AxesIndexer axes_strider;

public:
TakeFunctor(const char *src_cp,
Expand All @@ -112,9 +69,9 @@ class TakeFunctor
int k,
size_t ind_nelems,
const ssize_t *axes_shape_and_strides,
const OrthogStrider &orthog_strider_,
const IndicesStrider &ind_strider_,
const AxesStrider &axes_strider_)
const OrthogIndexer &orthog_strider_,
const IndicesIndexer &ind_strider_,
const AxesIndexer &axes_strider_)
: src_(src_cp), dst_(dst_cp), ind_(ind_cp), k_(k),
ind_nelems_(ind_nelems),
axes_shape_and_strides_(axes_shape_and_strides),
Expand All @@ -141,11 +98,11 @@ class TakeFunctor
indT *ind_data = reinterpret_cast<indT *>(ind_[axis_idx]);

ssize_t ind_offset = ind_strider(i_along, axis_idx);
ssize_t i = static_cast<ssize_t>(ind_data[ind_offset]);

proj(axes_shape_and_strides_[axis_idx], i);

src_offset += i * axes_shape_and_strides_[k_ + axis_idx];
// proj produces an index in the range of the given axis
ssize_t projected_idx =
proj(axes_shape_and_strides_[axis_idx], ind_data[ind_offset]);
src_offset +=
projected_idx * axes_shape_and_strides_[k_ + axis_idx];
}

dst_offset += axes_strider(i_along);
Expand All @@ -154,6 +111,14 @@ class TakeFunctor
}
};

template <typename ProjectorT,
typename OrthogIndexer,
typename IndicesIndexer,
typename AxesIndexer,
typename T,
typename indT>
class take_kernel;

typedef sycl::event (*take_fn_ptr_t)(sycl::queue &,
size_t,
size_t,
Expand Down Expand Up @@ -194,21 +159,29 @@ sycl::event take_impl(sycl::queue &q,
sycl::event take_ev = q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

const TwoOffsets_StridedIndexer orthog_indexer{
nd, src_offset, dst_offset, orthog_shape_and_strides};
const NthStrideOffset indices_indexer{ind_nd, ind_offsets,
ind_shape_and_strides};
const StridedIndexer axes_indexer{ind_nd, 0,
axes_shape_and_strides + (2 * k)};
using OrthogIndexerT =
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
const OrthogIndexerT orthog_indexer{nd, src_offset, dst_offset,
orthog_shape_and_strides};

using NthStrideIndexerT = dpctl::tensor::offset_utils::NthStrideOffset;
const NthStrideIndexerT indices_indexer{ind_nd, ind_offsets,
ind_shape_and_strides};

using AxesIndexerT = dpctl::tensor::offset_utils::StridedIndexer;
const AxesIndexerT axes_indexer{ind_nd, 0,
axes_shape_and_strides + (2 * k)};

using KernelName =
take_kernel<ProjectorT, OrthogIndexerT, NthStrideIndexerT,
AxesIndexerT, Ty, indT>;

const size_t gws = orthog_nelems * ind_nelems;

cgh.parallel_for<
take_kernel<ProjectorT, TwoOffsets_StridedIndexer, NthStrideOffset,
StridedIndexer, Ty, indT>>(
cgh.parallel_for<KernelName>(
sycl::range<1>(gws),
TakeFunctor<ProjectorT, TwoOffsets_StridedIndexer, NthStrideOffset,
StridedIndexer, Ty, indT>(
TakeFunctor<ProjectorT, OrthogIndexerT, NthStrideIndexerT,
AxesIndexerT, Ty, indT>(
src_p, dst_p, ind_p, k, ind_nelems, axes_shape_and_strides,
orthog_indexer, indices_indexer, axes_indexer));
});
Expand All @@ -217,9 +190,9 @@ sycl::event take_impl(sycl::queue &q,
}

template <typename ProjectorT,
typename OrthogStrider,
typename IndicesStrider,
typename AxesStrider,
typename OrthogIndexer,
typename IndicesIndexer,
typename AxesIndexer,
typename T,
typename indT>
class PutFunctor
Expand All @@ -231,9 +204,9 @@ class PutFunctor
int k_ = 0;
size_t ind_nelems_ = 0;
const ssize_t *axes_shape_and_strides_ = nullptr;
const OrthogStrider orthog_strider;
const IndicesStrider ind_strider;
const AxesStrider axes_strider;
const OrthogIndexer orthog_strider;
const IndicesIndexer ind_strider;
const AxesIndexer axes_strider;

public:
PutFunctor(char *dst_cp,
Expand All @@ -242,9 +215,9 @@ class PutFunctor
int k,
size_t ind_nelems,
const ssize_t *axes_shape_and_strides,
const OrthogStrider &orthog_strider_,
const IndicesStrider &ind_strider_,
const AxesStrider &axes_strider_)
const OrthogIndexer &orthog_strider_,
const IndicesIndexer &ind_strider_,
const AxesIndexer &axes_strider_)
: dst_(dst_cp), val_(val_cp), ind_(ind_cp), k_(k),
ind_nelems_(ind_nelems),
axes_shape_and_strides_(axes_shape_and_strides),
Expand All @@ -271,11 +244,12 @@ class PutFunctor
indT *ind_data = reinterpret_cast<indT *>(ind_[axis_idx]);

ssize_t ind_offset = ind_strider(i_along, axis_idx);
ssize_t i = static_cast<ssize_t>(ind_data[ind_offset]);

proj(axes_shape_and_strides_[axis_idx], i);

dst_offset += i * axes_shape_and_strides_[k_ + axis_idx];
// proj produces an index in the range of the given axis
ssize_t projected_idx =
proj(axes_shape_and_strides_[axis_idx], ind_data[ind_offset]);
dst_offset +=
projected_idx * axes_shape_and_strides_[k_ + axis_idx];
}

val_offset += axes_strider(i_along);
Expand All @@ -284,6 +258,14 @@ class PutFunctor
}
};

template <typename ProjectorT,
typename OrthogIndexer,
typename IndicesIndexer,
typename AxesIndexer,
typename T,
typename indT>
class put_kernel;

typedef sycl::event (*put_fn_ptr_t)(sycl::queue &,
size_t,
size_t,
Expand Down Expand Up @@ -324,20 +306,29 @@ sycl::event put_impl(sycl::queue &q,
sycl::event put_ev = q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

const TwoOffsets_StridedIndexer orthog_indexer{
nd, dst_offset, val_offset, orthog_shape_and_strides};
const NthStrideOffset indices_indexer{ind_nd, ind_offsets,
ind_shape_and_strides};
const StridedIndexer axes_indexer{ind_nd, 0,
axes_shape_and_strides + (2 * k)};
using OrthogIndexerT =
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
const OrthogIndexerT orthog_indexer{nd, dst_offset, val_offset,
orthog_shape_and_strides};

using NthStrideIndexerT = dpctl::tensor::offset_utils::NthStrideOffset;
const NthStrideIndexerT indices_indexer{ind_nd, ind_offsets,
ind_shape_and_strides};

using AxesIndexerT = dpctl::tensor::offset_utils::StridedIndexer;
const AxesIndexerT axes_indexer{ind_nd, 0,
axes_shape_and_strides + (2 * k)};

using KernelName =
put_kernel<ProjectorT, OrthogIndexerT, NthStrideIndexerT,
AxesIndexerT, Ty, indT>;

const size_t gws = orthog_nelems * ind_nelems;

cgh.parallel_for<put_kernel<ProjectorT, TwoOffsets_StridedIndexer,
NthStrideOffset, StridedIndexer, Ty, indT>>(
cgh.parallel_for<KernelName>(
sycl::range<1>(gws),
PutFunctor<ProjectorT, TwoOffsets_StridedIndexer, NthStrideOffset,
StridedIndexer, Ty, indT>(
PutFunctor<ProjectorT, OrthogIndexerT, NthStrideIndexerT,
AxesIndexerT, Ty, indT>(
dst_p, val_p, ind_p, k, ind_nelems, axes_shape_and_strides,
orthog_indexer, indices_indexer, axes_indexer));
});
Expand All @@ -352,7 +343,8 @@ template <typename fnT, typename T, typename indT> struct TakeWrapFactory
if constexpr (std::is_integral<indT>::value &&
!std::is_same<indT, bool>::value)
{
fnT fn = take_impl<WrapIndex, T, indT>;
using dpctl::tensor::indexing_utils::WrapIndex;
fnT fn = take_impl<WrapIndex<indT>, T, indT>;
return fn;
}
else {
Expand All @@ -369,7 +361,8 @@ template <typename fnT, typename T, typename indT> struct TakeClipFactory
if constexpr (std::is_integral<indT>::value &&
!std::is_same<indT, bool>::value)
{
fnT fn = take_impl<ClipIndex, T, indT>;
using dpctl::tensor::indexing_utils::ClipIndex;
fnT fn = take_impl<ClipIndex<indT>, T, indT>;
return fn;
}
else {
Expand All @@ -386,7 +379,8 @@ template <typename fnT, typename T, typename indT> struct PutWrapFactory
if constexpr (std::is_integral<indT>::value &&
!std::is_same<indT, bool>::value)
{
fnT fn = put_impl<WrapIndex, T, indT>;
using dpctl::tensor::indexing_utils::WrapIndex;
fnT fn = put_impl<WrapIndex<indT>, T, indT>;
return fn;
}
else {
Expand All @@ -403,7 +397,8 @@ template <typename fnT, typename T, typename indT> struct PutClipFactory
if constexpr (std::is_integral<indT>::value &&
!std::is_same<indT, bool>::value)
{
fnT fn = put_impl<ClipIndex, T, indT>;
using dpctl::tensor::indexing_utils::ClipIndex;
fnT fn = put_impl<ClipIndex<indT>, T, indT>;
return fn;
}
else {
Expand Down
Loading
Loading