diff --git a/cub/device/device_radix_sort.cuh b/cub/device/device_radix_sort.cuh
index 4d540568a1..a2caf40cff 100644
--- a/cub/device/device_radix_sort.cuh
+++ b/cub/device/device_radix_sort.cuh
@@ -179,46 +179,53 @@ struct DeviceRadixSort
*
* \endcode
*
- * \tparam KeyT [inferred] KeyT type
- * \tparam ValueT [inferred] ValueT type
+ * \tparam KeyInputIteratorT is a model of Random Access Iterator,
+ * \p KeyInputIteratorT is mutable, and \p KeyInputIteratorT's \c value_type is
+ * a model of LessThan Comparable,
+ * and the ordering relation on \p KeyInputIteratorT's \c value_type is a strict weak ordering, as defined in the
+ * LessThan Comparable requirements.
+ * \tparam ValueInputIteratorT is a model of Random Access Iterator.
+ * \tparam KeyIteratorT is a model of Random Access Iterator,
+ * \p KeyIteratorT is mutable, and \p KeyIteratorT's \c value_type is
+ * a model of LessThan Comparable,
+ * and the ordering relation on \p KeyIteratorT's \c value_type is a strict weak ordering, as defined in the
+ * LessThan Comparable requirements.
+ * \tparam ValueIteratorT is a model of Random Access Iterator.
*/
- template <
- typename KeyT,
- typename ValueT>
+ template
CUB_RUNTIME_FUNCTION
static cudaError_t SortPairs(
void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
- const KeyT *d_keys_in, ///< [in] Pointer to the input data of key data to sort
- KeyT *d_keys_out, ///< [out] Pointer to the sorted output sequence of key data
- const ValueT *d_values_in, ///< [in] Pointer to the corresponding input sequence of associated value items
- ValueT *d_values_out, ///< [out] Pointer to the correspondingly-reordered output sequence of associated value items
+ KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort
+ KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data
+ ValueInputIteratorT d_values_in, ///< [in] Iterator to the corresponding input sequence of associated value items
+ ValueIteratorT d_values_out, ///< [out] Iterator to the correspondingly-reordered output sequence of associated value items
int num_items, ///< [in] Number of items to sort
int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison
- int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ = sizeof(typename std::iterator_traits::value_type) * 8,
cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0.
bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
{
// Signed integer type for global offsets
typedef int OffsetT;
- // We cast away const-ness, but will *not* write to these arrays.
- // `DispatchRadixSort::Dispatch` will allocate temporary storage and
- // create a new double-buffer internally when the `is_overwrite_ok` flag
- // is not set.
- constexpr bool is_overwrite_okay = false;
- DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out);
- DoubleBuffer d_values(const_cast(d_values_in), d_values_out);
-
- return DispatchRadixSort::Dispatch(
+ return DispatchRadixSort<
+ false, KeyInputIteratorT, KeyIteratorT, ValueInputIteratorT, ValueIteratorT, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys_in,
+ d_keys_out,
+ d_values_in,
+ d_values_out,
num_items,
begin_bit,
end_bit,
- is_overwrite_okay,
stream,
debug_synchronous);
}
@@ -312,19 +319,22 @@ struct DeviceRadixSort
// Signed integer type for global offsets
typedef int OffsetT;
- constexpr bool is_overwrite_okay = true;
-
- return DispatchRadixSort::Dispatch(
+ return DispatchRadixSort<
+ false , const KeyT *, KeyT *, const ValueT *, ValueT *, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys.Current(),
+ d_keys.Alternate(),
+ d_values.Current(),
+ d_values.Alternate(),
num_items,
begin_bit,
end_bit,
- is_overwrite_okay,
stream,
debug_synchronous);
+ d_keys.selector ^= 1;
+ d_values.selector ^= 1;
}
@@ -384,43 +394,40 @@ struct DeviceRadixSort
* \tparam KeyT [inferred] KeyT type
* \tparam ValueT [inferred] ValueT type
*/
- template <
- typename KeyT,
- typename ValueT>
+ template
CUB_RUNTIME_FUNCTION
static cudaError_t SortPairsDescending(
void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
- const KeyT *d_keys_in, ///< [in] Pointer to the input data of key data to sort
- KeyT *d_keys_out, ///< [out] Pointer to the sorted output sequence of key data
- const ValueT *d_values_in, ///< [in] Pointer to the corresponding input sequence of associated value items
- ValueT *d_values_out, ///< [out] Pointer to the correspondingly-reordered output sequence of associated value items
+ KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort
+ KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data
+ ValueInputIteratorT d_values_in, ///< [in] Iterator to the corresponding input sequence of associated value items
+ ValueIteratorT d_values_out, ///< [out] Iterator to the correspondingly-reordered output sequence of associated value items
int num_items, ///< [in] Number of items to sort
int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison
- int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ = sizeof(typename std::iterator_traits::value_type) * 8,
cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0.
bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
{
// Signed integer type for global offsets
typedef int OffsetT;
- // We cast away const-ness, but will *not* write to these arrays.
- // `DispatchRadixSort::Dispatch` will allocate temporary storage and
- // create a new double-buffer internally when the `is_overwrite_ok` flag
- // is not set.
- constexpr bool is_overwrite_okay = false;
- DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out);
- DoubleBuffer d_values(const_cast(d_values_in), d_values_out);
-
- return DispatchRadixSort::Dispatch(
+ return DispatchRadixSort<
+ true, KeyInputIteratorT, KeyIteratorT, ValueInputIteratorT, ValueIteratorT, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys_in,
+ d_keys_out,
+ d_values_in,
+ d_values_out,
num_items,
begin_bit,
end_bit,
- is_overwrite_okay,
stream,
debug_synchronous);
}
@@ -509,19 +516,22 @@ struct DeviceRadixSort
// Signed integer type for global offsets
typedef int OffsetT;
- constexpr bool is_overwrite_okay = true;
-
- return DispatchRadixSort::Dispatch(
+ return DispatchRadixSort<
+ true, const KeyT *, KeyT *, const ValueT *, ValueT *, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys.Current(),
+ d_keys.Alternate(),
+ d_values.Current(),
+ d_values.Alternate(),
num_items,
begin_bit,
end_bit,
- is_overwrite_okay,
stream,
debug_synchronous);
+ d_keys.selector ^= 1;
+ d_values.selector ^= 1;
}
@@ -583,40 +593,36 @@ struct DeviceRadixSort
*
* \tparam KeyT [inferred] KeyT type
*/
- template
+ template
CUB_RUNTIME_FUNCTION
static cudaError_t SortKeys(
void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
- const KeyT *d_keys_in, ///< [in] Pointer to the input data of key data to sort
- KeyT *d_keys_out, ///< [out] Pointer to the sorted output sequence of key data
+ KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort
+ KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data
int num_items, ///< [in] Number of items to sort
int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison
- int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ = sizeof(typename std::iterator_traits::value_type) * 8,
cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0.
bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
{
// Signed integer type for global offsets
typedef int OffsetT;
- // We cast away const-ness, but will *not* write to these arrays.
- // `DispatchRadixSort::Dispatch` will allocate temporary storage and
- // create a new double-buffer internally when the `is_overwrite_ok` flag
- // is not set.
- constexpr bool is_overwrite_okay = false;
- DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out);
- // Null value type
- DoubleBuffer d_values;
-
- return DispatchRadixSort::Dispatch(
+ return DispatchRadixSort<
+ false, KeyInputIteratorT, KeyIteratorT, const NullType *, NullType *, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys_in,
+ d_keys_out,
+ nullptr,
+ nullptr,
num_items,
begin_bit,
end_bit,
- is_overwrite_okay,
stream,
debug_synchronous);
}
@@ -696,22 +702,21 @@ struct DeviceRadixSort
// Signed integer type for global offsets
typedef int OffsetT;
- constexpr bool is_overwrite_okay = true;
-
- // Null value type
- DoubleBuffer d_values;
-
- return DispatchRadixSort::Dispatch(
+ return DispatchRadixSort<
+ false, const KeyT *, KeyT *, const NullType *, NullType *, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys.Current(),
+ d_keys.Alternate(),
+ nullptr,
+ nullptr,
num_items,
begin_bit,
end_bit,
- is_overwrite_okay,
stream,
debug_synchronous);
+ d_keys.selector ^= 1;
}
/**
@@ -764,39 +769,36 @@ struct DeviceRadixSort
*
* \tparam KeyT [inferred] KeyT type
*/
- template
+ template
CUB_RUNTIME_FUNCTION
static cudaError_t SortKeysDescending(
void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
- const KeyT *d_keys_in, ///< [in] Pointer to the input data of key data to sort
- KeyT *d_keys_out, ///< [out] Pointer to the sorted output sequence of key data
+ KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort
+ KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data
int num_items, ///< [in] Number of items to sort
int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison
- int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ = sizeof(typename std::iterator_traits::value_type) * 8,
cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0.
bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
{
// Signed integer type for global offsets
typedef int OffsetT;
- // We cast away const-ness, but will *not* write to these arrays.
- // `DispatchRadixSort::Dispatch` will allocate temporary storage and
- // create a new double-buffer internally when the `is_overwrite_ok` flag
- // is not set.
- constexpr bool is_overwrite_okay = false;
- DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out);
- DoubleBuffer d_values;
-
- return DispatchRadixSort::Dispatch(
+ return DispatchRadixSort<
+ true, KeyInputIteratorT, KeyIteratorT, const NullType *, NullType *, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys_in,
+ d_keys_out,
+ nullptr,
+ nullptr,
num_items,
begin_bit,
end_bit,
- is_overwrite_okay,
stream,
debug_synchronous);
}
@@ -872,22 +874,21 @@ struct DeviceRadixSort
// Signed integer type for global offsets
typedef int OffsetT;
- constexpr bool is_overwrite_okay = true;
-
- // Null value type
- DoubleBuffer d_values;
-
- return DispatchRadixSort::Dispatch(
+ return DispatchRadixSort<
+ true, const KeyT *, KeyT *, const NullType *, NullType *, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys.Current(),
+ d_keys.Alternate(),
+ nullptr,
+ nullptr,
num_items,
begin_bit,
end_bit,
- is_overwrite_okay,
stream,
debug_synchronous);
+ d_keys.selector ^= 1;
}
diff --git a/cub/device/device_segmented_radix_sort.cuh b/cub/device/device_segmented_radix_sort.cuh
index 30d3028875..3682709b87 100644
--- a/cub/device/device_segmented_radix_sort.cuh
+++ b/cub/device/device_segmented_radix_sort.cuh
@@ -128,38 +128,42 @@ struct DeviceSegmentedRadixSort
* \tparam EndOffsetIteratorT [inferred] Random-access input iterator type for reading segment ending offsets \iterator
*/
template <
- typename KeyT,
- typename ValueT,
+ typename KeyInputIteratorT,
+ typename KeyIteratorT,
+ typename ValueInputIteratorT,
+ typename ValueIteratorT,
typename BeginOffsetIteratorT,
typename EndOffsetIteratorT>
CUB_RUNTIME_FUNCTION
static cudaError_t SortPairs(
void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
- const KeyT *d_keys_in, ///< [in] %Device-accessible pointer to the input data of key data to sort
- KeyT *d_keys_out, ///< [out] %Device-accessible pointer to the sorted output sequence of key data
- const ValueT *d_values_in, ///< [in] %Device-accessible pointer to the corresponding input sequence of associated value items
- ValueT *d_values_out, ///< [out] %Device-accessible pointer to the correspondingly-reordered output sequence of associated value items
+ KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort
+ KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data
+ ValueInputIteratorT d_values_in, ///< [in] Iterator to the corresponding input sequence of associated value items
+ ValueIteratorT d_values_out, ///< [out] Iterator to the correspondingly-reordered output sequence of associated value items
int num_items, ///< [in] The total number of items to sort (across all segments)
int num_segments, ///< [in] The number of segments that comprise the sorting data
BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_*
EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty.
int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison
- int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ = sizeof(typename std::iterator_traits::value_type) * 8,
cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0.
bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
{
// Signed integer type for global offsets
typedef int OffsetT;
- DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out);
- DoubleBuffer d_values(const_cast(d_values_in), d_values_out);
-
- return DispatchSegmentedRadixSort::Dispatch(
+ return DispatchSegmentedRadixSort<
+ false, KeyInputIteratorT, KeyIteratorT, ValueInputIteratorT, ValueIteratorT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys_in,
+ d_keys_out,
+ d_values_in,
+ d_values_out,
num_items,
num_segments,
d_begin_offsets,
@@ -261,11 +265,15 @@ struct DeviceSegmentedRadixSort
// Signed integer type for global offsets
typedef int OffsetT;
- return DispatchSegmentedRadixSort::Dispatch(
+ return DispatchSegmentedRadixSort::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys.Current(),
+ d_keys.Alternate(),
+ d_values.Current(),
+ d_values.Alternate(),
num_items,
num_segments,
d_begin_offsets,
@@ -275,6 +283,8 @@ struct DeviceSegmentedRadixSort
true,
stream,
debug_synchronous);
+ d_keys.selector ^= 1;
+ d_values.selector ^= 1;
}
@@ -334,38 +344,42 @@ struct DeviceSegmentedRadixSort
* \tparam EndOffsetIteratorT [inferred] Random-access input iterator type for reading segment ending offsets \iterator
*/
template <
- typename KeyT,
- typename ValueT,
+ typename KeyInputIteratorT,
+ typename KeyIteratorT,
+ typename ValueInputIteratorT,
+ typename ValueIteratorT,
typename BeginOffsetIteratorT,
typename EndOffsetIteratorT>
CUB_RUNTIME_FUNCTION
static cudaError_t SortPairsDescending(
void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
- const KeyT *d_keys_in, ///< [in] %Device-accessible pointer to the input data of key data to sort
- KeyT *d_keys_out, ///< [out] %Device-accessible pointer to the sorted output sequence of key data
- const ValueT *d_values_in, ///< [in] %Device-accessible pointer to the corresponding input sequence of associated value items
- ValueT *d_values_out, ///< [out] %Device-accessible pointer to the correspondingly-reordered output sequence of associated value items
+ KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort
+ KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data
+ ValueInputIteratorT d_values_in, ///< [in] Iterator to the corresponding input sequence of associated value items
+ ValueIteratorT d_values_out, ///< [out] Iterator to the correspondingly-reordered output sequence of associated value items
int num_items, ///< [in] The total number of items to sort (across all segments)
int num_segments, ///< [in] The number of segments that comprise the sorting data
BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_*
EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty.
int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison
- int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ = sizeof(typename std::iterator_traits::value_type) * 8,
cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0.
bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
{
// Signed integer type for global offsets
typedef int OffsetT;
- DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out);
- DoubleBuffer d_values(const_cast(d_values_in), d_values_out);
-
- return DispatchSegmentedRadixSort::Dispatch(
+ return DispatchSegmentedRadixSort<
+ true, KeyInputIteratorT, KeyIteratorT, ValueIteratorT, ValueIteratorT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys_in,
+ d_keys_out,
+ d_values_in,
+ d_values_out,
num_items,
num_segments,
d_begin_offsets,
@@ -467,11 +481,15 @@ struct DeviceSegmentedRadixSort
// Signed integer type for global offsets
typedef int OffsetT;
- return DispatchSegmentedRadixSort::Dispatch(
+ return DispatchSegmentedRadixSort<
+ true, const KeyT *, KeyT *, const ValueT *, ValueT *, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys.Current(),
+ d_keys.Alternate(),
+ d_values.Current(),
+ d_values.Alternate(),
num_items,
num_segments,
d_begin_offsets,
@@ -481,6 +499,8 @@ struct DeviceSegmentedRadixSort
true,
stream,
debug_synchronous);
+ d_keys.selector ^= 1;
+ d_values.selector ^= 1;
}
@@ -540,36 +560,38 @@ struct DeviceSegmentedRadixSort
* \tparam EndOffsetIteratorT [inferred] Random-access input iterator type for reading segment ending offsets \iterator
*/
template <
- typename KeyT,
+ typename KeyInputIteratorT,
+ typename KeyIteratorT,
typename BeginOffsetIteratorT,
typename EndOffsetIteratorT>
CUB_RUNTIME_FUNCTION
static cudaError_t SortKeys(
void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
- const KeyT *d_keys_in, ///< [in] %Device-accessible pointer to the input data of key data to sort
- KeyT *d_keys_out, ///< [out] %Device-accessible pointer to the sorted output sequence of key data
+ KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort
+ KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data
int num_items, ///< [in] The total number of items to sort (across all segments)
int num_segments, ///< [in] The number of segments that comprise the sorting data
BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_*
EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty.
int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison
- int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ = sizeof(typename std::iterator_traits::value_type) * 8,
cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0.
bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
{
// Signed integer type for global offsets
typedef int OffsetT;
- // Null value type
- DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out);
- DoubleBuffer d_values;
-
- return DispatchSegmentedRadixSort::Dispatch(
+ return DispatchSegmentedRadixSort<
+ false, KeyInputIteratorT, KeyIteratorT, const NullType *, NullType *, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys_in,
+ d_keys_out,
+ nullptr,
+ nullptr,
num_items,
num_segments,
d_begin_offsets,
@@ -661,14 +683,15 @@ struct DeviceSegmentedRadixSort
// Signed integer type for global offsets
typedef int OffsetT;
- // Null value type
- DoubleBuffer d_values;
-
- return DispatchSegmentedRadixSort::Dispatch(
+ return DispatchSegmentedRadixSort<
+ false, const KeyT *, KeyT *, const NullType *, NullType *, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys.Current(),
+ d_keys.Alternate(),
+ nullptr,
+ nullptr,
num_items,
num_segments,
d_begin_offsets,
@@ -678,6 +701,7 @@ struct DeviceSegmentedRadixSort
true,
stream,
debug_synchronous);
+ d_keys.selector ^= 1;
}
/**
@@ -732,35 +756,38 @@ struct DeviceSegmentedRadixSort
* \tparam EndOffsetIteratorT [inferred] Random-access input iterator type for reading segment ending offsets \iterator
*/
template <
- typename KeyT,
+ typename KeyInputIteratorT,
+ typename KeyIteratorT,
typename BeginOffsetIteratorT,
typename EndOffsetIteratorT>
CUB_RUNTIME_FUNCTION
static cudaError_t SortKeysDescending(
void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
- const KeyT *d_keys_in, ///< [in] %Device-accessible pointer to the input data of key data to sort
- KeyT *d_keys_out, ///< [out] %Device-accessible pointer to the sorted output sequence of key data
+ KeyInputIteratorT d_keys_in, ///< [in] Iterator to the input data of key data to sort
+ KeyIteratorT d_keys_out, ///< [out] Iterator to the sorted output sequence of key data
int num_items, ///< [in] The total number of items to sort (across all segments)
int num_segments, ///< [in] The number of segments that comprise the sorting data
BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_*
EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty.
int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison
- int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ int end_bit ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
+ = sizeof(typename std::iterator_traits::value_type) * 8,
cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0.
bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
{
// Signed integer type for global offsets
typedef int OffsetT;
- DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out);
- DoubleBuffer d_values;
-
- return DispatchSegmentedRadixSort::Dispatch(
+ return DispatchSegmentedRadixSort<
+ true, KeyInputIteratorT, KeyIteratorT, const NullType *, NullType *, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys_in,
+ d_keys_out,
+ nullptr,
+ nullptr,
num_items,
num_segments,
d_begin_offsets,
@@ -842,7 +869,7 @@ struct DeviceSegmentedRadixSort
DoubleBuffer &d_keys, ///< [in,out] Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys
int num_items, ///< [in] The total number of items to sort (across all segments)
int num_segments, ///< [in] The number of segments that comprise the sorting data
- BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_*
+ BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_*
EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty.
int begin_bit = 0, ///< [in] [optional] The least-significant bit index (inclusive) needed for key comparison
int end_bit = sizeof(KeyT) * 8, ///< [in] [optional] The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8)
@@ -852,14 +879,15 @@ struct DeviceSegmentedRadixSort
// Signed integer type for global offsets
typedef int OffsetT;
- // Null value type
- DoubleBuffer d_values;
-
- return DispatchSegmentedRadixSort::Dispatch(
+ return DispatchSegmentedRadixSort<
+ true, const KeyT *, KeyT *, const NullType *, NullType *, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT
+ >::Dispatch(
d_temp_storage,
temp_storage_bytes,
- d_keys,
- d_values,
+ d_keys.Current(),
+ d_keys.Alternate(),
+ nullptr,
+ nullptr,
num_items,
num_segments,
d_begin_offsets,
@@ -869,6 +897,7 @@ struct DeviceSegmentedRadixSort
true,
stream,
debug_synchronous);
+ d_keys.selector ^= 1;
}
diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh
index f48371e7c8..70044b3fdc 100644
--- a/cub/device/dispatch/dispatch_radix_sort.cuh
+++ b/cub/device/dispatch/dispatch_radix_sort.cuh
@@ -73,19 +73,20 @@ template <
typename ChainedPolicyT, ///< Chained tuning policy
bool ALT_DIGIT_BITS, ///< Whether or not to use the alternate (lower-bits) policy
bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low
- typename KeyT, ///< Key type
+ typename KeyInputIteratorT, ///< Input key iterator type
typename OffsetT> ///< Signed integer type for global offsets
__launch_bounds__ (int((ALT_DIGIT_BITS) ?
int(ChainedPolicyT::ActivePolicy::AltUpsweepPolicy::BLOCK_THREADS) :
int(ChainedPolicyT::ActivePolicy::UpsweepPolicy::BLOCK_THREADS)))
__global__ void DeviceRadixSortUpsweepKernel(
- const KeyT *d_keys, ///< [in] Input keys buffer
+ KeyInputIteratorT d_keys, ///< [in] Input keys iterator
OffsetT *d_spine, ///< [out] Privatized (per block) digit histograms (striped, i.e., 0s counts from each block, then 1s counts from each block, etc.)
OffsetT /*num_items*/, ///< [in] Total number of input data items
int current_bit, ///< [in] Bit position of current radix digit
int num_bits, ///< [in] Number of bits of current radix digit
GridEvenShare even_share) ///< [in] Even-share descriptor for mapan equal number of tiles onto each thread block
{
+ using KeyT = typename std::iterator_traits::value_type;
typedef typename If<
(ALT_DIGIT_BITS),
typename ChainedPolicyT::ActivePolicy::AltUpsweepPolicy,
@@ -180,23 +181,27 @@ template <
typename ChainedPolicyT, ///< Chained tuning policy
bool ALT_DIGIT_BITS, ///< Whether or not to use the alternate (lower-bits) policy
bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low
- typename KeyT, ///< Key type
- typename ValueT, ///< Value type
+ typename KeyInputIteratorT, ///< Input key iterator type
+ typename KeyIteratorT, ///< Key iterator type
+ typename ValueInputIteratorT, ///< Input value iterator type
+ typename ValueIteratorT, ///< Value iterator type
typename OffsetT> ///< Signed integer type for global offsets
__launch_bounds__ (int((ALT_DIGIT_BITS) ?
int(ChainedPolicyT::ActivePolicy::AltDownsweepPolicy::BLOCK_THREADS) :
int(ChainedPolicyT::ActivePolicy::DownsweepPolicy::BLOCK_THREADS)))
__global__ void DeviceRadixSortDownsweepKernel(
- const KeyT *d_keys_in, ///< [in] Input keys buffer
- KeyT *d_keys_out, ///< [in] Output keys buffer
- const ValueT *d_values_in, ///< [in] Input values buffer
- ValueT *d_values_out, ///< [in] Output values buffer
+ KeyInputIteratorT d_keys_in, ///< [in] Input keys iterator
+ KeyIteratorT d_keys_out, ///< [in] Output keys iterator
+ ValueInputIteratorT d_values_in, ///< [in] Input values iterator
+ ValueIteratorT d_values_out, ///< [in] Output values iterator
OffsetT *d_spine, ///< [in] Scan of privatized (per block) digit histograms (striped, i.e., 0s counts from each block, then 1s counts from each block, etc.)
OffsetT num_items, ///< [in] Total number of input data items
int current_bit, ///< [in] Bit position of current radix digit
int num_bits, ///< [in] Number of bits of current radix digit
GridEvenShare even_share) ///< [in] Even-share descriptor for mapan equal number of tiles onto each thread block
{
+ using KeyT = typename std::iterator_traits::value_type;
+ using ValueT = typename std::iterator_traits::value_type;
typedef typename If<
(ALT_DIGIT_BITS),
typename ChainedPolicyT::ActivePolicy::AltUpsweepPolicy,
@@ -243,19 +248,23 @@ __global__ void DeviceRadixSortDownsweepKernel(
template <
typename ChainedPolicyT, ///< Chained tuning policy
bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low
- typename KeyT, ///< Key type
- typename ValueT, ///< Value type
+ typename KeyInputIteratorT, ///< Input key iterator type
+ typename KeyIteratorT, ///< Key iterator type
+ typename ValueInputIteratorT, ///< Input value iterator type
+ typename ValueIteratorT, ///< Value iterator type
typename OffsetT> ///< Signed integer type for global offsets
__launch_bounds__ (int(ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THREADS), 1)
__global__ void DeviceRadixSortSingleTileKernel(
- const KeyT *d_keys_in, ///< [in] Input keys buffer
- KeyT *d_keys_out, ///< [in] Output keys buffer
- const ValueT *d_values_in, ///< [in] Input values buffer
- ValueT *d_values_out, ///< [in] Output values buffer
+ KeyInputIteratorT d_keys_in, ///< [in] Input keys iterator
+ KeyIteratorT d_keys_out, ///< [in] Output keys iterator
+ ValueInputIteratorT d_values_in, ///< [in] Input values iterator
+ ValueIteratorT d_values_out, ///< [in] Output values iterator
OffsetT num_items, ///< [in] Total number of input data items
int current_bit, ///< [in] Bit position of current radix digit
int end_bit) ///< [in] The past-the-end (most-significant) bit index needed for key comparison
{
+ using KeyT = typename std::iterator_traits::value_type;
+ using ValueT = typename std::iterator_traits::value_type;
// Constants
enum
{
@@ -357,8 +366,10 @@ template <
typename ChainedPolicyT, ///< Chained tuning policy
bool ALT_DIGIT_BITS, ///< Whether or not to use the alternate (lower-bits) policy
bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low
- typename KeyT, ///< Key type
- typename ValueT, ///< Value type
+ typename KeyInputIteratorT, ///< Input key iterator type
+ typename KeyIteratorT, ///< Key iterator type
+ typename ValueInputIteratorT, ///< Input value iterator type
+ typename ValueIteratorT, ///< Value iterator type
typename BeginOffsetIteratorT, ///< Random-access input iterator type for reading segment beginning offsets \iterator
typename EndOffsetIteratorT, ///< Random-access input iterator type for reading segment ending offsets \iterator
typename OffsetT> ///< Signed integer type for global offsets
@@ -366,16 +377,18 @@ __launch_bounds__ (int((ALT_DIGIT_BITS) ?
ChainedPolicyT::ActivePolicy::AltSegmentedPolicy::BLOCK_THREADS :
ChainedPolicyT::ActivePolicy::SegmentedPolicy::BLOCK_THREADS))
__global__ void DeviceSegmentedRadixSortKernel(
- const KeyT *d_keys_in, ///< [in] Input keys buffer
- KeyT *d_keys_out, ///< [in] Output keys buffer
- const ValueT *d_values_in, ///< [in] Input values buffer
- ValueT *d_values_out, ///< [in] Output values buffer
+ KeyInputIteratorT d_keys_in, ///< [in] Input keys iterator
+ KeyIteratorT d_keys_out, ///< [in] Output keys iterator
+ ValueInputIteratorT d_values_in, ///< [in] Input values iterator
+ ValueIteratorT d_values_out, ///< [in] Output values iterator
BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_*
EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty.
int /*num_segments*/, ///< [in] The number of segments that comprise the sorting data
int current_bit, ///< [in] Bit position of current radix digit
int pass_bits) ///< [in] Number of bits of current radix digit
{
+ using KeyT = typename std::iterator_traits::value_type;
+ using ValueT = typename std::iterator_traits::value_type;
//
// Constants
//
@@ -527,12 +540,13 @@ __global__ void DeviceSegmentedRadixSortKernel(
template <
typename ChainedPolicyT,
bool IS_DESCENDING,
- typename KeyT,
+ typename KeyInputIteratorT,
typename OffsetT>
__global__ void __launch_bounds__(ChainedPolicyT::ActivePolicy::HistogramPolicy::BLOCK_THREADS)
DeviceRadixSortHistogramKernel
- (OffsetT* d_bins_out, const KeyT* d_keys_in, OffsetT num_items, int start_bit, int end_bit)
+ (OffsetT* d_bins_out, KeyInputIteratorT d_keys_in, OffsetT num_items, int start_bit, int end_bit)
{
+ using KeyT = typename std::iterator_traits::value_type;
typedef typename ChainedPolicyT::ActivePolicy::HistogramPolicy HistogramPolicyT;
typedef AgentRadixSortHistogram AgentT;
__shared__ typename AgentT::TempStorage temp_storage;
@@ -543,16 +557,20 @@ DeviceRadixSortHistogramKernel
template <
typename ChainedPolicyT,
bool IS_DESCENDING,
- typename KeyT,
- typename ValueT,
+ typename KeyInputIteratorT,
+ typename KeyIteratorT,
+ typename ValueInputIteratorT,
+ typename ValueIteratorT,
typename OffsetT,
typename AtomicOffsetT = OffsetT>
__global__ void __launch_bounds__(ChainedPolicyT::ActivePolicy::OnesweepPolicy::BLOCK_THREADS)
DeviceRadixSortOnesweepKernel
(AtomicOffsetT* d_lookback, AtomicOffsetT* d_ctrs, OffsetT* d_bins_out,
- const OffsetT* d_bins_in, KeyT* d_keys_out, const KeyT* d_keys_in, ValueT* d_values_out,
- const ValueT* d_values_in, OffsetT num_items, int current_bit, int num_bits)
+ const OffsetT* d_bins_in, KeyIteratorT d_keys_out, KeyInputIteratorT d_keys_in, ValueIteratorT d_values_out,
+ ValueInputIteratorT d_values_in, OffsetT num_items, int current_bit, int num_bits)
{
+ using KeyT = typename std::iterator_traits::value_type;
+ using ValueT = typename std::iterator_traits::value_type;
typedef typename ChainedPolicyT::ActivePolicy::OnesweepPolicy OnesweepPolicyT;
typedef AgentRadixSortOnesweep AgentT;
__shared__ typename AgentT::TempStorage s;
@@ -954,14 +972,21 @@ struct DeviceRadixSortPolicy
* Utility class for dispatching the appropriately-tuned kernels for device-wide radix sort
*/
template <
- bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low
- typename KeyT, ///< Key type
- typename ValueT, ///< Value type
- typename OffsetT, ///< Signed integer type for global offsets
- typename SelectedPolicy = DeviceRadixSortPolicy >
+ bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low
+ typename KeyInputIteratorT, ///< Input key iterator type
+ typename KeyIteratorT, ///< Key iterator type
+ typename ValueInputIteratorT, ///< Input value iterator type
+ typename ValueIteratorT, ///< Value iterator type
+ typename OffsetT, ///< Signed integer type for global offsets
+ typename SelectedPolicy = DeviceRadixSortPolicy<
+ /*KeyT=*/ typename std::iterator_traits::value_type,
+ /*ValueT=*/ typename std::iterator_traits::value_type,
+ OffsetT> >
struct DispatchRadixSort :
SelectedPolicy
{
+ using KeyT = typename std::iterator_traits::value_type;
+ using ValueT = typename std::iterator_traits::value_type;
//------------------------------------------------------------------------------
// Constants
//------------------------------------------------------------------------------
@@ -979,15 +1004,16 @@ struct DispatchRadixSort :
void *d_temp_storage; ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t &temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
- DoubleBuffer &d_keys; ///< [in,out] Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys
- DoubleBuffer &d_values; ///< [in,out] Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values
+ KeyInputIteratorT d_keys_in; ///< [in] Iterator for the unsorted input keys
+ KeyIteratorT d_keys_out; ///< [out] Iterator for the sorted output keys
+ ValueInputIteratorT d_values_in; ///< [in] Iterator for the unsorted input values
+ ValueIteratorT d_values_out; ///< [out] Iterator for the sorted output values
OffsetT num_items; ///< [in] Number of items to sort
int begin_bit; ///< [in] The beginning (least-significant) bit index needed for key comparison
int end_bit; ///< [in] The past-the-end (most-significant) bit index needed for key comparison
cudaStream_t stream; ///< [in] CUDA stream to launch kernels within. Default is stream0.
bool debug_synchronous; ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
int ptx_version; ///< [in] PTX version
- bool is_overwrite_okay; ///< [in] Whether is okay to overwrite source buffers
//------------------------------------------------------------------------------
@@ -999,27 +1025,29 @@ struct DispatchRadixSort :
DispatchRadixSort(
void* d_temp_storage,
size_t &temp_storage_bytes,
- DoubleBuffer &d_keys,
- DoubleBuffer &d_values,
+ KeyInputIteratorT d_keys_in,
+ KeyIteratorT d_keys_out,
+ ValueInputIteratorT d_values_in,
+ ValueIteratorT d_values_out,
OffsetT num_items,
int begin_bit,
int end_bit,
- bool is_overwrite_okay,
cudaStream_t stream,
bool debug_synchronous,
int ptx_version)
:
d_temp_storage(d_temp_storage),
temp_storage_bytes(temp_storage_bytes),
- d_keys(d_keys),
- d_values(d_values),
+ d_keys_in(d_keys_in),
+ d_keys_out(d_keys_out),
+ d_values_in(d_values_in),
+ d_values_out(d_values_out),
num_items(num_items),
begin_bit(begin_bit),
end_bit(end_bit),
stream(stream),
debug_synchronous(debug_synchronous),
- ptx_version(ptx_version),
- is_overwrite_okay(is_overwrite_okay)
+ ptx_version(ptx_version)
{}
@@ -1064,10 +1092,10 @@ struct DispatchRadixSort :
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
1, ActivePolicyT::SingleTilePolicy::BLOCK_THREADS, 0, stream
).doit(single_tile_kernel,
- d_keys.Current(),
- d_keys.Alternate(),
- d_values.Current(),
- d_values.Alternate(),
+ d_keys_in,
+ d_keys_out,
+ d_values_in,
+ d_values_out,
num_items,
begin_bit,
end_bit);
@@ -1077,10 +1105,6 @@ struct DispatchRadixSort :
// Sync the stream if specified to flush runtime errors
if (debug_synchronous && (CubDebug(error = SyncStream(stream)))) break;
-
- // Update selector
- d_keys.selector ^= 1;
- d_values.selector ^= 1;
}
while (0);
@@ -1100,10 +1124,10 @@ struct DispatchRadixSort :
template
CUB_RUNTIME_FUNCTION __forceinline__
cudaError_t InvokePass(
- const KeyT *d_keys_in,
- KeyT *d_keys_out,
- const ValueT *d_values_in,
- ValueT *d_values_out,
+ KeyInputIteratorT d_keys_in,
+ KeyIteratorT d_keys_out,
+ ValueInputIteratorT d_values_in,
+ ValueIteratorT d_values_out,
OffsetT *d_spine,
int /*spine_length*/,
int ¤t_bit,
@@ -1281,9 +1305,9 @@ struct DispatchRadixSort :
// lookback
max_num_blocks * RADIX_DIGITS * sizeof(AtomicOffsetT),
// extra key buffer
- is_overwrite_okay || num_passes <= 1 ? 0 : num_items * sizeof(KeyT),
+ num_passes <= 1 ? 0 : num_items * sizeof(KeyT),
// extra value buffer
- is_overwrite_okay || num_passes <= 1 ? 0 : num_items * value_size,
+ num_passes <= 1 ? 0 : num_items * value_size,
// counters
num_parts * num_passes * sizeof(AtomicOffsetT),
};
@@ -1298,8 +1322,8 @@ struct DispatchRadixSort :
OffsetT* d_bins = (OffsetT*)allocations[0];
AtomicOffsetT* d_lookback = (AtomicOffsetT*)allocations[1];
- KeyT* d_keys_tmp2 = (KeyT*)allocations[2];
- ValueT* d_values_tmp2 = (ValueT*)allocations[3];
+ KeyT* d_keys_tmp = (KeyT*)allocations[2];
+ ValueT* d_values_tmp = (ValueT*)allocations[3];
AtomicOffsetT* d_ctrs = (AtomicOffsetT*)allocations[4];
do {
@@ -1319,11 +1343,11 @@ struct DispatchRadixSort :
const int HISTO_BLOCK_THREADS = ActivePolicyT::HistogramPolicy::BLOCK_THREADS;
int histo_blocks_per_sm = 1;
auto histogram_kernel = DeviceRadixSortHistogramKernel<
- MaxPolicyT, IS_DESCENDING, KeyT, OffsetT>;
+ MaxPolicyT, IS_DESCENDING, KeyInputIteratorT, OffsetT>;
if (CubDebug(error = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&histo_blocks_per_sm, histogram_kernel, HISTO_BLOCK_THREADS, 0))) break;
histogram_kernel<<>>
- (d_bins, d_keys.Current(), num_items, begin_bit, end_bit);
+ (d_bins, d_keys_in, num_items, begin_bit, end_bit);
if (CubDebug(error = cudaPeekAtLastError())) break;
// exclusive sums to determine starts
@@ -1331,15 +1355,13 @@ struct DispatchRadixSort :
DeviceRadixSortExclusiveSumKernel
<<>>(d_bins);
if (CubDebug(error = cudaPeekAtLastError())) break;
-
- // use the other buffer if no overwrite is allowed
- KeyT* d_keys_tmp = d_keys.Alternate();
- ValueT* d_values_tmp = d_values.Alternate();
- if (!is_overwrite_okay && num_passes % 2 == 0)
- {
- d_keys.d_buffers[1] = d_keys_tmp2;
- d_values.d_buffers[1] = d_values_tmp2;
- }
+
+ bool output_is_tmp = (num_passes % 2 == 0);
+ enum InputMode {
+ INPUT,
+ TMP_STORAGE,
+ OUTPUT
+ } input_mode = INPUT;
for (int current_bit = begin_bit, pass = 0; current_bit < end_bit;
current_bit += RADIX_BITS, ++pass)
@@ -1352,33 +1374,145 @@ struct DispatchRadixSort :
if (CubDebug(error = cudaMemsetAsync(
d_lookback, 0, num_blocks * RADIX_DIGITS * sizeof(AtomicOffsetT),
stream))) break;
- auto onesweep_kernel = DeviceRadixSortOnesweepKernel<
- MaxPolicyT, IS_DESCENDING, KeyT, ValueT, OffsetT>;
- onesweep_kernel<<>>
- (d_lookback, d_ctrs + part * num_passes + pass,
- part < num_parts - 1 ?
- d_bins + ((part + 1) * num_passes + pass) * RADIX_DIGITS : NULL,
- d_bins + (part * num_passes + pass) * RADIX_DIGITS,
- d_keys.Alternate(),
- d_keys.Current() + part * PART_SIZE,
- d_values.Alternate(),
- d_values.Current() + part * PART_SIZE,
- part_num_items, current_bit, num_bits);
+ if (output_is_tmp) {
+ using KeyOutIterT = KeyT *;
+ using ValueOutIterT = ValueT *;
+ KeyOutIterT d_keys_out_ = d_keys_tmp;
+ ValueOutIterT d_values_out_ = d_values_tmp;
+ switch (input_mode) {
+ case INPUT: {
+ using KeyInIterT = KeyInputIteratorT;
+ using ValueInIterT = ValueInputIteratorT;
+ auto onesweep_kernel = DeviceRadixSortOnesweepKernel<
+ MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT,
+ ValueInIterT, ValueOutIterT, OffsetT>;
+ KeyInIterT d_keys_in_ = d_keys_in;
+ ValueInIterT d_values_in_ = d_values_in;
+ onesweep_kernel<<>>
+ (d_lookback, d_ctrs + part * num_passes + pass,
+ part < num_parts - 1 ?
+ d_bins + ((part + 1) * num_passes + pass) * RADIX_DIGITS : NULL,
+ d_bins + (part * num_passes + pass) * RADIX_DIGITS,
+ d_keys_out_,
+ d_keys_in_ + part * PART_SIZE,
+ d_values_out_,
+ d_values_in_ + part * PART_SIZE,
+ part_num_items, current_bit, num_bits);
+ break;
+ }
+ case TMP_STORAGE: {
+ using KeyInIterT = const KeyT *;
+ using ValueInIterT = const ValueT *;
+ auto onesweep_kernel = DeviceRadixSortOnesweepKernel<
+ MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT,
+ ValueInIterT, ValueOutIterT, OffsetT>;
+ KeyInIterT d_keys_in_ = d_keys_tmp;
+ ValueInIterT d_values_in_ = d_values_tmp;
+ onesweep_kernel<<>>
+ (d_lookback, d_ctrs + part * num_passes + pass,
+ part < num_parts - 1 ?
+ d_bins + ((part + 1) * num_passes + pass) * RADIX_DIGITS : NULL,
+ d_bins + (part * num_passes + pass) * RADIX_DIGITS,
+ d_keys_out_,
+ d_keys_in_ + part * PART_SIZE,
+ d_values_out_,
+ d_values_in_ + part * PART_SIZE,
+ part_num_items, current_bit, num_bits);
+ break;
+ }
+ case OUTPUT: {
+ using KeyInIterT = KeyIteratorT;
+ using ValueInIterT = ValueIteratorT;
+ auto onesweep_kernel = DeviceRadixSortOnesweepKernel<
+ MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT,
+ ValueInIterT, ValueOutIterT, OffsetT>;
+ KeyInIterT d_keys_in_ = d_keys_out;
+ ValueInIterT d_values_in_ = d_values_out;
+ onesweep_kernel<<>>
+ (d_lookback, d_ctrs + part * num_passes + pass,
+ part < num_parts - 1 ?
+ d_bins + ((part + 1) * num_passes + pass) * RADIX_DIGITS : NULL,
+ d_bins + (part * num_passes + pass) * RADIX_DIGITS,
+ d_keys_out_,
+ d_keys_in_ + part * PART_SIZE,
+ d_values_out_,
+ d_values_in_ + part * PART_SIZE,
+ part_num_items, current_bit, num_bits);
+ break;
+ }
+ }
+ } else {
+ using KeyOutIterT = KeyIteratorT;
+ using ValueOutIterT = ValueIteratorT;
+ KeyOutIterT d_keys_out_ = d_keys_out;
+ ValueOutIterT d_values_out_ = d_values_out;
+ switch (input_mode) {
+ case INPUT: {
+ using KeyInIterT = KeyInputIteratorT;
+ using ValueInIterT = ValueInputIteratorT;
+ auto onesweep_kernel = DeviceRadixSortOnesweepKernel<
+ MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT,
+ ValueInIterT, ValueOutIterT, OffsetT>;
+ KeyInIterT d_keys_in_ = d_keys_in;
+ ValueInIterT d_values_in_ = d_values_in;
+ onesweep_kernel<<>>
+ (d_lookback, d_ctrs + part * num_passes + pass,
+ part < num_parts - 1 ?
+ d_bins + ((part + 1) * num_passes + pass) * RADIX_DIGITS : NULL,
+ d_bins + (part * num_passes + pass) * RADIX_DIGITS,
+ d_keys_out_,
+ d_keys_in_ + part * PART_SIZE,
+ d_values_out_,
+ d_values_in_ + part * PART_SIZE,
+ part_num_items, current_bit, num_bits);
+ break;
+ }
+ case TMP_STORAGE: {
+ using KeyInIterT = const KeyT *;
+ using ValueInIterT = const ValueT *;
+ auto onesweep_kernel = DeviceRadixSortOnesweepKernel<
+ MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT,
+ ValueInIterT, ValueOutIterT, OffsetT>;
+ KeyInIterT d_keys_in_ = d_keys_tmp;
+ ValueInIterT d_values_in_ = d_values_tmp;
+ onesweep_kernel<<>>
+ (d_lookback, d_ctrs + part * num_passes + pass,
+ part < num_parts - 1 ?
+ d_bins + ((part + 1) * num_passes + pass) * RADIX_DIGITS : NULL,
+ d_bins + (part * num_passes + pass) * RADIX_DIGITS,
+ d_keys_out_,
+ d_keys_in_ + part * PART_SIZE,
+ d_values_out_,
+ d_values_in_ + part * PART_SIZE,
+ part_num_items, current_bit, num_bits);
+ break;
+ }
+ case OUTPUT: {
+ using KeyInIterT = KeyIteratorT;
+ using ValueInIterT = ValueIteratorT;
+ auto onesweep_kernel = DeviceRadixSortOnesweepKernel<
+ MaxPolicyT, IS_DESCENDING, KeyInIterT, KeyOutIterT,
+ ValueInIterT, ValueOutIterT, OffsetT>;
+ KeyInIterT d_keys_in_ = d_keys_out;
+ ValueInIterT d_values_in_ = d_values_out;
+ onesweep_kernel<<>>
+ (d_lookback, d_ctrs + part * num_passes + pass,
+ part < num_parts - 1 ?
+ d_bins + ((part + 1) * num_passes + pass) * RADIX_DIGITS : NULL,
+ d_bins + (part * num_passes + pass) * RADIX_DIGITS,
+ d_keys_out_,
+ d_keys_in_ + part * PART_SIZE,
+ d_values_out_,
+ d_values_in_ + part * PART_SIZE,
+ part_num_items, current_bit, num_bits);
+ break;
+ }
+ }
+ }
if (CubDebug(error = cudaPeekAtLastError())) break;
}
-
- // use the temporary buffers if no overwrite is allowed
- if (!is_overwrite_okay && pass == 0)
- {
- d_keys = num_passes % 2 == 0 ?
- DoubleBuffer(d_keys_tmp, d_keys_tmp2) :
- DoubleBuffer(d_keys_tmp2, d_keys_tmp);
- d_values = num_passes % 2 == 0 ?
- DoubleBuffer(d_values_tmp, d_values_tmp2) :
- DoubleBuffer(d_values_tmp2, d_values_tmp);
- }
- d_keys.selector ^= 1;
- d_values.selector ^= 1;
+ input_mode = output_is_tmp ? TMP_STORAGE : OUTPUT;
+ output_is_tmp = !output_is_tmp;
}
} while (0);
@@ -1460,8 +1594,8 @@ struct DispatchRadixSort :
size_t allocation_sizes[3] =
{
spine_length * sizeof(OffsetT), // bytes needed for privatized block digit histograms
- (is_overwrite_okay) ? 0 : num_items * sizeof(KeyT), // bytes needed for 3rd keys buffer
- (is_overwrite_okay || (KEYS_ONLY)) ? 0 : num_items * sizeof(ValueT), // bytes needed for 3rd values buffer
+ num_items * sizeof(KeyT), // bytes needed for 3rd keys buffer
+ (KEYS_ONLY) ? 0 : num_items * sizeof(ValueT), // bytes needed for 3rd values buffer
};
// Alias the temporary allocations from the single storage blob (or compute the necessary size of the blob)
@@ -1482,12 +1616,12 @@ struct DispatchRadixSort :
OffsetT *d_spine = static_cast(allocations[0]);
DoubleBuffer d_keys_remaining_passes(
- (is_overwrite_okay || is_num_passes_odd) ? d_keys.Alternate() : static_cast(allocations[1]),
- (is_overwrite_okay) ? d_keys.Current() : (is_num_passes_odd) ? static_cast(allocations[1]) : d_keys.Alternate());
+ is_num_passes_odd ? d_keys.Alternate() : static_cast(allocations[1]),
+ is_num_passes_odd ? static_cast(allocations[1]) : d_keys.Alternate());
DoubleBuffer d_values_remaining_passes(
- (is_overwrite_okay || is_num_passes_odd) ? d_values.Alternate() : static_cast(allocations[2]),
- (is_overwrite_okay) ? d_values.Current() : (is_num_passes_odd) ? static_cast(allocations[2]) : d_values.Alternate());
+ is_num_passes_odd ? d_values.Alternate() : static_cast(allocations[2]),
+ is_num_passes_odd ? static_cast(allocations[2]) : d_values.Alternate());
// Run first pass, consuming from the input's current buffers
int current_bit = begin_bit;
@@ -1512,9 +1646,7 @@ struct DispatchRadixSort :
}
// Update selector
- if (!is_overwrite_okay) {
- num_passes = 1; // Sorted data always ends up in the other vector
- }
+ num_passes = 1; // Sorted data always ends up in the other vector
d_keys.selector = (d_keys.selector + num_passes) & 1;
d_values.selector = (d_values.selector + num_passes) & 1;
@@ -1538,11 +1670,11 @@ struct DispatchRadixSort :
// Invoke upsweep-downsweep
typedef typename DispatchRadixSort::MaxPolicy MaxPolicyT;
return InvokePasses(
- DeviceRadixSortUpsweepKernel< MaxPolicyT, false, IS_DESCENDING, KeyT, OffsetT>,
- DeviceRadixSortUpsweepKernel< MaxPolicyT, true, IS_DESCENDING, KeyT, OffsetT>,
+ DeviceRadixSortUpsweepKernel< MaxPolicyT, false, IS_DESCENDING, KeyInputIteratorT, OffsetT>,
+ DeviceRadixSortUpsweepKernel< MaxPolicyT, true, IS_DESCENDING, KeyInputIteratorT, OffsetT>,
RadixSortScanBinsKernel< MaxPolicyT, OffsetT>,
- DeviceRadixSortDownsweepKernel< MaxPolicyT, false, IS_DESCENDING, KeyT, ValueT, OffsetT>,
- DeviceRadixSortDownsweepKernel< MaxPolicyT, true, IS_DESCENDING, KeyT, ValueT, OffsetT>);
+ DeviceRadixSortDownsweepKernel< MaxPolicyT, false, IS_DESCENDING, KeyInputIteratorT, KeyIteratorT, ValueInputIteratorT, ValueIteratorT, OffsetT>,
+ DeviceRadixSortDownsweepKernel< MaxPolicyT, true, IS_DESCENDING, KeyInputIteratorT, KeyIteratorT, ValueInputIteratorT, ValueIteratorT, OffsetT>);
}
template
@@ -1566,7 +1698,8 @@ struct DispatchRadixSort :
{
// Small, single tile size
return InvokeSingleTile(
- DeviceRadixSortSingleTileKernel);
+ DeviceRadixSortSingleTileKernel);
}
else
{
@@ -1587,12 +1720,13 @@ struct DispatchRadixSort :
static cudaError_t Dispatch(
void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
- DoubleBuffer &d_keys, ///< [in,out] Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys
- DoubleBuffer &d_values, ///< [in,out] Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values
+ KeyInputIteratorT d_keys_in, ///< [in] Iterator for the unsorted input keys
+ KeyIteratorT d_keys_out, ///< [out] Iterator for the sorted output keys
+ ValueInputIteratorT d_values_in, ///< [in] Iterator for the unsorted input values
+ ValueIteratorT d_values_out, ///< [out] Iterator for the sorted output values
OffsetT num_items, ///< [in] Number of items to sort
int begin_bit, ///< [in] The beginning (least-significant) bit index needed for key comparison
int end_bit, ///< [in] The past-the-end (most-significant) bit index needed for key comparison
- bool is_overwrite_okay, ///< [in] Whether is okay to overwrite source buffers
cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0.
bool debug_synchronous) ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
{
@@ -1607,8 +1741,8 @@ struct DispatchRadixSort :
// Create dispatch functor
DispatchRadixSort dispatch(
d_temp_storage, temp_storage_bytes,
- d_keys, d_values,
- num_items, begin_bit, end_bit, is_overwrite_okay,
+ d_keys_in, d_keys_out, d_values_in, d_values_out,
+ num_items, begin_bit, end_bit,
stream, debug_synchronous, ptx_version);
// Dispatch to chained policy
@@ -1632,15 +1766,22 @@ struct DispatchRadixSort :
*/
template <
bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low
- typename KeyT, ///< Key type
- typename ValueT, ///< Value type
- typename BeginOffsetIteratorT, ///< Random-access input iterator type for reading segment beginning offsets \iterator
+ typename KeyInputIteratorT, ///< Input key iterator type
+ typename KeyIteratorT, ///< Key iterator type
+ typename ValueInputIteratorT, ///< Input value iterator type
+ typename ValueIteratorT, ///< Value iterator type
+ typename BeginOffsetIteratorT, ///< Random-access input iterator type for reading segment beginning offsets \iterator
typename EndOffsetIteratorT, ///< Random-access input iterator type for reading segment ending offsets \iterator
typename OffsetT, ///< Signed integer type for global offsets
- typename SelectedPolicy = DeviceRadixSortPolicy >
+ typename SelectedPolicy = DeviceRadixSortPolicy<
+ /*KeyT=*/typename std::iterator_traits::value_type,
+ /*ValueT=*/typename std::iterator_traits::value_type,
+ OffsetT> >
struct DispatchSegmentedRadixSort :
SelectedPolicy
{
+ using KeyT = typename std::iterator_traits::value_type;
+ using ValueT = typename std::iterator_traits::value_type;
//------------------------------------------------------------------------------
// Constants
//------------------------------------------------------------------------------
@@ -1658,8 +1799,10 @@ struct DispatchSegmentedRadixSort :
void *d_temp_storage; ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t &temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
- DoubleBuffer &d_keys; ///< [in,out] Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys
- DoubleBuffer &d_values; ///< [in,out] Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values
+ KeyInputIteratorT d_keys_in; ///< [in] Iterator for the unsorted input keys
+ KeyIteratorT d_keys_out; ///< [out] Iterator for the sorted output keys
+ ValueInputIteratorT d_values_in; ///< [in] Iterator for the unsorted input values
+ ValueIteratorT d_values_out; ///< [out] Iterator for the sorted output values
OffsetT num_items; ///< [in] Number of items to sort
OffsetT num_segments; ///< [in] The number of segments that comprise the sorting data
BeginOffsetIteratorT d_begin_offsets; ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_*
@@ -1669,7 +1812,6 @@ struct DispatchSegmentedRadixSort :
cudaStream_t stream; ///< [in] CUDA stream to launch kernels within. Default is stream0.
bool debug_synchronous; ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
int ptx_version; ///< [in] PTX version
- bool is_overwrite_okay; ///< [in] Whether is okay to overwrite source buffers
//------------------------------------------------------------------------------
@@ -1681,30 +1823,32 @@ struct DispatchSegmentedRadixSort :
DispatchSegmentedRadixSort(
void* d_temp_storage,
size_t &temp_storage_bytes,
- DoubleBuffer &d_keys,
- DoubleBuffer &d_values,
+ KeyInputIteratorT d_keys_in,
+ KeyIteratorT d_keys_out,
+ ValueInputIteratorT d_values_in,
+ ValueIteratorT d_values_out,
OffsetT num_items,
OffsetT num_segments,
BeginOffsetIteratorT d_begin_offsets,
EndOffsetIteratorT d_end_offsets,
int begin_bit,
int end_bit,
- bool is_overwrite_okay,
cudaStream_t stream,
bool debug_synchronous,
int ptx_version)
:
d_temp_storage(d_temp_storage),
temp_storage_bytes(temp_storage_bytes),
- d_keys(d_keys),
- d_values(d_values),
+ d_keys_in(d_keys_in),
+ d_keys_out(d_keys_out),
+ d_values_in(d_values_in),
+ d_values_out(d_values_out),
num_items(num_items),
num_segments(num_segments),
d_begin_offsets(d_begin_offsets),
d_end_offsets(d_end_offsets),
begin_bit(begin_bit),
end_bit(end_bit),
- is_overwrite_okay(is_overwrite_okay),
stream(stream),
debug_synchronous(debug_synchronous),
ptx_version(ptx_version)
@@ -1719,12 +1863,12 @@ struct DispatchSegmentedRadixSort :
template
CUB_RUNTIME_FUNCTION __forceinline__
cudaError_t InvokePass(
- const KeyT *d_keys_in,
- KeyT *d_keys_out,
- const ValueT *d_values_in,
- ValueT *d_values_out,
- int ¤t_bit,
- PassConfigT &pass_config)
+ KeyInputIteratorT d_keys_in,
+ KeyIteratorT d_keys_out,
+ ValueInputIteratorT d_values_in,
+ ValueIteratorT d_values_out,
+ int ¤t_bit,
+ PassConfigT &pass_config)
{
cudaError error = cudaSuccess;
do
@@ -1822,8 +1966,8 @@ struct DispatchSegmentedRadixSort :
void* allocations[2] = {};
size_t allocation_sizes[2] =
{
- (is_overwrite_okay) ? 0 : num_items * sizeof(KeyT), // bytes needed for 3rd keys buffer
- (is_overwrite_okay || (KEYS_ONLY)) ? 0 : num_items * sizeof(ValueT), // bytes needed for 3rd values buffer
+ num_items * sizeof(KeyT), // bytes needed for 3rd keys buffer
+ (KEYS_ONLY) ? 0 : num_items * sizeof(ValueT), // bytes needed for 3rd values buffer
};
// Alias the temporary allocations from the single storage blob (or compute the necessary size of the blob)
@@ -1847,12 +1991,12 @@ struct DispatchSegmentedRadixSort :
int alt_end_bit = CUB_MIN(end_bit, begin_bit + (max_alt_passes * alt_radix_bits));
DoubleBuffer d_keys_remaining_passes(
- (is_overwrite_okay || is_num_passes_odd) ? d_keys.Alternate() : static_cast(allocations[0]),
- (is_overwrite_okay) ? d_keys.Current() : (is_num_passes_odd) ? static_cast(allocations[0]) : d_keys.Alternate());
+ is_num_passes_odd ? d_keys.Alternate() : static_cast(allocations[0]),
+ is_num_passes_odd ? static_cast(allocations[0]) : d_keys.Alternate());
DoubleBuffer d_values_remaining_passes(
- (is_overwrite_okay || is_num_passes_odd) ? d_values.Alternate() : static_cast(allocations[1]),
- (is_overwrite_okay) ? d_values.Current() : (is_num_passes_odd) ? static_cast(allocations[1]) : d_values.Alternate());
+ is_num_passes_odd ? d_values.Alternate() : static_cast(allocations[1]),
+ is_num_passes_odd ? static_cast(allocations[1]) : d_values.Alternate());
// Run first pass, consuming from the input's current buffers
int current_bit = begin_bit;
@@ -1878,9 +2022,7 @@ struct DispatchSegmentedRadixSort :
}
// Update selector
- if (!is_overwrite_okay) {
- num_passes = 1; // Sorted data always ends up in the other vector
- }
+ num_passes = 1; // Sorted data always ends up in the other vector
d_keys.selector = (d_keys.selector + num_passes) & 1;
d_values.selector = (d_values.selector + num_passes) & 1;
@@ -1921,15 +2063,16 @@ struct DispatchSegmentedRadixSort :
static cudaError_t Dispatch(
void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
- DoubleBuffer &d_keys, ///< [in,out] Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys
- DoubleBuffer &d_values, ///< [in,out] Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values
+ KeyInputIteratorT d_keys_in, ///< [in] Iterator for the unsorted input keys
+ KeyIteratorT d_keys_out, ///< [out] Iterator for the sorted output keys
+ ValueInputIteratorT d_values_in, ///< [in] Iterator for the unsorted input values
+ ValueIteratorT d_values_out, ///< [out] Iterator for the sorted output values
int num_items, ///< [in] Number of items to sort
int num_segments, ///< [in] The number of segments that comprise the sorting data
BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that d_begin_offsets[i] is the first element of the ith data segment in d_keys_* and d_values_*
EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that d_end_offsets[i]-1 is the last element of the ith data segment in d_keys_* and d_values_*. If d_end_offsets[i]-1 <= d_begin_offsets[i], the ith is considered empty.
int begin_bit, ///< [in] The beginning (least-significant) bit index needed for key comparison
int end_bit, ///< [in] The past-the-end (most-significant) bit index needed for key comparison
- bool is_overwrite_okay, ///< [in] Whether is okay to overwrite source buffers
cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0.
bool debug_synchronous) ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
{
@@ -1946,7 +2089,7 @@ struct DispatchSegmentedRadixSort :
d_temp_storage, temp_storage_bytes,
d_keys, d_values,
num_items, num_segments, d_begin_offsets, d_end_offsets,
- begin_bit, end_bit, is_overwrite_okay,
+ begin_bit, end_bit,
stream, debug_synchronous, ptx_version);
// Dispatch to chained policy