@@ -25,6 +25,7 @@ class ModelConfig:
2525 rmsnorm_epsilon : float = 1e-6
2626 use_residual_scaling : bool = True
2727 tie_embeddings : bool = True # Whether to tie input and output embed
28+ qknorm_epsilon : float = 1e-6
2829
2930 dtype : jnp .dtype = jnp .float32
3031 attention_init : nn .initializers .Initializer = nn .initializers .normal (stddev = 0.02 )
@@ -116,6 +117,7 @@ def setup(self):
116117 cfg = self .cfg
117118 assert cfg .model_dim % cfg .num_heads == 0 , f'D { cfg .model_dim } not divisible by H { cfg .num_heads } '
118119 self .Dh = cfg .model_dim // cfg .num_heads
120+ self .eps = cfg .qknorm_epsilon
119121
120122 # Initialize rotary embeddings
121123 self .freqs_cis = init_rope (cfg .model_dim , cfg .seq_len , cfg .num_heads )
@@ -129,10 +131,13 @@ def setup(self):
129131 use_bias = False ,
130132 dtype = cfg .dtype ,
131133 )
132-
133134 self .multilinear_query = self .multilinear (name = 'query' )
134135 self .multilinear_key = self .multilinear (name = 'key' )
135136 self .multilinear_value = self .multilinear (name = 'value' )
137+ # See Henry et al. (2020) "Query Key Normalization for Transformers"
138+ seq_len = cfg .seq_len
139+ attn_scale0 = jnp .log2 (seq_len ** 2 - seq_len )
140+ self .attn_scale = self .param ('attn_scale' , nn .initializers .constant (attn_scale0 ), ())
136141 self .output_projection = nn .DenseGeneral (
137142 features = cfg .model_dim ,
138143 name = 'attn_out_proj' ,
@@ -153,8 +158,9 @@ def __call__(self, x_BxLxD: jax.Array):
153158 # Apply rotary embeddings to Q and K
154159 q_BxLxHxDh , k_BxLxHxDh = apply_rope (q_BxLxHxDh , k_BxLxHxDh , self .freqs_cis )
155160
156- # Scale queries
157- q_BxLxHxDh /= self .Dh ** 0.5
161+ # Apply QK normalization
162+ q_BxLxHxDh /= jnp .linalg .norm (q_BxLxHxDh , axis = - 1 , keepdims = True ) + self .eps
163+ k_BxLxHxDh /= jnp .linalg .norm (k_BxLxHxDh , axis = - 1 , keepdims = True ) + self .eps
158164
159165 # Compute attention scores
160166 att_BxHxLxL = jnp .einsum ('...qhd,...khd->...hqk' , q_BxLxHxDh , k_BxLxHxDh )
@@ -166,6 +172,7 @@ def __call__(self, x_BxLxD: jax.Array):
166172 # Apply mask and softmax
167173 _NEG_INF = jnp .finfo (cfg .dtype ).min
168174 att_BxHxLxL = jnp .where (mask_1x1xLxL , att_BxHxLxL , _NEG_INF )
175+ att_BxHxLxL = self .attn_scale * att_BxHxLxL # Learned scaling factor for QK norm
169176 att_BxHxLxL = jax .nn .softmax (att_BxHxLxL , axis = - 1 )
170177 att_BxHxLxL = att_BxHxLxL .astype (cfg .dtype )
171178
0 commit comments