Skip to content

Commit 17fbc90

Browse files
committed
Test PatternNodeRewriter doesn't support multi-output nodes in pattern
But it's fine if they're just root inputs
1 parent 58de233 commit 17fbc90

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

tests/graph/rewriting/test_basic.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
op_y,
4242
op_z,
4343
)
44+
from tests.unittest_tools import assert_equal_computations
4445

4546

4647
class AssertNoChanges(Feature):
@@ -725,22 +726,35 @@ def test_patternsub_invalid_dtype(out_pattern):
725726
assert e.type.is_super(fg.outputs[0].type)
726727

727728

728-
def test_patternsub_different_output_lengths():
729-
# Test that PatternNodeRewriter won't replace nodes with different numbers of outputs
730-
ps = PatternNodeRewriter(
731-
(op1, "x"),
732-
("x"),
729+
def test_patternsub_multi_output_nodes():
730+
# Test that PatternNodeRewriter won't attempt to replace multi-output nodes
731+
multiple_op_ps = PatternNodeRewriter(
732+
(op_multiple_outputs, "x"),
733+
"x",
733734
name="ps",
734735
)
735-
rewriter = in2out(ps)
736+
737+
single_op_ps = PatternNodeRewriter(
738+
(op_y, "x"),
739+
"x",
740+
name="ps",
741+
)
742+
743+
rewriter = in2out(multiple_op_ps, single_op_ps)
736744

737745
x = MyVariable("x")
738746
e1, e2 = op_multiple_outputs(x)
739-
o = op1(e1)
747+
o1, o2 = op_y(e1), op_y(e2)
748+
749+
fgraph = FunctionGraph(inputs=[x], outputs=[e2, e1], copy_inputs=False)
750+
rewriter.rewrite(fgraph)
751+
# This shouldn't rewrite because PatternNodeRewriter has no way of specifying which output(s) are being matched
752+
assert_equal_computations(fgraph.outputs, [e2, e1])
740753

741-
fgraph = FunctionGraph(inputs=[x], outputs=[o])
754+
fgraph = FunctionGraph(inputs=[x], outputs=[o2, o1], copy_inputs=False)
742755
rewriter.rewrite(fgraph)
743-
assert fgraph.outputs[0].owner.op == op1
756+
# Having a variable that comes out of a multi-output node should be fine
757+
assert_equal_computations(fgraph.outputs, [e2, e1])
744758

745759

746760
class TestSequentialNodeRewriter:

tests/graph/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def make_node(self, *inputs):
107107

108108

109109
class MyOpMultipleOutputs(MyOp):
110+
def __init__(self, name, dmap=None, x=None):
111+
super().__init__(name=name, dmap=dmap, x=x, n_outs=2)
112+
110113
def make_node(self, input):
111114
outputs = [input.type(), input.type()]
112115
return Apply(self, [input], outputs)

0 commit comments

Comments
 (0)