Skip to content

Commit b48c899

Browse files
committed
Migrate LlamaDecoderLayer to NNX
1 parent c9039d1 commit b48c899

File tree

6 files changed

+194
-149
lines changed

6 files changed

+194
-149
lines changed

src/MaxText/inference/kvcache.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,11 @@ def __init__(
310310
self.model_mode = model_mode
311311
self.use_chunked_prefill = use_chunked_prefill
312312

313-
self._initialize_prefill_caches(model_mode)
314-
self._initialize_ar_cache_vars(model_mode)
313+
if self.model_mode in (MODEL_MODE_PREFILL):
314+
self._initialize_prefill_caches(model_mode)
315+
if self.model_mode in (MODEL_MODE_AUTOREGRESSIVE):
316+
self._initialize_prefill_caches(model_mode)
317+
self._initialize_ar_cache_vars(model_mode)
315318

316319
@property
317320
def prefill_key_vars(self):

src/MaxText/layers/decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def get_decoder_layers(self):
374374
case DecoderBlockType.DEFAULT:
375375
return [DecoderLayer]
376376
case DecoderBlockType.LLAMA2:
377-
return [llama2.LlamaDecoderLayer]
377+
return [llama2.LlamaDecoderLayerToLinen]
378378
case DecoderBlockType.MISTRAL:
379379
# TODO(ranran): update to Mistral with sliding window attention
380380
return [mistral.MistralDecoderLayer]

src/MaxText/layers/llama2.py

Lines changed: 115 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,121 @@
1919
import jax.numpy as jnp
2020
from jax.ad_checkpoint import checkpoint_name
2121
from jax.sharding import Mesh
22-
# from jax.experimental.pallas.ops.tpu import flash_attention
2322

2423
from flax import linen as nn
24+
from flax import nnx
2525

2626
from MaxText.inference import page_manager
2727
from MaxText.common_types import Config
28-
from MaxText.layers.linears import mlp_block
28+
from MaxText.layers.linears import MlpBlock
29+
from MaxText.layers import initializers
30+
from MaxText.layers import nnx_wrappers
2931
from MaxText.layers import quantizations
30-
from MaxText.layers.attentions import attention_as_linen
32+
from MaxText.layers.attentions import Attention
3133
from MaxText.layers.quantizations import AqtQuantization as Quant
32-
from MaxText.layers.normalizations import rms_norm
33-
from MaxText.common_types import MODEL_MODE_PREFILL
34+
from MaxText.layers.normalizations import RMSNorm
35+
from MaxText.common_types import MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
3436

3537

3638
# -----------------------------------------
3739
# The Decoder Layer specific for Llama2
3840
# -----------------------------------------
3941

4042

41-
class LlamaDecoderLayer(nn.Module):
43+
class LlamaDecoderLayer(nnx.Module):
4244
"""Transformer decoder layer that attends to the encoder."""
4345

44-
config: Config
45-
mesh: Mesh
46-
model_mode: str
47-
quant: None | Quant = None
46+
def __init__(
47+
self,
48+
config: Config,
49+
model_mode: str,
50+
mesh: Mesh,
51+
rngs: nnx.Rngs,
52+
quant: None | Quant = None,
53+
):
54+
55+
self.config = config
56+
self.mesh = mesh
57+
self.quant = quant
58+
59+
batch_size = 1 if model_mode == MODEL_MODE_PREFILL else config.micro_batch_size_to_train_on
60+
61+
if model_mode == MODEL_MODE_PREFILL:
62+
seq_len = config.max_prefill_predict_length
63+
elif model_mode == MODEL_MODE_AUTOREGRESSIVE:
64+
seq_len = 1
65+
else:
66+
seq_len = config.max_target_length
67+
68+
dummy_inputs_shape = (batch_size, seq_len, config.emb_dim)
69+
70+
self.pre_self_attention_layer_norm = RMSNorm(
71+
num_features=config.emb_dim,
72+
dtype=config.dtype,
73+
weight_dtype=config.weight_dtype,
74+
kernel_axes=("norm",),
75+
epsilon=config.normalization_layer_epsilon,
76+
rngs=rngs,
77+
)
78+
79+
self.self_attention = Attention(
80+
config=config,
81+
num_query_heads=config.num_query_heads,
82+
num_kv_heads=config.num_kv_heads,
83+
head_dim=config.head_dim,
84+
max_target_length=config.max_target_length,
85+
max_prefill_predict_length=config.max_prefill_predict_length,
86+
attention_kernel=config.attention,
87+
inputs_q_shape=dummy_inputs_shape,
88+
inputs_kv_shape=dummy_inputs_shape,
89+
mesh=mesh,
90+
dtype=config.dtype,
91+
weight_dtype=config.weight_dtype,
92+
dropout_rate=config.dropout_rate,
93+
float32_qk_product=config.float32_qk_product,
94+
float32_logits=config.float32_logits,
95+
quant=self.quant,
96+
kv_quant=quantizations.configure_kv_quant(config),
97+
prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))),
98+
ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))),
99+
compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))),
100+
reshape_q=config.reshape_q,
101+
use_ragged_attention=config.use_ragged_attention,
102+
ragged_block_size=config.ragged_block_size,
103+
model_mode=model_mode,
104+
rngs=rngs,
105+
)
106+
107+
self.post_self_attention_layer_norm = RMSNorm(
108+
num_features=config.emb_dim,
109+
dtype=config.dtype,
110+
weight_dtype=config.weight_dtype,
111+
kernel_axes=("norm",),
112+
epsilon=config.normalization_layer_epsilon,
113+
rngs=rngs,
114+
)
115+
116+
self.mlp = MlpBlock(
117+
in_features=config.emb_dim,
118+
intermediate_dim=config.mlp_dim,
119+
activations=config.mlp_activations,
120+
intermediate_dropout_rate=config.dropout_rate,
121+
dtype=config.dtype,
122+
weight_dtype=config.weight_dtype,
123+
config=config,
124+
quant=self.quant,
125+
model_mode=model_mode,
126+
rngs=rngs,
127+
)
128+
129+
self.dropout = nnx.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs)
130+
131+
if model_mode == MODEL_MODE_PREFILL:
132+
self.activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
133+
else:
134+
self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
135+
48136

49-
@nn.compact
50137
def __call__(
51138
self,
52139
inputs,
@@ -59,57 +146,15 @@ def __call__(
59146
previous_chunk=None,
60147
):
61148
cfg = self.config
62-
mesh = self.mesh
63149

64-
if model_mode == MODEL_MODE_PREFILL:
65-
activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
66-
else:
67-
activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
68-
69-
inputs = nn.with_logical_constraint(inputs, activation_axis_names)
150+
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
70151
inputs = checkpoint_name(inputs, "decoder_layer_input")
71-
lnx_rms = rms_norm(
72-
num_features=inputs.shape[-1],
73-
dtype=cfg.dtype,
74-
weight_dtype=cfg.weight_dtype,
75-
name="pre_self_attention_layer_norm",
76-
kernel_axes=("norm",),
77-
epsilon=cfg.normalization_layer_epsilon,
78-
)
79-
lnx = lnx_rms(inputs)
152+
lnx = self.pre_self_attention_layer_norm(inputs)
80153

81-
lnx = nn.with_logical_constraint(lnx, activation_axis_names)
154+
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
82155

83156
# Self-attention block
84-
attention_layer = attention_as_linen(
85-
config=cfg,
86-
num_query_heads=cfg.num_query_heads,
87-
num_kv_heads=cfg.num_kv_heads,
88-
head_dim=cfg.head_dim,
89-
max_target_length=cfg.max_target_length,
90-
max_prefill_predict_length=cfg.max_prefill_predict_length,
91-
attention_kernel=cfg.attention,
92-
inputs_q_shape=lnx.shape,
93-
inputs_kv_shape=lnx.shape,
94-
mesh=mesh,
95-
dtype=cfg.dtype,
96-
weight_dtype=cfg.weight_dtype,
97-
dropout_rate=cfg.dropout_rate,
98-
name="self_attention",
99-
float32_qk_product=cfg.float32_qk_product,
100-
float32_logits=cfg.float32_logits,
101-
quant=self.quant,
102-
kv_quant=quantizations.configure_kv_quant(cfg),
103-
prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))),
104-
ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))),
105-
compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))),
106-
reshape_q=cfg.reshape_q,
107-
use_ragged_attention=cfg.use_ragged_attention,
108-
ragged_block_size=cfg.ragged_block_size,
109-
model_mode=model_mode,
110-
)
111-
112-
attention_lnx = attention_layer(
157+
attention_lnx = self.self_attention(
113158
lnx,
114159
lnx,
115160
decoder_positions,
@@ -121,40 +166,20 @@ def __call__(
121166
previous_chunk=previous_chunk,
122167
)
123168

124-
attention_lnx = nn.with_logical_constraint(attention_lnx, activation_axis_names)
169+
attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names)
125170
intermediate_inputs = inputs + attention_lnx
126171

127172
# Fully Connected
128-
hidden_states = rms_norm(
129-
num_features=intermediate_inputs.shape[-1],
130-
dtype=cfg.dtype,
131-
weight_dtype=cfg.weight_dtype,
132-
name="post_self_attention_layer_norm",
133-
kernel_axes=("norm",),
134-
epsilon=cfg.normalization_layer_epsilon,
135-
)(intermediate_inputs)
136-
hidden_states = nn.with_logical_constraint(hidden_states, activation_axis_names)
173+
hidden_states = self.post_self_attention_layer_norm(intermediate_inputs)
174+
hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names)
137175

138176
# MLP block.
139-
mlp_lnx = mlp_block(
140-
in_features=hidden_states.shape[-1],
141-
intermediate_dim=cfg.mlp_dim,
142-
activations=cfg.mlp_activations,
143-
intermediate_dropout_rate=cfg.dropout_rate,
144-
dtype=cfg.dtype,
145-
weight_dtype=cfg.weight_dtype,
146-
name="mlp",
147-
config=cfg,
148-
quant=self.quant,
149-
model_mode=model_mode,
150-
)(hidden_states, deterministic=deterministic)
151-
mlp_lnx = nn.with_logical_constraint(mlp_lnx, activation_axis_names)
177+
mlp_lnx = self.mlp(hidden_states, deterministic=deterministic)
178+
mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names)
152179

153180
layer_output = mlp_lnx + intermediate_inputs
154-
155-
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
156-
157-
layer_output = nn.with_logical_constraint(layer_output, activation_axis_names)
181+
layer_output = self.dropout(layer_output, deterministic=deterministic)
182+
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
158183

159184
if cfg.record_internal_nn_metrics:
160185
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
@@ -169,3 +194,9 @@ def __call__(
169194
return layer_output, None
170195
else:
171196
return layer_output
197+
198+
199+
LlamaDecoderLayerToLinen = nnx_wrappers.to_linen_class(
200+
LlamaDecoderLayer,
201+
base_metadata_fn=initializers.variable_to_logically_partitioned,
202+
)

src/MaxText/layers/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def __init__(self, config: Config, mesh: Mesh, quant: Quant, *, model_mode: str
257257
else:
258258
seq_len = cfg.max_target_length
259259

260-
batch_size = cfg.micro_batch_size_to_train_on
260+
batch_size = 1 if self.model_mode == MODEL_MODE_PREFILL else cfg.micro_batch_size_to_train_on
261261
dummy_decoder_input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
262262
dummy_decoder_positions = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
263263

0 commit comments

Comments
 (0)