Skip to content

Commit 0e03123

Browse files
committed
Implement xarray like semantics in dims module
1 parent f4bdc6c commit 0e03123

22 files changed

+1145
-134
lines changed

pymc/data.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414

1515
import io
16+
import typing
1617
import urllib.request
1718

1819
from collections.abc import Sequence
1920
from copy import copy
20-
from typing import cast
21+
from typing import Union, cast
2122

2223
import numpy as np
2324
import pandas as pd
@@ -32,12 +33,13 @@
3233
from pytensor.tensor.random.basic import IntegersRV
3334
from pytensor.tensor.variable import TensorConstant, TensorVariable
3435

35-
import pymc as pm
36-
37-
from pymc.logprob.utils import rvs_in_graph
38-
from pymc.pytensorf import convert_data
36+
from pymc.exceptions import ShapeError
37+
from pymc.pytensorf import convert_data, rvs_in_graph
3938
from pymc.vartypes import isgenerator
4039

40+
if typing.TYPE_CHECKING:
41+
from pymc.model.core import Model
42+
4143
__all__ = [
4244
"Data",
4345
"Minibatch",
@@ -197,7 +199,7 @@ def determine_coords(
197199

198200
if isinstance(value, np.ndarray) and dims is not None:
199201
if len(dims) != value.ndim:
200-
raise pm.exceptions.ShapeError(
202+
raise ShapeError(
201203
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
202204
actual=value.shape,
203205
expected=value.ndim,
@@ -222,6 +224,7 @@ def Data(
222224
dims: Sequence[str] | None = None,
223225
coords: dict[str, Sequence | np.ndarray] | None = None,
224226
infer_dims_and_coords=False,
227+
model: Union["Model", None] = None,
225228
**kwargs,
226229
) -> SharedVariable | TensorConstant:
227230
"""Create a data container that registers a data variable with the model.
@@ -286,15 +289,18 @@ def Data(
286289
... model.set_data("data", data_vals)
287290
... idatas.append(pm.sample())
288291
"""
292+
from pymc.model.core import modelcontext
293+
289294
if coords is None:
290295
coords = {}
291296

292297
if isinstance(value, list):
293298
value = np.array(value)
294299

295300
# Add data container to the named variables of the model.
296-
model = pm.Model.get_context(error_if_none=False)
297-
if model is None:
301+
try:
302+
model = modelcontext(model)
303+
except TypeError:
298304
raise TypeError(
299305
"No model on context stack, which is needed to instantiate a data container. "
300306
"Add variable inside a 'with model:' block."
@@ -321,7 +327,7 @@ def Data(
321327
if isinstance(dims, str):
322328
dims = (dims,)
323329
if not (dims is None or len(dims) == x.ndim):
324-
raise pm.exceptions.ShapeError(
330+
raise ShapeError(
325331
"Length of `dims` must match the dimensions of the dataset.",
326332
actual=len(dims),
327333
expected=x.ndim,

pymc/dims/__init__.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def __init__():
17+
"""Make PyMC aware of the xtensor functionality.
18+
19+
This should be done eagerly once development matures.
20+
"""
21+
import datetime
22+
import warnings
23+
24+
from pytensor.compile import optdb
25+
26+
from pymc.initial_point import initial_point_rewrites_db
27+
from pymc.logprob.abstract import MeasurableOp
28+
from pymc.logprob.rewriting import logprob_rewrites_db
29+
30+
# Filter PyTensor xtensor warning, we emmit our own warning
31+
with warnings.catch_warnings():
32+
warnings.simplefilter("ignore", UserWarning)
33+
import pytensor.xtensor
34+
35+
from pytensor.xtensor.vectorization import XRV
36+
37+
# Make PyMC aware of xtensor functionality
38+
MeasurableOp.register(XRV)
39+
lower_xtensor_query = optdb.query("+lower_xtensor")
40+
logprob_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1)
41+
initial_point_rewrites_db.register("lower_xtensor", lower_xtensor_query, "basic", position=0.1)
42+
43+
# TODO: Better model of probability of bugs
44+
day_of_conception = datetime.date(2025, 6, 17)
45+
day_of_last_bug = datetime.date(2025, 6, 20)
46+
today = datetime.date.today()
47+
days_with_bugs = (day_of_last_bug - day_of_conception).days
48+
days_without_bugs = (today - day_of_last_bug).days
49+
p = 1 - (days_without_bugs / (days_without_bugs + days_with_bugs + 10))
50+
if p > 0.05:
51+
warnings.warn(
52+
f"The `pymc.dims` module is experimental and may contain critical bugs (p={p:.3f}).\n"
53+
"Please report any issues you encounter at https://github.com/pymc-devs/pymc/issues.\n"
54+
"Disclaimer: This an experimental API and may change at any time.",
55+
UserWarning,
56+
stacklevel=2,
57+
)
58+
59+
60+
__init__()
61+
del __init__
62+
63+
from pymc.dims import math
64+
from pymc.dims.distributions import *
65+
from pymc.dims.model import Data, with_dims

pymc/dims/distribution_core.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from collections.abc import Callable, Sequence
15+
from itertools import chain
16+
17+
from pytensor.tensor.elemwise import DimShuffle
18+
from pytensor.xtensor import as_xtensor
19+
from pytensor.xtensor.type import XTensorVariable
20+
21+
from pymc import modelcontext
22+
from pymc.dims.model import with_dims
23+
from pymc.distributions import transforms
24+
from pymc.distributions.distribution import _support_point, support_point
25+
from pymc.distributions.shape_utils import DimsWithEllipsis, convert_dims
26+
from pymc.util import UNSET
27+
28+
29+
@_support_point.register(DimShuffle)
30+
def dimshuffle_support_point(ds_op, _, rv):
31+
# We implement support point for DimShuffle because
32+
# DimDistribution can register a transposed version of a variable.
33+
34+
return ds_op(support_point(rv))
35+
36+
37+
class DimDistribution:
38+
"""Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics."""
39+
40+
xrv_op: Callable
41+
default_transform: Callable | None = None
42+
43+
@staticmethod
44+
def _as_xtensor(x):
45+
try:
46+
return as_xtensor(x)
47+
except TypeError:
48+
try:
49+
return with_dims(x)
50+
except ValueError:
51+
raise ValueError(
52+
f"Variable {x} must have dims associated with it.\n"
53+
"To avoid subtle bugs, PyMC does not make any assumptions about the dims of the parameters.\n"
54+
"Convert parameters to an xarray.DataArray, pymc.dims.Data or pytensor.xtensor.as_xtensor with explicit dims."
55+
)
56+
57+
def __new__(
58+
cls,
59+
name: str,
60+
*dist_params,
61+
dims: DimsWithEllipsis | None = None,
62+
initval=None,
63+
observed=None,
64+
total_size=None,
65+
transform=UNSET,
66+
default_transform=UNSET,
67+
model=None,
68+
**kwargs,
69+
) -> XTensorVariable:
70+
try:
71+
model = modelcontext(model)
72+
except TypeError:
73+
raise TypeError(
74+
"No model on context stack, which is needed to instantiate distributions. "
75+
"Add variable inside a 'with model:' block, or use the '.dist' syntax for a standalone distribution."
76+
)
77+
78+
if not isinstance(name, str):
79+
raise TypeError(f"Name needs to be a string but got: {name}")
80+
81+
if dims is None:
82+
dims_dict = {}
83+
else:
84+
dims = convert_dims(dims)
85+
try:
86+
dims_dict = {dim: model.dim_lengths[dim] for dim in dims if dim is not Ellipsis}
87+
except KeyError:
88+
raise ValueError(
89+
f"Not all dims {dims} are part of the model coords. "
90+
f"Add them at initialization time or use `model.add_coord` before defining the distribution."
91+
)
92+
93+
if observed is not None:
94+
observed = cls._as_xtensor(observed)
95+
96+
# Propagate observed dims to dims_dict
97+
for observed_dim in observed.type.dims:
98+
if observed_dim not in dims_dict:
99+
dims_dict[observed_dim] = model.dim_lengths[observed_dim]
100+
101+
rv = cls.dist(*dist_params, dims_dict=dims_dict, **kwargs)
102+
103+
# User provided dims must specify all dims or use ellipsis
104+
if dims is not None:
105+
if (... not in dims) and (set(dims) != set(rv.type.dims)):
106+
raise ValueError(
107+
f"Provided dims {dims} do not match the distribution's output dims {rv.type.dims}. "
108+
"Use ellipsis to specify all other dimensions."
109+
)
110+
# Use provided dims to transpose the output to the desired order
111+
rv = rv.transpose(*dims)
112+
113+
rv_dims = rv.type.dims
114+
if observed is None:
115+
if default_transform is UNSET:
116+
default_transform = cls.default_transform
117+
else:
118+
# Align observed dims with those of the RV
119+
# TODO: If this fails give a more informative error message
120+
observed = observed.transpose(*rv_dims).values
121+
122+
rv = model.register_rv(
123+
rv.values,
124+
name=name,
125+
observed=observed,
126+
total_size=total_size,
127+
dims=rv_dims,
128+
transform=transform,
129+
default_transform=default_transform,
130+
initval=initval,
131+
)
132+
133+
return as_xtensor(rv, dims=rv_dims)
134+
135+
@classmethod
136+
def dist(
137+
cls,
138+
dist_params,
139+
*,
140+
dims_dict: dict[str, int] | None = None,
141+
core_dims: str | Sequence[str] | None = None,
142+
**kwargs,
143+
) -> XTensorVariable:
144+
for invalid_kwarg in ("size", "shape", "dims"):
145+
if invalid_kwarg in kwargs:
146+
raise TypeError(f"DimDistribution does not accept {invalid_kwarg} argument.")
147+
148+
# XRV requires only extra_dims, not dims
149+
dist_params = [cls._as_xtensor(param) for param in dist_params]
150+
151+
if dims_dict is None:
152+
extra_dims = None
153+
else:
154+
parameter_implied_dims = set(
155+
chain.from_iterable(param.type.dims for param in dist_params)
156+
)
157+
extra_dims = {
158+
dim: length
159+
for dim, length in dims_dict.items()
160+
if dim not in parameter_implied_dims
161+
}
162+
return cls.xrv_op(*dist_params, extra_dims=extra_dims, core_dims=core_dims, **kwargs)
163+
164+
165+
class MultivariateDimDistribution(DimDistribution):
166+
@classmethod
167+
def dist(self, *args, core_dims: str | Sequence[str] | None = None, **kwargs):
168+
# Add a helpful error message if core_dims is not provided
169+
if core_dims is None:
170+
raise ValueError(
171+
f"{self.__name__} requires core_dims to be specified, as it is a multivariate distribution."
172+
"Check the documentation of the distribution for details."
173+
)
174+
return super().dist(*args, core_dims=core_dims, **kwargs)
175+
176+
177+
class PositiveDimDistribution(DimDistribution):
178+
"""Base class for positive continuous distributions."""
179+
180+
default_transform = transforms.log
181+
182+
183+
class UnitDimDistribution(DimDistribution):
184+
"""Base class for unit-valued distributions."""
185+
186+
default_transform = transforms.logodds

0 commit comments

Comments
 (0)