From a73f0bd9ca31b7fdf9429a7b32a093383777ac18 Mon Sep 17 00:00:00 2001 From: Mashed Potato <38517644+potatomashed@users.noreply.github.com> Date: Thu, 17 Jul 2025 15:26:24 -0700 Subject: [PATCH] feat(dataclasses): Unify `c_class` more with `py_class` --- pyproject.toml | 1 + python/mlc/_cython/base.py | 17 +- python/mlc/dataclasses/c_class.py | 77 ++++----- python/mlc/dataclasses/py_class.py | 153 +++++------------- python/mlc/dataclasses/utils.py | 85 ++++++++-- python/mlc/printer/ast.py | 99 ++++++------ python/mlc/sym/_internal.py | 10 +- python/mlc/sym/analyzer.py | 23 ++- .../test_sym_analyzer_const_int_bound.py | 6 +- tests/python/test_sym_analyzer_modular_set.py | 6 +- 10 files changed, 239 insertions(+), 238 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cbd83e1f..b81492f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ dependencies = [ 'ml-dtypes >= 0.1', 'Pygments>=2.4.0', 'colorama', + 'typing-extensions >= 4.9.0', 'setuptools ; platform_system == "Windows"', ] description = "Python-first Development for AI Compilers" diff --git a/python/mlc/_cython/base.py b/python/mlc/_cython/base.py index 874579c6..e427fe74 100644 --- a/python/mlc/_cython/base.py +++ b/python/mlc/_cython/base.py @@ -366,13 +366,13 @@ def __new__( return super().__new__(cls, name, bases, dict) -def attach_field( +def make_field( cls: type, name: str, getter: typing.Callable[[typing.Any], typing.Any] | None, setter: typing.Callable[[typing.Any, typing.Any], None] | None, frozen: bool, -) -> None: +) -> property: def fget(this: typing.Any, _name: str = name) -> typing.Any: return getter(this) # type: ignore[misc] @@ -383,12 +383,21 @@ def fset(this: typing.Any, value: typing.Any, _name: str = name) -> None: fget.__module__ = fset.__module__ = cls.__module__ fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}" # type: ignore[attr-defined] fget.__doc__ = fset.__doc__ = f"Property `{name}` of class `{cls.__qualname__}`" # type: ignore[attr-defined] - prop = property( + return property( fget=fget if getter else None, fset=fset if (not frozen) and setter else None, doc=f"{cls.__module__}.{cls.__qualname__}.{name}", ) - setattr(cls, name, prop) + + +def attach_field( + cls: type, + name: str, + getter: typing.Callable[[typing.Any], typing.Any] | None, + setter: typing.Callable[[typing.Any, typing.Any], None] | None, + frozen: bool, +) -> None: + setattr(cls, name, make_field(cls, name, getter, setter, frozen)) # type: ignore[call-arg] def attach_method( diff --git a/python/mlc/dataclasses/c_class.py b/python/mlc/dataclasses/c_class.py index 0bc70c1f..a5cdd665 100644 --- a/python/mlc/dataclasses/c_class.py +++ b/python/mlc/dataclasses/c_class.py @@ -1,79 +1,68 @@ -import functools import typing import warnings from collections.abc import Callable +try: + from typing import dataclass_transform +except ImportError: + from typing_extensions import dataclass_transform + from mlc._cython import ( TypeInfo, TypeMethod, - attach_field, - attach_method, type_index2type_methods, type_key2py_type_info, ) from mlc.core import typing as mlc_typing -from .utils import ( - add_vtable_methods_for_type_cls, - get_parent_type, - inspect_dataclass_fields, - method_init, - prototype, -) +from . import utils -ClsType = typing.TypeVar("ClsType") +InputClsType = typing.TypeVar("InputClsType") +@dataclass_transform(field_specifiers=(utils.field, utils.Field)) def c_class( type_key: str, init: bool = True, -) -> Callable[[type[ClsType]], type[ClsType]]: - def decorator(super_type_cls: type[ClsType]) -> type[ClsType]: - @functools.wraps(super_type_cls, updated=()) - class type_cls(super_type_cls): # type: ignore[valid-type,misc] - __slots__ = () - +) -> Callable[[type[InputClsType]], type[InputClsType]]: + def decorator(super_type_cls: type[InputClsType]) -> type[InputClsType]: # Step 1. Retrieve `type_info` from registry + parent_type_info: TypeInfo = utils.get_parent_type(super_type_cls)._mlc_type_info # type: ignore[attr-defined] type_info: TypeInfo = type_key2py_type_info(type_key) - parent_type_info: TypeInfo = get_parent_type(super_type_cls)._mlc_type_info # type: ignore[attr-defined] - if type_info.type_cls is not None: raise ValueError(f"Type is already registered: {type_key}") - _, d_fields = inspect_dataclass_fields(type_key, type_cls, parent_type_info, frozen=False) - type_info.type_cls = type_cls - type_info.d_fields = tuple(d_fields) - # Step 2. Check if all fields are exposed as type annotations + # Step 2. Reflect all the fields of the type + _, d_fields, _ = utils.inspect_dataclass_fields( + super_type_cls, + parent_type_info, + frozen=False, + ) + type_info.d_fields = tuple(d_fields) + # Check if all fields are exposed as type annotations _check_c_class(super_type_cls, type_info) - # Step 3. Attach fields - setattr(type_cls, "_mlc_type_info", type_info) - for field in type_info.fields: - attach_field( - cls=type_cls, - name=field.name, - getter=field.getter, - setter=field.setter, - frozen=field.frozen, - ) - # Step 4. Attach methods + fn_init: Callable[..., None] | None = None if init: - attach_method( - parent_cls=super_type_cls, - cls=type_cls, - name="__init__", - method=method_init(super_type_cls, d_fields), - check_exists=True, - ) - add_vtable_methods_for_type_cls(super_type_cls, type_index=type_info.type_index) + fn_init = utils.method_init(super_type_cls, d_fields) + else: + fn_init = None + # Step 5. Create the proxy class with the fields as properties + type_cls: type[InputClsType] = utils.create_type_class( + cls=super_type_cls, + type_info=type_info, + methods={ + "__init__": fn_init, + }, + ) return type_cls return decorator def _check_c_class( - type_cls: type[ClsType], + type_cls: type[InputClsType], type_info: TypeInfo, ) -> None: type_hints = typing.get_type_hints(type_cls) @@ -117,5 +106,5 @@ def _check_c_class( if warned: warnings.warn( f"One or multiple warnings in `{type_cls.__module__}.{type_cls.__qualname__}`. Its prototype is:\n" - + prototype(type_info, lang="py") + + utils.prototype(type_info, lang="py") ) diff --git a/python/mlc/dataclasses/py_class.py b/python/mlc/dataclasses/py_class.py index 1d0c3c87..296914c2 100644 --- a/python/mlc/dataclasses/py_class.py +++ b/python/mlc/dataclasses/py_class.py @@ -5,43 +5,26 @@ except ImportError: from typing_extensions import dataclass_transform -import ctypes -import functools import typing from collections.abc import Callable from mlc._cython import ( - MLCHeader, TypeField, TypeInfo, - attach_field, - attach_method, make_mlc_init, type_add_method, type_create, type_create_instance, - type_field_get_accessor, type_register_fields, type_register_structure, ) -from mlc.core import Object - -from .utils import Field as _Field -from .utils import ( - Structure, - add_vtable_methods_for_type_cls, - get_parent_type, - inspect_dataclass_fields, - method_init, - structure_parse, - structure_to_c, -) -from .utils import field as _field + +from . import utils InputClsType = typing.TypeVar("InputClsType") -@dataclass_transform(field_specifiers=(_field, _Field)) +@dataclass_transform(field_specifiers=(utils.field, utils.Field)) def py_class( type_key: str | type | None = None, *, @@ -69,56 +52,36 @@ def decorator(super_type_cls: type[InputClsType]) -> type[InputClsType]: type_key = f"{super_type_cls.__module__}.{super_type_cls.__qualname__}" assert isinstance(type_key, str) - # Step 1. Create the type according to its parent type - parent_type_info: TypeInfo = get_parent_type(super_type_cls)._mlc_type_info # type: ignore[attr-defined] + # Step 1. Create `type_info` + parent_type_info: TypeInfo = utils.get_parent_type(super_type_cls)._mlc_type_info # type: ignore[attr-defined] type_info: TypeInfo = type_create(parent_type_info.type_index, type_key) type_index = type_info.type_index # Step 2. Reflect all the fields of the type - fields, d_fields = inspect_dataclass_fields( - type_key, + fields, d_fields, num_bytes = utils.inspect_dataclass_fields( super_type_cls, parent_type_info, frozen=frozen, + py_mode=True, ) - num_bytes = _add_field_properties(fields) type_info.fields = tuple(fields) type_info.d_fields = tuple(d_fields) type_register_fields(type_index, fields) - mlc_init = make_mlc_init(fields) - - # Step 3. Create the proxy class with the fields as properties - type_cls: type[InputClsType] = _create_cls( - cls=super_type_cls, - mlc_init=mlc_init, - mlc_new=lambda cls, *args, **kwargs: type_create_instance(cls, type_index, num_bytes), - ) - type_info.type_cls = type_cls - setattr(type_cls, "_mlc_type_info", type_info) - for field in fields: - attach_field( - type_cls, - name=field.name, - getter=field.getter, - setter=field.setter, - frozen=field.frozen, - ) - - # Step 4. Register the structure of the class - struct: Structure + # Step 3. Register the structure of the class + struct: utils.Structure struct_kind: int sub_structure_indices: list[int] sub_structure_kinds: list[int] if (struct := vars(super_type_cls).get("_mlc_structure", None)) is not None: - assert isinstance(struct, Structure) + assert isinstance(struct, utils.Structure) else: - struct = structure_parse(structure, d_fields) + struct = utils.structure_parse(structure, d_fields) ( struct_kind, sub_structure_indices, sub_structure_kinds, - ) = structure_to_c(struct, fields) + ) = utils.structure_to_c(struct, fields) if struct.kind is None: assert struct_kind == 0 assert not sub_structure_indices @@ -129,45 +92,46 @@ def decorator(super_type_cls: type[InputClsType]) -> type[InputClsType]: sub_structure_indices=tuple(sub_structure_indices), sub_structure_kinds=tuple(sub_structure_kinds), ) - setattr(type_cls, "_mlc_structure", struct) - # Step 5. Add `__init__` method - type_add_method(type_index, "__init__", _method_new(type_cls), 1) # static - # Step 6. Attach methods - fn: Callable[..., typing.Any] + # Step 4. Attach methods + # Step 4.1. Method `__init__` + fn_init: Callable[..., None] | None = None if init: - fn = method_init(super_type_cls, d_fields) - attach_method(super_type_cls, type_cls, "__init__", fn, check_exists=True) + fn_init = utils.method_init(super_type_cls, d_fields) + else: + fn_init = None + # Step 4.2. Method `__repr__` and `__str__` + fn_repr: Callable[[InputClsType], str] | None = None if repr: - fn = _method_repr(type_key, fields) - type_add_method(type_index, "__str__", fn, 1) # static - attach_method(super_type_cls, type_cls, "__repr__", fn, check_exists=True) - attach_method(super_type_cls, type_cls, "__str__", fn, check_exists=True) - elif (fn := vars(super_type_cls).get("__str__", None)) is not None: - assert callable(fn) - type_add_method(type_index, "__str__", fn, 1) - add_vtable_methods_for_type_cls(super_type_cls, type_index=type_index) + fn_repr = _method_repr(type_key, fields) + type_add_method(type_index, "__str__", fn_repr, 1) # static + elif (fn_repr := vars(super_type_cls).get("__str__", None)) is not None: + assert callable(fn_repr) + type_add_method(type_index, "__str__", fn_repr, 1) + else: + fn_repr = None + + # Step 5. Create the proxy class with the fields as properties + type_cls: type[InputClsType] = utils.create_type_class( + cls=super_type_cls, + type_info=type_info, + methods={ + "_mlc_init": make_mlc_init(fields), + "__new__": lambda cls, *args, **kwargs: type_create_instance( + cls, type_index, num_bytes + ), + "__init__": fn_init, + "__repr__": fn_repr, + "__str__": fn_repr, + }, + ) + type_add_method(type_index, "__init__", _method_new(type_cls), 1) # static + setattr(type_cls, "_mlc_structure", struct) return type_cls return decorator -def _add_field_properties(type_fields: list[TypeField]) -> int: - c_fields = [("_mlc_header", MLCHeader)] - for type_field in type_fields: - field_name = type_field.name - field_ty_c = type_field.ty._ctype() - c_fields.append((field_name, field_ty_c)) - - class CType(ctypes.Structure): - _fields_ = c_fields - - for field in type_fields: - field.offset = getattr(CType, field.name).offset - field.getter, field.setter = type_field_get_accessor(field) - return ctypes.sizeof(CType) - - def _method_repr( type_key: str, fields: list[TypeField], @@ -188,32 +152,3 @@ def method(*args: typing.Any) -> InputClsType: return obj return method - - -def _create_cls( - cls: type, - mlc_init: Callable[..., None], - mlc_new: Callable[..., None], -) -> type[InputClsType]: - cls_name = cls.__name__ - cls_bases = cls.__bases__ - attrs = dict(cls.__dict__) - if cls_bases == (object,): - cls_bases = (Object,) - - def _add_method(fn: Callable, fn_name: str) -> None: - attrs[fn_name] = fn - fn.__module__ = cls.__module__ - fn.__name__ = fn_name - fn.__qualname__ = f"{cls_name}.{fn_name}" - - attrs["__slots__"] = () - attrs.pop("__dict__", None) - attrs.pop("__weakref__", None) - _add_method(mlc_init, "_mlc_init") - _add_method(mlc_new, "__new__") - - new_cls = type(cls_name, cls_bases, attrs) - new_cls.__module__ = cls.__module__ - new_cls = functools.wraps(cls, updated=())(new_cls) # type: ignore - return new_cls diff --git a/python/mlc/dataclasses/utils.py b/python/mlc/dataclasses/utils.py index cf9bc450..d6c32ff1 100644 --- a/python/mlc/dataclasses/utils.py +++ b/python/mlc/dataclasses/utils.py @@ -13,18 +13,21 @@ from mlc._cython import ( MISSING, Field, + MLCHeader, TypeField, TypeInfo, TypeMethod, type_add_method, + type_field_get_accessor, type_index2type_methods, type_table, ) -from mlc.core import Object -from mlc.core import typing as mlc_typing +from mlc._cython.base import make_field KIND_MAP = {None: 0, "nobind": 1, "bind": 2, "var": 3} SUB_STRUCTURES = {"nobind": 0, "bind": 1} +InputClsType = typing.TypeVar("InputClsType") +FuncType = TypeVar("FuncType", bound=Callable[..., Any]) class DefaultFactory(typing.NamedTuple): @@ -122,17 +125,20 @@ def field( def inspect_dataclass_fields( # noqa: PLR0912 - type_key: str, type_cls: type, parent_type_info: TypeInfo, frozen: bool, -) -> tuple[list[TypeField], list[Field]]: + py_mode: bool = False, +) -> tuple[list[TypeField], list[Field], int]: + from mlc.core import typing as mlc_typing # noqa: PLC0415 + def _get_num_bytes(field_ty: Any) -> int: if hasattr(field_ty, "_ctype"): return ctypes.sizeof(field_ty._ctype()) return 0 # Step 1. Inspect and extract all the `TypeField`s + c_fields = [("_mlc_header", MLCHeader)] type_hints = get_type_hints(type_cls) type_fields: list[TypeField] = [] for type_field in parent_type_info.fields: @@ -141,7 +147,7 @@ def _get_num_bytes(field_ty: Any) -> int: field_frozen = type_field.frozen if type_hints.pop(field_name, None) is None: raise ValueError( - f"Missing field `{type_key}::{field_name}`, " + f"Missing field `{field_name}` from `{type_cls}`, " f"which appears in its parent class `{parent_type_info.type_key}`." ) type_fields.append( @@ -153,6 +159,9 @@ def _get_num_bytes(field_ty: Any) -> int: ty=field_ty, ) ) + if py_mode: + c_fields.append((field_name, field_ty._ctype())) + for field_name, field_ty_py in type_hints.items(): if field_name.startswith("_mlc_"): continue @@ -166,6 +175,8 @@ def _get_num_bytes(field_ty: Any) -> int: ty=field_ty, ) ) + if py_mode: + c_fields.append((field_name, field_ty._ctype())) # Step 2. Convert `TypeField`s to dataclass `Field`s d_fields: list[Field] = [] for type_field in type_fields: @@ -188,7 +199,62 @@ def _get_num_bytes(field_ty: Any) -> int: raise ValueError(f"Cannot recognize field: {type_field.name}: {rhs}") d_field.name = type_field.name d_fields.append(d_field) - return type_fields, d_fields + + if py_mode: + + class CType(ctypes.Structure): + _fields_ = c_fields + + for f in type_fields: + f.offset = getattr(CType, f.name).offset + f.getter, f.setter = type_field_get_accessor(f) + num_bytes = ctypes.sizeof(CType) + else: + num_bytes = -1 + + return type_fields, d_fields, num_bytes + + +def create_type_class( + cls: type, + type_info: TypeInfo, + methods: dict[str, Callable[..., typing.Any] | None], +) -> type[InputClsType]: + cls_name = cls.__name__ + cls_bases = cls.__bases__ + attrs = dict(cls.__dict__) + if cls_bases == (object,): + # If the class inherits from `object`, we need to set the base class to `Object` + from mlc.core.object import Object # noqa: PLC0415 + + cls_bases = (Object,) + + attrs.pop("__dict__", None) + attrs.pop("__weakref__", None) + attrs["__slots__"] = () + attrs["_mlc_type_info"] = type_info + for name, method in methods.items(): + if method is not None: + method.__module__ = cls.__module__ + method.__name__ = name + method.__qualname__ = f"{cls.__qualname__}.{name}" + method.__doc__ = f"Method `{name}` of class `{cls.__qualname__}`" + attrs[name] = method + for field in type_info.fields: + attrs[field.name] = make_field( + cls=cls, + name=field.name, + getter=field.getter, + setter=field.setter, + frozen=field.frozen, + ) + + new_cls = type(cls_name, cls_bases, attrs) + new_cls.__module__ = cls.__module__ + new_cls = functools.wraps(cls, updated=())(new_cls) # type: ignore + type_info.type_cls = new_cls + add_vtable_methods_for_type_cls(new_cls, type_index=type_info.type_index) + return new_cls @dataclasses.dataclass @@ -285,6 +351,8 @@ def get_parent_type(type_cls: type) -> type: for base in type_cls.__bases__: if hasattr(base, "_mlc_type_info"): return base + from mlc.core.object import Object # noqa: PLC0415 + return Object @@ -305,9 +373,6 @@ def add_vtable_method( ) -FuncType = TypeVar("FuncType", bound=Callable[..., Any]) - - def vtable_method(is_static: bool) -> Callable[[FuncType], FuncType]: def decorator(method: FuncType) -> FuncType: method._mlc_is_static_func = is_static # type: ignore[attr-defined] @@ -355,6 +420,8 @@ def _prototype_cxx( type_info: TypeInfo, export_macro: str = "_EXPORTS", ) -> str: + from mlc.core import typing as mlc_typing # noqa: PLC0415 + assert isinstance(type_info, TypeInfo) parent_type_info = type_info.get_parent() namespaces = type_info.type_key.split(".") diff --git a/python/mlc/printer/ast.py b/python/mlc/printer/ast.py index aaab8e3a..c2604ae7 100644 --- a/python/mlc/printer/ast.py +++ b/python/mlc/printer/ast.py @@ -8,7 +8,7 @@ @mlcd.c_class("mlc.printer.PrinterConfig") -class PrinterConfig(Object): +class PrinterConfig: def_free_var: bool = True indent_spaces: int = 2 print_line_numbers: int = 0 @@ -150,37 +150,38 @@ class Stmt(Node): @mlcd.c_class("mlc.printer.ast.StmtBlock") class StmtBlock(Stmt): - stmts: list[Stmt] + stmts: list[Stmt] # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.Literal") class Literal(Expr): - value: Any # int, str, float, bool, None + # value can be: int, str, float, bool, None + value: Any # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.Id") class Id(Expr): - name: str + name: str # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.Attr") class Attr(Expr): - obj: Expr - name: str + obj: Expr # type: ignore[misc] + name: str # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.Index") class Index(Expr): - obj: Expr - idx: list[Expr] + obj: Expr # type: ignore[misc] + idx: list[Expr] # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.Call") class Call(Expr): - callee: Expr - args: list[Expr] - kwargs_keys: list[str] - kwargs_values: list[Expr] + callee: Expr # type: ignore[misc] + args: list[Expr] # type: ignore[misc] + kwargs_keys: list[str] # type: ignore[misc] + kwargs_values: list[Expr] # type: ignore[misc] class OperationKind: @@ -221,37 +222,37 @@ class OperationKind: @mlcd.c_class("mlc.printer.ast.Operation") class Operation(Expr): - op: int # OperationKind - operands: list[Expr] + op: int # type: ignore[misc] + operands: list[Expr] # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.Lambda") class Lambda(Expr): - args: list[Id] - body: Expr + args: list[Id] # type: ignore[misc] + body: Expr # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.Tuple") class Tuple(Expr): - values: list[Expr] + values: list[Expr] # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.List") class List(Expr): - values: list[Expr] + values: list[Expr] # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.Dict") class Dict(Expr): - keys: list[Expr] - values: list[Expr] + keys: list[Expr] # type: ignore[misc] + values: list[Expr] # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.Slice", init=False) class Slice(Expr): - start: Optional[Expr] - stop: Optional[Expr] - step: Optional[Expr] + start: Optional[Expr] # type: ignore[misc] + stop: Optional[Expr] # type: ignore[misc] + step: Optional[Expr] # type: ignore[misc] def __init__( self, @@ -264,9 +265,9 @@ def __init__( @mlcd.c_class("mlc.printer.ast.Assign", init=False) class Assign(Stmt): - lhs: Expr - rhs: Optional[Expr] - annotation: Optional[Expr] + lhs: Expr # type: ignore[misc] + rhs: Optional[Expr] # type: ignore[misc] + annotation: Optional[Expr] # type: ignore[misc] def __init__( # noqa: PLR0913, RUF100 self, @@ -283,40 +284,40 @@ def __init__( # noqa: PLR0913, RUF100 @mlcd.c_class("mlc.printer.ast.If") class If(Stmt): - cond: Expr - then_branch: list[Stmt] - else_branch: list[Stmt] + cond: Expr # type: ignore[misc] + then_branch: list[Stmt] # type: ignore[misc] + else_branch: list[Stmt] # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.While") class While(Stmt): - cond: Expr - body: list[Stmt] + cond: Expr # type: ignore[misc] + body: list[Stmt] # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.For") class For(Stmt): - lhs: Expr - rhs: Expr - body: list[Stmt] + lhs: Expr # type: ignore[misc] + rhs: Expr # type: ignore[misc] + body: list[Stmt] # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.With") class With(Stmt): - lhs: Optional[Expr] - rhs: Expr - body: list[Stmt] + lhs: Optional[Expr] # type: ignore[misc] + rhs: Expr # type: ignore[misc] + body: list[Stmt] # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.ExprStmt") class ExprStmt(Stmt): - expr: Expr + expr: Expr # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.Assert", init=False) class Assert(Stmt): - cond: Expr - msg: Optional[Expr] + cond: Expr # type: ignore[misc] + msg: Optional[Expr] # type: ignore[misc] def __init__( self, @@ -332,23 +333,23 @@ def __init__( @mlcd.c_class("mlc.printer.ast.Return") class Return(Stmt): - value: Optional[Expr] + value: Optional[Expr] # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.Function") class Function(Stmt): - name: Id - args: list[Assign] - decorators: list[Expr] - return_type: Optional[Expr] - body: list[Stmt] + name: Id # type: ignore[misc] + args: list[Assign] # type: ignore[misc] + decorators: list[Expr] # type: ignore[misc] + return_type: Optional[Expr] # type: ignore[misc] + body: list[Stmt] # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.Class") class Class(Stmt): - name: Id - decorators: list[Expr] - body: list[Stmt] + name: Id # type: ignore[misc] + decorators: list[Expr] # type: ignore[misc] + body: list[Stmt] # type: ignore[misc] @mlcd.c_class("mlc.printer.ast.Comment", init=False) diff --git a/python/mlc/sym/_internal.py b/python/mlc/sym/_internal.py index 6e7816e1..d4ed782e 100644 --- a/python/mlc/sym/_internal.py +++ b/python/mlc/sym/_internal.py @@ -6,20 +6,20 @@ from collections.abc import Generator import mlc.dataclasses as mlcd -from mlc.core import Func, Object +from mlc.core import Func from .analyzer import Analyzer from .expr import Expr, Var, const @mlcd.c_class("mlc.sym.ConstIntBound") -class ConstIntBound(Object): +class ConstIntBound: min_value: int max_value: int @mlcd.c_class("mlc.sym.IntervalSet", init=False) -class IntervalSet(Object): +class IntervalSet: min_value: Expr max_value: Expr @@ -33,11 +33,11 @@ def __init__(self, min_value: Expr | int, max_value: Expr | int) -> None: if isinstance(max_value, int): assert isinstance(min_value, Expr) max_value = const(min_value.dtype, max_value) - self._mlc_init(min_value, max_value) + self._mlc_init(min_value, max_value) # type: ignore[attr-defined] @mlcd.c_class("mlc.sym.ModularSet") -class ModularSet(Object): +class ModularSet: coeff: int base: int diff --git a/python/mlc/sym/analyzer.py b/python/mlc/sym/analyzer.py index 1e1fe122..8a9a299e 100644 --- a/python/mlc/sym/analyzer.py +++ b/python/mlc/sym/analyzer.py @@ -3,16 +3,15 @@ from typing import TYPE_CHECKING, Literal import mlc.dataclasses as mlcd -from mlc.core import Object if TYPE_CHECKING: from .expr import Expr, Range, Var @mlcd.c_class("mlc.sym.Analyzer") -class Analyzer(Object): +class Analyzer: def mark_global_non_neg_value(self, v: Expr) -> None: - Analyzer._C(b"_mark_global_non_neg_value", self, v) + Analyzer._C(b"_mark_global_non_neg_value", self, v) # type: ignore[attr-defined] def bind( self, @@ -23,21 +22,21 @@ def bind( from .expr import Expr, Range, const # noqa: PLC0415 if isinstance(bound, Range): - Analyzer._C(b"_bind_range", self, v, bound, allow_override) + Analyzer._C(b"_bind_range", self, v, bound, allow_override) # type: ignore[attr-defined] elif isinstance(bound, Expr): - Analyzer._C(b"_bind_expr", self, v, bound, allow_override) + Analyzer._C(b"_bind_expr", self, v, bound, allow_override) # type: ignore[attr-defined] elif isinstance(bound, (int, float)): - Analyzer._C(b"_bind_expr", self, v, const(v.dtype, bound), allow_override) + Analyzer._C(b"_bind_expr", self, v, const(v.dtype, bound), allow_override) # type: ignore[attr-defined] else: raise TypeError(f"Unsupported type for bound: {type(bound)}") def can_prove_greater_equal(self, a: Expr, b: int) -> bool: assert isinstance(b, int) - return Analyzer._C(b"can_prove_greater_equal", self, a, b) + return Analyzer._C(b"can_prove_greater_equal", self, a, b) # type: ignore[attr-defined] def can_prove_less(self, a: Expr, b: int) -> bool: assert isinstance(b, int) - return Analyzer._C(b"can_prove_less", self, a, b) + return Analyzer._C(b"can_prove_less", self, a, b) # type: ignore[attr-defined] def can_prove_equal(self, a: Expr, b: Expr | int) -> bool: from .expr import Expr, const # noqa: PLC0415 @@ -45,10 +44,10 @@ def can_prove_equal(self, a: Expr, b: Expr | int) -> bool: assert isinstance(a, Expr) if isinstance(b, int): b = const(a.dtype, b) - return Analyzer._C(b"can_prove_equal", self, a, b) + return Analyzer._C(b"can_prove_equal", self, a, b) # type: ignore[attr-defined] def can_prove_less_equal_than_symbolic_shape_value(self, a: Expr, b: Expr) -> bool: - return Analyzer._C(b"can_prove_less_equal_than_symbolic_shape_value", self, a, b) + return Analyzer._C(b"can_prove_less_equal_than_symbolic_shape_value", self, a, b) # type: ignore[attr-defined] def can_prove( self, @@ -56,10 +55,10 @@ def can_prove( *, strength: Literal["default", "symbolic_bound"] = "default", ) -> bool: - return Analyzer._C(b"can_prove", self, cond, _STRENGTH[strength]) + return Analyzer._C(b"can_prove", self, cond, _STRENGTH[strength]) # type: ignore[attr-defined] def simplify(self, expr: Expr, *, steps: int = 2) -> Expr: - return Analyzer._C(b"simplify", self, expr, steps) + return Analyzer._C(b"simplify", self, expr, steps) # type: ignore[attr-defined] _STRENGTH = { diff --git a/tests/python/test_sym_analyzer_const_int_bound.py b/tests/python/test_sym_analyzer_const_int_bound.py index 2e537e2e..261f0613 100644 --- a/tests/python/test_sym_analyzer_const_int_bound.py +++ b/tests/python/test_sym_analyzer_const_int_bound.py @@ -41,12 +41,12 @@ def test_body(self, param: Param) -> None: for var, bounds in param.bounds.items(): const_int_bound_update(analyzer, var, ConstIntBound(*bounds)) with enter_constraint(analyzer, param.constraint): - bounds = const_int_bound(analyzer, param.expr) + actual = const_int_bound(analyzer, param.expr) expected_min_value, expected_max_value = param.expected if expected_min_value is not None: - assert bounds.min_value == expected_min_value # type: ignore[attr-defined] + assert actual.min_value == expected_min_value # type: ignore[attr-defined] if expected_max_value is not None: - assert bounds.max_value == expected_max_value # type: ignore[attr-defined] + assert actual.max_value == expected_max_value # type: ignore[attr-defined] class TestDataType(_Test): diff --git a/tests/python/test_sym_analyzer_modular_set.py b/tests/python/test_sym_analyzer_modular_set.py index 443493c8..1afea0fe 100644 --- a/tests/python/test_sym_analyzer_modular_set.py +++ b/tests/python/test_sym_analyzer_modular_set.py @@ -45,12 +45,12 @@ def test_body(self, param: Param) -> None: for constraint in param.constraints: assert isinstance(constraint, S.Expr) exit_stack.enter_context(enter_constraint(analyzer, constraint)) - bounds = modular_set(analyzer, param.expr) + actual = modular_set(analyzer, param.expr) expected_coeff, expected_base = param.expected if expected_coeff is not None: - assert bounds.coeff == expected_coeff + assert actual.coeff == expected_coeff if expected_base is not None: - assert bounds.base == expected_base + assert actual.base == expected_base class TestCast(_Test):