@@ -72,6 +72,10 @@ def all_gather_lhs_matmul(
72
72
f"{ n_shard_per_sm_n = } must be divisible by { block_n = } "
73
73
)
74
74
num_sms_m = max_num_sms // sm_n_tile
75
+ if num_sms_m < (m_shard // block_m ) and sm_n_tile > 1 :
76
+ # We never synchronize the N SMs across the different steps of the M
77
+ # loop, so they can start overwriting each other's data.
78
+ raise NotImplementedError ("The kernel has races when M is large and sm_n_tile > 1" )
75
79
76
80
swizzle = min (
77
81
plgpu .find_swizzle (block_k * jnp .finfo (element_type ).bits , "lhs" ),
@@ -99,22 +103,12 @@ def _m_loop(idx):
99
103
(mi ,) = idx
100
104
m_tile_slice = pl .ds (mi * block_m , block_m )
101
105
102
- # For some reason ptxas spills if we unroll the loop over k
103
- copy_block = 32
104
- @pl .loop (0 , k , step = copy_block )
105
- def _k_copy_loop (ki ):
106
- k_slice = pl .ds (ki , copy_block )
107
- scratch_ref [0 , :, k_slice ] = lhs_ref [m_tile_slice , k_slice ]
108
-
109
- @pl .loop (0 , num_devices )
110
- def _device_loop (device_offset ):
106
+ def device_step (lhs_source_ref , next_scratch_slot , device_offset ):
107
+ # Loop invariant: lhs_source_ref is ready to be used
111
108
device_m_slice = pl .ds (
112
109
lax .rem (device_offset + dev_id , num_devices ) * m_shard , block_m
113
110
)
114
111
115
- scratch_slot = device_offset
116
- next_scratch_slot = scratch_slot + 1
117
-
118
112
def compute (n_tile_slice , send : bool ):
119
113
@functools .partial (
120
114
pl .run_scoped , acc_ref = plgpu .ACC ((block_m , block_n ))
@@ -143,7 +137,7 @@ def k_loop(idxs, lhs_smem, rhs_smem):
143
137
plgpu .wgmma (acc_ref , lhs_smem , rhs_smem )
144
138
if send :
145
139
# TODO(giorgioa): Send only for first sm_n.
146
- @pl .when (next_scratch_slot <= num_devices - 1 )
140
+ @pl .when (next_scratch_slot < num_devices - 1 )
147
141
def _ ():
148
142
(ki ,) = idxs
149
143
k_slice = pl .ds (ki * block_k , block_k )
@@ -153,7 +147,7 @@ def _():
153
147
# We only delay release by 1 step, so we need to wait for the
154
148
# previous copies.
155
149
plgpu .wait_smem_to_gmem (1 , wait_read_only = True )
156
- k_loop (scratch_ref . at [ scratch_slot ] , rhs_ref .at [..., n_tile_slice ])
150
+ k_loop (lhs_source_ref , rhs_ref .at [..., n_tile_slice ])
157
151
if send :
158
152
# Make sure the copy is done and signal the receiving device.
159
153
plgpu .wait_smem_to_gmem (0 , wait_read_only = False )
@@ -176,6 +170,11 @@ def _n_loop(ni):
176
170
# Wait for the next scratch to arrive --- see the device loop invariant.
177
171
pl .semaphore_wait (received_sem )
178
172
173
+ device_step (lhs_ref .at [m_tile_slice ], 0 , 0 )
174
+ @pl .loop (1 , num_devices )
175
+ def _device_loop (device_offset ):
176
+ device_step (scratch_ref .at [device_offset - 1 ], device_offset , device_offset )
177
+
179
178
# Make sure all copies are fully done.
180
179
plgpu .wait_smem_to_gmem (0 , wait_read_only = True )
181
180
@@ -185,7 +184,7 @@ def _n_loop(ni):
185
184
# The output, with its M dimension all-gathered.
186
185
jax .ShapeDtypeStruct ((axis_size * m_shard , n_shard ), dtype ),
187
186
# The scratch buffer used for the all-gather.
188
- jax .ShapeDtypeStruct ((num_sms_m , num_devices , block_m , k ), dtype ),
187
+ jax .ShapeDtypeStruct ((num_sms_m , num_devices - 1 , block_m , k ), dtype ),
189
188
],
190
189
scratch_shapes = [
191
190
plgpu .SMEM ((block_m , block_n ), dtype , transforms = transforms ),
0 commit comments