@@ -42,6 +42,8 @@ def convert_model_config(modalities_config: dict) -> GPT2Config:
42
42
config = modalities_config ["model_raw" if "model_raw" in modalities_config else "model" ]["config" ]
43
43
_check_conversion_criteria (config )
44
44
45
+ ffn_norm_key = "ffn_norm" if "ffn_norm" in config else "ffn_norm_config"
46
+
45
47
return GPT2Config (
46
48
vocab_size = config ["vocab_size" ],
47
49
hidden_size = config ["n_embd" ],
@@ -53,9 +55,9 @@ def convert_model_config(modalities_config: dict) -> GPT2Config:
53
55
attention_bias = config ["bias" ],
54
56
mlp_bias = config ["bias" ],
55
57
hidden_act = "silu" ,
56
- layer_norm_eps = _get_layer_norm_value (config ["ffn_norm_config" ]["config" ], "eps" ),
57
- layer_norm_elementwise_affine = _get_layer_norm_value (config ["ffn_norm_config" ]["config" ], "elementwise_affine" ),
58
- layer_norm_bias = _get_layer_norm_value (config ["ffn_norm_config" ]["config" ], "bias" ),
58
+ layer_norm_eps = _get_layer_norm_value (config [ffn_norm_key ]["config" ], "eps" ),
59
+ layer_norm_elementwise_affine = _get_layer_norm_value (config [ffn_norm_key ]["config" ], "elementwise_affine" ),
60
+ layer_norm_bias = _get_layer_norm_value (config [ffn_norm_key ]["config" ], "bias" ),
59
61
max_position_embeddings = config ["sequence_length" ],
60
62
rope_theta = config ["attention_config" ]["qkv_transforms" ][0 ]["config" ]["base_freq" ],
61
63
_attn_implementation = _map_attention_type (config ),
@@ -97,9 +99,14 @@ def _check_conversion_criteria(model_config: dict) -> None:
97
99
assert model_config ["activation_type" ] == "swiglu"
98
100
assert model_config ["attention_implementation" ] in ["pytorch_flash" , "manual" ]
99
101
100
- norms = ["attention_norm_config" , "ffn_norm_config" , "lm_head_norm_config" ]
102
+ if "attention_norm" in model_config :
103
+ norms = ["attention_norm" , "ffn_norm" , "lm_head_norm" ]
104
+ norm_type_key = "variant_key"
105
+ else :
106
+ norms = ["attention_norm_config" , "ffn_norm_config" , "lm_head_norm_config" ]
107
+ norm_type_key = "norm_type"
101
108
for norm in norms :
102
- assert model_config [norm ]["norm_type" ] == "layer_norm"
109
+ assert model_config [norm ][norm_type_key ] == "layer_norm"
103
110
104
111
assert (
105
112
len (set (_get_layer_norm_value (model_config [norm ]["config" ], "bias" ) for norm in norms )) == 1
0 commit comments