5959import xarray
6060
6161from pytensor .graph .basic import Variable
62+ from pytensor .graph .replace import graph_replace
63+ from pytensor .tensor .shape import unbroadcast
6264
6365import pymc as pm
6466
@@ -1002,7 +1004,7 @@ def set_size_and_deterministic(
10021004 """
10031005
10041006 flat2rand = self .make_size_and_deterministic_replacements (s , d , more_replacements )
1005- node_out = pytensor . clone_replace (node , flat2rand )
1007+ node_out = graph_replace (node , flat2rand , strict = False )
10061008 assert not (
10071009 set (makeiter (self .input )) & set (pytensor .graph .graph_inputs (makeiter (node_out )))
10081010 )
@@ -1012,7 +1014,7 @@ def set_size_and_deterministic(
10121014
10131015 def to_flat_input (self , node ):
10141016 """*Dev* - replace vars with flattened view stored in `self.inputs`"""
1015- return pytensor . clone_replace (node , self .replacements )
1017+ return graph_replace (node , self .replacements , strict = False )
10161018
10171019 def symbolic_sample_over_posterior (self , node ):
10181020 """*Dev* - performs sampling of node applying independent samples from posterior each time.
@@ -1023,7 +1025,7 @@ def symbolic_sample_over_posterior(self, node):
10231025 random = pt .specify_shape (random , self .symbolic_initial .type .shape )
10241026
10251027 def sample (post , * _ ):
1026- return pytensor . clone_replace (node , {self .input : post })
1028+ return graph_replace (node , {self .input : post }, strict = False )
10271029
10281030 nodes , _ = pytensor .scan (
10291031 sample , random , non_sequences = _known_scan_ignored_inputs (makeiter (random ))
@@ -1038,7 +1040,7 @@ def symbolic_single_sample(self, node):
10381040 """
10391041 node = self .to_flat_input (node )
10401042 random = self .symbolic_random .astype (self .symbolic_initial .dtype )
1041- return pytensor . clone_replace (node , {self .input : random [0 ]})
1043+ return graph_replace (node , {self .input : random [0 ]}, strict = False )
10421044
10431045 def make_size_and_deterministic_replacements (self , s , d , more_replacements = None ):
10441046 """*Dev* - creates correct replacements for initial depending on
@@ -1059,8 +1061,15 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None)
10591061 """
10601062 initial = self ._new_initial (s , d , more_replacements )
10611063 initial = pt .specify_shape (initial , self .symbolic_initial .type .shape )
1064+ # The static shape of initial may be more precise than self.symbolic_initial,
1065+ # and reveal previously unknown broadcastable dimensions. We have to mask those again.
1066+ if initial .type .broadcastable != self .symbolic_initial .type .broadcastable :
1067+ unbroadcast_axes = (
1068+ i for i , b in enumerate (self .symbolic_initial .type .broadcastable ) if not b
1069+ )
1070+ initial = unbroadcast (initial , * unbroadcast_axes )
10621071 if more_replacements :
1063- initial = pytensor . clone_replace (initial , more_replacements )
1072+ initial = graph_replace (initial , more_replacements , strict = False )
10641073 return {self .symbolic_initial : initial }
10651074
10661075 @node_property
@@ -1394,17 +1403,17 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None):
13941403 _node = node
13951404 optimizations = self .get_optimization_replacements (s , d )
13961405 flat2rand = self .make_size_and_deterministic_replacements (s , d , more_replacements )
1397- node = pytensor . clone_replace (node , optimizations )
1398- node = pytensor . clone_replace (node , flat2rand )
1406+ node = graph_replace (node , optimizations , strict = False )
1407+ node = graph_replace (node , flat2rand , strict = False )
13991408 assert not (set (self .symbolic_randoms ) & set (pytensor .graph .graph_inputs (makeiter (node ))))
14001409 try_to_set_test_value (_node , node , s )
14011410 return node
14021411
14031412 def to_flat_input (self , node , more_replacements = None ):
14041413 """*Dev* - replace vars with flattened view stored in `self.inputs`"""
14051414 more_replacements = more_replacements or {}
1406- node = pytensor . clone_replace (node , more_replacements )
1407- return pytensor . clone_replace (node , self .replacements )
1415+ node = graph_replace (node , more_replacements , strict = False )
1416+ return graph_replace (node , self .replacements , strict = False )
14081417
14091418 def symbolic_sample_over_posterior (self , node , more_replacements = None ):
14101419 """*Dev* - performs sampling of node applying independent samples from posterior each time.
@@ -1413,7 +1422,7 @@ def symbolic_sample_over_posterior(self, node, more_replacements=None):
14131422 node = self .to_flat_input (node )
14141423
14151424 def sample (* post ):
1416- return pytensor . clone_replace (node , dict (zip (self .inputs , post )))
1425+ return graph_replace (node , dict (zip (self .inputs , post )), strict = False )
14171426
14181427 nodes , _ = pytensor .scan (
14191428 sample , self .symbolic_randoms , non_sequences = _known_scan_ignored_inputs (makeiter (node ))
@@ -1429,7 +1438,7 @@ def symbolic_single_sample(self, node, more_replacements=None):
14291438 node = self .to_flat_input (node , more_replacements = more_replacements )
14301439 post = [v [0 ] for v in self .symbolic_randoms ]
14311440 inp = self .inputs
1432- return pytensor . clone_replace (node , dict (zip (inp , post )))
1441+ return graph_replace (node , dict (zip (inp , post )), strict = False )
14331442
14341443 def get_optimization_replacements (self , s , d ):
14351444 """*Dev* - optimizations for logP. If sample size is static and equal to 1:
@@ -1463,7 +1472,7 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No
14631472 """
14641473 node_in = node
14651474 if more_replacements :
1466- node = pytensor . clone_replace (node , more_replacements )
1475+ node = graph_replace (node , more_replacements , strict = False )
14671476 if not isinstance (node , (list , tuple )):
14681477 node = [node ]
14691478 node = self .model .replace_rvs_by_values (node )
0 commit comments