Skip to content

Commit 1b99b08

Browse files
committed
Allow unifying with OpPattern
1 parent 7a1699b commit 1b99b08

File tree

4 files changed

+350
-40
lines changed

4 files changed

+350
-40
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 116 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pytensor.graph.features import AlreadyThere, Feature
3030
from pytensor.graph.fg import FunctionGraph, Output
3131
from pytensor.graph.op import Op
32-
from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars
32+
from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars
3333
from pytensor.graph.utils import AssocList, InconsistencyError
3434
from pytensor.misc.ordered_set import OrderedSet
3535
from pytensor.utils import flatten
@@ -1312,6 +1312,7 @@ class PatternNodeRewriter(NodeRewriter):
13121312
The input and output patterns have the following syntax:
13131313
13141314
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
1315+
input_pattern ::= (OpPattern(type(op), {<param>: <value>, ...}), <sub_pattern1>, <sub_pattern2>, ...)
13151316
input_pattern ::= dict(pattern = <input_pattern>,
13161317
constraint = <constraint>)
13171318
sub_pattern ::= input_pattern
@@ -1325,6 +1326,7 @@ class PatternNodeRewriter(NodeRewriter):
13251326
output_pattern ::= string
13261327
output_pattern ::= int
13271328
output_pattern ::= float
1329+
output_pattern ::= callable
13281330
13291331
Each string in the input pattern is a variable that will be set to
13301332
whatever expression is found in its place. If the same string is
@@ -1350,20 +1352,73 @@ class PatternNodeRewriter(NodeRewriter):
13501352
Examples
13511353
--------
13521354
1353-
PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x'))
1354-
PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x'))
1355-
PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x')
1356-
PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x'))
1357-
PatternNodeRewriter((boggle, {'pattern': 'x',
1358-
'constraint': lambda expr: expr.type == scrabble}),
1359-
(scrabble, 'x'))
1355+
.. code-block:: python
13601356
1357+
from pytensor.graph.rewriting.basic import PatternNodeRewriter
1358+
from pytensor.tensor import add, mul, sub, pow, square
1359+
1360+
PatternNodeRewriter((add, "x", "y"), (add, "y", "x"))
1361+
PatternNodeRewriter((mul, "x", "x"), (square, "x"))
1362+
PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x")
1363+
PatternNodeRewriter((pow, "x", 2.0), (square, "x"))
1364+
PatternNodeRewriter(
1365+
(mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"),
1366+
(mul, "y", "x"),
1367+
)
1368+
1369+
You can use OpPattern to match a subtype of an Op, with some parameter constraints
1370+
You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments.
1371+
1372+
1373+
.. code-block:: python
1374+
1375+
from pytensor.graph.rewriting.basic import PatternNodeRewriter
1376+
from pytensor.graph.rewriting.unify import OpPattern
1377+
from pytensor.tensor.basic import Join
1378+
from pytensor.tensor.elemwise import CAReduce, Elemwise
1379+
1380+
1381+
def output_fn(fgraph, node, s):
1382+
reduce_op = node.op
1383+
reduced_a = reduce_op(s["a"])
1384+
reduced_b = reduce_op(s["b"])
1385+
return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
1386+
1387+
1388+
PatternNodeRewriter(
1389+
(
1390+
OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
1391+
(Join(), "join_axis", "a", "b"),
1392+
),
1393+
output_fn,
1394+
)
1395+
1396+
1397+
If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable.
1398+
1399+
.. code-block:: python
1400+
1401+
1402+
from pytensor.graph.rewriting.basic import PatternNodeRewriter
1403+
from pytensor.graph.rewriting.unify import OpPattern, LiteralString
1404+
from pytensor.tensor.blockwise import Blockwise
1405+
from pytensor.tensor.slinalg import Solve
1406+
1407+
PatternNodeRewriter(
1408+
(
1409+
OpPattern(
1410+
Blockwise, core_op=OpPattern(Solve, assume_a=LiteralString("gen"))
1411+
),
1412+
"A",
1413+
"b",
1414+
)
1415+
)
13611416
"""
13621417

13631418
def __init__(
13641419
self,
1365-
in_pattern,
1366-
out_pattern,
1420+
in_pattern: tuple,
1421+
out_pattern: tuple | Callable | str,
13671422
allow_multiple_clients: bool = False,
13681423
name: str | None = None,
13691424
tracks=(),
@@ -1378,7 +1433,8 @@ def __init__(
13781433
in_pattern
13791434
The input pattern that we want to replace.
13801435
out_pattern
1381-
The replacement pattern.
1436+
The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs,
1437+
and returns the replacement variable (or None/False to reject the rewrite).
13821438
allow_multiple_clients
13831439
If ``False``, the pattern matching will fail if one of the subpatterns has
13841440
more than one client.
@@ -1407,26 +1463,40 @@ def __init__(
14071463
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
14081464
self.values_eq_approx = values_eq_approx
14091465
self.allow_cast = allow_cast
1410-
if isinstance(in_pattern, list | tuple):
1411-
self.op = self.in_pattern[0]
1412-
elif isinstance(in_pattern, dict):
1413-
self.op = self.in_pattern["pattern"][0]
1414-
else:
1415-
raise TypeError(
1416-
"The pattern to search for must start with a specific Op instance."
1417-
)
14181466
self.allow_multiple_clients = allow_multiple_clients
14191467
if name:
14201468
self.__name__ = name
1421-
self._tracks = tracks
14221469
self.get_nodes = get_nodes
14231470
if tracks != ():
1424-
assert get_nodes
1471+
if not get_nodes:
1472+
raise ValueError("Custom `tracks` requires `get_nodes` to be provided.")
1473+
self._tracks = tracks
1474+
else:
1475+
if isinstance(in_pattern, list | tuple):
1476+
op = self.in_pattern[0]
1477+
elif isinstance(in_pattern, dict):
1478+
op = self.in_pattern["pattern"][0]
1479+
else:
1480+
raise TypeError(
1481+
f"The in_pattern must be a sequence or a dict, but got {in_pattern} of type {type(in_pattern)}"
1482+
)
1483+
if isinstance(op, Op):
1484+
self._tracks = [op]
1485+
elif isinstance(op, type) and issubclass(op, Op):
1486+
raise ValueError(
1487+
f"The in_pattern starts with an Op class {op}, not an instance.\n"
1488+
"You can use pytensor.graph.unify.OpPattern instead if you want to match instances of a class."
1489+
)
1490+
elif isinstance(op, OpPattern):
1491+
self._tracks = [op.op_type]
1492+
else:
1493+
raise ValueError(
1494+
f"The in_pattern must start with a specific Op or an OpPattern instance. "
1495+
f"Got {op}, with type {type(op)}."
1496+
)
14251497

14261498
def tracks(self):
1427-
if self._tracks != ():
1428-
return self._tracks
1429-
return [self.op]
1499+
return self._tracks
14301500

14311501
def transform(self, fgraph, node, get_nodes=True):
14321502
"""Check if the graph from node corresponds to ``in_pattern``.
@@ -1447,28 +1517,39 @@ def transform(self, fgraph, node, get_nodes=True):
14471517
# PatternNodeRewriter doesn't support replacing multi-output nodes
14481518
return False
14491519

1450-
s = unify(self.in_pattern, node.out)
1520+
s = unify(self.in_pattern, node.out, {})
14511521

14521522
if s is False:
14531523
return False
14541524

1455-
ret = reify(self.out_pattern, s)
1456-
1457-
if isinstance(ret, ExpressionTuple):
1458-
ret = ret.evaled_obj
1459-
1460-
if self.values_eq_approx:
1461-
ret.tag.values_eq_approx = self.values_eq_approx
1462-
14631525
if not self.allow_multiple_clients:
1464-
input_vars = list(s.values())
1526+
input_vars = set(s.values())
1527+
clients = fgraph.clients
14651528
if any(
1466-
len(fgraph.clients[v]) > 1
1529+
len(clients[v]) > 1
14671530
for v in vars_between(input_vars, node.inputs)
14681531
if v not in input_vars
14691532
):
14701533
return False
14711534

1535+
if callable(self.out_pattern):
1536+
# token is the variable name used in the original pattern
1537+
ret = self.out_pattern(fgraph, node, {k.token: v for k, v in s.items()})
1538+
if ret is None or ret is False:
1539+
# The output function is still allowed to reject the rewrite
1540+
return False
1541+
if not isinstance(ret, Variable):
1542+
raise ValueError(
1543+
f"The output of the PatternNodeRewriter callable must be a variable got {ret} of type {type(ret)}."
1544+
)
1545+
else:
1546+
ret = reify(self.out_pattern, s)
1547+
if isinstance(ret, ExpressionTuple):
1548+
ret = ret.evaled_obj
1549+
1550+
if self.values_eq_approx:
1551+
ret.tag.values_eq_approx = self.values_eq_approx
1552+
14721553
[old_out] = node.outputs
14731554
if not old_out.type.is_super(ret.type):
14741555
from pytensor.tensor.type import TensorType

0 commit comments

Comments
 (0)