Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
071d0df
init: Add initial project structure and files
harshaljanjani Jul 19, 2025
1c9ebbc
nit: Fix code format test; and cool AI-generated reviews
harshaljanjani Jul 19, 2025
1c7dc13
refactor: Cleanup and replace incorrect T5LayerNorm with RMSNormaliza…
harshaljanjani Jul 21, 2025
41910d3
fix: Numerics @ atol=1e-4
harshaljanjani Jul 22, 2025
a8eb53c
refactor: Refactor T5Gemma decoder cache handling
harshaljanjani Jul 23, 2025
95f563b
feat: Add checkpoint conversion script
harshaljanjani Jul 23, 2025
afb9845
nit: Precise compute_output_shape methods; document head_dim
harshaljanjani Jul 24, 2025
5be6438
nit: Propagate dtypes
harshaljanjani Jul 24, 2025
3dbc0b7
bug fix + minor cleanup: Fix head_dim default → head_dim from config
harshaljanjani Jul 24, 2025
291d8f1
perf(jax/tpu): Fused kernel optim for TPU backend + get_config() args
harshaljanjani Jul 25, 2025
524aa37
cleanup: Slight refactor
harshaljanjani Jul 25, 2025
c1af495
Merge branch 'keras-team:master' into t5gemma
harshaljanjani Jul 26, 2025
889e23b
fix: Enable mixed precision and quantization tests
harshaljanjani Jul 30, 2025
32a6912
feat: Add support for asymmetrical presets (only invariants included)
harshaljanjani Jul 30, 2025
050910b
refactor: Address reviews - presets will be handled post D-FINE
harshaljanjani Aug 6, 2025
6b320fa
feat: Support direct loading of Hugging Face checkpoints
harshaljanjani Aug 17, 2025
26db4d1
✅ Yayy: Generate outputs identical, hidden states match within 1e-3
harshaljanjani Aug 21, 2025
87a221d
preset test: Register and test a preset (to be replaced later by the …
harshaljanjani Aug 22, 2025
9c79058
nit: Sharded weights don’t include `model.weights.h5`
harshaljanjani Aug 24, 2025
f7e356f
nits: Address reviews + replace gated model
harshaljanjani Aug 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
480 changes: 480 additions & 0 deletions keras_hub/src/models/t5gemma/t5gemma_attention.py

Large diffs are not rendered by default.

236 changes: 236 additions & 0 deletions keras_hub/src/models/t5gemma/t5gemma_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.modeling.reversible_embedding import (
ReversibleEmbedding,
)
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm
from keras_hub.src.models.t5gemma.t5gemma_decoder import T5GemmaDecoderLayer
from keras_hub.src.models.t5gemma.t5gemma_encoder import T5GemmaEncoderLayer


@keras_hub_export("keras_hub.models.T5GemmaBackbone")
class T5GemmaBackbone(Backbone):
"""T5Gemma backbone model.

This class implements the encoder-decoder backbone of the T5Gemma model,
consisting of an embedding layer, a stack of encoder layers, and a
stack of decoder layers.

Args:
vocabulary_size: int, The size of the vocabulary.
hidden_dim: int, The dimensionality of the hidden states throughout the
model.
intermediate_dim: int, The intermediate size of the feed-forward
networks in encoder and decoder layers.
num_layers: int, The number of encoder and decoder layers.
num_attention_heads: int, The number of attention heads in all attention
mechanisms.
num_key_value_heads: int, The number of key-value heads for grouped
query attention in all attention mechanisms.
dropout_rate: float, The dropout rate applied throughout the model.
rms_norm_eps: float, The epsilon value for RMS normalization.
query_pre_attn_scalar: float, Scalar to multiply queries by before
attention.
attention_bias: bool, Whether to include bias in attention computations.
hidden_activation: str, The activation function used in the feed-forward
networks.
layer_types: list of str, A list of strings specifying the type of
attention layer for each encoder/decoder layer. Each element can be
either `"sliding_attention"` or `"full_attention"`. For example,
`["full_attention", "sliding_attention", ...]`.
tie_word_embeddings: bool, Whether to tie input and output word
embeddings. Default is `True`.
initializer_range: float, The range for the random normal initializer.
Default is `0.02`.
attention_dropout: float, The dropout rate applied to attention weights.
Default is `0.0`.
sliding_window: int, optional, The window size for sliding attention.
Required if any `layer_type` is `"sliding_attention"`.
cross_attention_hidden_size: int, optional, The hidden size for
cross-attention in the decoder layers. If None, it defaults to
`hidden_dim`.
attn_logit_softcapping: float, optional, The softcapping value for
attention logits.
final_logit_softcapping: float, optional, The softcapping value for
final logits.
rope_max_wavelength: float, The maximum wavelength for Rotary Positional
Embeddings. Default is `10000.0`.
**kwargs: Additional keyword arguments passed to the parent `Backbone`
class.
"""

def __init__(
self,
vocabulary_size,
hidden_dim,
intermediate_dim,
num_layers,
num_attention_heads,
num_key_value_heads,
dropout_rate,
rms_norm_eps,
query_pre_attn_scalar,
attention_bias,
hidden_activation,
layer_types,
tie_word_embeddings=True,
initializer_range=0.02,
attention_dropout=0.0,
sliding_window=None,
cross_attention_hidden_size=None,
attn_logit_softcapping=None,
final_logit_softcapping=None,
rope_max_wavelength=10000.0,
**kwargs,
):
# === Layers ===
self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
tie_weights=tie_word_embeddings,
)
self.encoder_layers = [
T5GemmaEncoderLayer(
hidden_size=hidden_dim,
rms_norm_eps=rms_norm_eps,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
query_pre_attn_scalar=query_pre_attn_scalar,
attention_bias=attention_bias,
intermediate_size=intermediate_dim,
hidden_activation=hidden_activation,
dropout_rate=dropout_rate,
initializer_range=initializer_range,
attention_dropout=attention_dropout,
layer_type=layer_types[i],
sliding_window=sliding_window,
attn_logit_softcapping=attn_logit_softcapping,
rope_max_wavelength=rope_max_wavelength,
name=f"encoder_layer_{i}",
)
for i in range(num_layers)
]
self.encoder_norm = T5LayerNorm(epsilon=rms_norm_eps)
self.encoder_dropout = keras.layers.Dropout(dropout_rate)
self.decoder_layers = [
T5GemmaDecoderLayer(
hidden_size=hidden_dim,
rms_norm_eps=rms_norm_eps,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
query_pre_attn_scalar=query_pre_attn_scalar,
attention_bias=attention_bias,
intermediate_size=intermediate_dim,
hidden_activation=hidden_activation,
dropout_rate=dropout_rate,
initializer_range=initializer_range,
attention_dropout=attention_dropout,
layer_type=layer_types[i],
sliding_window=sliding_window,
cross_attention_hidden_size=cross_attention_hidden_size,
attn_logit_softcapping=attn_logit_softcapping,
rope_max_wavelength=rope_max_wavelength,
name=f"decoder_layer_{i}",
)
for i in range(num_layers)
]
self.decoder_norm = T5LayerNorm(epsilon=rms_norm_eps)
self.decoder_dropout = keras.layers.Dropout(dropout_rate)

# === Functional Model ===
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

# Encoder.
encoder_embeddings = self.token_embedding(token_id_input)
encoder_embeddings = encoder_embeddings * keras.ops.cast(
keras.ops.sqrt(hidden_dim), encoder_embeddings.dtype
)
encoder_hidden_states = self.encoder_dropout(encoder_embeddings)
for layer in self.encoder_layers:
encoder_hidden_states = layer(
encoder_hidden_states,
padding_mask=padding_mask_input,
)
encoder_output = self.encoder_norm(encoder_hidden_states)
encoder_output = self.encoder_dropout(encoder_output)

# Decoder.
decoder_embeddings = self.token_embedding(token_id_input)
decoder_embeddings = decoder_embeddings * keras.ops.cast(
keras.ops.sqrt(hidden_dim), decoder_embeddings.dtype
)
decoder_hidden_states = self.decoder_dropout(decoder_embeddings)
for layer in self.decoder_layers:
decoder_hidden_states, _ = layer(
(decoder_hidden_states, encoder_output),
self_attention_padding_mask=padding_mask_input,
cross_attention_padding_mask=padding_mask_input,
)
decoder_output = self.decoder_norm(decoder_hidden_states)
decoder_output = self.decoder_dropout(decoder_output)

super().__init__(
inputs={
"token_ids": token_id_input,
"padding_mask": padding_mask_input,
},
outputs=decoder_output,
**kwargs,
)

# === Config ===
self.vocabulary_size = vocabulary_size
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.dropout_rate = dropout_rate
self.rms_norm_eps = rms_norm_eps
self.tie_word_embeddings = tie_word_embeddings
self.query_pre_attn_scalar = query_pre_attn_scalar
self.attention_bias = attention_bias
self.hidden_activation = hidden_activation
self.layer_types = layer_types
self.initializer_range = initializer_range
self.attention_dropout = attention_dropout
self.sliding_window = sliding_window
self.cross_attention_hidden_size = cross_attention_hidden_size
self.attn_logit_softcapping = attn_logit_softcapping
self.final_logit_softcapping = final_logit_softcapping
self.rope_max_wavelength = rope_max_wavelength

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"num_key_value_heads": self.num_key_value_heads,
"dropout_rate": self.dropout_rate,
"rms_norm_eps": self.rms_norm_eps,
"tie_word_embeddings": self.tie_word_embeddings,
"query_pre_attn_scalar": self.query_pre_attn_scalar,
"attention_bias": self.attention_bias,
"hidden_activation": self.hidden_activation,
"layer_types": self.layer_types,
"initializer_range": self.initializer_range,
"attention_dropout": self.attention_dropout,
"sliding_window": self.sliding_window,
"cross_attention_hidden_size": self.cross_attention_hidden_size,
"attn_logit_softcapping": self.attn_logit_softcapping,
"final_logit_softcapping": self.final_logit_softcapping,
"rope_max_wavelength": self.rope_max_wavelength,
}
)
return config
52 changes: 52 additions & 0 deletions keras_hub/src/models/t5gemma/t5gemma_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import keras
import pytest

from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone
from keras_hub.src.tests.test_case import TestCase


class T5GemmaBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"vocabulary_size": 100,
"hidden_dim": 32,
"intermediate_dim": 64,
"num_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"dropout_rate": 0.1,
"rms_norm_eps": 1e-6,
"tie_word_embeddings": True,
"query_pre_attn_scalar": 1.0,
"attention_bias": False,
"hidden_activation": "gelu_approximate",
"layer_types": ["sliding_attention", "full_attention"],
"sliding_window": 16,
"cross_attention_hidden_size": 32,
"attn_logit_softcapping": 50.0,
"rope_max_wavelength": 10000.0,
"initializer_range": 0.02,
"attention_dropout": 0.0,
}
self.input_data = {
"token_ids": keras.ops.ones((2, 16), dtype="int32"),
"padding_mask": keras.ops.ones((2, 16), dtype="int32"),
}

def test_backbone_basics(self):
self.run_backbone_test(
cls=T5GemmaBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 16, 32),
run_mixed_precision_check=False,
run_quantization_check=False,
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=T5GemmaBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
Loading
Loading