Skip to content

Commit 937905c

Browse files
committed
Arviz don't fail hard on incompatible coordinate lengths
1 parent 40977bf commit 937905c

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

pymc/backends/arviz.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,41 @@
4949

5050
_log = logging.getLogger(__name__)
5151

52+
53+
RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False
54+
55+
5256
# random variable object ...
5357
Var = Any
5458

5559

60+
def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs):
61+
safe_coords = coords
62+
63+
if not RAISE_ON_INCOMPATIBLE_COORD_LENGTHS:
64+
coords_lengths = {k: len(v) for k, v in coords.items()}
65+
for var_name, var in vars_dict.items():
66+
# Iterate in reversed because of chain/draw batch dimensions
67+
for dim, dim_length in zip(reversed(dims.get(var_name, ())), reversed(var.shape)):
68+
coord_length = coords_lengths.get(dim, None)
69+
if (coord_length is not None) and (coord_length != dim_length):
70+
warnings.warn(
71+
f"Incompatible coordinate length of {coord_length} for dimension '{dim}' of variable '{var_name}'.\n"
72+
"This usually happens when a sliced or concatenated variable is wrapped as a `pymc.dims.Deterministic`."
73+
"The originate coordinates for this dim will not be included in the returned dataset for any of the variables. "
74+
"Instead they will default to `np.arange(var_length)` and the shorter variables will be right-padded with nan.\n"
75+
"To make this warning into an error set `pymc.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS` to `True`",
76+
UserWarning,
77+
)
78+
if safe_coords is coords:
79+
safe_coords = coords.copy()
80+
safe_coords.pop(dim)
81+
coords_lengths.pop(dim)
82+
83+
# FIXME: Would be better to drop coordinates altogether, but arviz defaults to `np.arange(var_length)`
84+
return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs)
85+
86+
5687
def find_observations(model: "Model") -> dict[str, Var]:
5788
"""If there are observations available, return them as a dictionary."""
5889
observations = {}
@@ -365,7 +396,7 @@ def priors_to_xarray(self):
365396
priors_dict[group] = (
366397
None
367398
if var_names is None
368-
else dict_to_dataset(
399+
else dict_to_dataset_drop_incompatible_coords(
369400
{k: np.expand_dims(self.prior[k], 0) for k in var_names},
370401
library=pymc,
371402
coords=self.coords,

tests/backends/test_arviz.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import re
1415
import warnings
1516

1617
import numpy as np
@@ -837,3 +838,27 @@ def test_dataset_to_point_list_str_key(self):
837838
ds[3] = xarray.DataArray([1, 2, 3])
838839
with pytest.raises(ValueError, match="must be str"):
839840
dataset_to_point_list(ds, sample_dims=["chain", "draw"])
841+
842+
843+
def test_incompatible_coordinate_lengths():
844+
with pm.Model(coords={"a": [-1, -2, -3]}) as m:
845+
x = pm.Normal("x", dims="a")
846+
y = pm.Deterministic("y", x[1:], dims=("a",))
847+
848+
with pytest.warns(
849+
UserWarning,
850+
match=re.escape(
851+
"Incompatible coordinate length of 3 for dimension 'a' of variable 'y'"
852+
),
853+
):
854+
prior = pm.sample_prior_predictive(draws=1).prior.squeeze(("chain", "draw"))
855+
assert prior.x.dims == prior.y.dims == ("a",)
856+
assert prior.x.shape == prior.y.shape == (3,)
857+
assert np.isnan(prior.y.values[-1])
858+
assert list(prior.coords["a"]) == [0, 1, 2]
859+
860+
pm.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = True
861+
with pytest.raises(ValueError):
862+
pm.sample_prior_predictive(draws=1)
863+
864+
pm.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False

0 commit comments

Comments
 (0)