Skip to content
Merged
Changes from 3 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions keras_hub/src/utils/transformers/export_gemma_to_safetensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import json
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets make this a model-agnostic export utility. Rename file to safetensor_exporter.py
add a dict to maintain the mapping

MODEL_EXPORTERS = {
                   "GemmaBackbone": gemma_exporter.get_gemma_weights_map,
                   "LlamaBackbone": llama_exporter.get_llama_weights_map, # Future
}

and a user facing API function for the export

def export_to_safetensors(keras_model):
...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implement exporter mapping for each model - for this PR's scope just the Gemma model that can serve as a prototype for other models

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's land this PR first and do this in a separate PR: #2290 (comment)

import os

import torch
from safetensors.torch import save_file

# Set the Keras backend to jax
os.environ["KERAS_BACKEND"] = "jax"


def convert_to_hf_config(keras_config):
hf_config = {
"vocab_size": keras_config.vocabulary_size,
"num_hidden_layers": keras_config.num_layers,
"num_attention_heads": keras_config.num_query_heads,
"num_key_value_heads": keras_config.num_key_value_heads,
"hidden_size": keras_config.hidden_dim,
"intermediate_size": keras_config.intermediate_dim // 2,
"head_dim": keras_config.head_dim,
"max_position_embeddings": 8192,
}
return hf_config


def export_to_hf(keras_model, path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@abheesht17 abheesht17 Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do you think we should refactor some of the common code across models to a separate file? We can then expose that as the API.

So, this is how the directory keras_hub/src/utils/transformers/convert_to_safetensor/ will look like:

  • export.py: this will have the common code. We will expose this as the API. This will also check if we support safetensor conversion for a given passed model yet.
  • gemma.py: this will just have a way to create the weight dictionary for Gemma. Inside export.py, we will call the the weight conversion function specific to a specified model.

Pinging @mattdangerw to confirm if we should do this now or at a later point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could land and do the API bit a later point. Though agree it's an important concern. I'm not sure if we want a method like model.save_to_preset() or a function like some_export(model). Any thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think structuring the export logic with a utility function (export_to_hf) and model-specific mappings (gemma.py) will enhance scalability and maintainability. New models can be added by creating a new file, while existing tests only need an import update.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to Abheesht's comment we need an API instead of a script for Gemma, we already have that
https://github.com/keras-team/keras-hub/blob/master/tools/gemma/export_gemma_to_hf.py

"""Export a Keras Gemma model to Hugging Face format.

Args:
keras_model: The Keras Gemma model (e.g., GemmaCausalLM) to convert.
path (str): Path to save the model.safetensors, config, and tokenizer.


This function converts a Keras Gemma model to Hugging Face format by:
- Extracting and mapping weights from the Keras backbone to safetensors.
- Saving the configuration as 'config.json'.
- Saving weights in 'model.safetensors'.
- Saving tokenizer assets.
"""
backbone = keras_model.backbone
hf_config = convert_to_hf_config(backbone)

weights_dict = {}

# Map token embedding
token_embedding = backbone.get_layer("token_embedding").get_weights()[0]
weights_dict["model.embed_tokens.weight"] = torch.from_numpy(
token_embedding
)

for i in range(backbone.num_layers):
decoder_layer = backbone.get_layer(f"decoder_block_{i}")

# Pre-attention normalization
pre_attn_norm = decoder_layer.pre_attention_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.input_layernorm.weight"] = (
torch.from_numpy(pre_attn_norm)
)

# Attention query projection
query_kernel = decoder_layer.attention.query_dense.get_weights()[0]
query_kernel = (
torch.from_numpy(query_kernel)
.permute(1, 0, 2)
.reshape(-1, backbone.hidden_dim)
.T
)
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = query_kernel

# Attention key projection
key_kernel = decoder_layer.attention.key_dense.get_weights()[0][0]
key_kernel = torch.from_numpy(key_kernel).T
weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = key_kernel

# Attention value projection
value_kernel = decoder_layer.attention.value_dense.get_weights()[0][0]
value_kernel = torch.from_numpy(value_kernel).T
weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = value_kernel

# Attention output projection
out_kernel = decoder_layer.attention.output_dense.get_weights()[0]
out_kernel = (
torch.from_numpy(out_kernel)
.permute(2, 0, 1)
.reshape(backbone.hidden_dim, -1)
)
weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = out_kernel

# Post-attention normalization
post_attn_norm = decoder_layer.pre_ffw_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = (
torch.from_numpy(post_attn_norm)
)

# MLP gate projection
gate_kernel = decoder_layer.gating_ffw.get_weights()[0]
gate_kernel = torch.from_numpy(gate_kernel).T
weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = gate_kernel

# MLP up projection
up_kernel = decoder_layer.gating_ffw_2.get_weights()[0]
up_kernel = torch.from_numpy(up_kernel).T
weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = up_kernel

# MLP down projection
down_kernel = decoder_layer.ffw_linear.get_weights()[0]
down_kernel = torch.from_numpy(down_kernel).T
weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = down_kernel

# Map final normalization
final_norm = backbone.get_layer("final_normalization").get_weights()[0]
weights_dict["model.norm.weight"] = torch.from_numpy(final_norm)

# Tie lm_head.weight to embedding weights
weights_dict["lm_head.weight"] = weights_dict[
"model.embed_tokens.weight"
].clone()

# Save config
os.makedirs(path, exist_ok=True)
config_path = os.path.join(path, "config.json")
with open(config_path, "w") as f:
json.dump(hf_config, f)

# Make tensors contiguous before saving
weights_dict_contiguous = {
k: v.contiguous() for k, v in weights_dict.items()
}

# Save weights
weights_path = os.path.join(path, "model.safetensors")
save_file(weights_dict_contiguous, weights_path)

# Save tokenizer assets
keras_model.preprocessor.tokenizer.save_assets(path)
Loading