-
Notifications
You must be signed in to change notification settings - Fork 291
Safetensors conversion #2290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Safetensors conversion #2290
Changes from 16 commits
903733b
9f99030
c896fdb
219bf37
b5cf25c
6eaa954
2cbedc4
bbb2042
df2951a
aa5f7e0
ab27a73
cda19d3
f31ad26
a9253c0
4045ce6
bbc05a6
0c46606
c591697
1c57291
37bf4c3
9545d02
465541b
26526e2
b39e3ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import json | ||
import os | ||
import shutil | ||
import warnings | ||
|
||
import jax.numpy as jnp | ||
import keras | ||
import keras.ops as ops | ||
from safetensors.flax import save_file as flax_save_file | ||
from safetensors.tensorflow import save_file as tf_save_file | ||
from safetensors.torch import save_file as torch_save_file | ||
Bond099 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def convert_to_hf_config(keras_config): | ||
hf_config = { | ||
"vocab_size": keras_config.vocabulary_size, | ||
"num_hidden_layers": keras_config.num_layers, | ||
"num_attention_heads": keras_config.num_query_heads, | ||
"num_key_value_heads": keras_config.num_key_value_heads, | ||
"hidden_size": keras_config.hidden_dim, | ||
"intermediate_size": keras_config.intermediate_dim // 2, | ||
"head_dim": keras_config.head_dim, | ||
"max_position_embeddings": 8192, | ||
} | ||
return hf_config | ||
|
||
|
||
def export_to_hf(keras_model, path): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add the API export decorator here, similar to this: https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/bloom/bloom_backbone.py#L15-L16 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, do you think we should refactor some of the common code across models to a separate file? We can then expose that as the API. So, this is how the directory
Pinging @mattdangerw to confirm if we should do this now or at a later point. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we could land and do the API bit a later point. Though agree it's an important concern. I'm not sure if we want a method like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think structuring the export logic with a utility function (export_to_hf) and model-specific mappings (gemma.py) will enhance scalability and maintainability. New models can be added by creating a new file, while existing tests only need an import update. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 to Abheesht's comment we need an API instead of a script for Gemma, we already have that |
||
"""This function converts a Keras Gemma model to Hugging Face format by: | ||
- Extracting and mapping weights from the Keras backbone to safetensors. | ||
- Saving the configuration as 'config.json'. | ||
- Saving weights in 'model.safetensors'. | ||
- Saving tokenizer assets. | ||
Args: | ||
keras_model: The Keras Gemma model (e.g., GemmaCausalLM) to convert. | ||
path: str. Path of the directory to which the safetensors file, | ||
config and tokenizer will be saved. | ||
Bond099 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
backend = keras.config.backend() | ||
backbone = keras_model.backbone | ||
hf_config = convert_to_hf_config(backbone) | ||
|
||
weights_dict = {} | ||
|
||
# Map token embedding | ||
token_embedding_layer = backbone.get_layer("token_embedding") | ||
weights_dict["model.embed_tokens.weight"] = token_embedding_layer.weights[0] | ||
|
||
for i in range(backbone.num_layers): | ||
decoder_layer = backbone.get_layer(f"decoder_block_{i}") | ||
|
||
# Pre-attention normalization | ||
weights_dict[f"model.layers.{i}.input_layernorm.weight"] = ( | ||
decoder_layer.pre_attention_norm.weights[0] | ||
) | ||
|
||
# Attention query projection | ||
query_kernel = decoder_layer.attention.query_dense.weights[0] | ||
query_kernel = ops.transpose(query_kernel, axes=(1, 0, 2)) | ||
query_kernel = ops.reshape(query_kernel, (-1, backbone.hidden_dim)) | ||
query_kernel = ops.transpose(query_kernel) | ||
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = query_kernel | ||
|
||
# Attention key projection | ||
key_kernel = decoder_layer.attention.key_dense.weights[0][0] | ||
weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = ( | ||
ops.transpose(key_kernel) | ||
) | ||
Bond099 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Attention value projection | ||
value_kernel = decoder_layer.attention.value_dense.weights[0][0] | ||
weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = ( | ||
ops.transpose(value_kernel) | ||
Bond099 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
# Attention output projection | ||
out_kernel = decoder_layer.attention.output_dense.weights[0] | ||
out_kernel = ops.transpose(out_kernel, axes=(2, 0, 1)) | ||
out_kernel = ops.reshape(out_kernel, (backbone.hidden_dim, -1)) | ||
weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = out_kernel | ||
|
||
# Post-attention normalization | ||
weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = ( | ||
decoder_layer.pre_ffw_norm.weights[0] | ||
) | ||
|
||
# MLP gate projection | ||
gate_kernel = decoder_layer.gating_ffw.weights[0] | ||
weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = ops.transpose( | ||
gate_kernel | ||
) | ||
|
||
# MLP up projection | ||
up_kernel = decoder_layer.gating_ffw_2.weights[0] | ||
weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = ops.transpose( | ||
up_kernel | ||
) | ||
|
||
# MLP down projection | ||
down_kernel = decoder_layer.ffw_linear.weights[0] | ||
weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = ops.transpose( | ||
down_kernel | ||
) | ||
|
||
# Map final normalization | ||
weights_dict["model.norm.weight"] = backbone.get_layer( | ||
"final_normalization" | ||
).weights[0] | ||
|
||
# Tie lm_head.weight to embedding weights | ||
weights_dict["lm_head.weight"] = token_embedding_layer.weights[0] | ||
|
||
# Save config | ||
os.makedirs(path, exist_ok=True) | ||
config_path = os.path.join(path, "config.json") | ||
with open(config_path, "w") as f: | ||
json.dump(hf_config, f) | ||
|
||
# Save weights based on backend | ||
weights_path = os.path.join(path, "model.safetensors") | ||
if backend == "torch": | ||
weights_dict_contiguous = { | ||
k: v.contiguous() for k, v in weights_dict.items() | ||
} | ||
torch_save_file(weights_dict_contiguous, weights_path) | ||
elif backend == "tensorflow": | ||
tf_save_file(weights_dict, weights_path) | ||
elif backend == "jax": | ||
weights_dict_contiguous = { | ||
k: jnp.ascontiguousarray(v) for k, v in weights_dict.items() | ||
} | ||
flax_save_file(weights_dict_contiguous, weights_path) | ||
|
||
# Save tokenizer assets | ||
keras_model.preprocessor.tokenizer.save_assets(path) | ||
|
||
# Rename vocabulary file | ||
vocab_spm_path = os.path.join(path, "vocabulary.spm") | ||
tokenizer_model_path = os.path.join(path, "tokenizer.model") | ||
if os.path.exists(vocab_spm_path): | ||
shutil.move(vocab_spm_path, tokenizer_model_path) | ||
else: | ||
warnings.warn( | ||
f"{vocab_spm_path} not found. Tokenizer may not load correctly." | ||
) | ||
Bond099 marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import os | ||
|
||
import pytest | ||
import torch | ||
from transformers import GemmaForCausalLM | ||
from transformers import GemmaTokenizer | ||
|
||
from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM | ||
from keras_hub.src.tests.test_case import TestCase | ||
from keras_hub.src.utils.transformers.export_gemma_to_safetensor import ( | ||
export_to_hf, | ||
) | ||
|
||
|
||
class TestGemmaExport(TestCase): | ||
@pytest.mark.large | ||
def test_export_to_hf(self): | ||
# Load Keras model | ||
keras_model = GemmaCausalLM.from_preset("gemma_2b_en") | ||
Bond099 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
input_text = "All hail RCB" | ||
max_length = 25 | ||
|
||
# Export to Hugging Face format using self.tmp_path | ||
export_path = os.path.join(self.get_temp_dir(), "export_to_hf") | ||
export_to_hf(keras_model, export_path) | ||
|
||
# Load Hugging Face model and tokenizer | ||
hf_model = GemmaForCausalLM.from_pretrained(export_path) | ||
hf_tokenizer = GemmaTokenizer.from_pretrained(export_path) | ||
|
||
# Generate text with Keras model | ||
keras_output = keras_model.generate(input_text, max_length=max_length) | ||
|
||
# Generate text with Hugging Face model | ||
hf_inputs = hf_tokenizer(input_text, return_tensors="pt") | ||
with torch.no_grad(): | ||
hf_outputs = hf_model.generate( | ||
**hf_inputs, max_length=max_length, do_sample=False | ||
) | ||
hf_output_text = hf_tokenizer.decode( | ||
hf_outputs[0], skip_special_tokens=True | ||
) | ||
|
||
self.assertEqual(keras_output, hf_output_text) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,3 +18,4 @@ sentencepiece | |
tensorflow-datasets | ||
safetensors | ||
pillow | ||
transformers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets make this a model-agnostic export utility. Rename file to safetensor_exporter.py
add a dict to maintain the mapping
and a user facing API function for the export
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implement exporter mapping for each model - for this PR's scope just the Gemma model that can serve as a prototype for other models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's land this PR first and do this in a separate PR: #2290 (comment)