Skip to content

Commit aef0010

Browse files
authored
fix: fix TypeVars for Array, fix __array_namespace__ not typed, fix TypeVars for Info, add ShapedArray (#30)
1 parent af30c2e commit aef0010

File tree

8 files changed

+322
-157
lines changed

8 files changed

+322
-157
lines changed

README.md

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ pip install types-array-api
5252

5353
### Type stubs
5454

55-
Provices type stubs for [`array-api-compat`](https://data-apis.org/array-api-compat/).
55+
Autocompletion for [`array-api-compat`](https://data-apis.org/array-api-compat/) is available in your IDE **just by installing** this package.
5656

5757
```python
5858
import array_api_compat
@@ -63,19 +63,31 @@ xp = array_api_compat.array_namespace(x)
6363
![Screenshot 1](https://raw.githubusercontent.com/34j/array-api/main/docs/_static/screenshot1.png)
6464
![Screenshot 2](https://raw.githubusercontent.com/34j/array-api/main/docs/_static/screenshot2.png)
6565

66-
### Array Type
66+
### Typing functions using `Array`
6767

68-
```python
69-
from array_api._2024_12 import Array
68+
There are multiple ways to type functions:
7069

70+
- ```python
71+
from array_api._2024_12 import Array
7172

72-
def my_function[TArray: Array](x: TArray) -> TArray:
73-
return x + 1
74-
```
73+
def simple(x: Array) -> Array:
74+
return x + 1
75+
```
76+
77+
The simplest way to enjoy autocompletion for `Array`. This should be enough for most use cases.
78+
79+
- To make sure that the same type of array is returned (`ndarray``ndarray`, `Tensor``Tensor`), a `TypeVar` bound to `Array` can be used:
80+
81+
```python
82+
def generic[TArray: Array](x: TArray) -> TArray:
83+
return x + 1
84+
```
85+
86+
## Advanced Usage
7587

7688
### Namespace Type
7789

78-
You can test if an object matches the Protocol by:
90+
You can test if an object matches the Protocol as they are [`runtime-checkable`](https://docs.python.org/3/library/typing.html#typing.runtime_checkable):
7991

8092
```python
8193
import array_api_strict
@@ -89,6 +101,47 @@ assert isinstance(array_api_strict, ArrayNamespace)
89101
assert not isinstance(array_api_strict, ArrayNamespaceFull)
90102
```
91103

104+
### Shape Typing
105+
106+
- To clarify the input and output shapes, `ShapedArray` and `ShapedAnyArray` can be used:
107+
108+
```python
109+
from array_api._2024_12 import ShapedAnyArray as Array
110+
111+
def sum_last_axis[*TShape](x: Array[*TShape, Any]) -> Array[*TShape]:
112+
return xp.sum(x, axis=-1)
113+
```
114+
115+
More complex example using [NewType](https://docs.python.org/3/library/typing.html#newtype) or [type aliases](https://docs.python.org/3/library/typing.html#type-aliases):
116+
117+
```python
118+
RTheta = NewType("RTheta", int)
119+
XY = NewType("XY", int)
120+
def polar_coordinates[*TShape](randtheta: Array[*TShape, RTheta]) -> Array[*TShape, XY]:
121+
"""Convert polar coordinates to Cartesian coordinates."""
122+
r = randtheta[..., 0]
123+
theta = randtheta[..., 1]
124+
x = r * xp.cos(theta)
125+
y = r * xp.sin(theta)
126+
return xp.stack((x, y), axis=-1)
127+
```
128+
129+
Note that `ShapedAnyArray` exists only for **documentation purposes** and internally it is treated as `Array`.
130+
Using both generic and shaped are impossible due to [python/typing#548](https://github.com/python/typing/issues/548).
131+
132+
- Note that the below example is ideal but impossible due to Python specification.
133+
134+
```python
135+
def impossible[
136+
TDtype,
137+
TDevice,
138+
*TShapeFormer: int,
139+
*TShapeLatter: int,
140+
TArray: Array
141+
](x: TArray[*TShapeFormer, *TShapeLatter | Literal[1], TDtype, TDevice], y: TArray[*TShapeLatter | Literal[1], TDtype, TDevice]) -> TArray[*TShapeFormer, *TShapeLatter, TDtype, TDevice]:
142+
return x + y # broadcasting
143+
```
144+
92145
## Contributors ✨
93146

94147
Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)):

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ urls.Changelog = "https://github.com/34j/types-array-api/blob/main/CHANGELOG.md"
3838
urls.documentation = "https://array-api.readthedocs.io"
3939
urls.repository = "https://github.com/34j/types-array-api"
4040
scripts.array-api = "array_api.cli:app"
41+
scripts.types-array-api = "array_api.cli:app"
4142

4243
[dependency-groups]
4344
dev = [

0 commit comments

Comments
 (0)