From b6b0bb2dc9888453bb8e7097e7d9523cae39bd25 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 13 Sep 2021 11:13:11 -0700 Subject: [PATCH 1/8] Allow cub::DeviceRadixSort and cub::DeviceSegmentedRadixSort to use iterator as input --- cub/device/device_radix_sort.cuh | 165 ++++++++++------- cub/device/device_segmented_radix_sort.cuh | 140 ++++++++------- cub/device/dispatch/dispatch_radix_sort.cuh | 186 +++++++++++++------- 3 files changed, 299 insertions(+), 192 deletions(-) diff --git a/cub/device/device_radix_sort.cuh b/cub/device/device_radix_sort.cuh index 4d540568a1..0422b00ff6 100644 --- a/cub/device/device_radix_sort.cuh +++ b/cub/device/device_radix_sort.cuh @@ -179,23 +179,35 @@ struct DeviceRadixSort * * \endcode * - * \tparam KeyT [inferred] KeyT type - * \tparam ValueT [inferred] ValueT type + * \tparam KeyInputIteratorT is a model of Random Access Iterator, + * \p KeyInputIteratorT is mutable, and \p KeyInputIteratorT's \c value_type is + * a model of LessThan Comparable, + * and the ordering relation on \p KeyInputIteratorT's \c value_type is a strict weak ordering, as defined in the + * LessThan Comparable requirements. + * \tparam ValueInputIteratorT is a model of Random Access Iterator. + * \tparam KeyIteratorT is a model of Random Access Iterator, + * \p KeyIteratorT is mutable, and \p KeyIteratorT's \c value_type is + * a model of LessThan Comparable, + * and the ordering relation on \p KeyIteratorT's \c value_type is a strict weak ordering, as defined in the + * LessThan Comparable requirements. + * \tparam ValueIteratorT is a model of Random Access Iterator. */ - template < - typename KeyT, - typename ValueT> + template CUB_RUNTIME_FUNCTION static cudaError_t SortPairs( void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - const KeyT *d_keys_in, ///< [in] Pointer to the input data of key data to sort - KeyT *d_keys_out, ///< [out] Pointer to the sorted output sequence of key data - const ValueT *d_values_in, ///< [in] Pointer to the corresponding input sequence of associated value items - ValueT *d_values_out, ///< [out] Pointer to the correspondingly-reordered output sequence of associated value items + KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort + KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data + ValueInputIteratorT d_values_in, ///< [in] Iterator to the corresponding input sequence of associated value items + ValueIteratorT d_values_out, ///< [out] Iterator to the correspondingly-reordered output sequence of associated value items int num_items, ///< [in] Number of items to sort int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison - int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + = sizeof(typename std::iterator_traits::value_type) * 8, cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. { @@ -207,14 +219,16 @@ struct DeviceRadixSort // create a new double-buffer internally when the `is_overwrite_ok` flag // is not set. constexpr bool is_overwrite_okay = false; - DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); - DoubleBuffer d_values(const_cast(d_values_in), d_values_out); - return DispatchRadixSort::Dispatch( + return DispatchRadixSort< + false, KeyInputIteratorT, KeyIteratorT, ValueInputIteratorT, ValueIteratorT, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, num_items, begin_bit, end_bit, @@ -314,11 +328,15 @@ struct DeviceRadixSort constexpr bool is_overwrite_okay = true; - return DispatchRadixSort::Dispatch( + return DispatchRadixSort< + false , const KeyT *, KeyT *, const ValueT *, ValueT *, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys.Current(), + d_keys.Alternate(), + d_values.Current(), + d_values.Alternate(), num_items, begin_bit, end_bit, @@ -384,20 +402,22 @@ struct DeviceRadixSort * \tparam KeyT [inferred] KeyT type * \tparam ValueT [inferred] ValueT type */ - template < - typename KeyT, - typename ValueT> + template CUB_RUNTIME_FUNCTION static cudaError_t SortPairsDescending( void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - const KeyT *d_keys_in, ///< [in] Pointer to the input data of key data to sort - KeyT *d_keys_out, ///< [out] Pointer to the sorted output sequence of key data - const ValueT *d_values_in, ///< [in] Pointer to the corresponding input sequence of associated value items - ValueT *d_values_out, ///< [out] Pointer to the correspondingly-reordered output sequence of associated value items + KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort + KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data + ValueInputIteratorT d_values_in, ///< [in] Iterator to the corresponding input sequence of associated value items + ValueIteratorT d_values_out, ///< [out] Iterator to the correspondingly-reordered output sequence of associated value items int num_items, ///< [in] Number of items to sort int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison - int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + = sizeof(typename std::iterator_traits::value_type) * 8, cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. { @@ -409,14 +429,16 @@ struct DeviceRadixSort // create a new double-buffer internally when the `is_overwrite_ok` flag // is not set. constexpr bool is_overwrite_okay = false; - DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); - DoubleBuffer d_values(const_cast(d_values_in), d_values_out); - return DispatchRadixSort::Dispatch( + return DispatchRadixSort< + true, KeyInputIteratorT, KeyIteratorT, ValueInputIteratorT, ValueIteratorT, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, num_items, begin_bit, end_bit, @@ -511,11 +533,15 @@ struct DeviceRadixSort constexpr bool is_overwrite_okay = true; - return DispatchRadixSort::Dispatch( + return DispatchRadixSort< + true, const KeyT *, KeyT *, const ValueT *, ValueT *, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys.Current(), + d_keys.Alternate(), + d_values.Current(), + d_values.Alternate(), num_items, begin_bit, end_bit, @@ -583,16 +609,18 @@ struct DeviceRadixSort * * \tparam KeyT [inferred] KeyT type */ - template + template CUB_RUNTIME_FUNCTION static cudaError_t SortKeys( void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - const KeyT *d_keys_in, ///< [in] Pointer to the input data of key data to sort - KeyT *d_keys_out, ///< [out] Pointer to the sorted output sequence of key data + KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort + KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data int num_items, ///< [in] Number of items to sort int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison - int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + = sizeof(typename std::iterator_traits::value_type) * 8, cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. { @@ -604,15 +632,16 @@ struct DeviceRadixSort // create a new double-buffer internally when the `is_overwrite_ok` flag // is not set. constexpr bool is_overwrite_okay = false; - DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); - // Null value type - DoubleBuffer d_values; - return DispatchRadixSort::Dispatch( + return DispatchRadixSort< + false, KeyInputIteratorT, KeyIteratorT, const NullType *, NullType *, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys_in, + d_keys_out, + nullptr, + nullptr, num_items, begin_bit, end_bit, @@ -698,14 +727,15 @@ struct DeviceRadixSort constexpr bool is_overwrite_okay = true; - // Null value type - DoubleBuffer d_values; - - return DispatchRadixSort::Dispatch( + return DispatchRadixSort< + false, const KeyT *, KeyT *, const NullType *, NullType *, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys.Current(), + d_keys.Alternate(), + nullptr, + nullptr, num_items, begin_bit, end_bit, @@ -764,16 +794,18 @@ struct DeviceRadixSort * * \tparam KeyT [inferred] KeyT type */ - template + template CUB_RUNTIME_FUNCTION static cudaError_t SortKeysDescending( void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - const KeyT *d_keys_in, ///< [in] Pointer to the input data of key data to sort - KeyT *d_keys_out, ///< [out] Pointer to the sorted output sequence of key data + KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort + KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data int num_items, ///< [in] Number of items to sort int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison - int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + = sizeof(typename std::iterator_traits::value_type) * 8, cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. { @@ -785,14 +817,16 @@ struct DeviceRadixSort // create a new double-buffer internally when the `is_overwrite_ok` flag // is not set. constexpr bool is_overwrite_okay = false; - DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); - DoubleBuffer d_values; - return DispatchRadixSort::Dispatch( + return DispatchRadixSort< + true, KeyInputIteratorT, KeyIteratorT, const NullType *, NullType *, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys_in, + d_keys_out, + nullptr, + nullptr, num_items, begin_bit, end_bit, @@ -874,14 +908,15 @@ struct DeviceRadixSort constexpr bool is_overwrite_okay = true; - // Null value type - DoubleBuffer d_values; - - return DispatchRadixSort::Dispatch( + return DispatchRadixSort< + true, const KeyT *, KeyT *, const NullType *, NullType *, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys.Current(), + d_keys.Alternate(), + nullptr, + nullptr, num_items, begin_bit, end_bit, diff --git a/cub/device/device_segmented_radix_sort.cuh b/cub/device/device_segmented_radix_sort.cuh index 30d3028875..87a86a3fd3 100644 --- a/cub/device/device_segmented_radix_sort.cuh +++ b/cub/device/device_segmented_radix_sort.cuh @@ -128,18 +128,20 @@ struct DeviceSegmentedRadixSort * \tparam EndOffsetIteratorT [inferred] Random-access input iterator type for reading segment ending offsets \iterator */ template < - typename KeyT, - typename ValueT, + typename KeyInputIteratorT, + typename KeyIteratorT, + typename ValueInputIteratorT, + typename ValueIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT> CUB_RUNTIME_FUNCTION static cudaError_t SortPairs( void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - const KeyT *d_keys_in, ///< [in] %Device-accessible pointer to the input data of key data to sort - KeyT *d_keys_out, ///< [out] %Device-accessible pointer to the sorted output sequence of key data - const ValueT *d_values_in, ///< [in] %Device-accessible pointer to the corresponding input sequence of associated value items - ValueT *d_values_out, ///< [out] %Device-accessible pointer to the correspondingly-reordered output sequence of associated value items + KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort + KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data + ValueInputIteratorT d_values_in, ///< [in] Iterator to the corresponding input sequence of associated value items + ValueIteratorT d_values_out, ///< [out] Iterator to the correspondingly-reordered output sequence of associated value items int num_items, ///< [in] The total number of items to sort (across all segments) int num_segments, ///< [in] The number of segments that comprise the sorting data BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* @@ -155,11 +157,15 @@ struct DeviceSegmentedRadixSort DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values(const_cast(d_values_in), d_values_out); - return DispatchSegmentedRadixSort::Dispatch( + return DispatchSegmentedRadixSort< + false, KeyInputIteratorT, KeyIteratorT, ValueInputIteratorT, ValueIteratorT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, num_items, num_segments, d_begin_offsets, @@ -261,11 +267,15 @@ struct DeviceSegmentedRadixSort // Signed integer type for global offsets typedef int OffsetT; - return DispatchSegmentedRadixSort::Dispatch( + return DispatchSegmentedRadixSort::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys.Current(), + d_keys.Alternate(), + d_values.Current(), + d_values.Alternate(), num_items, num_segments, d_begin_offsets, @@ -334,18 +344,20 @@ struct DeviceSegmentedRadixSort * \tparam EndOffsetIteratorT [inferred] Random-access input iterator type for reading segment ending offsets \iterator */ template < - typename KeyT, - typename ValueT, + typename KeyInputIteratorT, + typename KeyIteratorT, + typename ValueInputIteratorT, + typename ValueIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT> CUB_RUNTIME_FUNCTION static cudaError_t SortPairsDescending( void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - const KeyT *d_keys_in, ///< [in] %Device-accessible pointer to the input data of key data to sort - KeyT *d_keys_out, ///< [out] %Device-accessible pointer to the sorted output sequence of key data - const ValueT *d_values_in, ///< [in] %Device-accessible pointer to the corresponding input sequence of associated value items - ValueT *d_values_out, ///< [out] %Device-accessible pointer to the correspondingly-reordered output sequence of associated value items + KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort + KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data + ValueInputIteratorT d_values_in, ///< [in] Iterator to the corresponding input sequence of associated value items + ValueIteratorT d_values_out, ///< [out] Iterator to the correspondingly-reordered output sequence of associated value items int num_items, ///< [in] The total number of items to sort (across all segments) int num_segments, ///< [in] The number of segments that comprise the sorting data BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* @@ -358,14 +370,15 @@ struct DeviceSegmentedRadixSort // Signed integer type for global offsets typedef int OffsetT; - DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); - DoubleBuffer d_values(const_cast(d_values_in), d_values_out); - - return DispatchSegmentedRadixSort::Dispatch( + return DispatchSegmentedRadixSort< + true, KeyInputIteratorT, KeyIteratorT, ValueIteratorT, ValueIteratorT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, num_items, num_segments, d_begin_offsets, @@ -467,11 +480,15 @@ struct DeviceSegmentedRadixSort // Signed integer type for global offsets typedef int OffsetT; - return DispatchSegmentedRadixSort::Dispatch( + return DispatchSegmentedRadixSort< + true, const KeyT *, KeyT *, const ValueT *, ValueT *, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys.Current(), + d_keys.Alternate(), + d_values.Current(), + d_values.Alternate(), num_items, num_segments, d_begin_offsets, @@ -540,15 +557,16 @@ struct DeviceSegmentedRadixSort * \tparam EndOffsetIteratorT [inferred] Random-access input iterator type for reading segment ending offsets \iterator */ template < - typename KeyT, + typename KeyInputIteratorT, + typename KeyIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT> CUB_RUNTIME_FUNCTION static cudaError_t SortKeys( void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - const KeyT *d_keys_in, ///< [in] %Device-accessible pointer to the input data of key data to sort - KeyT *d_keys_out, ///< [out] %Device-accessible pointer to the sorted output sequence of key data + KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort + KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data int num_items, ///< [in] The total number of items to sort (across all segments) int num_segments, ///< [in] The number of segments that comprise the sorting data BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* @@ -561,15 +579,15 @@ struct DeviceSegmentedRadixSort // Signed integer type for global offsets typedef int OffsetT; - // Null value type - DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); - DoubleBuffer d_values; - - return DispatchSegmentedRadixSort::Dispatch( + return DispatchSegmentedRadixSort< + false, const KeyT *, KeyT *, const NullType *, NullType *, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys_in, + d_keys_out, + nullptr, + nullptr, num_items, num_segments, d_begin_offsets, @@ -661,14 +679,15 @@ struct DeviceSegmentedRadixSort // Signed integer type for global offsets typedef int OffsetT; - // Null value type - DoubleBuffer d_values; - - return DispatchSegmentedRadixSort::Dispatch( + return DispatchSegmentedRadixSort< + false, const KeyT *, KeyT *, const NullType *, NullType *, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys.Current(), + d_keys.Alternate(), + nullptr, + nullptr, num_items, num_segments, d_begin_offsets, @@ -732,15 +751,16 @@ struct DeviceSegmentedRadixSort * \tparam EndOffsetIteratorT [inferred] Random-access input iterator type for reading segment ending offsets \iterator */ template < - typename KeyT, + typename KeyInputIteratorT, + typename KeyIteratorT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT> CUB_RUNTIME_FUNCTION static cudaError_t SortKeysDescending( void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - const KeyT *d_keys_in, ///< [in] %Device-accessible pointer to the input data of key data to sort - KeyT *d_keys_out, ///< [out] %Device-accessible pointer to the sorted output sequence of key data + KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort + KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data int num_items, ///< [in] The total number of items to sort (across all segments) int num_segments, ///< [in] The number of segments that comprise the sorting data BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* @@ -753,14 +773,15 @@ struct DeviceSegmentedRadixSort // Signed integer type for global offsets typedef int OffsetT; - DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); - DoubleBuffer d_values; - - return DispatchSegmentedRadixSort::Dispatch( + return DispatchSegmentedRadixSort< + true, KeyInputIteratorT, KeyIteratorT, const NullType *, NullType *, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys_in, + d_keys_out, + nullptr, + nullptr, num_items, num_segments, d_begin_offsets, @@ -842,7 +863,7 @@ struct DeviceSegmentedRadixSort DoubleBuffer &d_keys, ///< [in,out] Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys int num_items, ///< [in] The total number of items to sort (across all segments) int num_segments, ///< [in] The number of segments that comprise the sorting data - BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* + BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) @@ -852,14 +873,15 @@ struct DeviceSegmentedRadixSort // Signed integer type for global offsets typedef int OffsetT; - // Null value type - DoubleBuffer d_values; - - return DispatchSegmentedRadixSort::Dispatch( + return DispatchSegmentedRadixSort< + true, const KeyT *, KeyT *, const NullType *, NullType *, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT + >::Dispatch( d_temp_storage, temp_storage_bytes, - d_keys, - d_values, + d_keys.Current(), + d_keys.Alternate(), + nullptr, + nullptr, num_items, num_segments, d_begin_offsets, diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh index f48371e7c8..62a64c6cda 100644 --- a/cub/device/dispatch/dispatch_radix_sort.cuh +++ b/cub/device/dispatch/dispatch_radix_sort.cuh @@ -73,19 +73,20 @@ template < typename ChainedPolicyT, ///< Chained tuning policy bool ALT_DIGIT_BITS, ///< Whether or not to use the alternate (lower-bits) policy bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low - typename KeyT, ///< Key type + typename KeyInputIteratorT, ///< Input key iterator type typename OffsetT> ///< Signed integer type for global offsets __launch_bounds__ (int((ALT_DIGIT_BITS) ? int(ChainedPolicyT::ActivePolicy::AltUpsweepPolicy::BLOCK_THREADS) : int(ChainedPolicyT::ActivePolicy::UpsweepPolicy::BLOCK_THREADS))) __global__ void DeviceRadixSortUpsweepKernel( - const KeyT *d_keys, ///< [in] Input keys buffer + KeyInputIteratorT d_keys, ///< [in] Input keys iterator OffsetT *d_spine, ///< [out] Privatized (per block) digit histograms (striped, i.e., 0s counts from each block, then 1s counts from each block, etc.) OffsetT /*num_items*/, ///< [in] Total number of input data items int current_bit, ///< [in] Bit position of current radix digit int num_bits, ///< [in] Number of bits of current radix digit GridEvenShare even_share) ///< [in] Even-share descriptor for mapan equal number of tiles onto each thread block { + using KeyT = typename std::iterator_traits::value_type; typedef typename If< (ALT_DIGIT_BITS), typename ChainedPolicyT::ActivePolicy::AltUpsweepPolicy, @@ -180,23 +181,27 @@ template < typename ChainedPolicyT, ///< Chained tuning policy bool ALT_DIGIT_BITS, ///< Whether or not to use the alternate (lower-bits) policy bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low - typename KeyT, ///< Key type - typename ValueT, ///< Value type + typename KeyInputIteratorT, ///< Input key iterator type + typename KeyIteratorT, ///< Key iterator type + typename ValueInputIteratorT, ///< Input value iterator type + typename ValueIteratorT, ///< Value iterator type typename OffsetT> ///< Signed integer type for global offsets __launch_bounds__ (int((ALT_DIGIT_BITS) ? int(ChainedPolicyT::ActivePolicy::AltDownsweepPolicy::BLOCK_THREADS) : int(ChainedPolicyT::ActivePolicy::DownsweepPolicy::BLOCK_THREADS))) __global__ void DeviceRadixSortDownsweepKernel( - const KeyT *d_keys_in, ///< [in] Input keys buffer - KeyT *d_keys_out, ///< [in] Output keys buffer - const ValueT *d_values_in, ///< [in] Input values buffer - ValueT *d_values_out, ///< [in] Output values buffer + KeyInputIteratorT d_keys_in, ///< [in] Input keys iterator + KeyIteratorT d_keys_out, ///< [in] Output keys iterator + ValueInputIteratorT d_values_in, ///< [in] Input values iterator + ValueIteratorT d_values_out, ///< [in] Output values iterator OffsetT *d_spine, ///< [in] Scan of privatized (per block) digit histograms (striped, i.e., 0s counts from each block, then 1s counts from each block, etc.) OffsetT num_items, ///< [in] Total number of input data items int current_bit, ///< [in] Bit position of current radix digit int num_bits, ///< [in] Number of bits of current radix digit GridEvenShare even_share) ///< [in] Even-share descriptor for mapan equal number of tiles onto each thread block { + using KeyT = typename std::iterator_traits::value_type; + using ValueT = typename std::iterator_traits::value_type; typedef typename If< (ALT_DIGIT_BITS), typename ChainedPolicyT::ActivePolicy::AltUpsweepPolicy, @@ -243,19 +248,23 @@ __global__ void DeviceRadixSortDownsweepKernel( template < typename ChainedPolicyT, ///< Chained tuning policy bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low - typename KeyT, ///< Key type - typename ValueT, ///< Value type + typename KeyInputIteratorT, ///< Input key iterator type + typename KeyIteratorT, ///< Key iterator type + typename ValueInputIteratorT, ///< Input value iterator type + typename ValueIteratorT, ///< Value iterator type typename OffsetT> ///< Signed integer type for global offsets __launch_bounds__ (int(ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THREADS), 1) __global__ void DeviceRadixSortSingleTileKernel( - const KeyT *d_keys_in, ///< [in] Input keys buffer - KeyT *d_keys_out, ///< [in] Output keys buffer - const ValueT *d_values_in, ///< [in] Input values buffer - ValueT *d_values_out, ///< [in] Output values buffer + KeyInputIteratorT d_keys_in, ///< [in] Input keys iterator + KeyIteratorT d_keys_out, ///< [in] Output keys iterator + ValueInputIteratorT d_values_in, ///< [in] Input values iterator + ValueIteratorT d_values_out, ///< [in] Output values iterator OffsetT num_items, ///< [in] Total number of input data items int current_bit, ///< [in] Bit position of current radix digit int end_bit) ///< [in] The past-the-end (most-significant) bit index needed for key comparison { + using KeyT = typename std::iterator_traits::value_type; + using ValueT = typename std::iterator_traits::value_type; // Constants enum { @@ -357,8 +366,10 @@ template < typename ChainedPolicyT, ///< Chained tuning policy bool ALT_DIGIT_BITS, ///< Whether or not to use the alternate (lower-bits) policy bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low - typename KeyT, ///< Key type - typename ValueT, ///< Value type + typename KeyInputIteratorT, ///< Input key iterator type + typename KeyIteratorT, ///< Key iterator type + typename ValueInputIteratorT, ///< Input value iterator type + typename ValueIteratorT, ///< Value iterator type typename BeginOffsetIteratorT, ///< Random-access input iterator type for reading segment beginning offsets \iterator typename EndOffsetIteratorT, ///< Random-access input iterator type for reading segment ending offsets \iterator typename OffsetT> ///< Signed integer type for global offsets @@ -366,16 +377,18 @@ __launch_bounds__ (int((ALT_DIGIT_BITS) ? ChainedPolicyT::ActivePolicy::AltSegmentedPolicy::BLOCK_THREADS : ChainedPolicyT::ActivePolicy::SegmentedPolicy::BLOCK_THREADS)) __global__ void DeviceSegmentedRadixSortKernel( - const KeyT *d_keys_in, ///< [in] Input keys buffer - KeyT *d_keys_out, ///< [in] Output keys buffer - const ValueT *d_values_in, ///< [in] Input values buffer - ValueT *d_values_out, ///< [in] Output values buffer + KeyInputIteratorT d_keys_in, ///< [in] Input keys iterator + KeyIteratorT d_keys_out, ///< [in] Output keys iterator + ValueInputIteratorT d_values_in, ///< [in] Input values iterator + ValueIteratorT d_values_out, ///< [in] Output values iterator BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. int /*num_segments*/, ///< [in] The number of segments that comprise the sorting data int current_bit, ///< [in] Bit position of current radix digit int pass_bits) ///< [in] Number of bits of current radix digit { + using KeyT = typename std::iterator_traits::value_type; + using ValueT = typename std::iterator_traits::value_type; // // Constants // @@ -527,12 +540,13 @@ __global__ void DeviceSegmentedRadixSortKernel( template < typename ChainedPolicyT, bool IS_DESCENDING, - typename KeyT, + typename KeyInputIteratorT, typename OffsetT> __global__ void __launch_bounds__(ChainedPolicyT::ActivePolicy::HistogramPolicy::BLOCK_THREADS) DeviceRadixSortHistogramKernel - (OffsetT* d_bins_out, const KeyT* d_keys_in, OffsetT num_items, int start_bit, int end_bit) + (OffsetT* d_bins_out, KeyInputIteratorT d_keys_in, OffsetT num_items, int start_bit, int end_bit) { + using KeyT = typename std::iterator_traits::value_type; typedef typename ChainedPolicyT::ActivePolicy::HistogramPolicy HistogramPolicyT; typedef AgentRadixSortHistogram AgentT; __shared__ typename AgentT::TempStorage temp_storage; @@ -543,16 +557,20 @@ DeviceRadixSortHistogramKernel template < typename ChainedPolicyT, bool IS_DESCENDING, - typename KeyT, - typename ValueT, + typename KeyInputIteratorT, + typename KeyIteratorT, + typename ValueInputIteratorT, + typename ValueIteratorT, typename OffsetT, typename AtomicOffsetT = OffsetT> __global__ void __launch_bounds__(ChainedPolicyT::ActivePolicy::OnesweepPolicy::BLOCK_THREADS) DeviceRadixSortOnesweepKernel (AtomicOffsetT* d_lookback, AtomicOffsetT* d_ctrs, OffsetT* d_bins_out, - const OffsetT* d_bins_in, KeyT* d_keys_out, const KeyT* d_keys_in, ValueT* d_values_out, - const ValueT* d_values_in, OffsetT num_items, int current_bit, int num_bits) + const OffsetT* d_bins_in, KeyIteratorT d_keys_out, KeyInputIteratorT d_keys_in, ValueIteratorT d_values_out, + ValueInputIteratorT d_values_in, OffsetT num_items, int current_bit, int num_bits) { + using KeyT = typename std::iterator_traits::value_type; + using ValueT = typename std::iterator_traits::value_type; typedef typename ChainedPolicyT::ActivePolicy::OnesweepPolicy OnesweepPolicyT; typedef AgentRadixSortOnesweep AgentT; __shared__ typename AgentT::TempStorage s; @@ -954,14 +972,21 @@ struct DeviceRadixSortPolicy * Utility class for dispatching the appropriately-tuned kernels for device-wide radix sort */ template < - bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low - typename KeyT, ///< Key type - typename ValueT, ///< Value type - typename OffsetT, ///< Signed integer type for global offsets - typename SelectedPolicy = DeviceRadixSortPolicy > + bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low + typename KeyInputIteratorT, ///< Input key iterator type + typename KeyIteratorT, ///< Key iterator type + typename ValueInputIteratorT, ///< Input value iterator type + typename ValueIteratorT, ///< Value iterator type + typename OffsetT, ///< Signed integer type for global offsets + typename SelectedPolicy = DeviceRadixSortPolicy< + /*KeyT=*/ typename std::iterator_traits::value_type, + /*ValueT=*/ typename std::iterator_traits::value_type, + OffsetT> > struct DispatchRadixSort : SelectedPolicy { + using KeyT = typename std::iterator_traits::value_type; + using ValueT = typename std::iterator_traits::value_type; //------------------------------------------------------------------------------ // Constants //------------------------------------------------------------------------------ @@ -979,8 +1004,10 @@ struct DispatchRadixSort : void *d_temp_storage; ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. size_t &temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - DoubleBuffer &d_keys; ///< [in,out] Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys - DoubleBuffer &d_values; ///< [in,out] Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + KeyInputIteratorT d_keys_in; ///< [in] Iterator for the unsorted input keys + KeyIteratorT d_keys_out; ///< [out] Iterator for the sorted output keys + ValueInputIteratorT d_values_in; ///< [in] Iterator for the unsorted input values + ValueIteratorT d_values_out; ///< [out] Iterator for the sorted output values OffsetT num_items; ///< [in] Number of items to sort int begin_bit; ///< [in] The beginning (least-significant) bit index needed for key comparison int end_bit; ///< [in] The past-the-end (most-significant) bit index needed for key comparison @@ -999,8 +1026,10 @@ struct DispatchRadixSort : DispatchRadixSort( void* d_temp_storage, size_t &temp_storage_bytes, - DoubleBuffer &d_keys, - DoubleBuffer &d_values, + KeyInputIteratorT d_keys_in, + KeyIteratorT d_keys_out, + ValueInputIteratorT d_values_in, + ValueIteratorT d_values_out, OffsetT num_items, int begin_bit, int end_bit, @@ -1011,8 +1040,10 @@ struct DispatchRadixSort : : d_temp_storage(d_temp_storage), temp_storage_bytes(temp_storage_bytes), - d_keys(d_keys), - d_values(d_values), + d_keys_in(d_keys_in), + d_keys_out(d_keys_out), + d_values_in(d_values_in), + d_values_out(d_values_out), num_items(num_items), begin_bit(begin_bit), end_bit(end_bit), @@ -1100,10 +1131,10 @@ struct DispatchRadixSort : template CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t InvokePass( - const KeyT *d_keys_in, - KeyT *d_keys_out, - const ValueT *d_values_in, - ValueT *d_values_out, + KeyInputIteratorT d_keys_in, + KeyIteratorT d_keys_out, + ValueInputIteratorT d_values_in, + ValueIteratorT d_values_out, OffsetT *d_spine, int /*spine_length*/, int ¤t_bit, @@ -1353,7 +1384,8 @@ struct DispatchRadixSort : d_lookback, 0, num_blocks * RADIX_DIGITS * sizeof(AtomicOffsetT), stream))) break; auto onesweep_kernel = DeviceRadixSortOnesweepKernel< - MaxPolicyT, IS_DESCENDING, KeyT, ValueT, OffsetT>; + MaxPolicyT, IS_DESCENDING, KeyInputIteratorT, KeyIteratorT, + ValueInputIteratorT, ValueIteratorT, OffsetT>; onesweep_kernel<<>> (d_lookback, d_ctrs + part * num_passes + pass, part < num_parts - 1 ? @@ -1538,11 +1570,11 @@ struct DispatchRadixSort : // Invoke upsweep-downsweep typedef typename DispatchRadixSort::MaxPolicy MaxPolicyT; return InvokePasses( - DeviceRadixSortUpsweepKernel< MaxPolicyT, false, IS_DESCENDING, KeyT, OffsetT>, - DeviceRadixSortUpsweepKernel< MaxPolicyT, true, IS_DESCENDING, KeyT, OffsetT>, + DeviceRadixSortUpsweepKernel< MaxPolicyT, false, IS_DESCENDING, KeyInputIteratorT, OffsetT>, + DeviceRadixSortUpsweepKernel< MaxPolicyT, true, IS_DESCENDING, KeyInputIteratorT, OffsetT>, RadixSortScanBinsKernel< MaxPolicyT, OffsetT>, - DeviceRadixSortDownsweepKernel< MaxPolicyT, false, IS_DESCENDING, KeyT, ValueT, OffsetT>, - DeviceRadixSortDownsweepKernel< MaxPolicyT, true, IS_DESCENDING, KeyT, ValueT, OffsetT>); + DeviceRadixSortDownsweepKernel< MaxPolicyT, false, IS_DESCENDING, KeyInputIteratorT, KeyIteratorT, ValueInputIteratorT, ValueIteratorT, OffsetT>, + DeviceRadixSortDownsweepKernel< MaxPolicyT, true, IS_DESCENDING, KeyInputIteratorT, KeyIteratorT, ValueInputIteratorT, ValueIteratorT, OffsetT>); } template @@ -1566,7 +1598,8 @@ struct DispatchRadixSort : { // Small, single tile size return InvokeSingleTile( - DeviceRadixSortSingleTileKernel); + DeviceRadixSortSingleTileKernel); } else { @@ -1587,8 +1620,10 @@ struct DispatchRadixSort : static cudaError_t Dispatch( void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - DoubleBuffer &d_keys, ///< [in,out] Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys - DoubleBuffer &d_values, ///< [in,out] Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + KeyInputIteratorT d_keys_in, ///< [in] Iterator for the unsorted input keys + KeyIteratorT d_keys_out, ///< [out] Iterator for the sorted output keys + ValueInputIteratorT d_values_in, ///< [in] Iterator for the unsorted input values + ValueIteratorT d_values_out, ///< [out] Iterator for the sorted output values OffsetT num_items, ///< [in] Number of items to sort int begin_bit, ///< [in] The beginning (least-significant) bit index needed for key comparison int end_bit, ///< [in] The past-the-end (most-significant) bit index needed for key comparison @@ -1607,7 +1642,7 @@ struct DispatchRadixSort : // Create dispatch functor DispatchRadixSort dispatch( d_temp_storage, temp_storage_bytes, - d_keys, d_values, + d_keys_in, d_keys_out, d_values_in, d_values_out, num_items, begin_bit, end_bit, is_overwrite_okay, stream, debug_synchronous, ptx_version); @@ -1632,15 +1667,22 @@ struct DispatchRadixSort : */ template < bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low - typename KeyT, ///< Key type - typename ValueT, ///< Value type - typename BeginOffsetIteratorT, ///< Random-access input iterator type for reading segment beginning offsets \iterator + typename KeyInputIteratorT, ///< Input key iterator type + typename KeyIteratorT, ///< Key iterator type + typename ValueInputIteratorT, ///< Input value iterator type + typename ValueIteratorT, ///< Value iterator type + typename BeginOffsetIteratorT, ///< Random-access input iterator type for reading segment beginning offsets \iterator typename EndOffsetIteratorT, ///< Random-access input iterator type for reading segment ending offsets \iterator typename OffsetT, ///< Signed integer type for global offsets - typename SelectedPolicy = DeviceRadixSortPolicy > + typename SelectedPolicy = DeviceRadixSortPolicy< + /*KeyT=*/typename std::iterator_traits::value_type, + /*ValueT=*/typename std::iterator_traits::value_type, + OffsetT> > struct DispatchSegmentedRadixSort : SelectedPolicy { + using KeyT = typename std::iterator_traits::value_type; + using ValueT = typename std::iterator_traits::value_type; //------------------------------------------------------------------------------ // Constants //------------------------------------------------------------------------------ @@ -1658,8 +1700,10 @@ struct DispatchSegmentedRadixSort : void *d_temp_storage; ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. size_t &temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - DoubleBuffer &d_keys; ///< [in,out] Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys - DoubleBuffer &d_values; ///< [in,out] Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + KeyInputIteratorT d_keys_in; ///< [in] Iterator for the unsorted input keys + KeyIteratorT d_keys_out; ///< [out] Iterator for the sorted output keys + ValueInputIteratorT d_values_in; ///< [in] Iterator for the unsorted input values + ValueIteratorT d_values_out; ///< [out] Iterator for the sorted output values OffsetT num_items; ///< [in] Number of items to sort OffsetT num_segments; ///< [in] The number of segments that comprise the sorting data BeginOffsetIteratorT d_begin_offsets; ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* @@ -1681,8 +1725,10 @@ struct DispatchSegmentedRadixSort : DispatchSegmentedRadixSort( void* d_temp_storage, size_t &temp_storage_bytes, - DoubleBuffer &d_keys, - DoubleBuffer &d_values, + KeyInputIteratorT d_keys_in, + KeyIteratorT d_keys_out, + ValueInputIteratorT d_values_in, + ValueIteratorT d_values_out, OffsetT num_items, OffsetT num_segments, BeginOffsetIteratorT d_begin_offsets, @@ -1696,8 +1742,10 @@ struct DispatchSegmentedRadixSort : : d_temp_storage(d_temp_storage), temp_storage_bytes(temp_storage_bytes), - d_keys(d_keys), - d_values(d_values), + d_keys_in(d_keys_in), + d_keys_out(d_keys_out), + d_values_in(d_values_in), + d_values_out(d_values_out), num_items(num_items), num_segments(num_segments), d_begin_offsets(d_begin_offsets), @@ -1719,12 +1767,12 @@ struct DispatchSegmentedRadixSort : template CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t InvokePass( - const KeyT *d_keys_in, - KeyT *d_keys_out, - const ValueT *d_values_in, - ValueT *d_values_out, - int ¤t_bit, - PassConfigT &pass_config) + KeyInputIteratorT d_keys_in, + KeyIteratorT d_keys_out, + ValueInputIteratorT d_values_in, + ValueIteratorT d_values_out, + int ¤t_bit, + PassConfigT &pass_config) { cudaError error = cudaSuccess; do @@ -1921,8 +1969,10 @@ struct DispatchSegmentedRadixSort : static cudaError_t Dispatch( void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - DoubleBuffer &d_keys, ///< [in,out] Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys - DoubleBuffer &d_values, ///< [in,out] Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + KeyInputIteratorT d_keys_in, ///< [in] Iterator for the unsorted input keys + KeyIteratorT d_keys_out, ///< [out] Iterator for the sorted output keys + ValueInputIteratorT d_values_in, ///< [in] Iterator for the unsorted input values + ValueIteratorT d_values_out, ///< [out] Iterator for the sorted output values int num_items, ///< [in] Number of items to sort int num_segments, ///< [in] The number of segments that comprise the sorting data BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* From 967a5aebcec5861f973bb8ddcbf9f846e3f7256a Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 13 Sep 2021 11:23:22 -0700 Subject: [PATCH 2/8] save --- cub/device/device_segmented_radix_sort.cuh | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/cub/device/device_segmented_radix_sort.cuh b/cub/device/device_segmented_radix_sort.cuh index 87a86a3fd3..ab7271047a 100644 --- a/cub/device/device_segmented_radix_sort.cuh +++ b/cub/device/device_segmented_radix_sort.cuh @@ -147,16 +147,14 @@ struct DeviceSegmentedRadixSort BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison - int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + = sizeof(typename std::iterator_traits::value_type) * 8, cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. { // Signed integer type for global offsets typedef int OffsetT; - DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); - DoubleBuffer d_values(const_cast(d_values_in), d_values_out); - return DispatchSegmentedRadixSort< false, KeyInputIteratorT, KeyIteratorT, ValueInputIteratorT, ValueIteratorT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT >::Dispatch( @@ -363,7 +361,8 @@ struct DeviceSegmentedRadixSort BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison - int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + = sizeof(typename std::iterator_traits::value_type) * 8, cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. { @@ -572,7 +571,8 @@ struct DeviceSegmentedRadixSort BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison - int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + = sizeof(typename std::iterator_traits::value_type) * 8, cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. { @@ -580,7 +580,7 @@ struct DeviceSegmentedRadixSort typedef int OffsetT; return DispatchSegmentedRadixSort< - false, const KeyT *, KeyT *, const NullType *, NullType *, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT + false, KeyInputIteratorT, KeyIteratorT, const NullType *, NullType *, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT >::Dispatch( d_temp_storage, temp_storage_bytes, @@ -766,7 +766,8 @@ struct DeviceSegmentedRadixSort BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_* EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison - int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + = sizeof(typename std::iterator_traits::value_type) * 8, cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. { From ef97e6018c8b66aaadfdef61b25ecf413187ddc3 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 13 Sep 2021 16:01:33 -0700 Subject: [PATCH 3/8] fix --- cub/device/device_radix_sort.cuh | 6 ++++++ cub/device/device_segmented_radix_sort.cuh | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/cub/device/device_radix_sort.cuh b/cub/device/device_radix_sort.cuh index 0422b00ff6..3326a22902 100644 --- a/cub/device/device_radix_sort.cuh +++ b/cub/device/device_radix_sort.cuh @@ -343,6 +343,8 @@ struct DeviceRadixSort is_overwrite_okay, stream, debug_synchronous); + d_keys.selector ^= 1; + d_values.selector ^= 1; } @@ -548,6 +550,8 @@ struct DeviceRadixSort is_overwrite_okay, stream, debug_synchronous); + d_keys.selector ^= 1; + d_values.selector ^= 1; } @@ -742,6 +746,7 @@ struct DeviceRadixSort is_overwrite_okay, stream, debug_synchronous); + d_keys.selector ^= 1; } /** @@ -923,6 +928,7 @@ struct DeviceRadixSort is_overwrite_okay, stream, debug_synchronous); + d_keys.selector ^= 1; } diff --git a/cub/device/device_segmented_radix_sort.cuh b/cub/device/device_segmented_radix_sort.cuh index ab7271047a..3682709b87 100644 --- a/cub/device/device_segmented_radix_sort.cuh +++ b/cub/device/device_segmented_radix_sort.cuh @@ -283,6 +283,8 @@ struct DeviceSegmentedRadixSort true, stream, debug_synchronous); + d_keys.selector ^= 1; + d_values.selector ^= 1; } @@ -497,6 +499,8 @@ struct DeviceSegmentedRadixSort true, stream, debug_synchronous); + d_keys.selector ^= 1; + d_values.selector ^= 1; } @@ -697,6 +701,7 @@ struct DeviceSegmentedRadixSort true, stream, debug_synchronous); + d_keys.selector ^= 1; } /** @@ -892,6 +897,7 @@ struct DeviceSegmentedRadixSort true, stream, debug_synchronous); + d_keys.selector ^= 1; } From aa40169ebf9aa388cac64cec37ff7b5fa820b614 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 13 Sep 2021 16:32:31 -0700 Subject: [PATCH 4/8] fix --- cub/device/dispatch/dispatch_radix_sort.cuh | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh index 62a64c6cda..7b0b724ae1 100644 --- a/cub/device/dispatch/dispatch_radix_sort.cuh +++ b/cub/device/dispatch/dispatch_radix_sort.cuh @@ -1095,10 +1095,10 @@ struct DispatchRadixSort : THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( 1, ActivePolicyT::SingleTilePolicy::BLOCK_THREADS, 0, stream ).doit(single_tile_kernel, - d_keys.Current(), - d_keys.Alternate(), - d_values.Current(), - d_values.Alternate(), + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, num_items, begin_bit, end_bit); @@ -1108,10 +1108,6 @@ struct DispatchRadixSort : // Sync the stream if specified to flush runtime errors if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break; - - // Update selector - d_keys.selector ^= 1; - d_values.selector ^= 1; } while (0); @@ -1354,7 +1350,7 @@ struct DispatchRadixSort : if (CubDebug(error = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &histo_blocks_per_sm, histogram_kernel, HISTO_BLOCK_THREADS, 0))) break; histogram_kernel<<>> - (d_bins, d_keys.Current(), num_items, begin_bit, end_bit); + (d_bins, d_keys_in, num_items, begin_bit, end_bit); if (CubDebug(error = cudaPeekAtLastError())) break; // exclusive sums to determine starts @@ -1364,8 +1360,8 @@ struct DispatchRadixSort : if (CubDebug(error = cudaPeekAtLastError())) break; // use the other buffer if no overwrite is allowed - KeyT* d_keys_tmp = d_keys.Alternate(); - ValueT* d_values_tmp = d_values.Alternate(); + KeyT* d_keys_tmp = d_keys_out; + ValueT* d_values_tmp = d_values_out; if (!is_overwrite_okay && num_passes % 2 == 0) { d_keys.d_buffers[1] = d_keys_tmp2; From a93034d3093c80df746130561b08ae766bcc465d Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 13 Sep 2021 17:03:21 -0700 Subject: [PATCH 5/8] no is_overwrite_okay --- cub/device/device_radix_sort.cuh | 40 --------------- cub/device/dispatch/dispatch_radix_sort.cuh | 54 ++++++++------------- 2 files changed, 21 insertions(+), 73 deletions(-) diff --git a/cub/device/device_radix_sort.cuh b/cub/device/device_radix_sort.cuh index 3326a22902..a2caf40cff 100644 --- a/cub/device/device_radix_sort.cuh +++ b/cub/device/device_radix_sort.cuh @@ -214,12 +214,6 @@ struct DeviceRadixSort // Signed integer type for global offsets typedef int OffsetT; - // We cast away const-ness, but will *not* write to these arrays. - // `DispatchRadixSort::Dispatch` will allocate temporary storage and - // create a new double-buffer internally when the `is_overwrite_ok` flag - // is not set. - constexpr bool is_overwrite_okay = false; - return DispatchRadixSort< false, KeyInputIteratorT, KeyIteratorT, ValueInputIteratorT, ValueIteratorT, OffsetT >::Dispatch( @@ -232,7 +226,6 @@ struct DeviceRadixSort num_items, begin_bit, end_bit, - is_overwrite_okay, stream, debug_synchronous); } @@ -326,8 +319,6 @@ struct DeviceRadixSort // Signed integer type for global offsets typedef int OffsetT; - constexpr bool is_overwrite_okay = true; - return DispatchRadixSort< false , const KeyT *, KeyT *, const ValueT *, ValueT *, OffsetT >::Dispatch( @@ -340,7 +331,6 @@ struct DeviceRadixSort num_items, begin_bit, end_bit, - is_overwrite_okay, stream, debug_synchronous); d_keys.selector ^= 1; @@ -426,12 +416,6 @@ struct DeviceRadixSort // Signed integer type for global offsets typedef int OffsetT; - // We cast away const-ness, but will *not* write to these arrays. - // `DispatchRadixSort::Dispatch` will allocate temporary storage and - // create a new double-buffer internally when the `is_overwrite_ok` flag - // is not set. - constexpr bool is_overwrite_okay = false; - return DispatchRadixSort< true, KeyInputIteratorT, KeyIteratorT, ValueInputIteratorT, ValueIteratorT, OffsetT >::Dispatch( @@ -444,7 +428,6 @@ struct DeviceRadixSort num_items, begin_bit, end_bit, - is_overwrite_okay, stream, debug_synchronous); } @@ -533,8 +516,6 @@ struct DeviceRadixSort // Signed integer type for global offsets typedef int OffsetT; - constexpr bool is_overwrite_okay = true; - return DispatchRadixSort< true, const KeyT *, KeyT *, const ValueT *, ValueT *, OffsetT >::Dispatch( @@ -547,7 +528,6 @@ struct DeviceRadixSort num_items, begin_bit, end_bit, - is_overwrite_okay, stream, debug_synchronous); d_keys.selector ^= 1; @@ -631,12 +611,6 @@ struct DeviceRadixSort // Signed integer type for global offsets typedef int OffsetT; - // We cast away const-ness, but will *not* write to these arrays. - // `DispatchRadixSort::Dispatch` will allocate temporary storage and - // create a new double-buffer internally when the `is_overwrite_ok` flag - // is not set. - constexpr bool is_overwrite_okay = false; - return DispatchRadixSort< false, KeyInputIteratorT, KeyIteratorT, const NullType *, NullType *, OffsetT >::Dispatch( @@ -649,7 +623,6 @@ struct DeviceRadixSort num_items, begin_bit, end_bit, - is_overwrite_okay, stream, debug_synchronous); } @@ -729,8 +702,6 @@ struct DeviceRadixSort // Signed integer type for global offsets typedef int OffsetT; - constexpr bool is_overwrite_okay = true; - return DispatchRadixSort< false, const KeyT *, KeyT *, const NullType *, NullType *, OffsetT >::Dispatch( @@ -743,7 +714,6 @@ struct DeviceRadixSort num_items, begin_bit, end_bit, - is_overwrite_okay, stream, debug_synchronous); d_keys.selector ^= 1; @@ -817,12 +787,6 @@ struct DeviceRadixSort // Signed integer type for global offsets typedef int OffsetT; - // We cast away const-ness, but will *not* write to these arrays. - // `DispatchRadixSort::Dispatch` will allocate temporary storage and - // create a new double-buffer internally when the `is_overwrite_ok` flag - // is not set. - constexpr bool is_overwrite_okay = false; - return DispatchRadixSort< true, KeyInputIteratorT, KeyIteratorT, const NullType *, NullType *, OffsetT >::Dispatch( @@ -835,7 +799,6 @@ struct DeviceRadixSort num_items, begin_bit, end_bit, - is_overwrite_okay, stream, debug_synchronous); } @@ -911,8 +874,6 @@ struct DeviceRadixSort // Signed integer type for global offsets typedef int OffsetT; - constexpr bool is_overwrite_okay = true; - return DispatchRadixSort< true, const KeyT *, KeyT *, const NullType *, NullType *, OffsetT >::Dispatch( @@ -925,7 +886,6 @@ struct DeviceRadixSort num_items, begin_bit, end_bit, - is_overwrite_okay, stream, debug_synchronous); d_keys.selector ^= 1; diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh index 7b0b724ae1..fcece392ce 100644 --- a/cub/device/dispatch/dispatch_radix_sort.cuh +++ b/cub/device/dispatch/dispatch_radix_sort.cuh @@ -1014,7 +1014,6 @@ struct DispatchRadixSort : cudaStream_t stream; ///< [in] CUDA stream to launch kernels within. Default is stream0. bool debug_synchronous; ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. int ptx_version; ///< [in] PTX version - bool is_overwrite_okay; ///< [in] Whether is okay to overwrite source buffers //------------------------------------------------------------------------------ @@ -1033,7 +1032,6 @@ struct DispatchRadixSort : OffsetT num_items, int begin_bit, int end_bit, - bool is_overwrite_okay, cudaStream_t stream, bool debug_synchronous, int ptx_version) @@ -1049,8 +1047,7 @@ struct DispatchRadixSort : end_bit(end_bit), stream(stream), debug_synchronous(debug_synchronous), - ptx_version(ptx_version), - is_overwrite_okay(is_overwrite_okay) + ptx_version(ptx_version) {} @@ -1308,9 +1305,9 @@ struct DispatchRadixSort : // lookback max_num_blocks * RADIX_DIGITS * sizeof(AtomicOffsetT), // extra key buffer - is_overwrite_okay || num_passes <= 1 ? 0 : num_items * sizeof(KeyT), + num_passes <= 1 ? 0 : num_items * sizeof(KeyT), // extra value buffer - is_overwrite_okay || num_passes <= 1 ? 0 : num_items * value_size, + num_passes <= 1 ? 0 : num_items * value_size, // counters num_parts * num_passes * sizeof(AtomicOffsetT), }; @@ -1362,7 +1359,7 @@ struct DispatchRadixSort : // use the other buffer if no overwrite is allowed KeyT* d_keys_tmp = d_keys_out; ValueT* d_values_tmp = d_values_out; - if (!is_overwrite_okay && num_passes % 2 == 0) + if (num_passes % 2 == 0) { d_keys.d_buffers[1] = d_keys_tmp2; d_values.d_buffers[1] = d_values_tmp2; @@ -1396,7 +1393,7 @@ struct DispatchRadixSort : } // use the temporary buffers if no overwrite is allowed - if (!is_overwrite_okay && pass == 0) + if (pass == 0) { d_keys = num_passes % 2 == 0 ? DoubleBuffer(d_keys_tmp, d_keys_tmp2) : @@ -1488,8 +1485,8 @@ struct DispatchRadixSort : size_t allocation_sizes[3] = { spine_length * sizeof(OffsetT), // bytes needed for privatized block digit histograms - (is_overwrite_okay) ? 0 : num_items * sizeof(KeyT), // bytes needed for 3rd keys buffer - (is_overwrite_okay || (KEYS_ONLY)) ? 0 : num_items * sizeof(ValueT), // bytes needed for 3rd values buffer + num_items * sizeof(KeyT), // bytes needed for 3rd keys buffer + (KEYS_ONLY) ? 0 : num_items * sizeof(ValueT), // bytes needed for 3rd values buffer }; // Alias the temporary allocations from the single storage blob (or compute the necessary size of the blob) @@ -1510,12 +1507,12 @@ struct DispatchRadixSort : OffsetT *d_spine = static_cast(allocations[0]); DoubleBuffer d_keys_remaining_passes( - (is_overwrite_okay || is_num_passes_odd) ? d_keys.Alternate() : static_cast(allocations[1]), - (is_overwrite_okay) ? d_keys.Current() : (is_num_passes_odd) ? static_cast(allocations[1]) : d_keys.Alternate()); + is_num_passes_odd ? d_keys.Alternate() : static_cast(allocations[1]), + is_num_passes_odd ? static_cast(allocations[1]) : d_keys.Alternate()); DoubleBuffer d_values_remaining_passes( - (is_overwrite_okay || is_num_passes_odd) ? d_values.Alternate() : static_cast(allocations[2]), - (is_overwrite_okay) ? d_values.Current() : (is_num_passes_odd) ? static_cast(allocations[2]) : d_values.Alternate()); + is_num_passes_odd ? d_values.Alternate() : static_cast(allocations[2]), + is_num_passes_odd ? static_cast(allocations[2]) : d_values.Alternate()); // Run first pass, consuming from the input's current buffers int current_bit = begin_bit; @@ -1540,9 +1537,7 @@ struct DispatchRadixSort : } // Update selector - if (!is_overwrite_okay) { - num_passes = 1; // Sorted data always ends up in the other vector - } + num_passes = 1; // Sorted data always ends up in the other vector d_keys.selector = (d_keys.selector + num_passes) & 1; d_values.selector = (d_values.selector + num_passes) & 1; @@ -1623,7 +1618,6 @@ struct DispatchRadixSort : OffsetT num_items, ///< [in] Number of items to sort int begin_bit, ///< [in] The beginning (least-significant) bit index needed for key comparison int end_bit, ///< [in] The past-the-end (most-significant) bit index needed for key comparison - bool is_overwrite_okay, ///< [in] Whether is okay to overwrite source buffers cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. bool debug_synchronous) ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. { @@ -1639,7 +1633,7 @@ struct DispatchRadixSort : DispatchRadixSort dispatch( d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, d_values_in, d_values_out, - num_items, begin_bit, end_bit, is_overwrite_okay, + num_items, begin_bit, end_bit, stream, debug_synchronous, ptx_version); // Dispatch to chained policy @@ -1709,7 +1703,6 @@ struct DispatchSegmentedRadixSort : cudaStream_t stream; ///< [in] CUDA stream to launch kernels within. Default is stream0. bool debug_synchronous; ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. int ptx_version; ///< [in] PTX version - bool is_overwrite_okay; ///< [in] Whether is okay to overwrite source buffers //------------------------------------------------------------------------------ @@ -1731,7 +1724,6 @@ struct DispatchSegmentedRadixSort : EndOffsetIteratorT d_end_offsets, int begin_bit, int end_bit, - bool is_overwrite_okay, cudaStream_t stream, bool debug_synchronous, int ptx_version) @@ -1748,7 +1740,6 @@ struct DispatchSegmentedRadixSort : d_end_offsets(d_end_offsets), begin_bit(begin_bit), end_bit(end_bit), - is_overwrite_okay(is_overwrite_okay), stream(stream), debug_synchronous(debug_synchronous), ptx_version(ptx_version) @@ -1866,8 +1857,8 @@ struct DispatchSegmentedRadixSort : void* allocations[2] = {}; size_t allocation_sizes[2] = { - (is_overwrite_okay) ? 0 : num_items * sizeof(KeyT), // bytes needed for 3rd keys buffer - (is_overwrite_okay || (KEYS_ONLY)) ? 0 : num_items * sizeof(ValueT), // bytes needed for 3rd values buffer + num_items * sizeof(KeyT), // bytes needed for 3rd keys buffer + (KEYS_ONLY) ? 0 : num_items * sizeof(ValueT), // bytes needed for 3rd values buffer }; // Alias the temporary allocations from the single storage blob (or compute the necessary size of the blob) @@ -1891,12 +1882,12 @@ struct DispatchSegmentedRadixSort : int alt_end_bit = CUB_MIN(end_bit, begin_bit + (max_alt_passes * alt_radix_bits)); DoubleBuffer d_keys_remaining_passes( - (is_overwrite_okay || is_num_passes_odd) ? d_keys.Alternate() : static_cast(allocations[0]), - (is_overwrite_okay) ? d_keys.Current() : (is_num_passes_odd) ? static_cast(allocations[0]) : d_keys.Alternate()); + is_num_passes_odd ? d_keys.Alternate() : static_cast(allocations[0]), + is_num_passes_odd ? static_cast(allocations[0]) : d_keys.Alternate()); DoubleBuffer d_values_remaining_passes( - (is_overwrite_okay || is_num_passes_odd) ? d_values.Alternate() : static_cast(allocations[1]), - (is_overwrite_okay) ? d_values.Current() : (is_num_passes_odd) ? static_cast(allocations[1]) : d_values.Alternate()); + is_num_passes_odd ? d_values.Alternate() : static_cast(allocations[1]), + is_num_passes_odd ? static_cast(allocations[1]) : d_values.Alternate()); // Run first pass, consuming from the input's current buffers int current_bit = begin_bit; @@ -1922,9 +1913,7 @@ struct DispatchSegmentedRadixSort : } // Update selector - if (!is_overwrite_okay) { - num_passes = 1; // Sorted data always ends up in the other vector - } + num_passes = 1; // Sorted data always ends up in the other vector d_keys.selector = (d_keys.selector + num_passes) & 1; d_values.selector = (d_values.selector + num_passes) & 1; @@ -1975,7 +1964,6 @@ struct DispatchSegmentedRadixSort : EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty. int begin_bit, ///< [in] The beginning (least-significant) bit index needed for key comparison int end_bit, ///< [in] The past-the-end (most-significant) bit index needed for key comparison - bool is_overwrite_okay, ///< [in] Whether is okay to overwrite source buffers cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. bool debug_synchronous) ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. { @@ -1992,7 +1980,7 @@ struct DispatchSegmentedRadixSort : d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, num_segments, d_begin_offsets, d_end_offsets, - begin_bit, end_bit, is_overwrite_okay, + begin_bit, end_bit, stream, debug_synchronous, ptx_version); // Dispatch to chained policy From fd7dd1beca5d57986471c2b1bf709a51480caa9c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 13 Sep 2021 17:36:08 -0700 Subject: [PATCH 6/8] save --- cub/device/dispatch/dispatch_radix_sort.cuh | 183 ++++++++++++++++---- 1 file changed, 146 insertions(+), 37 deletions(-) diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh index fcece392ce..438047503c 100644 --- a/cub/device/dispatch/dispatch_radix_sort.cuh +++ b/cub/device/dispatch/dispatch_radix_sort.cuh @@ -1322,8 +1322,8 @@ struct DispatchRadixSort : OffsetT* d_bins = (OffsetT*)allocations[0]; AtomicOffsetT* d_lookback = (AtomicOffsetT*)allocations[1]; - KeyT* d_keys_tmp2 = (KeyT*)allocations[2]; - ValueT* d_values_tmp2 = (ValueT*)allocations[3]; + KeyT* d_keys_tmp = (KeyT*)allocations[2]; + ValueT* d_values_tmp = (ValueT*)allocations[3]; AtomicOffsetT* d_ctrs = (AtomicOffsetT*)allocations[4]; do { @@ -1355,15 +1355,13 @@ struct DispatchRadixSort : DeviceRadixSortExclusiveSumKernel <<>>(d_bins); if (CubDebug(error = cudaPeekAtLastError())) break; - - // use the other buffer if no overwrite is allowed - KeyT* d_keys_tmp = d_keys_out; - ValueT* d_values_tmp = d_values_out; - if (num_passes % 2 == 0) - { - d_keys.d_buffers[1] = d_keys_tmp2; - d_values.d_buffers[1] = d_values_tmp2; - } + + bool output_is_tmp = (num_passes % 2 == 0); + enum InputMode { + INPUT, + TMP_STORAGE, + OUTPUT + } input_mode = INPUT; for (int current_bit = begin_bit, pass = 0; current_bit < end_bit; current_bit += RADIX_BITS, ++pass) @@ -1376,34 +1374,145 @@ struct DispatchRadixSort : if (CubDebug(error = cudaMemsetAsync( d_lookback, 0, num_blocks * RADIX_DIGITS * sizeof(AtomicOffsetT), stream))) break; - auto onesweep_kernel = DeviceRadixSortOnesweepKernel< - MaxPolicyT, IS_DESCENDING, KeyInputIteratorT, KeyIteratorT, - ValueInputIteratorT, ValueIteratorT, OffsetT>; - onesweep_kernel<<>> - (d_lookback, d_ctrs + part * num_passes + pass, - part < num_parts - 1 ? - d_bins + ((part + 1) * num_passes + pass) * RADIX_DIGITS : NULL, - d_bins + (part * num_passes + pass) * RADIX_DIGITS, - d_keys.Alternate(), - d_keys.Current() + part * PART_SIZE, - d_values.Alternate(), - d_values.Current() + part * PART_SIZE, - part_num_items, current_bit, num_bits); + if (output_is_tmp) { + using KeyOutIterT = KeyT *; + using ValueOutIterT = ValueT *; + KeyOutIterT d_keys_out_ = d_keys_tmp; + ValueOutIterT d_values_out_ = d_values_tmp; + switch (input_mode) { + case INPUT: { + using KeyInIterT = KeyInputIteratorT; + using ValueInIterT = ValueInputIteratorT; + auto onesweep_kernel = DeviceRadixSortOnesweepKernel< + MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT, + ValueInIterT, ValueOutIterT, OffsetT>; + KeyInIterT d_keys_in_ = d_keys_in; + ValueInIterT d_values_in_ = d_values_in; + onesweep_kernel<<>> + (d_lookback, d_ctrs + part * num_passes + pass, + part < num_parts - 1 ? + d_bins + ((part + 1) * num_passes + pass) * RADIX_DIGITS : NULL, + d_bins + (part * num_passes + pass) * RADIX_DIGITS, + d_keys_out_, + d_keys_in_ + part * PART_SIZE, + d_values_out_, + d_values_in_ + part * PART_SIZE, + part_num_items, current_bit, num_bits); + break; + } + case TMP_STORAGE: { + using KeyInIterT = KeyT *; + using ValueInIterT = ValueT *; + auto onesweep_kernel = DeviceRadixSortOnesweepKernel< + MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT, + ValueInIterT, ValueOutIterT, OffsetT>; + KeyInIterT d_keys_in_ = d_keys_tmp; + ValueInIterT d_values_in_ = d_values_tmp; + onesweep_kernel<<>> + (d_lookback, d_ctrs + part * num_passes + pass, + part < num_parts - 1 ? + d_bins + ((part + 1) * num_passes + pass) * RADIX_DIGITS : NULL, + d_bins + (part * num_passes + pass) * RADIX_DIGITS, + d_keys_out_, + d_keys_in_ + part * PART_SIZE, + d_values_out_, + d_values_in_ + part * PART_SIZE, + part_num_items, current_bit, num_bits); + break; + } + case OUTPUT: { + using KeyInIterT = KeyIteratorT; + using ValueInIterT = ValueIteratorT; + auto onesweep_kernel = DeviceRadixSortOnesweepKernel< + MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT, + ValueInIterT, ValueOutIterT, OffsetT>; + KeyInIterT d_keys_in_ = d_keys_out; + ValueInIterT d_values_in_ = d_values_out; + onesweep_kernel<<>> + (d_lookback, d_ctrs + part * num_passes + pass, + part < num_parts - 1 ? + d_bins + ((part + 1) * num_passes + pass) * RADIX_DIGITS : NULL, + d_bins + (part * num_passes + pass) * RADIX_DIGITS, + d_keys_out_, + d_keys_in_ + part * PART_SIZE, + d_values_out_, + d_values_in_ + part * PART_SIZE, + part_num_items, current_bit, num_bits); + break; + } + } + } else { + using KeyOutIterT = KeyIteratorT; + using ValueOutIterT = ValueIteratorT; + KeyOutIterT d_keys_out_ = d_keys_out; + ValueOutIterT d_values_out_ = d_values_out; + switch (input_mode) { + case INPUT: { + using KeyInIterT = KeyInputIteratorT; + using ValueInIterT = ValueInputIteratorT; + auto onesweep_kernel = DeviceRadixSortOnesweepKernel< + MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT, + ValueInIterT, ValueOutIterT, OffsetT>; + KeyInIterT d_keys_in_ = d_keys_in; + ValueInIterT d_values_in_ = d_values_in; + onesweep_kernel<<>> + (d_lookback, d_ctrs + part * num_passes + pass, + part < num_parts - 1 ? + d_bins + ((part + 1) * num_passes + pass) * RADIX_DIGITS : NULL, + d_bins + (part * num_passes + pass) * RADIX_DIGITS, + d_keys_out_, + d_keys_in_ + part * PART_SIZE, + d_values_out_, + d_values_in_ + part * PART_SIZE, + part_num_items, current_bit, num_bits); + break; + } + case TMP_STORAGE: { + using KeyInIterT = KeyT *; + using ValueInIterT = ValueT *; + auto onesweep_kernel = DeviceRadixSortOnesweepKernel< + MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT, + ValueInIterT, ValueOutIterT, OffsetT>; + KeyInIterT d_keys_in_ = d_keys_tmp; + ValueInIterT d_values_in_ = d_values_tmp; + onesweep_kernel<<>> + (d_lookback, d_ctrs + part * num_passes + pass, + part < num_parts - 1 ? + d_bins + ((part + 1) * num_passes + pass) * RADIX_DIGITS : NULL, + d_bins + (part * num_passes + pass) * RADIX_DIGITS, + d_keys_out_, + d_keys_in_ + part * PART_SIZE, + d_values_out_, + d_values_in_ + part * PART_SIZE, + part_num_items, current_bit, num_bits); + break; + } + case OUTPUT: { + using KeyInIterT = KeyIteratorT; + using ValueInIterT = ValueIteratorT; + auto onesweep_kernel = DeviceRadixSortOnesweepKernel< + MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT, + ValueInIterT, ValueOutIterT, OffsetT>; + KeyInIterT d_keys_in_ = d_keys_out; + ValueInIterT d_values_in_ = d_values_out; + onesweep_kernel<<>> + (d_lookback, d_ctrs + part * num_passes + pass, + part < num_parts - 1 ? + d_bins + ((part + 1) * num_passes + pass) * RADIX_DIGITS : NULL, + d_bins + (part * num_passes + pass) * RADIX_DIGITS, + d_keys_out_, + d_keys_in_ + part * PART_SIZE, + d_values_out_, + d_values_in_ + part * PART_SIZE, + part_num_items, current_bit, num_bits); + break; + } + } + } if (CubDebug(error = cudaPeekAtLastError())) break; } - - // use the temporary buffers if no overwrite is allowed - if (pass == 0) - { - d_keys = num_passes % 2 == 0 ? - DoubleBuffer(d_keys_tmp, d_keys_tmp2) : - DoubleBuffer(d_keys_tmp2, d_keys_tmp); - d_values = num_passes % 2 == 0 ? - DoubleBuffer(d_values_tmp, d_values_tmp2) : - DoubleBuffer(d_values_tmp2, d_values_tmp); - } - d_keys.selector ^= 1; - d_values.selector ^= 1; + input_mode = output_is_tmp ? TMP_STORAGE : OUTPUT; + output_is_tmp = !output_is_tmp; } } while (0); From 2eee39c98e74129c6216ce2a569cd0da9107551a Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 14 Sep 2021 08:56:16 -0700 Subject: [PATCH 7/8] save --- cub/device/dispatch/dispatch_radix_sort.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh index 438047503c..4ec8b12840 100644 --- a/cub/device/dispatch/dispatch_radix_sort.cuh +++ b/cub/device/dispatch/dispatch_radix_sort.cuh @@ -1343,7 +1343,7 @@ struct DispatchRadixSort : const int HISTO_BLOCK_THREADS = ActivePolicyT::HistogramPolicy::BLOCK_THREADS; int histo_blocks_per_sm = 1; auto histogram_kernel = DeviceRadixSortHistogramKernel< - MaxPolicyT, IS_DESCENDING, KeyT, OffsetT>; + MaxPolicyT, IS_DESCENDING, KeyInputIteratorT, OffsetT>; if (CubDebug(error = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &histo_blocks_per_sm, histogram_kernel, HISTO_BLOCK_THREADS, 0))) break; histogram_kernel<<>> From 8f0073f35434f8950e5c4febf484b956d62082e3 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 14 Sep 2021 10:02:06 -0700 Subject: [PATCH 8/8] fix --- cub/device/dispatch/dispatch_radix_sort.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh index 4ec8b12840..70044b3fdc 100644 --- a/cub/device/dispatch/dispatch_radix_sort.cuh +++ b/cub/device/dispatch/dispatch_radix_sort.cuh @@ -1401,8 +1401,8 @@ struct DispatchRadixSort : break; } case TMP_STORAGE: { - using KeyInIterT = KeyT *; - using ValueInIterT = ValueT *; + using KeyInIterT = const KeyT *; + using ValueInIterT = const ValueT *; auto onesweep_kernel = DeviceRadixSortOnesweepKernel< MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT, ValueInIterT, ValueOutIterT, OffsetT>; @@ -1468,8 +1468,8 @@ struct DispatchRadixSort : break; } case TMP_STORAGE: { - using KeyInIterT = KeyT *; - using ValueInIterT = ValueT *; + using KeyInIterT = const KeyT *; + using ValueInIterT = const ValueT *; auto onesweep_kernel = DeviceRadixSortOnesweepKernel< MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT, ValueInIterT, ValueOutIterT, OffsetT>;