Skip to content
5 changes: 2 additions & 3 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
)
from pymc.util import (
UNSET,
VarName,
WithMemoization,
_UnsetType,
get_transformed_name,
Expand Down Expand Up @@ -1945,7 +1944,7 @@ def debug_parameters(rv):
def to_graphviz(
self,
*,
var_names: Iterable[VarName] | None = None,
var_names: Iterable[str] | None = None,
formatting: str = "plain",
save: str | None = None,
figsize: tuple[int, int] | None = None,
Expand Down Expand Up @@ -2149,7 +2148,7 @@ def compile_fn(
)


def Point(*args, filter_model_vars=False, **kwargs) -> dict[VarName, np.ndarray]:
def Point(*args, filter_model_vars=False, **kwargs) -> dict[str, np.ndarray]:
"""Build a point.

Uses same args as dict() does.
Expand Down
36 changes: 17 additions & 19 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from pymc.model.core import modelcontext
from pymc.pytensorf import _cheap_eval_mode
from pymc.util import VarName, get_default_varnames, get_var_name
from pymc.util import get_default_varnames, get_var_name

__all__ = (
"ModelGraph",
Expand Down Expand Up @@ -172,7 +172,7 @@ def default_data(var: Variable) -> GraphvizNodeKwargs:
}


def get_node_type(var_name: VarName, model) -> NodeType:
def get_node_type(var_name: str, model) -> NodeType:
"""Return the node type of the variable in the model."""
v = model[var_name]

Expand Down Expand Up @@ -241,7 +241,7 @@ def __init__(self, model):
self._all_vars = {model[var_name] for var_name in self._all_var_names}
self.var_list = self.model.named_vars.values()

def get_parent_names(self, var: Variable) -> set[VarName]:
def get_parent_names(self, var: Variable) -> set[str]:
if var.owner is None:
return set()

Expand All @@ -260,12 +260,12 @@ def _expand(x):
return x.owner.inputs

return {
cast(VarName, ancestor.name) # type: ignore[union-attr]
cast(str, ancestor.name) # type: ignore[union-attr]
for ancestor in walk(nodes=var.owner.inputs, expand=_expand)
if ancestor in named_vars
}

def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]:
def vars_to_plot(self, var_names: Iterable[str] | None = None) -> list[str]:
if var_names is None:
return self._all_var_names

Expand Down Expand Up @@ -295,13 +295,11 @@ def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarNa
# ordering of self._all_var_names is important
return [get_var_name(var) for var in selected_ancestors]

def make_compute_graph(
self, var_names: Iterable[VarName] | None = None
) -> dict[VarName, set[VarName]]:
def make_compute_graph(self, var_names: Iterable[str] | None = None) -> dict[str, set[str]]:
"""Get map of var_name -> set(input var names) for the model."""
model = self.model
named_vars = self._all_vars
input_map: dict[VarName, set[VarName]] = defaultdict(set)
input_map: dict[str, set[str]] = defaultdict(set)

var_names_to_plot = self.vars_to_plot(var_names)
for var_name in var_names_to_plot:
Expand All @@ -318,15 +316,15 @@ def make_compute_graph(
for ancestor in ancestors([obs_var]):
if ancestor not in named_vars:
continue
obs_name = cast(VarName, ancestor.name)
obs_name = cast(str, ancestor.name)
input_map[var_name].discard(obs_name)
input_map[obs_name].add(var_name)

return input_map

def get_plates(
self,
var_names: Iterable[VarName] | None = None,
var_names: Iterable[str] | None = None,
) -> list[Plate]:
"""Rough but surprisingly accurate plate detection.

Expand All @@ -336,7 +334,7 @@ def get_plates(
Returns
-------
dict
Maps plate labels to the set of ``VarName``s inside the plate.
Maps plate labels to the set of strings inside the plate.
"""
plates = defaultdict(set)

Expand Down Expand Up @@ -388,8 +386,8 @@ def get_plates(

def edges(
self,
var_names: Iterable[VarName] | None = None,
) -> list[tuple[VarName, VarName]]:
var_names: Iterable[str] | None = None,
) -> list[tuple[str, str]]:
"""Get edges between the variables in the model.

Parameters
Expand All @@ -404,7 +402,7 @@ def edges(

"""
return [
(VarName(child.replace(":", "&")), VarName(parent.replace(":", "&")))
(str(child.replace(":", "&")), str(parent.replace(":", "&")))
for child, parents in self.make_compute_graph(var_names=var_names).items()
for parent in parents
]
Expand All @@ -421,7 +419,7 @@ def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]:
def make_graph(
name: str,
plates: list[Plate],
edges: list[tuple[VarName, VarName]],
edges: list[tuple[str, str]],
formatting: str = "plain",
save=None,
figsize=None,
Expand Down Expand Up @@ -495,7 +493,7 @@ def make_graph(
def make_networkx(
name: str,
plates: list[Plate],
edges: list[tuple[VarName, VarName]],
edges: list[tuple[str, str]],
formatting: str = "plain",
node_formatters: NodeTypeFormatterMapping | None = None,
create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length,
Expand Down Expand Up @@ -565,7 +563,7 @@ def make_networkx(
def model_to_networkx(
model=None,
*,
var_names: Iterable[VarName] | None = None,
var_names: Iterable[str] | None = None,
formatting: str = "plain",
node_formatters: NodeTypeFormatterMapping | None = None,
include_dim_lengths: bool = True,
Expand Down Expand Up @@ -659,7 +657,7 @@ def model_to_networkx(
def model_to_graphviz(
model=None,
*,
var_names: Iterable[VarName] | None = None,
var_names: Iterable[str] | None = None,
formatting: str = "plain",
save: str | None = None,
figsize: tuple[int, int] | None = None,
Expand Down
8 changes: 3 additions & 5 deletions pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import namedtuple
from collections.abc import Sequence
from copy import deepcopy
from typing import NewType, cast
from typing import cast

import arviz
import cloudpickle
Expand All @@ -31,8 +31,6 @@

from pymc.exceptions import BlockModelAccessError

VarName = NewType("VarName", str)


class _UnsetType:
"""Type for the `UNSET` object to make it look nice in `help(...)` outputs."""
Expand Down Expand Up @@ -214,9 +212,9 @@ def get_default_varnames(var_iterator, include_transformed):
return [var for var in var_iterator if not is_transformed_name(get_var_name(var))]


def get_var_name(var) -> VarName:
def get_var_name(var) -> str:
"""Get an appropriate, plain variable name for a variable."""
return VarName(str(getattr(var, "name", var)))
return var if isinstance(var, str) else str(var.name)


def get_transformed(z):
Expand Down