29
29
from pytensor .graph .features import AlreadyThere , Feature
30
30
from pytensor .graph .fg import FunctionGraph , Output
31
31
from pytensor .graph .op import Op
32
- from pytensor .graph .rewriting .unify import Var , convert_strs_to_vars
32
+ from pytensor .graph .rewriting .unify import OpPattern , Var , convert_strs_to_vars
33
33
from pytensor .graph .utils import AssocList , InconsistencyError
34
34
from pytensor .misc .ordered_set import OrderedSet
35
35
from pytensor .utils import flatten
@@ -1312,6 +1312,7 @@ class PatternNodeRewriter(NodeRewriter):
1312
1312
The input and output patterns have the following syntax:
1313
1313
1314
1314
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
1315
+ input_pattern ::= (OpPattern(type(op), {<param>: <value>, ...}), <sub_pattern1>, <sub_pattern2>, ...)
1315
1316
input_pattern ::= dict(pattern = <input_pattern>,
1316
1317
constraint = <constraint>)
1317
1318
sub_pattern ::= input_pattern
@@ -1325,6 +1326,7 @@ class PatternNodeRewriter(NodeRewriter):
1325
1326
output_pattern ::= string
1326
1327
output_pattern ::= int
1327
1328
output_pattern ::= float
1329
+ output_pattern ::= callable
1328
1330
1329
1331
Each string in the input pattern is a variable that will be set to
1330
1332
whatever expression is found in its place. If the same string is
@@ -1350,20 +1352,73 @@ class PatternNodeRewriter(NodeRewriter):
1350
1352
Examples
1351
1353
--------
1352
1354
1353
- PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x'))
1354
- PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x'))
1355
- PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x')
1356
- PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x'))
1357
- PatternNodeRewriter((boggle, {'pattern': 'x',
1358
- 'constraint': lambda expr: expr.type == scrabble}),
1359
- (scrabble, 'x'))
1355
+ .. code-block:: python
1360
1356
1357
+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1358
+ from pytensor.tensor import add, mul, sub, pow, square
1359
+
1360
+ PatternNodeRewriter((add, "x", "y"), (add, "y", "x"))
1361
+ PatternNodeRewriter((mul, "x", "x"), (square, "x"))
1362
+ PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x")
1363
+ PatternNodeRewriter((pow, "x", 2.0), (square, "x"))
1364
+ PatternNodeRewriter(
1365
+ (mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"),
1366
+ (mul, "y", "x"),
1367
+ )
1368
+
1369
+ You can use OpPattern to match a subtype of an Op, with some parameter constraints
1370
+ You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments.
1371
+
1372
+
1373
+ .. code-block:: python
1374
+
1375
+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1376
+ from pytensor.graph.rewriting.unify import OpPattern
1377
+ from pytensor.tensor.basic import Join
1378
+ from pytensor.tensor.elemwise import CAReduce, Elemwise
1379
+
1380
+
1381
+ def output_fn(fgraph, node, s):
1382
+ reduce_op = node.op
1383
+ reduced_a = reduce_op(s["a"])
1384
+ reduced_b = reduce_op(s["b"])
1385
+ return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
1386
+
1387
+
1388
+ PatternNodeRewriter(
1389
+ (
1390
+ OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
1391
+ (Join(), "join_axis", "a", "b"),
1392
+ ),
1393
+ output_fn,
1394
+ )
1395
+
1396
+
1397
+ If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable.
1398
+
1399
+ .. code-block:: python
1400
+
1401
+
1402
+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1403
+ from pytensor.graph.rewriting.unify import OpPattern, LiteralString
1404
+ from pytensor.tensor.blockwise import Blockwise
1405
+ from pytensor.tensor.slinalg import Solve
1406
+
1407
+ PatternNodeRewriter(
1408
+ (
1409
+ OpPattern(
1410
+ Blockwise, core_op=OpPattern(Solve, assume_a=LiteralString("gen"))
1411
+ ),
1412
+ "A",
1413
+ "b",
1414
+ )
1415
+ )
1361
1416
"""
1362
1417
1363
1418
def __init__ (
1364
1419
self ,
1365
- in_pattern ,
1366
- out_pattern ,
1420
+ in_pattern : tuple ,
1421
+ out_pattern : tuple | Callable | str ,
1367
1422
allow_multiple_clients : bool = False ,
1368
1423
name : str | None = None ,
1369
1424
tracks = (),
@@ -1378,7 +1433,8 @@ def __init__(
1378
1433
in_pattern
1379
1434
The input pattern that we want to replace.
1380
1435
out_pattern
1381
- The replacement pattern.
1436
+ The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs,
1437
+ and returns the replacement variable (or None/False to reject the rewrite).
1382
1438
allow_multiple_clients
1383
1439
If ``False``, the pattern matching will fail if one of the subpatterns has
1384
1440
more than one client.
@@ -1407,26 +1463,40 @@ def __init__(
1407
1463
self .out_pattern = convert_strs_to_vars (out_pattern , var_map = var_map )
1408
1464
self .values_eq_approx = values_eq_approx
1409
1465
self .allow_cast = allow_cast
1410
- if isinstance (in_pattern , list | tuple ):
1411
- self .op = self .in_pattern [0 ]
1412
- elif isinstance (in_pattern , dict ):
1413
- self .op = self .in_pattern ["pattern" ][0 ]
1414
- else :
1415
- raise TypeError (
1416
- "The pattern to search for must start with a specific Op instance."
1417
- )
1418
1466
self .allow_multiple_clients = allow_multiple_clients
1419
1467
if name :
1420
1468
self .__name__ = name
1421
- self ._tracks = tracks
1422
1469
self .get_nodes = get_nodes
1423
1470
if tracks != ():
1424
- assert get_nodes
1471
+ if not get_nodes :
1472
+ raise ValueError ("Custom `tracks` requires `get_nodes` to be provided." )
1473
+ self ._tracks = tracks
1474
+ else :
1475
+ if isinstance (in_pattern , list | tuple ):
1476
+ op = self .in_pattern [0 ]
1477
+ elif isinstance (in_pattern , dict ):
1478
+ op = self .in_pattern ["pattern" ][0 ]
1479
+ else :
1480
+ raise TypeError (
1481
+ f"The in_pattern must be a sequence or a dict, but got { in_pattern } of type { type (in_pattern )} "
1482
+ )
1483
+ if isinstance (op , Op ):
1484
+ self ._tracks = [op ]
1485
+ elif isinstance (op , type ) and issubclass (op , Op ):
1486
+ raise ValueError (
1487
+ f"The in_pattern starts with an Op class { op } , not an instance.\n "
1488
+ "You can use pytensor.graph.unify.OpPattern instead if you want to match instances of a class."
1489
+ )
1490
+ elif isinstance (op , OpPattern ):
1491
+ self ._tracks = [op .op_type ]
1492
+ else :
1493
+ raise ValueError (
1494
+ f"The in_pattern must start with a specific Op or an OpPattern instance. "
1495
+ f"Got { op } , with type { type (op )} ."
1496
+ )
1425
1497
1426
1498
def tracks (self ):
1427
- if self ._tracks != ():
1428
- return self ._tracks
1429
- return [self .op ]
1499
+ return self ._tracks
1430
1500
1431
1501
def transform (self , fgraph , node , get_nodes = True ):
1432
1502
"""Check if the graph from node corresponds to ``in_pattern``.
@@ -1447,28 +1517,39 @@ def transform(self, fgraph, node, get_nodes=True):
1447
1517
# PatternNodeRewriter doesn't support replacing multi-output nodes
1448
1518
return False
1449
1519
1450
- s = unify (self .in_pattern , node .out )
1520
+ s = unify (self .in_pattern , node .out , {} )
1451
1521
1452
1522
if s is False :
1453
1523
return False
1454
1524
1455
- ret = reify (self .out_pattern , s )
1456
-
1457
- if isinstance (ret , ExpressionTuple ):
1458
- ret = ret .evaled_obj
1459
-
1460
- if self .values_eq_approx :
1461
- ret .tag .values_eq_approx = self .values_eq_approx
1462
-
1463
1525
if not self .allow_multiple_clients :
1464
- input_vars = list (s .values ())
1526
+ input_vars = set (s .values ())
1527
+ clients = fgraph .clients
1465
1528
if any (
1466
- len (fgraph . clients [v ]) > 1
1529
+ len (clients [v ]) > 1
1467
1530
for v in vars_between (input_vars , node .inputs )
1468
1531
if v not in input_vars
1469
1532
):
1470
1533
return False
1471
1534
1535
+ if callable (self .out_pattern ):
1536
+ # token is the variable name used in the original pattern
1537
+ ret = self .out_pattern (fgraph , node , {k .token : v for k , v in s .items ()})
1538
+ if ret is None or ret is False :
1539
+ # The output function is still allowed to reject the rewrite
1540
+ return False
1541
+ if not isinstance (ret , Variable ):
1542
+ raise ValueError (
1543
+ f"The output of the PatternNodeRewriter callable must be a variable got { ret } of type { type (ret )} ."
1544
+ )
1545
+ else :
1546
+ ret = reify (self .out_pattern , s )
1547
+ if isinstance (ret , ExpressionTuple ):
1548
+ ret = ret .evaled_obj
1549
+
1550
+ if self .values_eq_approx :
1551
+ ret .tag .values_eq_approx = self .values_eq_approx
1552
+
1472
1553
[old_out ] = node .outputs
1473
1554
if not old_out .type .is_super (ret .type ):
1474
1555
from pytensor .tensor .type import TensorType
0 commit comments