4141from pytensor import tensor as pt
4242from pytensor .graph .op import compute_test_value
4343from pytensor .graph .rewriting .basic import node_rewriter
44- from pytensor .tensor .basic import Join , MakeVector
44+ from pytensor .tensor .basic import Alloc , Join , MakeVector
4545from pytensor .tensor .elemwise import DimShuffle
46- from pytensor .tensor .extra_ops import BroadcastTo
4746from pytensor .tensor .random .op import RandomVariable
4847from pytensor .tensor .random .rewriting import (
4948 local_dimshuffle_rv_lift ,
5958from pymc .logprob .utils import check_potential_measurability
6059
6160
62- @node_rewriter ([BroadcastTo ])
61+ @node_rewriter ([Alloc ])
6362def naive_bcast_rv_lift (fgraph , node ):
64- """Lift a ``BroadcastTo `` through a ``RandomVariable`` ``Op``.
63+ """Lift an ``Alloc `` through a ``RandomVariable`` ``Op``.
6564
6665 XXX: This implementation simply broadcasts the ``RandomVariable``'s
6766 parameters, which won't always work (e.g. multivariate distributions).
@@ -73,7 +72,7 @@ def naive_bcast_rv_lift(fgraph, node):
7372 """
7473
7574 if not (
76- isinstance (node .op , BroadcastTo )
75+ isinstance (node .op , Alloc )
7776 and node .inputs [0 ].owner
7877 and isinstance (node .inputs [0 ].owner .op , RandomVariable )
7978 ):
@@ -93,7 +92,7 @@ def naive_bcast_rv_lift(fgraph, node):
9392 return None
9493
9594 if not bcast_shape :
96- # The `BroadcastTo ` is broadcasting a scalar to a scalar (i.e. doing nothing)
95+ # The `Alloc ` is broadcasting a scalar to a scalar (i.e. doing nothing)
9796 assert rv_var .ndim == 0
9897 return [rv_var ]
9998
0 commit comments