-
-
Notifications
You must be signed in to change notification settings - Fork 19k
ENH: Implemented MultiIndex.searchsorted method ( GH14833) #61435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
cffb863
1ba7ff8
ac70f3e
275b0e2
0e0b9b5
4747609
9ac62ab
e2c2c5e
e88da57
94f7c44
1f4a1c9
ffd99d8
5e2caa4
6b0d0ab
73e308b
1342657
68a3b81
a681c2d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
Any, | ||
Literal, | ||
cast, | ||
overload, | ||
) | ||
import warnings | ||
|
||
|
@@ -44,6 +45,15 @@ | |
Shape, | ||
npt, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from pandas._typing import ( | ||
NumpySorter, | ||
NumpyValueArrayLike, | ||
ScalarLike_co, | ||
) | ||
|
||
|
||
from pandas.compat.numpy import function as nv | ||
from pandas.errors import ( | ||
InvalidIndexError, | ||
|
@@ -3778,6 +3788,99 @@ def _reorder_indexer( | |
ind = np.lexsort(keys) | ||
return indexer[ind] | ||
|
||
@overload | ||
def searchsorted( # type: ignore[overload-overlap] | ||
self, | ||
value: ScalarLike_co, | ||
side: Literal["left", "right"] = ..., | ||
sorter: NumpySorter = ..., | ||
) -> np.intp: ... | ||
|
||
@overload | ||
def searchsorted( | ||
self, | ||
value: npt.ArrayLike | ExtensionArray, | ||
side: Literal["left", "right"] = ..., | ||
sorter: NumpySorter = ..., | ||
) -> npt.NDArray[np.intp]: ... | ||
|
||
def searchsorted( | ||
self, | ||
value: NumpyValueArrayLike | ExtensionArray, | ||
side: Literal["left", "right"] = "left", | ||
sorter: npt.NDArray[np.intp] | None = None, | ||
) -> npt.NDArray[np.intp] | np.intp: | ||
""" | ||
Find the indices where elements should be inserted to maintain order. | ||
|
||
Parameters | ||
---------- | ||
value : Any | ||
The value(s) to search for in the MultiIndex. | ||
side : {'left', 'right'}, default 'left' | ||
If 'left', the index of the first suitable location found is given. | ||
If 'right', return the last such index. Note that if `value` is | ||
already present in the MultiIndex, the results will be different. | ||
sorter : 1-D array-like, optional | ||
Optional array of integer indices that sort the MultiIndex. | ||
|
||
Returns | ||
------- | ||
npt.NDArray[np.intp] or np.intp | ||
The index or indices where the value(s) should be inserted to | ||
maintain order. | ||
|
||
See Also | ||
-------- | ||
Index.searchsorted : Search for insertion point in a 1-D index. | ||
|
||
Examples | ||
-------- | ||
>>> mi = pd.MultiIndex.from_arrays([["a", "b", "c"], ["x", "y", "z"]]) | ||
>>> mi.searchsorted(("b", "y")) | ||
array([1]) | ||
""" | ||
|
||
if not value: | ||
raise ValueError("searchsorted requires a non-empty value") | ||
|
||
if not isinstance(value, (tuple, list)): | ||
|
||
raise TypeError("value must be a tuple or list") | ||
|
||
if isinstance(value, tuple): | ||
value = [value] | ||
|
||
if side not in ["left", "right"]: | ||
raise ValueError("side must be either 'left' or 'right'") | ||
|
||
indexer = self.get_indexer(value) | ||
result = [] | ||
|
||
for v, i in zip(value, indexer): | ||
if i != -1: | ||
val = i if side == "left" else i + 1 | ||
result.append(np.intp(val)) | ||
else: | ||
dtype = np.dtype( | ||
[ | ||
(f"level_{i}", np.asarray(level).dtype) | ||
for i, level in enumerate(self.levels) | ||
] | ||
) | ||
|
||
val_array = np.array([v], dtype=dtype) | ||
|
||
pos = np.searchsorted( | ||
np.asarray(self.values, dtype=dtype), | ||
val_array, | ||
side=side, | ||
sorter=sorter, | ||
) | ||
result.append(np.intp(pos[0])) | ||
|
||
if len(result) == 1: | ||
return result[0] | ||
return np.array(result, dtype=np.intp) | ||
|
||
def truncate(self, before=None, after=None) -> MultiIndex: | ||
""" | ||
Slice index between two labels / tuples, return new MultiIndex. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -147,14 +147,14 @@ def test_searchsorted(request, index_or_series_obj): | |
# See gh-12238 | ||
obj = index_or_series_obj | ||
|
||
if isinstance(obj, pd.MultiIndex): | ||
# See gh-14833 | ||
request.applymarker( | ||
pytest.mark.xfail( | ||
reason="np.searchsorted doesn't work on pd.MultiIndex: GH 14833" | ||
) | ||
) | ||
elif obj.dtype.kind == "c" and isinstance(obj, Index): | ||
# if isinstance(obj, pd.MultiIndex): | ||
# # See gh-14833 | ||
# request.applymarker( | ||
# pytest.mark.xfail( | ||
# reason="np.searchsorted doesn't work on pd.MultiIndex: GH 14833" | ||
# ) | ||
# ) | ||
if obj.dtype.kind == "c" and isinstance(obj, Index): | ||
|
||
# TODO: Should Series cases also raise? Looks like they use numpy | ||
# comparison semantics https://github.com/numpy/numpy/issues/15981 | ||
mark = pytest.mark.xfail(reason="complex objects are not comparable") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1029,3 +1029,28 @@ def test_get_loc_namedtuple_behaves_like_tuple(): | |
assert idx.get_loc(("i1", "i2")) == 0 | ||
assert idx.get_loc(("i3", "i4")) == 1 | ||
assert idx.get_loc(("i5", "i6")) == 2 | ||
|
||
|
||
def test_searchsorted(): | ||
|
||
# GH14833 | ||
mi = MultiIndex.from_tuples([("a", 0), ("a", 1), ("b", 0), ("b", 1), ("c", 0)]) | ||
|
||
assert np.all(mi.searchsorted(("b", 0)) == 2) | ||
assert np.all(mi.searchsorted(("b", 0), side="right") == 3) | ||
assert np.all(mi.searchsorted(("a", 0)) == 0) | ||
assert np.all(mi.searchsorted(("a", -1)) == 0) | ||
assert np.all(mi.searchsorted(("c", 1)) == 5) | ||
|
||
result = mi.searchsorted([("a", 1), ("b", 0), ("c", 0)]) | ||
expected = np.array([1, 2, 4], dtype=np.intp) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
result = mi.searchsorted([("a", 1), ("b", 0), ("c", 0)], side="right") | ||
expected = np.array([2, 3, 5], dtype=np.intp) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
with pytest.raises(ValueError, match="side must be either 'left' or 'right'"): | ||
mi.searchsorted(("a", 1), side="middle") | ||
|
||
with pytest.raises(TypeError, match="value must be a tuple or list"): | ||
mi.searchsorted("a") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to remove this. We do have some notebooks in this repo iirc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I will remove it.