|
21 | 21 | import pymc as pm
|
22 | 22 | import pytensor.tensor as pt
|
23 | 23 | from pymc.distributions.dist_math import check_parameters
|
24 |
| -from pytensor.tensor.random.utils import params_broadcast_shapes |
| 24 | +from pytensor.npy_2_compat import normalize_axis_index |
25 | 25 |
|
26 | 26 |
|
27 | 27 | class ConvMode(str, Enum):
|
@@ -103,58 +103,29 @@ def batched_convolution(
|
103 | 103 | """
|
104 | 104 | # We move the axis to the last dimension of the array so that it's easier to
|
105 | 105 | # reason about parameter broadcasting. We will move the axis back at the end
|
106 |
| - orig_ndim = x.ndim |
107 |
| - axis = axis if axis >= 0 else orig_ndim + axis |
| 106 | + x = pt.as_tensor(x) |
108 | 107 | w = pt.as_tensor(w)
|
| 108 | + |
| 109 | + axis = normalize_axis_index(axis, x.ndim) |
109 | 110 | x = pt.moveaxis(x, axis, -1)
|
110 |
| - l_max = w.type.shape[-1] |
111 |
| - if l_max is None: |
112 |
| - try: |
113 |
| - l_max = w.shape[-1].eval() |
114 |
| - except Exception: # noqa: S110 |
115 |
| - pass |
116 |
| - # Get the broadcast shapes of x and w but ignoring their last dimension. |
117 |
| - # The last dimension of x is the "time" axis, which doesn't get broadcast |
118 |
| - # The last dimension of w is the number of time steps that go into the convolution |
119 |
| - x_shape, w_shape = params_broadcast_shapes([x.shape, w.shape], [1, 1]) |
120 |
| - |
121 |
| - x = pt.broadcast_to(x, x_shape) |
122 |
| - w = pt.broadcast_to(w, w_shape) |
123 |
| - x_time = x.shape[-1] |
124 |
| - # Make a tensor with x at the different time lags needed for the convolution |
125 |
| - x_shape = x.shape |
126 |
| - # Add the size of the kernel to the time axis |
127 |
| - shape = (*x_shape[:-1], x_shape[-1] + w.shape[-1] - 1, w.shape[-1]) |
128 |
| - padded_x = pt.zeros(shape, dtype=x.dtype) |
129 |
| - |
130 |
| - if l_max is None: # pragma: no cover |
131 |
| - raise NotImplementedError( |
132 |
| - "At the moment, convolving with weight arrays that don't have a concrete shape " |
133 |
| - "at compile time is not supported." |
134 |
| - ) |
135 |
| - # The window is the slice of the padded array that corresponds to the original x |
136 |
| - if l_max <= 1: |
137 |
| - window = slice(None) |
| 111 | + x_batch_shape = tuple(x.shape)[:-1] |
| 112 | + lags = w.shape[-1] |
| 113 | + |
| 114 | + if mode == ConvMode.After: |
| 115 | + zeros = pt.zeros((*x_batch_shape, lags - 1), dtype=x.dtype) |
| 116 | + padded_x = pt.join(-1, zeros, x) |
138 | 117 | elif mode == ConvMode.Before:
|
139 |
| - window = slice(l_max - 1, None) |
140 |
| - elif mode == ConvMode.After: |
141 |
| - window = slice(None, -l_max + 1) |
| 118 | + zeros = pt.zeros((*x_batch_shape, lags - 1), dtype=x.dtype) |
| 119 | + padded_x = pt.join(-1, x, zeros) |
142 | 120 | elif mode == ConvMode.Overlap:
|
143 |
| - # Handle even and odd l_max differently if l_max is odd then we can split evenly otherwise we drop from the end |
144 |
| - window = slice((l_max // 2) - (1 if l_max % 2 == 0 else 0), -(l_max // 2)) |
| 121 | + zeros_left = pt.zeros((*x_batch_shape, lags // 2), dtype=x.dtype) |
| 122 | + zeros_right = pt.zeros((*x_batch_shape, (lags - 1) // 2), dtype=x.dtype) |
| 123 | + padded_x = pt.join(-1, zeros_left, x, zeros_right) |
145 | 124 | else:
|
146 |
| - raise ValueError(f"Wrong Mode: {mode}, expected of ConvMode") |
147 |
| - |
148 |
| - for i in range(l_max): |
149 |
| - padded_x = pt.set_subtensor(padded_x[..., i : x_time + i, i], x) |
150 |
| - |
151 |
| - padded_x = padded_x[..., window, :] |
| 125 | + raise ValueError(f"Wrong Mode: {mode}, expected one of {', '.join(ConvMode)}") |
152 | 126 |
|
153 |
| - # The convolution is treated as an element-wise product, that then gets reduced |
154 |
| - # along the dimension that represents the convolution time lags |
155 |
| - conv = pt.sum(padded_x * w[..., None, :], axis=-1) |
156 |
| - # Move the "time" axis back to where it was in the original x array |
157 |
| - return pt.moveaxis(conv, -1, axis + conv.ndim - orig_ndim) |
| 127 | + conv = pt.signal.convolve1d(padded_x, w, mode="valid") |
| 128 | + return pt.moveaxis(conv, -1, axis + conv.ndim - x.ndim) |
158 | 129 |
|
159 | 130 |
|
160 | 131 | def geometric_adstock(
|
|
0 commit comments