Skip to content

Commit 998e6ba

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU][NFC] Always set exact to False in tmem_alloc and tmem_dealloc lowerings.
`exact` is a validation mechanism, it does not affect lowering. The instruction is already validated by the op verifier. PiperOrigin-RevId: 800827106
1 parent 9495be7 commit 998e6ba

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,15 +1520,14 @@ def _tmem_alloc_op_lowering_rule(
15201520
ncols = output_shape[1] // op.packing.value
15211521

15221522
with mgpu_utils.when(ctx.single_warp_per_block_predicate):
1523-
tcgen05.tmem_alloc(op.smem_ptr, ncols, op.collective, op.exact)
1523+
tcgen05.tmem_alloc(op.smem_ptr, ncols, op.collective, exact=False)
15241524
gpu.barrier()
15251525
tmem_addr = memref.load(op.smem_ptr, [])
15261526

15271527
cast_op = builtin.UnrealizedConversionCastOp(
15281528
[op.result.type], [tmem_addr]
15291529
)
15301530
cast_op.attributes["collective"] = op.collective
1531-
cast_op.attributes["exact"] = op.exact
15321531
cast_op.attributes["packing"] = op.packing
15331532

15341533
return [cast_op.result]
@@ -1552,14 +1551,13 @@ def _tmem_dealloc_op_lowering_rule(
15521551
i32 = ir.IntegerType.get_signless(32)
15531552
conversion_cast, [tmem_addr] = _undo_conversion_cast(op.tmem_ref, [i32])
15541553
collective = ir.BoolAttr(conversion_cast.attributes["collective"]).value
1555-
exact = ir.BoolAttr(conversion_cast.attributes["exact"]).value
15561554
packing = ir.IntegerAttr(conversion_cast.attributes["packing"]).value
15571555

15581556
output_shape = ir.MemRefType(op.tmem_ref.type).shape
15591557
ncols = output_shape[1] // packing
15601558

15611559
with mgpu_utils.when(ctx.single_warp_per_block_predicate):
1562-
tcgen05.tmem_dealloc(tmem_addr, ncols, collective, exact)
1560+
tcgen05.tmem_dealloc(tmem_addr, ncols, collective, exact=False)
15631561

15641562
return []
15651563

tests/mosaic/gpu_test.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4466,15 +4466,13 @@ def setUp(self):
44664466
self.skipTest("Only works on GPU with capability sm_100a or sm_101a")
44674467

44684468
@parameterized.named_parameters(
4469-
("exact", (128, 64), jnp.bfloat16, 1, True, False, 64),
4470-
("non-exact", (128, 77), jnp.bfloat16, 1, False, False, 128),
4471-
("exact-packed", (128, 128), jnp.bfloat16, 2, True, False, 64),
4472-
("non-exact-packed", (128, 120), jnp.bfloat16, 2, False, False, 64),
4473-
("collective-exact", (128, 64), jnp.bfloat16, 1, True, True, 64),
4469+
("unpacked", (128, 77), jnp.bfloat16, 1, False, 128),
4470+
("packed", (128, 128), jnp.bfloat16, 2, False, 64),
4471+
("collective", (128, 64), jnp.bfloat16, 1, True, 64),
44744472
)
44754473
@unittest.skip("Layout inference fails for trivial load/store kernels.")
44764474
def test_tmem_alloc_dealloc(
4477-
self, shape, dtype, packing, exact, collective, expected_allocated_columns
4475+
self, shape, dtype, packing, collective, expected_allocated_columns
44784476
):
44794477
tmem_type = ir.MemRefType.get(
44804478
shape,
@@ -4493,7 +4491,7 @@ def body(
44934491
result=tmem_type,
44944492
smem_ptr=tmem,
44954493
collective=collective,
4496-
exact=exact,
4494+
exact=False,
44974495
packing=packing,
44984496
)
44994497

0 commit comments

Comments
 (0)