@@ -69,7 +69,7 @@ def _get_block_shape(spec: pallas_core.BlockSpec):
69
69
@jax .tree_util .register_dataclass
70
70
@dataclasses .dataclass (frozen = True )
71
71
class BufferedRef :
72
- spec : pallas_core .BlockSpec = dataclasses .field (metadata = {"static" : True })
72
+ spec : gpu_core .BlockSpec = dataclasses .field (metadata = {"static" : True })
73
73
is_index_invariant : bool = dataclasses .field (metadata = {"static" : True })
74
74
gmem_ref : state .AbstractRef
75
75
# ``None`` if the ref is pinned to GMEM; otherwise, has shape
@@ -184,6 +184,20 @@ def __eq__(self, other: _Slice) -> jax.Array: # type: ignore
184
184
)
185
185
186
186
187
+ def _downcast_spec (
188
+ spec : gpu_core .BlockSpec | pallas_core .BlockSpec ,
189
+ ) -> gpu_core .BlockSpec :
190
+ if isinstance (spec , gpu_core .BlockSpec ):
191
+ return spec
192
+
193
+ return gpu_core .BlockSpec (
194
+ block_shape = spec .block_shape ,
195
+ index_map = spec .index_map ,
196
+ memory_space = spec .memory_space ,
197
+ pipeline_mode = spec .pipeline_mode ,
198
+ )
199
+
200
+
187
201
def emit_pipeline (
188
202
body : Callable [..., T ],
189
203
* ,
@@ -220,25 +234,16 @@ def emit_pipeline(
220
234
pipeline and returns the final carry value (if ``init_carry`` was used),
221
235
otherwise it returns None.
222
236
"""
237
+
238
+ in_specs = tuple (map (_downcast_spec , in_specs ))
239
+ out_specs = tuple (map (_downcast_spec , out_specs ))
223
240
# TODO(justinfu): Factor out common code between warp-specialized and
224
241
# normal pipelines.
225
- delay_release = None
226
- for in_spec in in_specs :
227
- if not isinstance (in_spec , gpu_core .BlockSpec ):
228
- delay_release = 0
229
- continue
230
- delay_release = in_spec .delay_release
231
- if in_spec .delay_release != delay_release :
232
- raise NotImplementedError (
233
- "All inputs must have the same delay_release, but"
234
- f" { in_spec .delay_release = } != { delay_release = } "
235
- )
236
-
237
- delay_release = delay_release or 0
238
- if max_concurrent_steps <= delay_release :
242
+ delay_release_levels = sorted ({s .delay_release for s in in_specs }) or [0 ]
243
+ if delay_release_levels and max_concurrent_steps <= delay_release_levels [0 ]:
239
244
raise ValueError (
240
- "max_concurrent_steps must be greater than delay_release, but "
241
- f" { max_concurrent_steps = } , { delay_release = } "
245
+ "max_concurrent_steps must be greater than all delay_release values, "
246
+ f" but { max_concurrent_steps = } and { delay_release_levels = } . "
242
247
)
243
248
244
249
num_steps = math .prod (grid )
@@ -316,7 +321,7 @@ def scoped_pipeline(
316
321
317
322
def loop_body (step , carry ):
318
323
slot = lax .rem (step , max_concurrent_steps )
319
- indices , fetch_indices , last_store_slices , prev_body_carry = carry
324
+ indices , fetch_index_levels , last_store_slices , prev_body_carry = carry
320
325
321
326
if barrier_ref is not None :
322
327
# Wait for the current GMEM->SMEM copy to complete, if any.
@@ -368,31 +373,43 @@ def loop_body(step, carry):
368
373
if copies_out_in_loop :
369
374
gpu_primitives .commit_smem_to_gmem_group ()
370
375
371
- fetch_step = step + (max_concurrent_steps - delay_release )
372
- fetch_slot = lax .rem (fetch_step , max_concurrent_steps )
373
-
374
- def do_fetch ():
375
- for bref in in_brefs :
376
- bref .copy_in (fetch_slot , fetch_indices , barrier_ref )
377
-
378
- jax .lax .cond (
379
- lax .bitwise_and (step >= delay_release , fetch_step < num_steps ),
380
- do_fetch ,
381
- lambda : None ,
382
- )
376
+ for delay_release , fetch_indices in zip (
377
+ delay_release_levels , fetch_index_levels
378
+ ):
379
+ fetch_step = step + (max_concurrent_steps - delay_release )
380
+ fetch_slot = lax .rem (fetch_step , max_concurrent_steps )
381
+
382
+ # pylint: disable=cell-var-from-loop
383
+ def do_fetch ():
384
+ for bref in in_brefs :
385
+ if bref .spec .delay_release == delay_release :
386
+ bref .copy_in (fetch_slot , fetch_indices , barrier_ref )
387
+ # pylint: enable=cell-var-from-loop
388
+
389
+ jax .lax .cond (
390
+ lax .bitwise_and (step >= delay_release , fetch_step < num_steps ),
391
+ do_fetch ,
392
+ lambda : None ,
393
+ )
383
394
395
+ next_fetch_indices_levels = [
396
+ _inc_grid_by_1 (fetch_indices , grid )
397
+ for fetch_indices in fetch_index_levels
398
+ ]
384
399
return (
385
400
_inc_grid_by_1 (indices , grid ),
386
- _inc_grid_by_1 ( fetch_indices , grid ) ,
401
+ next_fetch_indices_levels ,
387
402
new_store_slices ,
388
403
next_body_carry if init_carry is not None else None ,
389
404
)
390
405
391
- # Invariant: ``indices`` and ``fetch_indices`` are always
392
- # ``max_concurrent_steps-delay_release`` apart.
393
- fetch_indices = indices
394
- for _ in range (max_concurrent_steps - delay_release ):
395
- fetch_indices = _inc_grid_by_1 (fetch_indices , grid )
406
+ fetch_index_levels = []
407
+ for delay_release in delay_release_levels :
408
+ fetch_indices = indices
409
+ for _ in range (max_concurrent_steps - delay_release ):
410
+ fetch_indices = _inc_grid_by_1 (fetch_indices , grid )
411
+ fetch_index_levels .append (fetch_indices )
412
+
396
413
# TODO(justinfu): Only store base pointer instead of all indices.
397
414
last_store_slices = [
398
415
None
@@ -404,7 +421,7 @@ def do_fetch():
404
421
0 ,
405
422
num_steps ,
406
423
loop_body ,
407
- (indices , fetch_indices , last_store_slices , init_carry ),
424
+ (indices , fetch_index_levels , last_store_slices , init_carry ),
408
425
)
409
426
410
427
# Outputs invariant to the sequential axis are never written from inside the
0 commit comments