37
37
vector ,
38
38
)
39
39
from pytensor .tensor .basic import MakeVector , concatenate , expand_dims , make_vector
40
+ from pytensor .tensor .blockwise import Blockwise
40
41
from pytensor .tensor .elemwise import DimShuffle , Elemwise
41
42
from pytensor .tensor .math import sum as pt_sum
42
43
from pytensor .tensor .rewriting .subtensor_lift import (
43
44
local_subtensor_make_vector ,
44
- local_subtensor_of_elemwise ,
45
+ local_subtensor_of_batch_dims ,
45
46
local_subtensor_shape_constant ,
46
47
)
47
48
from pytensor .tensor .shape import SpecifyShape , _shape
49
+ from pytensor .tensor .signal import convolve1d
48
50
from pytensor .tensor .special import softmax
49
51
from pytensor .tensor .subtensor import AdvancedSubtensor , Subtensor
50
52
58
60
NO_OPTIMIZATION_MODE = Mode (linker = "py" , optimizer = None )
59
61
60
62
61
- class TestLocalSubtensorOfElemwise :
63
+ class TestLocalSubtensorOfBatchDims :
62
64
def test_unary_multiple_clients (self ):
63
65
# as test0, but we reuse the output of the elemwise
64
66
# So we should not lift the subtensor
@@ -144,7 +146,7 @@ def test_multinary_multiple_clients(self):
144
146
),
145
147
],
146
148
)
147
- def test_local_subtensor_of_elemwise (self , original_fn , expected_fn ):
149
+ def test_elemwise (self , original_fn , expected_fn ):
148
150
rng = np .random .default_rng (257 )
149
151
x = pt .matrix ("x" , shape = (5 , 3 ))
150
152
y = pt .matrix ("y" , shape = (5 , 3 ))
@@ -163,19 +165,50 @@ def test_local_subtensor_of_elemwise(self, original_fn, expected_fn):
163
165
out .eval ({x : x_test , y : y_test }, ** eval_kwargs ),
164
166
)
165
167
166
- def test_local_subtensor_of_elemwise_multiple_clients (self ):
168
+ def test_elemwise_multiple_clients (self ):
167
169
x = pt .matrix ("x" , shape = (5 , 3 ))
168
170
y = pt .matrix ("y" , shape = (5 , 3 ))
169
171
out1 = add (x , y )
170
172
out2 = out1 [0 ]
171
173
172
174
# Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
173
175
fgraph = FunctionGraph ([x , y ], [out1 , out2 ], clone = False )
174
- assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is None
176
+ assert local_subtensor_of_batch_dims .transform (fgraph , out2 .owner ) is None
175
177
176
178
# Otherwise it should work
177
179
fgraph .remove_output (0 )
178
- assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is not None
180
+ assert local_subtensor_of_batch_dims .transform (fgraph , out2 .owner ) is not None
181
+
182
+ def test_blockwise (self ):
183
+ x = tensor3 ("x" , shape = (7 , 5 , 11 ))
184
+ y = tensor ("y" , shape = (7 , 33 ))
185
+ out = convolve1d (x , y [:, None , :])
186
+ assert isinstance (out .owner .op , Blockwise )
187
+
188
+ out_sliced = out [2 :][:, 3 :]
189
+ rewritten_out_sliced = rewrite_graph (out_sliced )
190
+ assert equal_computations (
191
+ [rewritten_out_sliced ], [convolve1d (x [2 :, 3 :], y [2 :][:, None , :])]
192
+ )
193
+
194
+ rng = np .random .default_rng (191 )
195
+ x_test = rng .normal (size = x .type .shape ).astype (x .type .dtype )
196
+ y_test = rng .normal (size = y .type .shape ).astype (y .type .dtype )
197
+ np .testing .assert_allclose (
198
+ rewritten_out_sliced .eval (
199
+ {x : x_test , y : y_test }, mode = NO_OPTIMIZATION_MODE
200
+ ),
201
+ out_sliced .eval ({x : x_test , y : y_test }, mode = NO_OPTIMIZATION_MODE ),
202
+ )
203
+
204
+ # Check slice on core dims
205
+ # Note: if we implement a rewrite on the core dims, this test should be changed for another Blockwise
206
+ # that has no such rewrite or one created just for testing purposes
207
+ out_sliced = out [2 :][:, 0 ][:, 4 :]
208
+ rewritten_out_sliced = rewrite_graph (out_sliced )
209
+ assert equal_computations (
210
+ [rewritten_out_sliced ], [convolve1d (x [2 :, 0 ], y [2 :])[:, 4 :]]
211
+ )
179
212
180
213
181
214
@pytest .mark .parametrize (
0 commit comments