21
21
from typing import Any , cast
22
22
23
23
from pytensor import function
24
- from pytensor .graph .basic import ancestors , walk
24
+ from pytensor .graph .basic import Variable , ancestors , walk
25
25
from pytensor .tensor .shape import Shape
26
- from pytensor .tensor .variable import TensorVariable
27
26
28
27
from pymc .model .core import modelcontext
28
+ from pymc .pytensorf import _cheap_eval_mode
29
29
from pymc .util import VarName , get_default_varnames , get_var_name
30
30
31
31
__all__ = (
@@ -73,7 +73,7 @@ def create_plate_label_with_dim_length(
73
73
74
74
75
75
def fast_eval (var ):
76
- return function ([], var , mode = "FAST_COMPILE" )()
76
+ return function ([], var , mode = _cheap_eval_mode )()
77
77
78
78
79
79
class NodeType (str , Enum ):
@@ -88,7 +88,7 @@ class NodeType(str, Enum):
88
88
89
89
@dataclass
90
90
class NodeInfo :
91
- var : TensorVariable
91
+ var : Variable
92
92
node_type : NodeType
93
93
94
94
def __hash__ (self ):
@@ -108,10 +108,10 @@ def __eq__(self, other) -> bool:
108
108
109
109
110
110
GraphvizNodeKwargs = dict [str , Any ]
111
- NodeFormatter = Callable [[TensorVariable ], GraphvizNodeKwargs ]
111
+ NodeFormatter = Callable [[Variable ], GraphvizNodeKwargs ]
112
112
113
113
114
- def default_potential (var : TensorVariable ) -> GraphvizNodeKwargs :
114
+ def default_potential (var : Variable ) -> GraphvizNodeKwargs :
115
115
"""Return default data for potential in the graph."""
116
116
return {
117
117
"shape" : "octagon" ,
@@ -120,17 +120,19 @@ def default_potential(var: TensorVariable) -> GraphvizNodeKwargs:
120
120
}
121
121
122
122
123
- def random_variable_symbol (var : TensorVariable ) -> str :
123
+ def random_variable_symbol (var : Variable ) -> str :
124
124
"""Get the symbol of the random variable."""
125
- symbol = var .owner .op . __class__ . __name__
125
+ op = var .owner .op
126
126
127
- if symbol .endswith ("RV" ):
128
- symbol = symbol [:- 2 ]
127
+ if name := getattr (op , "name" , None ):
128
+ symbol = name [0 ].upper () + name [1 :]
129
+ else :
130
+ symbol = op .__class__ .__name__ .removesuffix ("RV" )
129
131
130
132
return symbol
131
133
132
134
133
- def default_free_rv (var : TensorVariable ) -> GraphvizNodeKwargs :
135
+ def default_free_rv (var : Variable ) -> GraphvizNodeKwargs :
134
136
"""Return default data for free RV in the graph."""
135
137
symbol = random_variable_symbol (var )
136
138
@@ -141,7 +143,7 @@ def default_free_rv(var: TensorVariable) -> GraphvizNodeKwargs:
141
143
}
142
144
143
145
144
- def default_observed_rv (var : TensorVariable ) -> GraphvizNodeKwargs :
146
+ def default_observed_rv (var : Variable ) -> GraphvizNodeKwargs :
145
147
"""Return default data for observed RV in the graph."""
146
148
symbol = random_variable_symbol (var )
147
149
@@ -152,7 +154,7 @@ def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs:
152
154
}
153
155
154
156
155
- def default_deterministic (var : TensorVariable ) -> GraphvizNodeKwargs :
157
+ def default_deterministic (var : Variable ) -> GraphvizNodeKwargs :
156
158
"""Return default data for the deterministic in the graph."""
157
159
return {
158
160
"shape" : "box" ,
@@ -161,7 +163,7 @@ def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs:
161
163
}
162
164
163
165
164
- def default_data (var : TensorVariable ) -> GraphvizNodeKwargs :
166
+ def default_data (var : Variable ) -> GraphvizNodeKwargs :
165
167
"""Return default data for the data in the graph."""
166
168
return {
167
169
"shape" : "box" ,
@@ -239,7 +241,7 @@ def __init__(self, model):
239
241
self ._all_vars = {model [var_name ] for var_name in self ._all_var_names }
240
242
self .var_list = self .model .named_vars .values ()
241
243
242
- def get_parent_names (self , var : TensorVariable ) -> set [VarName ]:
244
+ def get_parent_names (self , var : Variable ) -> set [VarName ]:
243
245
if var .owner is None :
244
246
return set ()
245
247
@@ -345,7 +347,7 @@ def get_plates(
345
347
dim_name : fast_eval (value ).item () for dim_name , value in self .model .dim_lengths .items ()
346
348
}
347
349
var_shapes : dict [str , tuple [int , ...]] = {
348
- var_name : tuple (fast_eval (self .model [var_name ].shape ))
350
+ var_name : tuple (map ( int , fast_eval (self .model [var_name ].shape ) ))
349
351
for var_name in self .vars_to_plot (var_names )
350
352
}
351
353
0 commit comments