@@ -70,6 +70,7 @@ struct UnaryContigFunctor
7070
7171 void operator ()(sycl::nd_item<1 > ndit) const
7272 {
73+ constexpr std::uint32_t elems_per_wi = n_vecs * vec_sz;
7374 UnaryOperatorT op{};
7475 /* Each work-item processes vec_sz elements, contiguous in memory */
7576 /* NOTE: work-group size must be divisible by sub-group size */
@@ -80,14 +81,15 @@ struct UnaryContigFunctor
8081 constexpr resT const_val = UnaryOperatorT::constant_value;
8182
8283 auto sg = ndit.get_sub_group ();
83- std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
84- size_t base = n_vecs * vec_sz *
84+ std::uint32_t sgSize = sg.get_max_local_range ()[0 ];
85+
86+ size_t base = static_cast <size_t >(elems_per_wi) *
8587 (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
8688 sg.get_group_id ()[0 ] * sgSize);
87- if (base + n_vecs * vec_sz * sgSize < nelems_) {
89+ if (base + static_cast < size_t >(elems_per_wi * sgSize) < nelems_) {
8890 sycl::vec<resT, vec_sz> res_vec (const_val);
8991#pragma unroll
90- for (std::uint8_t it = 0 ; it < n_vecs * vec_sz ; it += vec_sz) {
92+ for (std::uint32_t it = 0 ; it < elems_per_wi ; it += vec_sz) {
9193 size_t offset = base + static_cast <size_t >(it) *
9294 static_cast <size_t >(sgSize);
9395 auto out_multi_ptr = sycl::address_space_cast<
@@ -98,9 +100,8 @@ struct UnaryContigFunctor
98100 }
99101 }
100102 else {
101- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
102- k += sgSize)
103- {
103+ const size_t lane_id = sg.get_local_id ()[0 ];
104+ for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
104105 out[k] = const_val;
105106 }
106107 }
@@ -110,15 +111,16 @@ struct UnaryContigFunctor
110111 UnaryOperatorT::supports_vec::value)
111112 {
112113 auto sg = ndit.get_sub_group ();
113- std::uint16_t sgSize = sg.get_max_local_range ()[0 ];
114- size_t base = n_vecs * vec_sz *
114+ std::uint32_t sgSize = sg.get_max_local_range ()[0 ];
115+
116+ size_t base = static_cast <size_t >(elems_per_wi) *
115117 (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
116118 sg.get_group_id ()[0 ] * sgSize);
117- if (base + n_vecs * vec_sz * sgSize < nelems_) {
119+ if (base + static_cast < size_t >(elems_per_wi * sgSize) < nelems_) {
118120 sycl::vec<argT, vec_sz> x;
119121
120122#pragma unroll
121- for (std::uint16_t it = 0 ; it < n_vecs * vec_sz ; it += vec_sz) {
123+ for (std::uint32_t it = 0 ; it < elems_per_wi ; it += vec_sz) {
122124 size_t offset = base + static_cast <size_t >(it) *
123125 static_cast <size_t >(sgSize);
124126 auto in_multi_ptr = sycl::address_space_cast<
@@ -134,9 +136,8 @@ struct UnaryContigFunctor
134136 }
135137 }
136138 else {
137- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
138- k += sgSize)
139- {
139+ const size_t lane_id = sg.get_local_id ()[0 ];
140+ for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
140141 // scalar call
141142 out[k] = op (in[k]);
142143 }
@@ -149,16 +150,16 @@ struct UnaryContigFunctor
149150 // default: use scalar-value function
150151
151152 auto sg = ndit.get_sub_group ();
152- std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
153- size_t base = n_vecs * vec_sz *
153+ std::uint32_t sgSize = sg.get_max_local_range ()[0 ];
154+ size_t base = static_cast < size_t >(elems_per_wi) *
154155 (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
155156 sg.get_group_id ()[0 ] * sgSize);
156157
157- if (base + n_vecs * vec_sz * sgSize < nelems_) {
158+ if (base + static_cast < size_t >(elems_per_wi * sgSize) < nelems_) {
158159 sycl::vec<argT, vec_sz> arg_vec;
159160
160161#pragma unroll
161- for (std::uint8_t it = 0 ; it < n_vecs * vec_sz ; it += vec_sz) {
162+ for (std::uint32_t it = 0 ; it < elems_per_wi ; it += vec_sz) {
162163 size_t offset = base + static_cast <size_t >(it) *
163164 static_cast <size_t >(sgSize);
164165 auto in_multi_ptr = sycl::address_space_cast<
@@ -170,16 +171,15 @@ struct UnaryContigFunctor
170171
171172 arg_vec = sg.load <vec_sz>(in_multi_ptr);
172173#pragma unroll
173- for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
174+ for (std::uint32_t k = 0 ; k < vec_sz; ++k) {
174175 arg_vec[k] = op (arg_vec[k]);
175176 }
176177 sg.store <vec_sz>(out_multi_ptr, arg_vec);
177178 }
178179 }
179180 else {
180- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
181- k += sgSize)
182- {
181+ const size_t lane_id = sg.get_local_id ()[0 ];
182+ for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
183183 out[k] = op (in[k]);
184184 }
185185 }
@@ -190,17 +190,17 @@ struct UnaryContigFunctor
190190 // default: use scalar-value function
191191
192192 auto sg = ndit.get_sub_group ();
193- std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
194- size_t base = n_vecs * vec_sz *
193+ std::uint32_t sgSize = sg.get_max_local_range ()[0 ];
194+ size_t base = static_cast < size_t >(elems_per_wi) *
195195 (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
196196 sg.get_group_id ()[0 ] * sgSize);
197197
198- if (base + n_vecs * vec_sz * sgSize < nelems_) {
198+ if (base + static_cast < size_t >(elems_per_wi * sgSize) < nelems_) {
199199 sycl::vec<argT, vec_sz> arg_vec;
200200 sycl::vec<resT, vec_sz> res_vec;
201201
202202#pragma unroll
203- for (std::uint8_t it = 0 ; it < n_vecs * vec_sz ; it += vec_sz) {
203+ for (std::uint32_t it = 0 ; it < elems_per_wi ; it += vec_sz) {
204204 size_t offset = base + static_cast <size_t >(it) *
205205 static_cast <size_t >(sgSize);
206206 auto in_multi_ptr = sycl::address_space_cast<
@@ -212,27 +212,27 @@ struct UnaryContigFunctor
212212
213213 arg_vec = sg.load <vec_sz>(in_multi_ptr);
214214#pragma unroll
215- for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
215+ for (std::uint32_t k = 0 ; k < vec_sz; ++k) {
216216 res_vec[k] = op (arg_vec[k]);
217217 }
218218 sg.store <vec_sz>(out_multi_ptr, res_vec);
219219 }
220220 }
221221 else {
222- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
223- k += sgSize)
224- {
222+ const size_t lane_id = sg.get_local_id ()[0 ];
223+ for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
225224 out[k] = op (in[k]);
226225 }
227226 }
228227 }
229228 else {
230- std:: uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
229+ size_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
231230 size_t base = ndit.get_global_linear_id ();
231+ const size_t elems_per_sg = sgSize * elems_per_wi;
232232
233- base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
233+ base = (base / sgSize) * elems_per_sg + (base % sgSize);
234234 for (size_t offset = base;
235- offset < std::min (nelems_, base + sgSize * (n_vecs * vec_sz) );
235+ offset < std::min (nelems_, base + elems_per_sg );
236236 offset += sgSize)
237237 {
238238 out[offset] = op (in[offset]);
@@ -392,6 +392,7 @@ struct BinaryContigFunctor
392392
393393 void operator ()(sycl::nd_item<1 > ndit) const
394394 {
395+ constexpr std::uint32_t elems_per_wi = n_vecs * vec_sz;
395396 BinaryOperatorT op{};
396397 /* Each work-item processes vec_sz elements, contiguous in memory */
397398 /* NOTE: work-group size must be divisible by sub-group size */
@@ -401,19 +402,19 @@ struct BinaryContigFunctor
401402 BinaryOperatorT::supports_vec::value)
402403 {
403404 auto sg = ndit.get_sub_group ();
404- std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
405+ std::uint16_t sgSize = sg.get_max_local_range ()[0 ];
405406
406- size_t base = n_vecs * vec_sz *
407+ size_t base = static_cast < size_t >(elems_per_wi) *
407408 (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
408409 sg.get_group_id ()[0 ] * sgSize);
409410
410- if (base + n_vecs * vec_sz * sgSize < nelems_) {
411+ if (base + static_cast < size_t >(elems_per_wi * sgSize) < nelems_) {
411412 sycl::vec<argT1, vec_sz> arg1_vec;
412413 sycl::vec<argT2, vec_sz> arg2_vec;
413414 sycl::vec<resT, vec_sz> res_vec;
414415
415416#pragma unroll
416- for (std::uint8_t it = 0 ; it < n_vecs * vec_sz ; it += vec_sz) {
417+ for (std::uint32_t it = 0 ; it < elems_per_wi ; it += vec_sz) {
417418 size_t offset = base + static_cast <size_t >(it) *
418419 static_cast <size_t >(sgSize);
419420 auto in1_multi_ptr = sycl::address_space_cast<
@@ -433,9 +434,8 @@ struct BinaryContigFunctor
433434 }
434435 }
435436 else {
436- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
437- k += sgSize)
438- {
437+ const std::size_t lane_id = sg.get_local_id ()[0 ];
438+ for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
439439 out[k] = op (in1[k], in2[k]);
440440 }
441441 }
@@ -446,17 +446,17 @@ struct BinaryContigFunctor
446446 auto sg = ndit.get_sub_group ();
447447 std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
448448
449- size_t base = n_vecs * vec_sz *
449+ size_t base = static_cast < size_t >(elems_per_wi) *
450450 (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
451451 sg.get_group_id ()[0 ] * sgSize);
452452
453- if (base + n_vecs * vec_sz * sgSize < nelems_) {
453+ if (base + static_cast < size_t >(elems_per_wi * sgSize) < nelems_) {
454454 sycl::vec<argT1, vec_sz> arg1_vec;
455455 sycl::vec<argT2, vec_sz> arg2_vec;
456456 sycl::vec<resT, vec_sz> res_vec;
457457
458458#pragma unroll
459- for (std::uint8_t it = 0 ; it < n_vecs * vec_sz ; it += vec_sz) {
459+ for (std::uint32_t it = 0 ; it < elems_per_wi ; it += vec_sz) {
460460 size_t offset = base + static_cast <size_t >(it) *
461461 static_cast <size_t >(sgSize);
462462 auto in1_multi_ptr = sycl::address_space_cast<
@@ -480,20 +480,20 @@ struct BinaryContigFunctor
480480 }
481481 }
482482 else {
483- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
484- k += sgSize)
485- {
483+ const std::size_t lane_id = sg.get_local_id ()[0 ];
484+ for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
486485 out[k] = op (in1[k], in2[k]);
487486 }
488487 }
489488 }
490489 else {
491- std::uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
492- size_t base = ndit.get_global_linear_id ();
490+ const size_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
491+ const size_t gid = ndit.get_global_linear_id ();
492+ const size_t elems_per_sg = sgSize * elems_per_wi;
493493
494- base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
494+ const size_t base = (gid / sgSize) * elems_per_sg + (gid % sgSize);
495495 for (size_t offset = base;
496- offset < std::min (nelems_, base + sgSize * (n_vecs * vec_sz) );
496+ offset < std::min (nelems_, base + elems_per_sg );
497497 offset += sgSize)
498498 {
499499 out[offset] = op (in1[offset], in2[offset]);
0 commit comments