Skip to content

Implement tensor.isin #2098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
30f254d
Implement tensor.isin
ndgrigorian Jun 6, 2025
fe22626
factor common utilities for scalar arguments to a new file
ndgrigorian Jun 8, 2025
6a51a7a
Make constexpr variables in `isin` static
ndgrigorian Jun 11, 2025
d6392a4
Update implementation of `isin`
ndgrigorian Jun 11, 2025
3fc487a
Update per review comments
ndgrigorian Jun 16, 2025
5de6bdd
Allow x to be a scalar in isin and remove assume_unique
ndgrigorian Jun 16, 2025
2d2ec9d
Make comparator static constexpr
ndgrigorian Jun 16, 2025
ca99f47
add basic tests for isin functionality
ndgrigorian Jun 17, 2025
28eb70f
add a fast-path for size == 1 arrays in sort
ndgrigorian Jun 17, 2025
401bc76
Remove unused import of dpctl in _set_functions.py
ndgrigorian Jun 17, 2025
68b488e
Add type hints to isin
ndgrigorian Jun 17, 2025
c8e511c
Add fast-path for size == 1 arrays to argsort
ndgrigorian Jun 17, 2025
32fe34e
Add usm_type to test_buf in isin
ndgrigorian Jun 18, 2025
c10798a
Address review comments for isin tests
ndgrigorian Jun 18, 2025
2689bc9
Add test covering nans and +/- 0 in isin
ndgrigorian Jun 18, 2025
f7f62c8
Add test for isin with Python scalar args
ndgrigorian Jun 23, 2025
0bb7932
Add test for combinations of dtypes as inputs to isin
ndgrigorian Jun 23, 2025
2137176
Add compute follows data test for isin
ndgrigorian Jun 23, 2025
7e5c4e0
Add isin to rendered docs
ndgrigorian Jun 23, 2025
bc4ced5
Test that isin output is C-contiguous when input is strided
ndgrigorian Jun 23, 2025
1fc9a25
improve formatting of isin docstring
ndgrigorian Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Set Functions
.. autosummary::
:toctree: generated

isin
unique_all
unique_counts
unique_inverse
Expand Down
1 change: 1 addition & 0 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ set(_reduction_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp
)
set(_sorting_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/isin.cpp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem relating to sorting routine

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it uses common utilities with searchsorted (i.e., from rich_comparisons.hpp) which is why it lives there

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the code from rich_comparisons gets factored out, I can go ahead and move it elsewhere, I guess to _tensor_impl for now

${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp
Expand Down
2 changes: 2 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@
)
from ._searchsorted import searchsorted
from ._set_functions import (
isin,
unique_all,
unique_counts,
unique_inverse,
Expand Down Expand Up @@ -394,4 +395,5 @@
"top_k",
"dldevice_to_sycl_device",
"sycl_device_to_dldevice",
"isin",
]
10 changes: 5 additions & 5 deletions dpctl/tensor/_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
_empty_like_pair_orderK,
_empty_like_triple_orderK,
)
from dpctl.tensor._elementwise_common import (
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.tensor._type_utils import _can_cast
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._scalar_utils import (
_get_dtype,
_get_queue_usm_type,
_get_shape,
_validate_dtype,
)
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.tensor._type_utils import _can_cast
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._type_utils import (
_resolve_one_strong_one_weak_types,
_resolve_one_strong_two_weak_types,
Expand Down
89 changes: 6 additions & 83 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numbers

import numpy as np

import dpctl
import dpctl.memory as dpm
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
from ._scalar_utils import (
_get_dtype,
_get_queue_usm_type,
_get_shape,
_validate_dtype,
)
from ._type_utils import (
WeakBooleanType,
WeakComplexType,
WeakFloatingType,
WeakIntegralType,
_acceptance_fn_default_binary,
_acceptance_fn_default_unary,
_all_data_types,
_find_buf_dtype,
_find_buf_dtype2,
_find_buf_dtype_in_place_op,
_resolve_weak_types,
_to_device_supported_dtype,
)


Expand Down Expand Up @@ -289,78 +284,6 @@ def __call__(self, x, /, *, out=None, order="K"):
return out


def _get_queue_usm_type(o):
"""Return SYCL device where object `o` allocated memory, or None."""
if isinstance(o, dpt.usm_ndarray):
return o.sycl_queue, o.usm_type
elif hasattr(o, "__sycl_usm_array_interface__"):
try:
m = dpm.as_usm_memory(o)
return m.sycl_queue, m.get_usm_type()
except Exception:
return None, None
return None, None


def _get_dtype(o, dev):
if isinstance(o, dpt.usm_ndarray):
return o.dtype
if hasattr(o, "__sycl_usm_array_interface__"):
return dpt.asarray(o).dtype
if _is_buffer(o):
host_dt = np.array(o).dtype
dev_dt = _to_device_supported_dtype(host_dt, dev)
return dev_dt
if hasattr(o, "dtype"):
dev_dt = _to_device_supported_dtype(o.dtype, dev)
return dev_dt
if isinstance(o, bool):
return WeakBooleanType(o)
if isinstance(o, int):
return WeakIntegralType(o)
if isinstance(o, float):
return WeakFloatingType(o)
if isinstance(o, complex):
return WeakComplexType(o)
return np.object_


def _validate_dtype(dt) -> bool:
return isinstance(
dt,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
) or (
isinstance(dt, dpt.dtype)
and dt
in [
dpt.bool,
dpt.int8,
dpt.uint8,
dpt.int16,
dpt.uint16,
dpt.int32,
dpt.uint32,
dpt.int64,
dpt.uint64,
dpt.float16,
dpt.float32,
dpt.float64,
dpt.complex64,
dpt.complex128,
]
)


def _get_shape(o):
if isinstance(o, dpt.usm_ndarray):
return o.shape
if _is_buffer(o):
return memoryview(o).shape
if isinstance(o, numbers.Number):
return tuple()
return getattr(o, "shape", tuple())


class BinaryElementwiseFunc:
"""
Class that implements binary element-wise functions.
Expand Down
111 changes: 111 additions & 0 deletions dpctl/tensor/_scalar_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numbers

import numpy as np

import dpctl.memory as dpm
import dpctl.tensor as dpt
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer

from ._type_utils import (
WeakBooleanType,
WeakComplexType,
WeakFloatingType,
WeakIntegralType,
_to_device_supported_dtype,
)


def _get_queue_usm_type(o):
"""Return SYCL device where object `o` allocated memory, or None."""
if isinstance(o, dpt.usm_ndarray):
return o.sycl_queue, o.usm_type
elif hasattr(o, "__sycl_usm_array_interface__"):
try:
m = dpm.as_usm_memory(o)
return m.sycl_queue, m.get_usm_type()
except Exception:
return None, None
return None, None


def _get_dtype(o, dev):
if isinstance(o, dpt.usm_ndarray):
return o.dtype
if hasattr(o, "__sycl_usm_array_interface__"):
return dpt.asarray(o).dtype
if _is_buffer(o):
host_dt = np.array(o).dtype
dev_dt = _to_device_supported_dtype(host_dt, dev)
return dev_dt
if hasattr(o, "dtype"):
dev_dt = _to_device_supported_dtype(o.dtype, dev)
return dev_dt
if isinstance(o, bool):
return WeakBooleanType(o)
if isinstance(o, int):
return WeakIntegralType(o)
if isinstance(o, float):
return WeakFloatingType(o)
if isinstance(o, complex):
return WeakComplexType(o)
return np.object_


def _validate_dtype(dt) -> bool:
return isinstance(
dt,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
) or (
isinstance(dt, dpt.dtype)
and dt
in [
dpt.bool,
dpt.int8,
dpt.uint8,
dpt.int16,
dpt.uint16,
dpt.int32,
dpt.uint32,
dpt.int64,
dpt.uint64,
dpt.float16,
dpt.float32,
dpt.float64,
dpt.complex64,
dpt.complex128,
]
)


def _get_shape(o):
if isinstance(o, dpt.usm_ndarray):
return o.shape
if _is_buffer(o):
return memoryview(o).shape
if isinstance(o, numbers.Number):
return tuple()
return getattr(o, "shape", tuple())


__all__ = [
"_get_dtype",
"_get_queue_usm_type",
"_get_shape",
"_validate_dtype",
]
10 changes: 5 additions & 5 deletions dpctl/tensor/_search_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
import dpctl
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
from dpctl.tensor._elementwise_common import (
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
from ._scalar_utils import (
_get_dtype,
_get_queue_usm_type,
_get_shape,
_validate_dtype,
)
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
from ._type_utils import (
WeakBooleanType,
WeakComplexType,
Expand Down
Loading
Loading