Skip to content

Commit 1178809

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Account for squeezed dims in partitioned axis of copy_smem_to_gmem.
PiperOrigin-RevId: 800628443
1 parent 1b24617 commit 1178809

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

jax/experimental/mosaic/gpu/launch_context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,9 @@ def async_copy(
788788
if gmem_ref is dst_ref:
789789
raise ValueError("Only GMEM -> SMEM copies can be collective")
790790
if partitioned is not None:
791+
# Increment partitioned by the number of preceding squeezed dimensions.
792+
partitioned = np.where(
793+
np.cumsum(~np.array(is_squeezed)) == partitioned+1)[0][0]
791794
# Partitioning happens on the logical slice we extract from GMEM, so we do
792795
# it before we apply transforms.
793796
if collective is None: # This implies non-gather TMA already.

jax/experimental/pallas/ops/gpu/blackwell_ragged_dot_mgpu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,7 @@ def _():
174174
b_gmem.at[slice_k, slice_n],
175175
b_smem.at[slot],
176176
b_tma_barrier.at[slot],
177-
# TODO: partitioned_axis doesn't account for squeezed dims so we have 2 instead of 1 here.
178-
partitioned_axis=2 if collective else None,
177+
partitioned_axis=1 if collective else None,
179178
collective_axes=collective_axis,
180179
)
181180
lax.fori_loop(0, k_iters, _loop_body, None)

tests/pallas/mosaic_gpu_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3947,18 +3947,25 @@ def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref):
39473947
expected = x @ y
39483948
np.testing.assert_allclose(result, expected, rtol=1e-3)
39493949

3950-
@parameterized.parameters((True,), (False,))
3951-
def test_copy_gmem_to_smem_partitioned(self, warp_level):
3950+
@parameterized.product(
3951+
warp_level=(True, False),
3952+
squeezed_index=(True, False),
3953+
)
3954+
def test_copy_gmem_to_smem_partitioned(self, warp_level, squeezed_index):
39523955
self.skip_if_wg_semantics()
39533956
block_size = (128, 128)
39543957
partitioned_block_size = (block_size[0] // 2, block_size[1])
39553958
a = jax.random.uniform(
39563959
jax.random.key(0), shape=block_size, dtype=jnp.float32)
3960+
if squeezed_index:
3961+
a = a.reshape(1, *block_size)
39573962
b = jax.random.uniform(
39583963
jax.random.key(1), shape=block_size, dtype=jnp.float32)
39593964
def kernel(a_gmem, b_gmem, out_gmem,
39603965
a_smem, b_smem, out_smem,
39613966
a_tma_barrier, b_tma_barrier, cluster_barrier):
3967+
if squeezed_index:
3968+
a_gmem = a_gmem.at[0]
39623969
cluster_idx = lax.axis_index("x")
39633970
out_slice = pl.ds(cluster_idx * partitioned_block_size[0],
39643971
partitioned_block_size[0])
@@ -4024,6 +4031,8 @@ def _():
40244031
),
40254032
)
40264033
result = f(a, b)
4034+
if squeezed_index:
4035+
a = a[0]
40274036
np.testing.assert_array_equal(result, a + b)
40284037

40294038
def test_arrive_wait_on_tc_barrier(self):

0 commit comments

Comments
 (0)