diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index ef20ee735035..e279afaa7904 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1753,12 +1753,16 @@ def _swap_lowering_rule( layout=value.layout, ) value.store_tiled(x_smem, swizzle=swizzle) - case () | (gpu_core.TransposeRef((1, 0)),): + case () | (gpu_core.TransposeRef(),): transposed = bool(transforms) match value.layout: case mgpu.TiledLayout(): if transposed: - x_smem = mgpu.memref_transpose(x_smem, (1, 0)) + assert isinstance( + transforms[0], gpu_core.TransposeRef + ) # silence pytype + permutation = transforms[0].permutation + x_smem = mgpu.memref_transpose(x_smem, permutation) old_value = mgpu.FragmentedArray.load_untiled( x_smem, layout=value.layout, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 158754fc697a..e0c241b27408 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1078,31 +1078,31 @@ def kernel(x_ref_gmem, idx_ref, o_ref, barrier_ref): idx = jax.random.permutation(jax.random.key(1234), out_shape[0]).astype(jnp.uint32) np.testing.assert_array_equal(kernel(x, idx), x[idx, 64:]) - @parameterized.parameters( - (plgpu.Layout.WGMMA, plgpu.Layout.WGMMA_TRANSPOSED), - (plgpu.Layout.WGMMA_TRANSPOSED, plgpu.Layout.WGMMA), + @parameterized.product( + src_transposed=(False, True), shape=((128, 128), (1, 128, 128)) ) - def test_transposed_load_store(self, src_layout, dst_layout): - def is_transposed(layout): - return layout == plgpu.Layout.WGMMA_TRANSPOSED - - shape, dtype = (128, 128), jnp.float32 - + def test_transposed_load_store(self, src_transposed, shape): + dtype = jnp.float32 + permutation = (0, 2, 1) if len(shape) == 3 else (1, 0) @functools.partial( self.kernel, out_shape=jax.ShapeDtypeStruct(shape, dtype), ) def kernel(src_ref, dst_ref): - if is_transposed(src_layout): - src_ref = src_ref.T - if is_transposed(dst_layout): - dst_ref = dst_ref.T + if src_transposed: + src_ref = plgpu.transpose_ref(src_ref, permutation) + src_layout = plgpu.Layout.WGMMA_TRANSPOSED + dst_layout = plgpu.Layout.WGMMA + else: + dst_ref = plgpu.transpose_ref(dst_ref, permutation) + src_layout = plgpu.Layout.WGMMA + dst_layout = plgpu.Layout.WGMMA_TRANSPOSED src = plgpu.load(src_ref, (), layout=src_layout, optimized=False) dst = plgpu.layout_cast(src, dst_layout) dst_ref[...] = dst x = jnp.arange(math.prod(shape), dtype=dtype).reshape(shape) - np.testing.assert_array_equal(kernel(x), x.T) + np.testing.assert_array_equal(kernel(x), jnp.transpose(x, permutation)) @parameterized.product( src_memory_space=[plgpu.SMEM, plgpu.GMEM],