Skip to content

Commit 2378384

Browse files
committed
Replace batched_convolution by pytensor native impl
1 parent a891405 commit 2378384

File tree

3 files changed

+22
-49
lines changed

3 files changed

+22
-49
lines changed

environment.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ dependencies:
1616
- preliz
1717
- pyprojroot
1818
# NOTE: Keep minimum pymc version in sync with ci.yml `OLDEST_PYMC_VERSION`
19-
- pymc>=5.21.1
19+
- pymc>=5.22.0
20+
- pytensor>=2.30.3
2021
- pymc-extras>=0.2.1
2122
- blackjax>=1.2.4
2223
- scikit-learn>=1.1.1

pymc_marketing/mmm/transformers.py

Lines changed: 18 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pymc as pm
2222
import pytensor.tensor as pt
2323
from pymc.distributions.dist_math import check_parameters
24-
from pytensor.tensor.random.utils import params_broadcast_shapes
24+
from pytensor.npy_2_compat import normalize_axis_index
2525

2626

2727
class ConvMode(str, Enum):
@@ -103,58 +103,29 @@ def batched_convolution(
103103
"""
104104
# We move the axis to the last dimension of the array so that it's easier to
105105
# reason about parameter broadcasting. We will move the axis back at the end
106-
orig_ndim = x.ndim
107-
axis = axis if axis >= 0 else orig_ndim + axis
106+
x = pt.as_tensor(x)
108107
w = pt.as_tensor(w)
108+
109+
axis = normalize_axis_index(axis, x.ndim)
109110
x = pt.moveaxis(x, axis, -1)
110-
l_max = w.type.shape[-1]
111-
if l_max is None:
112-
try:
113-
l_max = w.shape[-1].eval()
114-
except Exception: # noqa: S110
115-
pass
116-
# Get the broadcast shapes of x and w but ignoring their last dimension.
117-
# The last dimension of x is the "time" axis, which doesn't get broadcast
118-
# The last dimension of w is the number of time steps that go into the convolution
119-
x_shape, w_shape = params_broadcast_shapes([x.shape, w.shape], [1, 1])
120-
121-
x = pt.broadcast_to(x, x_shape)
122-
w = pt.broadcast_to(w, w_shape)
123-
x_time = x.shape[-1]
124-
# Make a tensor with x at the different time lags needed for the convolution
125-
x_shape = x.shape
126-
# Add the size of the kernel to the time axis
127-
shape = (*x_shape[:-1], x_shape[-1] + w.shape[-1] - 1, w.shape[-1])
128-
padded_x = pt.zeros(shape, dtype=x.dtype)
129-
130-
if l_max is None: # pragma: no cover
131-
raise NotImplementedError(
132-
"At the moment, convolving with weight arrays that don't have a concrete shape "
133-
"at compile time is not supported."
134-
)
135-
# The window is the slice of the padded array that corresponds to the original x
136-
if l_max <= 1:
137-
window = slice(None)
111+
x_batch_shape = tuple(x.shape)[:-1]
112+
lags = w.shape[-1]
113+
114+
if mode == ConvMode.After:
115+
zeros = pt.zeros((*x_batch_shape, lags - 1), dtype=x.dtype)
116+
padded_x = pt.join(-1, zeros, x)
138117
elif mode == ConvMode.Before:
139-
window = slice(l_max - 1, None)
140-
elif mode == ConvMode.After:
141-
window = slice(None, -l_max + 1)
118+
zeros = pt.zeros((*x_batch_shape, lags - 1), dtype=x.dtype)
119+
padded_x = pt.join(-1, x, zeros)
142120
elif mode == ConvMode.Overlap:
143-
# Handle even and odd l_max differently if l_max is odd then we can split evenly otherwise we drop from the end
144-
window = slice((l_max // 2) - (1 if l_max % 2 == 0 else 0), -(l_max // 2))
121+
zeros_left = pt.zeros((*x_batch_shape, lags // 2), dtype=x.dtype)
122+
zeros_right = pt.zeros((*x_batch_shape, (lags - 1) // 2), dtype=x.dtype)
123+
padded_x = pt.join(-1, zeros_left, x, zeros_right)
145124
else:
146-
raise ValueError(f"Wrong Mode: {mode}, expected of ConvMode")
147-
148-
for i in range(l_max):
149-
padded_x = pt.set_subtensor(padded_x[..., i : x_time + i, i], x)
150-
151-
padded_x = padded_x[..., window, :]
125+
raise ValueError(f"Wrong Mode: {mode}, expected one of {', '.join(ConvMode)}")
152126

153-
# The convolution is treated as an element-wise product, that then gets reduced
154-
# along the dimension that represents the convolution time lags
155-
conv = pt.sum(padded_x * w[..., None, :], axis=-1)
156-
# Move the "time" axis back to where it was in the original x array
157-
return pt.moveaxis(conv, -1, axis + conv.ndim - orig_ndim)
127+
conv = pt.signal.convolve1d(padded_x, w, mode="valid")
128+
return pt.moveaxis(conv, -1, axis + conv.ndim - x.ndim)
158129

159130

160131
def geometric_adstock(

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ dependencies = [
3131
"pandas",
3232
"pydantic>=2.1.0",
3333
# NOTE: Used as minimum pymc version with test.yml `OLDEST_PYMC_VERSION`
34-
"pymc>=5.21.1",
34+
"pymc>=5.22.0",
35+
"pytensor>=2.30.3",
3536
"scikit-learn>=1.1.1",
3637
"seaborn>=0.12.2",
3738
"xarray>=2024.1.0",

0 commit comments

Comments
 (0)