@@ -1142,68 +1142,36 @@ def setUp(self):
1142
1142
(tcgen05 .fa_m64_collective_layout , tcgen05 .tmem_m64_collective_layout , 64 ),
1143
1143
],
1144
1144
)
1145
- def test_load_store_tmem_swizzle (self , jax_dtype_packing , reg_tmem_layout_m ):
1145
+ def test_load_store_tmem (self , jax_dtype_packing , reg_tmem_layout_m ):
1146
1146
jax_dtype , packing = jax_dtype_packing
1147
1147
reg_layout_f , tmem_layout_f , m = reg_tmem_layout_m
1148
- swizzle = 128
1149
- in_mlir_dtype = utils .dtype_to_ir_type (jax_dtype )
1150
- swizzle_elems = swizzle // bytewidth (in_mlir_dtype )
1151
- tiling = (8 , swizzle_elems )
1152
1148
n = 256
1153
1149
reg_layout = reg_layout_f (n )
1154
1150
1155
- def kernel (ctx , input , output , scratch ):
1156
- smem , barrier , tmem = scratch
1157
- ctx .async_copy (
1158
- src_ref = input ,
1159
- dst_ref = smem ,
1160
- swizzle = swizzle ,
1161
- gmem_transform = mgpu .TileTransform (tiling ),
1162
- barrier = barrier ,
1163
- )
1164
- barrier .wait ()
1165
- tmem .store (fa .FragmentedArray .load_tiled (smem , swizzle , layout = reg_layout ))
1151
+ def kernel (ctx , input , output , tmem ):
1152
+ del ctx
1153
+ tmem .store (fa .FragmentedArray .load_untiled (input , layout = reg_layout , optimized = False ))
1166
1154
tcgen05 .commit_tmem ()
1167
- tmem .load (reg_layout ).store_tiled (smem , swizzle )
1168
- mgpu .commit_shared ()
1169
- ctx .async_copy (
1170
- src_ref = smem , dst_ref = output , swizzle = swizzle , gmem_transform = mgpu .TileTransform (tiling ),
1171
- )
1172
- ctx .await_async_copy (0 )
1155
+ tmem .load (reg_layout ).store_untiled (output , optimized = False )
1173
1156
1174
1157
x = self .prng .uniform (- 1 , 1 , (m , n )).astype (jax_dtype )
1175
- scratch_shape = [
1176
- jax .ShapeDtypeStruct (tile_shape (x .shape , tiling ), jax_dtype ),
1177
- mgpu .TMABarrier (),
1178
- mgpu .TMEM (x .shape , jax_dtype , layout = tmem_layout_f (n , packing )),
1179
- ]
1180
1158
y = mgpu .as_gpu_kernel (
1181
- kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), x , x , scratch_shape
1159
+ kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), x , x , mgpu . TMEM ( x . shape , jax_dtype , layout = tmem_layout_f ( n , packing )),
1182
1160
)(x )
1183
1161
np .testing .assert_array_equal (x , y )
1184
1162
1185
1163
@parameterized .parameters ([(jnp .float32 , 1 ), (jnp .float16 , 1 ), (jnp .float16 , 2 )])
1186
1164
def test_load_store_tmem_native (self , jax_dtype , packing ):
1187
1165
1188
- def kernel (ctx , input , output , scratch ):
1189
- smem , barrier , tmem = scratch
1190
- ctx .async_copy (src_ref = input , dst_ref = smem , barrier = barrier )
1191
- barrier .wait ()
1192
- tmem .store (fa .FragmentedArray .load_untiled (smem , layout = tcgen05 .TMEM_NATIVE_LAYOUT , optimized = False ))
1166
+ def kernel (ctx , input , output , tmem ):
1167
+ del ctx
1168
+ tmem .store (fa .FragmentedArray .load_untiled (input , layout = tcgen05 .TMEM_NATIVE_LAYOUT , optimized = False ))
1193
1169
tcgen05 .commit_tmem ()
1194
- tmem .load (tcgen05 .TMEM_NATIVE_LAYOUT ).store_untiled (smem , optimized = False )
1195
- mgpu .commit_shared ()
1196
- ctx .async_copy (src_ref = smem , dst_ref = output )
1197
- ctx .await_async_copy (0 )
1170
+ tmem .load (tcgen05 .TMEM_NATIVE_LAYOUT ).store_untiled (output , optimized = False )
1198
1171
1199
1172
x = self .prng .uniform (- 1 , 1 , (128 , 128 )).astype (jax_dtype )
1200
- scratch_shape = [
1201
- jax .ShapeDtypeStruct (x .shape , jax_dtype ),
1202
- mgpu .TMABarrier (),
1203
- mgpu .TMEM (x .shape , jax_dtype , packing = packing ),
1204
- ]
1205
1173
y = mgpu .as_gpu_kernel (
1206
- kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), x , x , scratch_shape
1174
+ kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), x , x , mgpu . TMEM ( x . shape , jax_dtype , packing = packing )
1207
1175
)(x )
1208
1176
np .testing .assert_array_equal (x , y )
1209
1177
@@ -1214,34 +1182,16 @@ def kernel(ctx, input, output, scratch):
1214
1182
])
1215
1183
@jtu .thread_unsafe_test ()
1216
1184
def test_tmem_debug_print (self , jax_dtype , packing , expected ):
1217
- swizzle = 128
1218
- in_mlir_dtype = utils .dtype_to_ir_type (jax_dtype )
1219
- swizzle_elems = swizzle // bytewidth (in_mlir_dtype )
1220
- tiling = (8 , swizzle_elems )
1221
-
1222
- def kernel (ctx , input , output , scratch ):
1223
- smem , barrier , tmem = scratch
1224
- ctx .async_copy (
1225
- src_ref = input ,
1226
- dst_ref = smem ,
1227
- swizzle = swizzle ,
1228
- gmem_transform = mgpu .TileTransform (tiling ),
1229
- barrier = barrier ,
1230
- )
1231
- barrier .wait ()
1232
- tmem .store (fa .FragmentedArray .load_tiled (smem , swizzle , layout = tcgen05 .LAYOUT ))
1185
+ def kernel (ctx , input , output , tmem ):
1186
+ del ctx , output
1187
+ tmem .store (fa .FragmentedArray .load_untiled (input , layout = tcgen05 .LAYOUT , optimized = False ))
1233
1188
tcgen05 .commit_tmem ()
1234
1189
tmem .slice (slice (None ), slice (0 , 8 ))._debug_print ()
1235
1190
1236
1191
x = jnp .arange (128 * 128 , dtype = jax_dtype ).reshape (128 , 128 )
1237
- scratch_shape = [
1238
- jax .ShapeDtypeStruct (tile_shape (x .shape , tiling ), jax_dtype ),
1239
- mgpu .TMABarrier (),
1240
- mgpu .TMEM (x .shape , jax_dtype , packing = packing ),
1241
- ]
1242
1192
with self .capture_stdout () as stdout :
1243
1193
mgpu .as_gpu_kernel (
1244
- kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), x , x , scratch_shape
1194
+ kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), x , x , mgpu . TMEM ( x . shape , jax_dtype , packing = packing ),
1245
1195
)(x ).block_until_ready ()
1246
1196
self .assertIn ("[1, 2]: " + expected , stdout ())
1247
1197
@@ -1457,29 +1407,21 @@ def test_mma_lhs_tmem(self, m, n, in_jax_dtype, out_jax_dtype):
1457
1407
in_mlir_dtype = utils .dtype_to_ir_type (in_jax_dtype )
1458
1408
swizzle_elems = swizzle // bytewidth (in_mlir_dtype )
1459
1409
k = swizzle_elems * k_steps
1460
- lhs_tiling = rhs_tiling = (8 , swizzle_elems )
1410
+ rhs_tiling = (8 , swizzle_elems )
1461
1411
1462
1412
def kernel (ctx , lhs , rhs , out , scratch ):
1463
- lhs_smem , rhs_smem , barriers , mma_barrier , acc , lhs_tmem = scratch
1464
- ctx .async_copy (
1465
- src_ref = lhs ,
1466
- dst_ref = lhs_smem ,
1467
- swizzle = swizzle ,
1468
- gmem_transform = mgpu .TileTransform (lhs_tiling ),
1469
- barrier = barriers [0 ],
1470
- )
1413
+ rhs_smem , barrier , mma_barrier , acc , lhs_tmem = scratch
1471
1414
ctx .async_copy (
1472
1415
src_ref = rhs ,
1473
1416
dst_ref = rhs_smem ,
1474
1417
swizzle = swizzle ,
1475
1418
gmem_transform = mgpu .TileTransform (rhs_tiling ),
1476
- barrier = barriers [ 1 ] ,
1419
+ barrier = barrier ,
1477
1420
)
1478
- barriers [0 ].wait ()
1479
- barriers [1 ].wait ()
1421
+ barrier .wait ()
1480
1422
lhs_tmem .store (
1481
- fa .FragmentedArray .load_tiled (
1482
- lhs_smem , swizzle , layout = tcgen05 .LAYOUT
1423
+ fa .FragmentedArray .load_untiled (
1424
+ lhs , layout = tcgen05 .LAYOUT , optimized = False
1483
1425
)
1484
1426
)
1485
1427
tcgen05 .commit_tmem ()
@@ -1497,9 +1439,8 @@ def kernel(ctx, lhs, rhs, out, scratch):
1497
1439
y = self .prng .uniform (- 1 , 1 , y_shape ).astype (in_jax_dtype )
1498
1440
out_shape = jax .ShapeDtypeStruct ((m , n ), out_jax_dtype )
1499
1441
scratch_shape = [
1500
- jax .ShapeDtypeStruct (tile_shape (x_shape , lhs_tiling ), in_jax_dtype ),
1501
1442
jax .ShapeDtypeStruct (tile_shape (y_shape , rhs_tiling ), in_jax_dtype ),
1502
- mgpu .TMABarrier (2 ),
1443
+ mgpu .TMABarrier (),
1503
1444
mgpu .Barrier (1 ),
1504
1445
mgpu .TMEM ((128 , n ), out_jax_dtype ),
1505
1446
mgpu .TMEM ((128 , k ), in_jax_dtype , packing = 2 ),
@@ -3644,7 +3585,6 @@ def setUp(self):
3644
3585
def test_smem_registers_load_store (self , layout , dtype ):
3645
3586
def body (ctx , param : ir .Value , result : ir .Value , smem : list [ir .Value ]):
3646
3587
del ctx
3647
- [tmp_smem ] = smem
3648
3588
shape = ir .MemRefType (param .type ).shape
3649
3589
elt_type = ir .MemRefType (param .type ).element_type
3650
3590
@@ -3657,10 +3597,10 @@ def body(ctx, param: ir.Value, result: ir.Value, smem: list[ir.Value]):
3657
3597
reg = mgpu_dialect .layout_cast (reg , mgpu_layouts .to_layout_attr (layout ))
3658
3598
3659
3599
# Registers -> SMEM
3660
- vector .store (reg , tmp_smem , zero_vector_indices )
3600
+ vector .store (reg , smem , zero_vector_indices )
3661
3601
3662
3602
# SMEM -> Registers
3663
- reg = vector .load (vector_type , tmp_smem , zero_vector_indices )
3603
+ reg = vector .load (vector_type , smem , zero_vector_indices )
3664
3604
reg = mgpu_dialect .layout_cast (reg , mgpu_layouts .to_layout_attr (layout ))
3665
3605
3666
3606
# Registers -> GMEM
@@ -3674,7 +3614,7 @@ def body(ctx, param: ir.Value, result: ir.Value, smem: list[ir.Value]):
3674
3614
block = (128 , 1 , 1 ),
3675
3615
in_shape = jax_shape ,
3676
3616
out_shape = jax_shape ,
3677
- smem_scratch_shape = [ jax_shape ] ,
3617
+ smem_scratch_shape = jax_shape ,
3678
3618
thread_semantics = mgpu .LoweringSemantics .Warpgroup ,
3679
3619
)
3680
3620
@@ -4488,11 +4428,9 @@ def body(
4488
4428
):
4489
4429
# We need to have a result `x` otherwise the kernel will not be generated.
4490
4430
del ctx , x
4491
- [tmem ] = smem
4492
-
4493
4431
tmem_ref = mgpu_dialect .tmem_alloc (
4494
4432
result = tmem_type ,
4495
- smem_ptr = tmem ,
4433
+ smem_ptr = smem ,
4496
4434
collective = collective ,
4497
4435
packing = packing ,
4498
4436
)
@@ -4513,7 +4451,7 @@ def body(
4513
4451
block = (128 , 1 , 1 ),
4514
4452
in_shape = (),
4515
4453
out_shape = (jax .ShapeDtypeStruct ((), jnp .int32 ),),
4516
- smem_scratch_shape = [ jax .ShapeDtypeStruct ((), jnp .int32 )] ,
4454
+ smem_scratch_shape = jax .ShapeDtypeStruct ((), jnp .int32 ),
4517
4455
thread_semantics = mgpu .LoweringSemantics .Warpgroup ,
4518
4456
)()
4519
4457
0 commit comments