Skip to content

Commit d10cca7

Browse files
committed
Arviz don't fail hard on incompatible coordinates
1 parent 7f76f23 commit d10cca7

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

pymc/backends/arviz.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,40 @@
4949

5050
_log = logging.getLogger(__name__)
5151

52+
53+
RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False
54+
55+
5256
# random variable object ...
5357
Var = Any
5458

5559

60+
def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs):
61+
safe_coords = coords
62+
63+
if not RAISE_ON_INCOMPATIBLE_COORD_LENGTHS:
64+
coords_lengths = {k: len(v) for k, v in coords.items()}
65+
for var_name, var in vars_dict.items():
66+
# Iterate in reversed because of chain/draw batch dimensions
67+
for dim, dim_length in zip(reversed(dims.get(var_name, ())), reversed(var.shape)):
68+
coord_length = coords_lengths.get(dim, None)
69+
if (coord_length is not None) and (coord_length != dim_length):
70+
warnings.warn(
71+
f"Incompatible coordinate length of {coord_length} found for dimension {dim} of variable {var_name}.\n"
72+
"The originate coordinates for this dim will not be included in the returned dataset for any of the variables."
73+
"Instead they will default to `np.arange(var_length)`.\n"
74+
"To make this warning into an errror set `pymc.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS` to `True`",
75+
UserWarning,
76+
)
77+
if safe_coords is coords:
78+
safe_coords = coords.copy()
79+
safe_coords.pop(dim)
80+
coords_lengths.pop(dim)
81+
82+
# FIXME: Would be better to drop coordinates altogether, but arviz defaults to `np.arange(var_length)`
83+
return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs)
84+
85+
5686
def find_observations(model: "Model") -> dict[str, Var]:
5787
"""If there are observations available, return them as a dictionary."""
5888
observations = {}
@@ -365,7 +395,7 @@ def priors_to_xarray(self):
365395
priors_dict[group] = (
366396
None
367397
if var_names is None
368-
else dict_to_dataset(
398+
else dict_to_dataset_drop_incompatible_coords(
369399
{k: np.expand_dims(self.prior[k], 0) for k in var_names},
370400
library=pymc,
371401
coords=self.coords,

0 commit comments

Comments
 (0)