Skip to content

Commit d1cd99a

Browse files
committed
chore: wip
1 parent 8471bc3 commit d1cd99a

File tree

1 file changed

+55
-1
lines changed

1 file changed

+55
-1
lines changed

src/array_api/cli/_main.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _function_to_protocol(stmt: ast.FunctionDef, typevars: Sequence[TypeVarInfo]
5656
Returns
5757
-------
5858
ProtocolData
59-
A ProtocolData object.
59+
A ProtocolData object containing the converted function definition.
6060
6161
"""
6262
stmt = deepcopy(stmt)
@@ -89,6 +89,22 @@ def _function_to_protocol(stmt: ast.FunctionDef, typevars: Sequence[TypeVarInfo]
8989

9090

9191
def _class_to_protocol(stmt: ast.ClassDef, typevars: Sequence[TypeVarInfo]) -> ProtocolData:
92+
"""
93+
Convert a class definition to a Protocol class.
94+
95+
Parameters
96+
----------
97+
stmt : ast.ClassDef
98+
The class definition to convert.
99+
typevars : Sequence[TypeVarInfo]
100+
The type variables used in the class.
101+
102+
Returns
103+
-------
104+
ProtocolData
105+
The ProtocolData object containing the converted class definition.
106+
107+
"""
92108
unp = ast.unparse(stmt)
93109
typevars = [typevar for typevar in typevars if typevar.name in unp]
94110
stmt.bases = [
@@ -111,6 +127,22 @@ def _class_to_protocol(stmt: ast.ClassDef, typevars: Sequence[TypeVarInfo]) -> P
111127

112128

113129
def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes]) -> ProtocolData:
130+
"""
131+
Convert a list of module attributes to a Protocol class.
132+
133+
Parameters
134+
----------
135+
name : str
136+
The name of the Protocol class.
137+
attributes : Sequence[ModuleAttributes]
138+
The attributes to include in the Protocol class.
139+
140+
Returns
141+
-------
142+
ProtocolData
143+
The ProtocolData object containing the converted attributes.
144+
145+
"""
114146
body: list[ast.stmt] = []
115147
for a in attributes:
116148
body.append(
@@ -140,6 +172,17 @@ def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes]) -
140172

141173

142174
def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
175+
"""
176+
Generate Protocol classes from the given module body.
177+
178+
Parameters
179+
----------
180+
body_module : dict[str, list[ast.stmt]]
181+
The module body containing the AST statements for each submodule.
182+
out_path : Path
183+
The output path where the generated Protocol classes will be saved.
184+
185+
"""
143186
body_typevars = body_module["_types"]
144187
del body_module["__init__"]
145188

@@ -264,6 +307,17 @@ def generate_all(
264307
cache_dir: Path | str = ".cache",
265308
out_path: Path | str = "src/array_api",
266309
) -> None:
310+
"""
311+
Clone the array-api repository and generate Protocol classes for all versions.
312+
313+
Parameters
314+
----------
315+
cache_dir : Path | str, optional
316+
The directory where the array-api repository will be cloned, by default ".cache"
317+
out_path : Path | str, optional
318+
The output path where the generated Protocol classes will be saved, by default "src/array_api"
319+
320+
"""
267321
import subprocess as sp
268322

269323
Path(cache_dir).mkdir(exist_ok=True)

0 commit comments

Comments
 (0)