Skip to content

Commit 87c3acb

Browse files
committed
[unified] refactor type mapping
1 parent 5a4a5fd commit 87c3acb

File tree

5 files changed

+125
-75
lines changed

5 files changed

+125
-75
lines changed

common/cuda_hip/base/kernel_launch_solver.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -27,8 +27,7 @@ __global__ __launch_bounds__(default_block_size) void generic_kernel_2d_solver(
2727
if (row >= rows) {
2828
return;
2929
}
30-
fn(row, col,
31-
device_unpack_solver_impl<KernelArgs>::unpack(args, default_stride)...);
30+
fn(row, col, device_unpack(args, default_stride)...);
3231
}
3332

3433

common/unified/base/kernel_launch.hpp

Lines changed: 106 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,11 @@ using aliased_ptr = T*;
175175
* objects.
176176
*
177177
* @tparam ValueType the value type of the underlying matrix.
178+
* @tparam PtrWrapper the pointer type. By default, it's just `T*`, but it may
179+
* be set to restricted_ptr.
178180
*/
179-
template <typename ValueType, template <typename> typename PtrWrapper>
181+
template <typename ValueType,
182+
template <typename> typename PtrWrapper = aliased_ptr>
180183
struct matrix_accessor {
181184
PtrWrapper<ValueType> data;
182185
int64 stride;
@@ -202,9 +205,24 @@ struct matrix_accessor {
202205
};
203206

204207

208+
/**
209+
* Tag to signal that pointers should be annotated with `__restrict`
210+
*/
205211
struct restrict_tag {};
206212

207213

214+
/**
215+
* @internal
216+
* Adds a restrict annotation to an object.
217+
*
218+
* @note Can't be used for run_kernel_solver.
219+
*
220+
* @tparam T Type that should be annotated
221+
*
222+
* @param orig Original object
223+
*
224+
* @return The original object and a restrict_tag
225+
*/
208226
template <typename T>
209227
auto as_restrict(T&& orig) -> std::pair<T&&, restrict_tag>
210228
{
@@ -220,105 +238,139 @@ auto as_restrict(T&& orig) -> std::pair<T&&, restrict_tag>
220238
*
221239
* By default, it only maps std::complex to the corresponding device
222240
* representation of the complex type. There are specializations for dealing
223-
* with gko::array and gko::matrix::Dense (both const and mutable) that map them
241+
* with gko::array and gko::matrix::Dense that map them
224242
* to plain pointers or matrix_accessor objects.
225243
*
226-
* @tparam T the type being mapped. It will be used based on a
227-
* forwarding-reference, i.e. preserve references in the input
228-
* parameter, so special care must be taken to only return types that
229-
* can be passed to the device, i.e. (structs containing) device
230-
* pointers or values. This means that T will be either a r-value or
231-
* l-value reference.
244+
* @tparam T the underlying type being mapped. Any references or const
245+
* qualifiers have to be resolved before passing the type.
246+
* The distinction between const/mutable objects is done by
247+
* overloading the map_to_device function.
248+
* @tparam PtrWrapper the pointer type. By default, it's just `T*`, but it may
249+
* be set to restricted_ptr.
232250
*/
233251
template <typename T, template <typename> typename PtrWrapper = aliased_ptr>
234252
struct to_device_type_impl {
235-
using type = std::decay_t<device_type<T>>;
236-
static type map_to_device(T in) { return as_device_type(in); }
253+
static auto map_to_device(T in) -> device_type<T>
254+
{
255+
return as_device_type(in);
256+
}
237257
};
238258

239-
template <typename ValueType, template <typename> typename PtrWrapper>
240-
struct to_device_type_impl<matrix::Dense<ValueType>*&, PtrWrapper> {
241-
using type = matrix_accessor<device_type<ValueType>, PtrWrapper>;
242-
static type map_to_device(matrix::Dense<ValueType>* mtx)
259+
template <typename T, template <typename> typename PtrWrapper>
260+
struct to_device_type_impl<T*, PtrWrapper> {
261+
static auto map_to_device(T* in) -> PtrWrapper<device_type<T>>
262+
{
263+
return {as_device_type(in)};
264+
}
265+
static auto map_to_device(const T* in) -> PtrWrapper<const device_type<T>>
243266
{
244-
return to_device_type_impl<matrix::Dense<ValueType>* const&,
245-
PtrWrapper>::map_to_device(mtx);
267+
return {as_device_type(in)};
246268
}
247269
};
248270

249271
template <typename ValueType, template <typename> typename PtrWrapper>
250-
struct to_device_type_impl<matrix::Dense<ValueType>* const&, PtrWrapper> {
251-
using type = matrix_accessor<device_type<ValueType>, PtrWrapper>;
252-
static type map_to_device(matrix::Dense<ValueType>* mtx)
272+
struct to_device_type_impl<matrix::Dense<ValueType>*, PtrWrapper> {
273+
static auto map_to_device(matrix::Dense<ValueType>* mtx)
274+
-> matrix_accessor<device_type<ValueType>, PtrWrapper>
253275
{
254276
return {as_device_type(mtx->get_values()),
255277
static_cast<int64>(mtx->get_stride())};
256278
}
257-
};
258279

259-
template <typename ValueType, template <typename> typename PtrWrapper>
260-
struct to_device_type_impl<const matrix::Dense<ValueType>*&, PtrWrapper> {
261-
using type = matrix_accessor<const device_type<ValueType>, PtrWrapper>;
262-
static type map_to_device(const matrix::Dense<ValueType>* mtx)
280+
static auto map_to_device(const matrix::Dense<ValueType>* mtx)
281+
-> matrix_accessor<const device_type<ValueType>, PtrWrapper>
263282
{
264283
return {as_device_type(mtx->get_const_values()),
265284
static_cast<int64>(mtx->get_stride())};
266285
}
267286
};
268287

269288
template <typename ValueType, template <typename> typename PtrWrapper>
270-
struct to_device_type_impl<array<ValueType>&, PtrWrapper> {
271-
using type = PtrWrapper<device_type<ValueType>>;
272-
static type map_to_device(array<ValueType>& array)
289+
struct to_device_type_impl<array<ValueType>, PtrWrapper> {
290+
static auto map_to_device(array<ValueType>& array)
291+
-> PtrWrapper<device_type<ValueType>>
273292
{
274293
return {as_device_type(array.get_data())};
275294
}
276-
};
277295

278-
template <typename ValueType, template <typename> typename PtrWrapper>
279-
struct to_device_type_impl<const array<ValueType>&, PtrWrapper> {
280-
using type = PtrWrapper<const device_type<ValueType>>;
281-
static type map_to_device(const array<ValueType>& array)
296+
static auto map_to_device(const array<ValueType>& array)
297+
-> PtrWrapper<const device_type<ValueType>>
282298
{
283299
return {as_device_type(array.get_const_data())};
284300
}
285301
};
286302

287-
template <typename T, template <typename> typename PtrWrapper>
288-
struct to_device_type_impl<T*, PtrWrapper> {
289-
using type = PtrWrapper<device_type<T>>;
290-
static type map_to_device(T in) { return {as_device_type(in)}; }
303+
/**
304+
* Specialization for handling objects annotated by as_restrict.
305+
* It changes the pointer wrapper type to restricted_ptr.
306+
*/
307+
template <typename T>
308+
struct to_device_type_impl<std::pair<T, restrict_tag>, aliased_ptr> {
309+
template <typename U>
310+
static auto map_to_device(U&& in)
311+
{
312+
return to_device_type_impl<T, restricted_ptr>::map_to_device(in.first);
313+
}
291314
};
292315

293-
template <typename T, template <typename> typename PtrWrapper>
294-
struct to_device_type_impl<const T*, PtrWrapper> {
295-
using type = PtrWrapper<const device_type<T>>;
296-
static type map_to_device(T in) { return {as_device_type(in)}; }
316+
317+
namespace detail {
318+
319+
320+
/**
321+
* Similar to std::remove_cv_t except that it remove the const from pointers,
322+
* i.e. `const T*` -> `T*`.
323+
*/
324+
template <typename T>
325+
struct aggressive_remove_const {
326+
using type = std::remove_cv_t<T>;
297327
};
298328

299329
template <typename T>
300-
struct to_device_type_impl<std::pair<T&, restrict_tag>&, aliased_ptr> {
301-
using type = typename to_device_type_impl<T&, restricted_ptr>::type;
302-
static type map_to_device(std::pair<T&, restrict_tag> in)
303-
{
304-
return to_device_type_impl<T&, restricted_ptr>::map_to_device(in.first);
305-
}
330+
struct aggressive_remove_const<const T*> {
331+
using type = T*;
306332
};
307333

334+
/**
335+
* Similar to std::decay, except that it also applies std::decay on the first
336+
* nesting of types, i.e. `T<U&>` -> `T<U>`.
337+
* This only resolves a single level of nesting.
338+
*/
308339
template <typename T>
309-
struct to_device_type_impl<std::pair<T, restrict_tag>&, aliased_ptr> {
310-
using type = typename to_device_type_impl<T, restricted_ptr>::type;
311-
static type map_to_device(std::pair<T, restrict_tag> in)
312-
{
313-
return to_device_type_impl<T, restricted_ptr>::map_to_device(in.first);
314-
}
340+
struct nested_decay;
341+
342+
/**
343+
* Helper type for nested_decay.
344+
* This is necessary, since the references in the top-level type have to be
345+
* removed, before std::decay may be applied to the nested type.
346+
*/
347+
template <typename T>
348+
struct nested_decay_inner {
349+
using type = typename aggressive_remove_const<std::decay_t<T>>::type;
350+
};
351+
352+
template <typename T, typename Tag>
353+
struct nested_decay_inner<std::pair<T, Tag>> {
354+
using type = std::pair<typename nested_decay_inner<T>::type, Tag>;
315355
};
316356

357+
template <typename T>
358+
struct nested_decay {
359+
using type = typename nested_decay_inner<std::decay_t<T>>::type;
360+
};
361+
362+
363+
} // namespace detail
364+
365+
366+
template <typename T>
367+
using to_device_type =
368+
to_device_type_impl<typename detail::nested_decay<T>::type>;
317369

318370
template <typename T>
319-
typename to_device_type_impl<T>::type map_to_device(T&& param)
371+
auto map_to_device(T&& param)
320372
{
321-
return to_device_type_impl<T>::map_to_device(param);
373+
return to_device_type<T>::map_to_device(std::forward<T>(param));
322374
}
323375

324376

common/unified/base/kernel_launch_solver.hpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ struct device_unpack_solver_impl {
4343

4444
template <typename ValueType>
4545
struct device_unpack_solver_impl<default_stride_dense_wrapper<ValueType>> {
46-
using type = matrix_accessor<ValueType, aliased_ptr>;
46+
using type = matrix_accessor<ValueType>;
4747
static GKO_INLINE GKO_ATTRIBUTES type
4848
unpack(default_stride_dense_wrapper<ValueType> param, int64 default_stride)
4949
{
@@ -52,6 +52,16 @@ struct device_unpack_solver_impl<default_stride_dense_wrapper<ValueType>> {
5252
};
5353

5454

55+
template <typename T>
56+
GKO_INLINE GKO_ATTRIBUTES auto device_unpack(T&& param, int64 default_stride)
57+
{
58+
using device_type =
59+
std::decay_t<decltype(map_to_device(std::forward<T>(param)))>;
60+
return device_unpack_solver_impl<device_type>::unpack(
61+
std::forward<T>(param), default_stride);
62+
}
63+
64+
5565
/**
5666
* @internal
5767
* Wraps the given matrix in a wrapper signifying that it has the default stride

dpcpp/base/kernel_launch_solver.dp.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -22,9 +22,7 @@ void generic_kernel_2d_solver(sycl::handler& cgh, int64 rows, int64 cols,
2222
[=](sycl::id<1> idx) {
2323
auto row = static_cast<int64>(idx[0] / cols);
2424
auto col = static_cast<int64>(idx[0] % cols);
25-
fn(row, col,
26-
device_unpack_solver_impl<KernelArgs>::unpack(
27-
args, default_stride)...);
25+
fn(row, col, device_unpack(args, default_stride)...);
2826
});
2927
}
3028

omp/base/kernel_launch_solver.hpp

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -13,23 +13,14 @@ namespace kernels {
1313
namespace omp {
1414

1515

16-
template <typename T>
17-
typename device_unpack_solver_impl<typename to_device_type_impl<T>::type>::type
18-
map_to_device_solver(T&& param, int64 default_stride)
19-
{
20-
return device_unpack_solver_impl<typename to_device_type_impl<T>::type>::
21-
unpack(to_device_type_impl<T>::map_to_device(param), default_stride);
22-
}
23-
24-
2516
template <typename KernelFunction, typename... KernelArgs>
2617
void run_kernel_solver(std::shared_ptr<const OmpExecutor> exec,
2718
KernelFunction fn, dim<2> size, size_type default_stride,
2819
KernelArgs&&... args)
2920
{
30-
run_kernel_impl(
31-
exec, fn, size,
32-
map_to_device_solver(args, static_cast<int64>(default_stride))...);
21+
run_kernel_impl(exec, fn, size,
22+
device_unpack(map_to_device(std::forward<KernelArgs>(args)),
23+
static_cast<int64>(default_stride))...);
3324
}
3425

3526

0 commit comments

Comments
 (0)