From 97cd9e17942599ae2e386c5dae870842e5a700a0 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 29 Jun 2025 23:54:18 +0200 Subject: [PATCH 01/10] Convert mu to TensorVariable in Normal Also remove commented out code --- pymc/distributions/continuous.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index ca4fece1ba..fb4fd5eb3a 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -487,11 +487,7 @@ class Normal(Continuous): def dist(cls, mu=0, sigma=None, tau=None, **kwargs): tau, sigma = get_tau_sigma(tau=tau, sigma=sigma) sigma = pt.as_tensor_variable(sigma) - - # tau = pt.as_tensor_variable(tau) - # mean = median = mode = mu = pt.as_tensor_variable(floatX(mu)) - # variance = 1.0 / self.tau - + mu = pt.as_tensor_variable(mu) return super().dist([mu, sigma], **kwargs) def support_point(rv, size, mu, sigma): From d82330e59dd20b01faea6f08dfa69d6209f9251c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 3 Jul 2025 20:27:26 +0200 Subject: [PATCH 02/10] Bump PyTensor dependency --- conda-envs/environment-alternative-backends.yml | 2 +- conda-envs/environment-dev.yml | 2 +- conda-envs/environment-docs.yml | 2 +- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-dev.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- requirements-dev.txt | 2 +- requirements.txt | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/conda-envs/environment-alternative-backends.yml b/conda-envs/environment-alternative-backends.yml index fcf78c6991..d2cfd439a5 100644 --- a/conda-envs/environment-alternative-backends.yml +++ b/conda-envs/environment-alternative-backends.yml @@ -22,7 +22,7 @@ dependencies: - numpyro>=0.8.0 - pandas>=0.24.0 - pip -- pytensor>=2.31.2,<2.32 +- pytensor>=2.31.7,<2.32 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index a49e0568ae..a1545e60d3 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -12,7 +12,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.31.2,<2.32 +- pytensor>=2.31.7,<2.32 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index 67e6673af2..fc0541f46f 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -11,7 +11,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.31.2,<2.32 +- pytensor>=2.31.7,<2.32 - python-graphviz - rich>=13.7.1 - scipy>=1.4.1 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 2230d08e77..d2978f581b 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -14,7 +14,7 @@ dependencies: - pandas>=0.24.0 - pip - polyagamma -- pytensor>=2.31.2,<2.32 +- pytensor>=2.31.7,<2.32 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index daf0398147..58547c4d49 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -12,7 +12,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.31.2,<2.32 +- pytensor>=2.31.7,<2.32 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 7c9c28b70f..b9a1a36a92 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -15,7 +15,7 @@ dependencies: - pandas>=0.24.0 - pip - polyagamma -- pytensor>=2.31.2,<2.32 +- pytensor>=2.31.7,<2.32 - python-graphviz - networkx - rich>=13.7.1 diff --git a/requirements-dev.txt b/requirements-dev.txt index cabfc740af..fd00bfd03d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -16,7 +16,7 @@ pandas>=0.24.0 polyagamma pre-commit>=2.8.0 pymc-sphinx-theme>=0.16.0 -pytensor>=2.31.2,<2.32 +pytensor>=2.31.7,<2.32 pytest-cov>=2.5 pytest>=3.0 rich>=13.7.1 diff --git a/requirements.txt b/requirements.txt index c1ca979dd1..621303906c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ cachetools>=4.2.1 cloudpickle numpy>=1.25.0 pandas>=0.24.0 -pytensor>=2.31.2,<2.32 +pytensor>=2.31.7,<2.32 rich>=13.7.1 scipy>=1.4.1 threadpoolctl>=3.1.0,<4.0.0 From 80a9fef6ed232443b5ec2a8e4e0d8d36779a36fb Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 14 Jul 2025 12:55:38 +0200 Subject: [PATCH 03/10] Implement xarray like semantics in dims module --- .github/workflows/tests.yml | 6 + pymc/data.py | 24 ++- pymc/dims/__init__.py | 71 ++++++++ pymc/dims/distributions/__init__.py | 15 ++ pymc/dims/distributions/core.py | 191 ++++++++++++++++++++ pymc/dims/distributions/scalar.py | 174 ++++++++++++++++++ pymc/dims/distributions/vector.py | 116 ++++++++++++ pymc/dims/math.py | 15 ++ pymc/dims/model.py | 128 +++++++++++++ pymc/distributions/continuous.py | 16 +- pymc/distributions/distribution.py | 9 +- pymc/distributions/multivariate.py | 4 +- pymc/distributions/shape_utils.py | 24 ++- pymc/initial_point.py | 52 +++--- pymc/logprob/basic.py | 15 +- pymc/logprob/rewriting.py | 26 ++- pymc/logprob/utils.py | 38 ++-- pymc/math.py | 4 + pymc/model/transform/conditioning.py | 3 +- pymc/pytensorf.py | 112 +++++++----- pymc/sampling/forward.py | 3 +- pymc/testing.py | 35 +++- pyproject.toml | 9 + tests/dims/__init__.py | 13 ++ tests/dims/distributions/__init__.py | 13 ++ tests/dims/distributions/test_core.py | 191 ++++++++++++++++++++ tests/dims/distributions/test_scalar.py | 217 +++++++++++++++++++++++ tests/dims/distributions/test_vector.py | 62 +++++++ tests/dims/test_model.py | 174 ++++++++++++++++++ tests/dims/utils.py | 64 +++++++ tests/distributions/test_continuous.py | 3 +- tests/distributions/test_distribution.py | 2 +- tests/distributions/test_multivariate.py | 3 +- tests/distributions/test_shape_utils.py | 9 + tests/sampling/test_forward.py | 3 +- tests/test_initial_point.py | 15 +- tests/test_pytensorf.py | 2 +- 37 files changed, 1711 insertions(+), 150 deletions(-) create mode 100644 pymc/dims/__init__.py create mode 100644 pymc/dims/distributions/__init__.py create mode 100644 pymc/dims/distributions/core.py create mode 100644 pymc/dims/distributions/scalar.py create mode 100644 pymc/dims/distributions/vector.py create mode 100644 pymc/dims/math.py create mode 100644 pymc/dims/model.py create mode 100644 tests/dims/__init__.py create mode 100644 tests/dims/distributions/__init__.py create mode 100644 tests/dims/distributions/test_core.py create mode 100644 tests/dims/distributions/test_scalar.py create mode 100644 tests/dims/distributions/test_vector.py create mode 100644 tests/dims/test_model.py create mode 100644 tests/dims/utils.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d9b4b000c7..4ef5add464 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -135,6 +135,12 @@ jobs: tests/logprob/test_transforms.py tests/logprob/test_utils.py + - | + tests/dims/distributions/test_core.py + tests/dims/distributions/test_scalar.py + tests/dims/distributions/test_vector.py + tests/dims/test_model.py + fail-fast: false runs-on: ${{ matrix.os }} env: diff --git a/pymc/data.py b/pymc/data.py index 507f547e5b..cfade37910 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -13,11 +13,12 @@ # limitations under the License. import io +import typing import urllib.request from collections.abc import Sequence from copy import copy -from typing import cast +from typing import Union, cast import numpy as np import pandas as pd @@ -32,12 +33,13 @@ from pytensor.tensor.random.basic import IntegersRV from pytensor.tensor.variable import TensorConstant, TensorVariable -import pymc as pm - -from pymc.logprob.utils import rvs_in_graph -from pymc.pytensorf import convert_data +from pymc.exceptions import ShapeError +from pymc.pytensorf import convert_data, rvs_in_graph from pymc.vartypes import isgenerator +if typing.TYPE_CHECKING: + from pymc.model.core import Model + __all__ = [ "Data", "Minibatch", @@ -197,7 +199,7 @@ def determine_coords( if isinstance(value, np.ndarray) and dims is not None: if len(dims) != value.ndim: - raise pm.exceptions.ShapeError( + raise ShapeError( "Invalid data shape. The rank of the dataset must match the length of `dims`.", actual=value.shape, expected=value.ndim, @@ -222,6 +224,7 @@ def Data( dims: Sequence[str] | None = None, coords: dict[str, Sequence | np.ndarray] | None = None, infer_dims_and_coords=False, + model: Union["Model", None] = None, **kwargs, ) -> SharedVariable | TensorConstant: """Create a data container that registers a data variable with the model. @@ -286,6 +289,8 @@ def Data( ... model.set_data("data", data_vals) ... idatas.append(pm.sample()) """ + from pymc.model.core import modelcontext + if coords is None: coords = {} @@ -293,8 +298,9 @@ def Data( value = np.array(value) # Add data container to the named variables of the model. - model = pm.Model.get_context(error_if_none=False) - if model is None: + try: + model = modelcontext(model) + except TypeError: raise TypeError( "No model on context stack, which is needed to instantiate a data container. " "Add variable inside a 'with model:' block." @@ -321,7 +327,7 @@ def Data( if isinstance(dims, str): dims = (dims,) if not (dims is None or len(dims) == x.ndim): - raise pm.exceptions.ShapeError( + raise ShapeError( "Length of `dims` must match the dimensions of the dataset.", actual=len(dims), expected=x.ndim, diff --git a/pymc/dims/__init__.py b/pymc/dims/__init__.py new file mode 100644 index 0000000000..d1a050638c --- /dev/null +++ b/pymc/dims/__init__.py @@ -0,0 +1,71 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def __init__(): + """Make PyMC aware of the xtensor functionality. + + This should be done eagerly once development matures. + """ + import datetime + import warnings + + from pytensor.compile import optdb + + from pymc.initial_point import initial_point_rewrites_db + from pymc.logprob.abstract import MeasurableOp + from pymc.logprob.rewriting import logprob_rewrites_db + + # Filter PyTensor xtensor warning, we emmit our own warning + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + import pytensor.xtensor + + from pytensor.xtensor.vectorization import XRV + + # Make PyMC aware of xtensor functionality + MeasurableOp.register(XRV) + logprob_rewrites_db.register( + "pre_lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1 + ) + initial_point_rewrites_db.register( + "lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1 + ) + + # TODO: Better model of probability of bugs + day_of_conception = datetime.date(2025, 6, 17) + day_of_last_bug = datetime.date(2025, 6, 30) + today = datetime.date.today() + days_with_bugs = (day_of_last_bug - day_of_conception).days + days_without_bugs = (today - day_of_last_bug).days + p = 1 - (days_without_bugs / (days_without_bugs + days_with_bugs + 10)) + if p > 0.05: + warnings.warn( + f"The `pymc.dims` module is experimental and may contain critical bugs (p={p:.3f}).\n" + "Please report any issues you encounter at https://github.com/pymc-devs/pymc/issues.\n" + "Disclaimer: This an experimental API and may change at any time.", + UserWarning, + stacklevel=2, + ) + + +__init__() +del __init__ + +from pytensor.xtensor import as_xtensor, broadcast, concat, dot, full_like, ones_like, zeros_like +from pytensor.xtensor.basic import tensor_from_xtensor + +from pymc.dims import math +from pymc.dims.distributions import * +from pymc.dims.model import Data, Deterministic, Potential diff --git a/pymc/dims/distributions/__init__.py b/pymc/dims/distributions/__init__.py new file mode 100644 index 0000000000..da85bc1463 --- /dev/null +++ b/pymc/dims/distributions/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pymc.dims.distributions.scalar import * +from pymc.dims.distributions.vector import * diff --git a/pymc/dims/distributions/core.py b/pymc/dims/distributions/core.py new file mode 100644 index 0000000000..bd48db0ec5 --- /dev/null +++ b/pymc/dims/distributions/core.py @@ -0,0 +1,191 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable, Sequence +from itertools import chain + +from pytensor.graph.basic import Variable +from pytensor.tensor.elemwise import DimShuffle +from pytensor.xtensor import as_xtensor +from pytensor.xtensor.type import XTensorVariable + +from pymc import modelcontext +from pymc.dims.model import with_dims +from pymc.distributions import transforms +from pymc.distributions.distribution import _support_point, support_point +from pymc.distributions.shape_utils import DimsWithEllipsis, convert_dims_with_ellipsis +from pymc.logprob.transforms import Transform +from pymc.util import UNSET + + +@_support_point.register(DimShuffle) +def dimshuffle_support_point(ds_op, _, rv): + # We implement support point for DimShuffle because + # DimDistribution can register a transposed version of a variable. + + return ds_op(support_point(rv)) + + +class DimDistribution: + """Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics.""" + + xrv_op: Callable + default_transform: Transform | None = None + + @staticmethod + def _as_xtensor(x): + try: + return as_xtensor(x) + except TypeError: + try: + return with_dims(x) + except ValueError: + raise ValueError( + f"Variable {x} must have dims associated with it.\n" + "To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n" + "Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly." + ) + + def __new__( + cls, + name: str, + *dist_params, + dims: DimsWithEllipsis | None = None, + initval=None, + observed=None, + total_size=None, + transform=UNSET, + default_transform=UNSET, + model=None, + **kwargs, + ): + try: + model = modelcontext(model) + except TypeError: + raise TypeError( + "No model on context stack, which is needed to instantiate distributions. " + "Add variable inside a 'with model:' block, or use the '.dist' syntax for a standalone distribution." + ) + + if not isinstance(name, str): + raise TypeError(f"Name needs to be a string but got: {name}") + + dims = convert_dims_with_ellipsis(dims) + if dims is None: + dim_lengths = {} + else: + try: + dim_lengths = {dim: model.dim_lengths[dim] for dim in dims if dim is not Ellipsis} + except KeyError: + raise ValueError( + f"Not all dims {dims} are part of the model coords. " + f"Add them at initialization time or use `model.add_coord` before defining the distribution." + ) + + if observed is not None: + observed = cls._as_xtensor(observed) + + # Propagate observed dims to dim_lengths + for observed_dim in observed.type.dims: + if observed_dim not in dim_lengths: + dim_lengths[observed_dim] = model.dim_lengths[observed_dim] + + rv = cls.dist(*dist_params, dim_lengths=dim_lengths, **kwargs) + + # User provided dims must specify all dims or use ellipsis + if dims is not None: + if (... not in dims) and (set(dims) != set(rv.type.dims)): + raise ValueError( + f"Provided dims {dims} do not match the distribution's output dims {rv.type.dims}. " + "Use ellipsis to specify all other dimensions." + ) + # Use provided dims to transpose the output to the desired order + rv = rv.transpose(*dims) + + rv_dims = rv.type.dims + if observed is None: + if default_transform is UNSET: + default_transform = cls.default_transform + else: + # Align observed dims with those of the RV + # TODO: If this fails give a more informative error message + observed = observed.transpose(*rv_dims).values + + rv = model.register_rv( + rv.values, + name=name, + observed=observed, + total_size=total_size, + dims=rv_dims, + transform=transform, + default_transform=default_transform, + initval=initval, + ) + + return as_xtensor(rv, dims=rv_dims) + + @classmethod + def dist( + cls, + dist_params, + *, + dim_lengths: dict[str, Variable | int] | None = None, + core_dims: str | Sequence[str] | None = None, + **kwargs, + ) -> XTensorVariable: + for invalid_kwarg in ("size", "shape", "dims"): + if invalid_kwarg in kwargs: + raise TypeError(f"DimDistribution does not accept {invalid_kwarg} argument.") + + # XRV requires only extra_dims, not dims + dist_params = [cls._as_xtensor(param) for param in dist_params] + + if dim_lengths is None: + extra_dims = None + else: + # Exclude dims that are implied by the parameters or core_dims + implied_dims = set(chain.from_iterable(param.type.dims for param in dist_params)) + if core_dims is not None: + if isinstance(core_dims, str): + implied_dims.add(core_dims) + else: + implied_dims.update(core_dims) + + extra_dims = { + dim: length for dim, length in dim_lengths.items() if dim not in implied_dims + } + return cls.xrv_op(*dist_params, extra_dims=extra_dims, core_dims=core_dims, **kwargs) + + +class VectorDimDistribution(DimDistribution): + @classmethod + def dist(self, *args, core_dims: str | Sequence[str] | None = None, **kwargs): + # Add a helpful error message if core_dims is not provided + if core_dims is None: + raise ValueError( + f"{self.__name__} requires core_dims to be specified, as it involves non-scalar inputs or outputs." + "Check the documentation of the distribution for details." + ) + return super().dist(*args, core_dims=core_dims, **kwargs) + + +class PositiveDimDistribution(DimDistribution): + """Base class for positive continuous distributions.""" + + default_transform = transforms.log + + +class UnitDimDistribution(DimDistribution): + """Base class for unit-valued distributions.""" + + default_transform = transforms.logodds diff --git a/pymc/dims/distributions/scalar.py b/pymc/dims/distributions/scalar.py new file mode 100644 index 0000000000..cb5fd5e1f4 --- /dev/null +++ b/pymc/dims/distributions/scalar.py @@ -0,0 +1,174 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytensor.xtensor as ptx +import pytensor.xtensor.random as pxr + +from pymc.dims.distributions.core import ( + DimDistribution, + PositiveDimDistribution, + UnitDimDistribution, +) +from pymc.distributions.continuous import Beta as RegularBeta +from pymc.distributions.continuous import Gamma as RegularGamma +from pymc.distributions.continuous import flat, halfflat + + +def _get_sigma_from_either_sigma_or_tau(*, sigma, tau): + if sigma is not None and tau is not None: + raise ValueError("Can't pass both tau and sigma") + + if sigma is None and tau is None: + return 1.0 + + if sigma is not None: + return sigma + + return ptx.math.reciprocal(ptx.math.sqrt(tau)) + + +class Flat(DimDistribution): + xrv_op = pxr.as_xrv(flat) + + @classmethod + def dist(cls, **kwargs): + return super().dist([], **kwargs) + + +class HalfFlat(PositiveDimDistribution): + xrv_op = pxr.as_xrv(halfflat, [], ()) + + @classmethod + def dist(cls, **kwargs): + return super().dist([], **kwargs) + + +class Normal(DimDistribution): + xrv_op = pxr.normal + + @classmethod + def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs): + sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau) + return super().dist([mu, sigma], **kwargs) + + +class HalfNormal(PositiveDimDistribution): + xrv_op = pxr.halfnormal + + @classmethod + def dist(cls, sigma=None, *, tau=None, **kwargs): + sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau) + return super().dist([0.0, sigma], **kwargs) + + +class LogNormal(PositiveDimDistribution): + xrv_op = pxr.lognormal + + @classmethod + def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs): + sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=tau) + return super().dist([mu, sigma], **kwargs) + + +class StudentT(DimDistribution): + xrv_op = pxr.t + + @classmethod + def dist(cls, nu, mu=0, sigma=None, *, lam=None, **kwargs): + sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam) + return super().dist([nu, mu, sigma], **kwargs) + + +class Cauchy(DimDistribution): + xrv_op = pxr.cauchy + + @classmethod + def dist(cls, alpha, beta, **kwargs): + return super().dist([alpha, beta], **kwargs) + + +class HalfCauchy(PositiveDimDistribution): + xrv_op = pxr.halfcauchy + + @classmethod + def dist(cls, beta, **kwargs): + return super().dist([0.0, beta], **kwargs) + + +class Beta(UnitDimDistribution): + xrv_op = pxr.beta + + @classmethod + def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, nu=None, **kwargs): + alpha, beta = RegularBeta.get_alpha_beta(alpha=alpha, beta=beta, mu=mu, sigma=sigma, nu=nu) + return super().dist([alpha, beta], **kwargs) + + +class Laplace(DimDistribution): + xrv_op = pxr.laplace + + @classmethod + def dist(cls, mu=0, b=1, **kwargs): + return super().dist([mu, b], **kwargs) + + +class Exponential(PositiveDimDistribution): + xrv_op = pxr.exponential + + @classmethod + def dist(cls, lam=None, *, scale=None, **kwargs): + if lam is None and scale is None: + scale = 1.0 + elif lam is not None and scale is not None: + raise ValueError("Cannot pass both 'lam' and 'scale'. Use one of them.") + elif lam is not None: + scale = 1 / lam + return super().dist([scale], **kwargs) + + +class Gamma(PositiveDimDistribution): + xrv_op = pxr.gamma + + @classmethod + def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs): + if (alpha is not None) and (beta is not None): + pass + elif (mu is not None) and (sigma is not None): + # Use sign of sigma to not let negative sigma fly by + alpha = (mu**2 / sigma**2) * ptx.math.sign(sigma) + beta = mu / sigma**2 + else: + raise ValueError( + "Incompatible parameterization. Either use alpha and beta, or mu and sigma." + ) + alpha, beta = RegularGamma.get_alpha_beta(alpha=alpha, beta=beta, mu=mu, sigma=sigma) + return super().dist([alpha, ptx.math.reciprocal(beta)], **kwargs) + + +class InverseGamma(PositiveDimDistribution): + xrv_op = pxr.invgamma + + @classmethod + def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs): + if alpha is not None: + if beta is None: + beta = 1.0 + elif (mu is not None) and (sigma is not None): + # Use sign of sigma to not let negative sigma fly by + alpha = ((2 * sigma**2 + mu**2) / sigma**2) * ptx.math.sign(sigma) + beta = mu * (mu**2 + sigma**2) / sigma**2 + else: + raise ValueError( + "Incompatible parameterization. Either use alpha and (optionally) beta, or mu and sigma" + ) + return super().dist([alpha, beta], **kwargs) diff --git a/pymc/dims/distributions/vector.py b/pymc/dims/distributions/vector.py new file mode 100644 index 0000000000..c67b69fba9 --- /dev/null +++ b/pymc/dims/distributions/vector.py @@ -0,0 +1,116 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytensor.xtensor as ptx +import pytensor.xtensor.random as ptxr + +from pytensor.xtensor import random as pxr + +from pymc.dims.distributions.core import VectorDimDistribution + + +class Categorical(VectorDimDistribution): + """Categorical distribution. + + Parameters + ---------- + p : xtensor_like, optional + Probabilities of each category. Must sum to 1 along the core dimension. + Must be provided if `logit_p` is not specified. + logit_p : xtensor_like, optional + Alternative parametrization using logits. Must be provided if `p` is not specified. + core_dims : str + The core dimension of the distribution, which represents the categories. + The dimension must be present in `p` or `logit_p`. + **kwargs + Other keyword arguments used to define the distribution. + + Returns + ------- + XTensorVariable + An xtensor variable representing the categorical distribution. + The output does not contain the core dimension, as it is absorbed into the distribution. + + + """ + + xrv_op = ptxr.categorical + + @classmethod + def dist(cls, p=None, *, logit_p=None, core_dims=None, **kwargs): + if p is not None and logit_p is not None: + raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") + elif p is None and logit_p is None: + raise ValueError("Incompatible parametrization. Must specify either p or logit_p.") + + if logit_p is not None: + p = ptx.math.softmax(logit_p, dim=core_dims) + return super().dist([p], core_dims=core_dims, **kwargs) + + +class MvNormal(VectorDimDistribution): + """Multivariate Normal distribution. + + Parameters + ---------- + mu : xtensor_like + Mean vector of the distribution. + cov : xtensor_like, optional + Covariance matrix of the distribution. Only one of `cov` or `chol` must be provided. + chol : xtensor_like, optional + Cholesky decomposition of the covariance matrix. only one of `cov` or `chol` must be provided. + lower : bool, default True + If True, the Cholesky decomposition is assumed to be lower triangular. + If False, it is assumed to be upper triangular. + core_dims: Sequence of string + Sequence of two strings representing the core dimensions of the distribution. + The two dimensions must be present in `cov` or `chol`, and exactly one must also be present in `mu`. + **kwargs + Additional keyword arguments used to define the distribution. + + Returns + ------- + XTensorVariable + An xtensor variable representing the multivariate normal distribution. + The output contains the core dimension that is shared between `mu` and `cov` or `chol`. + + """ + + xrv_op = pxr.multivariate_normal + + @classmethod + def dist(cls, mu, cov=None, *, chol=None, lower=True, core_dims=None, **kwargs): + if "tau" in kwargs: + raise NotImplementedError("MvNormal does not support 'tau' parameter.") + + if not (isinstance(core_dims, tuple | list) and len(core_dims) == 2): + raise ValueError("MvNormal requires 2 core_dims") + + if cov is None and chol is None: + raise ValueError("Either 'cov' or 'chol' must be provided.") + + if chol is not None: + d0, d1 = core_dims + if not lower: + # By logical symmetry this must be the only correct way to implement lower + # We refuse to test it because it is not useful + d1, d0 = d0, d1 + + chol = cls._as_xtensor(chol) + # chol @ chol.T in xarray semantics requires a rename + safe_name = "_" + if "_" in chol.type.dims: + safe_name *= max(map(len, chol.type.dims)) + 1 + cov = chol.dot(chol.rename({d0: safe_name}), dim=d1).rename({safe_name: d1}) + + return super().dist([mu, cov], core_dims=core_dims, **kwargs) diff --git a/pymc/dims/math.py b/pymc/dims/math.py new file mode 100644 index 0000000000..5e1c8ba17a --- /dev/null +++ b/pymc/dims/math.py @@ -0,0 +1,15 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytensor.xtensor import linalg +from pytensor.xtensor.math import * diff --git a/pymc/dims/model.py b/pymc/dims/model.py new file mode 100644 index 0000000000..67ab5fb562 --- /dev/null +++ b/pymc/dims/model.py @@ -0,0 +1,128 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable + +from pytensor.tensor import TensorVariable +from pytensor.xtensor import as_xtensor +from pytensor.xtensor.basic import TensorFromXTensor +from pytensor.xtensor.type import XTensorVariable + +from pymc.data import Data as RegularData +from pymc.distributions.shape_utils import ( + Dims, + DimsWithEllipsis, + convert_dims, + convert_dims_with_ellipsis, +) +from pymc.model.core import Deterministic as RegularDeterministic +from pymc.model.core import Model, modelcontext +from pymc.model.core import Potential as RegularPotential + + +def with_dims(x: TensorVariable | XTensorVariable, model: Model | None = None) -> XTensorVariable: + """Recover the dims of a variable that was registered in the Model.""" + if isinstance(x, XTensorVariable): + return x + + if (x.owner is not None) and isinstance(x.owner.op, TensorFromXTensor): + dims = x.owner.inputs[0].type.dims + return as_xtensor(x, dims=dims, name=x.name) + + # Try accessing the model context to get dims + try: + model = modelcontext(model) + if ( + model.named_vars.get(x.name, None) is x + and (dims := model.named_vars_to_dims.get(x.name, None)) is not None + ): + return as_xtensor(x, dims=dims, name=x.name) + except TypeError: + pass + + raise ValueError(f"variable {x} doesn't have dims associated with it") + + +def Data( + name: str, value, dims: Dims = None, model: Model | None = None, **kwargs +) -> XTensorVariable: + """Wrapper around pymc.Data that returns an XtensorVariable. + + Dimensions are required if the input is not a scalar. + These are always forwarded to the model object. + + The respective TensorVariable is registered in the model + """ + model = modelcontext(model) + dims = convert_dims(dims) # type: ignore[assignment] + + with model: + value = RegularData(name, value, dims=dims, **kwargs) # type: ignore[arg-type] + + dims = model.named_vars_to_dims[value.name] + if dims is None and value.ndim > 0: + raise ValueError("pymc.dims.Data requires dims to be specified for non-scalar data.") + + return as_xtensor(value, dims=dims, name=name) # type: ignore[arg-type] + + +def _register_and_return_xtensor_variable( + name: str, + value: TensorVariable | XTensorVariable, + dims: DimsWithEllipsis | None, + model: Model | None, + registration_func: Callable, +) -> XTensorVariable: + if isinstance(value, XTensorVariable): + dims = convert_dims_with_ellipsis(dims) + if dims is not None: + # If dims are provided, apply a transpose to align with the user expectation + value = value.transpose(*dims) + # Regardless of whether dims are provided, we now have them + dims = value.type.dims + # Register the equivalent TensorVariable with the model so it doesn't see XTensorVariables directly. + value = value.values # type: ignore[union-attr] + + value = registration_func(name, value, dims=dims, model=model) + + return as_xtensor(value, dims=dims, name=name) # type: ignore[arg-type] + + +def Deterministic( + name: str, value, dims: DimsWithEllipsis | None = None, model: Model | None = None +) -> XTensorVariable: + """Wrapper around pymc.Deterministic that returns an XtensorVariable. + + If the input is already an XTensorVariable, dims are optional. If dims are provided, the variable is aligned with them with a transpose. + If the input is not an XTensorVariable, it is converted to one using `as_xtensor`. Dims are required if the input is not a scalar. + + The dimensions of the resulting XTensorVariable are always forwarded to the model object. + + The respective TensorVariable is registered in the model + """ + return _register_and_return_xtensor_variable(name, value, dims, model, RegularDeterministic) + + +def Potential( + name: str, value, dims: DimsWithEllipsis | None = None, model: Model | None = None +) -> XTensorVariable: + """Wrapper around pymc.Potential that returns an XtensorVariable. + + If the input is already an XTensorVariable, dims are optional. If dims are provided, the variable is aligned with them with a transpose. + If the input is not an XTensorVariable, it is converted to one using `as_xtensor`. Dims are required if the input is not a scalar. + + The dimensions of the resulting XTensorVariable are always forwarded to the model object. + + The respective TensorVariable is registered in the model. + """ + return _register_and_return_xtensor_variable(name, value, dims, model, RegularPotential) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index fb4fd5eb3a..1ccc8c06a2 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -2370,11 +2370,8 @@ def get_alpha_beta(cls, alpha=None, beta=None, mu=None, sigma=None): if (alpha is not None) and (beta is not None): pass elif (mu is not None) and (sigma is not None): - if isinstance(sigma, Variable): - sigma = check_parameters(sigma, sigma > 0, msg="sigma > 0") - else: - assert np.all(np.asarray(sigma) > 0) - alpha = mu**2 / sigma**2 + # Use sign of sigma to not let negative sigma fly by + alpha = (mu**2 / sigma**2) * pt.sign(sigma) beta = mu / sigma**2 else: raise ValueError( @@ -2496,13 +2493,10 @@ def _get_alpha_beta(cls, alpha, beta, mu, sigma): if beta is not None: pass else: - beta = 1 + beta = 1.0 elif (mu is not None) and (sigma is not None): - if isinstance(sigma, Variable): - sigma = check_parameters(sigma, sigma > 0, msg="sigma > 0") - else: - assert np.all(np.asarray(sigma) > 0) - alpha = (2 * sigma**2 + mu**2) / sigma**2 + # Use sign of sigma to not let negative sigma fly by + alpha = ((2 * sigma**2 + mu**2) / sigma**2) * pt.sign(sigma) beta = mu * (mu**2 + sigma**2) / sigma**2 else: raise ValueError( diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index f8712c51e5..2a46929522 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -32,7 +32,7 @@ from pytensor.graph.rewriting.basic import in2out from pytensor.graph.utils import MetaType from pytensor.tensor.basic import as_tensor_variable -from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.op import RandomVariable, RNGConsumerOp from pytensor.tensor.random.rewriting import local_subtensor_rv_lift from pytensor.tensor.random.utils import normalize_size_param from pytensor.tensor.rewriting.shape import ShapeFeature @@ -207,7 +207,7 @@ def __get__(self, owner_self, owner_cls): return self.fget(owner_self if owner_self is not None else owner_cls) -class SymbolicRandomVariable(MeasurableOp, OpFromGraph): +class SymbolicRandomVariable(MeasurableOp, RNGConsumerOp, OpFromGraph): """Symbolic Random Variable. This is a subclasse of `OpFromGraph` which is used to encapsulate the symbolic @@ -294,7 +294,10 @@ def default_output(cls_or_self) -> int | None: @staticmethod def get_input_output_type_idxs( extended_signature: str | None, - ) -> tuple[tuple[tuple[int], int | None, tuple[int]], tuple[tuple[int], tuple[int]]]: + ) -> tuple[ + tuple[tuple[int, ...], int | None, tuple[int, ...]], + tuple[tuple[int, ...], tuple[int, ...]], + ]: """Parse extended_signature and return indexes for *[rng], [size] and parameters as well as outputs.""" if extended_signature is None: raise ValueError("extended_signature must be provided") diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index ac2a13431a..f5a5506bba 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -149,11 +149,11 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs): raise ValueError("chol must be at least two dimensional.") if not lower: - chol = pt.swapaxes(chol, -1, -2) + chol = chol.mT # tag as lower triangular to enable pytensor rewrites of chol(l.l') -> l chol.tag.lower_triangular = True - cov = pt.matmul(chol, pt.swapaxes(chol, -1, -2)) + cov = chol @ chol.mT return cov diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 8ef378b76b..85177369ae 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -18,6 +18,7 @@ from collections.abc import Sequence from functools import singledispatch +from types import EllipsisType from typing import Any, TypeAlias, cast import numpy as np @@ -87,11 +88,13 @@ def _check_shape_type(shape): # User-provided can be lazily specified as scalars Shape: TypeAlias = int | TensorVariable | Sequence[int | Variable] Dims: TypeAlias = str | Sequence[str | None] +DimsWithEllipsis: TypeAlias = str | EllipsisType | Sequence[str | None | EllipsisType] Size: TypeAlias = int | TensorVariable | Sequence[int | Variable] # After conversion to vectors StrongShape: TypeAlias = TensorVariable | tuple[int | Variable, ...] -StrongDims: TypeAlias = Sequence[str | None] +StrongDims: TypeAlias = Sequence[str] +StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType] StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...] @@ -107,7 +110,24 @@ def convert_dims(dims: Dims | None) -> StrongDims | None: else: raise ValueError(f"The `dims` parameter must be a tuple, str or list. Actual: {type(dims)}") - return dims + return dims # type: ignore[return-value] + + +def convert_dims_with_ellipsis(dims: DimsWithEllipsis | None) -> StrongDimsWithEllipsis | None: + """Process a user-provided dims variable into None or a valid dims tuple with ellipsis.""" + if dims is None: + return None + + if isinstance(dims, str | EllipsisType): + dims = (dims,) + elif isinstance(dims, list | tuple): + dims = tuple(dims) + else: + raise ValueError( + f"The `dims` parameter must be a tuple, list, str or Ellipsis. Actual: {type(dims)}" + ) + + return dims # type: ignore[return-value] def convert_shape(shape: Shape) -> StrongShape | None: diff --git a/pymc/initial_point.py b/pymc/initial_point.py index df9419c744..b6b8dba3c9 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -20,8 +20,9 @@ import pytensor import pytensor.tensor as pt -from pytensor.graph.basic import Variable +from pytensor.graph.basic import Variable, ancestors from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB from pytensor.tensor.variable import TensorVariable from pymc.logprob.transforms import Transform @@ -37,6 +38,8 @@ StartDict = dict[Variable | str, np.ndarray | Variable | str] PointType = dict[str, np.ndarray] +initial_point_rewrites_db = SequenceDB() +initial_point_basic_query = RewriteDatabaseQuery(include=["basic"]) def convert_str_to_rv_dict( @@ -230,11 +233,20 @@ def make_initial_point_expression( if jitter_rvs is None: jitter_rvs = set() + # Clone free_rvs so we don't modify the original graph + initial_point_fgraph = FunctionGraph(outputs=free_rvs, clone=True) + + # Apply any rewrites necessary to compute the initial points. + initial_point_rewriter = initial_point_rewrites_db.query(initial_point_basic_query) + if initial_point_rewriter: + initial_point_rewriter.rewrite(initial_point_fgraph) + + free_rvs_clone = initial_point_fgraph.outputs + initial_values = [] initial_values_transformed = [] - - for variable in free_rvs: - strategy = initval_strategies.get(variable, None) + for original_variable, variable in zip(free_rvs, free_rvs_clone): + strategy = initval_strategies.get(original_variable) if strategy is None: strategy = default_strategy @@ -245,7 +257,7 @@ def make_initial_point_expression( value = support_point(variable) except NotImplementedError: warnings.warn( - f"Moment not defined for variable {variable} of type " + f"Support point not defined for variable {variable} of type " f"{variable.owner.op.__class__.__name__}, defaulting to " f"a draw from the prior. This can lead to difficulties " f"during tuning. You can manually define an initval or " @@ -261,14 +273,18 @@ def make_initial_point_expression( f'Invalid string strategy: {strategy}. It must be one of ["support_point", "prior"]' ) else: - value = pt.as_tensor(strategy, dtype=variable.dtype).astype(variable.dtype) + if isinstance(strategy, Variable) and (set(free_rvs) & set(ancestors([strategy]))): + raise ValueError( + f"Initial value of {original_variable} depends on other random variables. This is not supported anymore." + ) + value = pt.as_tensor(strategy, variable.dtype).astype(variable.dtype) - transform = rvs_to_transforms.get(variable, None) + transform = rvs_to_transforms.get(original_variable, None) if transform is not None: value = transform.forward(value, *variable.owner.inputs) - if variable in jitter_rvs: + if original_variable in jitter_rvs: jitter = pt.random.uniform(-1, 1, size=value.shape) jitter.name = f"{variable.name}_jitter" value = value + jitter @@ -281,28 +297,16 @@ def make_initial_point_expression( initial_values.append(value) - all_outputs: list[TensorVariable] = [] - all_outputs.extend(free_rvs) - all_outputs.extend(initial_values) - all_outputs.extend(initial_values_transformed) - - copy_graph = FunctionGraph(outputs=all_outputs, clone=True) - - n_variables = len(free_rvs) - free_rvs_clone = copy_graph.outputs[:n_variables] - initial_values_clone = copy_graph.outputs[n_variables:-n_variables] - initial_values_transformed_clone = copy_graph.outputs[-n_variables:] - # We now replace all rvs by the respective initial_point expressions # in the constrained (untransformed) space. We do this in reverse topological # order, so that later nodes do not reintroduce expressions with earlier # rvs that would need to once again be replaced by their initial_points - graph = FunctionGraph(outputs=free_rvs_clone, clone=False) - toposort_replace(graph, tuple(zip(free_rvs_clone, initial_values_clone)), reverse=True) + toposort_replace(initial_point_fgraph, tuple(zip(free_rvs_clone, initial_values)), reverse=True) if not return_transformed: - return graph.outputs + return initial_point_fgraph.outputs + # Because the unconstrained (transformed) expressions are a subgraph of the # constrained initial point they were also automatically updated inplace # when calling graph.replace_all above, so we don't need to do anything else - return initial_values_transformed_clone + return initial_values_transformed diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index 85d74aab3b..9fb76df169 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -46,6 +46,7 @@ Constant, Variable, ancestors, + walk, ) from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter from pytensor.tensor.variable import TensorVariable @@ -60,8 +61,8 @@ from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph from pymc.logprob.transform_value import TransformValuesRewrite from pymc.logprob.transforms import Transform -from pymc.logprob.utils import get_related_valued_nodes, rvs_in_graph -from pymc.pytensorf import replace_vars_in_graphs +from pymc.logprob.utils import get_related_valued_nodes +from pymc.pytensorf import expand_inner_graph, replace_vars_in_graphs TensorLike: TypeAlias = Variable | float | np.ndarray @@ -71,9 +72,13 @@ def _find_unallowed_rvs_in_graph(graph): from pymc.distributions.simulator import SimulatorRV return { - rv - for rv in rvs_in_graph(graph) - if not isinstance(rv.owner.op, SimulatorRV | MinibatchIndexRV) + var + for var in walk(graph, expand_inner_graph, False) + if ( + var.owner + and isinstance(var.owner.op, MeasurableOp) + and not isinstance(var.owner.op, SimulatorRV | MinibatchIndexRV) + ) } diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index b5a6b23a09..ea0202f00c 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -132,6 +132,7 @@ def remove_DiracDelta(fgraph, node): return [dd_val] +logprob_rewrites_basic_query = RewriteDatabaseQuery(include=["basic"]) logprob_rewrites_db = SequenceDB() logprob_rewrites_db.name = "logprob_rewrites_db" @@ -146,16 +147,21 @@ def remove_DiracDelta(fgraph, node): failure_callback=None, ), "basic", + position=0, ) # Introduce sigmoid. We do it before canonicalization so that useless mul are removed next logprob_rewrites_db.register( - "local_exp_over_1_plus_exp", out2in(local_exp_over_1_plus_exp), "basic" + "local_exp_over_1_plus_exp", + out2in(local_exp_over_1_plus_exp), + "basic", + position=0.9, ) logprob_rewrites_db.register( "pre-canonicalize", optdb.query("+canonicalize", "-local_eager_useless_unbatched_blockwise"), "basic", + position=1, ) # These rewrites convert un-measurable variables into their measurable forms, @@ -164,18 +170,26 @@ def remove_DiracDelta(fgraph, node): measurable_ir_rewrites_db = EquilibriumDB() measurable_ir_rewrites_db.name = "measurable_ir_rewrites_db" -logprob_rewrites_db.register("measurable_ir_rewrites", measurable_ir_rewrites_db, "basic") +logprob_rewrites_db.register( + "measurable_ir_rewrites", + measurable_ir_rewrites_db, + "basic", + position=2, +) # These rewrites push random/measurable variables "down", making them closer to # (or eventually) the graph outputs. Often this is done by lifting other `Op`s # "up" through the random/measurable variables and into their inputs. measurable_ir_rewrites_db.register("subtensor_lift", local_subtensor_rv_lift, "basic") -# These rewrites are used to introduce specalized operations with better logprob graphs +# These rewrites are used to introduce specialized operations with better logprob graphs specialization_ir_rewrites_db = EquilibriumDB() specialization_ir_rewrites_db.name = "specialization_ir_rewrites_db" logprob_rewrites_db.register( - "specialization_ir_rewrites_db", specialization_ir_rewrites_db, "basic" + "specialization_ir_rewrites_db", + specialization_ir_rewrites_db, + "basic", + position=3, ) @@ -183,6 +197,7 @@ def remove_DiracDelta(fgraph, node): "post-canonicalize", optdb.query("+canonicalize", "-local_eager_useless_unbatched_blockwise"), "basic", + position=4, ) # Rewrites that remove IR Ops @@ -192,6 +207,7 @@ def remove_DiracDelta(fgraph, node): "cleanup_ir_rewrites", TopoDB(cleanup_ir_rewrites_db, order="out_to_in", ignore_newtrees=True, failure_callback=None), "cleanup", + position=5, ) cleanup_ir_rewrites_db.register("remove_DiracDelta", remove_DiracDelta, "cleanup") @@ -250,7 +266,7 @@ def construct_ir_fgraph( toposort_replace(fgraph, replacements, reverse=True) if ir_rewriter is None: - ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])) + ir_rewriter = logprob_rewrites_db.query(logprob_rewrites_basic_query) ir_rewriter.rewrite(fgraph) # Reintroduce original value variables diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index a028341032..0f521393d4 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -43,21 +43,18 @@ from pytensor import tensor as pt from pytensor.graph import Apply, Op, node_rewriter -from pytensor.graph.basic import Constant, Variable, clone_get_equiv, graph_inputs, walk +from pytensor.graph.basic import Constant, clone_get_equiv, graph_inputs, walk from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import HasInnerGraph from pytensor.link.c.type import CType from pytensor.raise_op import CheckAndRaise from pytensor.scalar.basic import Mul from pytensor.tensor.basic import get_underlying_scalar_constant_value from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable from pymc.logprob.abstract import MeasurableOp, ValuedRV, _logprob from pymc.pytensorf import replace_vars_in_graphs -from pymc.util import makeiter if typing.TYPE_CHECKING: from pymc.logprob.transforms import Transform @@ -130,26 +127,6 @@ def populate_replacements(var): return replace_vars_in_graphs(graphs, replacements) -def rvs_in_graph(vars: Variable | Sequence[Variable]) -> set[Variable]: - """Assert that there are no `MeasurableOp` nodes in a graph.""" - - def expand(r): - owner = r.owner - if owner: - inputs = list(reversed(owner.inputs)) - - if isinstance(owner.op, HasInnerGraph): - inputs += owner.op.inner_outputs - - return inputs - - return { - node - for node in walk(makeiter(vars), expand, False) - if node.owner and isinstance(node.owner.op, RandomVariable | MeasurableOp) - } - - def convert_indices(indices, entry): if indices and isinstance(entry, CType): rval = indices.pop(0) @@ -334,3 +311,16 @@ def get_related_valued_nodes(fgraph: FunctionGraph, node: Apply) -> list[Apply]: for client, _ in clients[out] if isinstance(client.op, ValuedRV) ] + + +def __getattr__(name): + if name == "rvs_in_graphs": + warnings.warn( + f"{name} has been moved to `pymc.pytensorf`. Importing from `pymc.logprob.utils` will fail in a future release.", + FutureWarning, + ) + from pymc.pytensorf import rvs_in_graph + + return rvs_in_graph() + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/pymc/math.py b/pymc/math.py index 13655f5345..65ddacfb95 100644 --- a/pymc/math.py +++ b/pymc/math.py @@ -33,6 +33,7 @@ arcsinh, arctan, arctanh, + as_tensor, broadcast_to, ceil, clip, @@ -42,6 +43,7 @@ cosh, cumprod, cumsum, + diff, dot, eq, erf, @@ -103,6 +105,7 @@ "arcsinh", "arctan", "arctanh", + "as_tensor", "batched_diag", "block_diagonal", "broadcast_to", @@ -115,6 +118,7 @@ "cosh", "cumprod", "cumsum", + "diff", "dot", "eq", "erf", diff --git a/pymc/model/transform/conditioning.py b/pymc/model/transform/conditioning.py index edcf5862b1..6c40ab5636 100644 --- a/pymc/model/transform/conditioning.py +++ b/pymc/model/transform/conditioning.py @@ -22,7 +22,6 @@ from pytensor.tensor import TensorVariable from pymc.logprob.transforms import Transform -from pymc.logprob.utils import rvs_in_graph from pymc.model.core import Model from pymc.model.fgraph import ( ModelDeterministic, @@ -41,7 +40,7 @@ parse_vars, prune_vars_detached_from_observed, ) -from pymc.pytensorf import replace_vars_in_graphs, toposort_replace +from pymc.pytensorf import replace_vars_in_graphs, rvs_in_graph, toposort_replace from pymc.util import get_transformed_name, get_untransformed_name diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index f1d69c9282..3ce3081477 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -33,13 +33,15 @@ clone_get_equiv, equal_computations, graph_inputs, + walk, ) from pytensor.graph.fg import FunctionGraph, Output +from pytensor.graph.op import HasInnerGraph from pytensor.scalar.basic import Cast from pytensor.scan.op import Scan from pytensor.tensor.basic import _as_tensor_variable from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.op import RandomVariable, RNGConsumerOp from pytensor.tensor.random.type import RandomType from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding @@ -133,6 +135,9 @@ def dataframe_to_tensor_variable(df: pd.DataFrame, *args, **kwargs) -> TensorVar return pt.as_tensor_variable(df.to_numpy(), *args, **kwargs) +_cheap_eval_mode = Mode(linker="py", optimizer="minimum_compile") + + def extract_obs_data(x: TensorVariable) -> np.ndarray: """Extract data from observed symbolic variables. @@ -161,15 +166,31 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray: mask[mask_idx] = 1 return np.ma.MaskedArray(array_data, mask) - from pymc.logprob.utils import rvs_in_graph - if not inputvars(x) and not rvs_in_graph(x): - cheap_eval_mode = Mode(linker="py", optimizer=None) - return x.eval(mode=cheap_eval_mode) + return x.eval(mode=_cheap_eval_mode) raise TypeError(f"Data cannot be extracted from {x}") +def expand_inner_graph(r): + if (node := r.owner) is not None: + inputs = list(reversed(node.inputs)) + + if isinstance(node.op, HasInnerGraph): + inputs += node.op.inner_outputs + + return inputs + + +def rvs_in_graph(vars: Variable | Sequence[Variable], rv_ops=None) -> set[Variable]: + """Assert that there are no random nodes in a graph.""" + return { + var + for var in walk(makeiter(vars), expand_inner_graph, False) + if (var.owner and isinstance(var.owner.op, RNGConsumerOp)) + } + + def replace_vars_in_graphs( graphs: Iterable[Variable], replacements: dict[Variable, Variable], @@ -720,8 +741,6 @@ def scan_step(xtm1): xs_draws = pm.draw(xs, draws=10) """ - # Avoid circular import - from pymc.distributions.distribution import SymbolicRandomVariable def find_default_update(clients, rng: Variable) -> None | Variable: rng_clients = clients.get(rng, None) @@ -764,48 +783,47 @@ def find_default_update(clients, rng: Variable) -> None | Variable: [client, _] = rng_clients[0] # RNG is an output of the function, this is not a problem - if isinstance(client.op, Output): - return None + client_op = client.op - # RNG is used by another operator, which should output an update for the RNG - if isinstance(client.op, RandomVariable): - # RandomVariable first output is always the update of the input RNG - next_rng = client.outputs[0] - - elif isinstance(client.op, SymbolicRandomVariable): - # SymbolicRandomVariable have an explicit method that returns an - # update mapping for their RNG(s) - next_rng = client.op.update(client).get(rng) - if next_rng is None: - raise ValueError( - f"No update found for at least one RNG used in SymbolicRandomVariable Op {client.op}" - ) - elif isinstance(client.op, Scan): - # Check if any shared output corresponds to the RNG - rng_idx = client.inputs.index(rng) - io_map = client.op.get_oinp_iinp_iout_oout_mappings()["outer_out_from_outer_inp"] - out_idx = io_map.get(rng_idx, -1) - if out_idx != -1: - next_rng = client.outputs[out_idx] - else: # No break - raise ValueError( - f"No update found for at least one RNG used in Scan Op {client.op}.\n" - "You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically." - ) - elif isinstance(client.op, OpFromGraph): - try: - next_rng = collect_default_updates_inner_fgraph(client).get(rng) + match client_op: + case Output(): + return None + # Otherwise, RNG is used by another operator, which should output an update for the RNG + case RandomVariable(): + # RandomVariable first output is always the update of the input RNG + next_rng = client.outputs[0] + case RNGConsumerOp(): + # RNGConsumerOp have an explicit method that returns an update mapping for their RNG(s) + # RandomVariable is a subclass of RNGConsumerOp, but we specialize above for speedup + next_rng = client_op.update(client).get(rng) if next_rng is None: - # OFG either does not make use of this RNG or inconsistent use that will have emitted a warning - return None - except ValueError as exc: - raise ValueError( - f"No update found for at least one RNG used in OpFromGraph Op {client.op}.\n" - "You can use `pytensorf.collect_default_updates` and include those updates as outputs." - ) from exc - else: - # We don't know how this RNG should be updated. The user should provide an update manually - return None + raise ValueError(f"No update found for at least one RNG used in {client_op}") + case Scan(): + # Check if any shared output corresponds to the RNG + rng_idx = client.inputs.index(rng) + io_map = client_op.get_oinp_iinp_iout_oout_mappings()["outer_out_from_outer_inp"] + out_idx = io_map.get(rng_idx, -1) + if out_idx != -1: + next_rng = client.outputs[out_idx] + else: # No break + raise ValueError( + f"No update found for at least one RNG used in Scan Op {client_op}.\n" + "You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically." + ) + case OpFromGraph(): + try: + next_rng = collect_default_updates_inner_fgraph(client).get(rng) + if next_rng is None: + # OFG either does not make use of this RNG or inconsistent use that will have emitted a warning + return None + except ValueError as exc: + raise ValueError( + f"No update found for at least one RNG used in OpFromGraph Op {client_op}.\n" + "You can use `pytensorf.collect_default_updates` and include those updates as outputs." + ) from exc + case _: + # We don't know how this RNG should be updated. The user should provide an update manually + return None # Recurse until we find final update for RNG nested_next_rng = find_default_update(clients, next_rng) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 1be14f77f3..4a794f8f71 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -55,9 +55,8 @@ from pymc.backends.base import MultiTrace from pymc.blocking import PointType from pymc.distributions.shape_utils import change_dist_size -from pymc.logprob.utils import rvs_in_graph from pymc.model import Model, modelcontext -from pymc.pytensorf import compile +from pymc.pytensorf import compile, rvs_in_graph from pymc.util import ( CustomProgress, RandomState, diff --git a/pymc/testing.py b/pymc/testing.py index b016c25ad1..886177ef02 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -24,11 +24,13 @@ from arviz import InferenceData from numpy import random as nr from numpy import testing as npt +from pytensor.compile import SharedVariable from pytensor.compile.mode import Mode -from pytensor.graph.basic import Variable +from pytensor.graph.basic import Constant, Variable, equal_computations, graph_inputs from pytensor.graph.rewriting.basic import in2out from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.type import RandomType from scipy import special as sp from scipy import stats as st @@ -41,9 +43,8 @@ from pymc.logprob.utils import ( ParameterValueError, local_check_parameter_to_ninf_switch, - rvs_in_graph, ) -from pymc.pytensorf import compile, floatX, inputvars +from pymc.pytensorf import compile, floatX, inputvars, rvs_in_graph # This mode can be used for tests where model compilations takes the bulk of the runtime # AND where we don't care about posterior numerical or sampling stability (e.g., when @@ -971,8 +972,7 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable: def assert_no_rvs(vars: Sequence[Variable]) -> None: """Assert that there are no `MeasurableOp` nodes in a graph.""" - rvs = rvs_in_graph(vars) - if rvs: + if rvs := rvs_in_graph(vars): raise AssertionError(f"RV found in graph: {rvs}") @@ -1086,3 +1086,28 @@ def test_model_inference(mock_pymc_sample): pm.sample = original_sample pm.Flat = original_flat pm.HalfFlat = original_half_flat + + +def equal_computations_up_to_root( + xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True +) -> bool: + # Check if graphs are equivalent even if root variables have distinct identities + + x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)] + y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)] + if len(x_graph_inputs) != len(y_graph_inputs): + return False + for x, y in zip(x_graph_inputs, y_graph_inputs): + if x.type != y.type: + return False + if x.name != y.name: + return False + if isinstance(x, SharedVariable): + if not isinstance(y, SharedVariable): + return False + if isinstance(x.type, RandomType) and ignore_rng_values: + continue + if not x.type.values_eq(x.get_value(), y.get_value()): + return False + + return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs) # type: ignore[arg-type] diff --git a/pyproject.toml b/pyproject.toml index a8ffb06eed..e999a19bb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,9 @@ ignore = [ "D101", # Missing docstring in public class "D102", # Missing docstring in public method "D103", # Missing docstring in public function + "D104", # Missing docstring in public package "D105", # Missing docstring in magic method + "D401", # Ignore Umbridge level of control ] [tool.ruff.lint.pydocstyle] @@ -66,6 +68,13 @@ lines-between-types = 1 "pymc/__init__.py" = [ "E402", # Module level import not at top of file ] +"pymc/dims/__init__.py" = [ + "E402", # Module level import not at top of file +] +"pymc/dims/math.py" = [ + "F401", # Module imported but unused + "F403", # 'from module import *' used; unable to detect undefined names +] "pymc/stats/__init__.py" = [ "E402", # Module level import not at top of file ] diff --git a/tests/dims/__init__.py b/tests/dims/__init__.py new file mode 100644 index 0000000000..00e50af6b0 --- /dev/null +++ b/tests/dims/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/dims/distributions/__init__.py b/tests/dims/distributions/__init__.py new file mode 100644 index 0000000000..00e50af6b0 --- /dev/null +++ b/tests/dims/distributions/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/dims/distributions/test_core.py b/tests/dims/distributions/test_core.py new file mode 100644 index 0000000000..3f33a0ce18 --- /dev/null +++ b/tests/dims/distributions/test_core.py @@ -0,0 +1,191 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +import numpy as np +import pytest + +import pymc as pm + +from pymc import dims as pmx + +pytestmark = pytest.mark.filterwarnings("error") + + +def test_distribution_dims(): + coords = { + "a": range(2), + "b": range(3), + "c": range(5), + "d": range(7), + } + with pm.Model(coords=coords) as model: + x = pmx.Data("x", np.random.randn(2, 3, 5), dims=("a", "b", "c")) + y1 = pmx.Normal("y1", mu=x) + assert y1.type.dims == ("a", "b", "c") + assert y1.eval().shape == (2, 3, 5) + + y2 = pmx.Normal("y2", mu=x, dims=("a", "b", "c")) # redundant + assert y2.type.dims == ("a", "b", "c") + assert y2.eval().shape == (2, 3, 5) + + y3 = pmx.Normal("y3", mu=x, dims=("b", "a", "c")) # Implies a transpose + assert y3.type.dims == ("b", "a", "c") + assert y3.eval().shape == (3, 2, 5) + + y4 = pmx.Normal("y4", mu=x, dims=("a", ...)) + assert y4.type.dims == ("a", "b", "c") + assert y4.eval().shape == (2, 3, 5) + + y5 = pmx.Normal("y5", mu=x, dims=("b", ...)) # Implies a transpose + assert y5.type.dims == ("b", "a", "c") + assert y5.eval().shape == (3, 2, 5) + + y6 = pmx.Normal("y6", mu=x, dims=("b", ..., "a")) # Implies a transpose + assert y6.type.dims == ("b", "c", "a") + assert y6.eval().shape == (3, 5, 2) + + y7 = pmx.Normal("y7", mu=x, dims=(..., "b")) # Implies a transpose + assert y7.type.dims == ("a", "c", "b") + assert y7.eval().shape == (2, 5, 3) + + y8 = pmx.Normal("y8", mu=x, dims=("d", "a", "b", "c")) # Adds an extra dimension + assert y8.type.dims == ("d", "a", "b", "c") + assert y8.eval().shape == (7, 2, 3, 5) + + y9 = pmx.Normal("y9", mu=x, dims=("d", ...)) # Adds an extra dimension + assert y9.type.dims == ("d", "a", "b", "c") + assert y9.eval().shape == (7, 2, 3, 5) + + y10 = pmx.Normal( + "y10", mu=x, dims=("b", "a", "c", "d") + ) # Adds an extra dimension and implies a transpose + assert y10.type.dims == ("b", "a", "c", "d") + assert y10.eval().shape == (3, 2, 5, 7) + + y11 = pmx.Normal( + "y11", mu=x, dims=("c", ..., "d") + ) # Adds an extra dimension and implies a transpose + assert y11.type.dims == ("c", "a", "b", "d") + assert y11.eval().shape == (5, 2, 3, 7) + + # Invalid cases + err_msg = "Provided dims ('a', 'b') do not match the distribution's output dims ('a', 'b', 'c'). Use ellipsis to specify all other dimensions." + with pytest.raises(ValueError, match=re.escape(err_msg)): + # Missing a dimension + pmx.Normal("y_bad", mu=x, dims=("a", "b")) + + err_msg = "Provided dims ('d',) do not match the distribution's output dims ('d', 'a', 'b', 'c'). Use ellipsis to specify all other dimensions." + with pytest.raises(ValueError, match=re.escape(err_msg)): + # Only specifies the extra dimension + pmx.Normal("y_bad", mu=x, dims=("d",)) + + err_msg = "Not all dims ('a', 'b', 'c', 'e') are part of the model coords. Add them at initialization time or use `model.add_coord` before defining the distribution" + with pytest.raises(ValueError, match=re.escape(err_msg)): + # Extra dimension not in coords + pmx.Normal("y_bad", mu=x, dims=("a", "b", "c", "e")) + + +def test_multivariate_distribution_dims(): + coords = { + "batch": range(2), + "core1": range(3), + "core2": range(3), + } + with pm.Model(coords=coords) as m: + mu = pmx.Normal("mu", dims=("batch", "core1")) + chol, _, _ = pm.LKJCholeskyCov( + "chol", + eta=1, + n=3, + sd_dist=pm.Exponential.dist(1), + ) + chol_xr = pmx.math.as_xtensor(chol, dims=("core1", "core2")) + + x1 = pmx.MvNormal( + "x1", + mu, + chol=chol_xr, + core_dims=("core1", "core2"), + ) + assert x1.type.dims == ("batch", "core1") + assert x1.eval().shape == (2, 3) + + x2 = pmx.MvNormal( + "x2", + mu, + chol=chol_xr, + core_dims=("core1", "core2"), + dims=("batch", "core1"), + ) + assert x2.type.dims == ("batch", "core1") + assert x2.eval().shape == (2, 3) + + x3 = pmx.MvNormal( + "x3", + mu, + chol=chol_xr, + core_dims=("core2", "core1"), + dims=("batch", ...), + ) + assert x3.type.dims == ("batch", "core1") + assert x3.eval().shape == (2, 3) + + x4 = pmx.MvNormal( + "x4", + mu, + chol=chol_xr, + core_dims=("core1", "core2"), + # Implies transposition + dims=("core1", ...), + ) + assert x4.type.dims == ("core1", "batch") + assert x4.eval().shape == (3, 2) + + # Errors + err_msg = "MvNormal requires 2 core_dims" + with pytest.raises(ValueError, match=re.escape(err_msg)): + # Missing core_dims + pmx.MvNormal( + "x_bad", + mu, + chol=chol_xr, + ) + + with pytest.raises(ValueError, match="Dimension batch not found in either input"): + pmx.MvNormal( + "x_bad", + mu, + chol=chol_xr, + # Invalid because batch is not on chol_xr + core_dims=("core1", "batch"), + ) + + with pytest.raises(ValueError, match="Parameter mu_renamed has invalid core dimensions"): + mu_renamed = mu.rename({"batch": "core2"}) + mu_renamed.name = "mu_renamed" + pmx.MvNormal( + "x_bad", + mu_renamed, + chol=chol_xr, + # Invalid because mu has both core dimensions (after renaming) + core_dims=("core1", "core2"), + ) + + # Invalid because core2 is not a core output dimension + err_msg = "Dimensions {'core2'} do not exist. Expected one or more of: ('batch', 'core1')" + with pytest.raises(ValueError, match=re.escape(err_msg)): + pmx.MvNormal( + "x_bad", mu, chol=chol_xr, core_dims=("core1", "core2"), dims=("core2", ...) + ) diff --git a/tests/dims/distributions/test_scalar.py b/tests/dims/distributions/test_scalar.py new file mode 100644 index 0000000000..a487591c01 --- /dev/null +++ b/tests/dims/distributions/test_scalar.py @@ -0,0 +1,217 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pymc import Model +from pymc import distributions as regular_distributions +from pymc.dims import ( + Beta, + Cauchy, + Exponential, + Flat, + Gamma, + HalfCauchy, + HalfFlat, + HalfNormal, + InverseGamma, + Laplace, + LogNormal, + Normal, + StudentT, +) +from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph + + +def test_flat(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + Flat("x", dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.Flat("x", dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_halfflat(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + HalfFlat("x", dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.HalfFlat("x", dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_normal(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + Normal("x", dims="a") + Normal("y", mu=2, sigma=3, dims="a") + Normal("z", mu=-2, tau=3, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.Normal("x", dims="a") + regular_distributions.Normal("y", mu=2, sigma=3, dims="a") + regular_distributions.Normal("z", mu=-2, tau=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_halfnormal(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + HalfNormal("x", dims="a") + HalfNormal("y", sigma=3, dims="a") + HalfNormal("z", tau=3, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.HalfNormal("x", dims="a") + regular_distributions.HalfNormal("y", sigma=3, dims="a") + regular_distributions.HalfNormal("z", tau=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_lognormal(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + LogNormal("x", dims="a") + LogNormal("y", mu=2, sigma=3, dims="a") + LogNormal("z", mu=-2, tau=3, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.LogNormal("x", dims="a") + regular_distributions.LogNormal("y", mu=2, sigma=3, dims="a") + regular_distributions.LogNormal("z", mu=-2, tau=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_studentt(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + StudentT("x", nu=1, dims="a") + StudentT("y", nu=1, mu=2, sigma=3, dims="a") + StudentT("z", nu=1, mu=-2, lam=3, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.StudentT("x", nu=1, dims="a") + regular_distributions.StudentT("y", nu=1, mu=2, sigma=3, dims="a") + regular_distributions.StudentT("z", nu=1, mu=-2, lam=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_cauchy(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + Cauchy("x", alpha=1, beta=2, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.Cauchy("x", alpha=1, beta=2, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_halfcauchy(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + HalfCauchy("x", beta=2, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.HalfCauchy("x", beta=2, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_beta(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + Beta("w", alpha=1, beta=1, dims="a") + Beta("x", mu=0.5, sigma=0.1, dims="a") + Beta("y", mu=0.5, nu=10, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.Beta("w", alpha=1, beta=1, dims="a") + regular_distributions.Beta("x", mu=0.5, sigma=0.1, dims="a") + regular_distributions.Beta("y", mu=0.5, nu=10, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_laplace(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + Laplace("x", dims="a") + Laplace("y", mu=1, b=2, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.Laplace("x", mu=0, b=1, dims="a") + regular_distributions.Laplace("y", mu=1, b=2, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_exponential(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + Exponential("x", dims="a") + Exponential("y", lam=2, dims="a") + Exponential("z", scale=3, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.Exponential("x", dims="a") + regular_distributions.Exponential("y", lam=2, dims="a") + regular_distributions.Exponential("z", scale=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_gamma(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + # Gamma("w", alpha=1, beta=1, dims="a") + Gamma("x", mu=2, sigma=3, dims="a") + + with Model(coords=coords) as reference_model: + # regular_distributions.Gamma("w", alpha=1, beta=1, dims="a") + regular_distributions.Gamma("x", mu=2, sigma=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_inverse_gamma(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + InverseGamma("w", alpha=1, beta=1, dims="a") + InverseGamma("x", mu=2, sigma=3, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.InverseGamma("w", alpha=1, beta=1, dims="a") + regular_distributions.InverseGamma("x", mu=2, sigma=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) diff --git a/tests/dims/distributions/test_vector.py b/tests/dims/distributions/test_vector.py new file mode 100644 index 0000000000..0f08505dba --- /dev/null +++ b/tests/dims/distributions/test_vector.py @@ -0,0 +1,62 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytensor.tensor as pt + +from pytensor.xtensor import as_xtensor + +import pymc.distributions as regular_distributions + +from pymc import Model +from pymc.dims import Categorical, MvNormal +from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph + + +def test_categorical(): + coords = {"a": range(3), "b": range(4)} + p = pt.as_tensor([0.1, 0.2, 0.3, 0.4]) + p_xr = as_xtensor(p, dims=("b",)) + + with Model(coords=coords) as model: + Categorical("x", p=p_xr, core_dims="b", dims=("a",)) + Categorical("y", logit_p=p_xr, core_dims="b", dims=("a",)) + + with Model(coords=coords) as reference_model: + regular_distributions.Categorical("x", p=p, dims=("a",)) + regular_distributions.Categorical("y", logit_p=p, dims=("a",)) + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_mvnormal(): + coords = {"a": range(3), "b": range(2)} + mu = pt.as_tensor([1, 2]) + cov = pt.as_tensor([[1, 0.5], [0.5, 2]]) + chol = pt.as_tensor([[1, 0], [0.5, np.sqrt(1.75)]]) + + mu_xr = as_xtensor(mu, dims=("b",)) + cov_xr = as_xtensor(cov, dims=("b", "b'")) + chol_xr = as_xtensor(chol, dims=("b", "b'")) + + with Model(coords=coords) as model: + MvNormal("x", mu=mu_xr, cov=cov_xr, core_dims=("b", "b'"), dims=("a", "b")) + MvNormal("y", mu=mu_xr, chol=chol_xr, core_dims=("b", "b'"), dims=("a", "b")) + + with Model(coords=coords) as reference_model: + regular_distributions.MvNormal("x", mu=mu, cov=cov, dims=("a", "b")) + regular_distributions.MvNormal("y", mu=mu, chol=chol, dims=("a", "b")) + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) diff --git a/tests/dims/test_model.py b/tests/dims/test_model.py new file mode 100644 index 0000000000..bf27bf482e --- /dev/null +++ b/tests/dims/test_model.py @@ -0,0 +1,174 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from pytensor.xtensor.type import XTensorType +from xarray import DataArray + +import pymc as pm + +from pymc import dims as pmd +from pymc import observe + +pytestmark = pytest.mark.filterwarnings("error") + + +def test_data(): + x_np = np.random.randn(10, 2, 3) + + with pm.Model() as m: + x1 = pmd.Data("x", x_np, dims=("a", "b", "c")) + assert isinstance(x1.type, XTensorType) + assert x1.type.dims == ("a", "b", "c") + + +def test_simple_model(): + coords = {"a": range(3), "b": range(5)} + + with pm.Model(coords=coords) as model: + x = pmd.Normal("x", mu=1, dims=("a", "b")) + sigma = pmd.HalfNormal("sigma", dims=("a",)) + y = pmd.Normal("y", mu=x.T * 2, sigma=sigma, dims=("b", "a")) + + with pm.Model(coords=coords) as xmodel: + x = pmd.Normal("x", mu=1, dims=("a", "b")) + sigma = pmd.HalfNormal("sigma", dims=("a",)) + # Imply a transposition + y = pmd.Normal("y", mu=x * 2, sigma=sigma, dims=("b", "a")) + + assert x.type.dims == ("a", "b") + assert sigma.type.dims == ("a",) + assert y.type.dims == ("b", "a") + + ip = model.initial_point() + xip = xmodel.initial_point() + assert ip.keys() == xip.keys() + for value, xvalue in zip(ip.values(), xip.values()): + np.testing.assert_allclose(value, xvalue) + + logp = model.compile_logp()(ip) + xlogp = xmodel.compile_logp()(xip) + np.testing.assert_allclose(logp, xlogp) + + dlogp = model.compile_dlogp()(ip) + xdlogp = xmodel.compile_dlogp()(xip) + np.testing.assert_allclose(dlogp, xdlogp) + + draw = pm.draw(xmodel["y"], random_seed=1) + draw_same = pm.draw(xmodel["y"], random_seed=1) + draw_diff = pm.draw(xmodel["y"], random_seed=2) + assert draw.shape == (5, 3) + np.testing.assert_allclose(draw, draw_same) + assert not np.allclose(draw, draw_diff) + + observed_values = DataArray(np.ones((3, 5)), dims=("a", "b")).transpose() + with observe(xmodel, {"y": observed_values}): + pm.sample_prior_predictive() + idata = pm.sample( + tune=200, chains=2, draws=50, compute_convergence_checks=False, progressbar=False + ) + pm.sample_posterior_predictive(idata, progressbar=False) + + +def test_complex_model(): + N = 100 + rng = np.random.default_rng(4) + x_np = np.linspace(0, 10, N) + y_np = np.piecewise( + x_np, + [x_np <= 3, (x_np > 3) & (x_np <= 7), x_np > 7], + [lambda x: 0.5 * x, lambda x: 1.5 + 0.2 * (x - 3), lambda x: 2.3 - 0.1 * (x - 7)], + ) + y_np += rng.normal(0, 0.2, size=N) + group_idx_np = rng.choice(3, size=N) + N_knots = 13 + knots_np = np.linspace(0, 10, num=N_knots) + + coords = { + "group": range(3), + "knots": range(N_knots), + "obs": range(N), + } + + with pm.Model(coords=coords) as model: + x = pm.Data("x", x_np, dims="obs") + knots = pm.Data("knots", knots_np, dims="knot") + + sigma = pm.HalfCauchy("sigma", beta=1) + sigma_beta0 = pm.HalfNormal("sigma_beta0", sigma=10) + beta0 = pm.HalfNormal("beta_0", sigma=sigma_beta0, dims="group") + z = pm.Normal("z", dims=("group", "knot")) + + delta_factors = pm.math.softmax(z, axis=-1) # (groups, knot) + slope_factors = 1 - delta_factors[:, :-1].cumsum(axis=-1) # (groups, knot-1) + spline_slopes = pm.math.concatenate( + [beta0[:, None], beta0[:, None] * slope_factors], axis=-1 + ) # (groups, knot-1) + beta = pm.math.concatenate( + [beta0[:, None], pm.math.diff(spline_slopes)], axis=-1 + ) # (groups, knot) + + beta = pm.Deterministic("beta", beta, dims=("group", "knot")) + + X = pm.math.maximum(0, x[:, None] - knots[None, :]) # (n, knot) + mu = (X * beta[group_idx_np]).sum(-1) # ((n, knots) * (n, knots)).sum(-1) = (n,) + y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y_np, dims="obs") + + with pm.Model(coords=coords) as xmodel: + x = pmd.Data("x", x_np, dims="obs") + y = pmd.Data("y", y_np, dims="obs") + knots = pmd.Data("knots", knots_np, dims=("knot",)) + group_idx = pmd.math.as_xtensor(group_idx_np, dims=("obs",)) + + sigma = pmd.HalfCauchy("sigma", beta=1) + sigma_beta0 = pmd.HalfNormal("sigma_beta0", sigma=10) + beta0 = pmd.HalfNormal("beta_0", sigma=sigma_beta0, dims=("group",)) + z = pmd.Normal("z", dims=("group", "knot")) + + delta_factors = pmd.math.softmax(z, dim="knot") + slope_factors = 1 - delta_factors.isel(knot=slice(None, -1)).cumsum("knot") + spline_slopes = pmd.concat([beta0, beta0 * slope_factors], dim="knot") + beta = pmd.concat([beta0, spline_slopes.diff("knot")], dim="knot") + + beta = pm.Deterministic("beta", beta) + + X = pmd.math.maximum(0, x - knots) + mu = (X * beta.isel(group=group_idx)).sum("knot") + y_obs = pmd.Normal("y_obs", mu=mu, sigma=sigma, observed=y) + + # Test initial point + model_ip = model.initial_point() + xmodel_ip = xmodel.initial_point() + assert model_ip.keys() == xmodel_ip.keys() + for value, xvalue in zip(model_ip.values(), xmodel_ip.values()): + np.testing.assert_allclose(value, xvalue) + + # Test logp + model_logp = model.compile_logp()(model_ip) + xmodel_logp = xmodel.compile_logp()(xmodel_ip) + np.testing.assert_allclose(model_logp, xmodel_logp) + + # Test random draws + model_draw = pm.draw(model["y_obs"], random_seed=1) + xmodel_draw = pm.draw(xmodel["y_obs"], random_seed=1) + np.testing.assert_allclose(model_draw, xmodel_draw) + np.testing.assert_allclose(model_draw, xmodel_draw) + + with xmodel: + pm.sample_prior_predictive() + idata = pm.sample( + tune=200, chains=2, draws=50, compute_convergence_checks=False, progressbar=False + ) + pm.sample_posterior_predictive(idata, progressbar=False) diff --git a/tests/dims/utils.py b/tests/dims/utils.py new file mode 100644 index 0000000000..e84dba4a55 --- /dev/null +++ b/tests/dims/utils.py @@ -0,0 +1,64 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytensor.graph import rewrite_graph +from pytensor.printing import debugprint + +from pymc import Model +from pymc.testing import equal_computations_up_to_root + + +def assert_equivalent_random_graph(model: Model, reference_model: Model) -> bool: + """Check if the random graph of a model with xtensor variables is equivalent.""" + lowered_model = rewrite_graph( + model.basic_RVs + model.deterministics + model.potentials, + include=( + "lower_xtensor", + "inline_ofg_expansion_xtensor", + "canonicalize", + "local_remove_all_assert", + ), + ) + reference_lowered_model = rewrite_graph( + reference_model.basic_RVs + reference_model.deterministics + reference_model.potentials, + include=( + "inline_ofg_expansion", + "canonicalize", + "local_remove_all_assert", + ), + ) + assert equal_computations_up_to_root( + lowered_model, + reference_lowered_model, + ignore_rng_values=True, + ), debugprint(lowered_model + reference_lowered_model, print_type=True) + + +def assert_equivalent_logp_graph(model: Model, reference_model: Model) -> bool: + """Check if the logp graph of a model with xtensor variables is equivalent.""" + lowered_model_logp = rewrite_graph( + [model.logp()], + include=("lower_xtensor", "canonicalize", "local_remove_all_assert"), + ) + reference_lowered_model_logp = rewrite_graph( + [reference_model.logp()], + include=("canonicalize", "local_remove_all_assert"), + ) + assert equal_computations_up_to_root( + lowered_model_logp, + reference_lowered_model_logp, + ignore_rng_values=False, + ), debugprint( + lowered_model_logp + reference_lowered_model_logp, + print_type=True, + ) diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 9c108d2035..7209382666 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -693,7 +693,8 @@ def test_inverse_gamma_logcdf(self): ) def test_inverse_gamma_alt_params(self): def test_fun(value, mu, sigma): - alpha, beta = pm.InverseGamma._get_alpha_beta(None, None, mu, sigma) + alpha = (2 * sigma**2 + mu**2) / sigma**2 + beta = mu * (mu**2 + sigma**2) / sigma**2 return st.invgamma.logpdf(value, alpha, scale=beta) check_logp( diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index f84fb8f869..a41df3481c 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -186,7 +186,7 @@ def update(self, node): )(rng) with pytest.raises( ValueError, - match="No update found for at least one RNG used in SymbolicRandomVariable Op SymbolicRVCustomUpdates", + match="No update found for at least one RNG used in SymbolicRVCustomUpdates", ): compile(inputs=[], outputs=x, random_seed=431) diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index cb4b8520b9..2597605bd1 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -1617,9 +1617,10 @@ def test_zsn_shape(self, n_zerosum_axes): def test_zsn_fail_axis(self, error, match, shape, support_shape, n_zerosum_axes): with pytest.raises(error, match=match): with pm.Model() as m: - _ = pm.ZeroSumNormal( + v = pm.ZeroSumNormal( "v", shape=shape, support_shape=support_shape, n_zerosum_axes=n_zerosum_axes ) + v.eval() @pytest.mark.parametrize( "shape, support_shape", diff --git a/tests/distributions/test_shape_utils.py b/tests/distributions/test_shape_utils.py index d0f3f1b432..8d787df611 100644 --- a/tests/distributions/test_shape_utils.py +++ b/tests/distributions/test_shape_utils.py @@ -31,6 +31,7 @@ from pymc.distributions.shape_utils import ( change_dist_size, convert_dims, + convert_dims_with_ellipsis, convert_shape, convert_size, get_support_shape, @@ -297,6 +298,14 @@ def test_convert_dims(self): assert convert_dims(dims="town") == ("town",) with pytest.raises(ValueError, match="must be a tuple, str or list"): convert_dims(3) + with pytest.raises(ValueError, match="must be a tuple, str or list"): + convert_dims(...) + + def test_convert_dims_with_ellipsis(self): + assert convert_dims_with_ellipsis(dims="town") == ("town",) + assert convert_dims_with_ellipsis(...) == (...,) + with pytest.raises(ValueError, match="must be a tuple, list, str or Ellipsis"): + convert_dims_with_ellipsis(3) def test_convert_shape(self): assert convert_shape(5) == (5,) diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index df8bb2dbf2..3dd30e14f7 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -35,9 +35,8 @@ import pymc as pm from pymc.backends.base import MultiTrace -from pymc.logprob.utils import rvs_in_graph from pymc.model.transform.optimization import freeze_dims_and_data -from pymc.pytensorf import compile +from pymc.pytensorf import compile, rvs_in_graph from pymc.sampling.forward import ( compile_forward_sampling_function, get_constant_coords, diff --git a/tests/test_initial_point.py b/tests/test_initial_point.py index e13ed5279c..b04542fde9 100644 --- a/tests/test_initial_point.py +++ b/tests/test_initial_point.py @@ -54,7 +54,7 @@ def test_dependent_initvals(self, reverse_rvs): L = pm.Uniform("L", 0, 1, initval=0.5) U = pm.Uniform("U", lower=9, upper=10, initval=9.5) B1 = pm.Uniform("B1", lower=L, upper=U, initval=5) - B2 = pm.Uniform("B2", lower=L, upper=U, initval=(L + U) / 2) + B2 = pm.Uniform("B2", lower=L, upper=U, initval=(0.5 + 9.5) / 2) if reverse_rvs: pmodel.free_RVs = pmodel.free_RVs[::-1] @@ -69,7 +69,16 @@ def test_dependent_initvals(self, reverse_rvs): pmodel.rvs_to_initial_values[U] = 9.9 ip = pmodel.initial_point(random_seed=0) assert ip["B1_interval__"] < 0 - assert ip["B2_interval__"] == 0 + assert ip["B2_interval__"] < 0 + + def test_symbolic_initval_not_supported(self): + with pm.Model() as pmodel: + L = pm.Uniform("L", 0, 1, initval=0.5) + U = pm.Uniform("U", lower=L, upper=1.5, initval=L * 2) + with pytest.raises( + ValueError, match="Initial value of U depends on other random variables" + ): + pmodel.initial_point(random_seed=0) pass def test_nested_initvals(self): @@ -288,7 +297,7 @@ class MyNormalDistribution(pm.Normal): x = MyNormalDistribution("x", 0, 1, initval="support_point") with pytest.warns( - UserWarning, match="Moment not defined for variable x of type MyNormalRV" + UserWarning, match="Support point not defined for variable x of type MyNormalRV" ): res = m.initial_point() diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index f7efd5f6d4..87659c2f0e 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -415,7 +415,7 @@ def update(self, node): ], )(rng1, rng2) with pytest.raises( - ValueError, match="No update found for at least one RNG used in SymbolicRandomVariable" + ValueError, match="No update found for at least one RNG used in SymbolicRV" ): compile(inputs=[], outputs=[dummy_x1, dummy_x2]) From 21647e4cb035c9043ccb30474581692611d801f8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 30 Jun 2025 12:47:41 +0200 Subject: [PATCH 04/10] Allow registering XTensorVariables directly in model --- pymc/dims/__init__.py | 3 + pymc/dims/distributions/core.py | 127 ++++++++++++++++++++++---- pymc/dims/distributions/transforms.py | 53 +++++++++++ pymc/dims/model.py | 39 +------- pymc/initial_point.py | 58 +++++++++++- pymc/logprob/basic.py | 18 ++-- pymc/logprob/rewriting.py | 7 +- pymc/model/core.py | 14 ++- pymc/pytensorf.py | 62 +++++-------- pymc/step_methods/metropolis.py | 8 +- tests/dims/test_model.py | 2 +- tests/dims/utils.py | 11 ++- 12 files changed, 284 insertions(+), 118 deletions(-) create mode 100644 pymc/dims/distributions/transforms.py diff --git a/pymc/dims/__init__.py b/pymc/dims/__init__.py index d1a050638c..34226133bc 100644 --- a/pymc/dims/__init__.py +++ b/pymc/dims/__init__.py @@ -39,6 +39,9 @@ def __init__(): logprob_rewrites_db.register( "pre_lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1 ) + logprob_rewrites_db.register( + "post_lower_xtensor", optdb.query("+lower_xtensor"), "cleanup", position=5.1 + ) initial_point_rewrites_db.register( "lower_xtensor", optdb.query("+lower_xtensor"), "basic", position=0.1 ) diff --git a/pymc/dims/distributions/core.py b/pymc/dims/distributions/core.py index bd48db0ec5..fb06457dc3 100644 --- a/pymc/dims/distributions/core.py +++ b/pymc/dims/distributions/core.py @@ -13,18 +13,26 @@ # limitations under the License. from collections.abc import Callable, Sequence from itertools import chain +from typing import cast +import numpy as np + +from pytensor.graph import node_rewriter from pytensor.graph.basic import Variable from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.random.op import RandomVariable from pytensor.xtensor import as_xtensor +from pytensor.xtensor.basic import XTensorFromTensor, xtensor_from_tensor from pytensor.xtensor.type import XTensorVariable -from pymc import modelcontext -from pymc.dims.model import with_dims -from pymc.distributions import transforms +from pymc import SymbolicRandomVariable, modelcontext +from pymc.dims.distributions.transforms import DimTransform, log_odds_transform, log_transform from pymc.distributions.distribution import _support_point, support_point from pymc.distributions.shape_utils import DimsWithEllipsis, convert_dims_with_ellipsis -from pymc.logprob.transforms import Transform +from pymc.logprob.abstract import MeasurableOp, _logprob +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.tensor import MeasurableDimShuffle +from pymc.logprob.utils import filter_measurable_variables from pymc.util import UNSET @@ -36,25 +44,98 @@ def dimshuffle_support_point(ds_op, _, rv): return ds_op(support_point(rv)) +@_support_point.register(XTensorFromTensor) +def xtensor_from_tensor_support_point(xtensor_op, _, rv): + # We remove the xtensor_from_tensor operation, so initial_point doesn't have to do a further lowering + return xtensor_op(support_point(rv)) + + +class MeasurableXTensorFromTensor(MeasurableOp, XTensorFromTensor): + __props__ = ("dims", "core_dims") # type: ignore[assignment] + + def __init__(self, dims, core_dims): + super().__init__(dims=dims) + self.core_dims = tuple(core_dims) if core_dims is not None else None + + +@node_rewriter([XTensorFromTensor]) +def find_measurable_xtensor_from_tensor(fgraph, node) -> list[XTensorVariable] | None: + if isinstance(node.op, MeasurableXTensorFromTensor): + return None + + xs = filter_measurable_variables(node.inputs) + + if not xs: + # Check if we have a transposition instead + # The rewrite that introduces measurable tranpsoses refuses to apply to multivariate RVs + # So we have a chance of inferring the core dims! + [ds] = node.inputs + ds_node = ds.owner + if not ( + ds_node is not None + and isinstance(ds_node.op, DimShuffle) + and ds_node.op.is_transpose + and filter_measurable_variables(ds_node.inputs) + ): + return None + [x] = ds_node.inputs + if not ( + x.owner is not None and isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable) + ): + return None + + measurable_x = MeasurableDimShuffle(**ds_node.op._props_dict())(x) # type: ignore[attr-defined] + + ndim_supp = x.owner.op.ndim_supp + if ndim_supp: + inverse_transpose = np.argsort(ds_node.op.shuffle) + dims = node.op.dims + dims_before_transpose = tuple(dims[i] for i in inverse_transpose) + core_dims = dims_before_transpose[-ndim_supp:] + else: + core_dims = () + + new_out = MeasurableXTensorFromTensor(dims=node.op.dims, core_dims=core_dims)(measurable_x) + else: + # If this happens we know there's no measurable transpose in between and we can + # safely infer the core_dims positionally when the inner logp is returned + new_out = MeasurableXTensorFromTensor(dims=node.op.dims, core_dims=None)(*node.inputs) + return [cast(XTensorVariable, new_out)] + + +@_logprob.register(MeasurableXTensorFromTensor) +def measurable_xtensor_from_tensor(op, values, rv, **kwargs): + rv_logp = _logprob(rv.owner.op, tuple(v.values for v in values), *rv.owner.inputs, **kwargs) + if op.core_dims is None: + # The core_dims of the inner rv are on the right + dims = op.dims[: rv_logp.ndim] + else: + # We inferred where the core_dims are! + dims = [d for d in op.dims if d not in op.core_dims] + return xtensor_from_tensor(rv_logp, dims=dims) + + +measurable_ir_rewrites_db.register( + "measurable_xtensor_from_tensor", find_measurable_xtensor_from_tensor, "basic", "xtensor" +) + + class DimDistribution: """Base class for PyMC distribution that wrap pytensor.xtensor.random operations, and follow xarray-like semantics.""" xrv_op: Callable - default_transform: Transform | None = None + default_transform: DimTransform | None = None @staticmethod def _as_xtensor(x): try: return as_xtensor(x) except TypeError: - try: - return with_dims(x) - except ValueError: - raise ValueError( - f"Variable {x} must have dims associated with it.\n" - "To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n" - "Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly." - ) + raise ValueError( + f"Variable {x} must have dims associated with it.\n" + "To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n" + "Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly." + ) def __new__( cls, @@ -119,10 +200,22 @@ def __new__( else: # Align observed dims with those of the RV # TODO: If this fails give a more informative error message - observed = observed.transpose(*rv_dims).values + observed = observed.transpose(*rv_dims) + + # Check user didn't pass regular transforms + if transform not in (UNSET, None): + if not isinstance(transform, DimTransform): + raise TypeError( + f"Transform must be a DimTransform, form pymc.dims.transforms, but got {type(transform)}." + ) + if default_transform not in (UNSET, None): + if not isinstance(default_transform, DimTransform): + raise TypeError( + f"default_transform must be a DimTransform, from pymc.dims.transforms, but got {type(default_transform)}." + ) rv = model.register_rv( - rv.values, + rv, name=name, observed=observed, total_size=total_size, @@ -182,10 +275,10 @@ def dist(self, *args, core_dims: str | Sequence[str] | None = None, **kwargs): class PositiveDimDistribution(DimDistribution): """Base class for positive continuous distributions.""" - default_transform = transforms.log + default_transform = log_transform class UnitDimDistribution(DimDistribution): """Base class for unit-valued distributions.""" - default_transform = transforms.logodds + default_transform = log_odds_transform diff --git a/pymc/dims/distributions/transforms.py b/pymc/dims/distributions/transforms.py new file mode 100644 index 0000000000..6805d1b5c1 --- /dev/null +++ b/pymc/dims/distributions/transforms.py @@ -0,0 +1,53 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytensor.xtensor as ptx + +from pymc.logprob.transforms import Transform + + +class DimTransform(Transform): + """Base class for transforms that are applied to dim distriubtions.""" + + +class LogTransform(DimTransform): + name = "log" + + def forward(self, value, *inputs): + return ptx.math.log(value) + + def backward(self, value, *inputs): + return ptx.math.exp(value) + + def log_jac_det(self, value, *inputs): + return value + + +log_transform = LogTransform() + + +class LogOddsTransform(DimTransform): + name = "logodds" + + def backward(self, value, *inputs): + return ptx.math.expit(value) + + def forward(self, value, *inputs): + return ptx.math.log(value / (1 - value)) + + def log_jac_det(self, value, *inputs): + sigmoid_value = ptx.math.sigmoid(value) + return ptx.math.log(sigmoid_value) + ptx.math.log1p(-sigmoid_value) + + +log_odds_transform = LogOddsTransform() diff --git a/pymc/dims/model.py b/pymc/dims/model.py index 67ab5fb562..4edd8df04b 100644 --- a/pymc/dims/model.py +++ b/pymc/dims/model.py @@ -15,7 +15,6 @@ from pytensor.tensor import TensorVariable from pytensor.xtensor import as_xtensor -from pytensor.xtensor.basic import TensorFromXTensor from pytensor.xtensor.type import XTensorVariable from pymc.data import Data as RegularData @@ -30,29 +29,6 @@ from pymc.model.core import Potential as RegularPotential -def with_dims(x: TensorVariable | XTensorVariable, model: Model | None = None) -> XTensorVariable: - """Recover the dims of a variable that was registered in the Model.""" - if isinstance(x, XTensorVariable): - return x - - if (x.owner is not None) and isinstance(x.owner.op, TensorFromXTensor): - dims = x.owner.inputs[0].type.dims - return as_xtensor(x, dims=dims, name=x.name) - - # Try accessing the model context to get dims - try: - model = modelcontext(model) - if ( - model.named_vars.get(x.name, None) is x - and (dims := model.named_vars_to_dims.get(x.name, None)) is not None - ): - return as_xtensor(x, dims=dims, name=x.name) - except TypeError: - pass - - raise ValueError(f"variable {x} doesn't have dims associated with it") - - def Data( name: str, value, dims: Dims = None, model: Model | None = None, **kwargs ) -> XTensorVariable: @@ -60,8 +36,6 @@ def Data( Dimensions are required if the input is not a scalar. These are always forwarded to the model object. - - The respective TensorVariable is registered in the model """ model = modelcontext(model) dims = convert_dims(dims) # type: ignore[assignment] @@ -90,12 +64,9 @@ def _register_and_return_xtensor_variable( value = value.transpose(*dims) # Regardless of whether dims are provided, we now have them dims = value.type.dims - # Register the equivalent TensorVariable with the model so it doesn't see XTensorVariables directly. - value = value.values # type: ignore[union-attr] - - value = registration_func(name, value, dims=dims, model=model) - - return as_xtensor(value, dims=dims, name=name) # type: ignore[arg-type] + else: + value = as_xtensor(value, dims=dims, name=name) # type: ignore[arg-type] + return registration_func(name, value, dims=dims, model=model) def Deterministic( @@ -107,8 +78,6 @@ def Deterministic( If the input is not an XTensorVariable, it is converted to one using `as_xtensor`. Dims are required if the input is not a scalar. The dimensions of the resulting XTensorVariable are always forwarded to the model object. - - The respective TensorVariable is registered in the model """ return _register_and_return_xtensor_variable(name, value, dims, model, RegularDeterministic) @@ -122,7 +91,5 @@ def Potential( If the input is not an XTensorVariable, it is converted to one using `as_xtensor`. Dims are required if the input is not a scalar. The dimensions of the resulting XTensorVariable are always forwarded to the model object. - - The respective TensorVariable is registered in the model. """ return _register_and_return_xtensor_variable(name, value, dims, model, RegularPotential) diff --git a/pymc/initial_point.py b/pymc/initial_point.py index b6b8dba3c9..4edf52f6b9 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -20,7 +20,9 @@ import pytensor import pytensor.tensor as pt -from pytensor.graph.basic import Variable, ancestors +from pytensor import graph_replace +from pytensor.compile.ops import TypeCastingOp +from pytensor.graph.basic import Apply, Variable, ancestors, walk from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB from pytensor.tensor.variable import TensorVariable @@ -195,6 +197,25 @@ def inner(seed, *args, **kwargs): return make_seeded_function(func) +class InitialPoint(TypeCastingOp): + def make_node(self, var): + return Apply(self, [var], [var.type()]) + + +def non_support_point_ancestors(value): + def expand(r: Variable): + node = r.owner + if node is not None and not isinstance(node.op, InitialPoint): + # Stop graph traversal at InitialPoint ops + return node.inputs + return None + + yield from walk([value], expand, bfs=False) + + +initial_point_op = InitialPoint() + + def make_initial_point_expression( *, free_rvs: Sequence[TensorVariable], @@ -235,13 +256,18 @@ def make_initial_point_expression( # Clone free_rvs so we don't modify the original graph initial_point_fgraph = FunctionGraph(outputs=free_rvs, clone=True) + # Wrap each rv in an initial_point Operation to avoid losing dependency between the RVs + replacements = tuple((rv, initial_point_op(rv)) for rv in initial_point_fgraph.outputs) + toposort_replace(initial_point_fgraph, replacements, reverse=True) # Apply any rewrites necessary to compute the initial points. initial_point_rewriter = initial_point_rewrites_db.query(initial_point_basic_query) if initial_point_rewriter: initial_point_rewriter.rewrite(initial_point_fgraph) - free_rvs_clone = initial_point_fgraph.outputs + ip_variables = initial_point_fgraph.outputs.copy() + free_rvs_clone = [ip.owner.inputs[0] for ip in ip_variables] + n_rvs = len(free_rvs_clone) initial_values = [] initial_values_transformed = [] @@ -255,6 +281,20 @@ def make_initial_point_expression( if strategy == "support_point": try: value = support_point(variable) + + # If a support point expression depends on other free_RVs that are not + # wrapped in InitialPoint, we need to replace them with their wrapped versions + # This can only happen for multi-output distributions, where the initial point + # of some outputs depends on the initial point of other outputs from the same node. + other_free_rvs = set(free_rvs_clone) - {variable} + support_point_replacements = { + ancestor: ip_variables[free_rvs_clone.index(ancestor)] + for ancestor in non_support_point_ancestors(value) + if ancestor in other_free_rvs + } + if support_point_replacements: + value = graph_replace(value, support_point_replacements) + except NotImplementedError: warnings.warn( f"Support point not defined for variable {variable} of type " @@ -286,7 +326,10 @@ def make_initial_point_expression( if original_variable in jitter_rvs: jitter = pt.random.uniform(-1, 1, size=value.shape) + # Hack to allow xtensor value to be added to tensor jitter + jitter = value.type.filter_variable(jitter) jitter.name = f"{variable.name}_jitter" + # Hack to allow xtensor value to be added to tensor jitter value = value + jitter value = value.astype(variable.dtype) @@ -297,14 +340,21 @@ def make_initial_point_expression( initial_values.append(value) + for initial_value in initial_values: + # FIXME: This is a hack so that interdependent replacements that can't + # be sorted topologically from the initial point graph come out correctly. + # This happens for multi-output nodes where the replacements depend on each other. + # From the original graph perspective, their ordering is equivalent. + initial_point_fgraph.add_output(initial_value, import_missing=True) + # We now replace all rvs by the respective initial_point expressions # in the constrained (untransformed) space. We do this in reverse topological # order, so that later nodes do not reintroduce expressions with earlier # rvs that would need to once again be replaced by their initial_points - toposort_replace(initial_point_fgraph, tuple(zip(free_rvs_clone, initial_values)), reverse=True) + toposort_replace(initial_point_fgraph, tuple(zip(ip_variables, initial_values)), reverse=True) if not return_transformed: - return initial_point_fgraph.outputs + return initial_point_fgraph.outputs[:n_rvs] # Because the unconstrained (transformed) expressions are a subgraph of the # constrained initial point they were also automatically updated inplace diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index 9fb76df169..409f09bc3b 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -197,9 +197,9 @@ def normal_logp(value, mu, sigma): [ir_valued_var] = fgraph.outputs [ir_rv, ir_value] = ir_valued_var.owner.inputs expr = _logprob_helper(ir_rv, ir_value, **kwargs) - cleanup_ir([expr]) + [expr] = cleanup_ir([expr]) if warn_rvs: - _warn_rvs_in_inferred_graph(expr) + _warn_rvs_in_inferred_graph([expr]) return expr @@ -297,9 +297,9 @@ def normal_logcdf(value, mu, sigma): [ir_valued_rv] = fgraph.outputs [ir_rv, ir_value] = ir_valued_rv.owner.inputs expr = _logcdf_helper(ir_rv, ir_value, **kwargs) - cleanup_ir([expr]) + [expr] = cleanup_ir([expr]) if warn_rvs: - _warn_rvs_in_inferred_graph(expr) + _warn_rvs_in_inferred_graph([expr]) return expr @@ -379,9 +379,9 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> Tens [ir_valued_rv] = fgraph.outputs [ir_rv, ir_value] = ir_valued_rv.owner.inputs expr = _icdf_helper(ir_rv, ir_value, **kwargs) - cleanup_ir([expr]) + [expr] = cleanup_ir([expr]) if warn_rvs: - _warn_rvs_in_inferred_graph(expr) + _warn_rvs_in_inferred_graph([expr]) return expr @@ -540,15 +540,15 @@ def conditional_logp( f"The logprob terms of the following value variables could not be derived: {missing_value_terms}" ) - logprobs = list(values_to_logprobs.values()) - cleanup_ir(logprobs) + values, logprobs = zip(*values_to_logprobs.items()) + logprobs = cleanup_ir(logprobs) if warn_rvs: rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs) if rvs_in_logp_expressions: warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning) - return values_to_logprobs + return dict(zip(values, logprobs)) def transformed_conditional_logp( diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index ea0202f00c..af0d8d01e1 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -133,6 +133,8 @@ def remove_DiracDelta(fgraph, node): logprob_rewrites_basic_query = RewriteDatabaseQuery(include=["basic"]) +logprob_rewrites_cleanup_query = RewriteDatabaseQuery(include=["cleanup"]) + logprob_rewrites_db = SequenceDB() logprob_rewrites_db.name = "logprob_rewrites_db" @@ -276,10 +278,11 @@ def construct_ir_fgraph( return fgraph -def cleanup_ir(vars: Sequence[Variable]) -> None: +def cleanup_ir(vars: Sequence[Variable]) -> Sequence[Variable]: fgraph = FunctionGraph(outputs=vars, clone=False) - ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["cleanup"])) + ir_rewriter = logprob_rewrites_db.query(logprob_rewrites_cleanup_query) ir_rewriter.rewrite(fgraph) + return fgraph.outputs def assume_valued_outputs(outputs: Sequence[TensorVariable]) -> Sequence[TensorVariable]: diff --git a/pymc/model/core.py b/pymc/model/core.py index 66e633e15e..5ec5c0ec34 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -35,6 +35,8 @@ from pytensor.compile import DeepCopyOp, Function, ProfileStats, get_mode from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant, Variable, ancestors, graph_inputs +from pytensor.tensor import as_tensor +from pytensor.tensor.math import variadic_add from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.type import RandomType from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -231,7 +233,9 @@ def __init__( grads = pytensor.grad(cost, grad_vars, disconnected_inputs="ignore") for grad_wrt, var in zip(grads, grad_vars): grad_wrt.name = f"{var.name}_grad" - grads = pt.join(0, *[pt.atleast_1d(grad.ravel()) for grad in grads]) + grads = pt.join( + 0, *[as_tensor(grad, allow_xtensor_conversion=True).ravel() for grad in grads] + ) outputs = [cost, grads] else: outputs = [cost] @@ -708,7 +712,9 @@ def logp( if not sum: return logp_factors - logp_scalar = pt.sum([pt.sum(factor) for factor in logp_factors]) + logp_scalar = variadic_add( + *(as_tensor(factor, allow_xtensor_conversion=True).sum() for factor in logp_factors) + ) logp_scalar_name = "__logp" if jacobian else "__logp_nojac" if self.name: logp_scalar_name = f"{logp_scalar_name}_{self.name}" @@ -1328,7 +1334,7 @@ def make_obs_var( else: if sps.issparse(data): data = sparse.basic.as_sparse(data, name=name) - else: + elif not isinstance(data, Variable): data = pt.as_tensor_variable(data, name=name) if total_size: @@ -1781,7 +1787,7 @@ def point_logps(self, point=None, round_vals=2, **kwargs): point = self.initial_point() factors = self.basic_RVs + self.potentials - factor_logps_fn = [pt.sum(factor) for factor in self.logp(factors, sum=False)] + factor_logps_fn = [factor.sum() for factor in self.logp(factors, sum=False)] return { factor.name: np.round(np.asarray(factor_logp), round_vals) for factor, factor_logp in zip( diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 3ce3081477..e0bb23b966 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -46,7 +46,7 @@ from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding from pytensor.tensor.rewriting.shape import ShapeFeature -from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable +from pytensor.tensor.sharedvar import SharedVariable from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 from pytensor.tensor.variable import TensorVariable @@ -300,7 +300,9 @@ def smarttypeX(x): def gradient1(f, v): """Flat gradient of f wrt v.""" - return pt.flatten(grad(f, v, disconnected_inputs="warn")) + return pt.as_tensor( + grad(f, v, disconnected_inputs="warn"), allow_xtensor_conversion=True + ).ravel() empty_gradient = pt.zeros(0, dtype="float32") @@ -419,11 +421,11 @@ def make_shared_replacements(point, vars, model): def join_nonshared_inputs( point: dict[str, np.ndarray], - outputs: list[TensorVariable], - inputs: list[TensorVariable], - shared_inputs: dict[TensorVariable, TensorSharedVariable] | None = None, + outputs: Sequence[Variable], + inputs: Sequence[Variable], + shared_inputs: dict[Variable, Variable] | None = None, make_inputs_shared: bool = False, -) -> tuple[list[TensorVariable], TensorVariable]: +) -> tuple[Sequence[Variable], TensorVariable]: """ Create new outputs and input TensorVariables where the non-shared inputs are joined in a single raveled vector input. @@ -548,7 +550,9 @@ def join_nonshared_inputs( if not inputs: raise ValueError("Empty list of input variables.") - raveled_inputs = pt.concatenate([var.ravel() for var in inputs]) + raveled_inputs = pt.concatenate( + [pt.as_tensor(var, allow_xtensor_conversion=True).ravel() for var in inputs] + ) if not make_inputs_shared: tensor_type = raveled_inputs.type @@ -560,12 +564,15 @@ def join_nonshared_inputs( if pytensor.config.compute_test_value != "off": joined_inputs.tag.test_value = raveled_inputs.tag.test_value - replace: dict[TensorVariable, TensorVariable] = {} + replace: dict[Variable, Variable] = {} last_idx = 0 for var in inputs: shape = point[var.name].shape arr_len = np.prod(shape, dtype=int) - replace[var] = joined_inputs[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype) + replacement_var = ( + joined_inputs[last_idx : last_idx + arr_len].reshape(shape).astype(var.dtype) + ) + replace[var] = var.type.filter_variable(replacement_var) last_idx += arr_len if shared_inputs is not None: @@ -1012,43 +1019,16 @@ def as_symbolic_string(x, **kwargs): def toposort_replace( - fgraph: FunctionGraph, replacements: Sequence[tuple[Variable, Variable]], reverse: bool = False + fgraph: FunctionGraph, + replacements: Sequence[tuple[Variable, Variable]], + reverse: bool = False, ) -> None: """Replace multiple variables in place in topological order.""" fgraph_toposort = {node: i for i, node in enumerate(fgraph.toposort())} - _inner_fgraph_toposorts = {} # Cache inner toposorts - - def _nested_toposort_index(var, fgraph_toposort) -> tuple[int]: - """Compute position of variable in fgraph toposort. - - When a variable is an OpFromGraph output, extend output with the toposort index of the inner graph(s). - - This allows ordering variables that come from the same OpFromGraph. - """ - if not var.owner: - return (-1,) - - index = fgraph_toposort[var.owner] - - # Recurse into OpFromGraphs - # TODO: Could also recurse into Scans - if isinstance(var.owner.op, OpFromGraph): - inner_fgraph = var.owner.op.fgraph - - if inner_fgraph not in _inner_fgraph_toposorts: - _inner_fgraph_toposorts[inner_fgraph] = { - node: i for i, node in enumerate(inner_fgraph.toposort()) - } - - inner_fgraph_toposort = _inner_fgraph_toposorts[inner_fgraph] - inner_var = inner_fgraph.outputs[var.owner.outputs.index(var)] - return (index, *_nested_toposort_index(inner_var, inner_fgraph_toposort)) - else: - return (index,) - + fgraph_toposort[None] = -1 # Variables without owner are not in the toposort sorted_replacements = sorted( replacements, - key=lambda pair: _nested_toposort_index(pair[0], fgraph_toposort), + key=lambda pair: fgraph_toposort[pair[0].owner], reverse=reverse, ) fgraph.replace_all(sorted_replacements, import_missing=True) diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 70c650653d..6bf3d6b9ed 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Callable from dataclasses import field -from typing import Any +from typing import Any, cast import numpy as np import numpy.random as nr @@ -22,6 +22,7 @@ import scipy.special from pytensor import tensor as pt +from pytensor.graph.basic import Variable from pytensor.graph.fg import MissingInputError from pytensor.tensor.random.basic import BernoulliRV, CategoricalRV from rich.progress import TextColumn @@ -1263,7 +1264,10 @@ def delta_logp( compile_kwargs: dict | None, ) -> pytensor.compile.Function: [logp0], inarray0 = join_nonshared_inputs( - point=point, outputs=[logp], inputs=vars, shared_inputs=shared + point=point, + outputs=[logp], + inputs=vars, + shared_inputs=cast(dict[Variable, Variable], shared), ) tensor_type = inarray0.type diff --git a/tests/dims/test_model.py b/tests/dims/test_model.py index bf27bf482e..c67f629b96 100644 --- a/tests/dims/test_model.py +++ b/tests/dims/test_model.py @@ -73,7 +73,7 @@ def test_simple_model(): np.testing.assert_allclose(draw, draw_same) assert not np.allclose(draw, draw_diff) - observed_values = DataArray(np.ones((3, 5)), dims=("a", "b")).transpose() + observed_values = DataArray(np.ones((3, 5)), dims=("a", "b")) with observe(xmodel, {"y": observed_values}): pm.sample_prior_predictive() idata = pm.sample( diff --git a/tests/dims/utils.py b/tests/dims/utils.py index e84dba4a55..07a340695b 100644 --- a/tests/dims/utils.py +++ b/tests/dims/utils.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytensor import graph_replace from pytensor.graph import rewrite_graph from pytensor.printing import debugprint +from pytensor.xtensor import as_xtensor from pymc import Model from pymc.testing import equal_computations_up_to_root @@ -21,7 +23,7 @@ def assert_equivalent_random_graph(model: Model, reference_model: Model) -> bool: """Check if the random graph of a model with xtensor variables is equivalent.""" lowered_model = rewrite_graph( - model.basic_RVs + model.deterministics + model.potentials, + [var.values for var in model.basic_RVs + model.deterministics + model.potentials], include=( "lower_xtensor", "inline_ofg_expansion_xtensor", @@ -46,8 +48,13 @@ def assert_equivalent_random_graph(model: Model, reference_model: Model) -> bool def assert_equivalent_logp_graph(model: Model, reference_model: Model) -> bool: """Check if the logp graph of a model with xtensor variables is equivalent.""" + # Replace xtensor value variables by tensor value variables + replacements = { + var: as_xtensor(var.values.clone(name=var.name), dims=var.dims) for var in model.value_vars + } + model_logp = graph_replace(model.logp(), replacements) lowered_model_logp = rewrite_graph( - [model.logp()], + [model_logp], include=("lower_xtensor", "canonicalize", "local_remove_all_assert"), ) reference_lowered_model_logp = rewrite_graph( From 38697234ed5a1392b8a4cc0d88f85cb9187d5563 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 30 Jun 2025 12:51:01 +0200 Subject: [PATCH 05/10] Allow Dim version of simple SymbolicRandomVariables --- pymc/dims/distributions/scalar.py | 19 ++++++++++++++++++- pymc/distributions/distribution.py | 21 +++++++++++++++++++++ tests/dims/distributions/test_scalar.py | 17 +++++++++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) diff --git a/pymc/dims/distributions/scalar.py b/pymc/dims/distributions/scalar.py index cb5fd5e1f4..6cdb3ccc73 100644 --- a/pymc/dims/distributions/scalar.py +++ b/pymc/dims/distributions/scalar.py @@ -14,6 +14,8 @@ import pytensor.xtensor as ptx import pytensor.xtensor.random as pxr +from pytensor.xtensor import as_xtensor + from pymc.dims.distributions.core import ( DimDistribution, PositiveDimDistribution, @@ -21,7 +23,7 @@ ) from pymc.distributions.continuous import Beta as RegularBeta from pymc.distributions.continuous import Gamma as RegularGamma -from pymc.distributions.continuous import flat, halfflat +from pymc.distributions.continuous import HalfStudentTRV, flat, halfflat def _get_sigma_from_either_sigma_or_tau(*, sigma, tau): @@ -89,6 +91,21 @@ def dist(cls, nu, mu=0, sigma=None, *, lam=None, **kwargs): return super().dist([nu, mu, sigma], **kwargs) +class HalfStudentT(PositiveDimDistribution): + @classmethod + def dist(cls, nu, sigma=None, *, lam=None, **kwargs): + sigma = _get_sigma_from_either_sigma_or_tau(sigma=sigma, tau=lam) + return super().dist([nu, sigma], **kwargs) + + @classmethod + def xrv_op(self, nu, sigma, core_dims=None, extra_dims=None, rng=None): + nu = as_xtensor(nu) + sigma = as_xtensor(sigma) + core_rv = HalfStudentTRV.rv_op(nu=nu.values, sigma=sigma.values).owner.op + xop = pxr.as_xrv(core_rv) + return xop(nu, sigma, core_dims=core_dims, extra_dims=extra_dims, rng=rng) + + class Cauchy(DimDistribution): xrv_op = pxr.cauchy diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 2a46929522..27d53c8687 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -370,8 +370,29 @@ def __init__( kwargs.setdefault("inline", True) kwargs.setdefault("strict", True) + # Many RVS have a size argument, even when this is `None` and is therefore unused + kwargs.setdefault("on_unused_input", "ignore") + if hasattr(self, "name"): + kwargs.setdefault("name", self.name) super().__init__(*args, **kwargs) + def make_node(self, *inputs): + # If we try to build the RV with a different size type (vector -> None or None -> vector) + # We need to rebuild the Op with new size type in the inner graph + if self.extended_signature is not None: + (rng_arg_idxs, size_arg_idx, param_idxs), _ = self.get_input_output_type_idxs( + self.extended_signature + ) + if size_arg_idx is not None and len(rng_arg_idxs) == 1: + new_size_type = normalize_size_param(inputs[size_arg_idx]).type + if not self.input_types[size_arg_idx].in_same_class(new_size_type): + params = [inputs[idx] for idx in param_idxs] + size = inputs[size_arg_idx] + rng = inputs[rng_arg_idxs[0]] + return self.rebuild_rv(*params, size=size, rng=rng).owner + + return super().make_node(*inputs) + def update(self, node: Apply) -> dict[Variable, Variable]: """Symbolic update expression for input random state variables. diff --git a/tests/dims/distributions/test_scalar.py b/tests/dims/distributions/test_scalar.py index a487591c01..df47422dea 100644 --- a/tests/dims/distributions/test_scalar.py +++ b/tests/dims/distributions/test_scalar.py @@ -22,6 +22,7 @@ HalfCauchy, HalfFlat, HalfNormal, + HalfStudentT, InverseGamma, Laplace, LogNormal, @@ -119,6 +120,22 @@ def test_studentt(): assert_equivalent_logp_graph(model, reference_model) +def test_halfstudentt(): + coords = {"a": range(3)} + with Model(coords=coords) as model: + HalfStudentT("x", nu=1, dims="a") + HalfStudentT("y", nu=1, sigma=3, dims="a") + HalfStudentT("z", nu=1, lam=3, dims="a") + + with Model(coords=coords) as reference_model: + regular_distributions.HalfStudentT("x", nu=1, dims="a") + regular_distributions.HalfStudentT("y", nu=1, sigma=3, dims="a") + regular_distributions.HalfStudentT("z", nu=1, lam=3, dims="a") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + def test_cauchy(): coords = {"a": range(3)} with Model(coords=coords) as model: From 954c33c4a505396d6130e7282e4a83f277950863 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Tue, 24 Jun 2025 19:10:18 +0200 Subject: [PATCH 06/10] Implement Dim ZeroSumNormal --- pymc/dims/distributions/transforms.py | 42 +++++++++++++ pymc/dims/distributions/vector.py | 82 +++++++++++++++++++++++++ pymc/distributions/multivariate.py | 11 ++-- tests/dims/distributions/test_vector.py | 20 +++++- tests/dims/test_model.py | 53 ++++++++++++++++ 5 files changed, 202 insertions(+), 6 deletions(-) diff --git a/pymc/dims/distributions/transforms.py b/pymc/dims/distributions/transforms.py index 6805d1b5c1..8f49d2a16f 100644 --- a/pymc/dims/distributions/transforms.py +++ b/pymc/dims/distributions/transforms.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytensor.tensor as pt import pytensor.xtensor as ptx from pymc.logprob.transforms import Transform @@ -51,3 +52,44 @@ def log_jac_det(self, value, *inputs): log_odds_transform = LogOddsTransform() + + +class ZeroSumTransform(DimTransform): + name = "zerosum" + + def __init__(self, dims: tuple[str, ...]): + self.dims = dims + + @staticmethod + def extend_dim(array, dim): + n = (array.sizes[dim] + 1).astype("floatX") + sum_vals = array.sum(dim) + norm = sum_vals / (pt.sqrt(n) + n) + fill_val = norm - sum_vals / pt.sqrt(n) + + out = ptx.concat([array, fill_val], dim=dim) + return out - norm + + @staticmethod + def reduce_dim(array, dim): + n = array.sizes[dim].astype("floatX") + last = array.isel({dim: -1}) + + sum_vals = -last * pt.sqrt(n) + norm = sum_vals / (pt.sqrt(n) + n) + return array.isel({dim: slice(None, -1)}) + norm + + def forward(self, value, *rv_inputs): + for dim in self.dims: + value = self.reduce_dim(value, dim=dim) + return value + + def backward(self, value, *rv_inputs): + for dim in self.dims: + value = self.extend_dim(value, dim=dim) + return value + + def log_jac_det(self, value, *rv_inputs): + # Use following once broadcast_like is implemented + # as_xtensor(0).broadcast_like(value, exclude=self.dims)` + return value.sum(self.dims) * 0 diff --git a/pymc/dims/distributions/vector.py b/pymc/dims/distributions/vector.py index c67b69fba9..0ad834c8a8 100644 --- a/pymc/dims/distributions/vector.py +++ b/pymc/dims/distributions/vector.py @@ -14,9 +14,14 @@ import pytensor.xtensor as ptx import pytensor.xtensor.random as ptxr +from pytensor.tensor import as_tensor +from pytensor.xtensor import as_xtensor from pytensor.xtensor import random as pxr from pymc.dims.distributions.core import VectorDimDistribution +from pymc.dims.distributions.transforms import ZeroSumTransform +from pymc.distributions.multivariate import ZeroSumNormalRV +from pymc.util import UNSET class Categorical(VectorDimDistribution): @@ -114,3 +119,80 @@ def dist(cls, mu, cov=None, *, chol=None, lower=True, core_dims=None, **kwargs): cov = chol.dot(chol.rename({d0: safe_name}), dim=d1).rename({safe_name: d1}) return super().dist([mu, cov], core_dims=core_dims, **kwargs) + + +class ZeroSumNormal(VectorDimDistribution): + """Zero-sum multivariate normal distribution. + + Parameters + ---------- + sigma : xtensor_like, optional + The standard deviation of the underlying unconstrained normal distribution. + Defaults to 1.0. It cannot have core dimensions. + core_dims : Sequence of str, optional + The axes along which the zero-sum constraint is applied. + **kwargs + Additional keyword arguments used to define the distribution. + + Returns + ------- + XTensorVariable + An xtensor variable representing the zero-sum multivariate normal distribution. + """ + + @classmethod + def __new__( + cls, *args, core_dims=None, dims=None, default_transform=UNSET, observed=None, **kwargs + ): + if core_dims is not None: + if isinstance(core_dims, str): + core_dims = (core_dims,) + + # Create default_transform + if observed is None and default_transform is UNSET: + default_transform = ZeroSumTransform(dims=core_dims) + + # If the user didn't specify dims, take it from core_dims + # We need them to be forwarded to dist in the `dim_lenghts` argument + if dims is None and core_dims is not None: + dims = (..., *core_dims) + + return super().__new__( + *args, + core_dims=core_dims, + dims=dims, + default_transform=default_transform, + observed=observed, + **kwargs, + ) + + @classmethod + def dist(cls, sigma=1.0, *, core_dims=None, dim_lengths, **kwargs): + if isinstance(core_dims, str): + core_dims = (core_dims,) + if core_dims is None or len(core_dims) == 0: + raise ValueError("ZeroSumNormal requires atleast 1 core_dims") + + support_dims = as_xtensor( + as_tensor([dim_lengths[core_dim] for core_dim in core_dims]), dims=("_",) + ) + sigma = cls._as_xtensor(sigma) + + return super().dist( + [sigma, support_dims], core_dims=core_dims, dim_lengths=dim_lengths, **kwargs + ) + + @classmethod + def xrv_op(self, sigma, support_dims, core_dims, extra_dims=None, rng=None): + sigma = as_xtensor(sigma) + support_dims = as_xtensor(support_dims, dims=("_",)) + support_shape = support_dims.values + core_rv = ZeroSumNormalRV.rv_op(sigma=sigma.values, support_shape=support_shape).owner.op + xop = pxr.as_xrv( + core_rv, + core_inps_dims_map=[(), (0,)], + core_out_dims_map=tuple(range(1, len(core_dims) + 1)), + ) + # Dummy "_" core dim to absorb the support_shape vector + # If ZeroSumNormal expected a scalar per support dim, this wouldn't be needed + return xop(sigma, support_dims, core_dims=("_", *core_dims), extra_dims=extra_dims, rng=rng) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index f5a5506bba..5dd2509eff 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2664,6 +2664,7 @@ def logp(value, alpha, K): class ZeroSumNormalRV(SymbolicRandomVariable): """ZeroSumNormal random variable.""" + name = "ZeroSumNormal" _print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}") @classmethod @@ -2687,12 +2688,12 @@ def rv_op(cls, sigma, support_shape, *, size=None, rng=None): zerosum_rv -= zerosum_rv.mean(axis=-axis - 1, keepdims=True) support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)]) - extended_signature = f"[rng],(),(s),[size]->[rng],({support_str})" - return ZeroSumNormalRV( - inputs=[rng, sigma, support_shape, size], + extended_signature = f"[rng],[size],(),(s)->[rng],({support_str})" + return cls( + inputs=[rng, size, sigma, support_shape], outputs=[next_rng, zerosum_rv], extended_signature=extended_signature, - )(rng, sigma, support_shape, size) + )(rng, size, sigma, support_shape) class ZeroSumNormal(Distribution): @@ -2828,7 +2829,7 @@ def zerosum_default_transform(op, rv): @_logprob.register(ZeroSumNormalRV) -def zerosumnormal_logp(op, values, rng, sigma, support_shape, size, **kwargs): +def zerosumnormal_logp(op, values, rng, size, sigma, support_shape, **kwargs): (value,) = values shape = value.shape n_zerosum_axes = op.ndim_supp diff --git a/tests/dims/distributions/test_vector.py b/tests/dims/distributions/test_vector.py index 0f08505dba..3a57453b48 100644 --- a/tests/dims/distributions/test_vector.py +++ b/tests/dims/distributions/test_vector.py @@ -19,7 +19,7 @@ import pymc.distributions as regular_distributions from pymc import Model -from pymc.dims import Categorical, MvNormal +from pymc.dims import Categorical, MvNormal, ZeroSumNormal from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph @@ -60,3 +60,21 @@ def test_mvnormal(): assert_equivalent_random_graph(model, reference_model) assert_equivalent_logp_graph(model, reference_model) + + +def test_zerosumnormal(): + coords = {"a": range(3), "b": range(2)} + with Model(coords=coords) as model: + ZeroSumNormal("x", core_dims=("b",), dims=("a", "b")) + ZeroSumNormal("y", sigma=3, core_dims=("b",), dims=("a", "b")) + ZeroSumNormal("z", core_dims=("a", "b"), dims=("a", "b")) + + with Model(coords=coords) as reference_model: + regular_distributions.ZeroSumNormal("x", dims=("a", "b")) + regular_distributions.ZeroSumNormal("y", sigma=3, n_zerosum_axes=1, dims=("a", "b")) + regular_distributions.ZeroSumNormal("z", n_zerosum_axes=2, dims=("a", "b")) + + assert_equivalent_random_graph(model, reference_model) + # Logp is correct, but we have join(..., -1) and join(..., 1), that don't get canonicalized to the same + # Should work once https://github.com/pymc-devs/pytensor/issues/1505 is fixed + # assert_equivalent_logp_graph(model, reference_model) diff --git a/tests/dims/test_model.py b/tests/dims/test_model.py index c67f629b96..68446c1a62 100644 --- a/tests/dims/test_model.py +++ b/tests/dims/test_model.py @@ -172,3 +172,56 @@ def test_complex_model(): tune=200, chains=2, draws=50, compute_convergence_checks=False, progressbar=False ) pm.sample_posterior_predictive(idata, progressbar=False) + + +def test_zerosumnormal_model(): + coords = {"time": range(5), "item": range(3)} + + with pm.Model(coords=coords) as model: + zsn_item = pmd.ZeroSumNormal("zsn_item", core_dims="item", dims=("time", "item")) + zsn_time = pmd.ZeroSumNormal("zsn_time", core_dims="time", dims=("time", "item")) + zsn_item_time = pmd.ZeroSumNormal("zsn_item_time", core_dims=("item", "time")) + assert zsn_item.type.dims == ("time", "item") + assert zsn_time.type.dims == ("time", "item") + assert zsn_item_time.type.dims == ("item", "time") + + zsn_item_draw, zsn_time_draw, zsn_item_time_draw = pm.draw( + [zsn_item, zsn_time, zsn_item_time], random_seed=1 + ) + assert zsn_item_draw.shape == (5, 3) + np.testing.assert_allclose(zsn_item_draw.mean(-1), 0, atol=1e-13) + assert not np.allclose(zsn_item_draw.mean(0), 0, atol=1e-13) + + assert zsn_time_draw.shape == (5, 3) + np.testing.assert_allclose(zsn_time_draw.mean(0), 0, atol=1e-13) + assert not np.allclose(zsn_time_draw.mean(-1), 0, atol=1e-13) + + assert zsn_item_time_draw.shape == (3, 5) + np.testing.assert_allclose(zsn_item_time_draw.mean(), 0, atol=1e-13) + + with pm.Model(coords=coords) as ref_model: + # Check that the ZeroSumNormal can be used in a model + pm.ZeroSumNormal("zsn_item", dims=("time", "item")) + pm.ZeroSumNormal("zsn_time", dims=("item", "time")) + pm.ZeroSumNormal("zsn_item_time", n_zerosum_axes=2, dims=("item", "time")) + + # Check initial_point and logp + ip = model.initial_point() + ref_ip = ref_model.initial_point() + assert ip.keys() == ref_ip.keys() + for i, (ip_value, ref_ip_value) in enumerate(zip(ip.values(), ref_ip.values())): + if i == 1: + # zsn_time is actually transposed in the original model + ip_value = ip_value.T + np.testing.assert_allclose(ip_value, ref_ip_value) + + logp_fn = model.compile_logp() + ref_logp_fn = ref_model.compile_logp() + np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ref_ip)) + + # Test a new point + rng = np.random.default_rng(68) + new_ip = ip.copy() + for key in new_ip: + new_ip[key] += rng.uniform(size=new_ip[key].shape) + np.testing.assert_allclose(logp_fn(new_ip), ref_logp_fn(new_ip)) From 48175075a43ff6e9f0df9c162e194513712b0095 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 29 Jun 2025 23:54:06 +0200 Subject: [PATCH 07/10] Arviz don't fail hard on incompatible coordinate lengths --- pymc/backends/arviz.py | 33 ++++++++++++++++++++++++++++++++- tests/backends/test_arviz.py | 25 +++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index f0f0eec963..b900b50019 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -49,10 +49,41 @@ _log = logging.getLogger(__name__) + +RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False + + # random variable object ... Var = Any +def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs): + safe_coords = coords + + if not RAISE_ON_INCOMPATIBLE_COORD_LENGTHS: + coords_lengths = {k: len(v) for k, v in coords.items()} + for var_name, var in vars_dict.items(): + # Iterate in reversed because of chain/draw batch dimensions + for dim, dim_length in zip(reversed(dims.get(var_name, ())), reversed(var.shape)): + coord_length = coords_lengths.get(dim, None) + if (coord_length is not None) and (coord_length != dim_length): + warnings.warn( + f"Incompatible coordinate length of {coord_length} for dimension '{dim}' of variable '{var_name}'.\n" + "This usually happens when a sliced or concatenated variable is wrapped as a `pymc.dims.Deterministic`." + "The originate coordinates for this dim will not be included in the returned dataset for any of the variables. " + "Instead they will default to `np.arange(var_length)` and the shorter variables will be right-padded with nan.\n" + "To make this warning into an error set `pymc.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS` to `True`", + UserWarning, + ) + if safe_coords is coords: + safe_coords = coords.copy() + safe_coords.pop(dim) + coords_lengths.pop(dim) + + # FIXME: Would be better to drop coordinates altogether, but arviz defaults to `np.arange(var_length)` + return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs) + + def find_observations(model: "Model") -> dict[str, Var]: """If there are observations available, return them as a dictionary.""" observations = {} @@ -365,7 +396,7 @@ def priors_to_xarray(self): priors_dict[group] = ( None if var_names is None - else dict_to_dataset( + else dict_to_dataset_drop_incompatible_coords( {k: np.expand_dims(self.prior[k], 0) for k in var_names}, library=pymc, coords=self.coords, diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index 3c06288b35..1b06d28ae3 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re import warnings import numpy as np @@ -837,3 +838,27 @@ def test_dataset_to_point_list_str_key(self): ds[3] = xarray.DataArray([1, 2, 3]) with pytest.raises(ValueError, match="must be str"): dataset_to_point_list(ds, sample_dims=["chain", "draw"]) + + +def test_incompatible_coordinate_lengths(): + with pm.Model(coords={"a": [-1, -2, -3]}) as m: + x = pm.Normal("x", dims="a") + y = pm.Deterministic("y", x[1:], dims=("a",)) + + with pytest.warns( + UserWarning, + match=re.escape( + "Incompatible coordinate length of 3 for dimension 'a' of variable 'y'" + ), + ): + prior = pm.sample_prior_predictive(draws=1).prior.squeeze(("chain", "draw")) + assert prior.x.dims == prior.y.dims == ("a",) + assert prior.x.shape == prior.y.shape == (3,) + assert np.isnan(prior.y.values[-1]) + assert list(prior.coords["a"]) == [0, 1, 2] + + pm.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = True + with pytest.raises(ValueError): + pm.sample_prior_predictive(draws=1) + + pm.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False From a1496444a76518b8087a6b407a8b877281694274 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 29 Jun 2025 23:55:10 +0200 Subject: [PATCH 08/10] Tweaks to model_graph to play nice with XTensorVariables * Use RV Op name when provided * More robust detection of observed data variables (after https://github.com/pymc-devs/pymc/pull/7656 arbitrary graphs are allowed) * Remove self loops explicitly (closes https://github.com/pymc-devs/pymc/issues/7722) --- pymc/model_graph.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 50fd5227d6..df228d638a 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -21,11 +21,11 @@ from typing import Any, cast from pytensor import function -from pytensor.graph.basic import ancestors, walk +from pytensor.graph.basic import Variable, ancestors, walk from pytensor.tensor.shape import Shape -from pytensor.tensor.variable import TensorVariable from pymc.model.core import modelcontext +from pymc.pytensorf import _cheap_eval_mode from pymc.util import VarName, get_default_varnames, get_var_name __all__ = ( @@ -73,7 +73,7 @@ def create_plate_label_with_dim_length( def fast_eval(var): - return function([], var, mode="FAST_COMPILE")() + return function([], var, mode=_cheap_eval_mode)() class NodeType(str, Enum): @@ -88,7 +88,7 @@ class NodeType(str, Enum): @dataclass class NodeInfo: - var: TensorVariable + var: Variable node_type: NodeType def __hash__(self): @@ -108,10 +108,10 @@ def __eq__(self, other) -> bool: GraphvizNodeKwargs = dict[str, Any] -NodeFormatter = Callable[[TensorVariable], GraphvizNodeKwargs] +NodeFormatter = Callable[[Variable], GraphvizNodeKwargs] -def default_potential(var: TensorVariable) -> GraphvizNodeKwargs: +def default_potential(var: Variable) -> GraphvizNodeKwargs: """Return default data for potential in the graph.""" return { "shape": "octagon", @@ -120,17 +120,19 @@ def default_potential(var: TensorVariable) -> GraphvizNodeKwargs: } -def random_variable_symbol(var: TensorVariable) -> str: +def random_variable_symbol(var: Variable) -> str: """Get the symbol of the random variable.""" - symbol = var.owner.op.__class__.__name__ + op = var.owner.op - if symbol.endswith("RV"): - symbol = symbol[:-2] + if name := getattr(op, "name", None): + symbol = name[0].upper() + name[1:] + else: + symbol = op.__class__.__name__.removesuffix("RV") return symbol -def default_free_rv(var: TensorVariable) -> GraphvizNodeKwargs: +def default_free_rv(var: Variable) -> GraphvizNodeKwargs: """Return default data for free RV in the graph.""" symbol = random_variable_symbol(var) @@ -141,7 +143,7 @@ def default_free_rv(var: TensorVariable) -> GraphvizNodeKwargs: } -def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs: +def default_observed_rv(var: Variable) -> GraphvizNodeKwargs: """Return default data for observed RV in the graph.""" symbol = random_variable_symbol(var) @@ -152,7 +154,7 @@ def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs: } -def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs: +def default_deterministic(var: Variable) -> GraphvizNodeKwargs: """Return default data for the deterministic in the graph.""" return { "shape": "box", @@ -161,7 +163,7 @@ def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs: } -def default_data(var: TensorVariable) -> GraphvizNodeKwargs: +def default_data(var: Variable) -> GraphvizNodeKwargs: """Return default data for the data in the graph.""" return { "shape": "box", @@ -239,7 +241,7 @@ def __init__(self, model): self._all_vars = {model[var_name] for var_name in self._all_var_names} self.var_list = self.model.named_vars.values() - def get_parent_names(self, var: TensorVariable) -> set[VarName]: + def get_parent_names(self, var: Variable) -> set[VarName]: if var.owner is None: return set() @@ -345,7 +347,7 @@ def get_plates( dim_name: fast_eval(value).item() for dim_name, value in self.model.dim_lengths.items() } var_shapes: dict[str, tuple[int, ...]] = { - var_name: tuple(fast_eval(self.model[var_name].shape)) + var_name: tuple(map(int, fast_eval(self.model[var_name].shape))) for var_name in self.vars_to_plot(var_names) } From a4032a13fec2041561688af71cb1910381e99b82 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 29 Jun 2025 23:54:40 +0200 Subject: [PATCH 09/10] Core notebook on dims module Co-authored-by: Oriol Abril-Pla Co-authored-by: Allen Downey --- .../learn/core_notebooks/dims_module.ipynb | 3322 +++++++++++++++++ 1 file changed, 3322 insertions(+) create mode 100644 docs/source/learn/core_notebooks/dims_module.ipynb diff --git a/docs/source/learn/core_notebooks/dims_module.ipynb b/docs/source/learn/core_notebooks/dims_module.ipynb new file mode 100644 index 0000000000..b3fab1e7a8 --- /dev/null +++ b/docs/source/learn/core_notebooks/dims_module.ipynb @@ -0,0 +1,3322 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "17e37649edaa8d0d", + "metadata": {}, + "source": [ + "# PyMC dims module" + ] + }, + { + "cell_type": "markdown", + "id": "79268b908558209c", + "metadata": {}, + "source": [ + "## A short history of dims in PyMC\n", + "\n", + "PyMC introduced the ability to specify model variable `dims` in version 3.9 in June 2020 (5 years as of the time of writing). In the release notes, it was mentioned only after [14 other new features](https://github.com/pymc-devs/pymc/blob/1d00f3eb81723523968f3610e81a0c42fd96326f/RELEASE-NOTES.md?plain=1#L236), but over time it became a foundation of the library.\n", + "\n", + "It allows users to more naturally specify the dimensions of model variables with string names, and provides a \"seamless\" conversion to arviz {doc}`InferenceData ` objects, which have become the standard for storing and investigating results from probabilistic programming languages.\n", + "\n", + "However, the behavior of dims is rather limited. It can only be used to specify the shape of new random variables and label existing dimensions (e.g., in {func}`~pymc.Deterministic`). Otherwise it has no effect on the computation, unlike operations done with {class}`~arviz.InferenceData` variables, which are based on {lib}`xarray` and where dims inform array selection, alignment, and broadcasting behavior.\n", + "\n", + "As a result, in PyMC models users have to write computations that follow NumPy semantics, which often requires transpositions, reshapes, new axis (`None`) and numerical axis arguments sprinkled everywhere. It can be hard to get these right and in the end it's often hard to make sense of the written model.\n", + "\n", + "### Expanding the role of dims\n", + "\n", + "Now we are introducing an experimental {mod}`pymc.dims` module that allows users to define data, distributions, and math operations that respect dim semantics, following {mode}`xarray` operations **without coordinates** as closely as possible.\n", + "\n", + ":::{warning}The `dims` module is experimental, not exhaustively tested and the API is being iteratively worked on, and is therefore subject to changes between any two PyMC releases. We welcome users to test it and provide feedback, but we don't yet endorse its use for production.:::\n", + "\n", + "\n", + "## A simple example\n", + "\n", + "We'll start with a model written in current PyMC style, using synthetic data." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "20399d0df364361e", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "import pymc as pm\n", + "\n", + "seed = sum(map(ord, \"dims module\"))\n", + "rng = np.random.default_rng(seed)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "313be4498ce33680", + "metadata": {}, + "outputs": [], + "source": [ + "# Very realistic looking data!\n", + "observed_response_np = np.ones((5, 20), dtype=int)\n", + "coords = coords = {\n", + " \"participant\": range(5),\n", + " \"trial\": range(20),\n", + " \"item\": range(3),\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "12822088a0040760", + "metadata": {}, + "source": [ + "This model predicts participants' categorical responses across trials by combining participant-specific item preferences (constrained to sum to zero) with shared time-varying effects, then applying a softmax to model response probabilities.\n", + "\n", + "Notice the need to identify axes by number rather than dimension name, and the need to use `None` to create new axes in order to specify broadcast requirements." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4c97fc1f0b8eeec", + "metadata": {}, + "outputs": [], + "source": [ + "with pm.Model(coords=coords) as model:\n", + " observed_response = pm.Data(\n", + " \"observed_response\", observed_response_np, dims=(\"participant\", \"trial\")\n", + " )\n", + " # Use ZeroSumNormal to avoid identifiability issues\n", + " participant_preference = pm.ZeroSumNormal(\n", + " \"participant_preference\", n_zerosum_axes=1, dims=(\"participant\", \"item\")\n", + " )\n", + "\n", + " # Shared time effects across all participants\n", + " time_effects = pm.Normal(\"time_effects\", dims=(\"trial\", \"item\"))\n", + "\n", + " trial_preference = pm.Deterministic(\n", + " \"trial_preference\",\n", + " participant_preference[:, None, :] + time_effects[None, :, :],\n", + " dims=(\"participant\", \"trial\", \"item\"),\n", + " )\n", + "\n", + " response = pm.Categorical(\n", + " \"response\",\n", + " p=pm.math.softmax(trial_preference, axis=-1),\n", + " observed=observed_response,\n", + " dims=(\"participant\", \"trial\"),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2efa25b6d8713c0", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:29.819511Z", + "start_time": "2025-06-30T15:53:25.547610Z" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "clusterparticipant (5) x trial (20)\n", + "\n", + "participant (5) x trial (20)\n", + "\n", + "\n", + "clusterparticipant (5) x item (3)\n", + "\n", + "participant (5) x item (3)\n", + "\n", + "\n", + "clustertrial (20) x item (3)\n", + "\n", + "trial (20) x item (3)\n", + "\n", + "\n", + "clusterparticipant (5) x trial (20) x item (3)\n", + "\n", + "participant (5) x trial (20) x item (3)\n", + "\n", + "\n", + "\n", + "observed_response\n", + "\n", + "observed_response\n", + "~\n", + "Data\n", + "\n", + "\n", + "\n", + "response\n", + "\n", + "response\n", + "~\n", + "Categorical\n", + "\n", + "\n", + "\n", + "response->observed_response\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "participant_preference\n", + "\n", + "participant_preference\n", + "~\n", + "ZeroSumNormal\n", + "\n", + "\n", + "\n", + "trial_preference\n", + "\n", + "trial_preference\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "participant_preference->trial_preference\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "time_effects\n", + "\n", + "time_effects\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "time_effects->trial_preference\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "trial_preference->response\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.to_graphviz()" + ] + }, + { + "cell_type": "markdown", + "id": "41bccc7c213e1be3", + "metadata": {}, + "source": [ + "And here's the equivalent model using the {mod}`pymc.dims` module." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3ff29a7a93236486", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ricardo/Documents/pymc/pymc/dims/__init__.py:66: UserWarning: The `pymc.dims` module is experimental and may contain critical bugs (p=0.676).\n", + "Please report any issues you encounter at https://github.com/pymc-devs/pymc/issues.\n", + "Disclaimer: This an experimental API and may change at any time.\n", + " __init__()\n" + ] + } + ], + "source": [ + "import pymc.dims as pmd" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c020e450cc165e46", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:29.997156Z", + "start_time": "2025-06-30T15:53:29.935025Z" + } + }, + "outputs": [], + "source": [ + "with pm.Model(coords=coords) as dmodel:\n", + " observed_response = pmd.Data(\n", + " \"observed_response\", observed_response_np, dims=(\"participant\", \"trial\")\n", + " )\n", + " participant_preference = pmd.ZeroSumNormal(\n", + " \"participant_preference\", core_dims=\"item\", dims=(\"participant\", \"item\")\n", + " )\n", + "\n", + " # Shared time effects across all participants\n", + " time_effects = pmd.Normal(\"time_effects\", dims=(\"item\", \"trial\"))\n", + "\n", + " trial_preference = pmd.Deterministic(\n", + " \"trial_preference\",\n", + " participant_preference + time_effects,\n", + " )\n", + "\n", + " response = pmd.Categorical(\n", + " \"response\",\n", + " p=pmd.math.softmax(trial_preference, dim=\"item\"),\n", + " core_dims=\"item\",\n", + " observed=observed_response,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "3985ee1b52708c0c", + "metadata": {}, + "source": [ + "Note we still use the same {class}`~pymc.Model` constructor, but everything else was now defined with an equivalent function or class defined in the {mod}`pymc.dims` module.\n", + "\n", + "There are some notable differences:\n", + "\n", + "1. `ZeroSumNormal` takes a `core_dims` argument instead of `n_zerosum_axes`. This tells PyMC which of the `dims` that define the distribution are constrained to be zero-summed. All distributions that take non-scalar parameters now require a `core_dims` argument. Previously, they were assumed to be right-aligned by the user (see more in {doc}`dimensionality`). Now you don't have to worry about the order of the dimensions in your model, just their meaning!\n", + "\n", + "2. The `trial_preference` computation aligns dimensions for broadcasting automatically. Note we use {func}`pymc.dims.Deterministic` and not {func}`pymc.Deterministic`, which automatically propagates the `dims` to the model object.\n", + "\n", + "3. The `softmax` operation specifies the `dim` argument, not the positional axis. Note: The parameter is called `dim` and not `core_dims` because we try to stay as close as possible to the Xarray API, which uses `dim` throughout. But we make an exception for distributions because they already have the `dims` argument.\n", + "\n", + "4. The `Categorical` observed variable, like `ZeroSumNormal`, requires a `core_dims` argument to specify which dimension corresponds to the probability vector. Previously, it was necessary to place this dimension explicitly on the rightmost axis -- not any more!\n", + "\n", + "5. Even though dims were not specified for either `trial_preference` or `response`, PyMC automatically infers them.\n", + "\n", + "The graphviz representation looks the same as before." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "994058c6fe2920b3", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "clusterparticipant (5) x trial (20)\n", + "\n", + "participant (5) x trial (20)\n", + "\n", + "\n", + "clusterparticipant (5) x item (3)\n", + "\n", + "participant (5) x item (3)\n", + "\n", + "\n", + "clusteritem (3) x trial (20)\n", + "\n", + "item (3) x trial (20)\n", + "\n", + "\n", + "clusterparticipant (5) x item (3) x trial (20)\n", + "\n", + "participant (5) x item (3) x trial (20)\n", + "\n", + "\n", + "\n", + "observed_response\n", + "\n", + "observed_response\n", + "~\n", + "Data\n", + "\n", + "\n", + "\n", + "response\n", + "\n", + "response\n", + "~\n", + "Categorical\n", + "\n", + "\n", + "\n", + "response->observed_response\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "participant_preference\n", + "\n", + "participant_preference\n", + "~\n", + "ZeroSumNormal\n", + "\n", + "\n", + "\n", + "trial_preference\n", + "\n", + "trial_preference\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "participant_preference->trial_preference\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "time_effects\n", + "\n", + "time_effects\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "time_effects->trial_preference\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "trial_preference->response\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dmodel.to_graphviz()" + ] + }, + { + "cell_type": "markdown", + "id": "24f2cafa1500f08a", + "metadata": {}, + "source": [ + "We can also check that the models are equivalent by comparing the `logp` of each variable evaluated at the initial_point." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "11073d9b67a72f72", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:34.534342Z", + "start_time": "2025-06-30T15:53:30.457494Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'participant_preference': np.float64(-9.19), 'time_effects': np.float64(-55.14), 'response': np.float64(-109.86)}\n", + "{'participant_preference': np.float64(-9.19), 'time_effects': np.float64(-55.14), 'response': np.float64(-109.86)}\n" + ] + } + ], + "source": [ + "print(model.point_logps())\n", + "print(dmodel.point_logps())" + ] + }, + { + "cell_type": "markdown", + "id": "ee981114897c7ee0", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.349480Z", + "start_time": "2025-06-26T21:31:07.387385Z" + } + }, + "source": [ + "## A brief look under the hood" + ] + }, + { + "cell_type": "markdown", + "id": "e360b5f5e9e8ca1e", + "metadata": {}, + "source": [ + "The {mod}`pymc.dims` module functionality is built on top of the experimental {mod}`pytensor.xtensor` module in PyTensor, which is the {lib}`xarray` analogoue of the {mod}`pytensor.tensor` module you may be familiar with (see {doc}`pymc_and_pytensor`).\n", + "\n", + "Whereas regular distributions and math operations return {class}`~pytensor.tensor.TensorVariable` objects, the corresponding functions in the {mod}`pymc.dims` module returns {class}`~pytensor.xtensor.type.XTensorVariable` objects. These are very similar to {class}`~pytensor.tensor.TensorVariable`, but they have a `dims` attribute that determines their behavior." + ] + }, + { + "cell_type": "markdown", + "id": "5a1fa8ecb7309782", + "metadata": {}, + "source": [ + "As an example, we'll create a regular {class}`~pymc.Normal` random variable with 3 elements, and perform an outer addition on them using NumPy syntax." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9bece958e432c369", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:34.611874Z", + "start_time": "2025-06-30T15:53:34.598883Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorType(float64, shape=(3,))" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "regular_normal = pm.Normal.dist(mu=pm.math.as_tensor([0, 1, 2]), sigma=1, shape=(3,))\n", + "regular_normal.type" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "5c0aa77e31170c54", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:34.700430Z", + "start_time": "2025-06-30T15:53:34.693544Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "pytensor.tensor.variable.TensorVariable" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(regular_normal)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "d77168a8c6e89de9", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:34.805130Z", + "start_time": "2025-06-30T15:53:34.792870Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorType(float64, shape=(3, 3))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "outer_addition = regular_normal[:, None] + regular_normal[None, :]\n", + "outer_addition.type" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "ac95fe88c1877fbe", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:34.930130Z", + "start_time": "2025-06-30T15:53:34.887799Z" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0.61284312, 1.68384684, 1.72225487],\n", + " [1.68384684, 2.75485056, 2.79325859],\n", + " [1.72225487, 2.79325859, 2.83166662]])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pm.draw(outer_addition, random_seed=rng)" + ] + }, + { + "cell_type": "markdown", + "id": "b7d78e2f3984adbc", + "metadata": {}, + "source": [ + "Here's the same operation with a dimmed `Normal` variable. It requires the use of `rename`." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "4d9e8417789d4c18", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:35.055530Z", + "start_time": "2025-06-30T15:53:35.045241Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "XTensorType(float64, shape=(3,), dims=('a',))" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dims_normal = pmd.Normal.dist(mu=pmd.math.as_xtensor([0, 1, 2], dims=(\"a\",)), sigma=1)\n", + "dims_normal.type" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "1d0c30011c910370", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:35.264315Z", + "start_time": "2025-06-30T15:53:35.256071Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "pytensor.xtensor.type.XTensorVariable" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(dims_normal)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "923e75f235dcf219", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:35.472551Z", + "start_time": "2025-06-30T15:53:35.465427Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "XTensorType(float64, shape=(3, 3), dims=('a', 'b'))" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "outer_addition = dims_normal + dims_normal.rename({\"a\": \"b\"})\n", + "outer_addition.type" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "cd9e2a2634a0e898", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:35.758679Z", + "start_time": "2025-06-30T15:53:35.649256Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 3.76355516, 0.31059132, 6.5420105 ],\n", + " [ 0.31059132, -3.14237253, 3.08904666],\n", + " [ 6.5420105 , 3.08904666, 9.32046584]])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pm.draw(outer_addition, random_seed=rng)" + ] + }, + { + "cell_type": "markdown", + "id": "90d65f41-2279-4205-bc2e-c5446a176c61", + "metadata": {}, + "source": [ + ":::{tip} Note that there are no coordinates anywhere in the graph. {class}`~pytensor.xtensor.type.XTensorVariable`s behave like xarray DataArrays **without** coords. Dims determine the dimension meaning and alignment, but no extra work can be done to reason within a dim. We discuss this limitation in more detail at the end.:::" + ] + }, + { + "cell_type": "markdown", + "id": "e12c7cda782cac69", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.365035095Z", + "start_time": "2025-06-26T11:42:19.021166Z" + } + }, + "source": [ + "## Redundant (or implicit) dims" + ] + }, + { + "cell_type": "markdown", + "id": "4ac4877e24859dfe", + "metadata": {}, + "source": [ + "When defining deterministic operations or creating variables whose dimension are all implied by the parameters, there's no need to specify the `dims` argument, as PyMC will automatically know them.\n", + "\n", + "However, some users might want to specify `dims` anyway, to check that the dimensions of the variables are as expected, or to provide type hints for someone reading the model.\n", + "\n", + "PyMC allows specifying dimensions in these cases. To reduce confusion, the output will always be transposed to be aligned with the user-specified dims." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "cd6031d682b88e51", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:35.989010Z", + "start_time": "2025-06-30T15:53:35.965369Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "det_implicit_dims.dims=('a', 'b')\n", + "det_explicit_dims.dims=('a', 'b')\n", + "det_transposed_dims.dims=('b', 'a')\n" + ] + } + ], + "source": [ + "with pm.Model(coords={\"a\": range(2), \"b\": range(5)}) as example:\n", + " x = pmd.Normal(\"x\", dims=(\"a\", \"b\"))\n", + " det_implicit_dims = pmd.Deterministic(\"det1\", x + 1)\n", + " det_explicit_dims = pmd.Deterministic(\"det2\", x + 1, dims=(\"a\", \"b\"))\n", + " det_transposed_dims = pmd.Deterministic(\"y\", x + 1, dims=(\"b\", \"a\"))\n", + "\n", + "print(f\"{det_implicit_dims.dims=}\")\n", + "print(f\"{det_explicit_dims.dims=}\")\n", + "print(f\"{det_transposed_dims.dims=}\")" + ] + }, + { + "cell_type": "markdown", + "id": "db657b20b447c56a", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.365174897Z", + "start_time": "2025-06-26T11:42:19.389561Z" + } + }, + "source": [ + "This happens with `Deterministic`, `Potential` and every distribution in the `dims` module.\n", + "Any time you specify `dims`, you will get back a variable with dimensions in the same order." + ] + }, + { + "cell_type": "markdown", + "id": "e5f73575a51d316a", + "metadata": {}, + "source": [ + "Furthermore -- and unlike regular PyMC objects -- it is now valid to use ellipsis in the `dims` argument.\n", + "As in an Xarray `transpose`, it means all the other dimensions should stay in the same order." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "a1e3597de136004f", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:36.397445Z", + "start_time": "2025-06-30T15:53:36.371682Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "det_ellipsis1.dims=('a', 'b')\n", + "det_ellipsis2.dims=('b', 'a')\n", + "det_ellipsis3.dims=('b', 'a')\n" + ] + } + ], + "source": [ + "with pm.Model(coords={\"a\": range(2), \"b\": range(5)}) as example:\n", + " x = pmd.Normal(\"x\", dims=(\"a\", \"b\"))\n", + " det_ellipsis1 = pmd.Deterministic(\"det1\", x + 1, dims=(...,))\n", + " det_ellipsis2 = pmd.Deterministic(\"det2\", x + 1, dims=(..., \"a\"))\n", + " det_ellipsis3 = pmd.Deterministic(\"det3\", x + 1, dims=(\"b\", ...))\n", + "\n", + "print(f\"{det_ellipsis1.dims=}\")\n", + "print(f\"{det_ellipsis2.dims=}\")\n", + "print(f\"{det_ellipsis3.dims=}\")" + ] + }, + { + "cell_type": "markdown", + "id": "318c563fb2eb106b", + "metadata": {}, + "source": [ + "## What functionality is supported?" + ] + }, + { + "cell_type": "markdown", + "id": "88acbfae651fe294", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "The documentation is still a work in progress, and there is no complete list of distributions and operations that are supported just yet. \n", + "\n", + "#### Model constructors\n", + "\n", + "The following PyMC model constructors are available in the `dims` module.\n", + "\n", + " * {func}`~pymc.dims.Data`\n", + " * {func}`~pymc.dims.Deterministic`\n", + " * {func}`~pymc.dims.Potential`\n", + "\n", + "They all return {class}`~pytensor.xtensor.type.XTensorVariable` objects, and either infer `dims` from the input or require the user to specify them explicitly. If they can be inferred, it is possible to transpose and use ellipsis in the `dims` argument, as described above.\n", + "\n", + "#### Distributions\n", + "\n", + "We want to offer all the existing distributions and parametrizations under the {mod}`pymc.dims` module, with the following expected API differences:\n", + "\n", + " * All vector arguments (and observed values) must have known dims. An error is raised otherwise.\n", + "\n", + " * Distributions with non-scalar inputs will require a `core_dims` argument. The meaning of the `core_dims` argument will be denoted in the docstrings of each distribution. For example, for the MvNormal, the `core_dims` are the two dimensions of the covariance matrix, one (and only one) of which must also be present in the mean parameter. The shared `core_dim` is the one that persists in the output. Sometimes the order of `core_dims` will be important!\n", + "\n", + " * `dims` accept ellipsis, and variables are transposed to match the user-specified `dims` argument.\n", + "\n", + " * `shape` and `size` cannot be provided.\n", + "\n", + " * The {meth}`~pymc.distributions.core.DimDistribution.dist` method accepts a `dims_length` argument, of the form `{dim_name: dim_length}`.\n", + "\n", + " * Only transforms defined in {mod}`pymc.dims.transforms` can be used with distributions from the module.\n", + "\n", + "#### Operations on variables\n", + "\n", + "Calling a PyMC distribution from the {mod}`pymc.dims` module returns an {class}`~pytensor.xtensor.type.XTensorVariable`.\n", + "\n", + "The expectation is that every {class}`xarray.DataArray` method in Xarray should have an equivalent version for XTensorVariables. So if you can do `x.diff(dim=\"a\")` in Xarray, you should be able to do `x.diff(dim=\"a\")` with XTensorVariables as well.\n", + "\n", + "In addition, many numerical operations are available in the {mod}`pymc.dims.math` module, which provides a superset of `ufuncs` functions found in Xarray (like `exp`). It also includes submodules such as `linalg` that provide counterpart to libraries like {lib}`xarray_einstats` (such as `linalg.solve`).\n", + "\n", + "Finally, functions that are available at the module level in Xarray (like `concat`) are also available in the {mod}`pymc.dims` namespace.\n", + "\n", + "You can find the original documentation of these functions and methods in the {doc}`PyTensor Xtensor docs `\n", + "\n", + "To facilitate adoption of these functions and methods, we try to follow the same API used by Xarray and related packages. However, some methods or keyword arguments won't be supported explicitly (like {meth}`~pytensor.xtensor.XTensorVariable.sel`, more on that at the end), in which case an informative error or warning will be raised.\n", + "\n", + "If you find an API difference or some missing functionality, and no reason is provided, please [open an issue](https://github.com/pymc-devs/pymc/issues) to let us know (after checking nobody has done it already).\n", + "\n", + "In the meantime, the next section provides some hints on how to make use of pre-existing functionality in PyMC/PyTensor." + ] + }, + { + "cell_type": "markdown", + "id": "98a94f5939872e75", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.365484979Z", + "start_time": "2025-06-26T11:42:19.577389Z" + } + }, + "source": [ + "## Combining dims module with the old API" + ] + }, + { + "cell_type": "markdown", + "id": "3d305a2bb8313704", + "metadata": {}, + "source": [ + "Because the {mod}`pymc.dims` module is more recent, it does not offer all the functionality of the old API.\n", + "\n", + "You can always combine the two APIs by converting the variables explicitly. To obtain a regular non-dimmed variable from a dimmed variable, you can use {attr}`~pytensor.xtensor.type.XTensorVariable.values` (like in Xarray) or the more verbose {func}`pymc.dims.tensor_from_xtensor`.\n", + "\n", + "Otherwise, if you try to pass an {class}`~pytensor.xtensor.type.XTensorVariable` to a function or distribution that does not support it, you will usually see an error like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "df3d1f1ce895d02e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TypeError: To avoid subtle bugs, PyTensor forbids automatic conversion of XTensorVariable to TensorVariable.\n", + "You can convert explicitly using `x.values` or pass `allow_xtensor_conversion=True`.\n" + ] + } + ], + "source": [ + "mu = pmd.math.as_xtensor([0, 1, 2], dims=(\"a\",))\n", + "try:\n", + " pm.Normal.dist(mu=mu)\n", + "except TypeError as e:\n", + " print(f\"{e.__class__.__name__}: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "335b094a6cdd0bd8", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:36.767188Z", + "start_time": "2025-06-30T15:53:36.745075Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorType(float64, shape=(None, None))" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pm.Normal.dist(mu=x.values).type" + ] + }, + { + "cell_type": "markdown", + "id": "eca88323e2529a57", + "metadata": {}, + "source": [ + "The order of the dimensions follows that specified in the {attr}`~pytensor.xtensor.type.XTensorVariable.dims` property. To be sure this matches the expectation you can use a {meth}`~pytensor.xtensor.type.XTensorVariable.transpose` operation to reorder the dimensions before converting to a regular variable." + ] + }, + { + "cell_type": "markdown", + "id": "9e256f247d3bd9e1", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.365732804Z", + "start_time": "2025-06-26T11:42:19.755931Z" + } + }, + "source": [ + "Conversely, if you try to pass a regular variable to a function or distribution that expects an XTensorVariable, you will see an error like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "380f7161ca2d22e4", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:36.879320Z", + "start_time": "2025-06-30T15:53:36.868282Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ValueError: Variable mu_x{[0 1 2]} must have dims associated with it.\n", + "To avoid subtle bugs, PyMC does not make any assumptions about the dims of parameters.\n", + "Use `as_xtensor` with the `dims` keyword argument to specify the dims explicitly.\n" + ] + } + ], + "source": [ + "mu = pm.math.as_tensor([0, 1, 2], name=\"mu_x\")\n", + "try:\n", + " x = pmd.Normal.dist(mu=mu)\n", + "except Exception as e:\n", + " print(f\"{e.__class__.__name__}: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "e507c0a8e76447fd", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.365847399Z", + "start_time": "2025-06-26T11:42:19.826864Z" + } + }, + "source": [ + "Which you can avoid by explicitly converting the variable to a dimmed variable:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "b50363261fe81df6", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:37.226477Z", + "start_time": "2025-06-30T15:53:37.204756Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "XTensorType(float64, shape=(3,), dims=('a',))" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pmd.Normal.dist(mu=pmd.as_xtensor(mu, dims=(\"a\",))).type" + ] + }, + { + "cell_type": "markdown", + "id": "3f2202b55cbc8c4", + "metadata": {}, + "source": [ + "#### Applied example\n", + "\n", + "To put this to practice, let's write a model that uses the {class}`~pymc.LKJCholeskyCov` distribution, which at the time of writing is not yet available in the {mod}`pymc.dims` module." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "195f6f509e797f9a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "chol_xr.dims=('core1', 'core2')\n", + "mu.dims=('batch', 'core1')\n", + "y.dims=('batch', 'core1')\n" + ] + } + ], + "source": [ + "with pm.Model(coords={\"core1\": range(3), \"core2\": range(3), \"batch\": range(5)}) as mixed_api_model:\n", + " chol, _, _ = pm.LKJCholeskyCov(\n", + " \"chol\",\n", + " eta=1,\n", + " n=3,\n", + " sd_dist=pm.Exponential.dist(1),\n", + " )\n", + " chol_xr = pmd.as_xtensor(chol, dims=(\"core1\", \"core2\"))\n", + "\n", + " mu = pmd.Normal(\"mu\", dims=(\"batch\", \"core1\"))\n", + " y = pmd.MvNormal(\n", + " \"y\",\n", + " mu,\n", + " chol=chol_xr,\n", + " core_dims=(\"core1\", \"core2\"),\n", + " )\n", + "\n", + "print(f\"{chol_xr.dims=}\")\n", + "print(f\"{mu.dims=}\")\n", + "print(f\"{y.dims=}\")" + ] + }, + { + "cell_type": "markdown", + "id": "3f0dd0fea61f8862", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.366068103Z", + "start_time": "2025-06-26T11:42:19.946909Z" + } + }, + "source": [ + "Note that we had to pass a \"regular\" {class}`~pymc.Exponential` distribution to the {class}`~pymc.LKJCholeskyCov` constructor. In general, distribution \"factories\" which are parametrized by unnamed distributions created with {meth}`~pymc.distributions.distribution.Distribution.dist` won't work with variables created with the {mod}`pymc.dims` module." + ] + }, + { + "cell_type": "markdown", + "id": "6464ae49ff629660", + "metadata": {}, + "source": [ + "## Case study: a splines model begging for vectorization" + ] + }, + { + "cell_type": "markdown", + "id": "c7c622bbde1c675", + "metadata": {}, + "source": [ + "The model below was presented by a user in a bug report. While granting that the author might have written the model this way to provide a reproducible example, we can say that the model is written in a highly suboptimal form. Specifically it misses (or actively breaks) many opportunities for vectorization. Seasoned Python programmers will know that python loops are SLOW, and that tools like numpy provide a way to escape from this handicap.\n", + "\n", + "PyMC code is not exactly numpy, for starters it uses a lazy symbolic computation library (PyTensor) that generates compiled code on demand. But it very much likes to be given numpy-like code. To begin with these graphs are much smaller and therefore easier to reason about (in fact the original bug could only be triggered for graphs with more than 500 nodes). Secondly, numpy-like graphs naturally translate to vectorized CPU and GPU code, which you want at the end of the day." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "726f696e6bf17d44", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:39.748768Z", + "start_time": "2025-06-30T15:53:39.740157Z" + } + }, + "outputs": [], + "source": [ + "# Simulated data of some spline\n", + "N = 500\n", + "x_np = np.linspace(0, 10, N)\n", + "y_obs_np = np.piecewise(\n", + " x_np,\n", + " [x_np <= 3, (x_np > 3) & (x_np <= 7), x_np > 7],\n", + " [lambda x: 0.5 * x, lambda x: 1.5 + 0.2 * (x - 3), lambda x: 2.3 - 0.1 * (x - 7)],\n", + ")\n", + "y_obs_np += rng.normal(0, 0.2, size=N) # Add noise\n", + "\n", + "# Artificial groups\n", + "groups = [0, 1, 2]\n", + "group_idx_np = np.random.choice(groups, size=N)\n", + "\n", + "n_knots = 50\n", + "knots_np = np.linspace(0, 10, num=n_knots)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "2e8519125d17f0c8", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:40.367851Z", + "start_time": "2025-06-30T15:53:39.854962Z" + } + }, + "outputs": [], + "source": [ + "with pm.Model() as non_vectorized_splines_model:\n", + " sigma_beta0 = pm.HalfNormal(\"sigma_beta0\", sigma=10)\n", + " sigma = pm.HalfCauchy(\"sigma\", beta=1)\n", + "\n", + " # Create likelihood per group\n", + " for gr in groups:\n", + " idx = group_idx_np == gr\n", + "\n", + " beta0 = pm.HalfNormal(f\"beta0_{gr}\", sigma=sigma_beta0)\n", + " z = pm.Normal(f\"z_{gr}\", mu=0, sigma=2, shape=n_knots)\n", + "\n", + " delta_factors = pm.math.softmax(z)\n", + " slope_factors = 1 - pm.math.cumsum(delta_factors[:-1])\n", + " spline_slopes = pm.math.stack(\n", + " [beta0] + [beta0 * slope_factors[i] for i in range(n_knots - 1)]\n", + " )\n", + " beta = pm.Deterministic(\n", + " f\"beta_{gr}\",\n", + " pm.math.concatenate(([beta0], pm.math.diff(spline_slopes))),\n", + " )\n", + "\n", + " hinge_terms = [pm.math.maximum(0, x_np[idx] - knot) for knot in knots_np]\n", + " X = pm.math.stack([hinge_terms[i] for i in range(n_knots)], axis=1)\n", + "\n", + " mu = pm.math.dot(X, beta)\n", + "\n", + " pm.Normal(f\"y_{gr}\", mu=mu, sigma=sigma, observed=y_obs_np[idx])" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "c56dcb51a07a5b54", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:40.700094Z", + "start_time": "2025-06-30T15:53:40.463806Z" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster50\n", + "\n", + "50\n", + "\n", + "\n", + "cluster176\n", + "\n", + "176\n", + "\n", + "\n", + "cluster175\n", + "\n", + "175\n", + "\n", + "\n", + "cluster149\n", + "\n", + "149\n", + "\n", + "\n", + "\n", + "beta0_1\n", + "\n", + "beta0_1\n", + "~\n", + "Halfnormal\n", + "\n", + "\n", + "\n", + "beta_1\n", + "\n", + "beta_1\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "beta0_1->beta_1\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "beta0_0\n", + "\n", + "beta0_0\n", + "~\n", + "Halfnormal\n", + "\n", + "\n", + "\n", + "beta_0\n", + "\n", + "beta_0\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "beta0_0->beta_0\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "sigma\n", + "\n", + "sigma\n", + "~\n", + "Halfcauchy\n", + "\n", + "\n", + "\n", + "y_0\n", + "\n", + "y_0\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "sigma->y_0\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "y_1\n", + "\n", + "y_1\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "sigma->y_1\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "y_2\n", + "\n", + "y_2\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "sigma->y_2\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "beta0_2\n", + "\n", + "beta0_2\n", + "~\n", + "Halfnormal\n", + "\n", + "\n", + "\n", + "beta_2\n", + "\n", + "beta_2\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "beta0_2->beta_2\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "sigma_beta0\n", + "\n", + "sigma_beta0\n", + "~\n", + "Halfnormal\n", + "\n", + "\n", + "\n", + "sigma_beta0->beta0_1\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "sigma_beta0->beta0_0\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "sigma_beta0->beta0_2\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "beta_1->y_1\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "beta_2->y_2\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "beta_0->y_0\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "z_2\n", + "\n", + "z_2\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "z_2->beta_2\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "z_1\n", + "\n", + "z_1\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "z_1->beta_1\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "z_0\n", + "\n", + "z_0\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "z_0->beta_0\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "non_vectorized_splines_model.to_graphviz()" + ] + }, + { + "cell_type": "markdown", + "id": "6be0b8afd4990567", + "metadata": {}, + "source": [ + "This version of the model has several `for` loop and list comprehensions.\n", + "Every Python iteration extends the computational graph (basically unrolling the loop), which becomes infeasible for PyMC to handle.\n", + "\n", + "The use of 3 likelihood and sets of priors is otherwise fine and can make more sense in some cases.\n", + "\n", + "We can improve the performance of this model by replacing the loop with vectorized operations.\n", + "We'll demonstrate first with current PyMC features, and then with dims." + ] + }, + { + "cell_type": "markdown", + "id": "33d16a1bef908f3f", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:52.686637Z", + "start_time": "2025-06-26T21:31:52.509222Z" + } + }, + "source": [ + "### Old style vectorization" + ] + }, + { + "cell_type": "markdown", + "id": "d8add515e2e56179", + "metadata": {}, + "source": [ + "\n", + "First, we'll introduce coords and {func}`~pymc.Data` to explicitly define the coordinate system, making the model's dimensionality clear and facilitating posterior analysis with labeled dimensions." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "6ee0f6e24fa81b02", + "metadata": {}, + "outputs": [], + "source": [ + "coords = {\n", + " \"group\": range(3),\n", + " \"knots\": range(n_knots),\n", + " \"obs\": range(N),\n", + "}\n", + "with pm.Model(coords=coords) as vectorized_splines_model:\n", + " x = pm.Data(\"x\", x_np, dims=\"obs\")\n", + " y_obs = pm.Data(\"y_obs\", y_obs_np, dims=\"obs\")\n", + "\n", + " knots = pm.Data(\"knots\", knots_np, dims=\"knot\")\n", + "\n", + " sigma = pm.HalfCauchy(\"sigma\", beta=1)\n", + " sigma_beta0 = pm.HalfNormal(\"sigma_beta0\", sigma=10)\n", + " beta0 = pm.HalfNormal(\"beta_0\", sigma=sigma_beta0, dims=\"group\")\n", + " z = pm.Normal(\"z\", dims=(\"group\", \"knot\"))\n", + "\n", + " delta_factors = pm.math.softmax(z, axis=-1) # (groups, knot)\n", + " slope_factors = 1 - pm.math.cumsum(delta_factors[:, :-1], axis=-1) # (groups, knot-1)\n", + " spline_slopes = pm.math.concatenate(\n", + " [beta0[:, None], beta0[:, None] * slope_factors], axis=-1\n", + " ) # (groups, knot-1)\n", + " beta = pm.math.concatenate(\n", + " [beta0[:, None], pm.math.diff(spline_slopes, axis=-1)], axis=-1\n", + " ) # (groups, knot)\n", + "\n", + " beta = pm.Deterministic(\"beta\", beta, dims=(\"group\", \"knot\"))\n", + "\n", + " X = pm.math.maximum(0, x[:, None] - knots[None, :]) # (n, knot)\n", + " mu = (X * beta[group_idx_np]).sum(-1) # ((n, knots) * (n, knots)).sum(-1) = (n,)\n", + " y = pm.Normal(\"y\", mu=mu, sigma=sigma, observed=y_obs, dims=\"obs\")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "bd94fa06a1190a05", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:41.063986Z", + "start_time": "2025-06-30T15:53:40.980450Z" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "clusterobs (500)\n", + "\n", + "obs (500)\n", + "\n", + "\n", + "clusterknot (50)\n", + "\n", + "knot (50)\n", + "\n", + "\n", + "clustergroup (3)\n", + "\n", + "group (3)\n", + "\n", + "\n", + "clustergroup (3) x knot (50)\n", + "\n", + "group (3) x knot (50)\n", + "\n", + "\n", + "\n", + "x\n", + "\n", + "x\n", + "~\n", + "Data\n", + "\n", + "\n", + "\n", + "y\n", + "\n", + "y\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "x->y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "y_obs\n", + "\n", + "y_obs\n", + "~\n", + "Data\n", + "\n", + "\n", + "\n", + "y->y_obs\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "knots\n", + "\n", + "knots\n", + "~\n", + "Data\n", + "\n", + "\n", + "\n", + "knots->y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "sigma\n", + "\n", + "sigma\n", + "~\n", + "Halfcauchy\n", + "\n", + "\n", + "\n", + "sigma->y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "sigma_beta0\n", + "\n", + "sigma_beta0\n", + "~\n", + "Halfnormal\n", + "\n", + "\n", + "\n", + "beta_0\n", + "\n", + "beta_0\n", + "~\n", + "Halfnormal\n", + "\n", + "\n", + "\n", + "sigma_beta0->beta_0\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "beta\n", + "\n", + "beta\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "beta_0->beta\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "beta->y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "z\n", + "\n", + "z\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "z->beta\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vectorized_splines_model.to_graphviz()" + ] + }, + { + "cell_type": "markdown", + "id": "a98bbe7f2e17783a", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:32:15.059906Z", + "start_time": "2025-06-26T21:32:15.024281Z" + } + }, + "source": [ + "Looking only at grahpviz does not reveal much difference. The real complexity is hidden in the underlying computational graph. Let us measure it." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "90f8f4ee3fd9f2a3", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:41.192185Z", + "start_time": "2025-06-30T15:53:41.162347Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Non-vectorized model has 806 nodes\n", + "Vectorized model has 38 nodes\n" + ] + } + ], + "source": [ + "from pytensor.graph import ancestors\n", + "\n", + "\n", + "def count_nodes(model):\n", + " return len({v for v in ancestors(model.basic_RVs) if v.owner})\n", + "\n", + "\n", + "print(f\"Non-vectorized model has {count_nodes(non_vectorized_splines_model)} nodes\")\n", + "print(f\"Vectorized model has {count_nodes(vectorized_splines_model)} nodes\")" + ] + }, + { + "cell_type": "markdown", + "id": "daa39293278cbb05", + "metadata": {}, + "source": [ + "This version of the model is much more succinct and performant.\n", + "\n", + "However, it's considerably hard to write, specially for beginners. It requires familiarity with Numpy unusual broadcasting syntax and mental gymnastics to keep track of the dimensions." + ] + }, + { + "cell_type": "markdown", + "id": "a532b2ec365bfecf", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:32:17.432572Z", + "start_time": "2025-06-26T21:32:17.415088Z" + } + }, + "source": [ + "### Vectorization with dims" + ] + }, + { + "cell_type": "markdown", + "id": "c24ce399f0eeae4c", + "metadata": {}, + "source": [ + "Here is a version of the model using {mod}`pymc.dims`, which we hope you'll agree is more intuitive" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "1bf2c6b99378b2db", + "metadata": {}, + "outputs": [], + "source": [ + "with pm.Model(coords=coords) as dims_splines_model:\n", + " x = pmd.Data(\"x\", x_np, dims=\"obs\")\n", + " y_obs = pmd.Data(\"y_obs\", y_obs_np, dims=\"obs\")\n", + " knots = pmd.Data(\"knots\", knots_np, dims=(\"knot\",))\n", + " group_idx = pmd.math.as_xtensor(group_idx_np, dims=(\"obs\",))\n", + "\n", + " sigma = pmd.HalfCauchy(\"sigma\", beta=1)\n", + " sigma_beta0 = pmd.HalfNormal(\"sigma_beta0\", sigma=10)\n", + " beta0 = pmd.HalfNormal(\"beta_0\", sigma=sigma_beta0, dims=(\"group\",))\n", + " z = pmd.Normal(\"z\", dims=(\"group\", \"knot\"))\n", + "\n", + " delta_factors = pmd.math.softmax(z, dim=\"knot\")\n", + " slope_factors = 1 - delta_factors.isel(knot=slice(None, -1)).cumsum(\"knot\")\n", + " spline_slopes = pmd.concat([beta0, beta0 * slope_factors], dim=\"knot\")\n", + " beta = pmd.Deterministic(\"beta\", pmd.concat([beta0, spline_slopes.diff(\"knot\")], dim=\"knot\"))\n", + "\n", + " X = pmd.math.maximum(0, x - knots)\n", + " mu = (X * beta.isel(group=group_idx)).sum(\"knot\")\n", + " y = pmd.Normal(\"y\", mu=mu, sigma=sigma, observed=y_obs)" + ] + }, + { + "cell_type": "markdown", + "id": "91593e5b3dbc885f", + "metadata": {}, + "source": [ + "Let us confirm it is identical to the previous version." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "5f55cf90d737dcf", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:42.362928Z", + "start_time": "2025-06-30T15:53:42.151483Z" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "clusterobs (500)\n", + "\n", + "obs (500)\n", + "\n", + "\n", + "clusterknot (50)\n", + "\n", + "knot (50)\n", + "\n", + "\n", + "clustergroup (3)\n", + "\n", + "group (3)\n", + "\n", + "\n", + "clustergroup (3) x knot (50)\n", + "\n", + "group (3) x knot (50)\n", + "\n", + "\n", + "\n", + "x\n", + "\n", + "x\n", + "~\n", + "Data\n", + "\n", + "\n", + "\n", + "y\n", + "\n", + "y\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "x->y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "y_obs\n", + "\n", + "y_obs\n", + "~\n", + "Data\n", + "\n", + "\n", + "\n", + "y->y_obs\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "knots\n", + "\n", + "knots\n", + "~\n", + "Data\n", + "\n", + "\n", + "\n", + "knots->y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "sigma\n", + "\n", + "sigma\n", + "~\n", + "Halfcauchy\n", + "\n", + "\n", + "\n", + "sigma->y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "sigma_beta0\n", + "\n", + "sigma_beta0\n", + "~\n", + "Halfnormal\n", + "\n", + "\n", + "\n", + "beta_0\n", + "\n", + "beta_0\n", + "~\n", + "Halfnormal\n", + "\n", + "\n", + "\n", + "sigma_beta0->beta_0\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "beta\n", + "\n", + "beta\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "beta_0->beta\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "beta->y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "z\n", + "\n", + "z\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "z->beta\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dims_splines_model.to_graphviz()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "3a9dd4669342ff9e", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:42.797760Z", + "start_time": "2025-06-30T15:53:42.794464Z" + } + }, + "outputs": [], + "source": [ + "# Comment out if you want to wait a long while for the results\n", + "# non_vectorized_splines_model.point_logps()" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "86701f276baaa7c", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:44.309696Z", + "start_time": "2025-06-30T15:53:43.273978Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'sigma': np.float64(-1.14),\n", + " 'sigma_beta0': np.float64(-0.73),\n", + " 'beta_0': np.float64(-2.18),\n", + " 'z': np.float64(-137.84),\n", + " 'y': np.float64(-319962.47)}" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vectorized_splines_model.point_logps()" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "f34c353131915cef", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:46.279261Z", + "start_time": "2025-06-30T15:53:45.344064Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'sigma': np.float64(-1.14),\n", + " 'sigma_beta0': np.float64(-0.73),\n", + " 'beta_0': np.float64(-2.18),\n", + " 'z': np.float64(-137.84),\n", + " 'y': np.float64(-319962.47)}" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dims_splines_model.point_logps()" + ] + }, + { + "cell_type": "markdown", + "id": "1bb865850774d0b2", + "metadata": {}, + "source": [ + "We have just two more nodes in the computation graph, with no impact on performance" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "ae2afa7841251f6c", + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-08T17:06:40.561788Z", + "start_time": "2025-07-08T17:06:40.554136Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Vectorized model has 38 nodes\n", + "Vectorized model with dims has 40 nodes\n" + ] + } + ], + "source": [ + "print(f\"Vectorized model has {count_nodes(vectorized_splines_model)} nodes\")\n", + "print(f\"Vectorized model with dims has {count_nodes(dims_splines_model)} nodes\")" + ] + }, + { + "cell_type": "markdown", + "id": "014c701b-1e23-4fc4-ae55-4279263b2ed9", + "metadata": {}, + "source": [ + "This wraps the presentation of the {mod}`pymc.dims` module. We invite you to try it out and report any problems you may find. If you wish to contribute missing functionality get involved in our [Github repository](https://github.com/pymc-devs/pymc)." + ] + }, + { + "cell_type": "markdown", + "id": "2059b00d-6d6a-41af-8e71-3709e6a33dcc", + "metadata": {}, + "source": [ + "## Questions and answers" + ] + }, + { + "cell_type": "markdown", + "id": "8d1a93d8-6fc5-42db-bae0-d4a3fd68abf7", + "metadata": {}, + "source": [ + "### How to get xarray out of XTensorVariables?" + ] + }, + { + "cell_type": "markdown", + "id": "0b8d9835-1bcd-4a20-a925-2ac88045113c", + "metadata": {}, + "source": [ + "Usually PyMC routines take care of converting outputs to Xarray for the user. \n", + "\n", + "If you want to evaluate an expression yourself you'll likely get a numpy array back. \n", + "To convert to a {class}`~xarray.DataArray` simply reuse the `dims` attribute from the PyMC variable." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "7ea9dc87-77e5-40f1-95ec-233984912443", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray (a: 3, b: 3)> Size: 72B\n",
+       "array([[-3.28099006, -1.57426294,  1.21146118],\n",
+       "       [-1.57426294,  0.13246418,  2.9181883 ],\n",
+       "       [ 1.21146118,  2.9181883 ,  5.70391243]])\n",
+       "Dimensions without coordinates: a, b
" + ], + "text/plain": [ + " Size: 72B\n", + "array([[-3.28099006, -1.57426294, 1.21146118],\n", + " [-1.57426294, 0.13246418, 2.9181883 ],\n", + " [ 1.21146118, 2.9181883 , 5.70391243]])\n", + "Dimensions without coordinates: a, b" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from xarray import DataArray\n", + "\n", + "x = pm.dims.Normal.dist(dim_lengths={\"a\": 3})\n", + "outer_x = x + x.rename({\"a\": \"b\"})\n", + "res_numpy = pm.draw(outer_addition)\n", + "res_xr = DataArray(res_numpy, dims=outer_addition.dims)\n", + "res_xr" + ] + }, + { + "cell_type": "markdown", + "id": "93ec3533ec122baa", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:32:29.344857Z", + "start_time": "2025-06-26T21:32:28.811467Z" + } + }, + "source": [ + "### Why aren't coordinate-based operations supported?" + ] + }, + { + "cell_type": "markdown", + "id": "b20324c9f4f65d4b", + "metadata": {}, + "source": [ + "The new xtensor variable and operations do not propagate information about coordinates, which means you cannot perform coordinate-related operations like `sel`, `loc`, `drop`.\n", + "\n", + "This limitation might be disappointing for someone used to Xarray, it is a necessary trade-off to allow PyMC to evaluate the model in a performant way.\n", + "PyMC uses PyTensor under the hood, which compiles functions to C (or numba or JAX) backends. None of these backends support dims or coordinates.\n", + "To work with these backends, PyTensor rewrites xtensor operations using equivalent tensor operations, which are pretty much abstract NumPy code. This is relatively easy to do, because it's mostly about aligning dimensions for broadcasting or indexing correctly.\n", + "\n", + "Rewriting coordinate-related operations into NumPy-like code is a different matter. Many such operations don't have straightforward equivalency, they are more like querying or joining a database than performing array operations.\n", + "\n", + "PyMC models will keep supporting the `coords` argument as a way to specify dimensions of model variables. But for modeling purposes, only the dimension names and their lengths play a role." + ] + }, + { + "cell_type": "markdown", + "id": "f963a53c229c04b9", + "metadata": {}, + "source": [ + "### One final note of caution on coordinates" + ] + }, + { + "cell_type": "markdown", + "id": "626c004fa7abdccf", + "metadata": {}, + "source": [ + "When you provide coords to a PyMC model, they are attached to any functions that returns Xarray or InferenceData objects.\n", + "\n", + "This creates a potential problem.\n", + "\n", + "Suppose we have multiple arrays with the same dims but different shapes.\n", + "This is legal in PyMC, as in Xarray, and some operations, like indexing or concatenating, can handle it.\n", + "\n", + "However, after sampling, PyMC tries to reattach the coordinates to any computed variables, and these might not have the right shape, or they might not be correctly aligned.\n", + "\n", + "When PyMC tries to convert the results of sampling to InferenceData, it will issue a warning and refuse to propagate the original coordinates.\n", + "\n", + "Here's an example where we have two variables with the `a` dim but different shapes, and only one matches the shape of the coordinates specified in the model." + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "2c6f3a2b70a4d77d", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:46.850715Z", + "start_time": "2025-06-30T15:53:46.750208Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling: [x]\n", + "/home/ricardo/Documents/pymc/pymc/backends/arviz.py:70: UserWarning: Incompatible coordinate length of 3 for dimension 'a' of variable 'y'.\n", + "This usually happens when a sliced or concatenated variable is wrapped as a `pymc.dims.Deterministic`.The originate coordinates for this dim will not be included in the returned dataset for any of the variables. Instead they will default to `np.arange(var_length)` and the shorter variables will be right-padded with nan.\n", + "To make this warning into an error set `pymc.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS` to `True`\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "Coordinates:\n", + " * chain (chain) int64 8B 0\n", + " * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n", + " * a (a) int64 24B 0 1 2" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with pm.Model(coords={\"a\": [-1, 0, 1]}) as m:\n", + " x = pmd.Normal(\"x\", dims=(\"a\",))\n", + " y = pmd.Deterministic(\"y\", x.isel(a=slice(1, None)))\n", + " assert y.dims == (\"a\",)\n", + "\n", + " idata = pm.sample_prior_predictive()\n", + "idata.prior[\"y\"].coords" + ] + }, + { + "cell_type": "markdown", + "id": "a20137348aa4f5bf", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.371358319Z", + "start_time": "2025-06-26T11:42:20.084741Z" + } + }, + "source": [ + "One way to work around this limitation is to rename the dimensions to avoid the conflict." + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "466f152cb249d1d2", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:47.572296Z", + "start_time": "2025-06-30T15:53:47.486879Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling: [x]\n" + ] + }, + { + "data": { + "text/plain": [ + "Coordinates:\n", + " * chain (chain) int64 8B 0\n", + " * draw (draw) int64 8B 0\n", + " * a* (a*) int64 16B -2 -1" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with pm.Model(coords={\"a\": [-3, -2, -1], \"a*\": [-2, -1]}) as m:\n", + " x = pmd.Normal(\"x\", dims=(\"a\",))\n", + " y = pmd.Deterministic(\"y\", x.isel(a=slice(1, None)).rename({\"a\": \"a*\"}))\n", + " assert y.dims == (\"a*\",)\n", + " # You can rename back to the original name if you need it for further operations\n", + " y = y.rename({\"a*\": \"a\"})\n", + "\n", + " idata = pm.sample_prior_predictive(draws=1)\n", + "idata.prior[\"y\"].coords" + ] + }, + { + "cell_type": "markdown", + "id": "9a051144b1259fc8", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-26T21:31:08.371499143Z", + "start_time": "2025-06-26T11:42:20.159762Z" + } + }, + "source": [ + "An alternative is to manually specify the coordinates after sampling.\n", + "\n", + "Note that when doing advanced indexing the name of the indexed dimension can be controlled by the name of the indexing xtensor" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "87726153a7d1fd2", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:48.186501Z", + "start_time": "2025-06-30T15:53:48.181799Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "('a*',)" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x.isel(a=pmd.math.as_xtensor([0, 1, 2], dims=(\"a*\",))).dims" + ] + }, + { + "cell_type": "markdown", + "id": "38a8684f-4849-4f69-a7c7-7433182358c2", + "metadata": {}, + "source": [ + "However, silent bugs can still happen if the shapes are compatible but the coordinates are not correct.\n", + "For example, in this model the coordinates are reversed, but the shape matches, so the error is not detected." + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "8443e99c-8c90-40e1-8189-943e374c1387", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:48.819667Z", + "start_time": "2025-06-30T15:53:48.747097Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling: [x]\n" + ] + }, + { + "data": { + "text/plain": [ + "array([[[-1.53640397, 0. , 1.53640397]]])" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with pm.Model(coords={\"a\": [1, 2, 3]}):\n", + " x = pmd.Normal(\"x\", dims=(\"a\",))\n", + " pmd.Deterministic(\"x_reversed\", x[::-1])\n", + " idata = pm.sample_prior_predictive(draws=1)\n", + "(idata.prior[\"x\"] - idata.prior[\"x_reversed\"]).values" + ] + }, + { + "cell_type": "markdown", + "id": "1fd902d9-33fd-40d6-a66d-fc15f1064b5c", + "metadata": {}, + "source": [ + "In Xarray the results would be correct because it is aware of the coordinates, not just the shape." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "2d282c1a-8a30-476d-8d92-94ea9f9a6971", + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-30T15:53:49.345188Z", + "start_time": "2025-06-30T15:53:49.337999Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[0., 0., 0.]]])" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(idata.prior[\"x\"] - idata.prior[\"x\"].isel(a=slice(None, None, -1))).values" + ] + }, + { + "cell_type": "markdown", + "id": "fa8552be-16b8-4dd0-9bdc-38df8d5609b1", + "metadata": {}, + "source": [ + "Is not a new problem with the {mod}`pymc.dims` module; it is a consequence of PyMC's inability to reason about coords symbolically.\n", + "But this kind of error is made more likely because functions from the {mod}`pymc.dims` module always propagate dimension names to the model object.\n", + "\n", + "We remind users that {func}`pymc.dims.Deterministic` variables are never required in a model; they are just a way to calculate and store the results of intermediate operations. If you use them, pay extra attention as to whether the model coordinates are appropriate for the variable stored in the {func}`pymc.dims.Deterministic` (and not just their length but ordering as well).\n", + "\n", + "Alternatively, you can use the regular {func}`pymc.Deterministic` without specifying `dims`, which will not propagate the coordinates to the model. Keep in mind that the respective dimensions will be considered unique after sampling, and operations between variables that had shared dims in the original model will broadcast orthogonally in the returned InferenceData variables.\n", + "\n", + "If you need variables to have the same dims but different coords, you can always fix them manually." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "f994c42a-8468-4e98-96e2-eec1086cc7d4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray (chain: 1, draw: 1, a: 3)> Size: 24B\n",
+       "array([[[0., 0., 0.]]])\n",
+       "Coordinates:\n",
+       "  * chain    (chain) int64 8B 0\n",
+       "  * draw     (draw) int64 8B 0\n",
+       "  * a        (a) int64 24B 1 2 3
" + ], + "text/plain": [ + " Size: 24B\n", + "array([[[0., 0., 0.]]])\n", + "Coordinates:\n", + " * chain (chain) int64 8B 0\n", + " * draw (draw) int64 8B 0\n", + " * a (a) int64 24B 1 2 3" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "idata.prior[\"x_reversed\"] = idata.prior[\"x_reversed\"].assign_coords({\"a\": [3, 2, 1]})\n", + "idata.prior[\"x\"] - idata.prior[\"x_reversed\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b1a5481d-085e-4c89-b0fa-3b4c6317c36b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc", + "language": "python", + "name": "pymc" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 341ffca6bb2c6d565a89fd4d32674c43baa2330e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 14 Jul 2025 12:54:25 +0200 Subject: [PATCH 10/10] List dims module functionality in the docs --- docs/source/api.rst | 1 + docs/source/api/dims.rst | 18 ++++++++++ docs/source/api/dims/distributions.rst | 36 +++++++++++++++++++ docs/source/api/dims/math.rst | 9 +++++ docs/source/api/dims/model.rst | 11 ++++++ docs/source/api/dims/transforms.rst | 11 ++++++ .../learn/core_notebooks/dims_module.ipynb | 2 ++ docs/source/learn/core_notebooks/index.md | 3 +- 8 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 docs/source/api/dims.rst create mode 100644 docs/source/api/dims/distributions.rst create mode 100644 docs/source/api/dims/math.rst create mode 100644 docs/source/api/dims/model.rst create mode 100644 docs/source/api/dims/transforms.rst diff --git a/docs/source/api.rst b/docs/source/api.rst index d80c0984ff..72e4ca5e58 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -23,6 +23,7 @@ API api/backends api/misc api/testing + api/dims ------------------ Dimensionality diff --git a/docs/source/api/dims.rst b/docs/source/api/dims.rst new file mode 100644 index 0000000000..60a0138164 --- /dev/null +++ b/docs/source/api/dims.rst @@ -0,0 +1,18 @@ +.. _api_dims: + +Dims +==== + +This submodule contains functions for defining distributions and operations that use explicit dimensions. + +The module is presented in :ref:`dims_module`. + +.. currentmodule:: pymc.dims + +.. autosummary:: + :toctree: generated/ + + dims/model + dims/math + dims/distributions + dims/transforms diff --git a/docs/source/api/dims/distributions.rst b/docs/source/api/dims/distributions.rst new file mode 100644 index 0000000000..8048d0a20f --- /dev/null +++ b/docs/source/api/dims/distributions.rst @@ -0,0 +1,36 @@ +******************** +Scalar distributions +******************** + +.. currentmodule:: pymc.dims +.. autosummary:: + :toctree: generated/ + :template: distribution.rst + + Flat + HalfFlat + Normal + HalfNormal + LogNormal + StudentT + HalfStudentT + Cauchy + HalfCauchy + Beta + Laplace + Gamma + InverseGamma + + +******************** +Vector distributions +******************** + +.. currentmodule:: pymc.dims +.. autosummary:: + :toctree: generated/ + :template: distribution.rst + + Categorical + MvNormal + ZeroSumNormal diff --git a/docs/source/api/dims/math.rst b/docs/source/api/dims/math.rst new file mode 100644 index 0000000000..78cac9ca50 --- /dev/null +++ b/docs/source/api/dims/math.rst @@ -0,0 +1,9 @@ +*************************************** +Mathematical operations with dimensions +*************************************** + +This module wraps all the mathematical operations defined in :doc:`pytensor.xtensor.math `. + +It includes a ``linalg`` submodule that wraps all the operations defined in :doc:`pytensor.xtensor.linalg `. + +Operations defined at the module level in :doc:`pytensor.xtensor ` are available at the parent model in ``pymc.dims`` instead of in here. diff --git a/docs/source/api/dims/model.rst b/docs/source/api/dims/model.rst new file mode 100644 index 0000000000..15e64758c5 --- /dev/null +++ b/docs/source/api/dims/model.rst @@ -0,0 +1,11 @@ +****************** +Model constructors +****************** + +.. currentmodule:: pymc.dims +.. autosummary:: + :toctree: generated/ + + Data + Deterministic + Potential diff --git a/docs/source/api/dims/transforms.rst b/docs/source/api/dims/transforms.rst new file mode 100644 index 0000000000..c1fbae9b65 --- /dev/null +++ b/docs/source/api/dims/transforms.rst @@ -0,0 +1,11 @@ +*********************** +Distribution Transforms +*********************** + +.. currentmodule:: pymc.dims.transforms +.. autosummary:: + :toctree: generated/ + + LogTransform + LogOddsTransform + ZeroSumTransform diff --git a/docs/source/learn/core_notebooks/dims_module.ipynb b/docs/source/learn/core_notebooks/dims_module.ipynb index b3fab1e7a8..995674bd01 100644 --- a/docs/source/learn/core_notebooks/dims_module.ipynb +++ b/docs/source/learn/core_notebooks/dims_module.ipynb @@ -5,6 +5,8 @@ "id": "17e37649edaa8d0d", "metadata": {}, "source": [ + "(dims_module)=\n", + "\n", "# PyMC dims module" ] }, diff --git a/docs/source/learn/core_notebooks/index.md b/docs/source/learn/core_notebooks/index.md index a493d9e9ca..4e7d7353b1 100644 --- a/docs/source/learn/core_notebooks/index.md +++ b/docs/source/learn/core_notebooks/index.md @@ -5,11 +5,12 @@ :maxdepth: 1 pymc_overview -GLM_linear model_comparison posterior_predictive dimensionality pymc_pytensor +dims_module +GLM_linear Gaussian_Processes :::