Skip to content

Commit 8263fc4

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
Avoid the unnecessary initial copy in the MGPU collective matmul
It seems to be quite expensive and is completely unnecessary. Ideally we'd have a way to select between scratch_ref and lhs_ref, but we have to peel the first iteration of the device loop until Pallas can represent it. PiperOrigin-RevId: 802067411
1 parent 59789e6 commit 8263fc4

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ def all_gather_lhs_matmul(
7272
f"{n_shard_per_sm_n=} must be divisible by {block_n=}"
7373
)
7474
num_sms_m = max_num_sms // sm_n_tile
75+
if num_sms_m < (m_shard // block_m) and sm_n_tile > 1:
76+
# We never synchronize the N SMs across the different steps of the M
77+
# loop, so they can start overwriting each other's data.
78+
raise NotImplementedError("The kernel has races when M is large and sm_n_tile > 1")
7579

7680
swizzle = min(
7781
plgpu.find_swizzle(block_k * jnp.finfo(element_type).bits, "lhs"),
@@ -99,22 +103,12 @@ def _m_loop(idx):
99103
(mi,) = idx
100104
m_tile_slice = pl.ds(mi * block_m, block_m)
101105

102-
# For some reason ptxas spills if we unroll the loop over k
103-
copy_block = 32
104-
@pl.loop(0, k, step=copy_block)
105-
def _k_copy_loop(ki):
106-
k_slice = pl.ds(ki, copy_block)
107-
scratch_ref[0, :, k_slice] = lhs_ref[m_tile_slice, k_slice]
108-
109-
@pl.loop(0, num_devices)
110-
def _device_loop(device_offset):
106+
def device_step(lhs_source_ref, next_scratch_slot, device_offset):
107+
# Loop invariant: lhs_source_ref is ready to be used
111108
device_m_slice = pl.ds(
112109
lax.rem(device_offset + dev_id, num_devices) * m_shard, block_m
113110
)
114111

115-
scratch_slot = device_offset
116-
next_scratch_slot = scratch_slot + 1
117-
118112
def compute(n_tile_slice, send: bool):
119113
@functools.partial(
120114
pl.run_scoped, acc_ref=plgpu.ACC((block_m, block_n))
@@ -143,7 +137,7 @@ def k_loop(idxs, lhs_smem, rhs_smem):
143137
plgpu.wgmma(acc_ref, lhs_smem, rhs_smem)
144138
if send:
145139
# TODO(giorgioa): Send only for first sm_n.
146-
@pl.when(next_scratch_slot <= num_devices - 1)
140+
@pl.when(next_scratch_slot < num_devices - 1)
147141
def _():
148142
(ki,) = idxs
149143
k_slice = pl.ds(ki * block_k, block_k)
@@ -153,7 +147,7 @@ def _():
153147
# We only delay release by 1 step, so we need to wait for the
154148
# previous copies.
155149
plgpu.wait_smem_to_gmem(1, wait_read_only=True)
156-
k_loop(scratch_ref.at[scratch_slot], rhs_ref.at[..., n_tile_slice])
150+
k_loop(lhs_source_ref, rhs_ref.at[..., n_tile_slice])
157151
if send:
158152
# Make sure the copy is done and signal the receiving device.
159153
plgpu.wait_smem_to_gmem(0, wait_read_only=False)
@@ -176,6 +170,11 @@ def _n_loop(ni):
176170
# Wait for the next scratch to arrive --- see the device loop invariant.
177171
pl.semaphore_wait(received_sem)
178172

173+
device_step(lhs_ref.at[m_tile_slice], 0, 0)
174+
@pl.loop(1, num_devices)
175+
def _device_loop(device_offset):
176+
device_step(scratch_ref.at[device_offset - 1], device_offset, device_offset)
177+
179178
# Make sure all copies are fully done.
180179
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
181180

@@ -185,7 +184,7 @@ def _n_loop(ni):
185184
# The output, with its M dimension all-gathered.
186185
jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), dtype),
187186
# The scratch buffer used for the all-gather.
188-
jax.ShapeDtypeStruct((num_sms_m, num_devices, block_m, k), dtype),
187+
jax.ShapeDtypeStruct((num_sms_m, num_devices - 1, block_m, k), dtype),
189188
],
190189
scratch_shapes=[
191190
plgpu.SMEM((block_m, block_n), dtype, transforms=transforms),

tests/pallas/mgpu_collective_matmul_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def test_all_gather_lhs_matmul(
100100
)
101101
if m_shard % block_m:
102102
self.skipTest("m_shard must be divisible by block_m for now.")
103+
if (132 // sm_n_tile) < m_shard // block_m and sm_n_tile > 1:
104+
self.skipTest("The kernel has races when M is large and sm_n_tile > 1")
103105

104106
k1, k2 = random.split(random.key(1234), num=2)
105107
lhs = random.normal(k1, (num_devices * m_shard, k), dtype)
@@ -118,6 +120,7 @@ def run(body):
118120
)(out)
119121
return out
120122

123+
ref_out = run(lambda x, y: lax.all_gather(x, "x", axis=0, tiled=True) @ y)
121124
out = run(
122125
functools.partial(
123126
collective_matmul_mgpu.all_gather_lhs_matmul,
@@ -130,7 +133,6 @@ def run(body):
130133
dtype=dtype,
131134
)
132135
)
133-
ref_out = run(lambda x, y: lax.all_gather(x, "x", axis=0, tiled=True) @ y)
134136
np.testing.assert_allclose(out, ref_out)
135137

136138

0 commit comments

Comments
 (0)