Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dependencies = [
'ml-dtypes >= 0.1',
'Pygments>=2.4.0',
'colorama',
'typing-extensions >= 4.9.0',
'setuptools ; platform_system == "Windows"',
]
description = "Python-first Development for AI Compilers"
Expand Down
17 changes: 13 additions & 4 deletions python/mlc/_cython/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,13 @@ def __new__(
return super().__new__(cls, name, bases, dict)


def attach_field(
def make_field(
cls: type,
name: str,
getter: typing.Callable[[typing.Any], typing.Any] | None,
setter: typing.Callable[[typing.Any, typing.Any], None] | None,
frozen: bool,
) -> None:
) -> property:
def fget(this: typing.Any, _name: str = name) -> typing.Any:
return getter(this) # type: ignore[misc]

Expand All @@ -383,12 +383,21 @@ def fset(this: typing.Any, value: typing.Any, _name: str = name) -> None:
fget.__module__ = fset.__module__ = cls.__module__
fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}" # type: ignore[attr-defined]
fget.__doc__ = fset.__doc__ = f"Property `{name}` of class `{cls.__qualname__}`" # type: ignore[attr-defined]
prop = property(
return property(
fget=fget if getter else None,
fset=fset if (not frozen) and setter else None,
doc=f"{cls.__module__}.{cls.__qualname__}.{name}",
)
setattr(cls, name, prop)


def attach_field(
cls: type,
name: str,
getter: typing.Callable[[typing.Any], typing.Any] | None,
setter: typing.Callable[[typing.Any, typing.Any], None] | None,
frozen: bool,
) -> None:
setattr(cls, name, make_field(cls, name, getter, setter, frozen)) # type: ignore[call-arg]


def attach_method(
Expand Down
77 changes: 33 additions & 44 deletions python/mlc/dataclasses/c_class.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,68 @@
import functools
import typing
import warnings
from collections.abc import Callable

try:
from typing import dataclass_transform
except ImportError:
from typing_extensions import dataclass_transform

from mlc._cython import (
TypeInfo,
TypeMethod,
attach_field,
attach_method,
type_index2type_methods,
type_key2py_type_info,
)
from mlc.core import typing as mlc_typing

from .utils import (
add_vtable_methods_for_type_cls,
get_parent_type,
inspect_dataclass_fields,
method_init,
prototype,
)
from . import utils

ClsType = typing.TypeVar("ClsType")
InputClsType = typing.TypeVar("InputClsType")


@dataclass_transform(field_specifiers=(utils.field, utils.Field))
def c_class(
type_key: str,
init: bool = True,
) -> Callable[[type[ClsType]], type[ClsType]]:
def decorator(super_type_cls: type[ClsType]) -> type[ClsType]:
@functools.wraps(super_type_cls, updated=())
class type_cls(super_type_cls): # type: ignore[valid-type,misc]
__slots__ = ()

) -> Callable[[type[InputClsType]], type[InputClsType]]:
def decorator(super_type_cls: type[InputClsType]) -> type[InputClsType]:
# Step 1. Retrieve `type_info` from registry
parent_type_info: TypeInfo = utils.get_parent_type(super_type_cls)._mlc_type_info # type: ignore[attr-defined]
type_info: TypeInfo = type_key2py_type_info(type_key)
parent_type_info: TypeInfo = get_parent_type(super_type_cls)._mlc_type_info # type: ignore[attr-defined]

if type_info.type_cls is not None:
raise ValueError(f"Type is already registered: {type_key}")
_, d_fields = inspect_dataclass_fields(type_key, type_cls, parent_type_info, frozen=False)
type_info.type_cls = type_cls
type_info.d_fields = tuple(d_fields)

# Step 2. Check if all fields are exposed as type annotations
# Step 2. Reflect all the fields of the type
_, d_fields, _ = utils.inspect_dataclass_fields(
super_type_cls,
parent_type_info,
frozen=False,
)
type_info.d_fields = tuple(d_fields)
# Check if all fields are exposed as type annotations
_check_c_class(super_type_cls, type_info)

# Step 3. Attach fields
setattr(type_cls, "_mlc_type_info", type_info)
for field in type_info.fields:
attach_field(
cls=type_cls,
name=field.name,
getter=field.getter,
setter=field.setter,
frozen=field.frozen,
)

# Step 4. Attach methods
fn_init: Callable[..., None] | None = None
if init:
attach_method(
parent_cls=super_type_cls,
cls=type_cls,
name="__init__",
method=method_init(super_type_cls, d_fields),
check_exists=True,
)
add_vtable_methods_for_type_cls(super_type_cls, type_index=type_info.type_index)
fn_init = utils.method_init(super_type_cls, d_fields)
else:
fn_init = None
# Step 5. Create the proxy class with the fields as properties
type_cls: type[InputClsType] = utils.create_type_class(
cls=super_type_cls,
type_info=type_info,
methods={
"__init__": fn_init,
},
)
return type_cls

return decorator


def _check_c_class(
type_cls: type[ClsType],
type_cls: type[InputClsType],
type_info: TypeInfo,
) -> None:
type_hints = typing.get_type_hints(type_cls)
Expand Down Expand Up @@ -117,5 +106,5 @@ def _check_c_class(
if warned:
warnings.warn(
f"One or multiple warnings in `{type_cls.__module__}.{type_cls.__qualname__}`. Its prototype is:\n"
+ prototype(type_info, lang="py")
+ utils.prototype(type_info, lang="py")
)
153 changes: 44 additions & 109 deletions python/mlc/dataclasses/py_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,26 @@
except ImportError:
from typing_extensions import dataclass_transform

import ctypes
import functools
import typing
from collections.abc import Callable

from mlc._cython import (
MLCHeader,
TypeField,
TypeInfo,
attach_field,
attach_method,
make_mlc_init,
type_add_method,
type_create,
type_create_instance,
type_field_get_accessor,
type_register_fields,
type_register_structure,
)
from mlc.core import Object

from .utils import Field as _Field
from .utils import (
Structure,
add_vtable_methods_for_type_cls,
get_parent_type,
inspect_dataclass_fields,
method_init,
structure_parse,
structure_to_c,
)
from .utils import field as _field

from . import utils

InputClsType = typing.TypeVar("InputClsType")


@dataclass_transform(field_specifiers=(_field, _Field))
@dataclass_transform(field_specifiers=(utils.field, utils.Field))
def py_class(
type_key: str | type | None = None,
*,
Expand Down Expand Up @@ -69,56 +52,36 @@ def decorator(super_type_cls: type[InputClsType]) -> type[InputClsType]:
type_key = f"{super_type_cls.__module__}.{super_type_cls.__qualname__}"
assert isinstance(type_key, str)

# Step 1. Create the type according to its parent type
parent_type_info: TypeInfo = get_parent_type(super_type_cls)._mlc_type_info # type: ignore[attr-defined]
# Step 1. Create `type_info`
parent_type_info: TypeInfo = utils.get_parent_type(super_type_cls)._mlc_type_info # type: ignore[attr-defined]
type_info: TypeInfo = type_create(parent_type_info.type_index, type_key)
type_index = type_info.type_index

# Step 2. Reflect all the fields of the type
fields, d_fields = inspect_dataclass_fields(
type_key,
fields, d_fields, num_bytes = utils.inspect_dataclass_fields(
super_type_cls,
parent_type_info,
frozen=frozen,
py_mode=True,
)
num_bytes = _add_field_properties(fields)
type_info.fields = tuple(fields)
type_info.d_fields = tuple(d_fields)
type_register_fields(type_index, fields)
mlc_init = make_mlc_init(fields)

# Step 3. Create the proxy class with the fields as properties
type_cls: type[InputClsType] = _create_cls(
cls=super_type_cls,
mlc_init=mlc_init,
mlc_new=lambda cls, *args, **kwargs: type_create_instance(cls, type_index, num_bytes),
)

type_info.type_cls = type_cls
setattr(type_cls, "_mlc_type_info", type_info)
for field in fields:
attach_field(
type_cls,
name=field.name,
getter=field.getter,
setter=field.setter,
frozen=field.frozen,
)

# Step 4. Register the structure of the class
struct: Structure
# Step 3. Register the structure of the class
struct: utils.Structure
struct_kind: int
sub_structure_indices: list[int]
sub_structure_kinds: list[int]
if (struct := vars(super_type_cls).get("_mlc_structure", None)) is not None:
assert isinstance(struct, Structure)
assert isinstance(struct, utils.Structure)
else:
struct = structure_parse(structure, d_fields)
struct = utils.structure_parse(structure, d_fields)
(
struct_kind,
sub_structure_indices,
sub_structure_kinds,
) = structure_to_c(struct, fields)
) = utils.structure_to_c(struct, fields)
if struct.kind is None:
assert struct_kind == 0
assert not sub_structure_indices
Expand All @@ -129,45 +92,46 @@ def decorator(super_type_cls: type[InputClsType]) -> type[InputClsType]:
sub_structure_indices=tuple(sub_structure_indices),
sub_structure_kinds=tuple(sub_structure_kinds),
)
setattr(type_cls, "_mlc_structure", struct)

# Step 5. Add `__init__` method
type_add_method(type_index, "__init__", _method_new(type_cls), 1) # static
# Step 6. Attach methods
fn: Callable[..., typing.Any]
# Step 4. Attach methods
# Step 4.1. Method `__init__`
fn_init: Callable[..., None] | None = None
if init:
fn = method_init(super_type_cls, d_fields)
attach_method(super_type_cls, type_cls, "__init__", fn, check_exists=True)
fn_init = utils.method_init(super_type_cls, d_fields)
else:
fn_init = None
# Step 4.2. Method `__repr__` and `__str__`
fn_repr: Callable[[InputClsType], str] | None = None
if repr:
fn = _method_repr(type_key, fields)
type_add_method(type_index, "__str__", fn, 1) # static
attach_method(super_type_cls, type_cls, "__repr__", fn, check_exists=True)
attach_method(super_type_cls, type_cls, "__str__", fn, check_exists=True)
elif (fn := vars(super_type_cls).get("__str__", None)) is not None:
assert callable(fn)
type_add_method(type_index, "__str__", fn, 1)
add_vtable_methods_for_type_cls(super_type_cls, type_index=type_index)
fn_repr = _method_repr(type_key, fields)
type_add_method(type_index, "__str__", fn_repr, 1) # static
elif (fn_repr := vars(super_type_cls).get("__str__", None)) is not None:
assert callable(fn_repr)
type_add_method(type_index, "__str__", fn_repr, 1)
else:
fn_repr = None

# Step 5. Create the proxy class with the fields as properties
type_cls: type[InputClsType] = utils.create_type_class(
cls=super_type_cls,
type_info=type_info,
methods={
"_mlc_init": make_mlc_init(fields),
"__new__": lambda cls, *args, **kwargs: type_create_instance(
cls, type_index, num_bytes
),
"__init__": fn_init,
"__repr__": fn_repr,
"__str__": fn_repr,
},
)
type_add_method(type_index, "__init__", _method_new(type_cls), 1) # static
setattr(type_cls, "_mlc_structure", struct)
return type_cls

return decorator


def _add_field_properties(type_fields: list[TypeField]) -> int:
c_fields = [("_mlc_header", MLCHeader)]
for type_field in type_fields:
field_name = type_field.name
field_ty_c = type_field.ty._ctype()
c_fields.append((field_name, field_ty_c))

class CType(ctypes.Structure):
_fields_ = c_fields

for field in type_fields:
field.offset = getattr(CType, field.name).offset
field.getter, field.setter = type_field_get_accessor(field)
return ctypes.sizeof(CType)


def _method_repr(
type_key: str,
fields: list[TypeField],
Expand All @@ -188,32 +152,3 @@ def method(*args: typing.Any) -> InputClsType:
return obj

return method


def _create_cls(
cls: type,
mlc_init: Callable[..., None],
mlc_new: Callable[..., None],
) -> type[InputClsType]:
cls_name = cls.__name__
cls_bases = cls.__bases__
attrs = dict(cls.__dict__)
if cls_bases == (object,):
cls_bases = (Object,)

def _add_method(fn: Callable, fn_name: str) -> None:
attrs[fn_name] = fn
fn.__module__ = cls.__module__
fn.__name__ = fn_name
fn.__qualname__ = f"{cls_name}.{fn_name}"

attrs["__slots__"] = ()
attrs.pop("__dict__", None)
attrs.pop("__weakref__", None)
_add_method(mlc_init, "_mlc_init")
_add_method(mlc_new, "__new__")

new_cls = type(cls_name, cls_bases, attrs)
new_cls.__module__ = cls.__module__
new_cls = functools.wraps(cls, updated=())(new_cls) # type: ignore
return new_cls
Loading
Loading