File tree Expand file tree Collapse file tree 1 file changed +16
-0
lines changed
dpctl/tensor/libtensor/source/sorting Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Original file line number Diff line number Diff line change 2323// ===--------------------------------------------------------------------===//
2424
2525#include < cstdint>
26+ #include < exception>
2627#include < utility>
2728#include < vector>
2829
@@ -105,6 +106,19 @@ void init_radix_sort_dispatch_vectors(void)
105106 dtv2.populate_dispatch_vector (descending_radix_sort_contig_dispatch_vector);
106107}
107108
109+ bool py_radix_sort_defined (int typenum)
110+ {
111+ const auto &array_types = td_ns::usm_ndarray_types ();
112+
113+ try {
114+ int type_id = array_types.typenum_to_lookup_id (typenum);
115+ return (nullptr !=
116+ ascending_radix_sort_contig_dispatch_vector[type_id]);
117+ } catch (const std::exception &e) {
118+ return false ;
119+ }
120+ }
121+
108122void init_radix_sort_functions (py::module_ m)
109123{
110124 dpctl::tensor::py_internal::init_radix_sort_dispatch_vectors ();
@@ -139,6 +153,8 @@ void init_radix_sort_functions(py::module_ m)
139153 py::arg (" trailing_dims_to_sort" ), py::arg (" dst" ),
140154 py::arg (" sycl_queue" ), py::arg (" depends" ) = py::list ());
141155
156+ m.def (" _radix_sort_dtype_supported" , py_radix_sort_defined);
157+
142158 return ;
143159}
144160
You can’t perform that action at this time.
0 commit comments