@@ -44,8 +44,8 @@ namespace copy_as_contig
4444
4545template <typename T,
4646 typename IndexerT,
47- int vec_sz = 4 ,
48- int n_vecs = 2 ,
47+ std:: uint32_t vec_sz = 4u ,
48+ std:: uint32_t n_vecs = 2u ,
4949 bool enable_sg_loadstore = true >
5050class CopyAsCContigFunctor
5151{
@@ -66,53 +66,63 @@ class CopyAsCContigFunctor
6666
6767 void operator ()(sycl::nd_item<1 > ndit) const
6868 {
69+ static_assert (vec_sz > 0 );
70+ static_assert (n_vecs > 0 );
71+ static_assert (vec_sz * n_vecs < (std::uint32_t (1 ) << 8 ));
72+
73+ constexpr std::uint8_t elems_per_wi =
74+ static_cast <std::uint8_t >(vec_sz * n_vecs);
75+
6976 using dpctl::tensor::type_utils::is_complex;
7077 if constexpr (!enable_sg_loadstore || is_complex<T>::value) {
71- const std::uint32_t sgSize =
78+ const std::uint16_t sgSize =
7279 ndit.get_sub_group ().get_local_range ()[0 ];
7380 const std::size_t gid = ndit.get_global_linear_id ();
7481
75- const std::size_t base =
76- (gid / sgSize) * sgSize * n_vecs * vec_sz + (gid % sgSize);
77- for (size_t offset = base;
78- offset < std::min (nelems, base + sgSize * (n_vecs * vec_sz));
79- offset += sgSize)
80- {
82+ // base = (gid / sgSize) * sgSize * elems_per_wi + (gid % sgSize)
83+ // gid % sgSize == gid - (gid / sgSize) * sgSize
84+ const std::size_t elems_per_sg = sgSize * (elems_per_wi - 1 );
85+ const std::size_t base = (gid / sgSize) * elems_per_sg + gid;
86+ const std::size_t offset_max =
87+ std::min (nelems, base + sgSize * elems_per_wi);
88+
89+ for (size_t offset = base; offset < offset_max; offset += sgSize) {
8190 auto src_offset = src_indexer (offset);
8291 dst_p[offset] = src_p[src_offset];
8392 }
8493 }
8594 else {
8695 auto sg = ndit.get_sub_group ();
87- const std::uint32_t sgSize = sg.get_local_range ()[0 ];
88- const size_t base = n_vecs * vec_sz *
89- (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
90- sg.get_group_id ()[0 ] * sgSize);
96+ const std::uint16_t sgSize = sg.get_max_local_range ()[0 ];
97+ const size_t base =
98+ elems_per_wi * (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
99+ sg.get_group_id ()[0 ] * sgSize);
91100
92- if (base + n_vecs * vec_sz * sgSize < nelems) {
101+ if (base + elems_per_wi * sgSize < nelems) {
93102 sycl::vec<T, vec_sz> dst_vec;
94103
95104#pragma unroll
96- for (std::uint32_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
105+ for (std::uint8_t it = 0 ; it < elems_per_wi; it += vec_sz) {
106+ const size_t block_start_id = base + it * sgSize;
97107 auto dst_multi_ptr = sycl::address_space_cast<
98108 sycl::access::address_space::global_space,
99- sycl::access::decorated::yes>(
100- &dst_p[base + it * sgSize]);
109+ sycl::access::decorated::yes>(&dst_p[block_start_id]);
101110
111+ const size_t elem_id0 = block_start_id + sg.get_local_id ();
102112#pragma unroll
103- for (std::uint32_t k = 0 ; k < vec_sz; k++) {
104- ssize_t src_offset = src_indexer (
105- base + (it + k) * sgSize + sg. get_local_id () );
113+ for (std::uint8_t k = 0 ; k < vec_sz; k++) {
114+ const size_t elem_id = elem_id0 + k * sgSize;
115+ const ssize_t src_offset = src_indexer (elem_id );
106116 dst_vec[k] = src_p[src_offset];
107117 }
108118 sg.store <vec_sz>(dst_multi_ptr, dst_vec);
109119 }
110120 }
111121 else {
112- for ( size_t k = base + sg.get_local_id ()[0 ]; k < nelems ;
113- k += sgSize)
114- {
115- ssize_t src_offset = src_indexer (k);
122+ const size_t lane_id = sg.get_local_id ()[0 ];
123+ const size_t k0 = base + lane_id;
124+ for ( size_t k = k0; k < nelems; k += sgSize) {
125+ const ssize_t src_offset = src_indexer (k);
116126 dst_p[k] = src_p[src_offset];
117127 }
118128 }
@@ -121,36 +131,23 @@ class CopyAsCContigFunctor
121131};
122132
123133template <typename T,
124- typename IndexT,
125- int vec_sz,
126- int n_vecs,
127- bool enable_sgload>
128- class as_contig_krn ;
129-
130- template <typename T>
131- sycl::event
132- as_c_contiguous_array_generic_impl (sycl::queue &exec_q,
133- size_t nelems,
134- int nd,
135- const ssize_t *shape_and_strides,
136- const char *src_p,
137- char *dst_p,
138- const std::vector<sycl::event> &depends)
134+ typename IndexerT,
135+ std::uint32_t vec_sz,
136+ std::uint32_t n_vecs,
137+ bool enable_sg_load,
138+ typename KernelName>
139+ sycl::event submit_c_contiguous_copy (sycl::queue &exec_q,
140+ size_t nelems,
141+ const T *src,
142+ T *dst,
143+ const IndexerT &src_indexer,
144+ const std::vector<sycl::event> &depends)
139145{
140- dpctl::tensor::type_utils::validate_type_for_device<T>(exec_q);
141-
142- const T *src_tp = reinterpret_cast <const T *>(src_p);
143- T *dst_tp = reinterpret_cast <T *>(dst_p);
144-
145- using IndexerT = dpctl::tensor::offset_utils::StridedIndexer;
146- const IndexerT src_indexer (nd, ssize_t (0 ), shape_and_strides);
146+ static_assert (vec_sz > 0 );
147+ static_assert (n_vecs > 0 );
148+ static_assert (vec_sz * n_vecs < (std::uint32_t (1 ) << 8 ));
147149
148150 constexpr std::size_t preferred_lws = 256 ;
149- constexpr std::uint32_t n_vecs = 2 ;
150- constexpr std::uint32_t vec_sz = 4 ;
151- constexpr bool enable_sg_load = true ;
152- using KernelName =
153- as_contig_krn<T, IndexerT, vec_sz, n_vecs, enable_sg_load>;
154151
155152 const auto &kernel_id = sycl::get_kernel_id<KernelName>();
156153
@@ -167,9 +164,11 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
167164 const std::size_t lws =
168165 ((preferred_lws + max_sg_size - 1 ) / max_sg_size) * max_sg_size;
169166
170- constexpr std::uint32_t nelems_per_wi = n_vecs * vec_sz;
171- size_t n_groups =
172- (nelems + nelems_per_wi * lws - 1 ) / (nelems_per_wi * lws);
167+ constexpr std::uint8_t nelems_per_wi = n_vecs * vec_sz;
168+
169+ const size_t nelems_per_group = nelems_per_wi * lws;
170+ const size_t n_groups =
171+ (nelems + nelems_per_group - 1 ) / (nelems_per_group);
173172
174173 sycl::event copy_ev = exec_q.submit ([&](sycl::handler &cgh) {
175174 cgh.depends_on (depends);
@@ -181,8 +180,62 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
181180 cgh.parallel_for <KernelName>(
182181 sycl::nd_range<1 >(gRange , lRange),
183182 CopyAsCContigFunctor<T, IndexerT, vec_sz, n_vecs, enable_sg_load>(
184- nelems, src_tp, dst_tp , src_indexer));
183+ nelems, src, dst , src_indexer));
185184 });
185+ return copy_ev;
186+ }
187+
188+ template <typename T,
189+ typename IndexT,
190+ std::uint32_t vec_sz,
191+ std::uint32_t n_vecs,
192+ bool enable_sgload>
193+ class as_contig_krn ;
194+
195+ template <typename T>
196+ sycl::event
197+ as_c_contiguous_array_generic_impl (sycl::queue &exec_q,
198+ size_t nelems,
199+ int nd,
200+ const ssize_t *shape_and_strides,
201+ const char *src_p,
202+ char *dst_p,
203+ const std::vector<sycl::event> &depends)
204+ {
205+ dpctl::tensor::type_utils::validate_type_for_device<T>(exec_q);
206+
207+ const T *src_tp = reinterpret_cast <const T *>(src_p);
208+ T *dst_tp = reinterpret_cast <T *>(dst_p);
209+
210+ using IndexerT = dpctl::tensor::offset_utils::StridedIndexer;
211+ const IndexerT src_indexer (nd, ssize_t (0 ), shape_and_strides);
212+
213+ constexpr std::uint32_t vec_sz = 4u ;
214+ constexpr std::uint32_t n_vecs = 2u ;
215+
216+ using dpctl::tensor::kernels::alignment_utils::
217+ disabled_sg_loadstore_wrapper_krn;
218+ using dpctl::tensor::kernels::alignment_utils::is_aligned;
219+ using dpctl::tensor::kernels::alignment_utils::required_alignment;
220+
221+ sycl::event copy_ev;
222+ if (is_aligned<required_alignment>(dst_p)) {
223+ constexpr bool enable_sg_load = true ;
224+ using KernelName =
225+ as_contig_krn<T, IndexerT, vec_sz, n_vecs, enable_sg_load>;
226+ copy_ev = submit_c_contiguous_copy<T, IndexerT, vec_sz, n_vecs,
227+ enable_sg_load, KernelName>(
228+ exec_q, nelems, src_tp, dst_tp, src_indexer, depends);
229+ }
230+ else {
231+ constexpr bool disable_sg_load = false ;
232+ using InnerKernelName =
233+ as_contig_krn<T, IndexerT, vec_sz, n_vecs, disable_sg_load>;
234+ using KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
235+ copy_ev = submit_c_contiguous_copy<T, IndexerT, vec_sz, n_vecs,
236+ disable_sg_load, KernelName>(
237+ exec_q, nelems, src_tp, dst_tp, src_indexer, depends);
238+ }
186239
187240 return copy_ev;
188241}
0 commit comments