Skip to content

Commit 7c6378d

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Remove intermediate copies to SMEM.
We rewrite `GMEM -> SMEM -> Registers` copies to `GMEM -> Registers` copies when the purpose of the test is not to test the copy to/from SMEM. PiperOrigin-RevId: 801796030
1 parent b0c9f61 commit 7c6378d

File tree

1 file changed

+27
-89
lines changed

1 file changed

+27
-89
lines changed

tests/mosaic/gpu_test.py

Lines changed: 27 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,68 +1142,36 @@ def setUp(self):
11421142
(tcgen05.fa_m64_collective_layout, tcgen05.tmem_m64_collective_layout, 64),
11431143
],
11441144
)
1145-
def test_load_store_tmem_swizzle(self, jax_dtype_packing, reg_tmem_layout_m):
1145+
def test_load_store_tmem(self, jax_dtype_packing, reg_tmem_layout_m):
11461146
jax_dtype, packing = jax_dtype_packing
11471147
reg_layout_f, tmem_layout_f, m = reg_tmem_layout_m
1148-
swizzle = 128
1149-
in_mlir_dtype = utils.dtype_to_ir_type(jax_dtype)
1150-
swizzle_elems = swizzle // bytewidth(in_mlir_dtype)
1151-
tiling = (8, swizzle_elems)
11521148
n = 256
11531149
reg_layout = reg_layout_f(n)
11541150

1155-
def kernel(ctx, input, output, scratch):
1156-
smem, barrier, tmem = scratch
1157-
ctx.async_copy(
1158-
src_ref=input,
1159-
dst_ref=smem,
1160-
swizzle=swizzle,
1161-
gmem_transform=mgpu.TileTransform(tiling),
1162-
barrier=barrier,
1163-
)
1164-
barrier.wait()
1165-
tmem.store(fa.FragmentedArray.load_tiled(smem, swizzle, layout=reg_layout))
1151+
def kernel(ctx, input, output, tmem):
1152+
del ctx
1153+
tmem.store(fa.FragmentedArray.load_untiled(input, layout=reg_layout, optimized=False))
11661154
tcgen05.commit_tmem()
1167-
tmem.load(reg_layout).store_tiled(smem, swizzle)
1168-
mgpu.commit_shared()
1169-
ctx.async_copy(
1170-
src_ref=smem, dst_ref=output, swizzle=swizzle, gmem_transform=mgpu.TileTransform(tiling),
1171-
)
1172-
ctx.await_async_copy(0)
1155+
tmem.load(reg_layout).store_untiled(output, optimized=False)
11731156

11741157
x = self.prng.uniform(-1, 1, (m, n)).astype(jax_dtype)
1175-
scratch_shape = [
1176-
jax.ShapeDtypeStruct(tile_shape(x.shape, tiling), jax_dtype),
1177-
mgpu.TMABarrier(),
1178-
mgpu.TMEM(x.shape, jax_dtype, layout=tmem_layout_f(n, packing)),
1179-
]
11801158
y = mgpu.as_gpu_kernel(
1181-
kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape
1159+
kernel, (1, 1, 1), (128, 1, 1), x, x, mgpu.TMEM(x.shape, jax_dtype, layout=tmem_layout_f(n, packing)),
11821160
)(x)
11831161
np.testing.assert_array_equal(x, y)
11841162

11851163
@parameterized.parameters([(jnp.float32, 1), (jnp.float16, 1), (jnp.float16, 2)])
11861164
def test_load_store_tmem_native(self, jax_dtype, packing):
11871165

1188-
def kernel(ctx, input, output, scratch):
1189-
smem, barrier, tmem = scratch
1190-
ctx.async_copy(src_ref=input, dst_ref=smem, barrier=barrier)
1191-
barrier.wait()
1192-
tmem.store(fa.FragmentedArray.load_untiled(smem, layout=tcgen05.TMEM_NATIVE_LAYOUT, optimized=False))
1166+
def kernel(ctx, input, output, tmem):
1167+
del ctx
1168+
tmem.store(fa.FragmentedArray.load_untiled(input, layout=tcgen05.TMEM_NATIVE_LAYOUT, optimized=False))
11931169
tcgen05.commit_tmem()
1194-
tmem.load(tcgen05.TMEM_NATIVE_LAYOUT).store_untiled(smem, optimized=False)
1195-
mgpu.commit_shared()
1196-
ctx.async_copy(src_ref=smem, dst_ref=output)
1197-
ctx.await_async_copy(0)
1170+
tmem.load(tcgen05.TMEM_NATIVE_LAYOUT).store_untiled(output, optimized=False)
11981171

11991172
x = self.prng.uniform(-1, 1, (128, 128)).astype(jax_dtype)
1200-
scratch_shape = [
1201-
jax.ShapeDtypeStruct(x.shape, jax_dtype),
1202-
mgpu.TMABarrier(),
1203-
mgpu.TMEM(x.shape, jax_dtype, packing=packing),
1204-
]
12051173
y = mgpu.as_gpu_kernel(
1206-
kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape
1174+
kernel, (1, 1, 1), (128, 1, 1), x, x, mgpu.TMEM(x.shape, jax_dtype, packing=packing)
12071175
)(x)
12081176
np.testing.assert_array_equal(x, y)
12091177

@@ -1214,34 +1182,16 @@ def kernel(ctx, input, output, scratch):
12141182
])
12151183
@jtu.thread_unsafe_test()
12161184
def test_tmem_debug_print(self, jax_dtype, packing, expected):
1217-
swizzle = 128
1218-
in_mlir_dtype = utils.dtype_to_ir_type(jax_dtype)
1219-
swizzle_elems = swizzle // bytewidth(in_mlir_dtype)
1220-
tiling = (8, swizzle_elems)
1221-
1222-
def kernel(ctx, input, output, scratch):
1223-
smem, barrier, tmem = scratch
1224-
ctx.async_copy(
1225-
src_ref=input,
1226-
dst_ref=smem,
1227-
swizzle=swizzle,
1228-
gmem_transform=mgpu.TileTransform(tiling),
1229-
barrier=barrier,
1230-
)
1231-
barrier.wait()
1232-
tmem.store(fa.FragmentedArray.load_tiled(smem, swizzle, layout=tcgen05.LAYOUT))
1185+
def kernel(ctx, input, output, tmem):
1186+
del ctx, output
1187+
tmem.store(fa.FragmentedArray.load_untiled(input, layout=tcgen05.LAYOUT, optimized=False))
12331188
tcgen05.commit_tmem()
12341189
tmem.slice(slice(None), slice(0, 8))._debug_print()
12351190

12361191
x = jnp.arange(128 * 128, dtype=jax_dtype).reshape(128, 128)
1237-
scratch_shape = [
1238-
jax.ShapeDtypeStruct(tile_shape(x.shape, tiling), jax_dtype),
1239-
mgpu.TMABarrier(),
1240-
mgpu.TMEM(x.shape, jax_dtype, packing=packing),
1241-
]
12421192
with self.capture_stdout() as stdout:
12431193
mgpu.as_gpu_kernel(
1244-
kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape
1194+
kernel, (1, 1, 1), (128, 1, 1), x, x, mgpu.TMEM(x.shape, jax_dtype, packing=packing),
12451195
)(x).block_until_ready()
12461196
self.assertIn("[1, 2]: " + expected, stdout())
12471197

@@ -1457,29 +1407,21 @@ def test_mma_lhs_tmem(self, m, n, in_jax_dtype, out_jax_dtype):
14571407
in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype)
14581408
swizzle_elems = swizzle // bytewidth(in_mlir_dtype)
14591409
k = swizzle_elems * k_steps
1460-
lhs_tiling = rhs_tiling = (8, swizzle_elems)
1410+
rhs_tiling = (8, swizzle_elems)
14611411

14621412
def kernel(ctx, lhs, rhs, out, scratch):
1463-
lhs_smem, rhs_smem, barriers, mma_barrier, acc, lhs_tmem = scratch
1464-
ctx.async_copy(
1465-
src_ref=lhs,
1466-
dst_ref=lhs_smem,
1467-
swizzle=swizzle,
1468-
gmem_transform=mgpu.TileTransform(lhs_tiling),
1469-
barrier=barriers[0],
1470-
)
1413+
rhs_smem, barrier, mma_barrier, acc, lhs_tmem = scratch
14711414
ctx.async_copy(
14721415
src_ref=rhs,
14731416
dst_ref=rhs_smem,
14741417
swizzle=swizzle,
14751418
gmem_transform=mgpu.TileTransform(rhs_tiling),
1476-
barrier=barriers[1],
1419+
barrier=barrier,
14771420
)
1478-
barriers[0].wait()
1479-
barriers[1].wait()
1421+
barrier.wait()
14801422
lhs_tmem.store(
1481-
fa.FragmentedArray.load_tiled(
1482-
lhs_smem, swizzle, layout=tcgen05.LAYOUT
1423+
fa.FragmentedArray.load_untiled(
1424+
lhs, layout=tcgen05.LAYOUT, optimized=False
14831425
)
14841426
)
14851427
tcgen05.commit_tmem()
@@ -1497,9 +1439,8 @@ def kernel(ctx, lhs, rhs, out, scratch):
14971439
y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype)
14981440
out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype)
14991441
scratch_shape = [
1500-
jax.ShapeDtypeStruct(tile_shape(x_shape, lhs_tiling), in_jax_dtype),
15011442
jax.ShapeDtypeStruct(tile_shape(y_shape, rhs_tiling), in_jax_dtype),
1502-
mgpu.TMABarrier(2),
1443+
mgpu.TMABarrier(),
15031444
mgpu.Barrier(1),
15041445
mgpu.TMEM((128, n), out_jax_dtype),
15051446
mgpu.TMEM((128, k), in_jax_dtype, packing=2),
@@ -3644,7 +3585,6 @@ def setUp(self):
36443585
def test_smem_registers_load_store(self, layout, dtype):
36453586
def body(ctx, param: ir.Value, result: ir.Value, smem: list[ir.Value]):
36463587
del ctx
3647-
[tmp_smem] = smem
36483588
shape = ir.MemRefType(param.type).shape
36493589
elt_type = ir.MemRefType(param.type).element_type
36503590

@@ -3657,10 +3597,10 @@ def body(ctx, param: ir.Value, result: ir.Value, smem: list[ir.Value]):
36573597
reg = mgpu_dialect.layout_cast(reg, mgpu_layouts.to_layout_attr(layout))
36583598

36593599
# Registers -> SMEM
3660-
vector.store(reg, tmp_smem, zero_vector_indices)
3600+
vector.store(reg, smem, zero_vector_indices)
36613601

36623602
# SMEM -> Registers
3663-
reg = vector.load(vector_type, tmp_smem, zero_vector_indices)
3603+
reg = vector.load(vector_type, smem, zero_vector_indices)
36643604
reg = mgpu_dialect.layout_cast(reg, mgpu_layouts.to_layout_attr(layout))
36653605

36663606
# Registers -> GMEM
@@ -3674,7 +3614,7 @@ def body(ctx, param: ir.Value, result: ir.Value, smem: list[ir.Value]):
36743614
block=(128, 1, 1),
36753615
in_shape=jax_shape,
36763616
out_shape=jax_shape,
3677-
smem_scratch_shape=[jax_shape],
3617+
smem_scratch_shape=jax_shape,
36783618
thread_semantics=mgpu.LoweringSemantics.Warpgroup,
36793619
)
36803620

@@ -4488,11 +4428,9 @@ def body(
44884428
):
44894429
# We need to have a result `x` otherwise the kernel will not be generated.
44904430
del ctx, x
4491-
[tmem] = smem
4492-
44934431
tmem_ref = mgpu_dialect.tmem_alloc(
44944432
result=tmem_type,
4495-
smem_ptr=tmem,
4433+
smem_ptr=smem,
44964434
collective=collective,
44974435
packing=packing,
44984436
)
@@ -4513,7 +4451,7 @@ def body(
45134451
block=(128, 1, 1),
45144452
in_shape=(),
45154453
out_shape=(jax.ShapeDtypeStruct((), jnp.int32),),
4516-
smem_scratch_shape=[jax.ShapeDtypeStruct((), jnp.int32)],
4454+
smem_scratch_shape=jax.ShapeDtypeStruct((), jnp.int32),
45174455
thread_semantics=mgpu.LoweringSemantics.Warpgroup,
45184456
)()
45194457

0 commit comments

Comments
 (0)