Skip to content

Commit d95288d

Browse files
committed
Tweaks to model_graph to play nice with XTensorVariables
* Use RV Op name when provided * More robust detection of observed data variables (after #7656 arbitrary graphs are allowed) * Remove self loops explicitly (closes #7722)
1 parent 6d29c79 commit d95288d

File tree

1 file changed

+22
-28
lines changed

1 file changed

+22
-28
lines changed

pymc/model_graph.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@
2222

2323
from pytensor import function
2424
from pytensor.graph import Apply
25-
from pytensor.graph.basic import ancestors, walk
26-
from pytensor.scalar.basic import Cast
27-
from pytensor.tensor.elemwise import Elemwise
25+
from pytensor.graph.basic import Variable, ancestors, walk
2826
from pytensor.tensor.random.op import RandomVariable
2927
from pytensor.tensor.shape import Shape
3028
from pytensor.tensor.variable import TensorVariable
3129

3230
from pymc.model.core import modelcontext
31+
from pymc.pytensorf import _cheap_eval_mode
3332
from pymc.util import VarName, get_default_varnames, get_var_name
3433

3534
__all__ = (
@@ -77,7 +76,7 @@ def create_plate_label_with_dim_length(
7776

7877

7978
def fast_eval(var):
80-
return function([], var, mode="FAST_COMPILE")()
79+
return function([], var, mode=_cheap_eval_mode)()
8180

8281

8382
class NodeType(str, Enum):
@@ -124,12 +123,14 @@ def default_potential(var: TensorVariable) -> GraphvizNodeKwargs:
124123
}
125124

126125

127-
def random_variable_symbol(var: TensorVariable) -> str:
126+
def random_variable_symbol(var: Variable) -> str:
128127
"""Get the symbol of the random variable."""
129-
symbol = var.owner.op.__class__.__name__
128+
op = var.owner.op
130129

131-
if symbol.endswith("RV"):
132-
symbol = symbol[:-2]
130+
if name := getattr(op, "name", None):
131+
symbol = name[0].upper() + name[1:]
132+
else:
133+
symbol = op.__class__.__name__.removesuffix("RV")
133134

134135
return symbol
135136

@@ -319,28 +320,21 @@ def make_compute_graph(
319320
input_map[var_name] = input_map[var_name].union(parent_name)
320321

321322
if var in self.model.observed_RVs:
322-
obs_node = self.model.rvs_to_values[var]
323-
324-
# loop created so that the elif block can go through this again
325-
# and remove any intermediate ops, notably dtype casting, to observations
326-
while True:
327-
obs_name = obs_node.name
328-
if obs_name and obs_name != var_name:
323+
# Make observed `Data` variables flow from the observed RV, and not the other way around
324+
# (In the generative graph they usually inform shape of the observed RV)
325+
# We have to iterate over the ancestors of the observed values because there can be
326+
# deterministic operations in between the `Data` variable and the observed value.
327+
obs_var = self.model.rvs_to_values[var]
328+
for ancestor in ancestors([obs_var]):
329+
if (obs_name := ancestor.name) in input_map:
329330
input_map[var_name] = input_map[var_name].difference({obs_name})
330331
input_map[obs_name] = input_map[obs_name].union({var_name})
331-
break
332-
elif (
333-
# for cases where observations are cast to a certain dtype
334-
# see issue 5795: https://github.com/pymc-devs/pymc/issues/5795
335-
obs_node.owner
336-
and isinstance(obs_node.owner.op, Elemwise)
337-
and isinstance(obs_node.owner.op.scalar_op, Cast)
338-
):
339-
# we can retrieve the observation node by going up the graph
340-
obs_node = obs_node.owner.inputs[0]
341-
else:
332+
# break assumes observed values can depend on only one `Data` variable
342333
break
343334

335+
# Remove self references
336+
for var_name in input_map:
337+
input_map[var_name] = input_map[var_name].difference({var_name})
344338
return input_map
345339

346340
def get_plates(
@@ -360,13 +354,13 @@ def get_plates(
360354
plates = defaultdict(set)
361355

362356
# TODO: Evaluate all RV shapes at once
363-
# This should help find discrepencies, and
357+
# This should help find discrepancies, and
364358
# avoids unnecessary function compiles for determining labels.
365359
dim_lengths: dict[str, int] = {
366360
dim_name: fast_eval(value).item() for dim_name, value in self.model.dim_lengths.items()
367361
}
368362
var_shapes: dict[str, tuple[int, ...]] = {
369-
var_name: tuple(fast_eval(self.model[var_name].shape))
363+
var_name: tuple(map(int, fast_eval(self.model[var_name].shape)))
370364
for var_name in self.vars_to_plot(var_names)
371365
}
372366

0 commit comments

Comments
 (0)