@@ -82,19 +82,14 @@ def all_gather_lhs_matmul(
82
82
plgpu .SwizzleTransform (swizzle ),
83
83
)
84
84
85
- def kernel_body (lhs_ref , rhs_ref , out_ref , scratch_ref , capacity_sem , received_sem ):
85
+ def kernel_body (lhs_ref , rhs_ref , out_ref , scratch_ref , received_sem ):
86
86
sm_m = lax .axis_index ('sm_m' )
87
87
sm_n = lax .axis_index ('sm_n' )
88
88
n_start = sm_n * n_shard_per_sm_n
89
89
scratch_ref = scratch_ref .at [sm_m ]
90
90
91
91
dev_id = lax .axis_index (axis_name )
92
92
send_dev_id = lax .rem (dev_id + axis_size - 1 , axis_size )
93
- recv_dev_id = lax .rem (dev_id + 1 , axis_size )
94
- # NOTE: Technically we should signal the recv_dev_id (and our signal would
95
- # be received from send_dev_id), but if everyone signals in a ring after a
96
- # barrier then it's equivalent to a local signal.
97
- pl .semaphore_signal (capacity_sem )
98
93
send_scratch_ref = plgpu .remote_ref (
99
94
scratch_ref , send_dev_id , device_id_type = pl .DeviceIdType .LOGICAL
100
95
)
@@ -118,13 +113,8 @@ def _device_loop(device_offset):
118
113
)
119
114
n_tile_slice = pl .ds (n_start , block_n )
120
115
121
- # Loop invariant: scratch_ref.at[scratch_slot] is ready to be used
122
- # We're double buffering the scratch space. At each step, we read from
123
- # scratch_ref.at[scratch_slot] and write to scratch_ref.at[next_scratch_slot]
124
- # located on the send_dev_id. We swap the slots after completing a step,
125
- # which lets us overlap the copy with compute.
126
- scratch_slot = lax .rem (device_offset , 2 )
127
- next_scratch_slot = 1 - scratch_slot
116
+ scratch_slot = device_offset
117
+ next_scratch_slot = scratch_slot + 1
128
118
129
119
out_smem = plgpu .SMEM ((block_m , block_n ), dtype , transforms = transforms )
130
120
@@ -134,7 +124,6 @@ def _device_loop(device_offset):
134
124
out_smem = out_smem ,
135
125
)
136
126
def _ (acc_ref , out_smem ):
137
- pl .semaphore_wait (capacity_sem )
138
127
@functools .partial (
139
128
plgpu .emit_pipeline ,
140
129
grid = (k // block_k ,),
@@ -148,7 +137,7 @@ def _(acc_ref, out_smem):
148
137
def k_loop (idxs , lhs_smem , rhs_smem ):
149
138
plgpu .wgmma (acc_ref , lhs_smem , rhs_smem )
150
139
# TODO(giorgioa): Send only for first sm_n.
151
- @pl .when (device_offset < num_devices - 1 )
140
+ @pl .when (next_scratch_slot <= num_devices - 1 )
152
141
def _ ():
153
142
(ki ,) = idxs
154
143
k_slice = pl .ds (ki * block_k , block_k )
@@ -161,11 +150,7 @@ def _():
161
150
k_loop (scratch_ref .at [scratch_slot ], rhs_ref .at [..., n_tile_slice ])
162
151
# Make sure the copy is fully done.
163
152
plgpu .wait_smem_to_gmem (0 , wait_read_only = False )
164
- # The order of signals doesn't matter here.
165
- plgpu .semaphore_signal_parallel (
166
- plgpu .SemaphoreSignal (capacity_sem , device_id = recv_dev_id ),
167
- plgpu .SemaphoreSignal (received_sem , device_id = send_dev_id ),
168
- )
153
+ pl .semaphore_signal (received_sem , device_id = send_dev_id )
169
154
# Make sure all TMAs have read SMEM before we overwrite it.
170
155
plgpu .wait_smem_to_gmem (0 , wait_read_only = True )
171
156
out_smem [...] = acc_ref [...].astype (out_smem .dtype )
@@ -214,15 +199,12 @@ def k_loop(_, lhs_smem, rhs_smem):
214
199
result , _ = plgpu .kernel (
215
200
kernel_body ,
216
201
out_shape = [
217
- # Out_ref. Stores full M computed in a collective way across devices .
202
+ # The output, with its M dimension all-gathered .
218
203
jax .ShapeDtypeStruct ((axis_size * m_shard , n_shard ), dtype ),
219
- # Scratch_ref. Used to buffer (2 * `block_m`) rows (because of double
220
- # buffering) of the lhs per sm_m. Accessible remotely by previous and
221
- # next devices.
222
- jax .ShapeDtypeStruct ((num_sms_m , 2 , block_m , k ), dtype ),
204
+ # The scratch buffer used for the all-gather.
205
+ jax .ShapeDtypeStruct ((num_sms_m , num_devices , block_m , k ), dtype ),
223
206
],
224
207
scratch_shapes = [
225
- plgpu .SemaphoreType .REGULAR , # Capacity semaphore
226
208
plgpu .SemaphoreType .REGULAR , # Received semaphore
227
209
],
228
210
grid = (num_sms_m , sm_n_tile ),
0 commit comments