22
22
23
23
from pytensor import function
24
24
from pytensor .graph import Apply
25
- from pytensor .graph .basic import ancestors , walk
26
- from pytensor .scalar .basic import Cast
27
- from pytensor .tensor .elemwise import Elemwise
25
+ from pytensor .graph .basic import Variable , ancestors , walk
28
26
from pytensor .tensor .random .op import RandomVariable
29
27
from pytensor .tensor .shape import Shape
30
28
from pytensor .tensor .variable import TensorVariable
31
29
32
30
from pymc .model .core import modelcontext
31
+ from pymc .pytensorf import _cheap_eval_mode
33
32
from pymc .util import VarName , get_default_varnames , get_var_name
34
33
35
34
__all__ = (
@@ -77,7 +76,7 @@ def create_plate_label_with_dim_length(
77
76
78
77
79
78
def fast_eval (var ):
80
- return function ([], var , mode = "FAST_COMPILE" )()
79
+ return function ([], var , mode = _cheap_eval_mode )()
81
80
82
81
83
82
class NodeType (str , Enum ):
@@ -124,12 +123,14 @@ def default_potential(var: TensorVariable) -> GraphvizNodeKwargs:
124
123
}
125
124
126
125
127
- def random_variable_symbol (var : TensorVariable ) -> str :
126
+ def random_variable_symbol (var : Variable ) -> str :
128
127
"""Get the symbol of the random variable."""
129
- symbol = var .owner .op . __class__ . __name__
128
+ op = var .owner .op
130
129
131
- if symbol .endswith ("RV" ):
132
- symbol = symbol [:- 2 ]
130
+ if name := getattr (op , "name" , None ):
131
+ symbol = name [0 ].upper () + name [1 :]
132
+ else :
133
+ symbol = op .__class__ .__name__ .removesuffix ("RV" )
133
134
134
135
return symbol
135
136
@@ -319,28 +320,21 @@ def make_compute_graph(
319
320
input_map [var_name ] = input_map [var_name ].union (parent_name )
320
321
321
322
if var in self .model .observed_RVs :
322
- obs_node = self . model . rvs_to_values [ var ]
323
-
324
- # loop created so that the elif block can go through this again
325
- # and remove any intermediate ops, notably dtype casting, to observations
326
- while True :
327
- obs_name = obs_node . name
328
- if obs_name and obs_name != var_name :
323
+ # Make observed `Data` variables flow from the observed RV, and not the other way around
324
+ # (In the generative graph they usually inform shape of the observed RV)
325
+ # We have to iterate over the ancestors of the observed values because there can be
326
+ # deterministic operations in between the `Data` variable and the observed value.
327
+ obs_var = self . model . rvs_to_values [ var ]
328
+ for ancestor in ancestors ([ obs_var ]):
329
+ if ( obs_name := ancestor . name ) in input_map :
329
330
input_map [var_name ] = input_map [var_name ].difference ({obs_name })
330
331
input_map [obs_name ] = input_map [obs_name ].union ({var_name })
331
- break
332
- elif (
333
- # for cases where observations are cast to a certain dtype
334
- # see issue 5795: https://github.com/pymc-devs/pymc/issues/5795
335
- obs_node .owner
336
- and isinstance (obs_node .owner .op , Elemwise )
337
- and isinstance (obs_node .owner .op .scalar_op , Cast )
338
- ):
339
- # we can retrieve the observation node by going up the graph
340
- obs_node = obs_node .owner .inputs [0 ]
341
- else :
332
+ # break assumes observed values can depend on only one `Data` variable
342
333
break
343
334
335
+ # Remove self references
336
+ for var_name in input_map :
337
+ input_map [var_name ] = input_map [var_name ].difference ({var_name })
344
338
return input_map
345
339
346
340
def get_plates (
@@ -360,13 +354,13 @@ def get_plates(
360
354
plates = defaultdict (set )
361
355
362
356
# TODO: Evaluate all RV shapes at once
363
- # This should help find discrepencies , and
357
+ # This should help find discrepancies , and
364
358
# avoids unnecessary function compiles for determining labels.
365
359
dim_lengths : dict [str , int ] = {
366
360
dim_name : fast_eval (value ).item () for dim_name , value in self .model .dim_lengths .items ()
367
361
}
368
362
var_shapes : dict [str , tuple [int , ...]] = {
369
- var_name : tuple (fast_eval (self .model [var_name ].shape ))
363
+ var_name : tuple (map ( int , fast_eval (self .model [var_name ].shape ) ))
370
364
for var_name in self .vars_to_plot (var_names )
371
365
}
372
366
0 commit comments