-
Notifications
You must be signed in to change notification settings - Fork 139
Implement OpPattern
for more flexible tracks
and PatternNodeRewriter
#1594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,8 +10,11 @@ | |
|
||
""" | ||
|
||
from collections.abc import Mapping | ||
from collections.abc import Mapping, Sequence | ||
from dataclasses import dataclass | ||
from numbers import Number | ||
from types import UnionType | ||
from typing import Any, TypeAlias | ||
|
||
import numpy as np | ||
from cons.core import ConsError, _car, _cdr | ||
|
@@ -254,6 +257,200 @@ def _unify_ConstrainedVar_object(u, v, s): | |
_unify.add((object, ConstrainedVar, Mapping), _unify_ConstrainedVar_object) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class LiteralString: | ||
value: str | ||
|
||
|
||
OpPatternOpTypeType: TypeAlias = type[Op] | tuple[type[Op], ...] | UnionType | ||
|
||
|
||
@dataclass(unsafe_hash=True) | ||
class OpPattern: | ||
"""Class that can be unified with Op instances of a given type (or instance) and parameters. | ||
|
||
Parameters that are not specified in the OpPattern are ignored during unification. | ||
|
||
This is needed because some Ops can be complex to parametrize fully, | ||
and not all parameters are relevant for a given pattern. | ||
|
||
|
||
Examples | ||
-------- | ||
|
||
OpPattern can be used in the `tracks` functionality of `node_rewriter` to more flexible filter out nodes. | ||
For Ops that are parametrized by other Ops, it's possible to use nested OpPatterns. | ||
|
||
.. test-code:: | ||
|
||
from pytensor.graph.rewriting.basic import node_rewriter | ||
from pytensor.graph.rewriting.unify import OpPattern | ||
from pytensor.tensor.elemwise import CAReduce | ||
from pytensor.tensor.blockwise import Blockwise | ||
from pytensor.tensor.slinalg import Solve | ||
|
||
@node_rewriter(tracks=[OpPattern(CAReduce, axis=None)]) | ||
def local_car_reduce_all_rewriter(fgraph, node): | ||
# This will always be true! | ||
assert isinstance(node.op, CAReduce) and node.op.axis is None | ||
... | ||
|
||
# Any Blockwise whose core_op is a Solve Op (or subclass) instance | ||
@node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve))]) | ||
def local_blockwise_solve_triangular_rewriter(fgraph, node): | ||
# This will always be true! | ||
assert isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve) | ||
... | ||
|
||
# Any Blockwise whose core_op is a Solve Op (or subclass) instance with b_ndim==1 | ||
@node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve, b_ndim=1))]) | ||
def local_blockwise_vector_solve_rewriter(fgraph, node): | ||
# This will always be true! | ||
assert ( | ||
isinstance(node.op, Blockwise) | ||
and isinstance(node.op.core_op, Solve) | ||
and node.op.core_op.b_ndim == 1 | ||
) | ||
... | ||
|
||
|
||
OpPattern can be used with `PatternNodeRewriter` to define graph rewrites that match Ops with specific parameters. | ||
The example below matches two nested CAReduce Ops with the same `scalar_op`, | ||
the outer with `axis=None` (full reduction) and fuses them into a single CAReduce. | ||
Note, that because we didn't specify it, the axis of the inner CAReduce can be anything. | ||
The same goes for other properties of the Op that are not specified in the OpPattern. | ||
|
||
.. testcode:: | ||
|
||
from pytensor.graph.rewriting.basic import PatternNodeRewriter | ||
from pytensor.graph.rewriting.unify import OpPattern | ||
from pytensor.tensor.basic import Join | ||
from pytensor.tensor.elemwise import CAReduce, Elemwise | ||
|
||
def output_fn(fgraph, node, s): | ||
reduce_op = node.op | ||
reduced_a = reduce_op(s["a"]) | ||
reduced_b = reduce_op(s["b"]) | ||
return Elemwise(s["scalar_op"])(reduced_a, reduced_b) | ||
|
||
|
||
PatternNodeRewriter( | ||
in_pattern=(OpPattern(CAReduce, scalar_op="scalar_op", axis=None), | ||
(OpPattern(CAReduce, scalar_op="scalar_op",), "x")), | ||
out_pattern=(OpPattern(CAReduce, scalar_op="scalar_op", axis=None), "x"), | ||
) | ||
|
||
|
||
OpPattern can also be used with `unification.unify` to match Ops with specific parameters. | ||
This is used by PatternNodeRewriter but can also be used directly. | ||
|
||
.. testcode:: | ||
|
||
from unification import var, unify | ||
from etuples import etuple | ||
|
||
import pytensor.tensor as pt | ||
from pytensor.graph.rewriting.unify import OpPattern | ||
from pytensor.tensor.blockwise import Blockwise | ||
from pytensor.tensor.slinalg import Solve | ||
|
||
A = var("A") | ||
b = var("b") | ||
pattern = etuple( | ||
OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a="gen")), | ||
A, | ||
b, | ||
) | ||
|
||
A_pt = pt.tensor3("A") | ||
b_pt = pt.tensor3("b") | ||
out1 = pt.linalg.solve(A_pt, b_pt) | ||
out2 = pt.linalg.solve(A_pt, b_pt, assume_a="pos") | ||
|
||
assert unify(pattern, out1) == {A: A_pt, b: b_pt} | ||
assert unify(pattern, out2) is False | ||
|
||
assume_a = var("assume_a") | ||
pattern = etuple( | ||
OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a=assume_a)), | ||
A, | ||
b, | ||
) | ||
assert unify(pattern, out1) == {A: A_pt, b: b_pt, assume_a: "gen"} | ||
assert unify(pattern, out2) == {A: A_pt, b: b_pt, assume_a: "pos"} | ||
|
||
|
||
""" | ||
|
||
op_type: OpPatternOpTypeType | ||
parameters: tuple[tuple[str, Any]] | ||
|
||
def __init__( | ||
self, | ||
op_type: OpPatternOpTypeType, | ||
parameters: dict[str, Any] | Sequence[tuple[str, Any]] | None = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. parameters is not of type parameters 🤔 |
||
**kwargs, | ||
): | ||
if kwargs: | ||
if parameters is not None: | ||
raise ValueError( | ||
"Cannot provide both parameters dict and keyword arguments" | ||
) | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
parameters = kwargs | ||
if isinstance(parameters, dict): | ||
parameters = tuple(sorted(parameters.items())) | ||
elif isinstance(parameters, list | tuple): | ||
parameters = tuple(sorted(parameters)) | ||
elif parameters is None: | ||
parameters = () | ||
self.op_type = op_type | ||
self.parameters = parameters # type: ignore[assignment] | ||
|
||
def match_op(self, op: Op): | ||
if not isinstance(op, self.op_type): | ||
return False | ||
return self.match_parameters(op) | ||
|
||
def match_parameters(self, op): | ||
# This is used by methods that already check the op_type is satisfied | ||
# Some methods may index on the op_type and know in advance the op is matched | ||
# Also recursive calls to OpPattern.match_parameters do the op check outside to exit early (see below) | ||
for key, param in self.parameters: | ||
if isinstance(param, OpPattern): | ||
# Parameters can itself be other OpPatterns | ||
# We check the op_type to avoid a nested call in cases we can reject early | ||
sub_op = getattr(op, key) | ||
if not isinstance(sub_op, param.op_type): | ||
return False | ||
# Match the pattern of the inner Op | ||
# Skip if there are no parameters | ||
if param.parameters and not param.match_parameters(sub_op): | ||
return False | ||
elif getattr(op, key) != param: | ||
return False | ||
return True | ||
|
||
def __str__(self): | ||
return f"OpPattern({self.op_type}, {', '.join(f'{k}={v}' for k, v in self.parameters)})" | ||
|
||
|
||
def _unify_parametrized_op(v: Op, u: OpPattern, s: Mapping): | ||
if not isinstance(v, u.op_type): | ||
yield False | ||
return | ||
for parameter_key, parameter_pattern in u.parameters: | ||
parameter_value = getattr(v, parameter_key) | ||
new_s = yield _unify(parameter_value, parameter_pattern, s) | ||
if new_s is False: | ||
yield False | ||
return | ||
s = new_s | ||
yield s | ||
|
||
|
||
_unify.add((Op, OpPattern, Mapping), _unify_parametrized_op) | ||
|
||
|
||
def convert_strs_to_vars( | ||
x: tuple | str | dict, var_map: dict[str, Var] | None = None | ||
) -> ExpressionTuple | Var: | ||
|
@@ -266,11 +463,13 @@ def convert_strs_to_vars( | |
if var_map is None: | ||
var_map = {} | ||
|
||
def _convert(y): | ||
def _convert(y, op_prop=False): | ||
if isinstance(y, str): | ||
v = var_map.get(y, var(y)) | ||
var_map[y] = v | ||
return v | ||
if isinstance(y, LiteralString): | ||
return y.value | ||
elif isinstance(y, dict): | ||
pattern = y["pattern"] | ||
if not isinstance(pattern, str): | ||
|
@@ -282,8 +481,14 @@ def _convert(y): | |
var_map[pattern] = v | ||
return v | ||
elif isinstance(y, tuple): | ||
return etuple(*(_convert(e) for e in y)) | ||
elif isinstance(y, Number | np.ndarray): | ||
return etuple(*(_convert(e, op_prop=op_prop) for e in y)) | ||
elif isinstance(y, OpPattern): | ||
return OpPattern( | ||
y.op_type, | ||
{k: _convert(v, op_prop=True) for k, v in y.parameters}, | ||
) | ||
elif (not op_prop) and isinstance(y, Number | np.ndarray): | ||
# If we are converting an Op property, we don't want to convert numbers to PyTensor constants | ||
from pytensor.tensor import as_tensor_variable | ||
|
||
return as_tensor_variable(y) | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -1228,6 +1228,8 @@ def __init__(self, output_types_preference=None, name=None): | |||||
f"(got: {output_types_preference})" | ||||||
) | ||||||
self.output_types_preference = output_types_preference | ||||||
elif not hasattr(self, "output_types_preference"): | ||||||
self.output_types_preference = None | ||||||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
def make_node(self, *inputs): | ||||||
if self.nin >= 0: | ||||||
|
@@ -1247,7 +1249,7 @@ def make_node(self, *inputs): | |||||
return Apply(self, inputs, outputs) | ||||||
|
||||||
def output_types(self, types): | ||||||
if hasattr(self, "output_types_preference"): | ||||||
if self.output_types_preference is not None: | ||||||
variables = self.output_types_preference(*types) | ||||||
if not isinstance(variables, list | tuple) or any( | ||||||
not isinstance(x, CType) for x in variables | ||||||
|
@@ -2696,7 +2698,7 @@ class Sign(UnaryScalarOp): | |||||
nfunc_spec = ("sign", 1, 1) | ||||||
|
||||||
@staticmethod | ||||||
def output_types_preference(x): | ||||||
def _output_types_preference(x): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Since it's a method in this case, not a member variable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is what init of the base class will look for if I don't pass a function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't get it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check for |
||||||
if x == bool: | ||||||
raise TypeError(x) | ||||||
return same_out_nocomplex(x) | ||||||
|
@@ -2737,7 +2739,7 @@ def c_code_cache_version(self): | |||||
return s | ||||||
|
||||||
|
||||||
sign = Sign(name="sign") | ||||||
sign = Sign(name="sign", output_types_preference=Sign._output_types_preference) | ||||||
|
||||||
|
||||||
class Ceil(UnaryScalarOp): | ||||||
|
Uh oh!
There was an error while loading. Please reload this page.