@@ -219,33 +219,43 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
219219 if isinstance (b , (ast .Import , ast .ImportFrom )):
220220 pass
221221 elif isinstance (b , ast .FunctionDef ):
222+ # implemented in object rather than Namespace
222223 if b .name == "__eq__" :
223224 continue
225+ # info.py conntains functions which are not part of the Namespace (but Info class)
224226 if submodule == "info" and b .name != "__array_namespace_info__" :
225227 continue
226228 data = _function_to_protocol (b , typevars )
229+ # add to module attributes
227230 module_attributes [submodule ].append (ModuleAttributes (b .name , data .name , None , data .typevars_used ))
231+ # some functions are duplicated in linalg and fft, skip them
232+ # their docstrings are unhelpful, e.g. "Alias for ..."
228233 if "Alias" in (ast .get_docstring (b ) or "" ):
229234 continue
235+ # add to output
230236 out .body .append (data .stmt )
231237 elif isinstance (b , ast .Assign ):
238+ # _types.py contains Assigns which are not part of the Namespace
232239 if submodule == "_types" :
233240 continue
234241 if not isinstance (b .targets [0 ], ast .Name ):
235242 continue
236243 id = b .targets [0 ].id
244+ # __init__.py
237245 if id == "__all__" :
238- pass
239- else :
240- docstring = None
241- if i != len (body ) - 1 :
242- docstring_expr = body [i + 1 ]
243- if isinstance (docstring_expr , ast .Expr ):
244- if isinstance (docstring_expr .value , ast .Constant ):
245- docstring = docstring_expr .value .value
246- module_attributes [submodule ].append (ModuleAttributes (id , ast .Name (id = "float" ), docstring , []))
246+ continue
247+ # get docstring
248+ docstring = None
249+ if i != len (body ) - 1 :
250+ docstring_expr = body [i + 1 ]
251+ if isinstance (docstring_expr , ast .Expr ):
252+ if isinstance (docstring_expr .value , ast .Constant ):
253+ docstring = docstring_expr .value .value
254+ # add to module attributes
255+ module_attributes [submodule ].append (ModuleAttributes (id , ast .Name (id = "float" ), docstring , []))
247256 elif isinstance (b , ast .ClassDef ):
248257 data = _class_to_protocol (b , typevars )
258+ # add to output, do not add to module attributes
249259 out .body .append (data .stmt )
250260 elif isinstance (b , ast .Expr ):
251261 pass
@@ -267,6 +277,7 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
267277 attributes = [attribute for submodule , attributes in module_attributes .items () for attribute in attributes if submodule not in OPTIONAL_SUBMODULES ] + submodules
268278 out .body .append (_attributes_to_protocol ("ArrayNamespace" , attributes ).stmt )
269279
280+ # Replace TypeVars because of the name conflicts like "array: array"
270281 for node in ast .walk (out ):
271282 for child in ast .iter_child_nodes (node ):
272283 if isinstance (child , ast .Name ) and child .id in {t .name for t in typevars }:
@@ -277,7 +288,10 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
277288 elif isinstance (child , ast .TypeVar ) and child .name in {t .name for t in typevars }:
278289 child .name = "T" + child .name .capitalize ()
279290
291+ # Manual modifications (easier than AST manipulations)
280292 text = ast .unparse (ast .fix_missing_locations (out ))
293+
294+ # Add imports
281295 text = (
282296 """"Auto generated Protocol classes (Do not edit)"
283297from enum import Enum
@@ -297,8 +311,12 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
297311"""
298312 + text
299313 )
314+
315+ # Fix self-references in typing
300316 ns = "Union[T_t_co, NestedSequence[T_t_co]]"
301317 text = text .replace (ns , f'"{ ns } "' )
318+
319+ # write to the output path
302320 out_path .parent .mkdir (parents = True , exist_ok = True )
303321 out_path .write_text (text , "utf-8" )
304322
@@ -324,6 +342,7 @@ def generate_all(
324342 sp .run (["git" , "clone" , "https://github.com/data-apis/array-api" , ".cache" ])
325343
326344 for dir_path in (Path (cache_dir ) / Path ("src" ) / "array_api_stubs" ).iterdir ():
345+ # skip non-directory entries
327346 if not dir_path .is_dir ():
328347 continue
329348 # 2021 is broken (no self keyword in `_array`` methods)
@@ -335,5 +354,6 @@ def generate_all(
335354
336355 import sys
337356
357+ # run ssort, otherwise it is broken
338358 sys .argv = ["ssort" , "src/array_api" ]
339359 runpy .run_module ("ssort" )
0 commit comments