-
Notifications
You must be signed in to change notification settings - Fork 307
Add T5Gemma to KerasHub #2339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
divyashreepathihalli
merged 20 commits into
keras-team:master
from
harshaljanjani:t5gemma
Aug 25, 2025
Merged
Add T5Gemma to KerasHub #2339
Changes from 6 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
071d0df
init: Add initial project structure and files
harshaljanjani 1c9ebbc
nit: Fix code format test; and cool AI-generated reviews
harshaljanjani 1c7dc13
refactor: Cleanup and replace incorrect T5LayerNorm with RMSNormaliza…
harshaljanjani 41910d3
fix: Numerics @ atol=1e-4
harshaljanjani a8eb53c
refactor: Refactor T5Gemma decoder cache handling
harshaljanjani 95f563b
feat: Add checkpoint conversion script
harshaljanjani afb9845
nit: Precise compute_output_shape methods; document head_dim
harshaljanjani 5be6438
nit: Propagate dtypes
harshaljanjani 3dbc0b7
bug fix + minor cleanup: Fix head_dim default → head_dim from config
harshaljanjani 291d8f1
perf(jax/tpu): Fused kernel optim for TPU backend + get_config() args
harshaljanjani 524aa37
cleanup: Slight refactor
harshaljanjani c1af495
Merge branch 'keras-team:master' into t5gemma
harshaljanjani 889e23b
fix: Enable mixed precision and quantization tests
harshaljanjani 32a6912
feat: Add support for asymmetrical presets (only invariants included)
harshaljanjani 050910b
refactor: Address reviews - presets will be handled post D-FINE
harshaljanjani 6b320fa
feat: Support direct loading of Hugging Face checkpoints
harshaljanjani 26db4d1
✅ Yayy: Generate outputs identical, hidden states match within 1e-3
harshaljanjani 87a221d
preset test: Register and test a preset (to be replaced later by the …
harshaljanjani 9c79058
nit: Sharded weights don’t include `model.weights.h5`
harshaljanjani f7e356f
nits: Address reviews + replace gated model
harshaljanjani File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,349 @@ | ||
| import keras | ||
|
|
||
| from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding | ||
| from keras_hub.src.models.gemma.gemma_attention import CachedGemmaAttention | ||
| from keras_hub.src.models.t5gemma.t5gemma_layers import ( | ||
| t5gemma_kernel_initializer, | ||
| ) | ||
| from keras_hub.src.utils.keras_utils import clone_initializer | ||
|
|
||
|
|
||
| def repeat_kv(hidden_states, n_rep): | ||
| """Repeats the key/value hidden states to match the number of query heads | ||
| for Grouped Query Attention (GQA). | ||
|
|
||
| This function is used in `T5GemmaAttention` to broadcast key and value | ||
| states across multiple query heads when Grouped Query Attention (GQA) is | ||
| used (i.e., when `num_query_heads` > `num_key_value_heads`). | ||
|
|
||
| Args: | ||
| hidden_states: Tensor, The key or value hidden states with shape | ||
| `(batch, num_key_value_heads, sequence_length, head_dim)`. | ||
| n_rep: int, The number of times to repeat the key/value heads. This is | ||
| typically `num_query_heads // num_key_value_heads`. | ||
|
|
||
| Returns: | ||
| Tensor: The expanded key/value hidden states with shape | ||
| `(batch, num_query_heads, sequence_length, head_dim)`. | ||
| """ | ||
| if n_rep == 1: | ||
| return hidden_states | ||
| batch, num_key_value_heads, slen, head_dim = keras.ops.shape(hidden_states) | ||
| hidden_states = keras.ops.expand_dims(hidden_states, 2) | ||
| hidden_states = keras.ops.tile(hidden_states, (1, 1, n_rep, 1, 1)) | ||
| return keras.ops.reshape( | ||
| hidden_states, (batch, num_key_value_heads * n_rep, slen, head_dim) | ||
| ) | ||
|
|
||
|
|
||
| @keras.saving.register_keras_serializable(package="keras_hub") | ||
| class T5GemmaAttention(CachedGemmaAttention): | ||
| """A unified attention layer for T5Gemma that handles both self-attention | ||
| and cross-attention. | ||
|
|
||
| This layer performs attention with optional Rotary Positional Embeddings | ||
| (RoPE) and supports Grouped Query Attention (GQA). It is used in | ||
| `T5GemmaEncoderLayer` and `T5GemmaDecoderLayer`. | ||
|
|
||
| Args: | ||
| hidden_size: int, The dimensionality of the hidden states. | ||
| num_attention_heads: int, The number of attention heads. | ||
| num_key_value_heads: int, The number of key-value heads. For GQA, this | ||
| can be less than `num_attention_heads`. | ||
| query_pre_attn_scalar: float, Scalar to multiply queries by before | ||
| attention. | ||
| attention_bias: bool, Whether to include bias in the dense layers. | ||
| attention_type: str, The type of attention, either 'self' or 'cross'. | ||
| Defaults to 'self'. | ||
| cross_attention_hidden_size: int, optional, The dimensionality of | ||
| encoder hidden states for cross-attention. | ||
| initializer_range: float, The range for the random normal initializer | ||
| for kernel weights. Default is `0.02`. | ||
| attention_dropout: float, The dropout rate applied to attention weights. | ||
| Default is `0.0`. | ||
| attn_logit_softcapping: float, optional, The softcapping value for | ||
| attention logits. | ||
| rope_max_wavelength: float, The maximum wavelength for Rotary Positional | ||
| Embeddings. Default is `10000.0`. Only used for self-attention. | ||
| **kwargs: Additional keyword arguments passed to the parent class. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| hidden_size, | ||
| num_attention_heads, | ||
| num_key_value_heads, | ||
| query_pre_attn_scalar, | ||
| attention_bias, | ||
| head_dim, | ||
| attention_type="self", | ||
| cross_attention_hidden_size=None, | ||
| initializer_range=0.02, | ||
| attention_dropout=0.0, | ||
| attn_logit_softcapping=None, | ||
| rope_max_wavelength=10000.0, | ||
| **kwargs, | ||
| ): | ||
| super().__init__( | ||
| head_dim=head_dim, | ||
| num_query_heads=num_attention_heads, | ||
| num_key_value_heads=num_key_value_heads, | ||
| kernel_initializer=t5gemma_kernel_initializer(initializer_range), | ||
| logit_soft_cap=attn_logit_softcapping, | ||
| dropout=attention_dropout, | ||
| query_head_dim_normalize=False, | ||
| use_sliding_window_attention=False, | ||
| **kwargs, | ||
| ) | ||
| if attention_type not in ["self", "cross"]: | ||
| raise ValueError( | ||
| f"attention_type must be 'self' or 'cross', but got " | ||
| f"{attention_type}" | ||
| ) | ||
| self.attention_type = attention_type | ||
| self.hidden_size = hidden_size | ||
| self.cross_attention_hidden_size = ( | ||
| cross_attention_hidden_size or hidden_size | ||
| ) | ||
| self.query_pre_attn_scalar = query_pre_attn_scalar | ||
| self.attention_bias = attention_bias | ||
| self.initializer_range = initializer_range | ||
| self.attention_dropout = attention_dropout | ||
| self.rope_max_wavelength = rope_max_wavelength | ||
| self.num_key_value_groups = ( | ||
| self.num_query_heads // self.num_key_value_heads | ||
| ) | ||
| self.scaling = self.query_pre_attn_scalar**-0.5 | ||
| if self.attention_type == "self": | ||
| self.rotary_embedding = RotaryEmbedding( | ||
| max_wavelength=self.rope_max_wavelength, | ||
| sequence_axis=2, | ||
| feature_axis=3, | ||
| name="rotary_embedding", | ||
| ) | ||
|
|
||
| def build(self, input_shape): | ||
| self._kernel_initializer = t5gemma_kernel_initializer( | ||
| self.initializer_range | ||
| ) | ||
|
|
||
| if self.attention_type == "cross": | ||
| hidden_states_shape, kv_states_shape = input_shape | ||
| else: | ||
| hidden_states_shape = input_shape | ||
| kv_states_shape = input_shape | ||
| # Query projection layer. | ||
| self.hidden_dim = hidden_states_shape[-1] | ||
| self.query_dense = keras.layers.EinsumDense( | ||
| equation="btd,dnh->bnth", | ||
| output_shape=(self.num_query_heads, None, self.head_dim), | ||
| kernel_initializer=clone_initializer(self._kernel_initializer), | ||
| bias_axes="nh" if self.attention_bias else None, | ||
| dtype=self.dtype_policy, | ||
| name="query", | ||
| ) | ||
| self.query_dense.build(hidden_states_shape) | ||
|
|
||
| # Key projection layer. | ||
| self.key_dense = keras.layers.EinsumDense( | ||
| equation="bsd,dkh->bksh", | ||
| output_shape=(self.num_key_value_heads, None, self.head_dim), | ||
| kernel_initializer=clone_initializer(self._kernel_initializer), | ||
| bias_axes="kh" if self.attention_bias else None, | ||
| dtype=self.dtype_policy, | ||
| name="key", | ||
| ) | ||
| self.key_dense.build(kv_states_shape) | ||
|
|
||
| # Value projection layer. | ||
| self.value_dense = keras.layers.EinsumDense( | ||
| equation="bsd,dkh->bksh", | ||
| output_shape=(self.num_key_value_heads, None, self.head_dim), | ||
| kernel_initializer=clone_initializer(self._kernel_initializer), | ||
| bias_axes="kh" if self.attention_bias else None, | ||
| dtype=self.dtype_policy, | ||
| name="value", | ||
| ) | ||
| self.value_dense.build(kv_states_shape) | ||
|
|
||
| # Output projection layer. | ||
| self.output_dense = keras.layers.EinsumDense( | ||
| equation="bnth,nhd->btd", | ||
| output_shape=(None, self.hidden_dim), | ||
| kernel_initializer=clone_initializer(self._kernel_initializer), | ||
| bias_axes="d" if self.attention_bias else None, | ||
| dtype=self.dtype_policy, | ||
| name="attention_output", | ||
| ) | ||
| self.output_dense.build( | ||
| ( | ||
| hidden_states_shape[0], | ||
| self.num_query_heads, | ||
| hidden_states_shape[1], | ||
| self.head_dim, | ||
| ) | ||
| ) | ||
| self.dropout_layer = keras.layers.Dropout( | ||
| rate=self.attention_dropout, | ||
| dtype=self.dtype_policy, | ||
| ) | ||
| self.softmax = keras.layers.Softmax(dtype="float32") | ||
| self.built = True | ||
|
|
||
| def call( | ||
| self, | ||
| inputs, | ||
| attention_mask=None, | ||
| cache=None, | ||
| cache_update_index=None, | ||
| training=None, | ||
| ): | ||
| if self.attention_type == "cross": | ||
| if not isinstance(inputs, (list, tuple)) or len(inputs) != 2: | ||
| raise ValueError( | ||
| "For cross-attention, `inputs` must be a list or tuple of " | ||
| "two tensors: `[hidden_states, encoder_hidden_states]`." | ||
| ) | ||
| hidden_states, kv_states = inputs | ||
| query_states = self.query_dense(hidden_states) | ||
| if cache is not None: | ||
| if cache_update_index is not None: | ||
| raise ValueError( | ||
| "`cache_update_index` should not be set for " | ||
| "cross-attention caching." | ||
| ) | ||
| key_states, value_states = cache[:, 0, ...], cache[:, 1, ...] | ||
| updated_cache = cache | ||
| else: | ||
| key_states = self.key_dense(kv_states) | ||
| value_states = self.value_dense(kv_states) | ||
| updated_cache = keras.ops.stack( | ||
| (key_states, value_states), axis=1 | ||
| ) | ||
| # Repeat key-value heads for GQA. | ||
| key_states = repeat_kv(key_states, self.num_key_value_groups) | ||
| value_states = repeat_kv(value_states, self.num_key_value_groups) | ||
| attn_weights = keras.ops.einsum( | ||
| "bnth,bnsh->bnts", query_states, key_states | ||
| ) | ||
| attn_weights *= self.scaling | ||
| if self.logit_soft_cap is not None: | ||
| attn_weights = attn_weights / self.logit_soft_cap | ||
| attn_weights = keras.ops.tanh(attn_weights) | ||
| attn_weights = attn_weights * self.logit_soft_cap | ||
| if attention_mask is not None: | ||
| attn_weights += attention_mask | ||
| attn_weights = keras.ops.cast( | ||
| self.softmax(attn_weights), | ||
| query_states.dtype, | ||
| ) | ||
| attn_weights = self.dropout_layer(attn_weights, training=training) | ||
| attn_output = keras.ops.einsum( | ||
| "bnts,bnsh->bnth", attn_weights, value_states | ||
| ) | ||
| attn_output = self.output_dense(attn_output) | ||
| return (attn_output, attn_weights), updated_cache | ||
| else: # Self-attention | ||
harshaljanjani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| hidden_states = inputs | ||
| kv_states = hidden_states | ||
| query_states = self.query_dense(hidden_states) | ||
| key_states = self.key_dense(kv_states) | ||
| value_states = self.value_dense(kv_states) | ||
| start_index = ( | ||
| 0 if cache_update_index is None else cache_update_index | ||
| ) | ||
| query_states = self.rotary_embedding( | ||
| query_states, start_index=start_index | ||
| ) | ||
| key_states = self.rotary_embedding( | ||
| key_states, start_index=start_index | ||
| ) | ||
| current_pass_cache = keras.ops.stack( | ||
| (key_states, value_states), axis=1 | ||
| ) | ||
| if cache is not None: | ||
| if cache_update_index is None: | ||
| raise ValueError( | ||
| "Both `cache` and `cache_update_index` must be passed " | ||
| "for self-attention caching." | ||
| ) | ||
| key_cache, value_cache = cache[:, 0, ...], cache[:, 1, ...] | ||
| start = [0, 0, cache_update_index, 0] | ||
| key_states = keras.ops.slice_update( | ||
| key_cache, start, key_states | ||
| ) | ||
| value_states = keras.ops.slice_update( | ||
| value_cache, start, value_states | ||
| ) | ||
| cache = keras.ops.stack((key_states, value_states), axis=1) | ||
| elif cache_update_index is not None: | ||
| raise ValueError( | ||
| "`cache_update_index` should not be set if `cache` is " | ||
| "`None`." | ||
| ) | ||
| else: | ||
| cache = current_pass_cache | ||
|
|
||
| # Repeat key-value heads for GQA. | ||
| key_states = repeat_kv(key_states, self.num_key_value_groups) | ||
| value_states = repeat_kv(value_states, self.num_key_value_groups) | ||
|
|
||
| attn_weights = keras.ops.einsum( | ||
| "bnth,bnsh->bnts", query_states, key_states | ||
| ) | ||
| attn_weights *= self.scaling | ||
|
|
||
| if self.logit_soft_cap is not None: | ||
| attn_weights = attn_weights / self.logit_soft_cap | ||
| attn_weights = keras.ops.tanh(attn_weights) | ||
| attn_weights = attn_weights * self.logit_soft_cap | ||
| if attention_mask is not None: | ||
| attn_weights += attention_mask | ||
|
|
||
| attn_weights = keras.ops.cast( | ||
| self.softmax(attn_weights), | ||
| query_states.dtype, | ||
| ) | ||
| attn_weights = self.dropout_layer(attn_weights, training=training) | ||
| attn_output = keras.ops.einsum( | ||
| "bnts,bnsh->bnth", attn_weights, value_states | ||
| ) | ||
| attn_output = self.output_dense(attn_output) | ||
| return (attn_output, attn_weights), cache | ||
|
|
||
| def compute_output_shape(self, input_shape): | ||
| if self.attention_type == "cross": | ||
| hidden_states_shape, kv_states_shape = input_shape | ||
| else: | ||
| hidden_states_shape = input_shape | ||
| kv_states_shape = input_shape | ||
| attn_output_shape = hidden_states_shape | ||
| q_len = hidden_states_shape[1] | ||
| kv_len = kv_states_shape[1] | ||
| attn_weights_shape = ( | ||
| hidden_states_shape[0], | ||
| self.num_query_heads, | ||
| q_len, | ||
| kv_len, | ||
| ) | ||
| return attn_output_shape, attn_weights_shape | ||
harshaljanjani marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "hidden_size": self.hidden_size, | ||
| "head_dim": self.head_dim, | ||
| "num_attention_heads": self.num_query_heads, | ||
| "num_key_value_heads": self.num_key_value_heads, | ||
| "query_pre_attn_scalar": self.query_pre_attn_scalar, | ||
| "attention_bias": self.attention_bias, | ||
| "attention_type": self.attention_type, | ||
| "cross_attention_hidden_size": self.cross_attention_hidden_size, | ||
| "initializer_range": self.initializer_range, | ||
| "attention_dropout": self.attention_dropout, | ||
| "attn_logit_softcapping": self.logit_soft_cap, | ||
| "rope_max_wavelength": self.rope_max_wavelength, | ||
| } | ||
| ) | ||
| return config | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.