@@ -2119,6 +2119,8 @@ def test_tma_load_multicast(self, collective_dims, noncollective_dims, collectiv
21192119 to test that the cluster axes are used correctly.
21202120 """
21212121
2122+ self .skip_if_wg_semantics () # User transforms are not supported.
2123+
21222124 dtype = jnp .float16
21232125 cluster = [1 , 1 , 1 ]
21242126 for d in collective_dims :
@@ -2167,7 +2169,7 @@ def cluster_id(axes):
21672169 plgpu .wait_smem_to_gmem (0 )
21682170
21692171 x = np .arange (np .prod (shape ), dtype = dtype ).reshape (shape )
2170- kernel = plgpu .kernel (
2172+ kernel = self .kernel (
21712173 body ,
21722174 grid = cluster ,
21732175 grid_names = ("grid_x" , "grid_y" , "grid_z" ),
@@ -2380,7 +2382,7 @@ def test_discharge_comms_effect(self):
23802382 def body (out , sem ):
23812383 pl .semaphore_signal (sem , device_id = jnp .asarray (2 , jnp .int32 ))
23822384
2383- f = plgpu .kernel (
2385+ f = self .kernel (
23842386 body ,
23852387 out_shape = jax .ShapeDtypeStruct ((), jnp .int32 ),
23862388 scratch_shapes = [plgpu .SemaphoreType .REGULAR ],
@@ -2416,7 +2418,7 @@ def kernel(dst, collective_barrier):
24162418 plgpu .barrier_arrive (collective_barrier )
24172419 plgpu .barrier_wait (collective_barrier )
24182420 dst [...] = jnp .ones_like (dst )
2419- y = plgpu .kernel (
2421+ y = self .kernel (
24202422 kernel ,
24212423 out_shape = jax .ShapeDtypeStruct ((128 ,), jnp .int32 ),
24222424 scratch_shapes = [plgpu .ClusterBarrier (collective_axes = ("x" ,), num_arrivals = 4 )],
@@ -2434,7 +2436,7 @@ def setUp(self):
24342436
24352437 def test_axis_index (self ):
24362438 warp_mesh = plgpu .WarpMesh (axis_name = "warp" )
2437- @functools .partial (plgpu .kernel ,
2439+ @functools .partial (self .kernel ,
24382440 out_shape = jax .ShapeDtypeStruct ((2 , 128 ), jnp .int32 ))
24392441 def kernel (y_ref ):
24402442 def scope (ones_smem_ref , threes_smem_ref ):
@@ -2471,7 +2473,7 @@ def _():
24712473 )
24722474 def test_scalar_binary_op (self , op ):
24732475 warp_mesh = plgpu .WarpMesh (axis_name = "warp" )
2474- @functools .partial (plgpu .kernel ,
2476+ @functools .partial (self .kernel ,
24752477 out_shape = jax .ShapeDtypeStruct ((), jnp .int32 ))
24762478 def kernel (y_ref ):
24772479 @pl .core_map (warp_mesh )
@@ -2492,7 +2494,7 @@ def test_errors_when_closing_over_array(self):
24922494 # a mesh, since we would need to present a view of the array local
24932495 # to each warp.
24942496 warp_mesh = plgpu .WarpMesh (axis_name = "warp" )
2495- @functools .partial (plgpu .kernel ,
2497+ @functools .partial (self .kernel ,
24962498 out_shape = jax .ShapeDtypeStruct ((32 , 32 ), jnp .float32 ),
24972499 scratch_shapes = [plgpu .SMEM ((32 , 32 ), jnp .float32 )])
24982500 def kernel (out_ref , smem_ref ):
@@ -2512,7 +2514,7 @@ def _():
25122514 @parameterized .parameters (True , False )
25132515 def test_single_warp_loop (self , force_while ):
25142516 warp_mesh = plgpu .WarpMesh (axis_name = "warp" )
2515- @functools .partial (plgpu .kernel ,
2517+ @functools .partial (self .kernel ,
25162518 out_shape = jax .ShapeDtypeStruct ((10 , 128 ), jnp .int32 ))
25172519 def kernel (y_ref ):
25182520 def scope (smem_ref ):
@@ -2539,7 +2541,7 @@ def loop_body(i, _):
25392541 def test_debug_print (self ):
25402542 warp_mesh = plgpu .WarpMesh (axis_name = "warp" )
25412543 @functools .partial (
2542- plgpu .kernel ,
2544+ self .kernel ,
25432545 out_shape = jnp .zeros (128 , np .int32 ),
25442546 )
25452547 def kernel (ref ):
@@ -2566,7 +2568,7 @@ def test_copy_gmem_to_smem_from_different_warps(self,
25662568 wait_smem_to_gmem_in_warp ):
25672569 # In this test, we issue a copy from from warp 0 and await it in warp 1.
25682570 warp_mesh = plgpu .WarpMesh (axis_name = "warp" )
2569- @functools .partial (plgpu .kernel ,
2571+ @functools .partial (self .kernel ,
25702572 out_shape = jax .ShapeDtypeStruct ((32 , 32 ), jnp .float32 ))
25712573 def kernel (x_ref , y_ref ):
25722574 def scope (smem_ref , tma_barrier ):
@@ -3574,7 +3576,7 @@ def kernel(a_gmem, b_gmem, out_gmem,
35743576 plgpu .copy_smem_to_gmem (out_smem , out_gmem )
35753577 plgpu .wait_smem_to_gmem (0 )
35763578
3577- f = plgpu .kernel (
3579+ f = self .kernel (
35783580 kernel ,
35793581 out_shape = jax .ShapeDtypeStruct (shape , dtype ),
35803582 scratch_shapes = [
@@ -3787,7 +3789,7 @@ def kernel(a_gmem, b_gmem, out_gmem128, out_gmem64,
37873789 plgpu .copy_smem_to_gmem (out_smem , out_gmem64 )
37883790 plgpu .wait_smem_to_gmem (0 )
37893791
3790- f = plgpu .kernel (
3792+ f = self .kernel (
37913793 kernel ,
37923794 out_shape = [jax .ShapeDtypeStruct (shape , dtype ),
37933795 jax .ShapeDtypeStruct (shape , dtype )],
@@ -3867,7 +3869,7 @@ def kernel(a_gmem, b_gmem, out_gmem128, out_gmem64,
38673869 plgpu .copy_smem_to_gmem (out_smem , out_gmem64 )
38683870 plgpu .wait_smem_to_gmem (0 )
38693871
3870- f = plgpu .kernel (
3872+ f = self .kernel (
38713873 kernel ,
38723874 out_shape = [jax .ShapeDtypeStruct (shape , dtype ),
38733875 jax .ShapeDtypeStruct (shape , dtype )],
@@ -3939,6 +3941,7 @@ def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref):
39393941 squeezed_index = (True , False ),
39403942 )
39413943 def test_copy_gmem_to_smem_partitioned (self , warp_level , squeezed_index ):
3944+ self .skip_if_wg_semantics () # `pl.core_map` not implemented for warpgroup.
39423945 block_size = (128 , 128 )
39433946 partitioned_block_size = (block_size [0 ] // 2 , block_size [1 ])
39443947 a = jax .random .uniform (
@@ -4000,7 +4003,7 @@ def _():
40004003 out_smem [...] = a_smem [...] + b_smem [...]
40014004 plgpu .copy_smem_to_gmem (out_smem , out_gmem .at [out_slice ])
40024005 plgpu .wait_smem_to_gmem (0 )
4003- f = plgpu .kernel (
4006+ f = self .kernel (
40044007 kernel ,
40054008 out_shape = jax .ShapeDtypeStruct (block_size , jnp .float32 ),
40064009 grid = (1 ,),
@@ -4028,7 +4031,7 @@ def kernel(out_ref, barrier):
40284031 plgpu .barrier_wait (barrier )
40294032 out_ref [...] = jnp .ones_like (out_ref )
40304033
4031- f = plgpu .kernel (
4034+ f = self .kernel (
40324035 kernel ,
40334036 out_shape = jax .ShapeDtypeStruct ((128 ,), jnp .float32 ),
40344037 scratch_shapes = ( # type: ignore
@@ -5615,7 +5618,7 @@ def body(o_ref, sem_ref):
56155618 pl .semaphore_signal (sem_ref )
56165619 o_ref [...] = jnp .ones_like (o_ref )
56175620 pl .semaphore_wait (sem_ref )
5618- kernel = plgpu .kernel (
5621+ kernel = self .kernel (
56195622 body ,
56205623 out_shape = jax .ShapeDtypeStruct ((128 ,), jnp .float32 ),
56215624 scratch_shapes = [plgpu .SemaphoreType .REGULAR ],
@@ -5638,7 +5641,7 @@ def body(o_ref, sem_ref):
56385641 with jax .named_scope ("output" ):
56395642 o_ref [...] = jnp .ones_like (o_ref )
56405643 with tempfile .TemporaryDirectory () as tmp_dir :
5641- kernel = plgpu .kernel (
5644+ kernel = self .kernel (
56425645 body ,
56435646 out_shape = jax .ShapeDtypeStruct ((128 ,), jnp .float32 ),
56445647 scratch_shapes = [plgpu .SemaphoreType .REGULAR ],
@@ -5670,7 +5673,7 @@ def _():
56705673 def _ ():
56715674 pl .semaphore_wait (sem_ref )
56725675 out_ref [...] = jnp .ones_like (out_ref )
5673- kernel = plgpu .kernel (
5676+ kernel = self .kernel (
56745677 body ,
56755678 out_shape = jax .ShapeDtypeStruct ((128 ,), jnp .float32 ),
56765679 grid = (10 ,),
@@ -5693,7 +5696,7 @@ def _():
56935696 def _ ():
56945697 pl .semaphore_wait (sem_ref )
56955698 out_ref [...] = jnp .ones_like (out_ref )
5696- kernel = plgpu .kernel (
5699+ kernel = self .kernel (
56975700 body ,
56985701 out_shape = jax .ShapeDtypeStruct ((128 ,), jnp .float32 ),
56995702 grid = (10 ,),
@@ -5725,7 +5728,7 @@ def _():
57255728 pl .semaphore_wait (global_sem )
57265729 out_ref [...] = jnp .ones_like (out_ref )
57275730 pl .semaphore_wait (block_sem )
5728- kernel = plgpu .kernel (
5731+ kernel = self .kernel (
57295732 body ,
57305733 out_shape = jax .ShapeDtypeStruct ((128 ,), jnp .float32 ),
57315734 grid = (10 ,),
@@ -5910,7 +5913,7 @@ def body(out_gmem, _):
59105913 def loop_body (loop_info : plgpu .NDLoopInfo ):
59115914 out_gmem [* loop_info .index , * cluster_idx ] = sm_idx
59125915 out_shape = (* grid , * cluster )
5913- result = plgpu .kernel (body ,
5916+ result = self .kernel (body ,
59145917 out_shape = jax .ShapeDtypeStruct (out_shape , jnp .int32 ),
59155918 grid = grid ,
59165919 grid_names = grid_names ,
@@ -5958,7 +5961,7 @@ def loop_body(loop_info: plgpu.NDLoopInfo, carry: jax.Array):
59585961 # All SMs wait until SM 0 has finished all blocks.
59595962 pl .semaphore_wait (global_semaphore )
59605963
5961- result = plgpu .kernel (body ,
5964+ result = self .kernel (body ,
59625965 out_shape = jax .ShapeDtypeStruct ((1 ,), jnp .int32 ),
59635966 grid = (sm_count + blocks_to_steal ,),
59645967 grid_names = ("x" ,),
0 commit comments