2323from typing import Any , Union
2424
2525import jax
26+ from jax import core as jax_core
2627from jax import lax
2728from jax import tree_util
2829from jax ._src import util as jax_util
2930from jax ._src .pallas import core as pallas_core
3031from jax ._src .pallas import primitives as primitives
3132from jax ._src .pallas .mosaic import core as tpu_core
3233from jax ._src .pallas .mosaic import helpers as tpu_helpers
33- from jax ._src .pallas .mosaic import tpu_info
3434from jax ._src .pallas .mosaic import primitives as tpu_primitives
35+ from jax ._src .pallas .mosaic import tpu_info
36+ from jax ._src .state import types as state_types
3537from jax .experimental import pallas as pl
3638import jax .numpy as jnp
37- import numpy as np
3839
3940
4041SMEM = tpu_core .MemorySpace .SMEM
@@ -79,17 +80,32 @@ def add_leaves(i, x):
7980def _get_tpu_generation () -> int :
8081 return tpu_info .get_tpu_info ().generation
8182
82- def _make_tiling (shape : tuple [int , ...], dtype : np .dtype ) -> tuple [int , ...]:
83- # For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
84- # and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
85- # (2, 3, 128, 128) -> (1, 1, 8, 128).
83+
84+ def _make_tiling (
85+ shape : tuple [int , ...], ty : jax_core .AbstractValue
86+ ) -> tuple [int | None , ...]:
87+ """Compute a tiling for the given shape and type.
88+
89+ For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
90+ and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
91+ (2, 3, 128, 128) -> (1, 1, 8, 128).
92+
93+ Types are not required to have a dtype, so for such types we return None for
94+ all dimensions because their tiling is unknown.
95+ """
96+
8697 if len (shape ) < 2 :
8798 raise ValueError (f"Shape must have at least 2 dimensions: { shape = } " )
99+
100+ if not hasattr (ty , 'dtype' ):
101+ return (None ,) * len (shape )
102+
88103 leading_dims , final_dims = shape [:- 2 ], shape [- 2 :]
89104 # We want to find the minimum power of 2 that fits the second-minor dimension
90105 # of shape, with maximum value 8.
91106 second_minor , _ = final_dims
92- packing = 4 // dtype .itemsize
107+
108+ packing = 4 // ty .dtype .itemsize
93109 max_tiling = _TILING [0 ]
94110 second_minor_tiling = (1 + int (_get_tpu_generation () < 4 )) * packing
95111 while second_minor_tiling < min (second_minor , max_tiling ):
@@ -114,13 +130,18 @@ def _make_block_ds(
114130 assert isinstance (out , pl .Slice )
115131 return out
116132
117- def _create_blocked_slice (block_index : jax .Array | int ,
118- block_size : int ,
119- dim_size : int ,
120- tiling : int ):
133+
134+ def _create_blocked_slice (
135+ block_index : jax .Array | int ,
136+ block_size : int ,
137+ dim_size : int ,
138+ tiling : int | None ,
139+ ):
121140 block_start = block_size * block_index
122141 if (dim_rem := dim_size % block_size ) == 0 :
123142 return pl .ds (block_start , block_size )
143+ if tiling is None :
144+ raise ValueError ("If tiling is None, block_size must divide dim_size." )
124145 if block_size % tiling != 0 :
125146 raise ValueError (f"Block size must divide tiling: { block_size = } , { tiling = } " )
126147 num_blocks = pl .cdiv (dim_size , block_size )
@@ -137,12 +158,15 @@ def _create_bounded_slice(slice_start: jax.Array | int,
137158 slice_size : jax .Array | int ,
138159 block_size : int ,
139160 dim_size : int ,
140- tiling : int ):
141- if block_size % tiling != 0 :
161+ tiling : int | None ):
162+ if tiling is not None and block_size % tiling != 0 :
142163 raise ValueError (f"Block size must divide tiling: { block_size = } , { tiling = } " )
143164 # We assume by construction that slice_size <= block_size. We also assume
144165 # that the slice_start is already aligned to the tiling.
145166
167+ if tiling is None :
168+ return pl .ds (slice_start , slice_size )
169+
146170 # If we are out of bound, we need to round the slice size down to the nearest
147171 # multiple of the tiling.
148172 is_oob = slice_start + slice_size > dim_size
@@ -157,7 +181,7 @@ def _create_bounded_slice(slice_start: jax.Array | int,
157181
158182def _make_block_slice (
159183 block_index : jax .Array , block_size : pl .BlockDim | int | None , size : int ,
160- tiling : int
184+ tiling : int | None
161185) -> pl .Slice | slice | int | jax .Array :
162186 # Computes a slice given a block index and block size. In the default case,
163187 # we return slice(block_index * block_size, (block_index + 1) * block_size).
@@ -332,7 +356,7 @@ def block_shape(self) -> Sequence[pl.BlockDim | int | None] | None:
332356 def compute_index (self ):
333357 return self .spec .index_map
334358
335- def get_dma_slice (self , src_shape , src_dtype , grid_indices ):
359+ def get_dma_slice (self , src_ty , grid_indices ):
336360 # We need to handle blocks that might go OOB in the src array. An in bounds
337361 # block looks like this (for array shape (600, 600) and block shape
338362 # (256, 256)):
@@ -379,10 +403,14 @@ def get_dma_slice(self, src_shape, src_dtype, grid_indices):
379403 # Suppose A is now (601, 600), instead of picking a (88, 256)-sized block
380404 # for the last iteration on that dimension, we will pick the next highest
381405 # tile multiple, i.e. (96, 256).
406+
407+ if (src_shape := getattr (src_ty , "shape" , None )) is None :
408+ raise ValueError (f'Type { src_ty } does not have a type.' )
409+
382410 if len (src_shape ) < 2 :
383411 raise NotImplementedError ("Must use >1D values." )
384412
385- tiling = _make_tiling (src_shape , src_dtype )
413+ tiling = _make_tiling (src_shape , src_ty )
386414 block_indices = self .compute_index (* grid_indices )
387415 return tuple (
388416 _make_block_slice (bi , bs , ss , t )
@@ -403,6 +431,14 @@ def with_spec(self, spec: pl.BlockSpec) -> BufferedRefBase:
403431 """Returns a new BufferedRefBase with the given block spec."""
404432 raise NotImplementedError ()
405433
434+ def _ref_to_value_aval (ref ):
435+ """Return the inner of a ref, or a ShapedArray for TransformedRefs."""
436+ return (
437+ jax_core .ShapedArray (shape = ref .shape , dtype = ref .dtype )
438+ if isinstance (ref , state_types .TransformedRef )
439+ else jax .typeof (ref ).inner_aval
440+ )
441+
406442
407443# TODO(justinfu): Refactor and rename slot fields to reflect cumulative values
408444# instead of slot index.
@@ -413,7 +449,6 @@ class BufferedRef(BufferedRefBase):
413449
414450 Attributes:
415451 spec: pallas blockspec.
416- dtype: dtype for buffers.
417452 buffer_type: enum indicating whether this is an input, output, or in/out
418453 accumulator buffered reference.
419454 window_ref: a multiple-buffer to hold the working and dirty buffers used
@@ -444,7 +479,6 @@ class BufferedRef(BufferedRefBase):
444479 copy.
445480 """
446481 _spec : pl .BlockSpec = dataclasses .field (metadata = dict (static = True ))
447- dtype : Any = dataclasses .field (metadata = dict (static = True ))
448482 _buffer_type : BufferType = dataclasses .field (metadata = dict (static = True ))
449483 window_ref : ArrayRef | None
450484 accum_ref : ArrayRef | None
@@ -507,7 +541,7 @@ def buffer_types() -> type[BufferType]:
507541 return BufferType
508542
509543 @classmethod
510- def create (cls , spec : pl .BlockSpec , dtype , buffer_type , buffer_count ,
544+ def create (cls , spec : pl .BlockSpec , dtype_or_type , buffer_type , buffer_count ,
511545 needs_swap_ref = True ,
512546 grid_rank = None ,
513547 use_lookahead = False ,
@@ -516,7 +550,8 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
516550
517551 Args:
518552 spec: pallas blockspec.
519- dtype: dtype for buffers.
553+ dtype_or_type: dtype or aval for buffers. If an aval, the shape is
554+ ignored.
520555 buffer_type: enum indicating whether this is an input, output, or in/out
521556 accumulator buffered reference.
522557 needs_swap_ref: whether a swap slots tracker needs to be allocated.
@@ -527,9 +562,18 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
527562 Returns:
528563 Initialized BufferedRef
529564 """
565+
566+ # (123, 456) is a dummy shape since we never use ty without
567+ # calling .update(shape=...) first.
568+ ty = (
569+ dtype_or_type
570+ if isinstance (dtype_or_type , jax_core .AbstractValue )
571+ else jax_core .ShapedArray ((123 , 456 ), dtype_or_type )
572+ )
573+
530574 block_shape = _get_block_shape (spec )
531575 if buffer_type is BufferType .ACCUMULATOR :
532- accum_ref = VMEM ( block_shape , dtype )
576+ accum_ref = VMEM . from_type ( ty . update ( shape = block_shape ) )
533577 else :
534578 accum_ref = None
535579 if source_memory_space == VMEM :
@@ -541,7 +585,6 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
541585 f"Cannot hold a non-buffered ref in { spec .memory_space = } " )
542586 return cls (
543587 _spec = spec ,
544- dtype = dtype ,
545588 _buffer_type = buffer_type ,
546589 window_ref = None , # to be bound to existing ref by the pipeline routine
547590 accum_ref = accum_ref ,
@@ -570,11 +613,12 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
570613 raise ValueError (
571614 "grid_rank must be specified when use_lookahead is True."
572615 )
616+
617+ buffer_ty = ty .update (shape = (buffer_count , * block_shape ))
573618 return cls (
574619 _spec = spec ,
575- dtype = dtype ,
576620 _buffer_type = buffer_type ,
577- window_ref = buffer_memory_space (( buffer_count ,) + block_shape , dtype ),
621+ window_ref = buffer_memory_space . from_type ( buffer_ty ),
578622 accum_ref = accum_ref ,
579623 copy_in_slot = SMEM ((1 ,), jnp .uint32 ) if buffer_type .is_input else None ,
580624 wait_in_slot = SMEM ((1 ,), jnp .uint32 ) if buffer_type .is_input else None ,
@@ -601,22 +645,28 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
601645 )
602646
603647 @classmethod
604- def input (cls , spec , dtype , buffer_count = 2 , ** kwargs ):
605- return cls .create (spec , dtype , BufferType .INPUT , buffer_count , ** kwargs )
648+ def input (cls , spec , dtype_or_type , buffer_count = 2 , ** kwargs ):
649+ return cls .create (
650+ spec , dtype_or_type , BufferType .INPUT , buffer_count , ** kwargs
651+ )
606652
607653 @classmethod
608- def output (cls , spec , dtype , buffer_count = 2 , ** kwargs ):
609- return cls .create (spec , dtype , BufferType .OUTPUT , buffer_count , ** kwargs )
654+ def output (cls , spec , dtype_or_type , buffer_count = 2 , ** kwargs ):
655+ return cls .create (
656+ spec , dtype_or_type , BufferType .OUTPUT , buffer_count , ** kwargs
657+ )
610658
611659 @classmethod
612- def accumulator (cls , spec , dtype , buffer_count = 2 , ** kwargs ):
613- return cls .create (spec , dtype , BufferType .ACCUMULATOR , buffer_count ,
614- ** kwargs )
660+ def accumulator (cls , spec , dtype_or_type , buffer_count = 2 , ** kwargs ):
661+ return cls .create (
662+ spec , dtype_or_type , BufferType .ACCUMULATOR , buffer_count , ** kwargs
663+ )
615664
616665 @classmethod
617- def input_output (cls , spec , dtype , buffer_count = 2 , ** kwargs ):
618- return cls .create (spec , dtype , BufferType .INPUT_OUTPUT , buffer_count ,
619- ** kwargs )
666+ def input_output (cls , spec , dtype_or_type , buffer_count = 2 , ** kwargs ):
667+ return cls .create (
668+ spec , dtype_or_type , BufferType .INPUT_OUTPUT , buffer_count , ** kwargs
669+ )
620670
621671 @property
622672 def block_shape (self ):
@@ -923,7 +973,7 @@ def copy_in(self, src_ref, grid_indices):
923973 if self .swap is not None :
924974 self .swap [0 ] = True
925975 slot = self .current_copy_in_slot
926- src_slice = self .get_dma_slice (src_ref . shape , src_ref . dtype , grid_indices )
976+ src_slice = self .get_dma_slice (_ref_to_value_aval ( src_ref ) , grid_indices )
927977 dst_slice = tuple (
928978 pl .ds (0 , s .size )
929979 for s , bd in zip (src_slice , self .block_shape )
@@ -944,7 +994,7 @@ def copy_out(self, dst_ref, grid_indices):
944994 if self .swap is not None :
945995 self .swap [0 ] = True
946996 slot = self .current_copy_out_slot
947- dst_slice = self .get_dma_slice (dst_ref . shape , dst_ref . dtype , grid_indices )
997+ dst_slice = self .get_dma_slice (_ref_to_value_aval ( dst_ref ) , grid_indices )
948998 src_slice = tuple (
949999 pl .ds (0 , s .size )
9501000 for s , bd in zip (dst_slice , self .block_shape )
@@ -962,7 +1012,7 @@ def wait_in(self, src_ref, grid_indices):
9621012 if not self .is_buffered : return
9631013 assert not (self .window_ref is None or isinstance (self .window_ref , REF ))
9641014 assert self .sem_recvs is not None
965- src_slice = self .get_dma_slice (src_ref . shape , src_ref . dtype , grid_indices )
1015+ src_slice = self .get_dma_slice (_ref_to_value_aval ( src_ref ) , grid_indices )
9661016 dst_slice = tuple (
9671017 pl .ds (0 , s .size )
9681018 for s , bd in zip (src_slice , self .block_shape )
@@ -984,7 +1034,7 @@ def wait_out(self, dst_ref, grid_indices):
9841034 assert not (self .window_ref is None or isinstance (self .window_ref , REF ))
9851035 assert self .sem_sends is not None
9861036 wait_slot = self .current_wait_out_slot
987- dst_slice = self .get_dma_slice (dst_ref . shape , dst_ref . dtype , grid_indices )
1037+ dst_slice = self .get_dma_slice (_ref_to_value_aval ( dst_ref ) , grid_indices )
9881038 src_slice = tuple (
9891039 pl .ds (0 , s .size )
9901040 for s , bd in zip (dst_slice , self .block_shape )
@@ -1682,7 +1732,9 @@ def make_input_bref(in_spec, in_ref):
16821732 use_lookahead = in_spec .pipeline_mode .use_lookahead
16831733 if use_lookahead and grid is None :
16841734 raise ValueError ("Grid must be specified when using lookahead." )
1685- return BufferedRef .input (in_spec , in_ref .dtype , buffer_count ,
1735+
1736+ in_aval = _ref_to_value_aval (in_ref )
1737+ return BufferedRef .input (in_spec , in_aval , buffer_count ,
16861738 needs_swap_ref = needs_swap_ref ,
16871739 grid_rank = len (grid ),
16881740 use_lookahead = use_lookahead ,
@@ -1695,11 +1747,13 @@ def make_output_bref(out_spec, out_ref, accumulate):
16951747 if out_spec .pipeline_mode .use_lookahead :
16961748 raise ValueError ("Output buffering does not support lookahead." )
16971749
1750+ out_aval = _ref_to_value_aval (out_ref )
1751+
16981752 if accumulate :
1699- return BufferedRef .accumulator (out_spec , out_ref . dtype , buffer_count ,
1753+ return BufferedRef .accumulator (out_spec , out_aval , buffer_count ,
17001754 needs_swap_ref = needs_swap_ref ,
17011755 source_memory_space = out_ref .memory_space )
1702- return BufferedRef .output (out_spec , out_ref . dtype , buffer_count ,
1756+ return BufferedRef .output (out_spec , out_aval , buffer_count ,
17031757 needs_swap_ref = needs_swap_ref ,
17041758 source_memory_space = out_ref .memory_space )
17051759 out_brefs = jax .tree .map (
@@ -1817,7 +1871,7 @@ def sync_copy(src: REF | BufferedRef, dst: REF | BufferedRef, indices):
18171871 bref = dst
18181872 hbm_ref = src
18191873 copy_in = True
1820- hbm_slice = bref .get_dma_slice (hbm_ref . shape , hbm_ref . dtype , indices )
1874+ hbm_slice = bref .get_dma_slice (_ref_to_value_aval ( hbm_ref ) , indices )
18211875 bref_slice = tuple (
18221876 pl .ds (0 , s .size )
18231877 for s , bd in zip (hbm_slice , bref .block_shape )
0 commit comments