1+ # pyright: reportPrivateUsage=false
2+ # pyright: reportUnknownArgumentType=false
3+ # pyright: reportUnknownMemberType=false
4+ # pyright: reportUnknownVariableType=false
5+
16from __future__ import annotations
27
3- from typing import Callable , Optional , Union
8+ from builtins import bool as py_bool
9+ from collections .abc import Callable
10+ from typing import TYPE_CHECKING , Any
11+
12+ if TYPE_CHECKING :
13+ from typing_extensions import TypeIs
414
15+ import dask .array as da
516import numpy as np
17+ from numpy import bool_ as bool
618from numpy import (
7- # dtypes
8- bool_ as bool ,
19+ can_cast ,
20+ complex64 ,
21+ complex128 ,
922 float32 ,
1023 float64 ,
1124 int8 ,
1225 int16 ,
1326 int32 ,
1427 int64 ,
28+ result_type ,
1529 uint8 ,
1630 uint16 ,
1731 uint32 ,
1832 uint64 ,
19- complex64 ,
20- complex128 ,
21- can_cast ,
22- result_type ,
2333)
24- import dask .array as da
2534
35+ from ..._internal import get_xp
2636from ...common import _aliases , _helpers , array_namespace
2737from ...common ._typing import (
2838 Array ,
3141 NestedSequence ,
3242 SupportsBufferProtocol ,
3343)
34- from ..._internal import get_xp
3544from ._info import __array_namespace_info__
3645
3746isdtype = get_xp (np )(_aliases .isdtype )
@@ -44,8 +53,8 @@ def astype(
4453 dtype : DType ,
4554 / ,
4655 * ,
47- copy : bool = True ,
48- device : Optional [ Device ] = None ,
56+ copy : py_bool = True ,
57+ device : Device | None = None ,
4958) -> Array :
5059 """
5160 Array API compatibility wrapper for astype().
@@ -69,14 +78,14 @@ def astype(
6978# not pass stop/step as keyword arguments, which will cause
7079# an error with dask
7180def arange (
72- start : Union [ int , float ] ,
81+ start : float ,
7382 / ,
74- stop : Optional [ Union [ int , float ]] = None ,
75- step : Union [ int , float ] = 1 ,
83+ stop : float | None = None ,
84+ step : float = 1 ,
7685 * ,
77- dtype : Optional [ DType ] = None ,
78- device : Optional [ Device ] = None ,
79- ** kwargs ,
86+ dtype : DType | None = None ,
87+ device : Device | None = None ,
88+ ** kwargs : object ,
8089) -> Array :
8190 """
8291 Array API compatibility wrapper for arange().
@@ -87,7 +96,7 @@ def arange(
8796 # TODO: respect device keyword?
8897 _helpers ._check_device (da , device )
8998
90- args = [start ]
99+ args : list [ Any ] = [start ]
91100 if stop is not None :
92101 args .append (stop )
93102 else :
@@ -137,18 +146,13 @@ def arange(
137146
138147# asarray also adds the copy keyword, which is not present in numpy 1.0.
139148def asarray (
140- obj : (
141- Array
142- | bool | int | float | complex
143- | NestedSequence [bool | int | float | complex ]
144- | SupportsBufferProtocol
145- ),
149+ obj : complex | NestedSequence [complex ] | Array | SupportsBufferProtocol ,
146150 / ,
147151 * ,
148- dtype : Optional [ DType ] = None ,
149- device : Optional [ Device ] = None ,
150- copy : Optional [ bool ] = None ,
151- ** kwargs ,
152+ dtype : DType | None = None ,
153+ device : Device | None = None ,
154+ copy : py_bool | None = None ,
155+ ** kwargs : object ,
152156) -> Array :
153157 """
154158 Array API compatibility wrapper for asarray().
@@ -164,7 +168,7 @@ def asarray(
164168 if copy is False :
165169 raise ValueError ("Unable to avoid copy when changing dtype" )
166170 obj = obj .astype (dtype )
167- return obj .copy () if copy else obj
171+ return obj .copy () if copy else obj # pyright: ignore[reportAttributeAccessIssue]
168172
169173 if copy is False :
170174 raise NotImplementedError (
@@ -177,22 +181,21 @@ def asarray(
177181 return da .from_array (obj )
178182
179183
180- from dask .array import (
181- # Element wise aliases
182- arccos as acos ,
183- arccosh as acosh ,
184- arcsin as asin ,
185- arcsinh as asinh ,
186- arctan as atan ,
187- arctan2 as atan2 ,
188- arctanh as atanh ,
189- left_shift as bitwise_left_shift ,
190- right_shift as bitwise_right_shift ,
191- invert as bitwise_invert ,
192- power as pow ,
193- # Other
194- concatenate as concat ,
195- )
184+ # Element wise aliases
185+ from dask .array import arccos as acos
186+ from dask .array import arccosh as acosh
187+ from dask .array import arcsin as asin
188+ from dask .array import arcsinh as asinh
189+ from dask .array import arctan as atan
190+ from dask .array import arctan2 as atan2
191+ from dask .array import arctanh as atanh
192+
193+ # Other
194+ from dask .array import concatenate as concat
195+ from dask .array import invert as bitwise_invert
196+ from dask .array import left_shift as bitwise_left_shift
197+ from dask .array import power as pow
198+ from dask .array import right_shift as bitwise_right_shift
196199
197200
198201# dask.array.clip does not work unless all three arguments are provided.
@@ -202,8 +205,8 @@ def asarray(
202205def clip (
203206 x : Array ,
204207 / ,
205- min : Optional [ Union [ int , float , Array ]] = None ,
206- max : Optional [ Union [ int , float , Array ]] = None ,
208+ min : float | Array | None = None ,
209+ max : float | Array | None = None ,
207210) -> Array :
208211 """
209212 Array API compatibility wrapper for clip().
@@ -212,8 +215,8 @@ def clip(
212215 specification for more details.
213216 """
214217
215- def _isscalar (a ) :
216- return isinstance (a , (int , float , type ( None ) ))
218+ def _isscalar (a : float | Array | None , / ) -> TypeIs [ float | None ] :
219+ return a is None or isinstance (a , (int , float ))
217220
218221 min_shape = () if _isscalar (min ) else min .shape
219222 max_shape = () if _isscalar (max ) else max .shape
@@ -266,7 +269,12 @@ def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array],
266269
267270
268271def sort (
269- x : Array , / , * , axis : int = - 1 , descending : bool = False , stable : bool = True
272+ x : Array ,
273+ / ,
274+ * ,
275+ axis : int = - 1 ,
276+ descending : py_bool = False ,
277+ stable : py_bool = True ,
270278) -> Array :
271279 """
272280 Array API compatibility layer around the lack of sort() in Dask.
@@ -296,7 +304,12 @@ def sort(
296304
297305
298306def argsort (
299- x : Array , / , * , axis : int = - 1 , descending : bool = False , stable : bool = True
307+ x : Array ,
308+ / ,
309+ * ,
310+ axis : int = - 1 ,
311+ descending : py_bool = False ,
312+ stable : py_bool = True ,
300313) -> Array :
301314 """
302315 Array API compatibility layer around the lack of argsort() in Dask.
@@ -330,25 +343,34 @@ def argsort(
330343# dask.array.count_nonzero does not have keepdims
331344def count_nonzero (
332345 x : Array ,
333- axis = None ,
334- keepdims = False
346+ axis : int | None = None ,
347+ keepdims : py_bool = False ,
335348) -> Array :
336- result = da .count_nonzero (x , axis )
337- if keepdims :
338- if axis is None :
339- return da .reshape (result , [1 ]* x .ndim )
340- return da .expand_dims (result , axis )
341- return result
342-
343-
349+ result = da .count_nonzero (x , axis )
350+ if keepdims :
351+ if axis is None :
352+ return da .reshape (result , [1 ] * x .ndim )
353+ return da .expand_dims (result , axis )
354+ return result
355+
356+
357+ __all__ = [
358+ "__array_namespace_info__" ,
359+ "count_nonzero" ,
360+ "bool" ,
361+ "int8" , "int16" , "int32" , "int64" ,
362+ "uint8" , "uint16" , "uint32" , "uint64" ,
363+ "float32" , "float64" ,
364+ "complex64" , "complex128" ,
365+ "asarray" , "astype" , "can_cast" , "result_type" ,
366+ "pow" ,
367+ "concat" ,
368+ "acos" , "acosh" , "asin" , "asinh" , "atan" , "atan2" , "atanh" ,
369+ "bitwise_left_shift" , "bitwise_right_shift" , "bitwise_invert" ,
370+ ] # fmt: skip
371+ __all__ += _aliases .__all__
372+ _all_ignore = ["array_namespace" , "get_xp" , "da" , "np" ]
344373
345- __all__ = _aliases .__all__ + [
346- '__array_namespace_info__' , 'asarray' , 'astype' , 'acos' ,
347- 'acosh' , 'asin' , 'asinh' , 'atan' , 'atan2' ,
348- 'atanh' , 'bitwise_left_shift' , 'bitwise_invert' ,
349- 'bitwise_right_shift' , 'concat' , 'pow' , 'can_cast' ,
350- 'result_type' , 'bool' , 'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' ,
351- 'uint8' , 'uint16' , 'uint32' , 'uint64' , 'complex64' , 'complex128' ,
352- 'can_cast' , 'count_nonzero' , 'result_type' ]
353374
354- _all_ignore = ["array_namespace" , "get_xp" , "da" , "np" ]
375+ def __dir__ () -> list [str ]:
376+ return __all__
0 commit comments