Skip to content

Commit c103e75

Browse files
committed
Allow unifying with Op classes
1 parent c41e73e commit c103e75

File tree

4 files changed

+285
-40
lines changed

4 files changed

+285
-40
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 110 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pytensor.graph.features import AlreadyThere, Feature
3030
from pytensor.graph.fg import FunctionGraph, Output
3131
from pytensor.graph.op import Op
32-
from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars
32+
from pytensor.graph.rewriting.unify import OpInstance, Var, convert_strs_to_vars
3333
from pytensor.graph.utils import AssocList, InconsistencyError
3434
from pytensor.misc.ordered_set import OrderedSet
3535
from pytensor.utils import flatten
@@ -1320,6 +1320,7 @@ class PatternNodeRewriter(NodeRewriter):
13201320
The input and output patterns have the following syntax:
13211321
13221322
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
1323+
input_pattern ::= (OpInstance(type(op), {<param>: <value>, ...}), <sub_pattern1>, <sub_pattern2>, ...)
13231324
input_pattern ::= dict(pattern = <input_pattern>,
13241325
constraint = <constraint>)
13251326
sub_pattern ::= input_pattern
@@ -1333,6 +1334,7 @@ class PatternNodeRewriter(NodeRewriter):
13331334
output_pattern ::= string
13341335
output_pattern ::= int
13351336
output_pattern ::= float
1337+
output_pattern ::= callable
13361338
13371339
Each string in the input pattern is a variable that will be set to
13381340
whatever expression is found in its place. If the same string is
@@ -1358,20 +1360,73 @@ class PatternNodeRewriter(NodeRewriter):
13581360
Examples
13591361
--------
13601362
1361-
PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x'))
1362-
PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x'))
1363-
PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x')
1364-
PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x'))
1365-
PatternNodeRewriter((boggle, {'pattern': 'x',
1366-
'constraint': lambda expr: expr.type == scrabble}),
1367-
(scrabble, 'x'))
1363+
.. code-block:: python
13681364
1365+
from pytensor.graph.rewriting.basic import PatternNodeRewriter
1366+
from pytensor.tensor import add, mul, sub, pow, square
1367+
1368+
PatternNodeRewriter((add, "x", "y"), (add, "y", "x"))
1369+
PatternNodeRewriter((mul, "x", "x"), (square, "x"))
1370+
PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x")
1371+
PatternNodeRewriter((pow, "x", 2.0), (square, "x"))
1372+
PatternNodeRewriter(
1373+
(mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"),
1374+
(mul, "y", "x"),
1375+
)
1376+
1377+
You can use OpInstance to match a subtype of an Op, with some parameter constraints
1378+
You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments.
1379+
1380+
1381+
.. code-block:: python
1382+
1383+
from pytensor.graph.rewriting.basic import PatternNodeRewriter
1384+
from pytensor.graph.rewriting.unify import OpInstance
1385+
from pytensor.tensor.basic import Join
1386+
from pytensor.tensor.elemwise import CAReduce, Elemwise
1387+
1388+
1389+
def output_fn(fgraph, node, s):
1390+
reduce_op = node.op
1391+
reduced_a = reduce_op(s["a"])
1392+
reduced_b = reduce_op(s["b"])
1393+
return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
1394+
1395+
1396+
PatternNodeRewriter(
1397+
(
1398+
OpInstance(CAReduce, scalar_op="scalar_op", axis=None),
1399+
(Join(), "join_axis", "a", "b"),
1400+
),
1401+
output_fn,
1402+
)
1403+
1404+
1405+
If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable.
1406+
1407+
.. code-block:: python
1408+
1409+
1410+
from pytensor.graph.rewriting.basic import PatternNodeRewriter
1411+
from pytensor.graph.rewriting.unify import OpInstance, LiteralString
1412+
from pytensor.tensor.blockwise import Blockwise
1413+
from pytensor.tensor.slinalg import Solve
1414+
1415+
PatternNodeRewriter(
1416+
(
1417+
OpInstance(
1418+
Blockwise, core_op=OpInstance(Solve, assume_a=LiteralString("gen"))
1419+
),
1420+
"A",
1421+
"b",
1422+
)
1423+
)
13691424
"""
13701425

13711426
def __init__(
13721427
self,
1373-
in_pattern,
1374-
out_pattern,
1428+
in_pattern: tuple,
1429+
out_pattern: tuple | Callable,
13751430
allow_multiple_clients: bool = False,
13761431
name: str | None = None,
13771432
tracks=(),
@@ -1386,7 +1441,7 @@ def __init__(
13861441
in_pattern
13871442
The input pattern that we want to replace.
13881443
out_pattern
1389-
The replacement pattern.
1444+
The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs
13901445
allow_multiple_clients
13911446
If ``False``, the pattern matching will fail if one of the subpatterns has
13921447
more than one client.
@@ -1415,26 +1470,35 @@ def __init__(
14151470
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
14161471
self.values_eq_approx = values_eq_approx
14171472
self.allow_cast = allow_cast
1418-
if isinstance(in_pattern, list | tuple):
1419-
self.op = self.in_pattern[0]
1420-
elif isinstance(in_pattern, dict):
1421-
self.op = self.in_pattern["pattern"][0]
1422-
else:
1423-
raise TypeError(
1424-
"The pattern to search for must start with a specific Op instance."
1425-
)
14261473
self.allow_multiple_clients = allow_multiple_clients
14271474
if name:
14281475
self.__name__ = name
1429-
self._tracks = tracks
14301476
self.get_nodes = get_nodes
14311477
if tracks != ():
1432-
assert get_nodes
1478+
if not get_nodes:
1479+
raise ValueError("Custom `tracks` requires `get_nodes` to be provided.")
1480+
self._tracks = tracks
1481+
else:
1482+
if isinstance(in_pattern, list | tuple):
1483+
op = self.in_pattern[0]
1484+
elif isinstance(in_pattern, dict):
1485+
op = self.in_pattern["pattern"][0]
1486+
else:
1487+
raise TypeError(
1488+
"The pattern to search for must start with a specific Op instance."
1489+
)
1490+
if isinstance(op, Op):
1491+
self._tracks = [op]
1492+
elif isinstance(op, OpInstance):
1493+
self._tracks = [op.op_type]
1494+
else:
1495+
raise ValueError(
1496+
f"The pattern to search for must start with a specific Op instance or an OpInstance class. "
1497+
f"Got {op}, with type {type(op)}."
1498+
)
14331499

14341500
def tracks(self):
1435-
if self._tracks != ():
1436-
return self._tracks
1437-
return [self.op]
1501+
return self._tracks
14381502

14391503
def transform(self, fgraph, node, get_nodes=True):
14401504
"""Check if the graph from node corresponds to ``in_pattern``.
@@ -1455,28 +1519,39 @@ def transform(self, fgraph, node, get_nodes=True):
14551519
# PatternNodeRewriter doesn't support replacing multi-output nodes
14561520
return False
14571521

1458-
s = unify(self.in_pattern, node.out)
1522+
s = unify(self.in_pattern, node.out, {})
14591523

14601524
if s is False:
14611525
return False
14621526

1463-
ret = reify(self.out_pattern, s)
1464-
1465-
if isinstance(ret, ExpressionTuple):
1466-
ret = ret.evaled_obj
1467-
1468-
if self.values_eq_approx:
1469-
ret.tag.values_eq_approx = self.values_eq_approx
1470-
14711527
if not self.allow_multiple_clients:
1472-
input_vars = list(s.values())
1528+
input_vars = set(s.values())
1529+
clients = fgraph.clients
14731530
if any(
1474-
len(fgraph.clients[v]) > 1
1531+
len(clients[v]) > 1
14751532
for v in vars_between(input_vars, node.inputs)
14761533
if v not in input_vars
14771534
):
14781535
return False
14791536

1537+
if callable(self.out_pattern):
1538+
# token is the variable name used in the original pattern
1539+
ret = self.out_pattern(fgraph, node, {k.token: v for k, v in s.items()})
1540+
if ret is None or ret is False:
1541+
# The output function is still allowed to reject the rewrite
1542+
return False
1543+
if not isinstance(ret, Variable):
1544+
raise ValueError(
1545+
f"The output of the PatternNodeRewriter callable must be a variable got {ret} of type {type(ret)}."
1546+
)
1547+
else:
1548+
ret = reify(self.out_pattern, s)
1549+
if isinstance(ret, ExpressionTuple):
1550+
ret = ret.evaled_obj
1551+
1552+
if self.values_eq_approx:
1553+
ret.tag.values_eq_approx = self.values_eq_approx
1554+
14801555
[old_out] = node.outputs
14811556
if not old_out.type.is_super(ret.type):
14821557
from pytensor.tensor.type import TensorType

pytensor/graph/rewriting/unify.py

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
1111
"""
1212

13-
from collections.abc import Mapping
13+
from collections.abc import Mapping, Sequence
14+
from dataclasses import dataclass
1415
from numbers import Number
16+
from typing import Any
1517

1618
import numpy as np
1719
from cons.core import ConsError, _car, _cdr
@@ -254,6 +256,103 @@ def _unify_ConstrainedVar_object(u, v, s):
254256
_unify.add((object, ConstrainedVar, Mapping), _unify_ConstrainedVar_object)
255257

256258

259+
@dataclass(frozen=True)
260+
class LiteralString:
261+
value: str
262+
263+
264+
class OpInstance:
265+
"""Class that can be unified with Op instances of a given type and parameters.
266+
267+
An op instance is unified as long as the parameters specified in the OpInstance can be unified as well.
268+
Parameters that are not specified in the OpInstance are ignored during unification.
269+
270+
This is needed because some Ops can be complex to parametrize fully,
271+
and not all parameters are relevant for a given pattern.
272+
273+
Examples
274+
--------
275+
276+
.. testcode::
277+
278+
from unification import var, unify
279+
from etuples import etuple
280+
281+
import pytensor.tensor as pt
282+
from pytensor.graph.rewriting.unify import OpInstance
283+
from pytensor.tensor.blockwise import Blockwise
284+
from pytensor.tensor.slinalg import Solve
285+
286+
A = var("A")
287+
b = var("b")
288+
pattern = etuple(
289+
OpInstance(Blockwise, core_op=OpInstance(Solve, assume_a="gen")), A, b
290+
)
291+
292+
A_pt = pt.tensor3("A")
293+
b_pt = pt.tensor3("b")
294+
out1 = pt.linalg.solve(A_pt, b_pt)
295+
out2 = pt.linalg.solve(A_pt, b_pt, assume_a="pos")
296+
297+
assert unify(pattern, out1) == {A: A_pt, b: b_pt}
298+
assert unify(pattern, out2) is False
299+
300+
assume_a = var("assume_a")
301+
pattern = etuple(
302+
OpInstance(Blockwise, core_op=OpInstance(Solve, assume_a=assume_a)),
303+
A,
304+
b,
305+
)
306+
assert unify(pattern, out1) == {A: A_pt, b: b_pt, assume_a: "gen"}
307+
assert unify(pattern, out2) == {A: A_pt, b: b_pt, assume_a: "pos"}
308+
309+
310+
"""
311+
312+
def __init__(
313+
self,
314+
op_type: type[Op],
315+
parameters: dict[str, Any] | Sequence[tuple[str, Any]] | None = None,
316+
**kwargs,
317+
):
318+
if not (isinstance(op_type, type) and issubclass(op_type, Op)):
319+
raise TypeError(f"Invalid op_type {op_type}. Expected type(Op)")
320+
321+
if kwargs:
322+
if parameters is not None:
323+
raise ValueError(
324+
"Cannot provide both parameters dict and keyword arguments"
325+
)
326+
parameters = kwargs
327+
if isinstance(parameters, dict):
328+
parameters = tuple(sorted(parameters.items()))
329+
elif isinstance(parameters, list | tuple):
330+
parameters = tuple(sorted(parameters))
331+
elif parameters is None:
332+
parameters = ()
333+
self.op_type = op_type
334+
self.parameters = parameters
335+
336+
def __str__(self):
337+
return f"{self.op_type.__name__}({self.op_type}, {', '.join(f'{k}={v}' for k, v in self.parameters)})"
338+
339+
340+
def _unify_parametrized_op(v: Op, u: OpInstance, s: Mapping):
341+
if not isinstance(v, u.op_type):
342+
yield False
343+
return
344+
for parameter_key, parameter_pattern in u.parameters:
345+
parameter_value = getattr(v, parameter_key)
346+
s = yield _unify(parameter_value, parameter_pattern, s)
347+
if s is False:
348+
yield False
349+
return
350+
yield s
351+
352+
353+
_unify.add((Op, OpInstance, Mapping), _unify_parametrized_op)
354+
355+
257356
def convert_strs_to_vars(
258357
x: tuple | str | dict, var_map: dict[str, Var] | None = None
259358
) -> ExpressionTuple | Var:
@@ -266,11 +365,13 @@ def convert_strs_to_vars(
266365
if var_map is None:
267366
var_map = {}
268367

269-
def _convert(y):
368+
def _convert(y, op_prop=False):
270369
if isinstance(y, str):
271370
v = var_map.get(y, var(y))
272371
var_map[y] = v
273372
return v
373+
if isinstance(y, LiteralString):
374+
return y.value
274375
elif isinstance(y, dict):
275376
pattern = y["pattern"]
276377
if not isinstance(pattern, str):
@@ -282,8 +383,14 @@ def _convert(y):
282383
var_map[pattern] = v
283384
return v
284385
elif isinstance(y, tuple):
285-
return etuple(*(_convert(e) for e in y))
286-
elif isinstance(y, Number | np.ndarray):
386+
return etuple(*(_convert(e, op_prop=op_prop) for e in y))
387+
elif isinstance(y, OpInstance):
388+
return OpInstance(
389+
y.op_type,
390+
{k: _convert(v, op_prop=True) for k, v in y.parameters},
391+
)
392+
elif (not op_prop) and isinstance(y, Number | np.ndarray):
393+
# If we are converting an Op property, we don't want to convert numbers to PyTensor constants
287394
from pytensor.tensor import as_tensor_variable
288395

289396
return as_tensor_variable(y)

0 commit comments

Comments
 (0)