Skip to content

Commit 1563b9e

Browse files
authored
Merge pull request #390 from Modalities/tp_swiglu_hidden_dim_fix
Fix of SwiGLU hidden not being multiple of world size
2 parents 491dac5 + 9bfa264 commit 1563b9e

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,10 @@ class GPT2LLMConfig(BaseModel):
319319
ffn_norm_config (LayerNormWrapperConfig): Config for normalization of the feed-forward network.
320320
lm_head_norm_config (LayerNormWrapperConfig): Config for normalization of the language model head.
321321
use_weight_tying (bool): Whether to use weight tying.
322-
322+
seed: int = None: The random seed for reproducibility.
323+
enforce_swiglu_hidden_dim_multiple_of (Optional[int]): If specified, enforces the hidden dimension
324+
in the SwiGLU layer to be a multiple of this value. Note that this is only relevant if the
325+
activation_type is SwiGLU. Defaults to None.
323326
"""
324327

325328
sample_key: str
@@ -344,6 +347,8 @@ class GPT2LLMConfig(BaseModel):
344347
ffn_norm_config: LayerNormWrapperConfig
345348
lm_head_norm_config: LayerNormWrapperConfig
346349
use_weight_tying: bool
350+
seed: Optional[int] = None
351+
enforce_swiglu_hidden_dim_multiple_of: Optional[int] = None
347352

348353
@model_validator(mode="after")
349354
def check_divisibility(self) -> "GPT2LLMConfig":
@@ -695,6 +700,7 @@ def __init__(
695700
ffn_hidden: int,
696701
attention_norm: nn.Module,
697702
ffn_norm: nn.Module,
703+
enforce_swiglu_hidden_dim_multiple_of: Optional[int] = None,
698704
):
699705
"""
700706
Initializes the GPT2Block.
@@ -711,6 +717,9 @@ def __init__(
711717
ffn_hidden (int): The size of the hidden layer in the feed-forward network.
712718
attention_norm (nn.Module): The normalization layer for attention.
713719
ffn_norm (nn.Module): The normalization layer for feed-forward network.
720+
enforce_swiglu_hidden_dim_multiple_of (Optional[int]): If specified, enforces the
721+
hidden dimension in the SwiGLU layer to be a multiple of this value. Note that this
722+
is only relevant if the activation_type is SwiGLU. Defaults to None.
714723
"""
715724
super().__init__()
716725
self.attention_norm = attention_norm
@@ -728,7 +737,12 @@ def __init__(
728737
if activation_type == ActivationType.GELU:
729738
self.mlp = TransformerMLP(n_embd=n_embd, ffn_hidden=ffn_hidden, bias=bias, dropout=dropout)
730739
elif activation_type == ActivationType.SWIGLU:
731-
self.mlp = SwiGLU(n_embd=n_embd, ffn_hidden=ffn_hidden, bias=bias)
740+
self.mlp = SwiGLU(
741+
n_embd=n_embd,
742+
ffn_hidden=ffn_hidden,
743+
bias=bias,
744+
enforce_swiglu_hidden_dim_multiple_of=enforce_swiglu_hidden_dim_multiple_of,
745+
)
732746
else:
733747
raise NotImplementedError("unimplemented activation")
734748

@@ -781,6 +795,7 @@ def __init__(
781795
lm_head_norm_config: LayerNormWrapperConfig,
782796
use_weight_tying: bool,
783797
seed: int = None,
798+
enforce_swiglu_hidden_dim_multiple_of: Optional[int] = None,
784799
):
785800
"""
786801
Initializes the GPT2LLM object.
@@ -806,6 +821,9 @@ def __init__(
806821
lm_head_norm_config (LayerNormWrapperConfig): Config for the language model head normalization module.
807822
seed (int, optional): The random seed. Defaults to None.
808823
use_weight_tying (bool): Whether to use weight tying.
824+
enforce_swiglu_hidden_dim_multiple_of (Optional[int]): If specified, enforces
825+
the hidden dimension in the SwiGLU layer to be a multiple of this value.
826+
Note that this is only relevant if the activation_type is SwiGLU. Defaults to None.
809827
"""
810828
weight_decay_groups = {
811829
"linear": [".attn", ".mlp", ".lm_head.weight"],
@@ -861,6 +879,7 @@ def __init__(
861879
# a meta device!
862880
attention_norm=attention_norm_config.norm_type.value(**dict(attention_norm_config.config)),
863881
ffn_norm=ffn_norm_config.norm_type.value(**dict(ffn_norm_config.config)),
882+
enforce_swiglu_hidden_dim_multiple_of=enforce_swiglu_hidden_dim_multiple_of,
864883
)
865884
for _ in range(n_layer)
866885
]

src/modalities/models/model.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def get_parameters(self) -> dict[str, torch.Tensor]:
7575
class SwiGLU(nn.Module):
7676
"""SwiGLU class to define the SwiGLU activation function."""
7777

78-
def __init__(self, n_embd: int, ffn_hidden: int, bias: bool):
78+
def __init__(
79+
self, n_embd: int, ffn_hidden: int, bias: bool, enforce_swiglu_hidden_dim_multiple_of: Optional[int] = None
80+
):
7981
"""
8082
Initializes the SwiGLU object.
8183
@@ -84,11 +86,17 @@ def __init__(self, n_embd: int, ffn_hidden: int, bias: bool):
8486
ffn_hidden (int): The number of hidden dimensions in the feed-forward network.
8587
Best practice: 4 * n_embd (https://arxiv.org/pdf/1706.03762)
8688
bias (bool): Whether to include bias terms in the linear layers.
89+
enforce_swiglu_hidden_dim_multiple_of (int): The multiple of which the hidden dimension should be enforced.
90+
This is required for FSDP + TP as the combincation does not support uneven sharding (yet).
91+
Defaults to 256 if not provided.
8792
"""
8893

8994
super().__init__()
90-
91-
hidden_dim = SwiGLU._get_hidden_dim(ffn_hidden=ffn_hidden)
95+
if enforce_swiglu_hidden_dim_multiple_of is None:
96+
enforce_swiglu_hidden_dim_multiple_of = 256
97+
hidden_dim = SwiGLU._get_hidden_dim(
98+
ffn_hidden=ffn_hidden, enforce_swiglu_hidden_dim_multiple_of=enforce_swiglu_hidden_dim_multiple_of
99+
)
92100

93101
self.W = nn.Linear(
94102
in_features=n_embd,
@@ -108,16 +116,21 @@ def __init__(self, n_embd: int, ffn_hidden: int, bias: bool):
108116
)
109117

110118
@staticmethod
111-
def _get_hidden_dim(ffn_hidden: int) -> int:
119+
def _get_hidden_dim(ffn_hidden: int, enforce_swiglu_hidden_dim_multiple_of: int) -> int:
112120
# Calculate the hidden dimension for the SwiGLU module based on the provided embedding dimension.
113121

114122
# Best practice: 4 * n_embd (https://arxiv.org/pdf/1706.03762)
115123
# To ensure that the number of parameters in the SwiGLU module with its additional
116124
# linear layer are equivalent to the TransformerMLP, we need to adapt the SwiGLU hidden dimension as follows:
117125
# 2 * (n_embd * hidden_dim) == 3 * (n_embd * 2/3 * hidden_dim)
118126
# Besides, we ensure that hidden_dim is the smallest multiple of
119-
# 256 that is greater than or equal the provided hidden_dim
120-
return 256 * ((int(2 * ffn_hidden / 3) + 256 - 1) // 256)
127+
# `enforce_swiglu_hidden_dim_multiple_of` that is greater than or equal the provided hidden_dim.
128+
# In case of TP we must set this to be at least of world size as FSDP + TP does not uneven sharding.
129+
# FSDP itself without TP support it already however.
130+
return enforce_swiglu_hidden_dim_multiple_of * (
131+
(int(2 * ffn_hidden / 3) + enforce_swiglu_hidden_dim_multiple_of - 1)
132+
// enforce_swiglu_hidden_dim_multiple_of
133+
)
121134

122135
def forward(self, x: torch.Tensor) -> torch.Tensor:
123136
"""

src/modalities/models/model_factory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,8 @@ def get_gpt2_model(
567567
lm_head_norm_config: LayerNormWrapperConfig,
568568
use_weight_tying: bool,
569569
use_meta_device: Optional[bool] = False,
570-
seed: int = None,
570+
seed: Optional[int] = None,
571+
enforce_swiglu_hidden_dim_multiple_of: Optional[int] = None,
571572
) -> GPT2LLM:
572573
config = dict(
573574
sample_key=sample_key,
@@ -590,6 +591,7 @@ def get_gpt2_model(
590591
lm_head_norm_config=lm_head_norm_config,
591592
seed=seed,
592593
use_weight_tying=use_weight_tying,
594+
enforce_swiglu_hidden_dim_multiple_of=enforce_swiglu_hidden_dim_multiple_of,
593595
)
594596
if use_meta_device and use_weight_tying:
595597
raise ValueError(

0 commit comments

Comments
 (0)