Skip to content

Commit 1ef3aa0

Browse files
committed
chore: wip
1 parent d1cd99a commit 1ef3aa0

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

src/array_api/cli/_main.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -219,33 +219,43 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
219219
if isinstance(b, (ast.Import, ast.ImportFrom)):
220220
pass
221221
elif isinstance(b, ast.FunctionDef):
222+
# implemented in object rather than Namespace
222223
if b.name == "__eq__":
223224
continue
225+
# info.py conntains functions which are not part of the Namespace (but Info class)
224226
if submodule == "info" and b.name != "__array_namespace_info__":
225227
continue
226228
data = _function_to_protocol(b, typevars)
229+
# add to module attributes
227230
module_attributes[submodule].append(ModuleAttributes(b.name, data.name, None, data.typevars_used))
231+
# some functions are duplicated in linalg and fft, skip them
232+
# their docstrings are unhelpful, e.g. "Alias for ..."
228233
if "Alias" in (ast.get_docstring(b) or ""):
229234
continue
235+
# add to output
230236
out.body.append(data.stmt)
231237
elif isinstance(b, ast.Assign):
238+
# _types.py contains Assigns which are not part of the Namespace
232239
if submodule == "_types":
233240
continue
234241
if not isinstance(b.targets[0], ast.Name):
235242
continue
236243
id = b.targets[0].id
244+
# __init__.py
237245
if id == "__all__":
238-
pass
239-
else:
240-
docstring = None
241-
if i != len(body) - 1:
242-
docstring_expr = body[i + 1]
243-
if isinstance(docstring_expr, ast.Expr):
244-
if isinstance(docstring_expr.value, ast.Constant):
245-
docstring = docstring_expr.value.value
246-
module_attributes[submodule].append(ModuleAttributes(id, ast.Name(id="float"), docstring, []))
246+
continue
247+
# get docstring
248+
docstring = None
249+
if i != len(body) - 1:
250+
docstring_expr = body[i + 1]
251+
if isinstance(docstring_expr, ast.Expr):
252+
if isinstance(docstring_expr.value, ast.Constant):
253+
docstring = docstring_expr.value.value
254+
# add to module attributes
255+
module_attributes[submodule].append(ModuleAttributes(id, ast.Name(id="float"), docstring, []))
247256
elif isinstance(b, ast.ClassDef):
248257
data = _class_to_protocol(b, typevars)
258+
# add to output, do not add to module attributes
249259
out.body.append(data.stmt)
250260
elif isinstance(b, ast.Expr):
251261
pass
@@ -267,6 +277,7 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
267277
attributes = [attribute for submodule, attributes in module_attributes.items() for attribute in attributes if submodule not in OPTIONAL_SUBMODULES] + submodules
268278
out.body.append(_attributes_to_protocol("ArrayNamespace", attributes).stmt)
269279

280+
# Replace TypeVars because of the name conflicts like "array: array"
270281
for node in ast.walk(out):
271282
for child in ast.iter_child_nodes(node):
272283
if isinstance(child, ast.Name) and child.id in {t.name for t in typevars}:
@@ -277,7 +288,10 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
277288
elif isinstance(child, ast.TypeVar) and child.name in {t.name for t in typevars}:
278289
child.name = "T" + child.name.capitalize()
279290

291+
# Manual modifications (easier than AST manipulations)
280292
text = ast.unparse(ast.fix_missing_locations(out))
293+
294+
# Add imports
281295
text = (
282296
""""Auto generated Protocol classes (Do not edit)"
283297
from enum import Enum
@@ -297,8 +311,12 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
297311
"""
298312
+ text
299313
)
314+
315+
# Fix self-references in typing
300316
ns = "Union[T_t_co, NestedSequence[T_t_co]]"
301317
text = text.replace(ns, f'"{ns}"')
318+
319+
# write to the output path
302320
out_path.parent.mkdir(parents=True, exist_ok=True)
303321
out_path.write_text(text, "utf-8")
304322

@@ -324,6 +342,7 @@ def generate_all(
324342
sp.run(["git", "clone", "https://github.com/data-apis/array-api", ".cache"])
325343

326344
for dir_path in (Path(cache_dir) / Path("src") / "array_api_stubs").iterdir():
345+
# skip non-directory entries
327346
if not dir_path.is_dir():
328347
continue
329348
# 2021 is broken (no self keyword in `_array`` methods)
@@ -335,5 +354,6 @@ def generate_all(
335354

336355
import sys
337356

357+
# run ssort, otherwise it is broken
338358
sys.argv = ["ssort", "src/array_api"]
339359
runpy.run_module("ssort")

0 commit comments

Comments
 (0)