|
49 | 49 |
|
50 | 50 | _log = logging.getLogger(__name__)
|
51 | 51 |
|
| 52 | + |
| 53 | +RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False |
| 54 | + |
| 55 | + |
52 | 56 | # random variable object ...
|
53 | 57 | Var = Any
|
54 | 58 |
|
55 | 59 |
|
| 60 | +def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs): |
| 61 | + safe_coords = coords |
| 62 | + |
| 63 | + if not RAISE_ON_INCOMPATIBLE_COORD_LENGTHS: |
| 64 | + coords_lengths = {k: len(v) for k, v in coords.items()} |
| 65 | + for var_name, var in vars_dict.items(): |
| 66 | + # Iterate in reversed because of chain/draw batch dimensions |
| 67 | + for dim, dim_length in zip(reversed(dims.get(var_name, ())), reversed(var.shape)): |
| 68 | + coord_length = coords_lengths.get(dim, None) |
| 69 | + if (coord_length is not None) and (coord_length != dim_length): |
| 70 | + warnings.warn( |
| 71 | + f"Incompatible coordinate length of {coord_length} for dimension '{dim}' of variable '{var_name}'.\n" |
| 72 | + "This usually happens when a sliced or concatenated variable is wrapped as a `pymc.dims.Deterministic`." |
| 73 | + "The originate coordinates for this dim will not be included in the returned dataset for any of the variables. " |
| 74 | + "Instead they will default to `np.arange(var_length)` and the shorter variables will be right-padded with nan.\n" |
| 75 | + "To make this warning into an error set `pymc.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS` to `True`", |
| 76 | + UserWarning, |
| 77 | + ) |
| 78 | + if safe_coords is coords: |
| 79 | + safe_coords = coords.copy() |
| 80 | + safe_coords.pop(dim) |
| 81 | + coords_lengths.pop(dim) |
| 82 | + |
| 83 | + # FIXME: Would be better to drop coordinates altogether, but arviz defaults to `np.arange(var_length)` |
| 84 | + return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs) |
| 85 | + |
| 86 | + |
56 | 87 | def find_observations(model: "Model") -> dict[str, Var]:
|
57 | 88 | """If there are observations available, return them as a dictionary."""
|
58 | 89 | observations = {}
|
@@ -365,7 +396,7 @@ def priors_to_xarray(self):
|
365 | 396 | priors_dict[group] = (
|
366 | 397 | None
|
367 | 398 | if var_names is None
|
368 |
| - else dict_to_dataset( |
| 399 | + else dict_to_dataset_drop_incompatible_coords( |
369 | 400 | {k: np.expand_dims(self.prior[k], 0) for k in var_names},
|
370 | 401 | library=pymc,
|
371 | 402 | coords=self.coords,
|
|
0 commit comments