@@ -72,22 +72,19 @@ struct UnaryContigFunctor
7272 {
7373 UnaryOperatorT op{};
7474 /* Each work-item processes vec_sz elements, contiguous in memory */
75- /* NOTE: vec_sz must divide sg.max_local_range()[0] */
75+ /* NOTE: work-group size must be divisible by sub-group size */
7676
7777 if constexpr (enable_sg_loadstore && UnaryOperatorT::is_constant::value)
7878 {
7979 // value of operator is known to be a known constant
8080 constexpr resT const_val = UnaryOperatorT::constant_value;
8181
8282 auto sg = ndit.get_sub_group ();
83- std::uint8_t sgSize = sg.get_local_range ()[0 ];
84- std::uint8_t max_sgSize = sg.get_max_local_range ()[0 ];
83+ std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
8584 size_t base = n_vecs * vec_sz *
8685 (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
8786 sg.get_group_id ()[0 ] * sgSize);
88- if (base + n_vecs * vec_sz * sgSize < nelems_ &&
89- max_sgSize == sgSize)
90- {
87+ if (base + n_vecs * vec_sz * sgSize < nelems_) {
9188 sycl::vec<resT, vec_sz> res_vec (const_val);
9289#pragma unroll
9390 for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
@@ -113,14 +110,11 @@ struct UnaryContigFunctor
113110 UnaryOperatorT::supports_vec::value)
114111 {
115112 auto sg = ndit.get_sub_group ();
116- std::uint16_t sgSize = sg.get_local_range ()[0 ];
117- std::uint16_t max_sgSize = sg.get_max_local_range ()[0 ];
113+ std::uint16_t sgSize = sg.get_max_local_range ()[0 ];
118114 size_t base = n_vecs * vec_sz *
119115 (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
120- sg.get_group_id ()[0 ] * max_sgSize);
121- if (base + n_vecs * vec_sz * sgSize < nelems_ &&
122- sgSize == max_sgSize)
123- {
116+ sg.get_group_id ()[0 ] * sgSize);
117+ if (base + n_vecs * vec_sz * sgSize < nelems_) {
124118 sycl::vec<argT, vec_sz> x;
125119
126120#pragma unroll
@@ -155,15 +149,12 @@ struct UnaryContigFunctor
155149 // default: use scalar-value function
156150
157151 auto sg = ndit.get_sub_group ();
158- std::uint8_t sgSize = sg.get_local_range ()[0 ];
159- std::uint8_t maxsgSize = sg.get_max_local_range ()[0 ];
152+ std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
160153 size_t base = n_vecs * vec_sz *
161154 (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
162- sg.get_group_id ()[0 ] * maxsgSize );
155+ sg.get_group_id ()[0 ] * sgSize );
163156
164- if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
165- (maxsgSize == sgSize))
166- {
157+ if (base + n_vecs * vec_sz * sgSize < nelems_) {
167158 sycl::vec<argT, vec_sz> arg_vec;
168159
169160#pragma unroll
@@ -199,15 +190,12 @@ struct UnaryContigFunctor
199190 // default: use scalar-value function
200191
201192 auto sg = ndit.get_sub_group ();
202- std::uint8_t sgSize = sg.get_local_range ()[0 ];
203- std::uint8_t maxsgSize = sg.get_max_local_range ()[0 ];
193+ std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
204194 size_t base = n_vecs * vec_sz *
205195 (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
206- sg.get_group_id ()[0 ] * maxsgSize );
196+ sg.get_group_id ()[0 ] * sgSize );
207197
208- if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
209- (maxsgSize == sgSize))
210- {
198+ if (base + n_vecs * vec_sz * sgSize < nelems_) {
211199 sycl::vec<argT, vec_sz> arg_vec;
212200 sycl::vec<resT, vec_sz> res_vec;
213201
@@ -406,22 +394,20 @@ struct BinaryContigFunctor
406394 {
407395 BinaryOperatorT op{};
408396 /* Each work-item processes vec_sz elements, contiguous in memory */
397+ /* NOTE: work-group size must be divisible by sub-group size */
409398
410399 if constexpr (enable_sg_loadstore &&
411400 BinaryOperatorT::supports_sg_loadstore::value &&
412401 BinaryOperatorT::supports_vec::value)
413402 {
414403 auto sg = ndit.get_sub_group ();
415- std::uint8_t sgSize = sg.get_local_range ()[0 ];
416- std::uint8_t maxsgSize = sg.get_max_local_range ()[0 ];
404+ std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
417405
418406 size_t base = n_vecs * vec_sz *
419407 (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
420408 sg.get_group_id ()[0 ] * sgSize);
421409
422- if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
423- (sgSize == maxsgSize))
424- {
410+ if (base + n_vecs * vec_sz * sgSize < nelems_) {
425411 sycl::vec<argT1, vec_sz> arg1_vec;
426412 sycl::vec<argT2, vec_sz> arg2_vec;
427413 sycl::vec<resT, vec_sz> res_vec;
@@ -458,16 +444,13 @@ struct BinaryContigFunctor
458444 BinaryOperatorT::supports_sg_loadstore::value)
459445 {
460446 auto sg = ndit.get_sub_group ();
461- std::uint8_t sgSize = sg.get_local_range ()[0 ];
462- std::uint8_t maxsgSize = sg.get_max_local_range ()[0 ];
447+ std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
463448
464449 size_t base = n_vecs * vec_sz *
465450 (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
466451 sg.get_group_id ()[0 ] * sgSize);
467452
468- if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
469- (sgSize == maxsgSize))
470- {
453+ if (base + n_vecs * vec_sz * sgSize < nelems_) {
471454 sycl::vec<argT1, vec_sz> arg1_vec;
472455 sycl::vec<argT2, vec_sz> arg2_vec;
473456 sycl::vec<resT, vec_sz> res_vec;
@@ -582,13 +565,15 @@ struct BinaryContigMatrixContigRowBroadcastingFunctor
582565
583566 void operator ()(sycl::nd_item<1 > ndit) const
584567 {
568+ /* NOTE: work-group size must be divisible by sub-group size */
569+
585570 BinaryOperatorT op{};
586571 static_assert (BinaryOperatorT::supports_sg_loadstore::value);
587572
588573 auto sg = ndit.get_sub_group ();
589574 size_t gid = ndit.get_global_linear_id ();
590575
591- std::uint8_t sgSize = sg.get_local_range ()[0 ];
576+ std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
592577 size_t base = gid - sg.get_local_id ()[0 ];
593578
594579 if (base + sgSize < n_elems) {
@@ -647,13 +632,14 @@ struct BinaryContigRowContigMatrixBroadcastingFunctor
647632
648633 void operator ()(sycl::nd_item<1 > ndit) const
649634 {
635+ /* NOTE: work-group size must be divisible by sub-group size */
650636 BinaryOperatorT op{};
651637 static_assert (BinaryOperatorT::supports_sg_loadstore::value);
652638
653639 auto sg = ndit.get_sub_group ();
654640 size_t gid = ndit.get_global_linear_id ();
655641
656- std::uint8_t sgSize = sg.get_local_range ()[0 ];
642+ std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
657643 size_t base = gid - sg.get_local_id ()[0 ];
658644
659645 if (base + sgSize < n_elems) {
0 commit comments