Skip to content

Conversation

@kharshith-k
Copy link
Member

No description provided.

@github-actions github-actions bot added the Gemma Gemma model specific issues label Oct 13, 2025
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @kharshith-k, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Gemma3 checkpoint conversion tool by introducing the capability to export Keras models into the Hugging Face safetensors format. This feature facilitates seamless integration of Gemma3 models with the Hugging Face ecosystem, providing a standardized and efficient way to share and utilize these models across different platforms. The changes include a new command-line option, comprehensive weight conversion logic, and a verification step to confirm the integrity of the exported models.

Highlights

  • Hugging Face Safetensors Export: Added functionality to convert Gemma3 Keras checkpoints to Hugging Face safetensors format, enabling broader interoperability.
  • New CLI Flag: Introduced the --export_safetensors command-line flag to activate the Hugging Face conversion process.
  • Comprehensive Weight Mapping: Implemented detailed logic for mapping and transforming Keras model weights (including token embeddings, attention layers, and MLP layers) to their corresponding Hugging Face safetensors structure, with support for bfloat16 data types.
  • Conversion Verification: Included a post-conversion step to load the exported Hugging Face model and perform a sample inference, ensuring the correctness and integrity of the conversion.
  • Configuration Conversion: Added a utility function to convert Keras Gemma model configurations to the Hugging Face Gemma3TextConfig format.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a valuable feature to convert Keras Gemma3 models to the Hugging Face Safetensors format, enhancing interoperability. The implementation is comprehensive, covering configuration conversion, weight porting, and a validation step. I've provided a few suggestions to improve code clarity, maintainability, and adherence to the repository's style guide, primarily by improving docstrings, refactoring duplicated code, and ensuring deterministic validation.

Comment on lines 46 to 47
def convert_to_hf_config(keras_config):
"""Convert Keras Gemma config to Hugging Face GemmaConfig."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for this function is missing Args and Returns sections, which is inconsistent with the repository's style guide. Providing detailed docstrings improves code clarity and maintainability.1

def convert_to_hf_config(keras_config):
    """Convert Keras Gemma config to Hugging Face GemmaConfig.

    Args:
        keras_config: A Keras Gemma3 config object from the backbone.

    Returns:
        A `transformers.Gemma3TextConfig` instance.
    """

Style Guide References

Footnotes

  1. The style guide requires all public functions to have Google-style docstrings, including comprehensive documentation for all parameters and return values.

Comment on lines 61 to 62
def export_to_hf(backbone, keras_tokenizer, path):
"""Convert a Keras Gemma model to Hugging Face format and save to path."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for this function is missing Args and Returns sections, which is inconsistent with the repository's style guide. Providing detailed docstrings improves code clarity and maintainability.1

def export_to_hf(backbone, keras_tokenizer, path):
    """Convert a Keras Gemma model to Hugging Face format and save to path.

    Args:
        backbone: A `keras_hub.models.Gemma3Backbone` instance.
        keras_tokenizer: A `keras_hub.models.Gemma3Tokenizer` instance.
        path: str. The path to save the Hugging Face model to.
    """

Style Guide References

Footnotes

  1. The style guide requires all public functions to have Google-style docstrings, including comprehensive documentation for all parameters and return values.

Comment on lines +68 to +72
def to_torch(weight):
# Convert bfloat16 to float32 first, then to torch, then to bfloat16
if hasattr(weight.dtype, "name") and "bfloat16" in str(weight.dtype):
weight = np.array(weight, dtype=np.float32)
return torch.from_numpy(weight).to(torch.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This helper function can be simplified to robustly handle various array types (like JAX arrays) and then used consistently throughout export_to_hf to reduce code duplication.

Currently, the conversion logic torch.from_numpy(np.array(weight, dtype=np.float32)).to(torch.bfloat16) is repeated for many weights. You can simplify to_torch to encapsulate this logic and improve maintainability.

With the suggested change, you can then refactor the rest of the function, for example:

q_kernel = block.attention.query_dense.get_weights()[0]
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = (
    to_torch(q_kernel)
    .permute(1, 0, 2)
    .reshape(backbone.hidden_dim, -1)
    .T
)
    def to_torch(weight):
        # Convert array-like weights (e.g., from JAX) to a float32 NumPy
        # array before creating a bfloat16 torch tensor for compatibility.
        np_weight = np.array(weight, dtype=np.float32)
        return torch.from_numpy(np_weight).to(torch.bfloat16)

@sachinprasadhs
Copy link
Collaborator

Thanks for the PR, the export to safetensors should be made available here https://github.com/keras-team/keras-hub/tree/master/keras_hub/src/utils/transformers/export.

  • Create a new file for Gemma3
  • Add Gemma3 details here
    MODEL_CONFIGS = {
    "GemmaBackbone": get_gemma_config,
    # Add for future models, e.g., "MistralBackbone": get_mistral_config
    }
    MODEL_EXPORTERS = {
    "GemmaBackbone": get_gemma_weights_map,
    # Add for future models, e.g., "MistralBackbone": get_mistral_weights_map
    }
    MODEL_TOKENIZER_CONFIGS = {
    "GemmaTokenizer": get_gemma_tokenizer_config,
    # Add for future models, e.g., "MistralTokenizer":
    # get_mistral_tokenizer_config
    }
  • Add a test file for Gemma3 export.

@hertschuh hertschuh added the kokoro:force-run Runs Tests on GPU label Oct 13, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Oct 13, 2025
@sachinprasadhs sachinprasadhs added the kokoro:force-run Runs Tests on GPU label Oct 24, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Oct 24, 2025
@sachinprasadhs sachinprasadhs added the kokoro:force-run Runs Tests on GPU label Oct 24, 2025
@sachinprasadhs
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces functionality to export Gemma3 models from KerasHub to the Hugging Face format. It adds the necessary export logic, corresponding tests, and integrates this into the checkpoint conversion script. The overall approach is good, with solid testing. However, there's a critical issue in tools/checkpoint_conversion/convert_gemma3_checkpoints.py where the export logic is duplicated instead of reusing the newly added library functions. This violates the DRY principle and the repository's style guide on backend-agnostic code. Additionally, there are some areas for improvement in the core export logic in keras_hub/src/utils/transformers/export/gemma3.py concerning code duplication and incorrect fallback logic for normalization layers.

Comment on lines +46 to +201
def convert_to_hf_config(keras_config):
"""Convert Keras Gemma config to Hugging Face GemmaConfig.
Args:
keras_config: A Keras Gemma3 config object from the backbone.
Returns:
A `transformers.Gemma3TextConfig` instance.
"""
hf_config = transformers.Gemma3TextConfig(
vocab_size=keras_config.vocabulary_size,
num_hidden_layers=keras_config.num_layers,
num_attention_heads=keras_config.num_query_heads,
num_key_value_heads=keras_config.num_key_value_heads,
hidden_size=keras_config.hidden_dim,
intermediate_size=keras_config.intermediate_dim,
head_dim=keras_config.head_dim,
max_position_embeddings=32768,
)
return hf_config


def export_to_hf(backbone, keras_tokenizer, path):
"""Convert a Keras Gemma model to Hugging Face format and save to path.
Args:
backbone: A `keras_hub.models.Gemma3Backbone` instance.
keras_tokenizer: A `keras_hub.models.Gemma3Tokenizer` instance.
path: str. The path to save the Hugging Face model to.
"""
hf_config = convert_to_hf_config(backbone)
weights_dict = {}

# Helper function to convert bfloat16 weights to torch tensors
def to_torch(weight):
# Convert bfloat16 to float32 first, then to torch, then to bfloat16
if hasattr(weight.dtype, "name") and "bfloat16" in str(weight.dtype):
weight = np.array(weight, dtype=np.float32)
return torch.from_numpy(weight).to(torch.bfloat16)

# Token embeddings
token_embedding = backbone.get_layer("token_embedding").get_weights()[0]
weights_dict["model.embed_tokens.weight"] = to_torch(token_embedding)

for i in range(backbone.num_layers):
block = backbone.get_layer(f"decoder_block_{i}")
q_kernel = block.attention.query_dense.get_weights()[0]
q_kernel = (
torch.from_numpy(np.array(q_kernel, dtype=np.float32))
.to(torch.bfloat16)
.permute(1, 0, 2)
.reshape(backbone.hidden_dim, -1)
.T
)
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = q_kernel

k_kernel = block.attention.key_dense.get_weights()[0]
k_kernel = (
torch.from_numpy(np.array(k_kernel, dtype=np.float32))
.to(torch.bfloat16)
.permute(1, 0, 2)
.reshape(backbone.hidden_dim, -1)
.T
)
weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = k_kernel

v_kernel = block.attention.value_dense.get_weights()[0]
v_kernel = (
torch.from_numpy(np.array(v_kernel, dtype=np.float32))
.to(torch.bfloat16)
.permute(1, 0, 2)
.reshape(backbone.hidden_dim, -1)
.T
)
weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = v_kernel

o_kernel = block.attention.output_dense.get_weights()[0]
o_kernel = (
torch.from_numpy(np.array(o_kernel, dtype=np.float32))
.to(torch.bfloat16)
.permute(2, 0, 1)
.reshape(backbone.hidden_dim, -1)
)
weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = o_kernel

q_norm = block.attention.query_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.self_attn.q_norm.weight"] = to_torch(
q_norm
)

k_norm = block.attention.key_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.self_attn.k_norm.weight"] = to_torch(
k_norm
)

gate_kernel = block.gating_ffw.get_weights()[0]
gate_kernel = (
torch.from_numpy(np.array(gate_kernel, dtype=np.float32))
.to(torch.bfloat16)
.T
)
weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = gate_kernel

up_kernel = block.gating_ffw_2.get_weights()[0]
up_kernel = (
torch.from_numpy(np.array(up_kernel, dtype=np.float32))
.to(torch.bfloat16)
.T
)
weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = up_kernel

down_kernel = block.ffw_linear.get_weights()[0]
down_kernel = (
torch.from_numpy(np.array(down_kernel, dtype=np.float32))
.to(torch.bfloat16)
.T
)
weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = down_kernel

input_layer_norm = block.pre_attention_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.input_layernorm.weight"] = to_torch(
input_layer_norm
)

post_attn_norm = block.post_attention_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = (
to_torch(post_attn_norm)
)

pre_feedforward_layernorm_weight = block.pre_ffw_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.pre_feedforward_layernorm.weight"] = (
to_torch(pre_feedforward_layernorm_weight)
)

post_feedforward_layernorm_weight = block.post_ffw_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.post_feedforward_layernorm.weight"] = (
to_torch(post_feedforward_layernorm_weight)
)

final_norm = backbone.get_layer("final_normalization").get_weights()[0]
weights_dict["model.norm.weight"] = to_torch(final_norm)
weights_dict["lm_head.weight"] = weights_dict[
"model.embed_tokens.weight"
].clone()

os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "config.json"), "w") as f:
json.dump(hf_config.to_dict(), f)
weights_dict = {k: v.contiguous() for k, v in weights_dict.items()}
save_file(weights_dict, os.path.join(path, "model.safetensors"))
keras_tokenizer.save_assets(path)
vocab_spm = os.path.join(path, "vocabulary.spm")
tokenizer_model = os.path.join(path, "tokenizer.model")
if os.path.exists(vocab_spm):
shutil.move(vocab_spm, tokenizer_model)
print("Export complete! Files saved in:", path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The functions convert_to_hf_config and export_to_hf duplicate the Hugging Face export logic that is already being added in keras_hub/src/utils/transformers/export/. This introduces significant code duplication and makes future maintenance difficult.

This implementation also uses torch and numpy directly for tensor manipulations, which violates the repository's style guide principle of being backend-agnostic.1

Please remove these duplicated functions and instead use the export_to_transformers method available on the Keras model. The logic in the main function at line 780 should be updated to call this method. For example:

# In main()
# ...
preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor(tokenizer=keras_tokenizer)
causal_lm = keras_hub.models.Gemma3CausalLM(
    backbone=keras_model,
    preprocessor=preprocessor,
)
causal_lm.export_to_transformers(export_dir)
# ...

Style Guide References

Footnotes

  1. All code must be backend agnostic. The duplicated code uses torch-specific operations, violating this principle.

Comment on lines +105 to +128
if hasattr(block, "post_attention_norm"):
post_attn_norm = block.post_attention_norm.weights[0]
else:
# Fallback to pre_ffw_norm if post_attention_norm doesn't exist
post_attn_norm = block.pre_ffw_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.post_attention_layernorm.weight"] = (
post_attn_norm
)

# Pre-feedforward normalization
pre_feedforward_layernorm = block.pre_ffw_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.pre_feedforward_layernorm.weight"] = (
pre_feedforward_layernorm
)

# Post-feedforward normalization
if hasattr(block, "post_ffw_norm"):
post_feedforward_layernorm = block.post_ffw_norm.weights[0]
else:
# Fallback to pre_ffw_norm if post_ffw_norm doesn't exist
post_feedforward_layernorm = block.pre_ffw_norm.weights[0]
weights_dict[
f"{prefix}layers.{i}.post_feedforward_layernorm.weight"
] = post_feedforward_layernorm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fallback logic for post_attention_norm and post_ffw_norm appears to be incorrect. If these layers do not exist on the block (likely because the model was configured with use_post_attention_norm=False or use_post_ffw_norm=False), the Hugging Face model would not expect weights for the corresponding layernorms. Assigning weights from pre_ffw_norm in these cases could lead to a functionally incorrect model.

The weights should only be exported if the corresponding layers exist. Please remove the else blocks for both post_attention_layernorm and post_feedforward_layernorm.

Suggested change
if hasattr(block, "post_attention_norm"):
post_attn_norm = block.post_attention_norm.weights[0]
else:
# Fallback to pre_ffw_norm if post_attention_norm doesn't exist
post_attn_norm = block.pre_ffw_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.post_attention_layernorm.weight"] = (
post_attn_norm
)
# Pre-feedforward normalization
pre_feedforward_layernorm = block.pre_ffw_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.pre_feedforward_layernorm.weight"] = (
pre_feedforward_layernorm
)
# Post-feedforward normalization
if hasattr(block, "post_ffw_norm"):
post_feedforward_layernorm = block.post_ffw_norm.weights[0]
else:
# Fallback to pre_ffw_norm if post_ffw_norm doesn't exist
post_feedforward_layernorm = block.pre_ffw_norm.weights[0]
weights_dict[
f"{prefix}layers.{i}.post_feedforward_layernorm.weight"
] = post_feedforward_layernorm
if hasattr(block, "post_attention_norm"):
post_attn_norm = block.post_attention_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.post_attention_layernorm.weight"] = (
post_attn_norm
)
# Pre-feedforward normalization
pre_feedforward_layernorm = block.pre_ffw_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.pre_feedforward_layernorm.weight"] = (
pre_feedforward_layernorm
)
# Post-feedforward normalization
if hasattr(block, "post_ffw_norm"):
post_feedforward_layernorm = block.post_ffw_norm.weights[0]
weights_dict[
f"{prefix}layers.{i}.post_feedforward_layernorm.weight"
] = post_feedforward_layernorm

Comment on lines +50 to +68
q_kernel = block.attention.query_dense.weights[0]
q_kernel = ops.transpose(q_kernel, axes=(1, 0, 2)) # permute(1, 0, 2)
q_kernel = ops.reshape(q_kernel, (backbone.hidden_dim, -1))
q_kernel = ops.transpose(q_kernel) # .T
weights_dict[f"{prefix}layers.{i}.self_attn.q_proj.weight"] = q_kernel

# Attention key projection
k_kernel = block.attention.key_dense.weights[0]
k_kernel = ops.transpose(k_kernel, axes=(1, 0, 2)) # permute(1, 0, 2)
k_kernel = ops.reshape(k_kernel, (backbone.hidden_dim, -1))
k_kernel = ops.transpose(k_kernel) # .T
weights_dict[f"{prefix}layers.{i}.self_attn.k_proj.weight"] = k_kernel

# Attention value projection
v_kernel = block.attention.value_dense.weights[0]
v_kernel = ops.transpose(v_kernel, axes=(1, 0, 2)) # permute(1, 0, 2)
v_kernel = ops.reshape(v_kernel, (backbone.hidden_dim, -1))
v_kernel = ops.transpose(v_kernel) # .T
weights_dict[f"{prefix}layers.{i}.self_attn.v_proj.weight"] = v_kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for converting the query, key, and value projection kernels is identical across these blocks. This repetition can be refactored into a private helper function to improve code clarity and maintainability, adhering to the DRY (Don't Repeat Yourself) principle.

For example, you could define a helper like _convert_qkv_kernel(kernel, hidden_dim) and call it for each of the q_proj, k_proj, and v_proj weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Gemma Gemma model specific issues kokoro:force-run Runs Tests on GPU

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants