Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions keras_hub/src/utils/transformers/export_gemma_to_safetensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import json
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets make this a model-agnostic export utility. Rename file to safetensor_exporter.py
add a dict to maintain the mapping

MODEL_EXPORTERS = {
                   "GemmaBackbone": gemma_exporter.get_gemma_weights_map,
                   "LlamaBackbone": llama_exporter.get_llama_weights_map, # Future
}

and a user facing API function for the export

def export_to_safetensors(keras_model):
...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implement exporter mapping for each model - for this PR's scope just the Gemma model that can serve as a prototype for other models

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's land this PR first and do this in a separate PR: #2290 (comment)

import os
import shutil
import warnings

import torch
from safetensors.torch import save_file

# Set the Keras backend to jax/pytorch/tensorflow
os.environ["KERAS_BACKEND"] = "jax"


def convert_to_hf_config(keras_config):
hf_config = {
"vocab_size": keras_config.vocabulary_size,
"num_hidden_layers": keras_config.num_layers,
"num_attention_heads": keras_config.num_query_heads,
"num_key_value_heads": keras_config.num_key_value_heads,
"hidden_size": keras_config.hidden_dim,
"intermediate_size": keras_config.intermediate_dim // 2,
"head_dim": keras_config.head_dim,
"max_position_embeddings": 8192,
}
return hf_config


def export_to_hf(keras_model, path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@abheesht17 abheesht17 Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do you think we should refactor some of the common code across models to a separate file? We can then expose that as the API.

So, this is how the directory keras_hub/src/utils/transformers/convert_to_safetensor/ will look like:

  • export.py: this will have the common code. We will expose this as the API. This will also check if we support safetensor conversion for a given passed model yet.
  • gemma.py: this will just have a way to create the weight dictionary for Gemma. Inside export.py, we will call the the weight conversion function specific to a specified model.

Pinging @mattdangerw to confirm if we should do this now or at a later point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could land and do the API bit a later point. Though agree it's an important concern. I'm not sure if we want a method like model.save_to_preset() or a function like some_export(model). Any thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think structuring the export logic with a utility function (export_to_hf) and model-specific mappings (gemma.py) will enhance scalability and maintainability. New models can be added by creating a new file, while existing tests only need an import update.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to Abheesht's comment we need an API instead of a script for Gemma, we already have that
https://github.com/keras-team/keras-hub/blob/master/tools/gemma/export_gemma_to_hf.py

"""Export a Keras Gemma model to Hugging Face format.

Args:
keras_model: The Keras Gemma model (e.g., GemmaCausalLM) to convert.
path (str): Path to save the model.safetensors, config, and tokenizer.


This function converts a Keras Gemma model to Hugging Face format by:
- Extracting and mapping weights from the Keras backbone to safetensors.
- Saving the configuration as 'config.json'.
- Saving weights in 'model.safetensors'.
- Saving tokenizer assets.
"""
backbone = keras_model.backbone
hf_config = convert_to_hf_config(backbone)

weights_dict = {}

# Map token embedding
token_embedding = backbone.get_layer("token_embedding").get_weights()[0]
weights_dict["model.embed_tokens.weight"] = torch.from_numpy(
token_embedding
)

for i in range(backbone.num_layers):
decoder_layer = backbone.get_layer(f"decoder_block_{i}")

# Pre-attention normalization
pre_attn_norm = decoder_layer.pre_attention_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.input_layernorm.weight"] = (
torch.from_numpy(pre_attn_norm)
)

# Attention query projection
query_kernel = decoder_layer.attention.query_dense.get_weights()[0]
query_kernel = (
torch.from_numpy(query_kernel)
.permute(1, 0, 2)
.reshape(-1, backbone.hidden_dim)
.T
)
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = query_kernel

# Attention key projection
key_kernel = decoder_layer.attention.key_dense.get_weights()[0][0]
key_kernel = torch.from_numpy(key_kernel).T
weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = key_kernel

# Attention value projection
value_kernel = decoder_layer.attention.value_dense.get_weights()[0][0]
value_kernel = torch.from_numpy(value_kernel).T
weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = value_kernel

# Attention output projection
out_kernel = decoder_layer.attention.output_dense.get_weights()[0]
out_kernel = (
torch.from_numpy(out_kernel)
.permute(2, 0, 1)
.reshape(backbone.hidden_dim, -1)
)
weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = out_kernel

# Post-attention normalization
post_attn_norm = decoder_layer.pre_ffw_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = (
torch.from_numpy(post_attn_norm)
)

# MLP gate projection
gate_kernel = decoder_layer.gating_ffw.get_weights()[0]
gate_kernel = torch.from_numpy(gate_kernel).T
weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = gate_kernel

# MLP up projection
up_kernel = decoder_layer.gating_ffw_2.get_weights()[0]
up_kernel = torch.from_numpy(up_kernel).T
weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = up_kernel

# MLP down projection
down_kernel = decoder_layer.ffw_linear.get_weights()[0]
down_kernel = torch.from_numpy(down_kernel).T
weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = down_kernel

# Map final normalization
final_norm = backbone.get_layer("final_normalization").get_weights()[0]
weights_dict["model.norm.weight"] = torch.from_numpy(final_norm)

# Tie lm_head.weight to embedding weights
weights_dict["lm_head.weight"] = weights_dict[
"model.embed_tokens.weight"
].clone()

# Save config
os.makedirs(path, exist_ok=True)
config_path = os.path.join(path, "config.json")
with open(config_path, "w") as f:
json.dump(hf_config, f)

# Make tensors contiguous before saving
weights_dict_contiguous = {
k: v.contiguous() for k, v in weights_dict.items()
}

# Save weights
weights_path = os.path.join(path, "model.safetensors")
save_file(weights_dict_contiguous, weights_path)

# Save tokenizer assets
keras_model.preprocessor.tokenizer.save_assets(path)

# Rename vocabulary file
vocab_spm_path = os.path.join(path, "vocabulary.spm")
tokenizer_model_path = os.path.join(path, "tokenizer.model")
if os.path.exists(vocab_spm_path):
shutil.move(vocab_spm_path, tokenizer_model_path)
else:
warnings.warn(
f"{vocab_spm_path} not found. Tokenizer may not load correctly."
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
import torch
from transformers import GemmaForCausalLM
from transformers import GemmaTokenizer

from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM
from keras_hub.src.tests.test_case import TestCase
from keras_hub.src.utils.transformers.export_gemma_to_safetensor import (
export_to_hf,
)


class TestGemmaExport(TestCase):
@pytest.fixture(autouse=True)
def setup_tmp_path(self, tmp_path):
"""Set up the tmp_path fixture as an instance attribute
before each test."""
self.tmp_path = tmp_path

@pytest.mark.large
def test_export_to_hf(self):
# Load Keras model
keras_model = GemmaCausalLM.from_preset("gemma_2b_en")
input_text = "All hail RCB"
max_length = 25

# Export to Hugging Face format using self.tmp_path
export_path = self.tmp_path / "export_to_hf"
export_to_hf(keras_model, export_path)

# Load Hugging Face model and tokenizer
hf_model = GemmaForCausalLM.from_pretrained(export_path)
hf_tokenizer = GemmaTokenizer.from_pretrained(export_path)

# Generate text with Keras model
keras_output = keras_model.generate(input_text, max_length=max_length)

# Generate text with Hugging Face model
hf_inputs = hf_tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
hf_outputs = hf_model.generate(
**hf_inputs, max_length=max_length, do_sample=False
)
hf_output_text = hf_tokenizer.decode(
hf_outputs[0], skip_special_tokens=True
)

self.assertEqual(keras_output, hf_output_text)
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ sentencepiece
tensorflow-datasets
safetensors
pillow
transformers
Loading