2222from ._tensor_sorting_impl import (
2323 _argsort_ascending ,
2424 _argsort_descending ,
25+ _radix_argsort_ascending ,
26+ _radix_argsort_descending ,
27+ _radix_sort_ascending ,
28+ _radix_sort_descending ,
29+ _radix_sort_dtype_supported ,
2530 _sort_ascending ,
2631 _sort_descending ,
2732)
2833
2934__all__ = ["sort" , "argsort" ]
3035
3136
32- def sort (x , / , * , axis = - 1 , descending = False , stable = True ):
37+ def _get_mergesort_impl_fn (descending ):
38+ return _sort_descending if descending else _sort_ascending
39+
40+
41+ def _get_radixsort_impl_fn (descending ):
42+ return _radix_sort_descending if descending else _radix_sort_ascending
43+
44+
45+ def sort (x , / , * , axis = - 1 , descending = False , stable = True , kind = None ):
3346 """sort(x, axis=-1, descending=False, stable=True)
3447
3548 Returns a sorted copy of an input array `x`.
@@ -49,7 +62,10 @@ def sort(x, /, *, axis=-1, descending=False, stable=True):
4962 relative order of `x` values which compare as equal. If `False`,
5063 the returned array may or may not maintain the relative order of
5164 `x` values which compare as equal. Default: `True`.
52-
65+ kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
66+ Sorting algorithm. The default is `"stable"`, which uses parallel
67+ merge-sort or parallel radix-sort algorithms depending on the
68+ array data type.
5369 Returns:
5470 usm_ndarray:
5571 a sorted array. The returned array has the same data type and
@@ -74,10 +90,33 @@ def sort(x, /, *, axis=-1, descending=False, stable=True):
7490 axis ,
7591 ]
7692 arr = dpt .permute_dims (x , perm )
93+ if kind is None :
94+ kind = "stable"
95+ if not isinstance (kind , str ) or kind not in [
96+ "stable" ,
97+ "radixsort" ,
98+ "mergesort" ,
99+ ]:
100+ raise ValueError (
101+ "Unsupported kind value. Expected 'stable', 'mergesort', "
102+ f"or 'radixsort', but got '{ kind } '"
103+ )
104+ if kind == "mergesort" :
105+ impl_fn = _get_mergesort_impl_fn (descending )
106+ elif kind == "radixsort" :
107+ if _radix_sort_dtype_supported (x .dtype .num ):
108+ impl_fn = _get_radixsort_impl_fn (descending )
109+ else :
110+ raise ValueError (f"Radix sort is not supported for { x .dtype } " )
111+ else :
112+ dt = x .dtype
113+ if dt in [dpt .bool , dpt .uint8 , dpt .int8 , dpt .int16 , dpt .uint16 ]:
114+ impl_fn = _get_radixsort_impl_fn (descending )
115+ else :
116+ impl_fn = _get_mergesort_impl_fn (descending )
77117 exec_q = x .sycl_queue
78118 _manager = du .SequentialOrderManager [exec_q ]
79119 dep_evs = _manager .submitted_events
80- impl_fn = _sort_descending if descending else _sort_ascending
81120 if arr .flags .c_contiguous :
82121 res = dpt .empty_like (arr , order = "C" )
83122 ht_ev , impl_ev = impl_fn (
@@ -109,7 +148,15 @@ def sort(x, /, *, axis=-1, descending=False, stable=True):
109148 return res
110149
111150
112- def argsort (x , axis = - 1 , descending = False , stable = True ):
151+ def _get_mergeargsort_impl_fn (descending ):
152+ return _argsort_descending if descending else _argsort_ascending
153+
154+
155+ def _get_radixargsort_impl_fn (descending ):
156+ return _radix_argsort_descending if descending else _radix_argsort_ascending
157+
158+
159+ def argsort (x , axis = - 1 , descending = False , stable = True , kind = None ):
113160 """argsort(x, axis=-1, descending=False, stable=True)
114161
115162 Returns the indices that sort an array `x` along a specified axis.
@@ -129,6 +176,10 @@ def argsort(x, axis=-1, descending=False, stable=True):
129176 relative order of `x` values which compare as equal. If `False`,
130177 the returned array may or may not maintain the relative order of
131178 `x` values which compare as equal. Default: `True`.
179+ kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
180+ Sorting algorithm. The default is `"stable"`, which uses parallel
181+ merge-sort or parallel radix-sort algorithms depending on the
182+ array data type.
132183
133184 Returns:
134185 usm_ndarray:
@@ -157,10 +208,33 @@ def argsort(x, axis=-1, descending=False, stable=True):
157208 axis ,
158209 ]
159210 arr = dpt .permute_dims (x , perm )
211+ if kind is None :
212+ kind = "stable"
213+ if not isinstance (kind , str ) or kind not in [
214+ "stable" ,
215+ "radixsort" ,
216+ "mergesort" ,
217+ ]:
218+ raise ValueError (
219+ "Unsupported kind value. Expected 'stable', 'mergesort', "
220+ f"or 'radixsort', but got '{ kind } '"
221+ )
222+ if kind == "mergesort" :
223+ impl_fn = _get_mergeargsort_impl_fn (descending )
224+ elif kind == "radixsort" :
225+ if _radix_sort_dtype_supported (x .dtype .num ):
226+ impl_fn = _get_radixargsort_impl_fn (descending )
227+ else :
228+ raise ValueError (f"Radix sort is not supported for { x .dtype } " )
229+ else :
230+ dt = x .dtype
231+ if dt in [dpt .bool , dpt .uint8 , dpt .int8 , dpt .int16 , dpt .uint16 ]:
232+ impl_fn = _get_radixargsort_impl_fn (descending )
233+ else :
234+ impl_fn = _get_mergeargsort_impl_fn (descending )
160235 exec_q = x .sycl_queue
161236 _manager = du .SequentialOrderManager [exec_q ]
162237 dep_evs = _manager .submitted_events
163- impl_fn = _argsort_descending if descending else _argsort_ascending
164238 index_dt = ti .default_device_index_type (exec_q )
165239 if arr .flags .c_contiguous :
166240 res = dpt .empty_like (arr , dtype = index_dt , order = "C" )
0 commit comments