3030#include < type_traits>
3131
3232#include " dpctl_tensor_types.hpp"
33+ #include " utils/indexing_utils.hpp"
3334#include " utils/offset_utils.hpp"
3435#include " utils/type_utils.hpp"
3536
@@ -42,54 +43,10 @@ namespace kernels
4243namespace indexing
4344{
4445
45- using namespace dpctl ::tensor::offset_utils;
46-
47- template <typename ProjectorT,
48- typename OrthogStrider,
49- typename IndicesStrider,
50- typename AxesStrider,
51- typename T,
52- typename indT>
53- class take_kernel ;
5446template <typename ProjectorT,
55- typename OrthogStrider,
56- typename IndicesStrider,
57- typename AxesStrider,
58- typename T,
59- typename indT>
60- class put_kernel ;
61-
62- class WrapIndex
63- {
64- public:
65- WrapIndex () = default ;
66-
67- void operator ()(ssize_t max_item, ssize_t &ind) const
68- {
69- max_item = std::max<ssize_t >(max_item, 1 );
70- ind = sycl::clamp<ssize_t >(ind, -max_item, max_item - 1 );
71- ind = (ind < 0 ) ? ind + max_item : ind;
72- return ;
73- }
74- };
75-
76- class ClipIndex
77- {
78- public:
79- ClipIndex () = default ;
80-
81- void operator ()(ssize_t max_item, ssize_t &ind) const
82- {
83- max_item = std::max<ssize_t >(max_item, 1 );
84- ind = sycl::clamp<ssize_t >(ind, ssize_t (0 ), max_item - 1 );
85- return ;
86- }
87- };
88-
89- template <typename ProjectorT,
90- typename OrthogStrider,
91- typename IndicesStrider,
92- typename AxesStrider,
47+ typename OrthogIndexer,
48+ typename IndicesIndexer,
49+ typename AxesIndexer,
9350 typename T,
9451 typename indT>
9552class TakeFunctor
@@ -101,9 +58,9 @@ class TakeFunctor
10158 int k_ = 0 ;
10259 size_t ind_nelems_ = 0 ;
10360 const ssize_t *axes_shape_and_strides_ = nullptr ;
104- const OrthogStrider orthog_strider;
105- const IndicesStrider ind_strider;
106- const AxesStrider axes_strider;
61+ const OrthogIndexer orthog_strider;
62+ const IndicesIndexer ind_strider;
63+ const AxesIndexer axes_strider;
10764
10865public:
10966 TakeFunctor (const char *src_cp,
@@ -112,9 +69,9 @@ class TakeFunctor
11269 int k,
11370 size_t ind_nelems,
11471 const ssize_t *axes_shape_and_strides,
115- const OrthogStrider &orthog_strider_,
116- const IndicesStrider &ind_strider_,
117- const AxesStrider &axes_strider_)
72+ const OrthogIndexer &orthog_strider_,
73+ const IndicesIndexer &ind_strider_,
74+ const AxesIndexer &axes_strider_)
11875 : src_(src_cp), dst_(dst_cp), ind_(ind_cp), k_(k),
11976 ind_nelems_ (ind_nelems),
12077 axes_shape_and_strides_(axes_shape_and_strides),
@@ -136,16 +93,16 @@ class TakeFunctor
13693 ssize_t src_offset = orthog_offsets.get_first_offset ();
13794 ssize_t dst_offset = orthog_offsets.get_second_offset ();
13895
139- const ProjectorT proj{};
96+ constexpr ProjectorT proj{};
14097 for (int axis_idx = 0 ; axis_idx < k_; ++axis_idx) {
14198 indT *ind_data = reinterpret_cast <indT *>(ind_[axis_idx]);
14299
143100 ssize_t ind_offset = ind_strider (i_along, axis_idx);
144- ssize_t i = static_cast < ssize_t >(ind_data[ind_offset]);
145-
146- proj (axes_shape_and_strides_[axis_idx], i );
147-
148- src_offset += i * axes_shape_and_strides_[k_ + axis_idx];
101+ // proj produces an index in the range of the given axis
102+ ssize_t projected_idx =
103+ proj (axes_shape_and_strides_[axis_idx], ind_data[ind_offset] );
104+ src_offset +=
105+ projected_idx * axes_shape_and_strides_[k_ + axis_idx];
149106 }
150107
151108 dst_offset += axes_strider (i_along);
@@ -154,6 +111,14 @@ class TakeFunctor
154111 }
155112};
156113
114+ template <typename ProjectorT,
115+ typename OrthogIndexer,
116+ typename IndicesIndexer,
117+ typename AxesIndexer,
118+ typename T,
119+ typename indT>
120+ class take_kernel ;
121+
157122typedef sycl::event (*take_fn_ptr_t )(sycl::queue &,
158123 size_t ,
159124 size_t ,
@@ -194,21 +159,29 @@ sycl::event take_impl(sycl::queue &q,
194159 sycl::event take_ev = q.submit ([&](sycl::handler &cgh) {
195160 cgh.depends_on (depends);
196161
197- const TwoOffsets_StridedIndexer orthog_indexer{
198- nd, src_offset, dst_offset, orthog_shape_and_strides};
199- const NthStrideOffset indices_indexer{ind_nd, ind_offsets,
200- ind_shape_and_strides};
201- const StridedIndexer axes_indexer{ind_nd, 0 ,
202- axes_shape_and_strides + (2 * k)};
162+ using OrthogIndexerT =
163+ dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
164+ const OrthogIndexerT orthog_indexer{nd, src_offset, dst_offset,
165+ orthog_shape_and_strides};
166+
167+ using NthStrideIndexerT = dpctl::tensor::offset_utils::NthStrideOffset;
168+ const NthStrideIndexerT indices_indexer{ind_nd, ind_offsets,
169+ ind_shape_and_strides};
170+
171+ using AxesIndexerT = dpctl::tensor::offset_utils::StridedIndexer;
172+ const AxesIndexerT axes_indexer{ind_nd, 0 ,
173+ axes_shape_and_strides + (2 * k)};
174+
175+ using KernelName =
176+ take_kernel<ProjectorT, OrthogIndexerT, NthStrideIndexerT,
177+ AxesIndexerT, Ty, indT>;
203178
204179 const size_t gws = orthog_nelems * ind_nelems;
205180
206- cgh.parallel_for <
207- take_kernel<ProjectorT, TwoOffsets_StridedIndexer, NthStrideOffset,
208- StridedIndexer, Ty, indT>>(
181+ cgh.parallel_for <KernelName>(
209182 sycl::range<1 >(gws),
210- TakeFunctor<ProjectorT, TwoOffsets_StridedIndexer, NthStrideOffset ,
211- StridedIndexer , Ty, indT>(
183+ TakeFunctor<ProjectorT, OrthogIndexerT, NthStrideIndexerT ,
184+ AxesIndexerT , Ty, indT>(
212185 src_p, dst_p, ind_p, k, ind_nelems, axes_shape_and_strides,
213186 orthog_indexer, indices_indexer, axes_indexer));
214187 });
@@ -217,9 +190,9 @@ sycl::event take_impl(sycl::queue &q,
217190}
218191
219192template <typename ProjectorT,
220- typename OrthogStrider ,
221- typename IndicesStrider ,
222- typename AxesStrider ,
193+ typename OrthogIndexer ,
194+ typename IndicesIndexer ,
195+ typename AxesIndexer ,
223196 typename T,
224197 typename indT>
225198class PutFunctor
@@ -231,9 +204,9 @@ class PutFunctor
231204 int k_ = 0 ;
232205 size_t ind_nelems_ = 0 ;
233206 const ssize_t *axes_shape_and_strides_ = nullptr ;
234- const OrthogStrider orthog_strider;
235- const IndicesStrider ind_strider;
236- const AxesStrider axes_strider;
207+ const OrthogIndexer orthog_strider;
208+ const IndicesIndexer ind_strider;
209+ const AxesIndexer axes_strider;
237210
238211public:
239212 PutFunctor (char *dst_cp,
@@ -242,9 +215,9 @@ class PutFunctor
242215 int k,
243216 size_t ind_nelems,
244217 const ssize_t *axes_shape_and_strides,
245- const OrthogStrider &orthog_strider_,
246- const IndicesStrider &ind_strider_,
247- const AxesStrider &axes_strider_)
218+ const OrthogIndexer &orthog_strider_,
219+ const IndicesIndexer &ind_strider_,
220+ const AxesIndexer &axes_strider_)
248221 : dst_(dst_cp), val_(val_cp), ind_(ind_cp), k_(k),
249222 ind_nelems_ (ind_nelems),
250223 axes_shape_and_strides_(axes_shape_and_strides),
@@ -266,16 +239,17 @@ class PutFunctor
266239 ssize_t dst_offset = orthog_offsets.get_first_offset ();
267240 ssize_t val_offset = orthog_offsets.get_second_offset ();
268241
269- const ProjectorT proj{};
242+ constexpr ProjectorT proj{};
270243 for (int axis_idx = 0 ; axis_idx < k_; ++axis_idx) {
271244 indT *ind_data = reinterpret_cast <indT *>(ind_[axis_idx]);
272245
273246 ssize_t ind_offset = ind_strider (i_along, axis_idx);
274- ssize_t i = static_cast <ssize_t >(ind_data[ind_offset]);
275-
276- proj (axes_shape_and_strides_[axis_idx], i);
277247
278- dst_offset += i * axes_shape_and_strides_[k_ + axis_idx];
248+ // proj produces an index in the range of the given axis
249+ ssize_t projected_idx =
250+ proj (axes_shape_and_strides_[axis_idx], ind_data[ind_offset]);
251+ dst_offset +=
252+ projected_idx * axes_shape_and_strides_[k_ + axis_idx];
279253 }
280254
281255 val_offset += axes_strider (i_along);
@@ -284,6 +258,14 @@ class PutFunctor
284258 }
285259};
286260
261+ template <typename ProjectorT,
262+ typename OrthogIndexer,
263+ typename IndicesIndexer,
264+ typename AxesIndexer,
265+ typename T,
266+ typename indT>
267+ class put_kernel ;
268+
287269typedef sycl::event (*put_fn_ptr_t )(sycl::queue &,
288270 size_t ,
289271 size_t ,
@@ -324,20 +306,29 @@ sycl::event put_impl(sycl::queue &q,
324306 sycl::event put_ev = q.submit ([&](sycl::handler &cgh) {
325307 cgh.depends_on (depends);
326308
327- const TwoOffsets_StridedIndexer orthog_indexer{
328- nd, dst_offset, val_offset, orthog_shape_and_strides};
329- const NthStrideOffset indices_indexer{ind_nd, ind_offsets,
330- ind_shape_and_strides};
331- const StridedIndexer axes_indexer{ind_nd, 0 ,
332- axes_shape_and_strides + (2 * k)};
309+ using OrthogIndexerT =
310+ dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
311+ const OrthogIndexerT orthog_indexer{nd, dst_offset, val_offset,
312+ orthog_shape_and_strides};
313+
314+ using NthStrideIndexerT = dpctl::tensor::offset_utils::NthStrideOffset;
315+ const NthStrideIndexerT indices_indexer{ind_nd, ind_offsets,
316+ ind_shape_and_strides};
317+
318+ using AxesIndexerT = dpctl::tensor::offset_utils::StridedIndexer;
319+ const AxesIndexerT axes_indexer{ind_nd, 0 ,
320+ axes_shape_and_strides + (2 * k)};
321+
322+ using KernelName =
323+ put_kernel<ProjectorT, OrthogIndexerT, NthStrideIndexerT,
324+ AxesIndexerT, Ty, indT>;
333325
334326 const size_t gws = orthog_nelems * ind_nelems;
335327
336- cgh.parallel_for <put_kernel<ProjectorT, TwoOffsets_StridedIndexer,
337- NthStrideOffset, StridedIndexer, Ty, indT>>(
328+ cgh.parallel_for <KernelName>(
338329 sycl::range<1 >(gws),
339- PutFunctor<ProjectorT, TwoOffsets_StridedIndexer, NthStrideOffset ,
340- StridedIndexer , Ty, indT>(
330+ PutFunctor<ProjectorT, OrthogIndexerT, NthStrideIndexerT ,
331+ AxesIndexerT , Ty, indT>(
341332 dst_p, val_p, ind_p, k, ind_nelems, axes_shape_and_strides,
342333 orthog_indexer, indices_indexer, axes_indexer));
343334 });
@@ -352,7 +343,8 @@ template <typename fnT, typename T, typename indT> struct TakeWrapFactory
352343 if constexpr (std::is_integral<indT>::value &&
353344 !std::is_same<indT, bool >::value)
354345 {
355- fnT fn = take_impl<WrapIndex, T, indT>;
346+ using dpctl::tensor::indexing_utils::WrapIndex;
347+ fnT fn = take_impl<WrapIndex<indT>, T, indT>;
356348 return fn;
357349 }
358350 else {
@@ -369,7 +361,8 @@ template <typename fnT, typename T, typename indT> struct TakeClipFactory
369361 if constexpr (std::is_integral<indT>::value &&
370362 !std::is_same<indT, bool >::value)
371363 {
372- fnT fn = take_impl<ClipIndex, T, indT>;
364+ using dpctl::tensor::indexing_utils::ClipIndex;
365+ fnT fn = take_impl<ClipIndex<indT>, T, indT>;
373366 return fn;
374367 }
375368 else {
@@ -386,7 +379,8 @@ template <typename fnT, typename T, typename indT> struct PutWrapFactory
386379 if constexpr (std::is_integral<indT>::value &&
387380 !std::is_same<indT, bool >::value)
388381 {
389- fnT fn = put_impl<WrapIndex, T, indT>;
382+ using dpctl::tensor::indexing_utils::WrapIndex;
383+ fnT fn = put_impl<WrapIndex<indT>, T, indT>;
390384 return fn;
391385 }
392386 else {
@@ -403,7 +397,8 @@ template <typename fnT, typename T, typename indT> struct PutClipFactory
403397 if constexpr (std::is_integral<indT>::value &&
404398 !std::is_same<indT, bool >::value)
405399 {
406- fnT fn = put_impl<ClipIndex, T, indT>;
400+ using dpctl::tensor::indexing_utils::ClipIndex;
401+ fnT fn = put_impl<ClipIndex<indT>, T, indT>;
407402 return fn;
408403 }
409404 else {
0 commit comments