Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/MaxText/inference/kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def kv_cache_chunked_prefill(

assert not self.kv_quant, "Not support kv_quant now."
if decoder_segment_ids is not None:
self.batch, segment_id_seq_len = decoder_segment_ids.shape
_, segment_id_seq_len = decoder_segment_ids.shape
assert self.key_seq_len == segment_id_seq_len, f"{self.key_seq_len=}, {segment_id_seq_len=} should match."

assert key.dtype == value.dtype, "Key and Value Dtypes should match."
Expand Down Expand Up @@ -694,6 +694,12 @@ def value_body(i, val):

else:
one_hot_indices = one_hot_indices.astype(int)

# Align batch size for cache with new token in decoding
if cached_key.value.shape[2] != one_token_key_shaped_for_cache.shape[2]:
cached_key.value = jnp.repeat(cached_key.value, one_token_key_shaped_for_cache.shape[2], axis=2)
cached_value.value = jnp.repeat(cached_value.value, one_token_value_shaped_for_cache.shape[2], axis=2)

cached_key.value = jax.lax.dynamic_update_index_in_dim(
cached_key.value, one_token_key_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis
)
Expand Down Expand Up @@ -773,6 +779,11 @@ def kv_cache_autoregressive(
use_ragged_attention,
)
active_indicator = jnp.zeros((self.batch, 1), dtype=jnp.int32) + DECODING_ACTIVE_SEQUENCE_INDICATOR

# Align batch size for cached segment IDs with indicator in decoding
if cached_ar_segment_id_var.value.shape[0] != active_indicator.shape[0]:
cached_ar_segment_id_var.value = jnp.repeat(cached_ar_segment_id_var.value, active_indicator.shape[0], axis=0)

cached_ar_segment_id_var.value = jax.lax.dynamic_update_index_in_dim(
cached_ar_segment_id_var.value, active_indicator, jnp.squeeze(cache_ar_index_var.value), 1
)
Expand Down
46 changes: 29 additions & 17 deletions src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __call__(
ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))),
compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))),
reshape_q=cfg.reshape_q,
model_mode=self.model_mode,
model_mode=model_mode,
)

attention_lnx = attention_layer(
Expand All @@ -144,7 +144,7 @@ def __call__(
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=self.model_mode,
model_mode=model_mode,
)

if model_mode == MODEL_MODE_PREFILL:
Expand All @@ -161,7 +161,7 @@ def __call__(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="mlp",
model_mode=self.model_mode,
model_mode=model_mode,
config=cfg,
quant=self.quant,
)(lnx, deterministic=deterministic)
Expand Down Expand Up @@ -223,7 +223,7 @@ def __call__(
) -> jnp.ndarray:
for lyr in range(self.num_decoder_layers):
inputs = self.decoder_layer(
config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode
config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=model_mode
)(
inputs,
decoder_segment_ids,
Expand Down Expand Up @@ -374,7 +374,7 @@ def get_decoder_layers(self):
case DecoderBlockType.DEFAULT:
return [DecoderLayer]
case DecoderBlockType.LLAMA2:
return [llama2.LlamaDecoderLayer]
return [llama2.LlamaDecoderLayerToLinen]
case DecoderBlockType.MISTRAL:
# TODO(ranran): update to Mistral with sliding window attention
return [mistral.MistralDecoderLayer]
Expand Down Expand Up @@ -480,7 +480,7 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me
length=length,
metadata_params={nn.PARTITION_NAME: metadata_axis_name},
)
return scan_fn(config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, model_mode=self.model_mode, **kwargs)
return scan_fn(config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs)

def get_pipeline_stage_module(self, decoder_blocks):
"""get pipeline stage module"""
Expand Down Expand Up @@ -525,13 +525,14 @@ def _apply_embedding(
decoder_input_tokens,
decoder_positions,
deterministic,
model_mode,
image_embeddings=None,
bidirectional_mask=None,
):
"""Applies token and positional embeddings to the input tokens."""
cfg = self.config

y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=self.model_mode)
y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode)

# Merge the image embeddings with the text embeddings for multimodal models
if image_embeddings is not None and cfg.use_multimodal:
Expand Down Expand Up @@ -559,11 +560,11 @@ def _apply_embedding(
embedding_init=nn.initializers.normal(stddev=1.0),
name="position_embedder",
config=cfg,
)(decoder_positions, model_mode=self.model_mode)
)(decoder_positions, model_mode=model_mode)
return y

@nn.compact
def _apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, deterministic):
def _apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, deterministic, model_mode):
"""Applies final normalization and projects hidden states to logits."""

cfg = self.config
Expand Down Expand Up @@ -608,7 +609,7 @@ def _apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determ
)(
y
) # We do not quantize the logits matmul.
if self.model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
logits = nn.with_logical_constraint(logits, (None, None, "activation_vocab"))
else:
logits = nn.with_logical_constraint(
Expand All @@ -628,6 +629,7 @@ def __call__(
decoder_positions,
decoder_segment_ids=None,
deterministic=False,
model_mode=MODEL_MODE_TRAIN,
previous_chunk=None,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
Expand All @@ -644,6 +646,7 @@ def __call__(
decoder_input_tokens,
decoder_positions,
deterministic,
model_mode,
image_embeddings,
bidirectional_mask
)
Expand All @@ -655,12 +658,12 @@ def __call__(
decoder_segment_ids,
decoder_positions,
deterministic,
self.model_mode,
model_mode,
)
if cfg.using_pipeline_parallelism:
if cfg.pipeline_fsdp_ag_once:
partition_spec = self.pipeline_module.get_weight_sharding(
y, decoder_segment_ids, decoder_positions, deterministic, self.model_mode
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
)
else:
partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
Expand All @@ -680,6 +683,7 @@ def __call__(
"dense_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
)(y, *broadcast_args)
if num_moe_layers_outside_pp > 0:
y, _ = self.scan_decoder_layers(
Expand All @@ -689,6 +693,7 @@ def __call__(
"moe_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
)(y, *broadcast_args)
y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec)
else: # Not DeepSeek
Expand All @@ -704,6 +709,7 @@ def __call__(
"layers_outside_pipeline",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
)(y, *broadcast_args)
else:
if cfg.scan_layers:
Expand All @@ -723,6 +729,7 @@ def __call__(
"dense_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
)(y, *broadcast_args)
moe_layer = RemattedBlockLayers[1]
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
Expand All @@ -734,13 +741,15 @@ def __call__(
"moe_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
)(y, *broadcast_args)
elif cfg.decoder_block == DecoderBlockType.GEMMA3:
y = self._apply_gemma3_scanned_blocks(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
bidirectional_mask,
previous_chunk,
page_state,
Expand All @@ -763,6 +772,7 @@ def __call__(
"layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
**layer_kwargs,
)(y, *broadcast_args)
else:
Expand All @@ -783,7 +793,7 @@ def __call__(
decoder_segment_ids,
decoder_positions,
deterministic,
self.model_mode,
model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
Expand Down Expand Up @@ -813,7 +823,7 @@ def __call__(
decoder_segment_ids,
decoder_positions,
deterministic,
self.model_mode,
model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
Expand All @@ -825,7 +835,7 @@ def __call__(
# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
hidden_state = y

logits = self._apply_output_head(shared_embedding, hidden_state, deterministic)
logits = self._apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)

# The API of the Decoder is now a tuple, providing both the main output
# and the raw hidden state needed for auxiliary tasks.
Expand All @@ -837,6 +847,7 @@ def _apply_gemma3_scanned_blocks(
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
bidirectional_mask,
previous_chunk,
page_state,
Expand All @@ -863,7 +874,7 @@ def _apply_gemma3_scanned_blocks(
decoder_segment_ids,
decoder_positions,
deterministic,
self.model_mode,
model_mode,
)
y, _ = self.scan_decoder_layers(
cfg,
Expand All @@ -872,6 +883,7 @@ def _apply_gemma3_scanned_blocks(
"layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=self.model_mode,
**layer_kwargs,
)(y, *broadcast_args, **layer_call_kwargs)

Expand All @@ -888,7 +900,7 @@ def _apply_gemma3_scanned_blocks(
decoder_segment_ids,
decoder_positions,
deterministic,
self.model_mode,
model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
Expand Down
Loading
Loading