Skip to content

Commit 480823b

Browse files
committed
Generalize local_subtensor_of_dot to work with batch dims
1 parent 6ad51ac commit 480823b

File tree

1 file changed

+41
-23
lines changed

1 file changed

+41
-23
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from pytensor import Variable
77
from pytensor.compile import optdb
8-
from pytensor.graph import Constant, FunctionGraph, node_rewriter
8+
from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph
99
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
1010
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
1111
from pytensor.scalar import basic as ps
@@ -119,21 +119,43 @@ def local_subtensor_of_dot(fgraph, node):
119119
the remaining entries of ``idxs`` (if any), modified to skip the
120120
second-to-last dimension of ``B`` (because dot sums over this dimension).
121121
"""
122-
if not isinstance(node.op, Subtensor):
123-
return
124-
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)):
122+
x, *idx_vars = node.inputs
123+
if not (
124+
x.owner is not None
125+
and (
126+
isinstance(x.owner.op, Dot)
127+
or (
128+
isinstance(x.owner.op, Blockwise)
129+
and isinstance(x.owner.op.core_op, Dot)
130+
)
131+
)
132+
):
125133
return
126134
# If there is other node that use the outputs of the dot
127135
# We don't want to compute twice the sub part.
128-
if len(fgraph.clients[node.inputs[0]]) > 1:
136+
if len(fgraph.clients[x]) > 1:
129137
return
130138

131-
a = node.inputs[0].owner.inputs[0]
132-
b = node.inputs[0].owner.inputs[1]
139+
a = x.owner.inputs[0]
140+
b = x.owner.inputs[1]
141+
idx_list = indices_from_subtensor(idx_vars, node.op.idx_list)
133142

134-
idx_list = get_idx_list(node.inputs, node.op.idx_list)
143+
batch_ndim = (
144+
x.owner.op.batch_ndim(x.owner) if isinstance(x.owner.op, Blockwise) else 0
145+
)
146+
147+
if batch_ndim:
148+
batch_idx_list, idx_list = idx_list[:batch_ndim], idx_list[batch_ndim:]
149+
if not idx_list:
150+
# Indexing only over batch dimensions of Blockwise, that can be handled by another rewrite
151+
return None
152+
# We perform the rest of the rewrite on dummy a, b that correspond to the core case
153+
a = a.type.clone(shape=a.type.shape[batch_ndim:])()
154+
b = b.type.clone(shape=b.type.shape[batch_ndim:])()
135155

136-
num_a_indices = min(a.ndim - 1, len(idx_list))
156+
a_ndim = a.ndim
157+
b_ndim = b.ndim
158+
num_a_indices = min(a_ndim - 1, len(idx_list))
137159
a_indices = idx_list[:num_a_indices]
138160
b_indices = idx_list[num_a_indices:]
139161

@@ -142,26 +164,22 @@ def local_subtensor_of_dot(fgraph, node):
142164
# This wasn't necessary for a, because we just omitted the last index.
143165
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
144166
# (dot also handles b.ndim < 2 as a special case)
145-
if b.ndim > 1 and len(b_indices) >= b.ndim - 1:
167+
if b_ndim > 1 and len(b_indices) >= b_ndim - 1:
146168
b_indices = (
147-
b_indices[: b.ndim - 2]
169+
b_indices[: b_ndim - 2]
148170
+ (slice(None, None, None),)
149-
+ b_indices[b.ndim - 2 :]
171+
+ b_indices[b_ndim - 2 :]
150172
)
151173

152-
a_sub = a.__getitem__(tuple(a_indices))
153-
b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b
174+
a_sub = a[tuple(a_indices)]
175+
b_sub = b[tuple(b_indices)] if b_indices else b
176+
r = dot(a_sub, b_sub)
154177

155-
# Copy over previous output stacktrace to a_sub and b_sub,
156-
# because an error in the subtensor operation (e.g. an index error)
157-
# on either a or b must correspond to an error in the
158-
# subtensor operation on their dot product.
159-
copy_stack_trace(node.outputs[0], [a_sub, b_sub])
178+
if batch_ndim:
179+
# Replace dummy inputs by the original batch ones
180+
r = vectorize_graph(r, replace={a: x.owner.inputs[0], b: x.owner.inputs[1]})
181+
r = r[tuple(batch_idx_list)]
160182

161-
# Copy over previous output stacktrace and previous dot product stacktrace,
162-
# because an error here may correspond to an either in either the original
163-
# dot product, or in the dot product after the subtensor operation.
164-
r = dot(a_sub, b_sub)
165183
copy_stack_trace([node.outputs[0], node.inputs[0]], r)
166184

167185
return [r]

0 commit comments

Comments
 (0)