Skip to content

Conversation

@CYHSM
Copy link
Contributor

@CYHSM CYHSM commented Oct 27, 2025

What does this PR do?

Adds QK-Norm to self-attention block. Without qk-norm the attention logits are computed as: (Q @ K^T) / sqrt(d_h), which is equivalent to (||q_i|| * ||k_j|| * cos(θ_ij)) / sqrt(d_h) using the geometric form of the dot product. This means the model can increase distance between logits by either scaling q or k vectors (magnitude) or adjusting the angle between them (direction). QK-Norm constrains the magnitude updates and steers the model towards directional updates (which improves training stability, see this paper for more details)

Here are the results for the runs with and without QK-Norm (r2 denotes a second run, s=slow, f=fast, g=gradients, so the first entry will run with around 25 samples/s with torch>2.9.0)

run lr qk norm loss g_mean g_max g_std samples/s
qk_rms 1.5e-4 X RMS(s) 3.219 0.726 15.75 0.336 15.12
qk_ln_r2 1.5e-4 X LN 3.229 0.777 79.09 1.085 19.11
base_ln_r2 1.5e-4   LN 3.266 0.820 64.74 1.067 22.90
qk_ln 1.5e-4 X LN 3.328 0.883 95.37 1.348 20.56
base_ln 1.5e-4   LN 3.349 0.979 127.36 1.745 23.51
qk_rms_lre3 1.5e-3 X RMS(f) 3.381 0.273 40.39 0.621 24.80
qk_rms_lre2 1.5e-2 X RMS(f) 3.788 0.375 113.46 1.241 24.86
base_lre3 1.5e-3   RMS(f) 5.906 42.87 15521.78 190.23 27.80
base_lre2 1.5e-2   RMS(f) 6.723 3.612 281.83 7.516 28.15
--> Compiled
base_ln 1.5e-4   LN - - - - 29.1
base_rms 1.5e-4   RMS(f) - - - - 30.8
qk_ln 1.5e-4 X LN - - - - 26.3
qk_rms 1.5e-4 X RMS(f) - - - - 28.2

And the loss curves for the extreme LR values:
image

General Changes

  • Add QK norm to config parameters
  • Add QK norm calculation to attention blocks
  • Remove manual RMSnorm and replace with pytorch RMSNorm
  • Added a test to see if output with and without qk norm differs

Breaking Changes

  • Configs need to be updated although use_qk_norm is set to false by default

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py) - Some still fail, might be due to torch nightly (checking now)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants