Skip to content
Draft
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1adca71
Scaffold
justinchuby Jun 29, 2025
f48537b
Add support for expr in SymbolicDim
justinchuby Jun 29, 2025
7464af1
wip
justinchuby Jun 29, 2025
27ae0ca
wip
justinchuby Jun 29, 2025
1afe64a
Create NodeInferencer
justinchuby Jun 29, 2025
78ad6e0
inference_common
justinchuby Jun 29, 2025
5aa2df7
Update shapes
justinchuby Jun 29, 2025
dbc3593
update
justinchuby Jun 29, 2025
b9f0528
Claude - add sympy import
justinchuby Jun 30, 2025
c9a35b7
Claude and lint
justinchuby Jun 30, 2025
65e3dd2
concat
justinchuby Jun 30, 2025
7960770
Update _maybe_convert_to_symbolic_dim
justinchuby Jun 30, 2025
a7704c5
reshape
justinchuby Jun 30, 2025
922a597
Update the way dim is set
justinchuby Jun 30, 2025
9183848
Simplify
justinchuby Jun 30, 2025
9300aba
Update
justinchuby Jun 30, 2025
8747a93
Handle unknown dims
justinchuby Jun 30, 2025
92049c4
Simplify
justinchuby Jun 30, 2025
720845e
Create inclusive range
justinchuby Jun 30, 2025
bae78ab
WIP inference engine
justinchuby Jun 30, 2025
a77f487
Create readme
justinchuby Jun 30, 2025
6686457
Result
justinchuby Jun 30, 2025
3207e84
Summary of Complete Refactoring
justinchuby Jun 30, 2025
a572145
lint
justinchuby Jun 30, 2025
11f8958
Removes unused shape inference code
justinchuby Jun 30, 2025
f3c70da
Summary of Shape Simplifications
justinchuby Jun 30, 2025
4b6d80d
Create factory
justinchuby Jun 30, 2025
e03733b
Use Enum
justinchuby Jun 30, 2025
5a34891
Update logging calls
justinchuby Jun 30, 2025
ab09107
Working on engine
justinchuby Jun 30, 2025
9256233
todo
justinchuby Jun 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
"Programming Language :: Python :: 3.13",
"License :: OSI Approved :: Apache Software License",
]
dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes"]
dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes", "sympy"]

[project.urls]
Homepage = "https://onnx.ai/ir-py"
Expand Down
37 changes: 31 additions & 6 deletions src/onnx_ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@

import ml_dtypes
import numpy as np
import sympy
import sympy.utilities.misc
from typing_extensions import TypeIs

import onnx_ir
Expand Down Expand Up @@ -1115,13 +1117,14 @@
It is immutable and can be compared or hashed.
"""

__slots__ = ("_value",)
__slots__ = ("_expr", "_value")

def __init__(self, value: str | None) -> None:
def __init__(self, value: str | None, /, expr: sympy.Expr | None = None) -> None:
"""Initialize a symbolic dimension.

Args:
value: The value of the dimension. It should not be an int.
expr: An optional sympy expression representing the dimension.

Raises:
TypeError: If value is an int.
Expand All @@ -1132,6 +1135,7 @@
"If you are creating a Shape, use int directly instead of SymbolicDim."
)
self._value = value
self._expr: sympy.Expr | None = expr

def __eq__(self, other: object) -> bool:
"""Check equality with another SymbolicDim or string/None."""
Expand All @@ -1148,11 +1152,24 @@
"""The value of the symbolic dimension (string or None)."""
return self._value

@property
def expr(self) -> sympy.Expr | None:
"""The sympy expression representing the symbolic dimension."""
return self._expr

Check warning on line 1158 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1158

Added line #L1158 was not covered by tests

def __str__(self) -> str:
return f"{self._value}"
if self._value is not None:
return str(self._value)

Check warning on line 1162 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1162

Added line #L1162 was not covered by tests
if self._expr is not None:
return str(self._expr)
return "?"

Check warning on line 1165 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1164-L1165

Added lines #L1164 - L1165 were not covered by tests

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._value})"
if self._expr is not None:
expr_text = f", expr={self._expr!r}"

Check warning on line 1169 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1169

Added line #L1169 was not covered by tests
else:
expr_text = ""
return f"{self.__class__.__name__}({self._value}{expr_text})"

Check warning on line 1172 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1171-L1172

Added lines #L1171 - L1172 were not covered by tests


def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
Expand Down Expand Up @@ -1190,10 +1207,16 @@
return SymbolicDim(dim)
if _is_int_compatible(dim):
return int(dim)
if isinstance(dim, sympy.Expr):
# If the dimension is a sympy expression, we create a SymbolicDim with it
expr = sympy.sympify(dim)

Check warning on line 1212 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1212

Added line #L1212 was not covered by tests
if expr.is_integer:
return sympy.utilities.misc.as_int(expr)
return SymbolicDim(str(expr), expr=sympy.sympify(expr))

Check warning on line 1215 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1214-L1215

Added lines #L1214 - L1215 were not covered by tests
if isinstance(dim, SymbolicDim):
return dim
raise TypeError(
f"Expected int, str, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'"
f"Expected int, str, sympy.Expr, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'"
)


Expand Down Expand Up @@ -1334,7 +1357,9 @@
def __getitem__(self, index):
return tuple(self._dims)[index]

def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None:
def __setitem__(
self, index: int, value: int | SymbolicDim | str | sympy.Expr | None
) -> None:
"""Set the dimension at the index.

Args:
Expand Down
15 changes: 15 additions & 0 deletions src/onnx_ir/_shape_type_inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Symbolic shape and type inference for ONNX IR."""

__all__ = [
"SymbolicInferenceEngine",
"InferenceError",
"NodeInferrer",
"InferenceResult",
]


from onnx_ir._shape_type_inference._common import InferenceResult, NodeInferrer
from onnx_ir._shape_type_inference._engine import (
InferenceError,
SymbolicInferenceEngine,
)
159 changes: 159 additions & 0 deletions src/onnx_ir/_shape_type_inference/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""Symbolic shape inference for ONNX IR."""

from __future__ import annotations

import abc
import dataclasses
import functools
from collections.abc import Collection, Sequence
from typing import Any, Callable

import sympy

import onnx_ir as ir


MAX_SUPPORTED_OPSET = 23


def get_expr(shape: ir.Shape, index: int) -> sympy.Expr:
"""Get the expression or value at a specific index in the shape.

Args:
shape: The shape to get the expression from.
index: The index of the dimension to get.

Returns:
The expression or value at the specified index.
"""
dim = shape[index]

Check warning on line 29 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L29

Added line #L29 was not covered by tests
if isinstance(dim, ir.SymbolicDim):
if dim.expr is not None:
return dim.expr

Check warning on line 32 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L32

Added line #L32 was not covered by tests
if dim.value is None:
return sympy.Symbol("__unknown__")
return sympy.Symbol(dim.value)
return sympy.Integer(dim)

Check warning on line 36 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L34-L36

Added lines #L34 - L36 were not covered by tests


@dataclasses.dataclass
class InferenceResult:
values: Sequence[ir.Value] | None = None
failure: str | None = None


class NodeInferrer(abc.ABC):
"""Base class for node inferrers.

This class provides a common interface for all node inferrers.
"""

def __init__(self, op_type: str, opsets: Collection[int], domain: str = "") -> None:
"""Initialize the node inferrer.

Args:
op_type: The type of the operation.
opsets: A collection of ONNX opset versions supported by this inferrer.
domain: The domain of the operation, default is an empty string.
"""
self.op_type = op_type
self.opsets = opsets
self.domain = domain

Check warning on line 61 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L59-L61

Added lines #L59 - L61 were not covered by tests

def __repr__(self) -> str:
"""Return a string representation of the node inferrer."""
return f"{self.__class__.__name__}(op_type={self.op_type}, opsets={self.opsets}, domain={self.domain})"

Check warning on line 65 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L65

Added line #L65 was not covered by tests

@abc.abstractmethod
def infer(self, node: ir.Node) -> InferenceResult:
"""Infer the shape for the node.

Args:
node: The ONNX node to infer the type and shape for.

Returns:
A sequence of ONNX values containing the inferred shapes.
"""
raise NotImplementedError

Check warning on line 77 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L77

Added line #L77 was not covered by tests


def requires_non_none_inputs(
count: int, /
) -> Callable[
[Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult]
]:
"""Ensure that the node has a specific number of non-None inputs.

Args:
count: The exact number of non-None inputs required for the node.

Returns:
A decorator that checks the number of inputs and their non-None status.
"""

def decorator(
func: Callable[[Any, ir.Node], InferenceResult],
) -> Callable[[Any, ir.Node], InferenceResult]:
@functools.wraps(func)
def wrapper(self, node: ir.Node) -> InferenceResult:
if len(node.inputs) != count:
return InferenceResult(

Check warning on line 100 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L100

Added line #L100 was not covered by tests
failure=f"[{node.op_type} must have {count} inputs, got {len(node.inputs)}."
)
for i, inp in enumerate(node.inputs):
if inp is None:
return InferenceResult(failure=f"{node.op_type} input {i} cannot be None.")
return func(self, node)

Check warning on line 106 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L105-L106

Added lines #L105 - L106 were not covered by tests

return wrapper

return decorator


def requires_outputs(
count: int, /
) -> Callable[
[Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult]
]:
"""Ensure that the node has a specific number of outputs.

Args:
count: The exact number of outputs required for the node.

Returns:
A decorator that checks the number of outputs.
"""

def decorator(
func: Callable[[Any, ir.Node], InferenceResult],
) -> Callable[[Any, ir.Node], InferenceResult]:
@functools.wraps(func)
def wrapper(self, node: ir.Node) -> InferenceResult:
if len(node.outputs) != count:
return InferenceResult(

Check warning on line 133 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L133

Added line #L133 was not covered by tests
failure=f"[{node.op_type} must have {count} outputs, got {len(node.outputs)}."
)
return func(self, node)

Check warning on line 136 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L136

Added line #L136 was not covered by tests

return wrapper

return decorator


def inclusive_range(start_or_end: int = 0, end: int | None = None) -> range:
"""Create an inclusive range from start to end with a given step.

Args:
start_or_end: The starting value of the range.
end: The ending value of the range (inclusive).

Returns:
A range object that includes both start and end.
"""
if end is None:
end = start_or_end
start = 0

Check warning on line 155 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L154-L155

Added lines #L154 - L155 were not covered by tests
else:
start = start_or_end

Check warning on line 157 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L157

Added line #L157 was not covered by tests

return range(start, end + 1)

Check warning on line 159 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L159

Added line #L159 was not covered by tests
Loading
Loading