Skip to content

Commit cddb59e

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[pallas:mgpu] Support per-input delay_release in emit_pipeline
This change allows specifying different `delay_release` values for each input in `emit_pipeline` (`emit_pipeline_warp_specialized` left for a followup cl). This is achieved by stratifying the inputs based on their `delay_release` values and adjusting the pipeline logic to handle the different effective number of stages for each stratum. PiperOrigin-RevId: 801093645
1 parent c56b8bc commit cddb59e

File tree

2 files changed

+95
-37
lines changed

2 files changed

+95
-37
lines changed

jax/_src/pallas/mosaic_gpu/pipeline.py

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _get_block_shape(spec: pallas_core.BlockSpec):
6969
@jax.tree_util.register_dataclass
7070
@dataclasses.dataclass(frozen=True)
7171
class BufferedRef:
72-
spec: pallas_core.BlockSpec = dataclasses.field(metadata={"static": True})
72+
spec: gpu_core.BlockSpec = dataclasses.field(metadata={"static": True})
7373
is_index_invariant: bool = dataclasses.field(metadata={"static": True})
7474
gmem_ref: state.AbstractRef
7575
# ``None`` if the ref is pinned to GMEM; otherwise, has shape
@@ -184,6 +184,20 @@ def __eq__(self, other: _Slice) -> jax.Array: # type: ignore
184184
)
185185

186186

187+
def _downcast_spec(
188+
spec: gpu_core.BlockSpec | pallas_core.BlockSpec,
189+
) -> gpu_core.BlockSpec:
190+
if isinstance(spec, gpu_core.BlockSpec):
191+
return spec
192+
193+
return gpu_core.BlockSpec(
194+
block_shape=spec.block_shape,
195+
index_map=spec.index_map,
196+
memory_space=spec.memory_space,
197+
pipeline_mode=spec.pipeline_mode,
198+
)
199+
200+
187201
def emit_pipeline(
188202
body: Callable[..., T],
189203
*,
@@ -220,25 +234,16 @@ def emit_pipeline(
220234
pipeline and returns the final carry value (if ``init_carry`` was used),
221235
otherwise it returns None.
222236
"""
237+
238+
in_specs = tuple(map(_downcast_spec, in_specs))
239+
out_specs = tuple(map(_downcast_spec, out_specs))
223240
# TODO(justinfu): Factor out common code between warp-specialized and
224241
# normal pipelines.
225-
delay_release = None
226-
for in_spec in in_specs:
227-
if not isinstance(in_spec, gpu_core.BlockSpec):
228-
delay_release = 0
229-
continue
230-
delay_release = in_spec.delay_release
231-
if in_spec.delay_release != delay_release:
232-
raise NotImplementedError(
233-
"All inputs must have the same delay_release, but"
234-
f" {in_spec.delay_release=} != {delay_release=}"
235-
)
236-
237-
delay_release = delay_release or 0
238-
if max_concurrent_steps <= delay_release:
242+
delay_release_levels = sorted({s.delay_release for s in in_specs}) or [0]
243+
if delay_release_levels and max_concurrent_steps <= delay_release_levels[0]:
239244
raise ValueError(
240-
"max_concurrent_steps must be greater than delay_release, but"
241-
f" {max_concurrent_steps=}, {delay_release=}"
245+
"max_concurrent_steps must be greater than all delay_release values,"
246+
f" but {max_concurrent_steps=} and {delay_release_levels=}."
242247
)
243248

244249
num_steps = math.prod(grid)
@@ -316,7 +321,7 @@ def scoped_pipeline(
316321

317322
def loop_body(step, carry):
318323
slot = lax.rem(step, max_concurrent_steps)
319-
indices, fetch_indices, last_store_slices, prev_body_carry = carry
324+
indices, fetch_index_levels, last_store_slices, prev_body_carry = carry
320325

321326
if barrier_ref is not None:
322327
# Wait for the current GMEM->SMEM copy to complete, if any.
@@ -368,31 +373,43 @@ def loop_body(step, carry):
368373
if copies_out_in_loop:
369374
gpu_primitives.commit_smem_to_gmem_group()
370375

371-
fetch_step = step + (max_concurrent_steps - delay_release)
372-
fetch_slot = lax.rem(fetch_step, max_concurrent_steps)
373-
374-
def do_fetch():
375-
for bref in in_brefs:
376-
bref.copy_in(fetch_slot, fetch_indices, barrier_ref)
377-
378-
jax.lax.cond(
379-
lax.bitwise_and(step >= delay_release, fetch_step < num_steps),
380-
do_fetch,
381-
lambda: None,
382-
)
376+
for delay_release, fetch_indices in zip(
377+
delay_release_levels, fetch_index_levels
378+
):
379+
fetch_step = step + (max_concurrent_steps - delay_release)
380+
fetch_slot = lax.rem(fetch_step, max_concurrent_steps)
381+
382+
# pylint: disable=cell-var-from-loop
383+
def do_fetch():
384+
for bref in in_brefs:
385+
if bref.spec.delay_release == delay_release:
386+
bref.copy_in(fetch_slot, fetch_indices, barrier_ref)
387+
# pylint: enable=cell-var-from-loop
388+
389+
jax.lax.cond(
390+
lax.bitwise_and(step >= delay_release, fetch_step < num_steps),
391+
do_fetch,
392+
lambda: None,
393+
)
383394

395+
next_fetch_indices_levels = [
396+
_inc_grid_by_1(fetch_indices, grid)
397+
for fetch_indices in fetch_index_levels
398+
]
384399
return (
385400
_inc_grid_by_1(indices, grid),
386-
_inc_grid_by_1(fetch_indices, grid),
401+
next_fetch_indices_levels,
387402
new_store_slices,
388403
next_body_carry if init_carry is not None else None,
389404
)
390405

391-
# Invariant: ``indices`` and ``fetch_indices`` are always
392-
# ``max_concurrent_steps-delay_release`` apart.
393-
fetch_indices = indices
394-
for _ in range(max_concurrent_steps-delay_release):
395-
fetch_indices = _inc_grid_by_1(fetch_indices, grid)
406+
fetch_index_levels = []
407+
for delay_release in delay_release_levels:
408+
fetch_indices = indices
409+
for _ in range(max_concurrent_steps - delay_release):
410+
fetch_indices = _inc_grid_by_1(fetch_indices, grid)
411+
fetch_index_levels.append(fetch_indices)
412+
396413
# TODO(justinfu): Only store base pointer instead of all indices.
397414
last_store_slices = [
398415
None
@@ -404,7 +421,7 @@ def do_fetch():
404421
0,
405422
num_steps,
406423
loop_body,
407-
(indices, fetch_indices, last_store_slices, init_carry),
424+
(indices, fetch_index_levels, last_store_slices, init_carry),
408425
)
409426

410427
# Outputs invariant to the sequential axis are never written from inside the

tests/pallas/mosaic_gpu_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4711,6 +4711,47 @@ def pipeline(*gmem_refs):
47114711
y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32)
47124712
np.testing.assert_allclose(kernel(x, y), x + y, atol=1e-4)
47134713

4714+
def test_different_delay_release(self):
4715+
self.skip_if_wg_semantics() # Crashes!
4716+
m, n = 128, 64
4717+
blk_m, blk_n = 32, 64
4718+
in_specs = [
4719+
plgpu.BlockSpec(
4720+
block_shape=(blk_m, blk_n),
4721+
index_map=lambda i, j: (i, j),
4722+
delay_release=delay,
4723+
)
4724+
for delay in range(3)
4725+
]
4726+
out_spec = pl.BlockSpec(
4727+
block_shape=(blk_m, blk_n),
4728+
index_map=lambda i, j: (i, j),
4729+
)
4730+
4731+
def tiled_add_kernel(_, x_smem, y_smem, z_smem, o_smem):
4732+
o_smem[...] = x_smem[...] + y_smem[...] + z_smem[...]
4733+
4734+
def pipeline(*gmem_refs):
4735+
grid = (m // blk_m, n // blk_n)
4736+
return mgpu_pipeline.emit_pipeline(
4737+
tiled_add_kernel,
4738+
grid=grid,
4739+
max_concurrent_steps=4,
4740+
in_specs=in_specs,
4741+
out_specs=[out_spec],
4742+
)(*gmem_refs)
4743+
4744+
kernel = self.kernel(
4745+
pipeline,
4746+
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32),
4747+
grid=(1,),
4748+
grid_names=("_",)
4749+
)
4750+
x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32)
4751+
y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32)
4752+
z = jax.random.uniform(jax.random.key(3), (m, n), dtype=jnp.float32)
4753+
np.testing.assert_allclose(kernel(x, y, z), x + y + z)
4754+
47144755
@parameterized.product(
47154756
delay_release=[0, 1],
47164757
)

0 commit comments

Comments
 (0)