-
Notifications
You must be signed in to change notification settings - Fork 309
Safetensors conversion #2290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Safetensors conversion #2290
Changes from 9 commits
903733b
9f99030
c896fdb
219bf37
b5cf25c
6eaa954
2cbedc4
bbb2042
df2951a
aa5f7e0
ab27a73
cda19d3
f31ad26
a9253c0
4045ce6
bbc05a6
0c46606
c591697
1c57291
37bf4c3
9545d02
465541b
26526e2
b39e3ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| import json | ||
| import os | ||
| import shutil | ||
| import warnings | ||
|
|
||
| import torch | ||
| from safetensors.torch import save_file | ||
Bond099 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Set the Keras backend to jax/pytorch/tensorflow | ||
| os.environ["KERAS_BACKEND"] = "jax" | ||
Bond099 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def convert_to_hf_config(keras_config): | ||
| hf_config = { | ||
| "vocab_size": keras_config.vocabulary_size, | ||
| "num_hidden_layers": keras_config.num_layers, | ||
| "num_attention_heads": keras_config.num_query_heads, | ||
| "num_key_value_heads": keras_config.num_key_value_heads, | ||
| "hidden_size": keras_config.hidden_dim, | ||
| "intermediate_size": keras_config.intermediate_dim // 2, | ||
| "head_dim": keras_config.head_dim, | ||
| "max_position_embeddings": 8192, | ||
| } | ||
| return hf_config | ||
|
|
||
|
|
||
| def export_to_hf(keras_model, path): | ||
|
||
| """Export a Keras Gemma model to Hugging Face format. | ||
|
|
||
| Args: | ||
| keras_model: The Keras Gemma model (e.g., GemmaCausalLM) to convert. | ||
| path (str): Path to save the model.safetensors, config, and tokenizer. | ||
Bond099 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
Bond099 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| This function converts a Keras Gemma model to Hugging Face format by: | ||
Bond099 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| - Extracting and mapping weights from the Keras backbone to safetensors. | ||
| - Saving the configuration as 'config.json'. | ||
| - Saving weights in 'model.safetensors'. | ||
| - Saving tokenizer assets. | ||
| """ | ||
| backbone = keras_model.backbone | ||
| hf_config = convert_to_hf_config(backbone) | ||
|
|
||
| weights_dict = {} | ||
|
|
||
| # Map token embedding | ||
| token_embedding = backbone.get_layer("token_embedding").get_weights()[0] | ||
| weights_dict["model.embed_tokens.weight"] = torch.from_numpy( | ||
| token_embedding | ||
| ) | ||
|
|
||
| for i in range(backbone.num_layers): | ||
| decoder_layer = backbone.get_layer(f"decoder_block_{i}") | ||
|
|
||
| # Pre-attention normalization | ||
| pre_attn_norm = decoder_layer.pre_attention_norm.get_weights()[0] | ||
| weights_dict[f"model.layers.{i}.input_layernorm.weight"] = ( | ||
| torch.from_numpy(pre_attn_norm) | ||
| ) | ||
|
|
||
| # Attention query projection | ||
| query_kernel = decoder_layer.attention.query_dense.get_weights()[0] | ||
| query_kernel = ( | ||
| torch.from_numpy(query_kernel) | ||
| .permute(1, 0, 2) | ||
| .reshape(-1, backbone.hidden_dim) | ||
| .T | ||
| ) | ||
| weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = query_kernel | ||
|
|
||
| # Attention key projection | ||
| key_kernel = decoder_layer.attention.key_dense.get_weights()[0][0] | ||
| key_kernel = torch.from_numpy(key_kernel).T | ||
| weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = key_kernel | ||
|
|
||
| # Attention value projection | ||
| value_kernel = decoder_layer.attention.value_dense.get_weights()[0][0] | ||
| value_kernel = torch.from_numpy(value_kernel).T | ||
| weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = value_kernel | ||
|
|
||
| # Attention output projection | ||
| out_kernel = decoder_layer.attention.output_dense.get_weights()[0] | ||
| out_kernel = ( | ||
| torch.from_numpy(out_kernel) | ||
| .permute(2, 0, 1) | ||
| .reshape(backbone.hidden_dim, -1) | ||
| ) | ||
| weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = out_kernel | ||
|
|
||
| # Post-attention normalization | ||
| post_attn_norm = decoder_layer.pre_ffw_norm.get_weights()[0] | ||
| weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = ( | ||
| torch.from_numpy(post_attn_norm) | ||
| ) | ||
|
|
||
| # MLP gate projection | ||
| gate_kernel = decoder_layer.gating_ffw.get_weights()[0] | ||
| gate_kernel = torch.from_numpy(gate_kernel).T | ||
| weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = gate_kernel | ||
|
|
||
| # MLP up projection | ||
| up_kernel = decoder_layer.gating_ffw_2.get_weights()[0] | ||
| up_kernel = torch.from_numpy(up_kernel).T | ||
| weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = up_kernel | ||
|
|
||
| # MLP down projection | ||
| down_kernel = decoder_layer.ffw_linear.get_weights()[0] | ||
| down_kernel = torch.from_numpy(down_kernel).T | ||
| weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = down_kernel | ||
|
|
||
| # Map final normalization | ||
| final_norm = backbone.get_layer("final_normalization").get_weights()[0] | ||
| weights_dict["model.norm.weight"] = torch.from_numpy(final_norm) | ||
|
|
||
| # Tie lm_head.weight to embedding weights | ||
| weights_dict["lm_head.weight"] = weights_dict[ | ||
| "model.embed_tokens.weight" | ||
| ].clone() | ||
Bond099 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Save config | ||
| os.makedirs(path, exist_ok=True) | ||
| config_path = os.path.join(path, "config.json") | ||
| with open(config_path, "w") as f: | ||
| json.dump(hf_config, f) | ||
|
|
||
| # Make tensors contiguous before saving | ||
| weights_dict_contiguous = { | ||
| k: v.contiguous() for k, v in weights_dict.items() | ||
| } | ||
Bond099 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Save weights | ||
| weights_path = os.path.join(path, "model.safetensors") | ||
| save_file(weights_dict_contiguous, weights_path) | ||
|
|
||
| # Save tokenizer assets | ||
| keras_model.preprocessor.tokenizer.save_assets(path) | ||
|
|
||
| # Rename vocabulary file | ||
| vocab_spm_path = os.path.join(path, "vocabulary.spm") | ||
| tokenizer_model_path = os.path.join(path, "tokenizer.model") | ||
| if os.path.exists(vocab_spm_path): | ||
| shutil.move(vocab_spm_path, tokenizer_model_path) | ||
| else: | ||
| warnings.warn( | ||
| f"{vocab_spm_path} not found. Tokenizer may not load correctly." | ||
| ) | ||
Bond099 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| import pytest | ||
| import torch | ||
| from transformers import GemmaForCausalLM | ||
| from transformers import GemmaTokenizer | ||
|
|
||
| from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM | ||
| from keras_hub.src.tests.test_case import TestCase | ||
| from keras_hub.src.utils.transformers.export_gemma_to_safetensor import ( | ||
| export_to_hf, | ||
| ) | ||
|
|
||
|
|
||
| class TestGemmaExport(TestCase): | ||
| @pytest.fixture(autouse=True) | ||
Bond099 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def setup_tmp_path(self, tmp_path): | ||
| """Set up the tmp_path fixture as an instance attribute | ||
| before each test.""" | ||
| self.tmp_path = tmp_path | ||
|
|
||
| @pytest.mark.large | ||
| def test_export_to_hf(self): | ||
| # Load Keras model | ||
| keras_model = GemmaCausalLM.from_preset("gemma_2b_en") | ||
Bond099 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| input_text = "All hail RCB" | ||
| max_length = 25 | ||
|
|
||
| # Export to Hugging Face format using self.tmp_path | ||
| export_path = self.tmp_path / "export_to_hf" | ||
| export_to_hf(keras_model, export_path) | ||
|
|
||
| # Load Hugging Face model and tokenizer | ||
| hf_model = GemmaForCausalLM.from_pretrained(export_path) | ||
| hf_tokenizer = GemmaTokenizer.from_pretrained(export_path) | ||
|
|
||
| # Generate text with Keras model | ||
| keras_output = keras_model.generate(input_text, max_length=max_length) | ||
|
|
||
| # Generate text with Hugging Face model | ||
| hf_inputs = hf_tokenizer(input_text, return_tensors="pt") | ||
| with torch.no_grad(): | ||
| hf_outputs = hf_model.generate( | ||
| **hf_inputs, max_length=max_length, do_sample=False | ||
| ) | ||
| hf_output_text = hf_tokenizer.decode( | ||
| hf_outputs[0], skip_special_tokens=True | ||
| ) | ||
|
|
||
| self.assertEqual(keras_output, hf_output_text) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,3 +18,4 @@ sentencepiece | |
| tensorflow-datasets | ||
| safetensors | ||
| pillow | ||
| transformers | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets make this a model-agnostic export utility. Rename file to safetensor_exporter.py
add a dict to maintain the mapping
and a user facing API function for the export
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implement exporter mapping for each model - for this PR's scope just the Gemma model that can serve as a prototype for other models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's land this PR first and do this in a separate PR: #2290 (comment)