Skip to content

Commit 9989fda

Browse files
Bond099abheesht17
andauthored
Safetensors conversion (#2290)
* Safetensors conversion * Reformatted * corrected and formatted into a util file * test cases wip * unit tests for safetensors conversion * rename vocab.spm * reformatted * address comments * minor changes * backend agnostic * address comments * convert_to_safetensor * Compatible with all backends * Cosmetic changes * Cosmetic changes (1) * Cosmetic changes (2) * Cosmetic changes (3) * Address comments --------- Co-authored-by: Abheesht Sharma <[email protected]>
1 parent 6729eaf commit 9989fda

File tree

4 files changed

+331
-0
lines changed

4 files changed

+331
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import keras.ops as ops
2+
3+
4+
def get_gemma_config(backbone):
5+
hf_config = {
6+
"vocab_size": backbone.vocabulary_size,
7+
"num_hidden_layers": backbone.num_layers,
8+
"num_attention_heads": backbone.num_query_heads,
9+
"num_key_value_heads": backbone.num_key_value_heads,
10+
"hidden_size": backbone.hidden_dim,
11+
"intermediate_size": backbone.intermediate_dim // 2,
12+
"head_dim": backbone.head_dim,
13+
"max_position_embeddings": 8192,
14+
}
15+
return hf_config
16+
17+
18+
def get_gemma_weights_map(backbone):
19+
weights_dict = {}
20+
21+
# Map token embedding
22+
token_embedding_layer = backbone.get_layer("token_embedding")
23+
weights_dict["model.embed_tokens.weight"] = token_embedding_layer.weights[0]
24+
25+
for i in range(backbone.num_layers):
26+
decoder_layer = backbone.get_layer(f"decoder_block_{i}")
27+
28+
# Pre-attention normalization
29+
weights_dict[f"model.layers.{i}.input_layernorm.weight"] = (
30+
decoder_layer.pre_attention_norm.weights[0]
31+
)
32+
33+
# Attention query projection
34+
query_kernel = decoder_layer.attention.query_dense.weights[0]
35+
query_kernel = ops.transpose(query_kernel, axes=(1, 0, 2))
36+
query_kernel = ops.reshape(query_kernel, (-1, backbone.hidden_dim))
37+
query_kernel = ops.transpose(query_kernel)
38+
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = query_kernel
39+
40+
# Attention key projection
41+
key_kernel = decoder_layer.attention.key_dense.weights[0][0]
42+
weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = (
43+
ops.transpose(key_kernel)
44+
)
45+
46+
# Attention value projection
47+
value_kernel = decoder_layer.attention.value_dense.weights[0][0]
48+
weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = (
49+
ops.transpose(value_kernel)
50+
)
51+
52+
# Attention output projection
53+
out_kernel = decoder_layer.attention.output_dense.weights[0]
54+
out_kernel = ops.transpose(out_kernel, axes=(2, 0, 1))
55+
out_kernel = ops.reshape(out_kernel, (backbone.hidden_dim, -1))
56+
weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = out_kernel
57+
58+
# Post-attention normalization
59+
weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = (
60+
decoder_layer.pre_ffw_norm.weights[0]
61+
)
62+
63+
# MLP gate projection
64+
gate_kernel = decoder_layer.gating_ffw.weights[0]
65+
weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = ops.transpose(
66+
gate_kernel
67+
)
68+
69+
# MLP up projection
70+
up_kernel = decoder_layer.gating_ffw_2.weights[0]
71+
weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = ops.transpose(
72+
up_kernel
73+
)
74+
75+
# MLP down projection
76+
down_kernel = decoder_layer.ffw_linear.weights[0]
77+
weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = ops.transpose(
78+
down_kernel
79+
)
80+
81+
# Map final normalization
82+
weights_dict["model.norm.weight"] = backbone.get_layer(
83+
"final_normalization"
84+
).weights[0]
85+
86+
# Tie weights, but clone to avoid sharing memory issues
87+
weights_dict["lm_head.weight"] = ops.copy(token_embedding_layer.weights[0])
88+
89+
return weights_dict
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import os
2+
3+
import numpy as np
4+
import torch
5+
from sentencepiece import SentencePieceTrainer
6+
from transformers import GemmaForCausalLM
7+
from transformers import GemmaTokenizer as HFGemmaTokenizer
8+
9+
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
10+
from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM
11+
from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
12+
GemmaCausalLMPreprocessor,
13+
)
14+
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
15+
from keras_hub.src.tests.test_case import TestCase
16+
from keras_hub.src.utils.transformers.export.hf_exporter import (
17+
export_to_safetensors,
18+
)
19+
20+
21+
class TestGemmaExport(TestCase):
22+
def test_export_to_hf(self):
23+
# Create a dummy tokenizer
24+
train_sentences = [
25+
"The quick brown fox jumped.",
26+
"I like pizza.",
27+
"This is a test.",
28+
]
29+
# TODO:Consider using keras_hub/src/tests/test_data/gemma_test_vocab.spm
30+
# instead of retraining a new vocab here. Will be faster.
31+
proto_prefix = os.path.join(self.get_temp_dir(), "dummy_vocab")
32+
SentencePieceTrainer.train(
33+
sentence_iterator=iter(train_sentences),
34+
model_prefix=proto_prefix,
35+
vocab_size=290,
36+
model_type="unigram",
37+
pad_id=0,
38+
bos_id=1,
39+
eos_id=2,
40+
unk_id=3,
41+
byte_fallback=True,
42+
pad_piece="<pad>",
43+
bos_piece="<bos>",
44+
eos_piece="<eos>",
45+
unk_piece="<unk>",
46+
user_defined_symbols=["<start_of_turn>", "<end_of_turn>"],
47+
)
48+
tokenizer = GemmaTokenizer(proto=f"{proto_prefix}.model")
49+
50+
# Create a small backbone
51+
backbone = GemmaBackbone(
52+
vocabulary_size=tokenizer.vocabulary_size(),
53+
num_layers=2,
54+
num_query_heads=4,
55+
num_key_value_heads=1,
56+
hidden_dim=512,
57+
intermediate_dim=1028,
58+
head_dim=128,
59+
)
60+
# Create preprocessor
61+
preprocessor = GemmaCausalLMPreprocessor(tokenizer=tokenizer)
62+
63+
# Create the causal LM model
64+
keras_model = GemmaCausalLM(
65+
backbone=backbone, preprocessor=preprocessor
66+
)
67+
68+
# Set all weights to random values
69+
rng = np.random.default_rng(42)
70+
weights = keras_model.get_weights()
71+
for i in range(len(weights)):
72+
weights[i] = rng.random(weights[i].shape).astype(weights[i].dtype)
73+
keras_model.set_weights(weights)
74+
75+
# Export to Hugging Face format
76+
export_path = os.path.join(self.get_temp_dir(), "export_small_model")
77+
export_to_safetensors(keras_model, export_path)
78+
# Load Hugging Face model and tokenizer
79+
hf_model = GemmaForCausalLM.from_pretrained(export_path)
80+
hf_tokenizer = HFGemmaTokenizer.from_pretrained(export_path)
81+
82+
# Verify configuration
83+
hf_config = hf_model.config
84+
self.assertEqual(
85+
hf_config.vocab_size,
86+
backbone.vocabulary_size,
87+
"Vocabulary sizes do not match",
88+
)
89+
self.assertEqual(
90+
hf_config.num_hidden_layers,
91+
backbone.num_layers,
92+
"Number of layers do not match",
93+
)
94+
self.assertEqual(
95+
hf_config.num_attention_heads,
96+
backbone.num_query_heads,
97+
"Number of query heads do not match",
98+
)
99+
self.assertEqual(
100+
hf_config.num_key_value_heads,
101+
backbone.num_key_value_heads,
102+
"Number of key value heads do not match",
103+
)
104+
self.assertEqual(
105+
hf_config.hidden_size,
106+
backbone.hidden_dim,
107+
"Hidden dimensions do not match",
108+
)
109+
self.assertEqual(
110+
hf_config.intermediate_size,
111+
backbone.intermediate_dim // 2,
112+
"Intermediate sizes do not match",
113+
)
114+
self.assertEqual(
115+
hf_config.head_dim,
116+
backbone.head_dim,
117+
"Head dimensions do not match",
118+
)
119+
self.assertEqual(
120+
hf_config.max_position_embeddings,
121+
8192,
122+
"Max position embeddings do not match",
123+
)
124+
125+
# Verify tokenizer compatibility
126+
self.assertEqual(
127+
hf_tokenizer.vocab_size,
128+
tokenizer.vocabulary_size(),
129+
"Tokenizer vocabulary sizes do not match",
130+
)
131+
132+
# Compare generated outputs
133+
prompt = "the quick"
134+
keras_output = keras_model.generate(prompt, max_length=20)
135+
input_ids = hf_tokenizer.encode(prompt, return_tensors="pt")
136+
with torch.no_grad():
137+
output_ids = hf_model.generate(
138+
input_ids, max_length=20, do_sample=False
139+
)
140+
hf_output = hf_tokenizer.decode(output_ids[0], skip_special_tokens=True)
141+
self.assertEqual(
142+
keras_output, hf_output, "Generated outputs do not match"
143+
)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import json
2+
import os
3+
import shutil
4+
import warnings
5+
6+
import keras
7+
8+
from keras_hub.src.utils.transformers.export.gemma import get_gemma_config
9+
from keras_hub.src.utils.transformers.export.gemma import get_gemma_weights_map
10+
11+
MODEL_CONFIGS = {
12+
"GemmaBackbone": get_gemma_config,
13+
# Add future models here, e.g., "LlamaBackbone": get_llama_config,
14+
}
15+
16+
MODEL_EXPORTERS = {
17+
"GemmaBackbone": get_gemma_weights_map,
18+
# Add future models here, e.g., "LlamaBackbone": get_llama_weights_map,
19+
}
20+
21+
22+
def export_to_safetensors(keras_model, path):
23+
"""Converts a Keras model to Hugging Face safetensor format.
24+
25+
It does the following:
26+
- Extracts and maps weights from the Keras backbone to safetensors.
27+
- Saves the configuration as 'config.json'.
28+
- Saves weights in 'model.safetensors'.
29+
- Saves tokenizer assets.
30+
31+
Args:
32+
keras_model: The Keras model to convert.
33+
path: str. Path of the directory to which the safetensors file,
34+
config and tokenizer will be saved.
35+
"""
36+
backend = keras.config.backend()
37+
backbone = keras_model.backbone
38+
model_type = backbone.__class__.__name__
39+
40+
if model_type not in MODEL_CONFIGS:
41+
raise ValueError(f"Config not implemented for {model_type}")
42+
43+
if model_type not in MODEL_EXPORTERS:
44+
raise ValueError(f"Exporter not implemented for {model_type}")
45+
46+
get_config_fn = MODEL_CONFIGS[model_type]
47+
hf_config = get_config_fn(backbone)
48+
49+
get_weights_fn = MODEL_EXPORTERS[model_type]
50+
weights_dict = get_weights_fn(backbone)
51+
52+
if not weights_dict:
53+
raise ValueError("No weights to save.")
54+
55+
# Save config
56+
os.makedirs(path, exist_ok=True)
57+
config_path = os.path.join(path, "config.json")
58+
with open(config_path, "w") as f:
59+
json.dump(hf_config, f)
60+
61+
# Save weights based on backend
62+
weights_path = os.path.join(path, "model.safetensors")
63+
if backend == "torch":
64+
from safetensors.torch import save_file
65+
66+
weights_dict_contiguous = {
67+
k: v.value.contiguous() if hasattr(v, "value") else v.contiguous()
68+
for k, v in weights_dict.items()
69+
}
70+
save_file(
71+
weights_dict_contiguous, weights_path, metadata={"format": "pt"}
72+
)
73+
elif backend == "tensorflow":
74+
from safetensors.tensorflow import save_file
75+
76+
save_file(weights_dict, weights_path, metadata={"format": "pt"})
77+
elif backend == "jax":
78+
from safetensors.flax import save_file
79+
80+
save_file(weights_dict, weights_path, metadata={"format": "pt"})
81+
else:
82+
raise ValueError(f"Unsupported backend: {backend}")
83+
84+
# Save tokenizer assets
85+
keras_model.preprocessor.tokenizer.save_assets(path)
86+
87+
# Rename vocabulary file
88+
vocab_spm_path = os.path.join(path, "vocabulary.spm")
89+
tokenizer_model_path = os.path.join(path, "tokenizer.model")
90+
if os.path.exists(vocab_spm_path):
91+
shutil.move(vocab_spm_path, tokenizer_model_path)
92+
else:
93+
warnings.warn(
94+
f"{vocab_spm_path} not found. Tokenizer may not load "
95+
"correctly. Ensure that the tokenizer configuration "
96+
"is correct and that the vocabulary file is present "
97+
"in the original model."
98+
)

requirements-common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ sentencepiece
1818
tensorflow-datasets
1919
safetensors
2020
pillow
21+
transformers

0 commit comments

Comments
 (0)