Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2479,16 +2479,6 @@ def _gather_lowering_rule(

@register_lowering_rule(lax.transpose_p)
def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation):
dim_size = len(ctx.avals_in[0].shape)
if (
permutation[-2:] != (dim_size - 1, dim_size - 2)
and permutation[-2:] != (dim_size - 2, dim_size - 1)
and len(permutation) != 3
and permutation[-3:] != (dim_size - 2, dim_size - 3, dim_size - 1)
):
raise NotImplementedError(
f"Unsupported transpose permutation: {permutation}"
)
out_type = aval_to_ir_type(
ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0]
)
Expand Down
151 changes: 117 additions & 34 deletions jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1371,54 +1371,137 @@ FailureOr<Value> canonicalize_transpose(const CanonicalizeContext &ctx,

auto create_or_decompose_transpose =
[&](Value input, ArrayRef<int64_t> permutation) -> Value {
auto create_transpose_op = [&builder](Value input, ArrayRef<int64_t> perm) {
int64_t dim_size = permutation.size();
// Keep track of the current permutation.
SmallVector<int64_t> curr_perm =
llvm::to_vector(llvm::seq<int64_t>(0, dim_size));

auto create_transpose_op = [&builder, &curr_perm](Value input,
ArrayRef<int64_t> perm) {
if (llvm::is_sorted(perm)) {
// Early exit if the permutation doesn't change the order.
return input;
}
auto input_vty = cast<VectorType>(input.getType());
SmallVector<int64_t> new_shape(input_vty.getShape().size());
SmallVector<int64_t> new_perm(perm.size());
for (int i = 0; i < input_vty.getShape().size(); ++i) {
new_shape[i] = input_vty.getShape()[perm[i]];
new_perm[i] = curr_perm[perm[i]];
}
curr_perm = new_perm;
return builder.create<tpu::TransposeOp>(
VectorType::get(new_shape, input_vty.getElementType()), input, perm);
};

// Returns a permutation that permutes from `from` to `to`.
auto get_perm = [&](ArrayRef<int64_t> from, ArrayRef<int64_t> to) {
SmallVector<int64_t> perm(from.size());
DenseMap<int64_t, int64_t> index;
for (int i = 0; i < from.size(); ++i) {
index[from[i]] = i;
}
for (int i = 0; i < to.size(); ++i) {
perm[i] = index[to[i]];
}
return perm;
};

// We support transpositions of the following dims in apply:
//
// (1) 3D transpose for second-minor > major.
// (2) ND transpose between majors and between second-minor and minormost.
// (1) ND transpose between 2nd minor and 3rd minor.
// (2) ND transpose between 2nd minor and minormost.
// (3) ND transpose between 3rd+ minors. (noop)
//
// Other cases can be decomposed into multiple transpositions of (1) and
// (2).
if (permutation == ArrayRef<int64_t>({1, 2, 0}) ||
permutation == ArrayRef<int64_t>({2, 0, 1}) ||
permutation == ArrayRef<int64_t>({2, 1, 0})) {
int64_t dim_size = permutation.size();
bool minormost_to_untiled = permutation[dim_size - 3] == dim_size - 1;
bool unchanged_second_minor = permutation[dim_size - 2] == dim_size - 2;
bool minormost_to_second_minor =
permutation[dim_size - 2] == dim_size - 1;
// For example, given the permutation (2, 1, 0), transpose order is:
// (0, 1, 2) > (0, 2, 1) > (2, 0, 1) > (2, 1, 0)
// | | |
// (2) (1) (2)
// TODO(b/419268277): The decomposition can be generalized to dim_size !=
// 3, but we don't support > 3D transposition between major and
// second-minor in apply.
Value res = input;
if (minormost_to_untiled) {
// Transpose minormost to second-minor to use (1).
res = create_transpose_op(res, ArrayRef<int64_t>({0, 2, 1}));
}
res = create_transpose_op(res, ArrayRef<int64_t>({1, 0, 2}));
if ((minormost_to_untiled && unchanged_second_minor) ||
minormost_to_second_minor) {
// When transposing minormost to major, we need an additional
// transposition between minormost and second-minor to reposition the
// original second-minor if it's unchanged.
res = create_transpose_op(res, ArrayRef<int64_t>({0, 2, 1}));
// Other cases can be decomposed into multiple of the above transpositions.
if (permutation.take_back(2) ==
ArrayRef<int64_t>{dim_size - 2, dim_size - 1} ||
permutation.take_back(2) ==
ArrayRef<int64_t>{dim_size - 1, dim_size - 2} ||
permutation.take_back(3) ==
ArrayRef<int64_t>{dim_size - 2, dim_size - 3, dim_size - 1}) {
return create_transpose_op(input, permutation);
}

// 2D case is supported in apply. Safely assume it's 3D+.
CHECK_GE(dim_size, 3);

SmallVector<int64_t> swap_2nd_minor_minormost =
llvm::to_vector(llvm::seq<int64_t>(0, dim_size));
swap_2nd_minor_minormost[dim_size - 2] = dim_size - 1;
swap_2nd_minor_minormost[dim_size - 1] = dim_size - 2;
SmallVector<int64_t> swap_3rd_minor_2nd_minor =
llvm::to_vector(llvm::seq<int64_t>(0, dim_size));
swap_3rd_minor_2nd_minor[dim_size - 3] = dim_size - 2;
swap_3rd_minor_2nd_minor[dim_size - 2] = dim_size - 3;

// Three stages for the decomposition.
//
// a. Permute minormost to the correct position.
// b. Permute 2nd minor to the correct position.
// c. Permute 3rd+ minors to the correct position.
//
// Given permutation = (3, 2, 1, 0) for example, starting with original
// permutation = (0, 1, 2, 3), the following stages will happen in order:
//
// (2) (3) (1) (2)
// a. (0, 1, 2, 3) --> (0, 1, 3, 2) --> (1, 0, 3, 2) --> (1, 3, 0, 2) -->
// (1, 3, 2, 0)
//
// (3) (1)
// b. (1, 3, 2, 0) --> (3, 1, 2, 0) --> (3, 2, 1, 0)
//
//
// c. n/a

Value res = input;
// Permute minormost to the correct position.
if (curr_perm.back() != permutation.back()) {
// (2).
res = create_transpose_op(res, swap_2nd_minor_minormost);
if (curr_perm.back() != permutation.back()) {
// If it's still not in the correct position, it must be in 3rd+ minors.
// Need (3) to put that dim at 3rd minor and use (1).
SmallVector<int64_t> perm =
llvm::to_vector(llvm::seq<int64_t>(0, dim_size));
std::swap(perm[dim_size - 3], perm[permutation.back()]);
// (3).
res = create_transpose_op(res, perm);
// (1).
res = create_transpose_op(res, swap_3rd_minor_2nd_minor);
// (2).
res = create_transpose_op(res, swap_2nd_minor_minormost);
}
return res;
}
return create_transpose_op(input, permutation);

// Minormost must be in the correct position now.
CHECK_EQ(permutation.back(), curr_perm.back());

// Permute second-minor to the correct position.
if (curr_perm[dim_size - 2] != permutation[dim_size - 2]) {
SmallVector<int64_t> required_perm = get_perm(curr_perm, permutation);
SmallVector<int64_t> perm =
llvm::to_vector(llvm::seq<int64_t>(0, dim_size));
// Need (3) to put the dim supposed at 2nd minor dim to 3rd minor and use
// (1).
std::swap(perm[dim_size - 3], perm[required_perm[dim_size - 2]]);
// (3).
res = create_transpose_op(res, perm);
// (1).
res = create_transpose_op(res, swap_3rd_minor_2nd_minor);
}

// 2nd minor must be in the correct position now.
CHECK_EQ(permutation[dim_size - 2], curr_perm[dim_size - 2]);

// Permute 3rd+ minors with (3).
res = create_transpose_op(res, get_perm(curr_perm, permutation));
for (int i = 0; i < dim_size; ++i) {
CHECK_EQ(permutation[i], curr_perm[i])
<< "permutation[" << i << "] = " << permutation[i] << ", curr_perm["
<< i << "] = " << curr_perm[i];
}
return res;
};

// TODO(mvoz): Even gen 7 support is spotty on all test targets.
Expand Down
13 changes: 13 additions & 0 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2514,6 +2514,12 @@ def kernel(x_ref, out_ref):
((1, 2, 8, 8, 1), (0, 1, 3, 2, 4)),
((3, 2, 8, 8, 8), (1, 0, 3, 2, 4)),
]
+ [
# Any 5D permutation. It might involve 3rd <-> 2nd minor swap, which
# has a stricter dim size requirement, so we use 8x8x8x8x8 here.
((8, 8, 8, 8, 8), perm)
for perm in itertools.permutations(range(5))
]
)
def test_transpose(self, shape_and_axes):
if jtu.test_device_matches(["gpu"]):
Expand All @@ -2535,6 +2541,13 @@ def test_transpose(self, shape_and_axes):
):
self.skipTest("Requires libtpu built after 2025-8-29")

if (
rank == 5
and in_shape == (8, 8, 8, 8, 8)
and not jtu.if_cloud_tpu_at_least(2025, 9, 1)
):
self.skipTest("Requires libtpu built after 2025-9-1")

x = jnp.arange(math.prod(in_shape), dtype=jnp.float32).reshape(in_shape)
expected = jnp.transpose(x, axes=transpose_axes)

Expand Down
Loading