Skip to content

Commit e6bc889

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU][NFC] Consistently use self.kernel instead of plgpu.kernel in tests.
`self.kernel` abstracts away the lowering semantic which allows writing a single test for both Lane and Warpgroup lowering. PiperOrigin-RevId: 816205635
1 parent 73095f0 commit e6bc889

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

tests/pallas/mosaic_gpu_test.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2119,6 +2119,8 @@ def test_tma_load_multicast(self, collective_dims, noncollective_dims, collectiv
21192119
to test that the cluster axes are used correctly.
21202120
"""
21212121

2122+
self.skip_if_wg_semantics() # User transforms are not supported.
2123+
21222124
dtype = jnp.float16
21232125
cluster = [1, 1, 1]
21242126
for d in collective_dims:
@@ -2167,7 +2169,7 @@ def cluster_id(axes):
21672169
plgpu.wait_smem_to_gmem(0)
21682170

21692171
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
2170-
kernel = plgpu.kernel(
2172+
kernel = self.kernel(
21712173
body,
21722174
grid=cluster,
21732175
grid_names=("grid_x", "grid_y", "grid_z"),
@@ -2380,7 +2382,7 @@ def test_discharge_comms_effect(self):
23802382
def body(out, sem):
23812383
pl.semaphore_signal(sem, device_id=jnp.asarray(2, jnp.int32))
23822384

2383-
f = plgpu.kernel(
2385+
f = self.kernel(
23842386
body,
23852387
out_shape=jax.ShapeDtypeStruct((), jnp.int32),
23862388
scratch_shapes=[plgpu.SemaphoreType.REGULAR],
@@ -2416,7 +2418,7 @@ def kernel(dst, collective_barrier):
24162418
plgpu.barrier_arrive(collective_barrier)
24172419
plgpu.barrier_wait(collective_barrier)
24182420
dst[...] = jnp.ones_like(dst)
2419-
y = plgpu.kernel(
2421+
y = self.kernel(
24202422
kernel,
24212423
out_shape=jax.ShapeDtypeStruct((128,), jnp.int32),
24222424
scratch_shapes=[plgpu.ClusterBarrier(collective_axes=("x",), num_arrivals=4)],
@@ -2434,7 +2436,7 @@ def setUp(self):
24342436

24352437
def test_axis_index(self):
24362438
warp_mesh = plgpu.WarpMesh(axis_name="warp")
2437-
@functools.partial(plgpu.kernel,
2439+
@functools.partial(self.kernel,
24382440
out_shape=jax.ShapeDtypeStruct((2, 128), jnp.int32))
24392441
def kernel(y_ref):
24402442
def scope(ones_smem_ref, threes_smem_ref):
@@ -2471,7 +2473,7 @@ def _():
24712473
)
24722474
def test_scalar_binary_op(self, op):
24732475
warp_mesh = plgpu.WarpMesh(axis_name="warp")
2474-
@functools.partial(plgpu.kernel,
2476+
@functools.partial(self.kernel,
24752477
out_shape=jax.ShapeDtypeStruct((), jnp.int32))
24762478
def kernel(y_ref):
24772479
@pl.core_map(warp_mesh)
@@ -2492,7 +2494,7 @@ def test_errors_when_closing_over_array(self):
24922494
# a mesh, since we would need to present a view of the array local
24932495
# to each warp.
24942496
warp_mesh = plgpu.WarpMesh(axis_name="warp")
2495-
@functools.partial(plgpu.kernel,
2497+
@functools.partial(self.kernel,
24962498
out_shape=jax.ShapeDtypeStruct((32, 32), jnp.float32),
24972499
scratch_shapes=[plgpu.SMEM((32, 32), jnp.float32)])
24982500
def kernel(out_ref, smem_ref):
@@ -2512,7 +2514,7 @@ def _():
25122514
@parameterized.parameters(True, False)
25132515
def test_single_warp_loop(self, force_while):
25142516
warp_mesh = plgpu.WarpMesh(axis_name="warp")
2515-
@functools.partial(plgpu.kernel,
2517+
@functools.partial(self.kernel,
25162518
out_shape=jax.ShapeDtypeStruct((10, 128), jnp.int32))
25172519
def kernel(y_ref):
25182520
def scope(smem_ref):
@@ -2539,7 +2541,7 @@ def loop_body(i, _):
25392541
def test_debug_print(self):
25402542
warp_mesh = plgpu.WarpMesh(axis_name="warp")
25412543
@functools.partial(
2542-
plgpu.kernel,
2544+
self.kernel,
25432545
out_shape=jnp.zeros(128, np.int32),
25442546
)
25452547
def kernel(ref):
@@ -2566,7 +2568,7 @@ def test_copy_gmem_to_smem_from_different_warps(self,
25662568
wait_smem_to_gmem_in_warp):
25672569
# In this test, we issue a copy from from warp 0 and await it in warp 1.
25682570
warp_mesh = plgpu.WarpMesh(axis_name="warp")
2569-
@functools.partial(plgpu.kernel,
2571+
@functools.partial(self.kernel,
25702572
out_shape=jax.ShapeDtypeStruct((32, 32), jnp.float32))
25712573
def kernel(x_ref, y_ref):
25722574
def scope(smem_ref, tma_barrier):
@@ -3574,7 +3576,7 @@ def kernel(a_gmem, b_gmem, out_gmem,
35743576
plgpu.copy_smem_to_gmem(out_smem, out_gmem)
35753577
plgpu.wait_smem_to_gmem(0)
35763578

3577-
f = plgpu.kernel(
3579+
f = self.kernel(
35783580
kernel,
35793581
out_shape=jax.ShapeDtypeStruct(shape, dtype),
35803582
scratch_shapes=[
@@ -3787,7 +3789,7 @@ def kernel(a_gmem, b_gmem, out_gmem128, out_gmem64,
37873789
plgpu.copy_smem_to_gmem(out_smem, out_gmem64)
37883790
plgpu.wait_smem_to_gmem(0)
37893791

3790-
f = plgpu.kernel(
3792+
f = self.kernel(
37913793
kernel,
37923794
out_shape=[jax.ShapeDtypeStruct(shape, dtype),
37933795
jax.ShapeDtypeStruct(shape, dtype)],
@@ -3867,7 +3869,7 @@ def kernel(a_gmem, b_gmem, out_gmem128, out_gmem64,
38673869
plgpu.copy_smem_to_gmem(out_smem, out_gmem64)
38683870
plgpu.wait_smem_to_gmem(0)
38693871

3870-
f = plgpu.kernel(
3872+
f = self.kernel(
38713873
kernel,
38723874
out_shape=[jax.ShapeDtypeStruct(shape, dtype),
38733875
jax.ShapeDtypeStruct(shape, dtype)],
@@ -3939,6 +3941,7 @@ def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref):
39393941
squeezed_index=(True, False),
39403942
)
39413943
def test_copy_gmem_to_smem_partitioned(self, warp_level, squeezed_index):
3944+
self.skip_if_wg_semantics() # `pl.core_map` not implemented for warpgroup.
39423945
block_size = (128, 128)
39433946
partitioned_block_size = (block_size[0] // 2, block_size[1])
39443947
a = jax.random.uniform(
@@ -4000,7 +4003,7 @@ def _():
40004003
out_smem[...] = a_smem[...] + b_smem[...]
40014004
plgpu.copy_smem_to_gmem(out_smem, out_gmem.at[out_slice])
40024005
plgpu.wait_smem_to_gmem(0)
4003-
f = plgpu.kernel(
4006+
f = self.kernel(
40044007
kernel,
40054008
out_shape=jax.ShapeDtypeStruct(block_size, jnp.float32),
40064009
grid=(1,),
@@ -4028,7 +4031,7 @@ def kernel(out_ref, barrier):
40284031
plgpu.barrier_wait(barrier)
40294032
out_ref[...] = jnp.ones_like(out_ref)
40304033

4031-
f = plgpu.kernel(
4034+
f = self.kernel(
40324035
kernel,
40334036
out_shape=jax.ShapeDtypeStruct((128,), jnp.float32),
40344037
scratch_shapes=( # type: ignore
@@ -5615,7 +5618,7 @@ def body(o_ref, sem_ref):
56155618
pl.semaphore_signal(sem_ref)
56165619
o_ref[...] = jnp.ones_like(o_ref)
56175620
pl.semaphore_wait(sem_ref)
5618-
kernel = plgpu.kernel(
5621+
kernel = self.kernel(
56195622
body,
56205623
out_shape=jax.ShapeDtypeStruct((128,), jnp.float32),
56215624
scratch_shapes=[plgpu.SemaphoreType.REGULAR],
@@ -5638,7 +5641,7 @@ def body(o_ref, sem_ref):
56385641
with jax.named_scope("output"):
56395642
o_ref[...] = jnp.ones_like(o_ref)
56405643
with tempfile.TemporaryDirectory() as tmp_dir:
5641-
kernel = plgpu.kernel(
5644+
kernel = self.kernel(
56425645
body,
56435646
out_shape=jax.ShapeDtypeStruct((128,), jnp.float32),
56445647
scratch_shapes=[plgpu.SemaphoreType.REGULAR],
@@ -5670,7 +5673,7 @@ def _():
56705673
def _():
56715674
pl.semaphore_wait(sem_ref)
56725675
out_ref[...] = jnp.ones_like(out_ref)
5673-
kernel = plgpu.kernel(
5676+
kernel = self.kernel(
56745677
body,
56755678
out_shape=jax.ShapeDtypeStruct((128,), jnp.float32),
56765679
grid=(10,),
@@ -5693,7 +5696,7 @@ def _():
56935696
def _():
56945697
pl.semaphore_wait(sem_ref)
56955698
out_ref[...] = jnp.ones_like(out_ref)
5696-
kernel = plgpu.kernel(
5699+
kernel = self.kernel(
56975700
body,
56985701
out_shape=jax.ShapeDtypeStruct((128,), jnp.float32),
56995702
grid=(10,),
@@ -5725,7 +5728,7 @@ def _():
57255728
pl.semaphore_wait(global_sem)
57265729
out_ref[...] = jnp.ones_like(out_ref)
57275730
pl.semaphore_wait(block_sem)
5728-
kernel = plgpu.kernel(
5731+
kernel = self.kernel(
57295732
body,
57305733
out_shape=jax.ShapeDtypeStruct((128,), jnp.float32),
57315734
grid=(10,),
@@ -5910,7 +5913,7 @@ def body(out_gmem, _):
59105913
def loop_body(loop_info: plgpu.NDLoopInfo):
59115914
out_gmem[*loop_info.index, *cluster_idx] = sm_idx
59125915
out_shape = (*grid, *cluster)
5913-
result = plgpu.kernel(body,
5916+
result = self.kernel(body,
59145917
out_shape=jax.ShapeDtypeStruct(out_shape, jnp.int32),
59155918
grid=grid,
59165919
grid_names=grid_names,
@@ -5958,7 +5961,7 @@ def loop_body(loop_info: plgpu.NDLoopInfo, carry: jax.Array):
59585961
# All SMs wait until SM 0 has finished all blocks.
59595962
pl.semaphore_wait(global_semaphore)
59605963

5961-
result = plgpu.kernel(body,
5964+
result = self.kernel(body,
59625965
out_shape=jax.ShapeDtypeStruct((1,), jnp.int32),
59635966
grid=(sm_count + blocks_to_steal,),
59645967
grid_names=("x",),

0 commit comments

Comments
 (0)