Skip to content

Safetensors conversion #2290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions keras_hub/src/utils/transformers/export_gemma_to_safetensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import json
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets make this a model-agnostic export utility. Rename file to safetensor_exporter.py
add a dict to maintain the mapping

MODEL_EXPORTERS = {
                   "GemmaBackbone": gemma_exporter.get_gemma_weights_map,
                   "LlamaBackbone": llama_exporter.get_llama_weights_map, # Future
}

and a user facing API function for the export

def export_to_safetensors(keras_model):
...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implement exporter mapping for each model - for this PR's scope just the Gemma model that can serve as a prototype for other models

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's land this PR first and do this in a separate PR: #2290 (comment)

import os
import shutil
import warnings

import jax.numpy as jnp
import keras
import keras.ops as ops
from safetensors.flax import save_file as flax_save_file
from safetensors.tensorflow import save_file as tf_save_file
from safetensors.torch import save_file as torch_save_file


def convert_to_hf_config(keras_config):
hf_config = {
"vocab_size": keras_config.vocabulary_size,
"num_hidden_layers": keras_config.num_layers,
"num_attention_heads": keras_config.num_query_heads,
"num_key_value_heads": keras_config.num_key_value_heads,
"hidden_size": keras_config.hidden_dim,
"intermediate_size": keras_config.intermediate_dim // 2,
"head_dim": keras_config.head_dim,
"max_position_embeddings": 8192,
}
return hf_config


def export_to_hf(keras_model, path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@abheesht17 abheesht17 Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do you think we should refactor some of the common code across models to a separate file? We can then expose that as the API.

So, this is how the directory keras_hub/src/utils/transformers/convert_to_safetensor/ will look like:

  • export.py: this will have the common code. We will expose this as the API. This will also check if we support safetensor conversion for a given passed model yet.
  • gemma.py: this will just have a way to create the weight dictionary for Gemma. Inside export.py, we will call the the weight conversion function specific to a specified model.

Pinging @mattdangerw to confirm if we should do this now or at a later point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could land and do the API bit a later point. Though agree it's an important concern. I'm not sure if we want a method like model.save_to_preset() or a function like some_export(model). Any thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think structuring the export logic with a utility function (export_to_hf) and model-specific mappings (gemma.py) will enhance scalability and maintainability. New models can be added by creating a new file, while existing tests only need an import update.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to Abheesht's comment we need an API instead of a script for Gemma, we already have that
https://github.com/keras-team/keras-hub/blob/master/tools/gemma/export_gemma_to_hf.py

"""This function converts a Keras Gemma model to Hugging Face format by:
- Extracting and mapping weights from the Keras backbone to safetensors.
- Saving the configuration as 'config.json'.
- Saving weights in 'model.safetensors'.
- Saving tokenizer assets.
Args:
keras_model: The Keras Gemma model (e.g., GemmaCausalLM) to convert.
path: str. Path of the directory to which the safetensors file,
config and tokenizer will be saved.
"""
backend = keras.config.backend()
backbone = keras_model.backbone
hf_config = convert_to_hf_config(backbone)

weights_dict = {}

# Map token embedding
token_embedding_layer = backbone.get_layer("token_embedding")
weights_dict["model.embed_tokens.weight"] = token_embedding_layer.weights[0]

for i in range(backbone.num_layers):
decoder_layer = backbone.get_layer(f"decoder_block_{i}")

# Pre-attention normalization
weights_dict[f"model.layers.{i}.input_layernorm.weight"] = (
decoder_layer.pre_attention_norm.weights[0]
)

# Attention query projection
query_kernel = decoder_layer.attention.query_dense.weights[0]
query_kernel = ops.transpose(query_kernel, axes=(1, 0, 2))
query_kernel = ops.reshape(query_kernel, (-1, backbone.hidden_dim))
query_kernel = ops.transpose(query_kernel)
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = query_kernel

# Attention key projection
key_kernel = decoder_layer.attention.key_dense.weights[0][0]
weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = (
ops.transpose(key_kernel)
)

# Attention value projection
value_kernel = decoder_layer.attention.value_dense.weights[0][0]
weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = (
ops.transpose(value_kernel)
)

# Attention output projection
out_kernel = decoder_layer.attention.output_dense.weights[0]
out_kernel = ops.transpose(out_kernel, axes=(2, 0, 1))
out_kernel = ops.reshape(out_kernel, (backbone.hidden_dim, -1))
weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = out_kernel

# Post-attention normalization
weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = (
decoder_layer.pre_ffw_norm.weights[0]
)

# MLP gate projection
gate_kernel = decoder_layer.gating_ffw.weights[0]
weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = ops.transpose(
gate_kernel
)

# MLP up projection
up_kernel = decoder_layer.gating_ffw_2.weights[0]
weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = ops.transpose(
up_kernel
)

# MLP down projection
down_kernel = decoder_layer.ffw_linear.weights[0]
weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = ops.transpose(
down_kernel
)

# Map final normalization
weights_dict["model.norm.weight"] = backbone.get_layer(
"final_normalization"
).weights[0]

# Tie lm_head.weight to embedding weights
weights_dict["lm_head.weight"] = token_embedding_layer.weights[0]

# Save config
os.makedirs(path, exist_ok=True)
config_path = os.path.join(path, "config.json")
with open(config_path, "w") as f:
json.dump(hf_config, f)

# Save weights based on backend
weights_path = os.path.join(path, "model.safetensors")
if backend == "torch":
weights_dict_contiguous = {
k: v.contiguous() for k, v in weights_dict.items()
}
torch_save_file(weights_dict_contiguous, weights_path)
elif backend == "tensorflow":
tf_save_file(weights_dict, weights_path)
elif backend == "jax":
weights_dict_contiguous = {
k: jnp.ascontiguousarray(v) for k, v in weights_dict.items()
}
flax_save_file(weights_dict_contiguous, weights_path)

# Save tokenizer assets
keras_model.preprocessor.tokenizer.save_assets(path)

# Rename vocabulary file
vocab_spm_path = os.path.join(path, "vocabulary.spm")
tokenizer_model_path = os.path.join(path, "tokenizer.model")
if os.path.exists(vocab_spm_path):
shutil.move(vocab_spm_path, tokenizer_model_path)
else:
warnings.warn(
f"{vocab_spm_path} not found. Tokenizer may not load correctly."
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os

import pytest
import torch
from transformers import GemmaForCausalLM
from transformers import GemmaTokenizer

from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM
from keras_hub.src.tests.test_case import TestCase
from keras_hub.src.utils.transformers.export_gemma_to_safetensor import (
export_to_hf,
)


class TestGemmaExport(TestCase):
@pytest.mark.large
def test_export_to_hf(self):
# Load Keras model
keras_model = GemmaCausalLM.from_preset("gemma_2b_en")
input_text = "All hail RCB"
max_length = 25

# Export to Hugging Face format using self.tmp_path
export_path = os.path.join(self.get_temp_dir(), "export_to_hf")
export_to_hf(keras_model, export_path)

# Load Hugging Face model and tokenizer
hf_model = GemmaForCausalLM.from_pretrained(export_path)
hf_tokenizer = GemmaTokenizer.from_pretrained(export_path)

# Generate text with Keras model
keras_output = keras_model.generate(input_text, max_length=max_length)

# Generate text with Hugging Face model
hf_inputs = hf_tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
hf_outputs = hf_model.generate(
**hf_inputs, max_length=max_length, do_sample=False
)
hf_output_text = hf_tokenizer.decode(
hf_outputs[0], skip_special_tokens=True
)

self.assertEqual(keras_output, hf_output_text)
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ sentencepiece
tensorflow-datasets
safetensors
pillow
transformers
Loading