3333#include " kernels/alignment.hpp"
3434#include " utils/math_utils.hpp"
3535#include " utils/offset_utils.hpp"
36+ #include " utils/sycl_utils.hpp"
3637#include " utils/type_utils.hpp"
3738
3839namespace dpctl
@@ -51,6 +52,9 @@ using dpctl::tensor::kernels::alignment_utils::
5152using dpctl::tensor::kernels::alignment_utils::is_aligned;
5253using dpctl::tensor::kernels::alignment_utils::required_alignment;
5354
55+ using dpctl::tensor::sycl_utils::sub_group_load;
56+ using dpctl::tensor::sycl_utils::sub_group_store;
57+
5458template <typename T> T clip (const T &x, const T &min, const T &max)
5559{
5660 using dpctl::tensor::type_utils::is_complex;
@@ -75,8 +79,8 @@ template <typename T> T clip(const T &x, const T &min, const T &max)
7579}
7680
7781template <typename T,
78- int vec_sz = 4 ,
79- int n_vecs = 2 ,
82+ std:: uint8_t vec_sz = 4 ,
83+ std:: uint8_t n_vecs = 2 ,
8084 bool enable_sg_loadstore = true >
8185class ClipContigFunctor
8286{
@@ -100,37 +104,36 @@ class ClipContigFunctor
100104
101105 void operator ()(sycl::nd_item<1 > ndit) const
102106 {
107+ constexpr std::uint8_t nelems_per_wi = n_vecs * vec_sz;
108+
103109 using dpctl::tensor::type_utils::is_complex;
104110 if constexpr (is_complex<T>::value || !enable_sg_loadstore) {
105- std::uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
106- size_t base = ndit.get_global_linear_id ();
107-
108- base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
109- for (size_t offset = base;
110- offset < std::min (nelems, base + sgSize * (n_vecs * vec_sz));
111- offset += sgSize)
112- {
111+ const std::uint16_t sgSize =
112+ ndit.get_sub_group ().get_local_range ()[0 ];
113+ const size_t gid = ndit.get_global_linear_id ();
114+ const uint16_t nelems_per_sg = sgSize * nelems_per_wi;
115+
116+ const size_t start =
117+ (gid / sgSize) * (nelems_per_sg - sgSize) + gid;
118+ const size_t end = std::min (nelems, start + nelems_per_sg);
119+
120+ for (size_t offset = start; offset < end; offset += sgSize) {
113121 dst_p[offset] = clip (x_p[offset], min_p[offset], max_p[offset]);
114122 }
115123 }
116124 else {
117125 auto sg = ndit.get_sub_group ();
118- std::uint8_t sgSize = sg.get_local_range ()[0 ];
119- std::uint8_t max_sgSize = sg.get_max_local_range ()[0 ];
120- size_t base = n_vecs * vec_sz *
121- (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
122- sg.get_group_id ()[0 ] * max_sgSize);
123-
124- if (base + n_vecs * vec_sz * sgSize < nelems &&
125- sgSize == max_sgSize)
126- {
127- sycl::vec<T, vec_sz> x_vec;
128- sycl::vec<T, vec_sz> min_vec;
129- sycl::vec<T, vec_sz> max_vec;
126+ const std::uint16_t sgSize = sg.get_max_local_range ()[0 ];
127+
128+ const size_t base =
129+ nelems_per_wi * (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
130+ sg.get_group_id ()[0 ] * sgSize);
131+
132+ if (base + nelems_per_wi * sgSize < nelems) {
130133 sycl::vec<T, vec_sz> dst_vec;
131134#pragma unroll
132135 for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
133- auto idx = base + it * sgSize;
136+ const size_t idx = base + it * sgSize;
134137 auto x_multi_ptr = sycl::address_space_cast<
135138 sycl::access::address_space::global_space,
136139 sycl::access::decorated::yes>(&x_p[idx]);
@@ -144,21 +147,23 @@ class ClipContigFunctor
144147 sycl::access::address_space::global_space,
145148 sycl::access::decorated::yes>(&dst_p[idx]);
146149
147- x_vec = sg.load <vec_sz>(x_multi_ptr);
148- min_vec = sg.load <vec_sz>(min_multi_ptr);
149- max_vec = sg.load <vec_sz>(max_multi_ptr);
150+ const sycl::vec<T, vec_sz> x_vec =
151+ sub_group_load<vec_sz>(sg, x_multi_ptr);
152+ const sycl::vec<T, vec_sz> min_vec =
153+ sub_group_load<vec_sz>(sg, min_multi_ptr);
154+ const sycl::vec<T, vec_sz> max_vec =
155+ sub_group_load<vec_sz>(sg, max_multi_ptr);
150156#pragma unroll
151157 for (std::uint8_t vec_id = 0 ; vec_id < vec_sz; ++vec_id) {
152158 dst_vec[vec_id] = clip (x_vec[vec_id], min_vec[vec_id],
153159 max_vec[vec_id]);
154160 }
155- sg. store <vec_sz>(dst_multi_ptr , dst_vec);
161+ sub_group_store <vec_sz>(sg , dst_vec, dst_multi_ptr );
156162 }
157163 }
158164 else {
159- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems;
160- k += sgSize)
161- {
165+ const size_t lane_id = sg.get_local_id ()[0 ];
166+ for (size_t k = base + lane_id; k < nelems; k += sgSize) {
162167 dst_p[k] = clip (x_p[k], min_p[k], max_p[k]);
163168 }
164169 }
@@ -195,8 +200,8 @@ sycl::event clip_contig_impl(sycl::queue &q,
195200 cgh.depends_on (depends);
196201
197202 size_t lws = 64 ;
198- constexpr unsigned int vec_sz = 4 ;
199- constexpr unsigned int n_vecs = 2 ;
203+ constexpr std:: uint8_t vec_sz = 4 ;
204+ constexpr std:: uint8_t n_vecs = 2 ;
200205 const size_t n_groups =
201206 ((nelems + lws * n_vecs * vec_sz - 1 ) / (lws * n_vecs * vec_sz));
202207 const auto gws_range = sycl::range<1 >(n_groups * lws);
0 commit comments