3434#include " kernels/dpctl_tensor_types.hpp"
3535#include " kernels/sorting/search_sorted_detail.hpp"
3636#include " utils/offset_utils.hpp"
37+ #include " utils/rich_comparisons.hpp"
3738
3839namespace dpctl
3940{
@@ -47,8 +48,7 @@ using dpctl::tensor::ssize_t;
4748template <typename T,
4849 typename HayIndexerT,
4950 typename NeedlesIndexerT,
50- typename OutIndexerT,
51- typename Compare>
51+ typename OutIndexerT>
5252struct IsinFunctor
5353{
5454private:
@@ -78,6 +78,8 @@ struct IsinFunctor
7878
7979 void operator ()(sycl::id<1 > id) const
8080 {
81+ using Compare =
82+ typename dpctl::tensor::rich_comparisons::AscendingSorter<T>::type;
8183 static constexpr Compare comp{};
8284
8385 const std::size_t i = id[0 ];
@@ -115,7 +117,7 @@ typedef sycl::event (*isin_contig_impl_fp_ptr_t)(
115117
116118template <typename T> class isin_contig_impl_krn ;
117119
118- template <typename T, typename Compare >
120+ template <typename T>
119121sycl::event isin_contig_impl (sycl::queue &exec_q,
120122 const bool invert,
121123 const std::size_t hay_nelems,
@@ -148,9 +150,9 @@ sycl::event isin_contig_impl(sycl::queue &exec_q,
148150 static constexpr TrivialIndexerT out_indexer{};
149151
150152 const auto fnctr =
151- IsinFunctor<T, TrivialIndexerT, TrivialIndexerT, TrivialIndexerT,
152- Compare>( invert, hay_tp, needles_tp, out_tp, hay_nelems,
153- hay_indexer, needles_indexer, out_indexer);
153+ IsinFunctor<T, TrivialIndexerT, TrivialIndexerT, TrivialIndexerT>(
154+ invert, hay_tp, needles_tp, out_tp, hay_nelems, hay_indexer ,
155+ needles_indexer, out_indexer);
154156
155157 cgh.parallel_for <KernelName>(gRange , fnctr);
156158 });
@@ -176,7 +178,7 @@ typedef sycl::event (*isin_strided_impl_fp_ptr_t)(
176178
177179template <typename T> class isin_strided_impl_krn ;
178180
179- template <typename T, typename Compare >
181+ template <typename T>
180182sycl::event isin_strided_impl (
181183 sycl::queue &exec_q,
182184 const bool invert,
@@ -224,7 +226,7 @@ sycl::event isin_strided_impl(
224226 out_strides);
225227
226228 const auto fnctr =
227- IsinFunctor<T, HayIndexerT, NeedlesIndexerT, OutIndexerT, Compare >(
229+ IsinFunctor<T, HayIndexerT, NeedlesIndexerT, OutIndexerT>(
228230 invert, hay_tp, needles_tp, out_tp, hay_nelems, hay_indexer,
229231 needles_indexer, out_indexer);
230232 using KernelName = class isin_strided_impl_krn <T>;
0 commit comments