Skip to content

Commit f3d83cd

Browse files
Davis YoshidaGoogle-ML-Automation
authored andcommitted
Support Hijax types in emit_pipeline.
PiperOrigin-RevId: 843337588
1 parent 23fc5e8 commit f3d83cd

File tree

10 files changed

+494
-48
lines changed

10 files changed

+494
-48
lines changed

jax/_src/pallas/core.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,27 @@ def _get_sds(aval: jax_core.AbstractValue):
13651365
core_map_p = jax_core.Primitive("core_map")
13661366
core_map_p.multiple_results = True
13671367

1368+
def _core_map_is_high(*avals, jaxpr, **params):
1369+
del avals, params
1370+
return jaxpr.is_high
1371+
core_map_p.is_high = _core_map_is_high # type: ignore[method-assign]
1372+
1373+
def _core_map_to_lojax(*consts, jaxpr, mesh, **params):
1374+
closed_hi_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
1375+
with (
1376+
tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()),
1377+
jax_core.extend_axis_env_nd(mesh.shape.items()),
1378+
):
1379+
closed_lo_jaxpr = pe.lower_jaxpr(closed_hi_jaxpr)
1380+
assert not closed_lo_jaxpr.is_high
1381+
return core_map_p.bind(
1382+
*closed_lo_jaxpr.consts,
1383+
jaxpr=closed_lo_jaxpr.jaxpr,
1384+
mesh=mesh,
1385+
**params,
1386+
)
1387+
core_map_p.to_lojax = _core_map_to_lojax
1388+
13681389

13691390
def core_map(
13701391
mesh,

jax/_src/pallas/mosaic/pipeline.py

Lines changed: 96 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,19 @@
2323
from typing import Any, Union
2424

2525
import jax
26+
from jax import core as jax_core
2627
from jax import lax
2728
from jax import tree_util
2829
from jax._src import util as jax_util
2930
from jax._src.pallas import core as pallas_core
3031
from jax._src.pallas import primitives as primitives
3132
from jax._src.pallas.mosaic import core as tpu_core
3233
from jax._src.pallas.mosaic import helpers as tpu_helpers
33-
from jax._src.pallas.mosaic import tpu_info
3434
from jax._src.pallas.mosaic import primitives as tpu_primitives
35+
from jax._src.pallas.mosaic import tpu_info
36+
from jax._src.state import types as state_types
3537
from jax.experimental import pallas as pl
3638
import jax.numpy as jnp
37-
import numpy as np
3839

3940

4041
SMEM = tpu_core.MemorySpace.SMEM
@@ -79,17 +80,32 @@ def add_leaves(i, x):
7980
def _get_tpu_generation() -> int:
8081
return tpu_info.get_tpu_info().generation
8182

82-
def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]:
83-
# For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
84-
# and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
85-
# (2, 3, 128, 128) -> (1, 1, 8, 128).
83+
84+
def _make_tiling(
85+
shape: tuple[int, ...], ty: jax_core.AbstractValue
86+
) -> tuple[int | None, ...]:
87+
"""Compute a tiling for the given shape and type.
88+
89+
For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
90+
and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
91+
(2, 3, 128, 128) -> (1, 1, 8, 128).
92+
93+
Types are not required to have a dtype, so for such types we return None for
94+
all dimensions because their tiling is unknown.
95+
"""
96+
8697
if len(shape) < 2:
8798
raise ValueError(f"Shape must have at least 2 dimensions: {shape=}")
99+
100+
if not hasattr(ty, 'dtype'):
101+
return (None,) * len(shape)
102+
88103
leading_dims, final_dims = shape[:-2], shape[-2:]
89104
# We want to find the minimum power of 2 that fits the second-minor dimension
90105
# of shape, with maximum value 8.
91106
second_minor, _ = final_dims
92-
packing = 4 // dtype.itemsize
107+
108+
packing = 4 // ty.dtype.itemsize
93109
max_tiling = _TILING[0]
94110
second_minor_tiling = (1 + int(_get_tpu_generation() < 4)) * packing
95111
while second_minor_tiling < min(second_minor, max_tiling):
@@ -114,13 +130,18 @@ def _make_block_ds(
114130
assert isinstance(out, pl.Slice)
115131
return out
116132

117-
def _create_blocked_slice(block_index: jax.Array | int,
118-
block_size: int,
119-
dim_size: int,
120-
tiling: int):
133+
134+
def _create_blocked_slice(
135+
block_index: jax.Array | int,
136+
block_size: int,
137+
dim_size: int,
138+
tiling: int | None,
139+
):
121140
block_start = block_size * block_index
122141
if (dim_rem := dim_size % block_size) == 0:
123142
return pl.ds(block_start, block_size)
143+
if tiling is None:
144+
raise ValueError("If tiling is None, block_size must divide dim_size.")
124145
if block_size % tiling != 0:
125146
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
126147
num_blocks = pl.cdiv(dim_size, block_size)
@@ -137,12 +158,15 @@ def _create_bounded_slice(slice_start: jax.Array | int,
137158
slice_size: jax.Array | int,
138159
block_size: int,
139160
dim_size: int,
140-
tiling: int):
141-
if block_size % tiling != 0:
161+
tiling: int | None):
162+
if tiling is not None and block_size % tiling != 0:
142163
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
143164
# We assume by construction that slice_size <= block_size. We also assume
144165
# that the slice_start is already aligned to the tiling.
145166

167+
if tiling is None:
168+
return pl.ds(slice_start, slice_size)
169+
146170
# If we are out of bound, we need to round the slice size down to the nearest
147171
# multiple of the tiling.
148172
is_oob = slice_start + slice_size > dim_size
@@ -157,7 +181,7 @@ def _create_bounded_slice(slice_start: jax.Array | int,
157181

158182
def _make_block_slice(
159183
block_index: jax.Array, block_size: pl.BlockDim | int | None, size: int,
160-
tiling: int
184+
tiling: int | None
161185
) -> pl.Slice | slice | int | jax.Array:
162186
# Computes a slice given a block index and block size. In the default case,
163187
# we return slice(block_index * block_size, (block_index + 1) * block_size).
@@ -332,7 +356,7 @@ def block_shape(self) -> Sequence[pl.BlockDim | int | None] | None:
332356
def compute_index(self):
333357
return self.spec.index_map
334358

335-
def get_dma_slice(self, src_shape, src_dtype, grid_indices):
359+
def get_dma_slice(self, src_ty, grid_indices):
336360
# We need to handle blocks that might go OOB in the src array. An in bounds
337361
# block looks like this (for array shape (600, 600) and block shape
338362
# (256, 256)):
@@ -379,10 +403,14 @@ def get_dma_slice(self, src_shape, src_dtype, grid_indices):
379403
# Suppose A is now (601, 600), instead of picking a (88, 256)-sized block
380404
# for the last iteration on that dimension, we will pick the next highest
381405
# tile multiple, i.e. (96, 256).
406+
407+
if (src_shape := getattr(src_ty, "shape", None)) is None:
408+
raise ValueError(f'Type {src_ty} does not have a type.')
409+
382410
if len(src_shape) < 2:
383411
raise NotImplementedError("Must use >1D values.")
384412

385-
tiling = _make_tiling(src_shape, src_dtype)
413+
tiling = _make_tiling(src_shape, src_ty)
386414
block_indices = self.compute_index(*grid_indices)
387415
return tuple(
388416
_make_block_slice(bi, bs, ss, t)
@@ -403,6 +431,14 @@ def with_spec(self, spec: pl.BlockSpec) -> BufferedRefBase:
403431
"""Returns a new BufferedRefBase with the given block spec."""
404432
raise NotImplementedError()
405433

434+
def _ref_to_value_aval(ref):
435+
"""Return the inner of a ref, or a ShapedArray for TransformedRefs."""
436+
return (
437+
jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype)
438+
if isinstance(ref, state_types.TransformedRef)
439+
else jax.typeof(ref).inner_aval
440+
)
441+
406442

407443
# TODO(justinfu): Refactor and rename slot fields to reflect cumulative values
408444
# instead of slot index.
@@ -413,7 +449,6 @@ class BufferedRef(BufferedRefBase):
413449
414450
Attributes:
415451
spec: pallas blockspec.
416-
dtype: dtype for buffers.
417452
buffer_type: enum indicating whether this is an input, output, or in/out
418453
accumulator buffered reference.
419454
window_ref: a multiple-buffer to hold the working and dirty buffers used
@@ -444,7 +479,6 @@ class BufferedRef(BufferedRefBase):
444479
copy.
445480
"""
446481
_spec: pl.BlockSpec = dataclasses.field(metadata=dict(static=True))
447-
dtype: Any = dataclasses.field(metadata=dict(static=True))
448482
_buffer_type: BufferType = dataclasses.field(metadata=dict(static=True))
449483
window_ref: ArrayRef | None
450484
accum_ref: ArrayRef | None
@@ -507,7 +541,7 @@ def buffer_types() -> type[BufferType]:
507541
return BufferType
508542

509543
@classmethod
510-
def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
544+
def create(cls, spec: pl.BlockSpec, dtype_or_type, buffer_type, buffer_count,
511545
needs_swap_ref=True,
512546
grid_rank=None,
513547
use_lookahead=False,
@@ -516,7 +550,8 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
516550
517551
Args:
518552
spec: pallas blockspec.
519-
dtype: dtype for buffers.
553+
dtype_or_type: dtype or aval for buffers. If an aval, the shape is
554+
ignored.
520555
buffer_type: enum indicating whether this is an input, output, or in/out
521556
accumulator buffered reference.
522557
needs_swap_ref: whether a swap slots tracker needs to be allocated.
@@ -527,9 +562,18 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
527562
Returns:
528563
Initialized BufferedRef
529564
"""
565+
566+
# (123, 456) is a dummy shape since we never use ty without
567+
# calling .update(shape=...) first.
568+
ty = (
569+
dtype_or_type
570+
if isinstance(dtype_or_type, jax_core.AbstractValue)
571+
else jax_core.ShapedArray((123, 456), dtype_or_type)
572+
)
573+
530574
block_shape = _get_block_shape(spec)
531575
if buffer_type is BufferType.ACCUMULATOR:
532-
accum_ref = VMEM(block_shape, dtype)
576+
accum_ref = VMEM.from_type(ty.update(shape=block_shape))
533577
else:
534578
accum_ref = None
535579
if source_memory_space == VMEM:
@@ -541,7 +585,6 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
541585
f"Cannot hold a non-buffered ref in {spec.memory_space=}")
542586
return cls(
543587
_spec=spec,
544-
dtype=dtype,
545588
_buffer_type=buffer_type,
546589
window_ref=None, # to be bound to existing ref by the pipeline routine
547590
accum_ref=accum_ref,
@@ -570,11 +613,12 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
570613
raise ValueError(
571614
"grid_rank must be specified when use_lookahead is True."
572615
)
616+
617+
buffer_ty = ty.update(shape=(buffer_count, *block_shape))
573618
return cls(
574619
_spec=spec,
575-
dtype=dtype,
576620
_buffer_type=buffer_type,
577-
window_ref=buffer_memory_space((buffer_count,) + block_shape, dtype),
621+
window_ref=buffer_memory_space.from_type(buffer_ty),
578622
accum_ref=accum_ref,
579623
copy_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None,
580624
wait_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None,
@@ -601,22 +645,28 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
601645
)
602646

603647
@classmethod
604-
def input(cls, spec, dtype, buffer_count=2, **kwargs):
605-
return cls.create(spec, dtype, BufferType.INPUT, buffer_count, **kwargs)
648+
def input(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
649+
return cls.create(
650+
spec, dtype_or_type, BufferType.INPUT, buffer_count, **kwargs
651+
)
606652

607653
@classmethod
608-
def output(cls, spec, dtype, buffer_count=2, **kwargs):
609-
return cls.create(spec, dtype, BufferType.OUTPUT, buffer_count, **kwargs)
654+
def output(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
655+
return cls.create(
656+
spec, dtype_or_type, BufferType.OUTPUT, buffer_count, **kwargs
657+
)
610658

611659
@classmethod
612-
def accumulator(cls, spec, dtype, buffer_count=2, **kwargs):
613-
return cls.create(spec, dtype, BufferType.ACCUMULATOR, buffer_count,
614-
**kwargs)
660+
def accumulator(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
661+
return cls.create(
662+
spec, dtype_or_type, BufferType.ACCUMULATOR, buffer_count, **kwargs
663+
)
615664

616665
@classmethod
617-
def input_output(cls, spec, dtype, buffer_count=2, **kwargs):
618-
return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, buffer_count,
619-
**kwargs)
666+
def input_output(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
667+
return cls.create(
668+
spec, dtype_or_type, BufferType.INPUT_OUTPUT, buffer_count, **kwargs
669+
)
620670

621671
@property
622672
def block_shape(self):
@@ -923,7 +973,7 @@ def copy_in(self, src_ref, grid_indices):
923973
if self.swap is not None:
924974
self.swap[0] = True
925975
slot = self.current_copy_in_slot
926-
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
976+
src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices)
927977
dst_slice = tuple(
928978
pl.ds(0, s.size)
929979
for s, bd in zip(src_slice, self.block_shape)
@@ -944,7 +994,7 @@ def copy_out(self, dst_ref, grid_indices):
944994
if self.swap is not None:
945995
self.swap[0] = True
946996
slot = self.current_copy_out_slot
947-
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
997+
dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices)
948998
src_slice = tuple(
949999
pl.ds(0, s.size)
9501000
for s, bd in zip(dst_slice, self.block_shape)
@@ -962,7 +1012,7 @@ def wait_in(self, src_ref, grid_indices):
9621012
if not self.is_buffered: return
9631013
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
9641014
assert self.sem_recvs is not None
965-
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
1015+
src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices)
9661016
dst_slice = tuple(
9671017
pl.ds(0, s.size)
9681018
for s, bd in zip(src_slice, self.block_shape)
@@ -984,7 +1034,7 @@ def wait_out(self, dst_ref, grid_indices):
9841034
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
9851035
assert self.sem_sends is not None
9861036
wait_slot = self.current_wait_out_slot
987-
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
1037+
dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices)
9881038
src_slice = tuple(
9891039
pl.ds(0, s.size)
9901040
for s, bd in zip(dst_slice, self.block_shape)
@@ -1682,7 +1732,9 @@ def make_input_bref(in_spec, in_ref):
16821732
use_lookahead = in_spec.pipeline_mode.use_lookahead
16831733
if use_lookahead and grid is None:
16841734
raise ValueError("Grid must be specified when using lookahead.")
1685-
return BufferedRef.input(in_spec, in_ref.dtype, buffer_count,
1735+
1736+
in_aval = _ref_to_value_aval(in_ref)
1737+
return BufferedRef.input(in_spec, in_aval, buffer_count,
16861738
needs_swap_ref=needs_swap_ref,
16871739
grid_rank=len(grid),
16881740
use_lookahead=use_lookahead,
@@ -1695,11 +1747,13 @@ def make_output_bref(out_spec, out_ref, accumulate):
16951747
if out_spec.pipeline_mode.use_lookahead:
16961748
raise ValueError("Output buffering does not support lookahead.")
16971749

1750+
out_aval = _ref_to_value_aval(out_ref)
1751+
16981752
if accumulate:
1699-
return BufferedRef.accumulator(out_spec, out_ref.dtype, buffer_count,
1753+
return BufferedRef.accumulator(out_spec, out_aval, buffer_count,
17001754
needs_swap_ref=needs_swap_ref,
17011755
source_memory_space=out_ref.memory_space)
1702-
return BufferedRef.output(out_spec, out_ref.dtype, buffer_count,
1756+
return BufferedRef.output(out_spec, out_aval, buffer_count,
17031757
needs_swap_ref=needs_swap_ref,
17041758
source_memory_space=out_ref.memory_space)
17051759
out_brefs = jax.tree.map(
@@ -1817,7 +1871,7 @@ def sync_copy(src: REF | BufferedRef, dst: REF | BufferedRef, indices):
18171871
bref = dst
18181872
hbm_ref = src
18191873
copy_in = True
1820-
hbm_slice = bref.get_dma_slice(hbm_ref.shape, hbm_ref.dtype, indices)
1874+
hbm_slice = bref.get_dma_slice(_ref_to_value_aval(hbm_ref), indices)
18211875
bref_slice = tuple(
18221876
pl.ds(0, s.size)
18231877
for s, bd in zip(hbm_slice, bref.block_shape)

0 commit comments

Comments
 (0)