29
29
from pytensor .graph .features import AlreadyThere , Feature
30
30
from pytensor .graph .fg import FunctionGraph , Output
31
31
from pytensor .graph .op import Op
32
- from pytensor .graph .rewriting .unify import Var , convert_strs_to_vars
32
+ from pytensor .graph .rewriting .unify import OpInstance , Var , convert_strs_to_vars
33
33
from pytensor .graph .utils import AssocList , InconsistencyError
34
34
from pytensor .misc .ordered_set import OrderedSet
35
35
from pytensor .utils import flatten
@@ -1320,6 +1320,7 @@ class PatternNodeRewriter(NodeRewriter):
1320
1320
The input and output patterns have the following syntax:
1321
1321
1322
1322
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
1323
+ input_pattern ::= (OpInstance(type(op), {<param>: <value>, ...}), <sub_pattern1>, <sub_pattern2>, ...)
1323
1324
input_pattern ::= dict(pattern = <input_pattern>,
1324
1325
constraint = <constraint>)
1325
1326
sub_pattern ::= input_pattern
@@ -1333,6 +1334,7 @@ class PatternNodeRewriter(NodeRewriter):
1333
1334
output_pattern ::= string
1334
1335
output_pattern ::= int
1335
1336
output_pattern ::= float
1337
+ output_pattern ::= callable
1336
1338
1337
1339
Each string in the input pattern is a variable that will be set to
1338
1340
whatever expression is found in its place. If the same string is
@@ -1358,20 +1360,73 @@ class PatternNodeRewriter(NodeRewriter):
1358
1360
Examples
1359
1361
--------
1360
1362
1361
- PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x'))
1362
- PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x'))
1363
- PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x')
1364
- PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x'))
1365
- PatternNodeRewriter((boggle, {'pattern': 'x',
1366
- 'constraint': lambda expr: expr.type == scrabble}),
1367
- (scrabble, 'x'))
1363
+ .. code-block:: python
1368
1364
1365
+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1366
+ from pytensor.tensor import add, mul, sub, pow, square
1367
+
1368
+ PatternNodeRewriter((add, "x", "y"), (add, "y", "x"))
1369
+ PatternNodeRewriter((mul, "x", "x"), (square, "x"))
1370
+ PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x")
1371
+ PatternNodeRewriter((pow, "x", 2.0), (square, "x"))
1372
+ PatternNodeRewriter(
1373
+ (mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"),
1374
+ (mul, "y", "x"),
1375
+ )
1376
+
1377
+ You can use OpInstance to match a subtype of an Op, with some parameter constraints
1378
+ You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments.
1379
+
1380
+
1381
+ .. code-block:: python
1382
+
1383
+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1384
+ from pytensor.graph.rewriting.unify import OpInstance
1385
+ from pytensor.tensor.basic import Join
1386
+ from pytensor.tensor.elemwise import CAReduce, Elemwise
1387
+
1388
+
1389
+ def output_fn(fgraph, node, s):
1390
+ reduce_op = node.op
1391
+ reduced_a = reduce_op(s["a"])
1392
+ reduced_b = reduce_op(s["b"])
1393
+ return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
1394
+
1395
+
1396
+ PatternNodeRewriter(
1397
+ (
1398
+ OpInstance(CAReduce, scalar_op="scalar_op", axis=None),
1399
+ (Join(), "join_axis", "a", "b"),
1400
+ ),
1401
+ output_fn,
1402
+ )
1403
+
1404
+
1405
+ If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable.
1406
+
1407
+ .. code-block:: python
1408
+
1409
+
1410
+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1411
+ from pytensor.graph.rewriting.unify import OpInstance, LiteralString
1412
+ from pytensor.tensor.blockwise import Blockwise
1413
+ from pytensor.tensor.slinalg import Solve
1414
+
1415
+ PatternNodeRewriter(
1416
+ (
1417
+ OpInstance(
1418
+ Blockwise, core_op=OpInstance(Solve, assume_a=LiteralString("gen"))
1419
+ ),
1420
+ "A",
1421
+ "b",
1422
+ )
1423
+ )
1369
1424
"""
1370
1425
1371
1426
def __init__ (
1372
1427
self ,
1373
- in_pattern ,
1374
- out_pattern ,
1428
+ in_pattern : tuple ,
1429
+ out_pattern : tuple | Callable ,
1375
1430
allow_multiple_clients : bool = False ,
1376
1431
name : str | None = None ,
1377
1432
tracks = (),
@@ -1386,7 +1441,7 @@ def __init__(
1386
1441
in_pattern
1387
1442
The input pattern that we want to replace.
1388
1443
out_pattern
1389
- The replacement pattern.
1444
+ The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs
1390
1445
allow_multiple_clients
1391
1446
If ``False``, the pattern matching will fail if one of the subpatterns has
1392
1447
more than one client.
@@ -1415,26 +1470,35 @@ def __init__(
1415
1470
self .out_pattern = convert_strs_to_vars (out_pattern , var_map = var_map )
1416
1471
self .values_eq_approx = values_eq_approx
1417
1472
self .allow_cast = allow_cast
1418
- if isinstance (in_pattern , list | tuple ):
1419
- self .op = self .in_pattern [0 ]
1420
- elif isinstance (in_pattern , dict ):
1421
- self .op = self .in_pattern ["pattern" ][0 ]
1422
- else :
1423
- raise TypeError (
1424
- "The pattern to search for must start with a specific Op instance."
1425
- )
1426
1473
self .allow_multiple_clients = allow_multiple_clients
1427
1474
if name :
1428
1475
self .__name__ = name
1429
- self ._tracks = tracks
1430
1476
self .get_nodes = get_nodes
1431
1477
if tracks != ():
1432
- assert get_nodes
1478
+ if not get_nodes :
1479
+ raise ValueError ("Custom `tracks` requires `get_nodes` to be provided." )
1480
+ self ._tracks = tracks
1481
+ else :
1482
+ if isinstance (in_pattern , list | tuple ):
1483
+ op = self .in_pattern [0 ]
1484
+ elif isinstance (in_pattern , dict ):
1485
+ op = self .in_pattern ["pattern" ][0 ]
1486
+ else :
1487
+ raise TypeError (
1488
+ "The pattern to search for must start with a specific Op instance."
1489
+ )
1490
+ if isinstance (op , Op ):
1491
+ self ._tracks = [op ]
1492
+ elif isinstance (op , OpInstance ):
1493
+ self ._tracks = [op .op_type ]
1494
+ else :
1495
+ raise ValueError (
1496
+ f"The pattern to search for must start with a specific Op instance or an OpInstance class. "
1497
+ f"Got { op } , with type { type (op )} ."
1498
+ )
1433
1499
1434
1500
def tracks (self ):
1435
- if self ._tracks != ():
1436
- return self ._tracks
1437
- return [self .op ]
1501
+ return self ._tracks
1438
1502
1439
1503
def transform (self , fgraph , node , get_nodes = True ):
1440
1504
"""Check if the graph from node corresponds to ``in_pattern``.
@@ -1455,28 +1519,39 @@ def transform(self, fgraph, node, get_nodes=True):
1455
1519
# PatternNodeRewriter doesn't support replacing multi-output nodes
1456
1520
return False
1457
1521
1458
- s = unify (self .in_pattern , node .out )
1522
+ s = unify (self .in_pattern , node .out , {} )
1459
1523
1460
1524
if s is False :
1461
1525
return False
1462
1526
1463
- ret = reify (self .out_pattern , s )
1464
-
1465
- if isinstance (ret , ExpressionTuple ):
1466
- ret = ret .evaled_obj
1467
-
1468
- if self .values_eq_approx :
1469
- ret .tag .values_eq_approx = self .values_eq_approx
1470
-
1471
1527
if not self .allow_multiple_clients :
1472
- input_vars = list (s .values ())
1528
+ input_vars = set (s .values ())
1529
+ clients = fgraph .clients
1473
1530
if any (
1474
- len (fgraph . clients [v ]) > 1
1531
+ len (clients [v ]) > 1
1475
1532
for v in vars_between (input_vars , node .inputs )
1476
1533
if v not in input_vars
1477
1534
):
1478
1535
return False
1479
1536
1537
+ if callable (self .out_pattern ):
1538
+ # token is the variable name used in the original pattern
1539
+ ret = self .out_pattern (fgraph , node , {k .token : v for k , v in s .items ()})
1540
+ if ret is None or ret is False :
1541
+ # The output function is still allowed to reject the rewrite
1542
+ return False
1543
+ if not isinstance (ret , Variable ):
1544
+ raise ValueError (
1545
+ f"The output of the PatternNodeRewriter callable must be a variable got { ret } of type { type (ret )} ."
1546
+ )
1547
+ else :
1548
+ ret = reify (self .out_pattern , s )
1549
+ if isinstance (ret , ExpressionTuple ):
1550
+ ret = ret .evaled_obj
1551
+
1552
+ if self .values_eq_approx :
1553
+ ret .tag .values_eq_approx = self .values_eq_approx
1554
+
1480
1555
[old_out ] = node .outputs
1481
1556
if not old_out .type .is_super (ret .type ):
1482
1557
from pytensor .tensor .type import TensorType
0 commit comments