- 
                Notifications
    
You must be signed in to change notification settings  - Fork 307
 
Add T5Gemma to KerasHub #2339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
            divyashreepathihalli
  merged 20 commits into
  keras-team:master
from
harshaljanjani:t5gemma
  
      
      
   
  Aug 25, 2025 
      
    
  
     Merged
                    Add T5Gemma to KerasHub #2339
Changes from 1 commit
      Commits
    
    
            Show all changes
          
          
            20 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      071d0df
              
                init: Add initial project structure and files
              
              
                harshaljanjani 1c9ebbc
              
                nit: Fix code format test; and cool AI-generated reviews
              
              
                harshaljanjani 1c7dc13
              
                refactor: Cleanup and replace incorrect T5LayerNorm with RMSNormaliza…
              
              
                harshaljanjani 41910d3
              
                fix: Numerics @ atol=1e-4
              
              
                harshaljanjani a8eb53c
              
                refactor: Refactor T5Gemma decoder cache handling
              
              
                harshaljanjani 95f563b
              
                feat: Add checkpoint conversion script
              
              
                harshaljanjani afb9845
              
                nit: Precise compute_output_shape methods; document head_dim
              
              
                harshaljanjani 5be6438
              
                nit: Propagate dtypes
              
              
                harshaljanjani 3dbc0b7
              
                bug fix + minor cleanup: Fix head_dim default → head_dim from config
              
              
                harshaljanjani 291d8f1
              
                perf(jax/tpu): Fused kernel optim for TPU backend + get_config() args
              
              
                harshaljanjani 524aa37
              
                cleanup: Slight refactor
              
              
                harshaljanjani c1af495
              
                Merge branch 'keras-team:master' into t5gemma
              
              
                harshaljanjani 889e23b
              
                fix: Enable mixed precision and quantization tests
              
              
                harshaljanjani 32a6912
              
                feat: Add support for asymmetrical presets (only invariants included)
              
              
                harshaljanjani 050910b
              
                refactor: Address reviews - presets will be handled post D-FINE
              
              
                harshaljanjani 6b320fa
              
                feat: Support direct loading of Hugging Face checkpoints
              
              
                harshaljanjani 26db4d1
              
                ✅ Yayy: Generate outputs identical, hidden states match within 1e-3
              
              
                harshaljanjani 87a221d
              
                preset test: Register and test a preset (to be replaced later by the …
              
              
                harshaljanjani 9c79058
              
                nit: Sharded weights don’t include `model.weights.h5`
              
              
                harshaljanjani f7e356f
              
                nits: Address reviews + replace gated model
              
              
                harshaljanjani File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
      
      Oops, something went wrong.
      
    
  
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,236 @@ | ||
| import keras | ||
| 
     | 
||
| from keras_hub.src.api_export import keras_hub_export | ||
| from keras_hub.src.layers.modeling.reversible_embedding import ( | ||
| ReversibleEmbedding, | ||
| ) | ||
| from keras_hub.src.models.backbone import Backbone | ||
| from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm | ||
| from keras_hub.src.models.t5gemma.t5gemma_decoder import T5GemmaDecoderLayer | ||
| from keras_hub.src.models.t5gemma.t5gemma_encoder import T5GemmaEncoderLayer | ||
| 
     | 
||
| 
     | 
||
| @keras_hub_export("keras_hub.models.T5GemmaBackbone") | ||
| class T5GemmaBackbone(Backbone): | ||
| """T5Gemma backbone model. | ||
| 
     | 
||
| This class implements the encoder-decoder backbone of the T5Gemma model, | ||
| consisting of an embedding layer, a stack of encoder layers, and a | ||
| stack of decoder layers. | ||
| 
     | 
||
| Args: | ||
| vocabulary_size: int, The size of the vocabulary. | ||
| hidden_dim: int, The dimensionality of the hidden states throughout the | ||
| model. | ||
| intermediate_dim: int, The intermediate size of the feed-forward | ||
| networks in encoder and decoder layers. | ||
| num_layers: int, The number of encoder and decoder layers. | ||
| num_attention_heads: int, The number of attention heads in all attention | ||
| mechanisms. | ||
| num_key_value_heads: int, The number of key-value heads for grouped | ||
| query attention in all attention mechanisms. | ||
| dropout_rate: float, The dropout rate applied throughout the model. | ||
| rms_norm_eps: float, The epsilon value for RMS normalization. | ||
| query_pre_attn_scalar: float, Scalar to multiply queries by before | ||
| attention. | ||
| attention_bias: bool, Whether to include bias in attention computations. | ||
| hidden_activation: str, The activation function used in the feed-forward | ||
| networks. | ||
| layer_types: list of str, A list of strings specifying the type of | ||
| attention layer for each encoder/decoder layer. Each element can be | ||
| either `"sliding_attention"` or `"full_attention"`. For example, | ||
| `["full_attention", "sliding_attention", ...]`. | ||
| tie_word_embeddings: bool, Whether to tie input and output word | ||
| embeddings. Default is `True`. | ||
| initializer_range: float, The range for the random normal initializer. | ||
| Default is `0.02`. | ||
| attention_dropout: float, The dropout rate applied to attention weights. | ||
| Default is `0.0`. | ||
| sliding_window: int, optional, The window size for sliding attention. | ||
| Required if any `layer_type` is `"sliding_attention"`. | ||
| cross_attention_hidden_size: int, optional, The hidden size for | ||
| cross-attention in the decoder layers. If None, it defaults to | ||
| `hidden_dim`. | ||
| attn_logit_softcapping: float, optional, The softcapping value for | ||
| attention logits. | ||
| final_logit_softcapping: float, optional, The softcapping value for | ||
| final logits. | ||
| rope_max_wavelength: float, The maximum wavelength for Rotary Positional | ||
| Embeddings. Default is `10000.0`. | ||
| **kwargs: Additional keyword arguments passed to the parent `Backbone` | ||
| class. | ||
| """ | ||
| 
     | 
||
| def __init__( | ||
| self, | ||
| vocabulary_size, | ||
| hidden_dim, | ||
| intermediate_dim, | ||
| num_layers, | ||
| num_attention_heads, | ||
| num_key_value_heads, | ||
| dropout_rate, | ||
| rms_norm_eps, | ||
| query_pre_attn_scalar, | ||
| attention_bias, | ||
| hidden_activation, | ||
| layer_types, | ||
| tie_word_embeddings=True, | ||
| initializer_range=0.02, | ||
| attention_dropout=0.0, | ||
| sliding_window=None, | ||
| cross_attention_hidden_size=None, | ||
| attn_logit_softcapping=None, | ||
| final_logit_softcapping=None, | ||
| rope_max_wavelength=10000.0, | ||
| **kwargs, | ||
| ): | ||
| # === Layers === | ||
| self.token_embedding = ReversibleEmbedding( | ||
| input_dim=vocabulary_size, | ||
| output_dim=hidden_dim, | ||
| tie_weights=tie_word_embeddings, | ||
| ) | ||
| self.encoder_layers = [ | ||
| T5GemmaEncoderLayer( | ||
| hidden_size=hidden_dim, | ||
| rms_norm_eps=rms_norm_eps, | ||
| num_attention_heads=num_attention_heads, | ||
| num_key_value_heads=num_key_value_heads, | ||
| query_pre_attn_scalar=query_pre_attn_scalar, | ||
| attention_bias=attention_bias, | ||
| intermediate_size=intermediate_dim, | ||
| hidden_activation=hidden_activation, | ||
| dropout_rate=dropout_rate, | ||
| initializer_range=initializer_range, | ||
| attention_dropout=attention_dropout, | ||
| layer_type=layer_types[i], | ||
| sliding_window=sliding_window, | ||
| attn_logit_softcapping=attn_logit_softcapping, | ||
| rope_max_wavelength=rope_max_wavelength, | ||
| name=f"encoder_layer_{i}", | ||
| ) | ||
| for i in range(num_layers) | ||
| ] | ||
| self.encoder_norm = T5LayerNorm(epsilon=rms_norm_eps) | ||
| self.encoder_dropout = keras.layers.Dropout(dropout_rate) | ||
| self.decoder_layers = [ | ||
| T5GemmaDecoderLayer( | ||
| hidden_size=hidden_dim, | ||
| rms_norm_eps=rms_norm_eps, | ||
| num_attention_heads=num_attention_heads, | ||
| num_key_value_heads=num_key_value_heads, | ||
| query_pre_attn_scalar=query_pre_attn_scalar, | ||
| attention_bias=attention_bias, | ||
| intermediate_size=intermediate_dim, | ||
| hidden_activation=hidden_activation, | ||
| dropout_rate=dropout_rate, | ||
| initializer_range=initializer_range, | ||
| attention_dropout=attention_dropout, | ||
| layer_type=layer_types[i], | ||
| sliding_window=sliding_window, | ||
| cross_attention_hidden_size=cross_attention_hidden_size, | ||
| attn_logit_softcapping=attn_logit_softcapping, | ||
| rope_max_wavelength=rope_max_wavelength, | ||
| name=f"decoder_layer_{i}", | ||
| ) | ||
| for i in range(num_layers) | ||
| ] | ||
| self.decoder_norm = T5LayerNorm(epsilon=rms_norm_eps) | ||
| self.decoder_dropout = keras.layers.Dropout(dropout_rate) | ||
| 
     | 
||
| # === Functional Model === | ||
| token_id_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="token_ids" | ||
| ) | ||
| padding_mask_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="padding_mask" | ||
| ) | ||
| 
     | 
||
| # Encoder. | ||
| encoder_embeddings = self.token_embedding(token_id_input) | ||
| encoder_embeddings = encoder_embeddings * keras.ops.cast( | ||
| keras.ops.sqrt(hidden_dim), encoder_embeddings.dtype | ||
| ) | ||
| encoder_hidden_states = self.encoder_dropout(encoder_embeddings) | ||
| for layer in self.encoder_layers: | ||
| encoder_hidden_states = layer( | ||
| encoder_hidden_states, | ||
| padding_mask=padding_mask_input, | ||
| ) | ||
| encoder_output = self.encoder_norm(encoder_hidden_states) | ||
| encoder_output = self.encoder_dropout(encoder_output) | ||
| 
     | 
||
| # Decoder. | ||
| decoder_embeddings = self.token_embedding(token_id_input) | ||
| decoder_embeddings = decoder_embeddings * keras.ops.cast( | ||
| keras.ops.sqrt(hidden_dim), decoder_embeddings.dtype | ||
| ) | ||
| decoder_hidden_states = self.decoder_dropout(decoder_embeddings) | ||
| for layer in self.decoder_layers: | ||
| decoder_hidden_states, _ = layer( | ||
| (decoder_hidden_states, encoder_output), | ||
| self_attention_padding_mask=padding_mask_input, | ||
| cross_attention_padding_mask=padding_mask_input, | ||
| ) | ||
| decoder_output = self.decoder_norm(decoder_hidden_states) | ||
| decoder_output = self.decoder_dropout(decoder_output) | ||
| 
     | 
||
| super().__init__( | ||
| inputs={ | ||
| "token_ids": token_id_input, | ||
| "padding_mask": padding_mask_input, | ||
| }, | ||
| outputs=decoder_output, | ||
| **kwargs, | ||
| ) | ||
| 
     | 
||
| # === Config === | ||
| self.vocabulary_size = vocabulary_size | ||
| self.hidden_dim = hidden_dim | ||
| self.intermediate_dim = intermediate_dim | ||
| self.num_layers = num_layers | ||
| self.num_attention_heads = num_attention_heads | ||
| self.num_key_value_heads = num_key_value_heads | ||
| self.dropout_rate = dropout_rate | ||
| self.rms_norm_eps = rms_norm_eps | ||
| self.tie_word_embeddings = tie_word_embeddings | ||
| self.query_pre_attn_scalar = query_pre_attn_scalar | ||
| self.attention_bias = attention_bias | ||
| self.hidden_activation = hidden_activation | ||
| self.layer_types = layer_types | ||
| self.initializer_range = initializer_range | ||
| self.attention_dropout = attention_dropout | ||
| self.sliding_window = sliding_window | ||
| self.cross_attention_hidden_size = cross_attention_hidden_size | ||
| self.attn_logit_softcapping = attn_logit_softcapping | ||
| self.final_logit_softcapping = final_logit_softcapping | ||
| self.rope_max_wavelength = rope_max_wavelength | ||
| 
     | 
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "vocabulary_size": self.vocabulary_size, | ||
| "hidden_dim": self.hidden_dim, | ||
| "intermediate_dim": self.intermediate_dim, | ||
| "num_layers": self.num_layers, | ||
| "num_attention_heads": self.num_attention_heads, | ||
| "num_key_value_heads": self.num_key_value_heads, | ||
| "dropout_rate": self.dropout_rate, | ||
| "rms_norm_eps": self.rms_norm_eps, | ||
| "tie_word_embeddings": self.tie_word_embeddings, | ||
| "query_pre_attn_scalar": self.query_pre_attn_scalar, | ||
| "attention_bias": self.attention_bias, | ||
| "hidden_activation": self.hidden_activation, | ||
| "layer_types": self.layer_types, | ||
| "initializer_range": self.initializer_range, | ||
| "attention_dropout": self.attention_dropout, | ||
| "sliding_window": self.sliding_window, | ||
| "cross_attention_hidden_size": self.cross_attention_hidden_size, | ||
| "attn_logit_softcapping": self.attn_logit_softcapping, | ||
| "final_logit_softcapping": self.final_logit_softcapping, | ||
| "rope_max_wavelength": self.rope_max_wavelength, | ||
| } | ||
| ) | ||
| return config | ||
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| import keras | ||
| import pytest | ||
| 
     | 
||
| from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone | ||
| from keras_hub.src.tests.test_case import TestCase | ||
| 
     | 
||
| 
     | 
||
| class T5GemmaBackboneTest(TestCase): | ||
| def setUp(self): | ||
| self.init_kwargs = { | ||
| "vocabulary_size": 100, | ||
| "hidden_dim": 32, | ||
| "intermediate_dim": 64, | ||
| "num_layers": 2, | ||
| "num_attention_heads": 4, | ||
| "num_key_value_heads": 2, | ||
| "dropout_rate": 0.1, | ||
| "rms_norm_eps": 1e-6, | ||
| "tie_word_embeddings": True, | ||
| "query_pre_attn_scalar": 1.0, | ||
| "attention_bias": False, | ||
| "hidden_activation": "gelu_approximate", | ||
| "layer_types": ["sliding_attention", "full_attention"], | ||
| "sliding_window": 16, | ||
| "cross_attention_hidden_size": 32, | ||
| "attn_logit_softcapping": 50.0, | ||
| "rope_max_wavelength": 10000.0, | ||
| "initializer_range": 0.02, | ||
| "attention_dropout": 0.0, | ||
                
      
                  harshaljanjani marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| } | ||
| self.input_data = { | ||
| "token_ids": keras.ops.ones((2, 16), dtype="int32"), | ||
| "padding_mask": keras.ops.ones((2, 16), dtype="int32"), | ||
| } | ||
| 
     | 
||
| def test_backbone_basics(self): | ||
| self.run_backbone_test( | ||
| cls=T5GemmaBackbone, | ||
| init_kwargs=self.init_kwargs, | ||
| input_data=self.input_data, | ||
| expected_output_shape=(2, 16, 32), | ||
| run_mixed_precision_check=False, | ||
                
      
                  harshaljanjani marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| run_quantization_check=False, | ||
| ) | ||
| 
     | 
||
| @pytest.mark.large | ||
| def test_saved_model(self): | ||
| self.run_model_saving_test( | ||
| cls=T5GemmaBackbone, | ||
| init_kwargs=self.init_kwargs, | ||
| input_data=self.input_data, | ||
| ) | ||
      
      Oops, something went wrong.
        
    
  
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.