|
15 | 15 | from pytensor.graph.rewriting.basic import check_stack_trace, node_rewriter, out2in
|
16 | 16 | from pytensor.graph.rewriting.utils import rewrite_graph
|
17 | 17 | from pytensor.graph.type import Type
|
18 |
| -from pytensor.tensor.basic import as_tensor_variable |
| 18 | +from pytensor.tensor.basic import alloc, as_tensor_variable |
19 | 19 | from pytensor.tensor.elemwise import DimShuffle, Elemwise
|
20 | 20 | from pytensor.tensor.math import add, exp, maximum
|
21 | 21 | from pytensor.tensor.rewriting.basic import register_specialize
|
@@ -239,6 +239,25 @@ def test_no_shapeopt(self):
|
239 | 239 | # FIXME: This is not a good test.
|
240 | 240 | f([[1, 2], [2, 3]])
|
241 | 241 |
|
| 242 | + def test_shape_of_useless_alloc(self): |
| 243 | + """Test that local_shape_to_shape_i does not create circular graph. |
| 244 | +
|
| 245 | + Regression test for #565 |
| 246 | + """ |
| 247 | + alpha = vector(shape=(None,), dtype="float64") |
| 248 | + channel = vector(shape=(None,), dtype="float64") |
| 249 | + |
| 250 | + broadcast_channel = alloc( |
| 251 | + channel, |
| 252 | + maximum( |
| 253 | + shape(alpha)[0], |
| 254 | + shape(channel)[0], |
| 255 | + ), |
| 256 | + ) |
| 257 | + out = shape(broadcast_channel) |
| 258 | + fn = function([alpha, channel], out) |
| 259 | + assert fn([1.0, 2, 3], [1.0, 2, 3]) == (3,) |
| 260 | + |
242 | 261 |
|
243 | 262 | class TestReshape:
|
244 | 263 | def setup_method(self):
|
|
0 commit comments