11from __future__ import annotations
22
3+ from collections .abc import Mapping
34from types import ModuleType as Namespace
4- from typing import Any , Protocol , TypeAlias , TypedDict , TypeVar
5+ from typing import TYPE_CHECKING , Literal , Protocol , TypeAlias , TypedDict , TypeVar
6+
7+ if TYPE_CHECKING :
8+ from _typeshed import Incomplete
9+
10+ SupportsBufferProtocol : TypeAlias = Incomplete
11+ Array : TypeAlias = Incomplete
12+ Device : TypeAlias = Incomplete
13+ DType : TypeAlias = Incomplete
14+ else :
15+ SupportsBufferProtocol = object
16+ Array = object
17+ Device = object
18+ DType = object
19+
520
621_T_co = TypeVar ("_T_co" , covariant = True )
722
@@ -20,6 +35,7 @@ class HasShape(Protocol[_T_co]):
2035 def shape (self , / ) -> _T_co : ...
2136
2237
38+ # Return type of `__array_namespace_info__.default_dtypes`
2339Capabilities = TypedDict (
2440 "Capabilities" ,
2541 {
@@ -29,17 +45,98 @@ def shape(self, /) -> _T_co: ...
2945 },
3046)
3147
48+ # Return type of `__array_namespace_info__.default_dtypes`
49+ DefaultDTypes = TypedDict (
50+ "DefaultDTypes" ,
51+ {
52+ "real floating" : DType ,
53+ "complex floating" : DType ,
54+ "integral" : DType ,
55+ "indexing" : DType ,
56+ },
57+ )
58+
59+
60+ _DTypeKind : TypeAlias = Literal [
61+ "bool" ,
62+ "signed integer" ,
63+ "unsigned integer" ,
64+ "integral" ,
65+ "real floating" ,
66+ "complex floating" ,
67+ "numeric" ,
68+ ]
69+ # Type of the `kind` parameter in `__array_namespace_info__.dtypes`
70+ DTypeKind : TypeAlias = _DTypeKind | tuple [_DTypeKind , ...]
71+
72+
73+ # `__array_namespace_info__.dtypes(kind="bool")`
74+ class DTypesBool (TypedDict ):
75+ bool : DType
76+
77+
78+ # `__array_namespace_info__.dtypes(kind="signed integer")`
79+ class DTypesSigned (TypedDict ):
80+ int8 : DType
81+ int16 : DType
82+ int32 : DType
83+ int64 : DType
84+
85+
86+ # `__array_namespace_info__.dtypes(kind="unsigned integer")`
87+ class DTypesUnsigned (TypedDict ):
88+ uint8 : DType
89+ uint16 : DType
90+ uint32 : DType
91+ uint64 : DType
92+
93+
94+ # `__array_namespace_info__.dtypes(kind="integral")`
95+ class DTypesIntegral (DTypesSigned , DTypesUnsigned ):
96+ pass
97+
98+
99+ # `__array_namespace_info__.dtypes(kind="real floating")`
100+ class DTypesReal (TypedDict ):
101+ float32 : DType
102+ float64 : DType
103+
104+
105+ # `__array_namespace_info__.dtypes(kind="complex floating")`
106+ class DTypesComplex (TypedDict ):
107+ complex64 : DType
108+ complex128 : DType
109+
110+
111+ # `__array_namespace_info__.dtypes(kind="numeric")`
112+ class DTypesNumeric (DTypesIntegral , DTypesReal , DTypesComplex ):
113+ pass
114+
115+
116+ # `__array_namespace_info__.dtypes(kind=None)` (default)
117+ class DTypesAll (DTypesBool , DTypesNumeric ):
118+ pass
119+
32120
33- SupportsBufferProtocol : TypeAlias = Any
34- Array : TypeAlias = Any
35- Device : TypeAlias = Any
36- DType : TypeAlias = Any
121+ # `__array_namespace_info__.dtypes(kind=?)` (fallback)
122+ DTypesAny : TypeAlias = Mapping [str , DType ]
37123
38124
39125__all__ = [
40126 "Array" ,
41127 "Capabilities" ,
42128 "DType" ,
129+ "DTypeKind" ,
130+ "DTypesAny" ,
131+ "DTypesAll" ,
132+ "DTypesBool" ,
133+ "DTypesNumeric" ,
134+ "DTypesIntegral" ,
135+ "DTypesSigned" ,
136+ "DTypesUnsigned" ,
137+ "DTypesReal" ,
138+ "DTypesComplex" ,
139+ "DefaultDTypes" ,
43140 "Device" ,
44141 "HasShape" ,
45142 "Namespace" ,
0 commit comments