Skip to content

Commit 07d01b4

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Allocate a unique buffer per device in LHS all gather matmul
This simplifies the kernel and removes synchronization points that seem to be hurting the performance. The primary benefit we want to get from collective matmuls is the reduced runtime, not memory usage. PiperOrigin-RevId: 800060504
1 parent 0bad246 commit 07d01b4

File tree

1 file changed

+8
-26
lines changed

1 file changed

+8
-26
lines changed

jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -82,19 +82,14 @@ def all_gather_lhs_matmul(
8282
plgpu.SwizzleTransform(swizzle),
8383
)
8484

85-
def kernel_body(lhs_ref, rhs_ref, out_ref, scratch_ref, capacity_sem, received_sem):
85+
def kernel_body(lhs_ref, rhs_ref, out_ref, scratch_ref, received_sem):
8686
sm_m = lax.axis_index('sm_m')
8787
sm_n = lax.axis_index('sm_n')
8888
n_start = sm_n * n_shard_per_sm_n
8989
scratch_ref = scratch_ref.at[sm_m]
9090

9191
dev_id = lax.axis_index(axis_name)
9292
send_dev_id = lax.rem(dev_id + axis_size - 1, axis_size)
93-
recv_dev_id = lax.rem(dev_id + 1, axis_size)
94-
# NOTE: Technically we should signal the recv_dev_id (and our signal would
95-
# be received from send_dev_id), but if everyone signals in a ring after a
96-
# barrier then it's equivalent to a local signal.
97-
pl.semaphore_signal(capacity_sem)
9893
send_scratch_ref = plgpu.remote_ref(
9994
scratch_ref, send_dev_id, device_id_type=pl.DeviceIdType.LOGICAL
10095
)
@@ -118,13 +113,8 @@ def _device_loop(device_offset):
118113
)
119114
n_tile_slice = pl.ds(n_start, block_n)
120115

121-
# Loop invariant: scratch_ref.at[scratch_slot] is ready to be used
122-
# We're double buffering the scratch space. At each step, we read from
123-
# scratch_ref.at[scratch_slot] and write to scratch_ref.at[next_scratch_slot]
124-
# located on the send_dev_id. We swap the slots after completing a step,
125-
# which lets us overlap the copy with compute.
126-
scratch_slot = lax.rem(device_offset, 2)
127-
next_scratch_slot = 1 - scratch_slot
116+
scratch_slot = device_offset
117+
next_scratch_slot = scratch_slot + 1
128118

129119
out_smem = plgpu.SMEM((block_m, block_n), dtype, transforms=transforms)
130120

@@ -134,7 +124,6 @@ def _device_loop(device_offset):
134124
out_smem=out_smem,
135125
)
136126
def _(acc_ref, out_smem):
137-
pl.semaphore_wait(capacity_sem)
138127
@functools.partial(
139128
plgpu.emit_pipeline,
140129
grid=(k // block_k,),
@@ -148,7 +137,7 @@ def _(acc_ref, out_smem):
148137
def k_loop(idxs, lhs_smem, rhs_smem):
149138
plgpu.wgmma(acc_ref, lhs_smem, rhs_smem)
150139
# TODO(giorgioa): Send only for first sm_n.
151-
@pl.when(device_offset < num_devices - 1)
140+
@pl.when(next_scratch_slot <= num_devices - 1)
152141
def _():
153142
(ki,) = idxs
154143
k_slice = pl.ds(ki * block_k, block_k)
@@ -161,11 +150,7 @@ def _():
161150
k_loop(scratch_ref.at[scratch_slot], rhs_ref.at[..., n_tile_slice])
162151
# Make sure the copy is fully done.
163152
plgpu.wait_smem_to_gmem(0, wait_read_only=False)
164-
# The order of signals doesn't matter here.
165-
plgpu.semaphore_signal_parallel(
166-
plgpu.SemaphoreSignal(capacity_sem, device_id=recv_dev_id),
167-
plgpu.SemaphoreSignal(received_sem, device_id=send_dev_id),
168-
)
153+
pl.semaphore_signal(received_sem, device_id=send_dev_id)
169154
# Make sure all TMAs have read SMEM before we overwrite it.
170155
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
171156
out_smem[...] = acc_ref[...].astype(out_smem.dtype)
@@ -214,15 +199,12 @@ def k_loop(_, lhs_smem, rhs_smem):
214199
result, _ = plgpu.kernel(
215200
kernel_body,
216201
out_shape=[
217-
# Out_ref. Stores full M computed in a collective way across devices.
202+
# The output, with its M dimension all-gathered.
218203
jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), dtype),
219-
# Scratch_ref. Used to buffer (2 * `block_m`) rows (because of double
220-
# buffering) of the lhs per sm_m. Accessible remotely by previous and
221-
# next devices.
222-
jax.ShapeDtypeStruct((num_sms_m, 2, block_m, k), dtype),
204+
# The scratch buffer used for the all-gather.
205+
jax.ShapeDtypeStruct((num_sms_m, num_devices, block_m, k), dtype),
223206
],
224207
scratch_shapes=[
225-
plgpu.SemaphoreType.REGULAR, # Capacity semaphore
226208
plgpu.SemaphoreType.REGULAR, # Received semaphore
227209
],
228210
grid=(num_sms_m, sm_n_tile),

0 commit comments

Comments
 (0)