Skip to content

Commit 8a1df28

Browse files
committed
type IndexOpsMixin
1 parent 58d59f2 commit 8a1df28

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

pandas-stubs/core/base.pyi

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ from typing_extensions import Self
2222

2323
from pandas._typing import (
2424
S1,
25+
AnyArrayLike,
2526
AxisIndex,
2627
DropKeep,
2728
DTypeLike,
2829
GenericT,
2930
GenericT_co,
31+
ListLike,
3032
NDFrameT,
3133
Scalar,
3234
SupportsDType,
@@ -51,7 +53,7 @@ class IndexOpsMixin(OpsMixin, Generic[S1, GenericT_co]):
5153
@property
5254
def T(self) -> Self: ...
5355
@property
54-
def shape(self) -> tuple: ...
56+
def shape(self) -> tuple[int, ...]: ...
5557
@property
5658
def ndim(self) -> int: ...
5759
def item(self) -> S1: ...
@@ -67,41 +69,45 @@ class IndexOpsMixin(OpsMixin, Generic[S1, GenericT_co]):
6769
dtype: None = None,
6870
copy: bool = False,
6971
na_value: Scalar = ...,
70-
**kwargs,
72+
**kwargs: Any,
7173
) -> np_1darray[GenericT_co]: ...
7274
@overload
7375
def to_numpy(
7476
self,
7577
dtype: np.dtype[GenericT] | SupportsDType[GenericT] | type[GenericT],
7678
copy: bool = False,
7779
na_value: Scalar = ...,
78-
**kwargs,
80+
**kwargs: Any,
7981
) -> np_1darray[GenericT]: ...
8082
@overload
8183
def to_numpy(
8284
self,
8385
dtype: DTypeLike,
8486
copy: bool = False,
8587
na_value: Scalar = ...,
86-
**kwargs,
88+
**kwargs: Any,
8789
) -> np_1darray: ...
8890
@property
8991
def empty(self) -> bool: ...
90-
def max(self, axis=..., skipna: bool = ..., **kwargs): ...
91-
def min(self, axis=..., skipna: bool = ..., **kwargs): ...
92+
def max(
93+
self, axis: AxisIndex | None = ..., skipna: bool = ..., **kwargs: Any
94+
) -> S1: ...
95+
def min(
96+
self, axis: AxisIndex | None = ..., skipna: bool = ..., **kwargs: Any
97+
) -> S1: ...
9298
def argmax(
9399
self,
94100
axis: AxisIndex | None = ...,
95101
skipna: bool = True,
96-
*args,
97-
**kwargs,
102+
*args: Any,
103+
**kwargs: Any,
98104
) -> np.int64: ...
99105
def argmin(
100106
self,
101107
axis: AxisIndex | None = ...,
102108
skipna: bool = True,
103-
*args,
104-
**kwargs,
109+
*args: Any,
110+
**kwargs: Any,
105111
) -> np.int64: ...
106112
def tolist(self) -> list[S1]: ...
107113
def to_list(self) -> list[S1]: ...
@@ -114,7 +120,7 @@ class IndexOpsMixin(OpsMixin, Generic[S1, GenericT_co]):
114120
normalize: Literal[False] = ...,
115121
sort: bool = ...,
116122
ascending: bool = ...,
117-
bins=...,
123+
bins: int | None = ...,
118124
dropna: bool = ...,
119125
) -> Series[int]: ...
120126
@overload
@@ -123,7 +129,7 @@ class IndexOpsMixin(OpsMixin, Generic[S1, GenericT_co]):
123129
normalize: Literal[True],
124130
sort: bool = ...,
125131
ascending: bool = ...,
126-
bins=...,
132+
bins: int | None = ...,
127133
dropna: bool = ...,
128134
) -> Series[float]: ...
129135
def nunique(self, dropna: bool = True) -> int: ...
@@ -136,7 +142,18 @@ class IndexOpsMixin(OpsMixin, Generic[S1, GenericT_co]):
136142
def factorize(
137143
self, sort: bool = False, use_na_sentinel: bool = True
138144
) -> tuple[np_1darray, np_1darray | Index | Categorical]: ...
145+
@overload
139146
def searchsorted(
140-
self, value, side: Literal["left", "right"] = ..., sorter=...
141-
) -> int | list[int]: ...
147+
self,
148+
value: Scalar,
149+
side: Literal["left", "right"] = ...,
150+
sorter: AnyArrayLike = ...,
151+
) -> np.intp: ...
152+
@overload
153+
def searchsorted(
154+
self,
155+
value: ListLike,
156+
side: Literal["left", "right"] = ...,
157+
sorter: AnyArrayLike = ...,
158+
) -> np_1darray[np.intp]: ...
142159
def drop_duplicates(self, *, keep: DropKeep = ...) -> Self: ...

tests/indexes/test_indexes.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919
assert_type,
2020
)
2121

22+
from tests import np_1darray
23+
2224
if TYPE_CHECKING:
2325
from tests import Dtype # noqa: F401
2426

2527
from tests import (
2628
PD_LTE_23,
2729
TYPE_CHECKING_INVALID_USAGE,
2830
check,
29-
np_1darray,
3031
pytest_warns_bounded,
3132
)
3233

@@ -1489,6 +1490,16 @@ def test_index_naming() -> None:
14891490
check(assert_type(df.index.names, list[Hashable | None]), list)
14901491

14911492

1493+
def test_index_searchsorted() -> None:
1494+
idx = pd.Index([1, 2, 3])
1495+
check(assert_type(idx.searchsorted(1), np.intp), np.intp)
1496+
check(assert_type(idx.searchsorted([1]), "np_1darray[np.intp]"), np_1darray)
1497+
check(assert_type(idx.searchsorted(1, side="left"), np.intp), np.intp)
1498+
check(
1499+
assert_type(idx.searchsorted(1, sorter=pd.Series([1, 0, 2])), np.intp), np.intp
1500+
)
1501+
1502+
14921503
def test_period_index_constructor() -> None:
14931504
check(
14941505
assert_type(pd.PeriodIndex(["2000"], dtype="period[D]"), pd.PeriodIndex),

0 commit comments

Comments
 (0)