Skip to content

Commit 8d4edb4

Browse files
committed
feat(huggingface): More robust conversion allowing different config formats for norms.
1 parent 9866790 commit 8d4edb4

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

src/modalities/conversion/gpt2/conversion_model.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def convert_model_config(modalities_config: dict) -> GPT2Config:
4242
config = modalities_config["model_raw" if "model_raw" in modalities_config else "model"]["config"]
4343
_check_conversion_criteria(config)
4444

45+
ffn_norm_key = "ffn_norm" if "ffn_norm" in config else "ffn_norm_config"
46+
4547
return GPT2Config(
4648
vocab_size=config["vocab_size"],
4749
hidden_size=config["n_embd"],
@@ -53,9 +55,9 @@ def convert_model_config(modalities_config: dict) -> GPT2Config:
5355
attention_bias=config["bias"],
5456
mlp_bias=config["bias"],
5557
hidden_act="silu",
56-
layer_norm_eps=_get_layer_norm_value(config["ffn_norm_config"]["config"], "eps"),
57-
layer_norm_elementwise_affine=_get_layer_norm_value(config["ffn_norm_config"]["config"], "elementwise_affine"),
58-
layer_norm_bias=_get_layer_norm_value(config["ffn_norm_config"]["config"], "bias"),
58+
layer_norm_eps=_get_layer_norm_value(config[ffn_norm_key]["config"], "eps"),
59+
layer_norm_elementwise_affine=_get_layer_norm_value(config[ffn_norm_key]["config"], "elementwise_affine"),
60+
layer_norm_bias=_get_layer_norm_value(config[ffn_norm_key]["config"], "bias"),
5961
max_position_embeddings=config["sequence_length"],
6062
rope_theta=config["attention_config"]["qkv_transforms"][0]["config"]["base_freq"],
6163
_attn_implementation=_map_attention_type(config),
@@ -97,9 +99,14 @@ def _check_conversion_criteria(model_config: dict) -> None:
9799
assert model_config["activation_type"] == "swiglu"
98100
assert model_config["attention_implementation"] in ["pytorch_flash", "manual"]
99101

100-
norms = ["attention_norm_config", "ffn_norm_config", "lm_head_norm_config"]
102+
if "attention_norm" in model_config:
103+
norms = ["attention_norm", "ffn_norm", "lm_head_norm"]
104+
norm_type_key = "variant_key"
105+
else:
106+
norms = ["attention_norm_config", "ffn_norm_config", "lm_head_norm_config"]
107+
norm_type_key = "norm_type"
101108
for norm in norms:
102-
assert model_config[norm]["norm_type"] == "layer_norm"
109+
assert model_config[norm][norm_type_key] == "layer_norm"
103110

104111
assert (
105112
len(set(_get_layer_norm_value(model_config[norm]["config"], "bias") for norm in norms)) == 1

0 commit comments

Comments
 (0)