Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
99cca4c
initial changes adding qk norm
CYHSM Oct 23, 2025
2f3598d
feat: added qk norm
CYHSM Oct 23, 2025
d918c87
added config files
CYHSM Oct 23, 2025
f182727
feat: added qk norm to attention block
CYHSM Oct 27, 2025
6a53a31
Merge: main into current branch
Oct 27, 2025
693c57d
fix: Fix spelling mistake
Oct 27, 2025
f48aa0c
Fix: Adapt configts to latest changes
Oct 27, 2025
941be96
fix: Reverse removal of RMSNorm as its imported in the test suite
CYHSM Oct 27, 2025
329beb0
Fix: Adapt configs to latest configs
Oct 27, 2025
f1152e5
fix: Compute peak memory correctly on cpu device
Oct 27, 2025
5ab3ecf
test: Require A100 GPUs for mfu test
Oct 27, 2025
4571ffe
test: Adapt tests to latest changes
Oct 27, 2025
e8e8126
test(parallelism): Adapted fsdp2+tp+tt test to recent changes.
BlueCrescent Oct 27, 2025
ce00ba3
test: add test for qk norm
CYHSM Oct 27, 2025
e692e39
Merge branch 'Modalities:main' into qk_norm
CYHSM Oct 28, 2025
61b4a2c
fixing rebase
CYHSM Oct 28, 2025
7ce6d50
fix: fix tests by making RMSNorm backward compatible
CYHSM Oct 28, 2025
28186f7
Update tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_b…
rrutmann Oct 28, 2025
f375c2a
Update src/modalities/models/parallelism/pipeline_parallelism.py
rrutmann Oct 28, 2025
0c4eee8
fix: merge pp branch into qk branch
CYHSM Oct 29, 2025
f3bfb58
added compiled model configs for throughput tests
CYHSM Oct 29, 2025
6c625cd
Merge remote-tracking branch 'upstream/main' into qk_norm
CYHSM Nov 11, 2025
f25e3d3
additional fixes after review
CYHSM Nov 11, 2025
36769ee
removed config files from PR
CYHSM Nov 11, 2025
5fed18e
fix: apply pre-commit
CYHSM Nov 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ requires-python = ">=3.10,<3.13"
description = "Modalities, a PyTorch-native framework for distributed and reproducible foundation model training."
readme = "README.md"
dependencies = [
"numpy<2.0",
"numpy",
"torch",
"packaging",
"tqdm",
"pyyaml",
Expand Down
19 changes: 13 additions & 6 deletions src/modalities/models/components/layer_norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,13 @@ class RMSLayerNorm(nn.Module):
def __init__(self, ndim: int, bias: bool = True, epsilon: float = 1e-5):
"""
Initializes a LayerNorm module.

Args:
ndim (int): The number of dimensions of the input tensor.
bias (bool, optional): If True, adds a learnable bias to the normalized tensor. Defaults to True.
epsilon (float, optional): A small value added to the denominator for numerical stability. Defaults to 1e-5.

Note:
Original paper: https://arxiv.org/pdf/1910.07467.pdf
Source code adopted from https://github.com/facebookresearch/llama/blob/a0a4da8b497c566403941ceec47c2512ecf9dd20/llama/model.py#L34C1-L77C36

Returns:
None
"""
Expand All @@ -41,13 +38,10 @@ def _norm(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the layer normalization module.

Args:
x (torch.Tensor): Input tensor.

Returns:
torch.Tensor: Output tensor after applying layer normalization.

"""
output = self._norm(x.float()).type_as(x)
if self.bias is None:
Expand Down Expand Up @@ -97,3 +91,16 @@ class RMSLayerNormConfig(BaseModel):
ndim: Annotated[int, Field(strict=True, ge=1)]
epsilon: Annotated[float, Field(gt=0, default=1e-6)]
bias: Annotated[bool, Field(strict=True, default=True)]


class PytorchRMSLayerNormConfig(BaseModel):
"""
Configuration class for RMSLayerNorm.

Args:
normalized_shape (int): The expected size of the input shape.
eps (float, optional): Small value added to the input to avoid division by zero. Defaults to 1e-5.
"""

normalized_shape: Annotated[int, Field(strict=True, ge=1)]
eps: Annotated[float, Field(strict=True, gt=0, default=1e-5)]
32 changes: 30 additions & 2 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@

from modalities.config.lookup_enum import LookupEnum
from modalities.config.utils import convert_base_model_config_to_dict
from modalities.models.components.layer_norms import LayerNormConfig, RMSLayerNorm, RMSLayerNormConfig
from modalities.models.components.layer_norms import (
LayerNormConfig,
PytorchRMSLayerNormConfig,
RMSLayerNorm,
RMSLayerNormConfig,
)
from modalities.models.model import ActivationType, NNModel, SwiGLU
from modalities.util import parse_enum_by_name

Expand All @@ -33,15 +38,17 @@ class LayerNorms(LookupEnum):
Attributes:
RMSNorm: RMSLayerNorm class.
LayerNorm: nn.LayerNorm class.
PyTorchRMSNorm: nn.RMSNorm class.
"""

rms_norm = RMSLayerNorm
layer_norm = nn.LayerNorm
pytorch_rms_norm = nn.RMSNorm


class LayerNormWrapperConfig(BaseModel):
norm_type: LayerNorms
config: LayerNormConfig | RMSLayerNormConfig
config: PytorchRMSLayerNormConfig | RMSLayerNormConfig | LayerNormConfig


class PositionTypes(str, Enum):
Expand Down Expand Up @@ -292,6 +299,7 @@ def parse_sharding_strategy_by_name(cls, name):
config: RotaryTransformConfig | IdentityTransformConfig

qkv_transforms: list[QueryKeyValueTransformConfig]
qk_norm_config: Optional[LayerNormWrapperConfig] = None


class GPT2LLMConfig(BaseModel):
Expand Down Expand Up @@ -461,6 +469,23 @@ def __init__(
for transform_config in attention_config.qkv_transforms
)

# QK Norm - helpful for models >1B to stabilize training
# Baseline logits w/o qk norm: (Q @ K^T) / sqrt(d_h)
# with geometric form of dot product: (||q_i|| * ||k_j|| * cos(θ_ij)) / sqrt(d_h)
# so if the model wants to increase the distance between logits
# it needs to scale q or k OR adjust the angle between them
# qk norm forces the model to mostly adjust the angle between q and k which stabilizes training
if attention_config.attention_config is not None:
self.q_norm = attention_config.qk_norm_config.norm_type.value(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this be a use-case to use other norms for q and k norm? If not, I would hardcode RMS norm here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LayerNorm can also be used for QK-norm so I would leave it like that. This keeps the experiments also consistent where we tested layer vs rms as all layers were changed to that norm. Just for an experiment where we want layernorm everywhere except the QK values we would need to change that

**dict(attention_config.qk_norm_config.config)
)
self.k_norm = attention_config.qk_norm_config.norm_type.value(
**dict(attention_config.qk_norm_config.config)
)
else:
self.q_norm = None
self.k_norm = None

def projection(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Applies projections to the input tensor to get queries, keys, and values.
Expand Down Expand Up @@ -632,6 +657,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# q: (B, nh_q, T, hd), k: (B, nh_kv, T, hd), v: (B, nh_kv, T, hd)
q, k, v = CausalSelfAttention.execute_qkv_transforms(q, k, v, self.qkv_transforms, self.n_head_q)
if self.q_norm is not None and self.k_norm is not None:
q = self.q_norm(q)
k = self.k_norm(k)
y = CausalSelfAttention.execute_attention(q, k, v, self.dropout, self.attention_impl) # (B, T, nh_q, hd)
y = y.reshape(B, T, -1) # (B, T, n_embd), re-assemble all head outputs side by side
return self.resid_dropout(self.c_proj(y)) # (B, T, n_embd), output projection
Expand Down
52 changes: 51 additions & 1 deletion tests/models/test_causal_self_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
import pytest
import torch

from modalities.models.gpt2.gpt2_model import AttentionConfig, CausalSelfAttention
from modalities.models.gpt2.gpt2_model import (
AttentionConfig,
CausalSelfAttention,
LayerNorms,
LayerNormWrapperConfig,
PytorchRMSLayerNormConfig,
)

torch.manual_seed(0)

Expand Down Expand Up @@ -222,3 +228,47 @@ def test_attention_implementation_approximate_equality(
atol=2.5e-3, # default for bfloat16: 1e-5
rtol=0.016, # default for bfloat16: 0.016
)


@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.")
@pytest.mark.parametrize(
"n_head_q, n_head_kv, n_embd, attention_impl",
[
(4, 4, 32, "manual"),
(8, 2, 32, "manual"),
(4, 4, 32, "pytorch_flash"),
(8, 2, 32, "pytorch_flash"),
(4, 4, 32, "dao_flash"),
(8, 2, 32, "dao_flash"),
],
)
def test_qk_norm(n_head_q, n_head_kv, n_embd, attention_impl):
batch_size = 2
block_size = 10
head_dim = n_embd // n_head_q
embedding_shape = (batch_size, block_size - 1, n_embd)
embedded_input_seq = _get_random_input_seq(embedding_shape)

attention_config_no_norm = AttentionConfig(qkv_transforms=[], use_qk_norm=False)
attention_config_with_norm = AttentionConfig(
qkv_transforms=[],
use_qk_norm=True,
qk_norm_config=LayerNormWrapperConfig(
norm_type=LayerNorms.pytorch_rms_norm, config=PytorchRMSLayerNormConfig(normalized_shape=head_dim)
),
)

# Create two separate layers with same initial weights
torch.manual_seed(0)
layer_no_norm = _get_random_attention_layer(n_head_q, n_head_kv, n_embd, attention_impl, attention_config_no_norm)

torch.manual_seed(0)
layer_with_norm = _get_random_attention_layer(
n_head_q, n_head_kv, n_embd, attention_impl, attention_config_with_norm
)

output_no_norm = layer_no_norm(embedded_input_seq)
output_with_norm = layer_with_norm(embedded_input_seq)

assert output_no_norm.shape == output_with_norm.shape == embedding_shape
assert not torch.allclose(output_no_norm, output_with_norm, atol=1e-6)