Skip to content

Commit 6ad51ac

Browse files
committed
Generalize local_subtensor_of_elemwise to Blockwise
1 parent ba63aa1 commit 6ad51ac

File tree

2 files changed

+71
-12
lines changed

2 files changed

+71
-12
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
join,
2121
register_infer_shape,
2222
)
23+
from pytensor.tensor.blockwise import Blockwise
2324
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
2425
from pytensor.tensor.exceptions import NotScalarConstantError
2526
from pytensor.tensor.extra_ops import squeeze
@@ -169,16 +170,16 @@ def local_subtensor_of_dot(fgraph, node):
169170
@register_canonicalize("shape_unsafe")
170171
@register_specialize("shape_unsafe")
171172
@node_rewriter([Subtensor])
172-
def local_subtensor_of_elemwise(fgraph, node):
173-
"""Lift a Subtensor through an Elemwise and its implicit broadcasting behavior.
173+
def local_subtensor_of_batch_dims(fgraph, node):
174+
"""Lift a Subtensor through the batch dims of an (Elemwise or Blockwise) operation and its implicit broadcasting behavior.
174175
175176
exp(x)[:, 0] -> exp(x[:, 0])
176177
add(x, y)[0] -> add(x[0], y[0])
177178
add(x[None], y)[2] -> add(x, y[2])
178179
"""
179180
elem, *idx = node.inputs
180181

181-
if not (elem.owner and isinstance(elem.owner.op, Elemwise)):
182+
if not (elem.owner and isinstance(elem.owner.op, Elemwise | Blockwise)):
182183
return None
183184

184185
if len(fgraph.clients[elem]) > 1:
@@ -188,9 +189,34 @@ def local_subtensor_of_elemwise(fgraph, node):
188189

189190
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
190191

192+
batch_ndim = (
193+
elem.owner.op.batch_ndim(elem.owner)
194+
if isinstance(elem.owner.op, Blockwise)
195+
else elem.ndim
196+
)
197+
198+
if len(idx_tuple) > batch_ndim:
199+
# Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only
200+
batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:]
201+
if all(is_full_slice(idx) for idx in batch_indices):
202+
# No batch indices, nothing to do
203+
return None
204+
elem_with_batch_indices = elem[batch_indices]
205+
[elem_with_batch_indices_lifted] = local_subtensor_of_batch_dims.transform(
206+
fgraph, elem_with_batch_indices.owner
207+
)
208+
# Reapply the core_indices
209+
core_ndim = elem.type.ndim - batch_ndim
210+
# Number of batch dims may have changed with the lifting of indices, so we recompute
211+
new_batch_ndim = elem_with_batch_indices_lifted.type.ndim - core_ndim
212+
new_indices = (*(slice(None),) * new_batch_ndim, *core_indices)
213+
new_elem = elem_with_batch_indices_lifted[new_indices]
214+
copy_stack_trace(node.outputs[0], new_elem)
215+
return [new_elem]
216+
191217
elem_inputs = elem.owner.inputs
192-
elem_bcast = elem.type.broadcastable
193-
if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs):
218+
elem_bcast = elem.type.broadcastable[:batch_ndim]
219+
if all(inp.type.broadcastable[:batch_ndim] == elem_bcast for inp in elem_inputs):
194220
# No need to worry about implicit broadcasting.
195221
indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]
196222

@@ -201,7 +227,7 @@ def local_subtensor_of_elemwise(fgraph, node):
201227
zip(
202228
idx_tuple,
203229
elem_bcast,
204-
*(inp.type.broadcastable for inp in elem_inputs),
230+
*(inp.type.broadcastable[:batch_ndim] for inp in elem_inputs),
205231
# Indices can be shorter than input ndims
206232
strict=False,
207233
)

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,16 @@
3737
vector,
3838
)
3939
from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector
40+
from pytensor.tensor.blockwise import Blockwise
4041
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4142
from pytensor.tensor.math import sum as pt_sum
4243
from pytensor.tensor.rewriting.subtensor_lift import (
4344
local_subtensor_make_vector,
44-
local_subtensor_of_elemwise,
45+
local_subtensor_of_batch_dims,
4546
local_subtensor_shape_constant,
4647
)
4748
from pytensor.tensor.shape import SpecifyShape, _shape
49+
from pytensor.tensor.signal import convolve1d
4850
from pytensor.tensor.special import softmax
4951
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
5052

@@ -58,7 +60,7 @@
5860
NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None)
5961

6062

61-
class TestLocalSubtensorOfElemwise:
63+
class TestLocalSubtensorOfBatchDims:
6264
def test_unary_multiple_clients(self):
6365
# as test0, but we reuse the output of the elemwise
6466
# So we should not lift the subtensor
@@ -144,7 +146,7 @@ def test_multinary_multiple_clients(self):
144146
),
145147
],
146148
)
147-
def test_local_subtensor_of_elemwise(self, original_fn, expected_fn):
149+
def test_elemwise(self, original_fn, expected_fn):
148150
rng = np.random.default_rng(257)
149151
x = pt.matrix("x", shape=(5, 3))
150152
y = pt.matrix("y", shape=(5, 3))
@@ -163,19 +165,50 @@ def test_local_subtensor_of_elemwise(self, original_fn, expected_fn):
163165
out.eval({x: x_test, y: y_test}, **eval_kwargs),
164166
)
165167

166-
def test_local_subtensor_of_elemwise_multiple_clients(self):
168+
def test_elemwise_multiple_clients(self):
167169
x = pt.matrix("x", shape=(5, 3))
168170
y = pt.matrix("y", shape=(5, 3))
169171
out1 = add(x, y)
170172
out2 = out1[0]
171173

172174
# Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
173175
fgraph = FunctionGraph([x, y], [out1, out2], clone=False)
174-
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is None
176+
assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is None
175177

176178
# Otherwise it should work
177179
fgraph.remove_output(0)
178-
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None
180+
assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is not None
181+
182+
def test_blockwise(self):
183+
x = tensor3("x", shape=(7, 5, 11))
184+
y = tensor("y", shape=(7, 33))
185+
out = convolve1d(x, y[:, None, :])
186+
assert isinstance(out.owner.op, Blockwise)
187+
188+
out_sliced = out[2:][:, 3:]
189+
rewritten_out_sliced = rewrite_graph(out_sliced)
190+
assert equal_computations(
191+
[rewritten_out_sliced], [convolve1d(x[2:, 3:], y[2:][:, None, :])]
192+
)
193+
194+
rng = np.random.default_rng(191)
195+
x_test = rng.normal(size=x.type.shape).astype(x.type.dtype)
196+
y_test = rng.normal(size=y.type.shape).astype(y.type.dtype)
197+
np.testing.assert_allclose(
198+
rewritten_out_sliced.eval(
199+
{x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE
200+
),
201+
out_sliced.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
202+
)
203+
204+
# Check slice on core dims
205+
# Note: if we implement a rewrite on the core dims, this test should be changed for another Blockwise
206+
# that has no such rewrite or one created just for testing purposes
207+
out_sliced = out[2:][:, 0][:, 4:]
208+
rewritten_out_sliced = rewrite_graph(out_sliced)
209+
assert equal_computations(
210+
[rewritten_out_sliced], [convolve1d(x[2:, 0], y[2:])[:, 4:]]
211+
)
179212

180213

181214
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)