19
19
import jax .numpy as jnp
20
20
from jax .ad_checkpoint import checkpoint_name
21
21
from jax .sharding import Mesh
22
- # from jax.experimental.pallas.ops.tpu import flash_attention
23
22
24
23
from flax import linen as nn
24
+ from flax import nnx
25
25
26
26
from MaxText .inference import page_manager
27
27
from MaxText .common_types import Config
28
- from MaxText .layers .linears import mlp_block
28
+ from MaxText .layers .linears import MlpBlock
29
+ from MaxText .layers import initializers
30
+ from MaxText .layers import nnx_wrappers
29
31
from MaxText .layers import quantizations
30
- from MaxText .layers .attentions import attention_as_linen
32
+ from MaxText .layers .attentions import Attention
31
33
from MaxText .layers .quantizations import AqtQuantization as Quant
32
- from MaxText .layers .normalizations import rms_norm
33
- from MaxText .common_types import MODEL_MODE_PREFILL
34
+ from MaxText .layers .normalizations import RMSNorm
35
+ from MaxText .common_types import MODEL_MODE_PREFILL , MODEL_MODE_AUTOREGRESSIVE
34
36
35
37
36
38
# -----------------------------------------
37
39
# The Decoder Layer specific for Llama2
38
40
# -----------------------------------------
39
41
40
42
41
- class LlamaDecoderLayer (nn .Module ):
43
+ class LlamaDecoderLayer (nnx .Module ):
42
44
"""Transformer decoder layer that attends to the encoder."""
43
45
44
- config : Config
45
- mesh : Mesh
46
- model_mode : str
47
- quant : None | Quant = None
46
+ def __init__ (
47
+ self ,
48
+ config : Config ,
49
+ model_mode : str ,
50
+ mesh : Mesh ,
51
+ rngs : nnx .Rngs ,
52
+ quant : None | Quant = None ,
53
+ ):
54
+
55
+ self .config = config
56
+ self .mesh = mesh
57
+ self .quant = quant
58
+
59
+ batch_size = 1 if model_mode == MODEL_MODE_PREFILL else config .micro_batch_size_to_train_on
60
+
61
+ if model_mode == MODEL_MODE_PREFILL :
62
+ seq_len = config .max_prefill_predict_length
63
+ elif model_mode == MODEL_MODE_AUTOREGRESSIVE :
64
+ seq_len = 1
65
+ else :
66
+ seq_len = config .max_target_length
67
+
68
+ dummy_inputs_shape = (batch_size , seq_len , config .emb_dim )
69
+
70
+ self .pre_self_attention_layer_norm = RMSNorm (
71
+ num_features = config .emb_dim ,
72
+ dtype = config .dtype ,
73
+ weight_dtype = config .weight_dtype ,
74
+ kernel_axes = ("norm" ,),
75
+ epsilon = config .normalization_layer_epsilon ,
76
+ rngs = rngs ,
77
+ )
78
+
79
+ self .self_attention = Attention (
80
+ config = config ,
81
+ num_query_heads = config .num_query_heads ,
82
+ num_kv_heads = config .num_kv_heads ,
83
+ head_dim = config .head_dim ,
84
+ max_target_length = config .max_target_length ,
85
+ max_prefill_predict_length = config .max_prefill_predict_length ,
86
+ attention_kernel = config .attention ,
87
+ inputs_q_shape = dummy_inputs_shape ,
88
+ inputs_kv_shape = dummy_inputs_shape ,
89
+ mesh = mesh ,
90
+ dtype = config .dtype ,
91
+ weight_dtype = config .weight_dtype ,
92
+ dropout_rate = config .dropout_rate ,
93
+ float32_qk_product = config .float32_qk_product ,
94
+ float32_logits = config .float32_logits ,
95
+ quant = self .quant ,
96
+ kv_quant = quantizations .configure_kv_quant (config ),
97
+ prefill_cache_axis_order = tuple (map (int , config .prefill_cache_axis_order .split ("," ))),
98
+ ar_cache_axis_order = tuple (map (int , config .ar_cache_axis_order .split ("," ))),
99
+ compute_axis_order = tuple (map (int , config .compute_axis_order .split ("," ))),
100
+ reshape_q = config .reshape_q ,
101
+ use_ragged_attention = config .use_ragged_attention ,
102
+ ragged_block_size = config .ragged_block_size ,
103
+ model_mode = model_mode ,
104
+ rngs = rngs ,
105
+ )
106
+
107
+ self .post_self_attention_layer_norm = RMSNorm (
108
+ num_features = config .emb_dim ,
109
+ dtype = config .dtype ,
110
+ weight_dtype = config .weight_dtype ,
111
+ kernel_axes = ("norm" ,),
112
+ epsilon = config .normalization_layer_epsilon ,
113
+ rngs = rngs ,
114
+ )
115
+
116
+ self .mlp = MlpBlock (
117
+ in_features = config .emb_dim ,
118
+ intermediate_dim = config .mlp_dim ,
119
+ activations = config .mlp_activations ,
120
+ intermediate_dropout_rate = config .dropout_rate ,
121
+ dtype = config .dtype ,
122
+ weight_dtype = config .weight_dtype ,
123
+ config = config ,
124
+ quant = self .quant ,
125
+ model_mode = model_mode ,
126
+ rngs = rngs ,
127
+ )
128
+
129
+ self .dropout = nnx .Dropout (rate = config .dropout_rate , broadcast_dims = (- 2 ,), rngs = rngs )
130
+
131
+ if model_mode == MODEL_MODE_PREFILL :
132
+ self .activation_axis_names = ("activation_batch" , "prefill_activation_norm_length" , "activation_embed" )
133
+ else :
134
+ self .activation_axis_names = ("activation_batch" , "activation_norm_length" , "activation_embed" )
135
+
48
136
49
- @nn .compact
50
137
def __call__ (
51
138
self ,
52
139
inputs ,
@@ -59,57 +146,15 @@ def __call__(
59
146
previous_chunk = None ,
60
147
):
61
148
cfg = self .config
62
- mesh = self .mesh
63
149
64
- if model_mode == MODEL_MODE_PREFILL :
65
- activation_axis_names = ("activation_batch" , "prefill_activation_norm_length" , "activation_embed" )
66
- else :
67
- activation_axis_names = ("activation_batch" , "activation_norm_length" , "activation_embed" )
68
-
69
- inputs = nn .with_logical_constraint (inputs , activation_axis_names )
150
+ inputs = nn .with_logical_constraint (inputs , self .activation_axis_names )
70
151
inputs = checkpoint_name (inputs , "decoder_layer_input" )
71
- lnx_rms = rms_norm (
72
- num_features = inputs .shape [- 1 ],
73
- dtype = cfg .dtype ,
74
- weight_dtype = cfg .weight_dtype ,
75
- name = "pre_self_attention_layer_norm" ,
76
- kernel_axes = ("norm" ,),
77
- epsilon = cfg .normalization_layer_epsilon ,
78
- )
79
- lnx = lnx_rms (inputs )
152
+ lnx = self .pre_self_attention_layer_norm (inputs )
80
153
81
- lnx = nn .with_logical_constraint (lnx , activation_axis_names )
154
+ lnx = nn .with_logical_constraint (lnx , self . activation_axis_names )
82
155
83
156
# Self-attention block
84
- attention_layer = attention_as_linen (
85
- config = cfg ,
86
- num_query_heads = cfg .num_query_heads ,
87
- num_kv_heads = cfg .num_kv_heads ,
88
- head_dim = cfg .head_dim ,
89
- max_target_length = cfg .max_target_length ,
90
- max_prefill_predict_length = cfg .max_prefill_predict_length ,
91
- attention_kernel = cfg .attention ,
92
- inputs_q_shape = lnx .shape ,
93
- inputs_kv_shape = lnx .shape ,
94
- mesh = mesh ,
95
- dtype = cfg .dtype ,
96
- weight_dtype = cfg .weight_dtype ,
97
- dropout_rate = cfg .dropout_rate ,
98
- name = "self_attention" ,
99
- float32_qk_product = cfg .float32_qk_product ,
100
- float32_logits = cfg .float32_logits ,
101
- quant = self .quant ,
102
- kv_quant = quantizations .configure_kv_quant (cfg ),
103
- prefill_cache_axis_order = tuple (map (int , cfg .prefill_cache_axis_order .split ("," ))),
104
- ar_cache_axis_order = tuple (map (int , cfg .ar_cache_axis_order .split ("," ))),
105
- compute_axis_order = tuple (map (int , cfg .compute_axis_order .split ("," ))),
106
- reshape_q = cfg .reshape_q ,
107
- use_ragged_attention = cfg .use_ragged_attention ,
108
- ragged_block_size = cfg .ragged_block_size ,
109
- model_mode = model_mode ,
110
- )
111
-
112
- attention_lnx = attention_layer (
157
+ attention_lnx = self .self_attention (
113
158
lnx ,
114
159
lnx ,
115
160
decoder_positions ,
@@ -121,40 +166,20 @@ def __call__(
121
166
previous_chunk = previous_chunk ,
122
167
)
123
168
124
- attention_lnx = nn .with_logical_constraint (attention_lnx , activation_axis_names )
169
+ attention_lnx = nn .with_logical_constraint (attention_lnx , self . activation_axis_names )
125
170
intermediate_inputs = inputs + attention_lnx
126
171
127
172
# Fully Connected
128
- hidden_states = rms_norm (
129
- num_features = intermediate_inputs .shape [- 1 ],
130
- dtype = cfg .dtype ,
131
- weight_dtype = cfg .weight_dtype ,
132
- name = "post_self_attention_layer_norm" ,
133
- kernel_axes = ("norm" ,),
134
- epsilon = cfg .normalization_layer_epsilon ,
135
- )(intermediate_inputs )
136
- hidden_states = nn .with_logical_constraint (hidden_states , activation_axis_names )
173
+ hidden_states = self .post_self_attention_layer_norm (intermediate_inputs )
174
+ hidden_states = nn .with_logical_constraint (hidden_states , self .activation_axis_names )
137
175
138
176
# MLP block.
139
- mlp_lnx = mlp_block (
140
- in_features = hidden_states .shape [- 1 ],
141
- intermediate_dim = cfg .mlp_dim ,
142
- activations = cfg .mlp_activations ,
143
- intermediate_dropout_rate = cfg .dropout_rate ,
144
- dtype = cfg .dtype ,
145
- weight_dtype = cfg .weight_dtype ,
146
- name = "mlp" ,
147
- config = cfg ,
148
- quant = self .quant ,
149
- model_mode = model_mode ,
150
- )(hidden_states , deterministic = deterministic )
151
- mlp_lnx = nn .with_logical_constraint (mlp_lnx , activation_axis_names )
177
+ mlp_lnx = self .mlp (hidden_states , deterministic = deterministic )
178
+ mlp_lnx = nn .with_logical_constraint (mlp_lnx , self .activation_axis_names )
152
179
153
180
layer_output = mlp_lnx + intermediate_inputs
154
-
155
- layer_output = nn .Dropout (rate = cfg .dropout_rate , broadcast_dims = (- 2 ,))(layer_output , deterministic = deterministic )
156
-
157
- layer_output = nn .with_logical_constraint (layer_output , activation_axis_names )
181
+ layer_output = self .dropout (layer_output , deterministic = deterministic )
182
+ layer_output = nn .with_logical_constraint (layer_output , self .activation_axis_names )
158
183
159
184
if cfg .record_internal_nn_metrics :
160
185
self .sow ("intermediates" , "activation_mean" , jnp .mean (layer_output ))
@@ -169,3 +194,9 @@ def __call__(
169
194
return layer_output , None
170
195
else :
171
196
return layer_output
197
+
198
+
199
+ LlamaDecoderLayerToLinen = nnx_wrappers .to_linen_class (
200
+ LlamaDecoderLayer ,
201
+ base_metadata_fn = initializers .variable_to_logically_partitioned ,
202
+ )
0 commit comments