Skip to content

Commit a5f2a26

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[NFC] Refactor the collective matmul
1. Make sure out_smem is allocated for the duration of the kernel. The previous code didn't fully await its usage before its scoped allocations expired, which is UB. 2. Deduplicate the bodies of all N loop steps (we're peeling off the first step, since it's the only one that does comms). PiperOrigin-RevId: 801825966
1 parent 4dfc924 commit a5f2a26

File tree

1 file changed

+30
-57
lines changed

1 file changed

+30
-57
lines changed

jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py

Lines changed: 30 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def all_gather_lhs_matmul(
8282
plgpu.SwizzleTransform(swizzle),
8383
)
8484

85-
def kernel_body(lhs_ref, rhs_ref, out_ref, scratch_ref, received_sem):
85+
def kernel_body(lhs_ref, rhs_ref, out_ref, scratch_ref, out_smem, received_sem):
8686
sm_m = lax.axis_index('sm_m')
8787
sm_n = lax.axis_index('sm_n')
8888
n_start = sm_n * n_shard_per_sm_n
@@ -111,66 +111,15 @@ def _device_loop(device_offset):
111111
device_m_slice = pl.ds(
112112
lax.rem(device_offset + dev_id, num_devices) * m_shard, block_m
113113
)
114-
n_tile_slice = pl.ds(n_start, block_n)
115114

116115
scratch_slot = device_offset
117116
next_scratch_slot = scratch_slot + 1
118117

119-
out_smem = plgpu.SMEM((block_m, block_n), dtype, transforms=transforms)
120-
121-
@functools.partial(
122-
pl.run_scoped,
123-
acc_ref=plgpu.ACC((block_m, block_n)),
124-
out_smem=out_smem,
125-
)
126-
def _(acc_ref, out_smem):
127-
@functools.partial(
128-
plgpu.emit_pipeline,
129-
grid=(k // block_k,),
130-
in_specs=[
131-
plgpu.BlockSpec((block_m, block_k), lambda k: (0, k), transforms=transforms),
132-
plgpu.BlockSpec((block_k, block_n), lambda k: (k, 0), transforms=transforms),
133-
],
134-
max_concurrent_steps=max_concurrent_steps,
135-
delay_release=1,
136-
)
137-
def k_loop(idxs, lhs_smem, rhs_smem):
138-
plgpu.wgmma(acc_ref, lhs_smem, rhs_smem)
139-
# TODO(giorgioa): Send only for first sm_n.
140-
@pl.when(next_scratch_slot <= num_devices - 1)
141-
def _():
142-
(ki,) = idxs
143-
k_slice = pl.ds(ki * block_k, block_k)
144-
plgpu.copy_smem_to_gmem(
145-
lhs_smem, send_scratch_ref.at[next_scratch_slot, :, k_slice]
146-
)
147-
# We only delay release by 1 step, so we need to wait for the
148-
# previous copies.
149-
plgpu.wait_smem_to_gmem(1, wait_read_only=True)
150-
k_loop(scratch_ref.at[scratch_slot], rhs_ref.at[..., n_tile_slice])
151-
# Make sure the copy is fully done.
152-
plgpu.wait_smem_to_gmem(0, wait_read_only=False)
153-
pl.semaphore_signal(received_sem, device_id=send_dev_id)
154-
# Make sure all TMAs have read SMEM before we overwrite it.
155-
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
156-
out_smem[...] = acc_ref[...].astype(out_smem.dtype)
157-
plgpu.commit_smem()
158-
plgpu.copy_smem_to_gmem(
159-
out_smem,
160-
out_ref.at[device_m_slice, n_tile_slice].at[m_tile_slice],
161-
)
162-
163-
@pl.loop(1, n_shard_per_sm_n // block_n)
164-
def _n_loop(ni):
165-
n_tile_slice = pl.ds(n_start + ni * block_n, block_n)
166-
118+
def compute(n_tile_slice, send: bool):
167119
@functools.partial(
168-
pl.run_scoped,
169-
acc_ref=plgpu.ACC((block_m, block_n)),
170-
out_smem=out_smem,
120+
pl.run_scoped, acc_ref=plgpu.ACC((block_m, block_n))
171121
)
172-
def _(acc_ref, out_smem):
173-
122+
def _(acc_ref):
174123
@functools.partial(
175124
plgpu.emit_pipeline,
176125
grid=(k // block_k,),
@@ -190,17 +139,40 @@ def _(acc_ref, out_smem):
190139
],
191140
max_concurrent_steps=max_concurrent_steps,
192141
)
193-
def k_loop(_, lhs_smem, rhs_smem):
142+
def k_loop(idxs, lhs_smem, rhs_smem):
194143
plgpu.wgmma(acc_ref, lhs_smem, rhs_smem)
144+
if send:
145+
# TODO(giorgioa): Send only for first sm_n.
146+
@pl.when(next_scratch_slot <= num_devices - 1)
147+
def _():
148+
(ki,) = idxs
149+
k_slice = pl.ds(ki * block_k, block_k)
150+
plgpu.copy_smem_to_gmem(
151+
lhs_smem, send_scratch_ref.at[next_scratch_slot, :, k_slice]
152+
)
153+
# We only delay release by 1 step, so we need to wait for the
154+
# previous copies.
155+
plgpu.wait_smem_to_gmem(1, wait_read_only=True)
195156
k_loop(scratch_ref.at[scratch_slot], rhs_ref.at[..., n_tile_slice])
157+
if send:
158+
# Make sure the copy is done and signal the receiving device.
159+
plgpu.wait_smem_to_gmem(0, wait_read_only=False)
160+
pl.semaphore_signal(received_sem, device_id=send_dev_id)
196161
# Make sure all TMAs have read SMEM before we overwrite it.
197162
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
198163
out_smem[...] = acc_ref[...].astype(out_smem.dtype)
199164
plgpu.commit_smem()
200165
plgpu.copy_smem_to_gmem(
201-
out_smem, out_ref.at[device_m_slice, n_tile_slice].at[m_tile_slice]
166+
out_smem,
167+
out_ref.at[device_m_slice, n_tile_slice].at[m_tile_slice],
202168
)
203169

170+
compute(pl.ds(n_start, block_n), send=True)
171+
172+
@pl.loop(1, n_shard_per_sm_n // block_n)
173+
def _n_loop(ni):
174+
compute(pl.ds(n_start + ni * block_n, block_n), send=False)
175+
204176
# Wait for the next scratch to arrive --- see the device loop invariant.
205177
pl.semaphore_wait(received_sem)
206178

@@ -216,6 +188,7 @@ def k_loop(_, lhs_smem, rhs_smem):
216188
jax.ShapeDtypeStruct((num_sms_m, num_devices, block_m, k), dtype),
217189
],
218190
scratch_shapes=[
191+
plgpu.SMEM((block_m, block_n), dtype, transforms=transforms),
219192
plgpu.SemaphoreType.REGULAR, # Received semaphore
220193
],
221194
grid=(num_sms_m, sm_n_tile),

0 commit comments

Comments
 (0)