@@ -319,7 +319,10 @@ class GPT2LLMConfig(BaseModel):
319
319
ffn_norm_config (LayerNormWrapperConfig): Config for normalization of the feed-forward network.
320
320
lm_head_norm_config (LayerNormWrapperConfig): Config for normalization of the language model head.
321
321
use_weight_tying (bool): Whether to use weight tying.
322
-
322
+ seed: int = None: The random seed for reproducibility.
323
+ enforce_swiglu_hidden_dim_multiple_of (Optional[int]): If specified, enforces the hidden dimension
324
+ in the SwiGLU layer to be a multiple of this value. Note that this is only relevant if the
325
+ activation_type is SwiGLU. Defaults to None.
323
326
"""
324
327
325
328
sample_key : str
@@ -344,6 +347,8 @@ class GPT2LLMConfig(BaseModel):
344
347
ffn_norm_config : LayerNormWrapperConfig
345
348
lm_head_norm_config : LayerNormWrapperConfig
346
349
use_weight_tying : bool
350
+ seed : Optional [int ] = None
351
+ enforce_swiglu_hidden_dim_multiple_of : Optional [int ] = None
347
352
348
353
@model_validator (mode = "after" )
349
354
def check_divisibility (self ) -> "GPT2LLMConfig" :
@@ -695,6 +700,7 @@ def __init__(
695
700
ffn_hidden : int ,
696
701
attention_norm : nn .Module ,
697
702
ffn_norm : nn .Module ,
703
+ enforce_swiglu_hidden_dim_multiple_of : Optional [int ] = None ,
698
704
):
699
705
"""
700
706
Initializes the GPT2Block.
@@ -711,6 +717,9 @@ def __init__(
711
717
ffn_hidden (int): The size of the hidden layer in the feed-forward network.
712
718
attention_norm (nn.Module): The normalization layer for attention.
713
719
ffn_norm (nn.Module): The normalization layer for feed-forward network.
720
+ enforce_swiglu_hidden_dim_multiple_of (Optional[int]): If specified, enforces the
721
+ hidden dimension in the SwiGLU layer to be a multiple of this value. Note that this
722
+ is only relevant if the activation_type is SwiGLU. Defaults to None.
714
723
"""
715
724
super ().__init__ ()
716
725
self .attention_norm = attention_norm
@@ -728,7 +737,12 @@ def __init__(
728
737
if activation_type == ActivationType .GELU :
729
738
self .mlp = TransformerMLP (n_embd = n_embd , ffn_hidden = ffn_hidden , bias = bias , dropout = dropout )
730
739
elif activation_type == ActivationType .SWIGLU :
731
- self .mlp = SwiGLU (n_embd = n_embd , ffn_hidden = ffn_hidden , bias = bias )
740
+ self .mlp = SwiGLU (
741
+ n_embd = n_embd ,
742
+ ffn_hidden = ffn_hidden ,
743
+ bias = bias ,
744
+ enforce_swiglu_hidden_dim_multiple_of = enforce_swiglu_hidden_dim_multiple_of ,
745
+ )
732
746
else :
733
747
raise NotImplementedError ("unimplemented activation" )
734
748
@@ -781,6 +795,7 @@ def __init__(
781
795
lm_head_norm_config : LayerNormWrapperConfig ,
782
796
use_weight_tying : bool ,
783
797
seed : int = None ,
798
+ enforce_swiglu_hidden_dim_multiple_of : Optional [int ] = None ,
784
799
):
785
800
"""
786
801
Initializes the GPT2LLM object.
@@ -806,6 +821,9 @@ def __init__(
806
821
lm_head_norm_config (LayerNormWrapperConfig): Config for the language model head normalization module.
807
822
seed (int, optional): The random seed. Defaults to None.
808
823
use_weight_tying (bool): Whether to use weight tying.
824
+ enforce_swiglu_hidden_dim_multiple_of (Optional[int]): If specified, enforces
825
+ the hidden dimension in the SwiGLU layer to be a multiple of this value.
826
+ Note that this is only relevant if the activation_type is SwiGLU. Defaults to None.
809
827
"""
810
828
weight_decay_groups = {
811
829
"linear" : [".attn" , ".mlp" , ".lm_head.weight" ],
@@ -861,6 +879,7 @@ def __init__(
861
879
# a meta device!
862
880
attention_norm = attention_norm_config .norm_type .value (** dict (attention_norm_config .config )),
863
881
ffn_norm = ffn_norm_config .norm_type .value (** dict (ffn_norm_config .config )),
882
+ enforce_swiglu_hidden_dim_multiple_of = enforce_swiglu_hidden_dim_multiple_of ,
864
883
)
865
884
for _ in range (n_layer )
866
885
]
0 commit comments