Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 251 additions & 0 deletions src/onnx_ir/_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import dataclasses
import logging
from collections.abc import Iterator, Mapping, Sequence
from typing import Any

import onnx

import onnx_ir as ir

logger = logging.getLogger(__name__)


# A special value to indicate that the default value is not specified
class _Empty:
def __repr__(self):
return "_EMPTY_DEFAULT"

Check warning on line 20 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L20

Added line #L20 was not covered by tests


_EMPTY_DEFAULT = _Empty()

_ALL_VALUE_TYPES = (
{ir.TensorType(dtype) for dtype in ir.DataType}
| {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType}
| {ir.OptionalType(ir.TensorType(dtype)) for dtype in ir.DataType}
)


@dataclasses.dataclass(frozen=True)
class TypeConstraintParam:
"""Type constraint for a parameter.

Attributes:
name: Name of the parameter. E.g. "TFloat"
allowed_types: Allowed types for the parameter.
"""

name: str
allowed_types: set[ir.TypeProtocol]
description: str = ""

def __hash__(self) -> int:
return hash((self.name, tuple(self.allowed_types)))

Check warning on line 46 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L46

Added line #L46 was not covered by tests

def __str__(self) -> str:
allowed_types_str = " | ".join(str(t) for t in self.allowed_types)
return f"{self.name}={allowed_types_str}"

Check warning on line 50 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L49-L50

Added lines #L49 - L50 were not covered by tests

@classmethod
def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam:
return cls(name, {ir.TensorType(dtype) for dtype in ir.DataType}, description)

Check warning on line 54 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L54

Added line #L54 was not covered by tests

@classmethod
def any_value(cls, name: str, description: str = "") -> TypeConstraintParam:
return cls(name, _ALL_VALUE_TYPES, description) # type: ignore[arg-type]

Check warning on line 58 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L58

Added line #L58 was not covered by tests


@dataclasses.dataclass(frozen=True)
class Parameter:
"""A formal parameter of an operator."""

name: str
type_constraint: TypeConstraintParam
required: bool
variadic: bool
default: Any = _EMPTY_DEFAULT
# TODO: Add other properties too

def __str__(self) -> str:
type_str = self.type_constraint.name

Check warning on line 73 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L73

Added line #L73 was not covered by tests
if self.has_default():
return f"{self.name}: {type_str} = {self.default}"
return f"{self.name}: {type_str}"

Check warning on line 76 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L75-L76

Added lines #L75 - L76 were not covered by tests

def has_default(self) -> bool:
return self.default is not _EMPTY_DEFAULT

Check warning on line 79 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L79

Added line #L79 was not covered by tests


@dataclasses.dataclass(frozen=True)
class AttributeParameter:
"""A parameter in the function signature that represents an ONNX attribute."""

name: str
type: ir.AttributeType
required: bool
default: ir.Attr | None = None

def __str__(self) -> str:
type_str = self.type.name

Check warning on line 92 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L92

Added line #L92 was not covered by tests
if self.has_default():
return f"{self.name}: {type_str} = {self.default}"
return f"{self.name}: {type_str}"

Check warning on line 95 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L94-L95

Added lines #L94 - L95 were not covered by tests

def has_default(self) -> bool:
return self.default is not None

Check warning on line 98 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L98

Added line #L98 was not covered by tests


def _get_type_from_str(
type_str: str,
) -> ir.TensorType | ir.SequenceType | ir.OptionalType:
"""Converter a type_str from ONNX OpSchema to ir.TypeProtocol.

A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))".
"""
# Split the type_str a sequence types and dtypes
# 1. Remove the ending ")"
striped = type_str.rstrip(")")

Check warning on line 110 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L110

Added line #L110 was not covered by tests
# 2. Split the type_str by "("
type_parts = striped.split("(")

Check warning on line 112 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L112

Added line #L112 was not covered by tests

# Convert the dtype to ir.DataType
dtype = ir.DataType[type_parts[-1].upper()]

Check warning on line 115 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L115

Added line #L115 was not covered by tests

# Create a place holder type first
type_: ir.TypeProtocol = ir.TensorType(ir.DataType.UNDEFINED)

Check warning on line 118 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L118

Added line #L118 was not covered by tests

# Construct the type
for type_part in reversed(type_parts[:-1]):
if type_part == "tensor":
type_ = ir.TensorType(dtype)

Check warning on line 123 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L123

Added line #L123 was not covered by tests
elif type_part == "seq":
type_ = ir.SequenceType(type_)

Check warning on line 125 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L125

Added line #L125 was not covered by tests
elif type_part == "optional":
type_ = ir.OptionalType(type_)

Check warning on line 127 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L127

Added line #L127 was not covered by tests
else:
raise ValueError(f"Unknown type part: '{type_part}' in type '{type_str}'")
return type_ # type: ignore[return-value]

Check warning on line 130 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L129-L130

Added lines #L129 - L130 were not covered by tests


def _convert_formal_parameter(
param: onnx.defs.OpSchema.FormalParameter,
type_constraints: Mapping[str, TypeConstraintParam],
) -> Parameter:
"""Convert a formal parameter from ONNX OpSchema to Parameter."""
if param.type_str in type_constraints:
type_constraint = type_constraints[param.type_str]

Check warning on line 139 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L139

Added line #L139 was not covered by tests
else:
# param.type_str can be a plain type like 'int64'.
type_constraint = TypeConstraintParam(

Check warning on line 142 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L142

Added line #L142 was not covered by tests
name=param.name,
allowed_types={_get_type_from_str(param.type_str)},
)
return Parameter(

Check warning on line 146 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L146

Added line #L146 was not covered by tests
name=param.name,
type_constraint=type_constraint,
required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional,
variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic,
)


@dataclasses.dataclass
class OpSignature:
"""Schema for an operator.

Attributes:
domain: Domain of the operator. E.g. "".
name: Name of the operator. E.g. "Add".
overload: Overload name of the operator.
params: Input parameters. When the op is an ONNX function definition,
the order is according to the function signature. This mean we can
interleave ONNX inputs and ONNX attributes in the list.
outputs: Output parameters.
"""

domain: str
name: str
overload: str
params: Sequence[Parameter | AttributeParameter]
outputs: Sequence[Parameter]
params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field(
init=False, repr=False
)

def __post_init__(self):
self.params_map = {param.name: param for param in self.params}

Check warning on line 178 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L178

Added line #L178 was not covered by tests

def get(self, name: str) -> Parameter | AttributeParameter:
return self.params_map[name]

Check warning on line 181 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L181

Added line #L181 was not covered by tests

def __contains__(self, name: str) -> bool:
return name in self.params_map

Check warning on line 184 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L184

Added line #L184 was not covered by tests

def __iter__(self) -> Iterator[Parameter | AttributeParameter]:
return iter(self.params)

Check warning on line 187 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L187

Added line #L187 was not covered by tests

def __str__(self) -> str:
domain = self.domain or "''"
overload = f"::{self.overload}" if self.overload else ""
params = ", ".join(str(param) for param in self.params)
outputs = ", ".join(str(param.type_constraint.name) for param in self.outputs)
type_constraints = {}

Check warning on line 194 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L190-L194

Added lines #L190 - L194 were not covered by tests
for param in self.params:
if isinstance(param, Parameter):
type_constraints[param.type_constraint.name] = param.type_constraint

Check warning on line 197 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L197

Added line #L197 was not covered by tests
for param in self.outputs:
type_constraints[param.type_constraint.name] = param.type_constraint
type_constraints_str = ", ".join(

Check warning on line 200 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L199-L200

Added lines #L199 - L200 were not covered by tests
str(type_constraint) for type_constraint in type_constraints.values()
)
return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}"

Check warning on line 203 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L203

Added line #L203 was not covered by tests

@classmethod
def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature:
"""Produce an OpSignature from an ONNX OpSchema."""
type_constraints = {

Check warning on line 208 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L208

Added line #L208 was not covered by tests
constraint.type_param_str: TypeConstraintParam(
name=constraint.type_param_str,
allowed_types={
_get_type_from_str(type_str) for type_str in constraint.allowed_type_strs
},
description=constraint.description,
)
for constraint in op_schema.type_constraints
}

params = [

Check warning on line 219 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L219

Added line #L219 was not covered by tests
_convert_formal_parameter(param, type_constraints) for param in op_schema.inputs
]

for param in op_schema.attributes.values():
default_attr = (

Check warning on line 224 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L224

Added line #L224 was not covered by tests
ir.serde.deserialize_attribute(param.default_value)
if param.default_value is not None
else None
)
if default_attr is not None:
# Set the name of the default attribute because it may have a different name from the parameter
default_attr.name = param.name
params.append(

Check warning on line 232 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L231-L232

Added lines #L231 - L232 were not covered by tests
AttributeParameter(
name=param.name,
type=ir.AttributeType(param.type), # type: ignore[arg-type]
required=param.required,
default=default_attr, # type: ignore[arg-type]
)
)

outputs = [

Check warning on line 241 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L241

Added line #L241 was not covered by tests
_convert_formal_parameter(param, type_constraints) for param in op_schema.outputs
]

return cls(

Check warning on line 245 in src/onnx_ir/_schemas.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_schemas.py#L245

Added line #L245 was not covered by tests
domain=op_schema.domain,
name=op_schema.name,
overload="",
params=params,
outputs=outputs,
)