Skip to content

Commit 3d380e8

Browse files
committed
Tweaks to model_graph to play nice with XTensorVariables
* Use RV Op name when provided * More robust detection of observed data variables (after #7656 arbitrary graphs are allowed) * Remove self loops explicitly (closes #7722)
1 parent d11fae7 commit 3d380e8

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

pymc/model_graph.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
from typing import Any, cast
2222

2323
from pytensor import function
24-
from pytensor.graph.basic import ancestors, walk
24+
from pytensor.graph.basic import Variable, ancestors, walk
2525
from pytensor.tensor.shape import Shape
26-
from pytensor.tensor.variable import TensorVariable
2726

2827
from pymc.model.core import modelcontext
28+
from pymc.pytensorf import _cheap_eval_mode
2929
from pymc.util import VarName, get_default_varnames, get_var_name
3030

3131
__all__ = (
@@ -73,7 +73,7 @@ def create_plate_label_with_dim_length(
7373

7474

7575
def fast_eval(var):
76-
return function([], var, mode="FAST_COMPILE")()
76+
return function([], var, mode=_cheap_eval_mode)()
7777

7878

7979
class NodeType(str, Enum):
@@ -88,7 +88,7 @@ class NodeType(str, Enum):
8888

8989
@dataclass
9090
class NodeInfo:
91-
var: TensorVariable
91+
var: Variable
9292
node_type: NodeType
9393

9494
def __hash__(self):
@@ -108,10 +108,10 @@ def __eq__(self, other) -> bool:
108108

109109

110110
GraphvizNodeKwargs = dict[str, Any]
111-
NodeFormatter = Callable[[TensorVariable], GraphvizNodeKwargs]
111+
NodeFormatter = Callable[[Variable], GraphvizNodeKwargs]
112112

113113

114-
def default_potential(var: TensorVariable) -> GraphvizNodeKwargs:
114+
def default_potential(var: Variable) -> GraphvizNodeKwargs:
115115
"""Return default data for potential in the graph."""
116116
return {
117117
"shape": "octagon",
@@ -120,17 +120,19 @@ def default_potential(var: TensorVariable) -> GraphvizNodeKwargs:
120120
}
121121

122122

123-
def random_variable_symbol(var: TensorVariable) -> str:
123+
def random_variable_symbol(var: Variable) -> str:
124124
"""Get the symbol of the random variable."""
125-
symbol = var.owner.op.__class__.__name__
125+
op = var.owner.op
126126

127-
if symbol.endswith("RV"):
128-
symbol = symbol[:-2]
127+
if name := getattr(op, "name", None):
128+
symbol = name[0].upper() + name[1:]
129+
else:
130+
symbol = op.__class__.__name__.removesuffix("RV")
129131

130132
return symbol
131133

132134

133-
def default_free_rv(var: TensorVariable) -> GraphvizNodeKwargs:
135+
def default_free_rv(var: Variable) -> GraphvizNodeKwargs:
134136
"""Return default data for free RV in the graph."""
135137
symbol = random_variable_symbol(var)
136138

@@ -141,7 +143,7 @@ def default_free_rv(var: TensorVariable) -> GraphvizNodeKwargs:
141143
}
142144

143145

144-
def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs:
146+
def default_observed_rv(var: Variable) -> GraphvizNodeKwargs:
145147
"""Return default data for observed RV in the graph."""
146148
symbol = random_variable_symbol(var)
147149

@@ -152,7 +154,7 @@ def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs:
152154
}
153155

154156

155-
def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs:
157+
def default_deterministic(var: Variable) -> GraphvizNodeKwargs:
156158
"""Return default data for the deterministic in the graph."""
157159
return {
158160
"shape": "box",
@@ -161,7 +163,7 @@ def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs:
161163
}
162164

163165

164-
def default_data(var: TensorVariable) -> GraphvizNodeKwargs:
166+
def default_data(var: Variable) -> GraphvizNodeKwargs:
165167
"""Return default data for the data in the graph."""
166168
return {
167169
"shape": "box",
@@ -239,7 +241,7 @@ def __init__(self, model):
239241
self._all_vars = {model[var_name] for var_name in self._all_var_names}
240242
self.var_list = self.model.named_vars.values()
241243

242-
def get_parent_names(self, var: TensorVariable) -> set[VarName]:
244+
def get_parent_names(self, var: Variable) -> set[VarName]:
243245
if var.owner is None:
244246
return set()
245247

@@ -345,7 +347,7 @@ def get_plates(
345347
dim_name: fast_eval(value).item() for dim_name, value in self.model.dim_lengths.items()
346348
}
347349
var_shapes: dict[str, tuple[int, ...]] = {
348-
var_name: tuple(fast_eval(self.model[var_name].shape))
350+
var_name: tuple(map(int, fast_eval(self.model[var_name].shape)))
349351
for var_name in self.vars_to_plot(var_names)
350352
}
351353

tests/test_model_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ class TestVariableSelection:
470470
[
471471
(["c"], ["a", "b", "c"], {"c": {"a", "b"}, "a": set(), "b": set()}),
472472
(
473-
["L"],
473+
["L", "obs"],
474474
["pred", "obs", "L", "intermediate", "a", "b"],
475475
{
476476
"pred": {"intermediate"},

0 commit comments

Comments
 (0)