Skip to content

Commit 58de233

Browse files
author
Luca Citi
committed
Addressed feedback by ricardoV94
1 parent 700b0d8 commit 58de233

File tree

3 files changed

+57
-30
lines changed

3 files changed

+57
-30
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,6 +1636,10 @@ def transform(self, fgraph, node, get_nodes=True):
16361636
if node.op != self.op:
16371637
return False
16381638

1639+
if len(node.outputs) != 1:
1640+
# PatternNodeRewriter doesn't support replacing multi-output nodes
1641+
return False
1642+
16391643
s = unify(self.in_pattern, node.out)
16401644

16411645
if s is False:
@@ -1658,27 +1662,20 @@ def transform(self, fgraph, node, get_nodes=True):
16581662
):
16591663
return False
16601664

1661-
if ret.owner:
1662-
if len(node.outputs) != len(ret.owner.outputs):
1663-
return
1664-
if len(node.outputs) > 1:
1665-
if not all(
1666-
o.type.is_super(new_o.type)
1667-
for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True)
1668-
):
1669-
return False
1670-
else:
1671-
if self.allow_cast:
1672-
out_dtype = getattr(node.outputs[0].type, "dtype", None)
1673-
ret_dtype = getattr(ret.owner.outputs[0].type, "dtype", None)
1674-
if ret_dtype != out_dtype:
1675-
ret = pytensor.tensor.basic.cast(ret, out_dtype)
1676-
if not node.outputs[0].type.is_super(ret.owner.outputs[0].type):
1677-
return False
1678-
else:
1679-
# ret is just an input variable
1680-
assert len(node.outputs) == 1
1681-
if not node.outputs[0].type.is_super(ret.type):
1665+
[old_out] = node.outputs
1666+
if not old_out.type.is_super(ret.type):
1667+
# Type doesn't match
1668+
if not (
1669+
self.allow_cast
1670+
and isinstance(old_out.type, pytensor.tensor.TensorType)
1671+
and isinstance(ret.type, pytensor.tensor.TensorType)
1672+
):
1673+
return False
1674+
1675+
# Try to cast tensors
1676+
ret = ret.astype(old_out.type.dtype)
1677+
if not old_out.type.is_super(ret.type):
1678+
# Still doesn't match
16821679
return False
16831680

16841681
return [ret]

tests/tensor/rewriting/test_math.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4142,16 +4142,10 @@ def test_local_1msigmoid(self):
41424142
(np.float32(1 - 9e-6) - sigmoid(x), np.float32(1 - 9e-6) - sigmoid(x)),
41434143
(np.float64(1 - 1e-9) - sigmoid(xd), np.float64(1 - 1e-9) - sigmoid(xd)),
41444144
]:
4145-
f = pytensor.function([x, xd], out, m, on_unused_input="ignore")
4146-
f_outs = f.maker.fgraph.outputs
4147-
assert equal_computations(
4148-
f_outs, [expected]
4149-
), "Expression:\n{}rewritten as:\n{}expected:\n{}".format(
4150-
*(
4151-
pytensor.dprint(expr, print_type=True, file="str")
4152-
for expr in (out, f_outs, expected)
4153-
)
4145+
rewritten = rewrite_graph(
4146+
out, include=["canonicalize", "specialize", "stabilize"]
41544147
)
4148+
utt.assert_equal_computations([rewritten], [expected], original=out)
41554149

41564150
def test_local_sigm_times_exp(self):
41574151
"""

tests/unittest_tools.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor.compile.debugmode import str_diagnostic
1212
from pytensor.configdefaults import config
1313
from pytensor.gradient import verify_grad as orig_verify_grad
14+
from pytensor.graph.basic import equal_computations
1415
from pytensor.tensor.basic import as_tensor_variable
1516
from pytensor.tensor.math import _allclose
1617
from pytensor.tensor.math import add as pt_add
@@ -279,6 +280,41 @@ def assert_allclose(expected, value, rtol=None, atol=None):
279280
raise WrongValue(expected, value, rtol, atol)
280281

281282

283+
def assert_equal_computations(rewritten, expected, *args, original=None, **kwargs):
284+
"""
285+
Assert that `rewritten` computes the same as `expected`.
286+
287+
Parameters
288+
----------
289+
rewritten
290+
The expression after the rewrite pass.
291+
expected
292+
The reference expression to compare against.
293+
*args, **kwargs
294+
Extra arguments forwarded to equal_computations.
295+
original : optional
296+
If given, will be printed in the error message.
297+
"""
298+
__tracebackhide__ = True # Hide traceback for py.test
299+
300+
ok = equal_computations(rewritten, expected, *args, **kwargs)
301+
302+
if not ok:
303+
parts = []
304+
305+
def _dprint(expr):
306+
return pytensor.dprint(expr, print_type=True, file="str")
307+
308+
if original is not None:
309+
parts.append(f"\nOriginal:\n{_dprint(original)}")
310+
parts.append(f"\nRewritten:\n{_dprint(rewritten)}")
311+
parts.append(f"\nExpected:\n{_dprint(expected)}")
312+
313+
raise AssertionError("equal_computations failed\n" + "".join(parts))
314+
315+
return True
316+
317+
282318
class AttemptManyTimes:
283319
"""Decorator for unit tests that forces a unit test to be attempted
284320
multiple times. The test needs to pass a certain number of times for it to

0 commit comments

Comments
 (0)