@@ -175,8 +175,11 @@ using aliased_ptr = T*;
175175 * objects.
176176 *
177177 * @tparam ValueType the value type of the underlying matrix.
178+ * @tparam PtrWrapper the pointer type. By default, it's just `T*`, but it may
179+ * be set to restricted_ptr.
178180 */
179- template <typename ValueType, template <typename > typename PtrWrapper>
181+ template <typename ValueType,
182+ template <typename > typename PtrWrapper = aliased_ptr>
180183struct matrix_accessor {
181184 PtrWrapper<ValueType> data;
182185 int64 stride;
@@ -202,9 +205,24 @@ struct matrix_accessor {
202205};
203206
204207
208+ /* *
209+ * Tag to signal that pointers should be annotated with `__restrict`
210+ */
205211struct restrict_tag {};
206212
207213
214+ /* *
215+ * @internal
216+ * Adds a restrict annotation to an object.
217+ *
218+ * @note Can't be used for run_kernel_solver.
219+ *
220+ * @tparam T Type that should be annotated
221+ *
222+ * @param orig Original object
223+ *
224+ * @return The original object and a restrict_tag
225+ */
208226template <typename T>
209227auto as_restrict (T&& orig) -> std::pair<T&&, restrict_tag>
210228{
@@ -220,105 +238,139 @@ auto as_restrict(T&& orig) -> std::pair<T&&, restrict_tag>
220238 *
221239 * By default, it only maps std::complex to the corresponding device
222240 * representation of the complex type. There are specializations for dealing
223- * with gko::array and gko::matrix::Dense (both const and mutable) that map them
241+ * with gko::array and gko::matrix::Dense that map them
224242 * to plain pointers or matrix_accessor objects.
225243 *
226- * @tparam T the type being mapped. It will be used based on a
227- * forwarding-reference, i.e. preserve references in the input
228- * parameter, so special care must be taken to only return types that
229- * can be passed to the device, i.e. (structs containing) device
230- * pointers or values. This means that T will be either a r-value or
231- * l-value reference .
244+ * @tparam T the underlying type being mapped. Any references or const
245+ * qualifiers have to be resolved before passing the type.
246+ * The distinction between const/mutable objects is done by
247+ * overloading the map_to_device function.
248+ * @tparam PtrWrapper the pointer type. By default, it's just `T*`, but it may
249+ * be set to restricted_ptr .
232250 */
233251template <typename T, template <typename > typename PtrWrapper = aliased_ptr>
234252struct to_device_type_impl {
235- using type = std::decay_t <device_type<T>>;
236- static type map_to_device (T in) { return as_device_type (in); }
253+ static auto map_to_device (T in) -> device_type<T>
254+ {
255+ return as_device_type (in);
256+ }
237257};
238258
239- template <typename ValueType, template <typename > typename PtrWrapper>
240- struct to_device_type_impl <matrix::Dense<ValueType>*&, PtrWrapper> {
241- using type = matrix_accessor<device_type<ValueType>, PtrWrapper>;
242- static type map_to_device (matrix::Dense<ValueType>* mtx)
259+ template <typename T, template <typename > typename PtrWrapper>
260+ struct to_device_type_impl <T*, PtrWrapper> {
261+ static auto map_to_device (T* in) -> PtrWrapper<device_type<T>>
262+ {
263+ return {as_device_type (in)};
264+ }
265+ static auto map_to_device (const T* in) -> PtrWrapper<const device_type<T>>
243266 {
244- return to_device_type_impl<matrix::Dense<ValueType>* const &,
245- PtrWrapper>::map_to_device (mtx);
267+ return {as_device_type (in)};
246268 }
247269};
248270
249271template <typename ValueType, template <typename > typename PtrWrapper>
250- struct to_device_type_impl <matrix::Dense<ValueType>* const & , PtrWrapper> {
251- using type = matrix_accessor<device_type< ValueType>, PtrWrapper>;
252- static type map_to_device (matrix::Dense< ValueType>* mtx)
272+ struct to_device_type_impl <matrix::Dense<ValueType>*, PtrWrapper> {
273+ static auto map_to_device (matrix::Dense< ValueType>* mtx)
274+ -> matrix_accessor<device_type< ValueType>, PtrWrapper>
253275 {
254276 return {as_device_type (mtx->get_values ()),
255277 static_cast <int64>(mtx->get_stride ())};
256278 }
257- };
258279
259- template <typename ValueType, template <typename > typename PtrWrapper>
260- struct to_device_type_impl <const matrix::Dense<ValueType>*&, PtrWrapper> {
261- using type = matrix_accessor<const device_type<ValueType>, PtrWrapper>;
262- static type map_to_device (const matrix::Dense<ValueType>* mtx)
280+ static auto map_to_device (const matrix::Dense<ValueType>* mtx)
281+ -> matrix_accessor<const device_type<ValueType>, PtrWrapper>
263282 {
264283 return {as_device_type (mtx->get_const_values ()),
265284 static_cast <int64>(mtx->get_stride ())};
266285 }
267286};
268287
269288template <typename ValueType, template <typename > typename PtrWrapper>
270- struct to_device_type_impl <array<ValueType>& , PtrWrapper> {
271- using type = PtrWrapper<device_type< ValueType>>;
272- static type map_to_device (array< ValueType>& array)
289+ struct to_device_type_impl <array<ValueType>, PtrWrapper> {
290+ static auto map_to_device (array< ValueType>& array)
291+ -> PtrWrapper<device_type< ValueType>>
273292 {
274293 return {as_device_type (array.get_data ())};
275294 }
276- };
277295
278- template <typename ValueType, template <typename > typename PtrWrapper>
279- struct to_device_type_impl <const array<ValueType>&, PtrWrapper> {
280- using type = PtrWrapper<const device_type<ValueType>>;
281- static type map_to_device (const array<ValueType>& array)
296+ static auto map_to_device (const array<ValueType>& array)
297+ -> PtrWrapper<const device_type<ValueType>>
282298 {
283299 return {as_device_type (array.get_const_data ())};
284300 }
285301};
286302
287- template <typename T, template <typename > typename PtrWrapper>
288- struct to_device_type_impl <T*, PtrWrapper> {
289- using type = PtrWrapper<device_type<T>>;
290- static type map_to_device (T in) { return {as_device_type (in)}; }
303+ /* *
304+ * Specialization for handling objects annotated by as_restrict.
305+ * It changes the pointer wrapper type to restricted_ptr.
306+ */
307+ template <typename T>
308+ struct to_device_type_impl <std::pair<T, restrict_tag>, aliased_ptr> {
309+ template <typename U>
310+ static auto map_to_device (U&& in)
311+ {
312+ return to_device_type_impl<T, restricted_ptr>::map_to_device (in.first );
313+ }
291314};
292315
293- template <typename T, template <typename > typename PtrWrapper>
294- struct to_device_type_impl <const T*, PtrWrapper> {
295- using type = PtrWrapper<const device_type<T>>;
296- static type map_to_device (T in) { return {as_device_type (in)}; }
316+
317+ namespace detail {
318+
319+
320+ /* *
321+ * Similar to std::remove_cv_t except that it remove the const from pointers,
322+ * i.e. `const T*` -> `T*`.
323+ */
324+ template <typename T>
325+ struct aggressive_remove_const {
326+ using type = std::remove_cv_t <T>;
297327};
298328
299329template <typename T>
300- struct to_device_type_impl <std::pair<T&, restrict_tag>&, aliased_ptr> {
301- using type = typename to_device_type_impl<T&, restricted_ptr>::type;
302- static type map_to_device (std::pair<T&, restrict_tag> in)
303- {
304- return to_device_type_impl<T&, restricted_ptr>::map_to_device (in.first );
305- }
330+ struct aggressive_remove_const <const T*> {
331+ using type = T*;
306332};
307333
334+ /* *
335+ * Similar to std::decay, except that it also applies std::decay on the first
336+ * nesting of types, i.e. `T<U&>` -> `T<U>`.
337+ * This only resolves a single level of nesting.
338+ */
308339template <typename T>
309- struct to_device_type_impl <std::pair<T, restrict_tag>&, aliased_ptr> {
310- using type = typename to_device_type_impl<T, restricted_ptr>::type;
311- static type map_to_device (std::pair<T, restrict_tag> in)
312- {
313- return to_device_type_impl<T, restricted_ptr>::map_to_device (in.first );
314- }
340+ struct nested_decay ;
341+
342+ /* *
343+ * Helper type for nested_decay.
344+ * This is necessary, since the references in the top-level type have to be
345+ * removed, before std::decay may be applied to the nested type.
346+ */
347+ template <typename T>
348+ struct nested_decay_inner {
349+ using type = typename aggressive_remove_const<std::decay_t <T>>::type;
350+ };
351+
352+ template <typename T, typename Tag>
353+ struct nested_decay_inner <std::pair<T, Tag>> {
354+ using type = std::pair<typename nested_decay_inner<T>::type, Tag>;
315355};
316356
357+ template <typename T>
358+ struct nested_decay {
359+ using type = typename nested_decay_inner<std::decay_t <T>>::type;
360+ };
361+
362+
363+ } // namespace detail
364+
365+
366+ template <typename T>
367+ using to_device_type =
368+ to_device_type_impl<typename detail::nested_decay<T>::type>;
317369
318370template <typename T>
319- typename to_device_type_impl<T>::type map_to_device (T&& param)
371+ auto map_to_device (T&& param)
320372{
321- return to_device_type_impl <T>::map_to_device (param);
373+ return to_device_type <T>::map_to_device (std::forward<T>( param) );
322374}
323375
324376
0 commit comments