Skip to content

Commit 7a1699b

Browse files
committed
Cache is_matrix_transpose at DimShuffle Op level
1 parent 8d10764 commit 7a1699b

File tree

2 files changed

+11
-20
lines changed

2 files changed

+11
-20
lines changed

pytensor/tensor/elemwise.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,7 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
159159
# This is the list of the original dimensions that we keep
160160
self.shuffle = [x for x in new_order if x != "x"]
161161
self.transposition = self.shuffle + drop
162-
# List of dimensions of the output that are broadcastable and were not
163-
# in the original input
162+
# List of dimensions of the output that are broadcastable and were not in the original input
164163
self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x")
165164
self.drop = drop
166165

@@ -175,6 +174,12 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
175174
self.is_right_expand_dims = self.is_expand_dims and new_order[
176175
:input_ndim
177176
] == list(range(input_ndim))
177+
self.is_matrix_transpose = False
178+
if dims_are_shuffled and (not drop) and input_ndim >= 2:
179+
# We consider a matrix transpose if we only flip the last two dims
180+
# Regardless of whethre there's an expand_dims or not
181+
mt_pattern = [*range(input_ndim - 2), input_ndim - 1, input_ndim - 2]
182+
self.is_matrix_transpose = new_order[len(augment) :] == mt_pattern
178183

179184
def __setstate__(self, state):
180185
self.__dict__.update(state)

pytensor/tensor/rewriting/linalg.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,25 +66,11 @@
6666
def is_matrix_transpose(x: TensorVariable) -> bool:
6767
"""Check if a variable corresponds to a transpose of the last two axes"""
6868
node = x.owner
69-
if (
70-
node
69+
return (
70+
node is not None
7171
and isinstance(node.op, DimShuffle)
72-
and not (node.op.drop or node.op.augment)
73-
):
74-
[inp] = node.inputs
75-
ndims = inp.type.ndim
76-
if ndims < 2:
77-
return False
78-
transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2)
79-
80-
# Allow expand_dims on the left of the transpose
81-
if (diff := len(transpose_order) - len(node.op.new_order)) > 0:
82-
transpose_order = (
83-
*(["x"] * diff),
84-
*transpose_order,
85-
)
86-
return node.op.new_order == transpose_order
87-
return False
72+
and node.op.is_matrix_transpose
73+
)
8874

8975

9076
@register_canonicalize

0 commit comments

Comments
 (0)