-
Notifications
You must be signed in to change notification settings - Fork 14
[WIP] Create symbolic type/shape inference logic #117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
justinchuby
wants to merge
31
commits into
main
Choose a base branch
from
justinchu/symbolic-inference-claude
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 20 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
1adca71
Scaffold
justinchuby f48537b
Add support for expr in SymbolicDim
justinchuby 7464af1
wip
justinchuby 27ae0ca
wip
justinchuby 1afe64a
Create NodeInferencer
justinchuby 78ad6e0
inference_common
justinchuby 5aa2df7
Update shapes
justinchuby dbc3593
update
justinchuby b9f0528
Claude - add sympy import
justinchuby c9a35b7
Claude and lint
justinchuby 65e3dd2
concat
justinchuby 7960770
Update _maybe_convert_to_symbolic_dim
justinchuby a7704c5
reshape
justinchuby 922a597
Update the way dim is set
justinchuby 9183848
Simplify
justinchuby 9300aba
Update
justinchuby 8747a93
Handle unknown dims
justinchuby 92049c4
Simplify
justinchuby 720845e
Create inclusive range
justinchuby bae78ab
WIP inference engine
justinchuby a77f487
Create readme
justinchuby 6686457
Result
justinchuby 3207e84
Summary of Complete Refactoring
justinchuby a572145
lint
justinchuby 11f8958
Removes unused shape inference code
justinchuby f3c70da
Summary of Shape Simplifications
justinchuby 4b6d80d
Create factory
justinchuby e03733b
Use Enum
justinchuby 5a34891
Update logging calls
justinchuby ab09107
Working on engine
justinchuby 9256233
todo
justinchuby File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
"""Symbolic shape and type inference for ONNX IR.""" | ||
|
||
__all__ = [ | ||
"SymbolicInferenceEngine", | ||
"InferenceError", | ||
"NodeInferrer", | ||
"InferenceResult", | ||
] | ||
|
||
|
||
from onnx_ir._shape_type_inference._common import InferenceResult, NodeInferrer | ||
from onnx_ir._shape_type_inference._engine import ( | ||
InferenceError, | ||
SymbolicInferenceEngine, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
"""Symbolic shape inference for ONNX IR.""" | ||
|
||
|
||
from __future__ import annotations | ||
|
||
|
||
import abc | ||
import dataclasses | ||
import functools | ||
from collections.abc import Collection, Sequence | ||
from typing import Any, Callable | ||
|
||
import sympy | ||
|
||
import onnx_ir as ir | ||
|
||
|
||
MAX_SUPPORTED_OPSET = 23 | ||
|
||
|
||
def get_expr(shape: ir.Shape, index: int) -> sympy.Expr: | ||
"""Get the expression or value at a specific index in the shape. | ||
|
||
Args: | ||
shape: The shape to get the expression from. | ||
index: The index of the dimension to get. | ||
|
||
Returns: | ||
The expression or value at the specified index. | ||
""" | ||
dim = shape[index] | ||
if isinstance(dim, ir.SymbolicDim): | ||
if dim.expr is not None: | ||
return dim.expr | ||
if dim.value is None: | ||
return sympy.Symbol("__unknown__") | ||
return sympy.Symbol(dim.value) | ||
return sympy.Integer(dim) | ||
|
||
|
||
@dataclasses.dataclass | ||
class InferenceResult: | ||
values: Sequence[ir.Value] | None = None | ||
failure: str | None = None | ||
|
||
|
||
class NodeInferrer(abc.ABC): | ||
"""Base class for node inferrers. | ||
|
||
This class provides a common interface for all node inferrers. | ||
""" | ||
|
||
def __init__(self, op_type: str, opsets: Collection[int], domain: str = "") -> None: | ||
"""Initialize the node inferrer. | ||
|
||
Args: | ||
op_type: The type of the operation. | ||
opsets: A collection of ONNX opset versions supported by this inferrer. | ||
domain: The domain of the operation, default is an empty string. | ||
""" | ||
self.op_type = op_type | ||
self.opsets = opsets | ||
self.domain = domain | ||
|
||
def __repr__(self) -> str: | ||
"""Return a string representation of the node inferrer.""" | ||
return f"{self.__class__.__name__}(op_type={self.op_type}, opsets={self.opsets}, domain={self.domain})" | ||
|
||
@abc.abstractmethod | ||
def infer(self, node: ir.Node) -> InferenceResult: | ||
"""Infer the shape for the node. | ||
|
||
Args: | ||
node: The ONNX node to infer the type and shape for. | ||
|
||
Returns: | ||
A sequence of ONNX values containing the inferred shapes. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
def requires_non_none_inputs( | ||
count: int, / | ||
) -> Callable[ | ||
[Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult] | ||
]: | ||
"""Ensure that the node has a specific number of non-None inputs. | ||
|
||
Args: | ||
count: The exact number of non-None inputs required for the node. | ||
|
||
Returns: | ||
A decorator that checks the number of inputs and their non-None status. | ||
""" | ||
|
||
def decorator( | ||
func: Callable[[Any, ir.Node], InferenceResult], | ||
) -> Callable[[Any, ir.Node], InferenceResult]: | ||
@functools.wraps(func) | ||
def wrapper(self, node: ir.Node) -> InferenceResult: | ||
if len(node.inputs) != count: | ||
return InferenceResult( | ||
failure=f"[{node.op_type} must have {count} inputs, got {len(node.inputs)}." | ||
) | ||
for i, inp in enumerate(node.inputs): | ||
if inp is None: | ||
return InferenceResult(failure=f"{node.op_type} input {i} cannot be None.") | ||
return func(self, node) | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
def requires_outputs( | ||
count: int, / | ||
) -> Callable[ | ||
[Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult] | ||
]: | ||
"""Ensure that the node has a specific number of outputs. | ||
|
||
Args: | ||
count: The exact number of outputs required for the node. | ||
|
||
Returns: | ||
A decorator that checks the number of outputs. | ||
""" | ||
|
||
def decorator( | ||
func: Callable[[Any, ir.Node], InferenceResult], | ||
) -> Callable[[Any, ir.Node], InferenceResult]: | ||
@functools.wraps(func) | ||
def wrapper(self, node: ir.Node) -> InferenceResult: | ||
if len(node.outputs) != count: | ||
return InferenceResult( | ||
failure=f"[{node.op_type} must have {count} outputs, got {len(node.outputs)}." | ||
) | ||
return func(self, node) | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
def inclusive_range(start_or_end: int = 0, end: int | None = None) -> range: | ||
"""Create an inclusive range from start to end with a given step. | ||
|
||
Args: | ||
start_or_end: The starting value of the range. | ||
end: The ending value of the range (inclusive). | ||
|
||
Returns: | ||
A range object that includes both start and end. | ||
""" | ||
if end is None: | ||
end = start_or_end | ||
start = 0 | ||
else: | ||
start = start_or_end | ||
|
||
return range(start, end + 1) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.