Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/tensor/signal/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def convolve1d(
if mode == "same":
# We implement "same" as "valid" with padded `in1`.
in1_batch_shape = tuple(in1.shape)[:-1]
zeros_left = in2.shape[0] // 2
zeros_right = (in2.shape[0] - 1) // 2
zeros_left = in2.shape[-1] // 2
zeros_right = (in2.shape[-1] - 1) // 2
in1 = join(
-1,
zeros((*in1_batch_shape, zeros_left), dtype=in2.dtype),
Expand Down
13 changes: 13 additions & 0 deletions tests/tensor/signal/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,16 @@ def test_convolve1d_batch():
res_np = np.convolve(x_test[0], y_test[0])
np.testing.assert_allclose(res[0], res_np, rtol=rtol)
np.testing.assert_allclose(res[1], res_np, rtol=rtol)


def test_convolve1d_same():
x = matrix("data")
y = matrix("kernel")
out = convolve1d(x, y, mode="same")

rng = np.random.default_rng(38)
x_test = rng.normal(size=(2, 8)).astype(x.dtype)
y_test = rng.normal(size=(2, 8)).astype(x.dtype)

res = out.eval({x: x_test, y: y_test})
assert res.shape == (2, 8)