|
41 | 41 | op_y,
|
42 | 42 | op_z,
|
43 | 43 | )
|
| 44 | +from tests.unittest_tools import assert_equal_computations |
44 | 45 |
|
45 | 46 |
|
46 | 47 | class AssertNoChanges(Feature):
|
@@ -725,22 +726,35 @@ def test_patternsub_invalid_dtype(out_pattern):
|
725 | 726 | assert e.type.is_super(fg.outputs[0].type)
|
726 | 727 |
|
727 | 728 |
|
728 |
| -def test_patternsub_different_output_lengths(): |
729 |
| - # Test that PatternNodeRewriter won't replace nodes with different numbers of outputs |
730 |
| - ps = PatternNodeRewriter( |
731 |
| - (op1, "x"), |
732 |
| - ("x"), |
| 729 | +def test_patternsub_multi_output_nodes(): |
| 730 | + # Test that PatternNodeRewriter won't attempt to replace multi-output nodes |
| 731 | + multiple_op_ps = PatternNodeRewriter( |
| 732 | + (op_multiple_outputs, "x"), |
| 733 | + "x", |
733 | 734 | name="ps",
|
734 | 735 | )
|
735 |
| - rewriter = in2out(ps) |
| 736 | + |
| 737 | + single_op_ps = PatternNodeRewriter( |
| 738 | + (op_y, "x"), |
| 739 | + "x", |
| 740 | + name="ps", |
| 741 | + ) |
| 742 | + |
| 743 | + rewriter = in2out(multiple_op_ps, single_op_ps) |
736 | 744 |
|
737 | 745 | x = MyVariable("x")
|
738 | 746 | e1, e2 = op_multiple_outputs(x)
|
739 |
| - o = op1(e1) |
| 747 | + o1, o2 = op_y(e1), op_y(e2) |
| 748 | + |
| 749 | + fgraph = FunctionGraph(inputs=[x], outputs=[e2, e1], copy_inputs=False) |
| 750 | + rewriter.rewrite(fgraph) |
| 751 | + # This shouldn't rewrite because PatternNodeRewriter has no way of specifying which output(s) are being matched |
| 752 | + assert_equal_computations(fgraph.outputs, [e2, e1]) |
740 | 753 |
|
741 |
| - fgraph = FunctionGraph(inputs=[x], outputs=[o]) |
| 754 | + fgraph = FunctionGraph(inputs=[x], outputs=[o2, o1], copy_inputs=False) |
742 | 755 | rewriter.rewrite(fgraph)
|
743 |
| - assert fgraph.outputs[0].owner.op == op1 |
| 756 | + # Having a variable that comes out of a multi-output node should be fine |
| 757 | + assert_equal_computations(fgraph.outputs, [e2, e1]) |
744 | 758 |
|
745 | 759 |
|
746 | 760 | class TestSequentialNodeRewriter:
|
|
0 commit comments