From 23783846589a5ec20663d43e7a9ff8bf90a7c8b2 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 24 Mar 2025 13:29:13 +0100 Subject: [PATCH] Replace batched_convolution by pytensor native impl --- environment.yml | 3 +- pymc_marketing/mmm/transformers.py | 65 +++++++++--------------------- pyproject.toml | 3 +- 3 files changed, 22 insertions(+), 49 deletions(-) diff --git a/environment.yml b/environment.yml index 5a31aa771..9c63f5ada 100644 --- a/environment.yml +++ b/environment.yml @@ -16,7 +16,8 @@ dependencies: - preliz - pyprojroot # NOTE: Keep minimum pymc version in sync with ci.yml `OLDEST_PYMC_VERSION` -- pymc>=5.21.1 +- pymc>=5.22.0 +- pytensor>=2.30.3 - pymc-extras>=0.2.1 - blackjax>=1.2.4 - scikit-learn>=1.1.1 diff --git a/pymc_marketing/mmm/transformers.py b/pymc_marketing/mmm/transformers.py index 1e480a2bb..6cd4bd19a 100644 --- a/pymc_marketing/mmm/transformers.py +++ b/pymc_marketing/mmm/transformers.py @@ -21,7 +21,7 @@ import pymc as pm import pytensor.tensor as pt from pymc.distributions.dist_math import check_parameters -from pytensor.tensor.random.utils import params_broadcast_shapes +from pytensor.npy_2_compat import normalize_axis_index class ConvMode(str, Enum): @@ -103,58 +103,29 @@ def batched_convolution( """ # We move the axis to the last dimension of the array so that it's easier to # reason about parameter broadcasting. We will move the axis back at the end - orig_ndim = x.ndim - axis = axis if axis >= 0 else orig_ndim + axis + x = pt.as_tensor(x) w = pt.as_tensor(w) + + axis = normalize_axis_index(axis, x.ndim) x = pt.moveaxis(x, axis, -1) - l_max = w.type.shape[-1] - if l_max is None: - try: - l_max = w.shape[-1].eval() - except Exception: # noqa: S110 - pass - # Get the broadcast shapes of x and w but ignoring their last dimension. - # The last dimension of x is the "time" axis, which doesn't get broadcast - # The last dimension of w is the number of time steps that go into the convolution - x_shape, w_shape = params_broadcast_shapes([x.shape, w.shape], [1, 1]) - - x = pt.broadcast_to(x, x_shape) - w = pt.broadcast_to(w, w_shape) - x_time = x.shape[-1] - # Make a tensor with x at the different time lags needed for the convolution - x_shape = x.shape - # Add the size of the kernel to the time axis - shape = (*x_shape[:-1], x_shape[-1] + w.shape[-1] - 1, w.shape[-1]) - padded_x = pt.zeros(shape, dtype=x.dtype) - - if l_max is None: # pragma: no cover - raise NotImplementedError( - "At the moment, convolving with weight arrays that don't have a concrete shape " - "at compile time is not supported." - ) - # The window is the slice of the padded array that corresponds to the original x - if l_max <= 1: - window = slice(None) + x_batch_shape = tuple(x.shape)[:-1] + lags = w.shape[-1] + + if mode == ConvMode.After: + zeros = pt.zeros((*x_batch_shape, lags - 1), dtype=x.dtype) + padded_x = pt.join(-1, zeros, x) elif mode == ConvMode.Before: - window = slice(l_max - 1, None) - elif mode == ConvMode.After: - window = slice(None, -l_max + 1) + zeros = pt.zeros((*x_batch_shape, lags - 1), dtype=x.dtype) + padded_x = pt.join(-1, x, zeros) elif mode == ConvMode.Overlap: - # Handle even and odd l_max differently if l_max is odd then we can split evenly otherwise we drop from the end - window = slice((l_max // 2) - (1 if l_max % 2 == 0 else 0), -(l_max // 2)) + zeros_left = pt.zeros((*x_batch_shape, lags // 2), dtype=x.dtype) + zeros_right = pt.zeros((*x_batch_shape, (lags - 1) // 2), dtype=x.dtype) + padded_x = pt.join(-1, zeros_left, x, zeros_right) else: - raise ValueError(f"Wrong Mode: {mode}, expected of ConvMode") - - for i in range(l_max): - padded_x = pt.set_subtensor(padded_x[..., i : x_time + i, i], x) - - padded_x = padded_x[..., window, :] + raise ValueError(f"Wrong Mode: {mode}, expected one of {', '.join(ConvMode)}") - # The convolution is treated as an element-wise product, that then gets reduced - # along the dimension that represents the convolution time lags - conv = pt.sum(padded_x * w[..., None, :], axis=-1) - # Move the "time" axis back to where it was in the original x array - return pt.moveaxis(conv, -1, axis + conv.ndim - orig_ndim) + conv = pt.signal.convolve1d(padded_x, w, mode="valid") + return pt.moveaxis(conv, -1, axis + conv.ndim - x.ndim) def geometric_adstock( diff --git a/pyproject.toml b/pyproject.toml index 22acc1877..78054be09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,8 @@ dependencies = [ "pandas", "pydantic>=2.1.0", # NOTE: Used as minimum pymc version with test.yml `OLDEST_PYMC_VERSION` - "pymc>=5.21.1", + "pymc>=5.22.0", + "pytensor>=2.30.3", "scikit-learn>=1.1.1", "seaborn>=0.12.2", "xarray>=2024.1.0",