@@ -82,7 +82,7 @@ def all_gather_lhs_matmul(
82
82
plgpu .SwizzleTransform (swizzle ),
83
83
)
84
84
85
- def kernel_body (lhs_ref , rhs_ref , out_ref , scratch_ref , received_sem ):
85
+ def kernel_body (lhs_ref , rhs_ref , out_ref , scratch_ref , out_smem , received_sem ):
86
86
sm_m = lax .axis_index ('sm_m' )
87
87
sm_n = lax .axis_index ('sm_n' )
88
88
n_start = sm_n * n_shard_per_sm_n
@@ -111,66 +111,15 @@ def _device_loop(device_offset):
111
111
device_m_slice = pl .ds (
112
112
lax .rem (device_offset + dev_id , num_devices ) * m_shard , block_m
113
113
)
114
- n_tile_slice = pl .ds (n_start , block_n )
115
114
116
115
scratch_slot = device_offset
117
116
next_scratch_slot = scratch_slot + 1
118
117
119
- out_smem = plgpu .SMEM ((block_m , block_n ), dtype , transforms = transforms )
120
-
121
- @functools .partial (
122
- pl .run_scoped ,
123
- acc_ref = plgpu .ACC ((block_m , block_n )),
124
- out_smem = out_smem ,
125
- )
126
- def _ (acc_ref , out_smem ):
127
- @functools .partial (
128
- plgpu .emit_pipeline ,
129
- grid = (k // block_k ,),
130
- in_specs = [
131
- plgpu .BlockSpec ((block_m , block_k ), lambda k : (0 , k ), transforms = transforms ),
132
- plgpu .BlockSpec ((block_k , block_n ), lambda k : (k , 0 ), transforms = transforms ),
133
- ],
134
- max_concurrent_steps = max_concurrent_steps ,
135
- delay_release = 1 ,
136
- )
137
- def k_loop (idxs , lhs_smem , rhs_smem ):
138
- plgpu .wgmma (acc_ref , lhs_smem , rhs_smem )
139
- # TODO(giorgioa): Send only for first sm_n.
140
- @pl .when (next_scratch_slot <= num_devices - 1 )
141
- def _ ():
142
- (ki ,) = idxs
143
- k_slice = pl .ds (ki * block_k , block_k )
144
- plgpu .copy_smem_to_gmem (
145
- lhs_smem , send_scratch_ref .at [next_scratch_slot , :, k_slice ]
146
- )
147
- # We only delay release by 1 step, so we need to wait for the
148
- # previous copies.
149
- plgpu .wait_smem_to_gmem (1 , wait_read_only = True )
150
- k_loop (scratch_ref .at [scratch_slot ], rhs_ref .at [..., n_tile_slice ])
151
- # Make sure the copy is fully done.
152
- plgpu .wait_smem_to_gmem (0 , wait_read_only = False )
153
- pl .semaphore_signal (received_sem , device_id = send_dev_id )
154
- # Make sure all TMAs have read SMEM before we overwrite it.
155
- plgpu .wait_smem_to_gmem (0 , wait_read_only = True )
156
- out_smem [...] = acc_ref [...].astype (out_smem .dtype )
157
- plgpu .commit_smem ()
158
- plgpu .copy_smem_to_gmem (
159
- out_smem ,
160
- out_ref .at [device_m_slice , n_tile_slice ].at [m_tile_slice ],
161
- )
162
-
163
- @pl .loop (1 , n_shard_per_sm_n // block_n )
164
- def _n_loop (ni ):
165
- n_tile_slice = pl .ds (n_start + ni * block_n , block_n )
166
-
118
+ def compute (n_tile_slice , send : bool ):
167
119
@functools .partial (
168
- pl .run_scoped ,
169
- acc_ref = plgpu .ACC ((block_m , block_n )),
170
- out_smem = out_smem ,
120
+ pl .run_scoped , acc_ref = plgpu .ACC ((block_m , block_n ))
171
121
)
172
- def _ (acc_ref , out_smem ):
173
-
122
+ def _ (acc_ref ):
174
123
@functools .partial (
175
124
plgpu .emit_pipeline ,
176
125
grid = (k // block_k ,),
@@ -190,17 +139,40 @@ def _(acc_ref, out_smem):
190
139
],
191
140
max_concurrent_steps = max_concurrent_steps ,
192
141
)
193
- def k_loop (_ , lhs_smem , rhs_smem ):
142
+ def k_loop (idxs , lhs_smem , rhs_smem ):
194
143
plgpu .wgmma (acc_ref , lhs_smem , rhs_smem )
144
+ if send :
145
+ # TODO(giorgioa): Send only for first sm_n.
146
+ @pl .when (next_scratch_slot <= num_devices - 1 )
147
+ def _ ():
148
+ (ki ,) = idxs
149
+ k_slice = pl .ds (ki * block_k , block_k )
150
+ plgpu .copy_smem_to_gmem (
151
+ lhs_smem , send_scratch_ref .at [next_scratch_slot , :, k_slice ]
152
+ )
153
+ # We only delay release by 1 step, so we need to wait for the
154
+ # previous copies.
155
+ plgpu .wait_smem_to_gmem (1 , wait_read_only = True )
195
156
k_loop (scratch_ref .at [scratch_slot ], rhs_ref .at [..., n_tile_slice ])
157
+ if send :
158
+ # Make sure the copy is done and signal the receiving device.
159
+ plgpu .wait_smem_to_gmem (0 , wait_read_only = False )
160
+ pl .semaphore_signal (received_sem , device_id = send_dev_id )
196
161
# Make sure all TMAs have read SMEM before we overwrite it.
197
162
plgpu .wait_smem_to_gmem (0 , wait_read_only = True )
198
163
out_smem [...] = acc_ref [...].astype (out_smem .dtype )
199
164
plgpu .commit_smem ()
200
165
plgpu .copy_smem_to_gmem (
201
- out_smem , out_ref .at [device_m_slice , n_tile_slice ].at [m_tile_slice ]
166
+ out_smem ,
167
+ out_ref .at [device_m_slice , n_tile_slice ].at [m_tile_slice ],
202
168
)
203
169
170
+ compute (pl .ds (n_start , block_n ), send = True )
171
+
172
+ @pl .loop (1 , n_shard_per_sm_n // block_n )
173
+ def _n_loop (ni ):
174
+ compute (pl .ds (n_start + ni * block_n , block_n ), send = False )
175
+
204
176
# Wait for the next scratch to arrive --- see the device loop invariant.
205
177
pl .semaphore_wait (received_sem )
206
178
@@ -216,6 +188,7 @@ def k_loop(_, lhs_smem, rhs_smem):
216
188
jax .ShapeDtypeStruct ((num_sms_m , num_devices , block_m , k ), dtype ),
217
189
],
218
190
scratch_shapes = [
191
+ plgpu .SMEM ((block_m , block_n ), dtype , transforms = transforms ),
219
192
plgpu .SemaphoreType .REGULAR , # Received semaphore
220
193
],
221
194
grid = (num_sms_m , sm_n_tile ),
0 commit comments