From 071d0df19a97a9d51c77486dafbd32f6fcfae54e Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Sat, 19 Jul 2025 16:12:33 +0400 Subject: [PATCH 01/19] init: Add initial project structure and files --- .../src/models/t5gemma/t5gemma_attention.py | 480 ++++++++++++++++++ .../src/models/t5gemma/t5gemma_backbone.py | 236 +++++++++ .../models/t5gemma/t5gemma_backbone_test.py | 52 ++ .../src/models/t5gemma/t5gemma_causal_lm.py | 334 ++++++++++++ .../t5gemma/t5gemma_causal_lm_preprocessor.py | 72 +++ .../models/t5gemma/t5gemma_causal_lm_test.py | 141 +++++ .../src/models/t5gemma/t5gemma_decoder.py | 272 ++++++++++ .../src/models/t5gemma/t5gemma_encoder.py | 199 ++++++++ .../src/models/t5gemma/t5gemma_layers.py | 106 ++++ .../src/models/t5gemma/t5gemma_presets.py | 2 + .../src/models/t5gemma/t5gemma_tokenizer.py | 73 +++ 11 files changed, 1967 insertions(+) create mode 100644 keras_hub/src/models/t5gemma/t5gemma_attention.py create mode 100644 keras_hub/src/models/t5gemma/t5gemma_backbone.py create mode 100644 keras_hub/src/models/t5gemma/t5gemma_backbone_test.py create mode 100644 keras_hub/src/models/t5gemma/t5gemma_causal_lm.py create mode 100644 keras_hub/src/models/t5gemma/t5gemma_causal_lm_preprocessor.py create mode 100644 keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py create mode 100644 keras_hub/src/models/t5gemma/t5gemma_decoder.py create mode 100644 keras_hub/src/models/t5gemma/t5gemma_encoder.py create mode 100644 keras_hub/src/models/t5gemma/t5gemma_layers.py create mode 100644 keras_hub/src/models/t5gemma/t5gemma_presets.py create mode 100644 keras_hub/src/models/t5gemma/t5gemma_tokenizer.py diff --git a/keras_hub/src/models/t5gemma/t5gemma_attention.py b/keras_hub/src/models/t5gemma/t5gemma_attention.py new file mode 100644 index 0000000000..f063644757 --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_attention.py @@ -0,0 +1,480 @@ +import keras + +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.models.gemma.gemma_attention import CachedGemmaAttention +from keras_hub.src.models.t5gemma.t5gemma_layers import ( + t5gemma_kernel_initializer, +) +from keras_hub.src.utils.keras_utils import clone_initializer + + +def repeat_kv(hidden_states, n_rep): + """Repeats the key/value hidden states to match the number of query heads + for Grouped Query Attention (GQA). + + This function is used in `T5GemmaSelfAttention` and `T5GemmaCrossAttention` + to broadcast key and value states across multiple query heads when Grouped + Query Attention (GQA) is used (i.e., when `num_query_heads` > + `num_key_value_heads`). + + Args: + hidden_states: Tensor, The key or value hidden states with shape + `(batch, num_key_value_heads, sequence_length, head_dim)`. + n_rep: int, The number of times to repeat the key/value heads. This is + typically `num_query_heads // num_key_value_heads`. + + Returns: + Tensor: The expanded key/value hidden states with shape + `(batch, num_query_heads, sequence_length, head_dim)`. + """ + if n_rep == 1: + return hidden_states + batch, num_key_value_heads, slen, head_dim = keras.ops.shape(hidden_states) + hidden_states = keras.ops.expand_dims(hidden_states, 2) + hidden_states = keras.ops.tile(hidden_states, (1, 1, n_rep, 1, 1)) + return keras.ops.reshape( + hidden_states, (batch, num_key_value_heads * n_rep, slen, head_dim) + ) + + +@keras.saving.register_keras_serializable(package="keras_hub") +class T5GemmaSelfAttention(CachedGemmaAttention): + """Self-attention block for the T5Gemma model. + + This layer performs self-attention with Rotary Positional Embeddings (RoPE) + and supports Grouped Query Attention (GQA). It is used in + `T5GemmaEncoderLayer` and `T5GemmaDecoderLayer`. + + Args: + hidden_size: int, The dimensionality of the hidden states. + num_attention_heads: int, The number of attention heads. + num_key_value_heads: int, The number of key-value heads. For GQA, this + can be less than `num_attention_heads`. + query_pre_attn_scalar: float, Scalar to multiply queries by before + attention. + attention_bias: bool, Whether to include bias in the query, key, value, + and output dense layers. + initializer_range: float, The range for the random normal initializer + for kernel weights. Default is `0.02`. + attention_dropout: float, The dropout rate applied to attention weights. + Default is `0.0`. + attn_logit_softcapping: float, optional, The softcapping value for + attention logits. + rope_max_wavelength: float, The maximum wavelength for Rotary Positional + Embeddings. Default is `10000.0`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + initializer_range=0.02, + attention_dropout=0.0, + attn_logit_softcapping=None, + rope_max_wavelength=10000.0, + **kwargs, + ): + super().__init__( + head_dim=hidden_size // num_attention_heads, + num_query_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + kernel_initializer=t5gemma_kernel_initializer(initializer_range), + logit_soft_cap=attn_logit_softcapping, + dropout=attention_dropout, + query_head_dim_normalize=False, + use_sliding_window_attention=False, + **kwargs, + ) + self.attention_dropout = attention_dropout + self.hidden_size = hidden_size + self.query_pre_attn_scalar = query_pre_attn_scalar + self.initializer_range = initializer_range + self.attention_bias = attention_bias + self.num_key_value_groups = ( + self.num_query_heads // self.num_key_value_heads + ) + self.scaling = self.query_pre_attn_scalar**-0.5 + self.rope_max_wavelength = rope_max_wavelength + self.rotary_embedding = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + sequence_axis=2, + feature_axis=3, + name="rotary_embedding", + ) + + def build(self, input_shape): + self._kernel_initializer = t5gemma_kernel_initializer( + self.initializer_range + ) + + # Query projection layer. + self.hidden_dim = input_shape[-1] + self.query_dense = keras.layers.EinsumDense( + equation="...a,abc->...bc", + output_shape=(self.num_query_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="bc" if self.attention_bias else None, + dtype=self.dtype_policy, + name="query", + ) + self.query_dense.build(input_shape) + + # Key projection layer. + self.key_dense = keras.layers.EinsumDense( + equation="...a,abc->...bc", + output_shape=(self.num_key_value_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="bc" if self.attention_bias else None, + dtype=self.dtype_policy, + name="key", + ) + self.key_dense.build(input_shape) + + # Value projection layer. + self.value_dense = keras.layers.EinsumDense( + equation="...a,abc->...bc", + output_shape=(self.num_key_value_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="bc" if self.attention_bias else None, + dtype=self.dtype_policy, + name="value", + ) + self.value_dense.build(input_shape) + + # Output projection layer. + self.output_dense = keras.layers.EinsumDense( + equation="...a,ab->...b", + output_shape=(self.hidden_dim,), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="b" if self.attention_bias else None, + dtype=self.dtype_policy, + name="attention_output", + ) + self.output_dense.build( + (*input_shape[:-1], self.num_query_heads * self.head_dim) + ) + self.dropout_layer = keras.layers.Dropout( + rate=self.attention_dropout, + dtype=self.dtype_policy, + ) + q_len = input_shape[1] + attn_weights_shape = (None, self.num_query_heads, q_len, q_len) + self.dropout_layer.build(attn_weights_shape) + self.softmax = keras.layers.Softmax(dtype="float32") + self.built = True + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + query_states = self.query_dense(hidden_states) + query_states = keras.ops.transpose(query_states, (0, 2, 1, 3)) + key_states = self.key_dense(hidden_states) + key_states = keras.ops.transpose(key_states, (0, 2, 1, 3)) + value_states = self.value_dense(hidden_states) + value_states = keras.ops.transpose(value_states, (0, 2, 1, 3)) + start_index = 0 if cache_update_index is None else cache_update_index + query_states = self.rotary_embedding( + query_states, start_index=start_index + ) + key_states = self.rotary_embedding(key_states, start_index=start_index) + current_pass_cache = keras.ops.stack((key_states, value_states), axis=1) + if cache is not None: + if cache_update_index is None: + raise ValueError( + "Both `cache` and `cache_update_index` must be " + "passed for caching." + ) + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + start = [0, 0, cache_update_index, 0] + key_states = keras.ops.slice_update(key_cache, start, key_states) + value_states = keras.ops.slice_update( + value_cache, start, value_states + ) + cache = keras.ops.stack((key_states, value_states), axis=1) + elif cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is `None`." + ) + else: + cache = current_pass_cache + + # Repeat key-value heads for GQA. + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = ( + keras.ops.matmul( + query_states, keras.ops.transpose(key_states, (0, 1, 3, 2)) + ) + * self.scaling + ) + + if self.logit_soft_cap is not None: + attn_weights = attn_weights / self.logit_soft_cap + attn_weights = keras.ops.tanh(attn_weights) + attn_weights = attn_weights * self.logit_soft_cap + if attention_mask is not None: + attn_weights += attention_mask + + attn_weights = keras.ops.cast( + self.softmax(attn_weights), + query_states.dtype, + ) + attn_weights = self.dropout_layer(attn_weights, training=training) + attn_output = keras.ops.matmul(attn_weights, value_states) + attn_output = keras.ops.transpose(attn_output, (0, 2, 1, 3)) + attn_output = keras.ops.reshape( + attn_output, + ( + keras.ops.shape(hidden_states)[0], + -1, + self.num_query_heads * self.head_dim, + ), + ) + attn_output = self.output_dense(attn_output) + return (attn_output, attn_weights), cache + + def compute_output_shape(self, input_shape): + attn_output_shape = input_shape + q_len = input_shape[1] + attn_weights_shape = ( + input_shape[0], + self.num_query_heads, + q_len, + q_len, + ) + return attn_output_shape, attn_weights_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "num_attention_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "attn_logit_softcapping": self.logit_soft_cap, + "rope_max_wavelength": self.rope_max_wavelength, + } + ) + return config + + +@keras.saving.register_keras_serializable(package="keras_hub") +class T5GemmaCrossAttention(keras.layers.Layer): + """Cross-attention block for the T5Gemma model. + + This layer performs cross-attention, where queries are derived from the + decoder hidden states and keys/values are from the encoder hidden states. + It supports Grouped Query Attention (GQA). It is used in + `T5GemmaDecoderLayer`. + + Args: + hidden_size: int, The dimensionality of the hidden states for queries + and output. + cross_attention_hidden_size: int, The dimensionality of the hidden + states from the encoder for keys and values. + num_attention_heads: int, The number of attention heads for queries. + num_key_value_heads: int, The number of key-value heads. For GQA, this + can be less than `num_attention_heads`. + query_pre_attn_scalar: float, Scalar to multiply queries by before + attention. + attention_bias: bool, Whether to include bias in the query, key, value, + and output dense layers. + initializer_range: float, The range for the random normal initializer + for kernel weights. Default is `0.02`. + attention_dropout: float, The dropout rate applied to attention weights. + Default is `0.0`. + attn_logit_softcapping: float, optional, The softcapping value for + attention logits. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + hidden_size, + cross_attention_hidden_size, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + initializer_range=0.02, + attention_dropout=0.0, + attn_logit_softcapping=None, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.cross_attention_hidden_size = ( + cross_attention_hidden_size or hidden_size + ) + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.query_pre_attn_scalar = query_pre_attn_scalar + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.attn_logit_softcapping = attn_logit_softcapping + + self.head_dim = self.hidden_size // self.num_attention_heads + self.num_key_value_groups = ( + self.num_attention_heads // self.num_key_value_heads + ) + self.scaling = self.query_pre_attn_scalar**-0.5 + + def build(self, input_shape): + hidden_states_shape, encoder_hidden_states_shape = input_shape + self.kernel_initializer = t5gemma_kernel_initializer( + self.initializer_range + ) + + # Query projection layer. + self.query_dense = keras.layers.EinsumDense( + equation="...a,abc->...bc", + output_shape=(self.num_attention_heads, self.head_dim), + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_axes="bc" if self.attention_bias else None, + name="query", + ) + self.query_dense.build(hidden_states_shape) + cross_attn_proj_shape = ( + *encoder_hidden_states_shape[:-1], + self.cross_attention_hidden_size, + ) + + # Key projection layer. + self.key_dense = keras.layers.EinsumDense( + equation="...a,abc->...bc", + output_shape=(self.num_key_value_heads, self.head_dim), + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_axes="bc" if self.attention_bias else None, + name="key", + ) + self.key_dense.build(cross_attn_proj_shape) + + # Value projection layer. + self.value_dense = keras.layers.EinsumDense( + equation="...a,abc->...bc", + output_shape=(self.num_key_value_heads, self.head_dim), + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_axes="bc" if self.attention_bias else None, + name="value", + ) + self.value_dense.build(cross_attn_proj_shape) + + # Output projection layer. + self.output_dense = keras.layers.EinsumDense( + equation="...a,ab->...b", + output_shape=(self.hidden_size,), + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_axes="b" if self.attention_bias else None, + name="attention_output", + ) + o_proj_input_shape = (*hidden_states_shape[:-1], self.hidden_size) + self.output_dense.build(o_proj_input_shape) + self.dropout_layer = keras.layers.Dropout(self.attention_dropout) + q_len = hidden_states_shape[1] + kv_len = encoder_hidden_states_shape[1] + attn_weights_shape = (None, self.num_attention_heads, q_len, kv_len) + self.dropout_layer.build(attn_weights_shape) + self.built = True + + def call( + self, + inputs, + attention_mask=None, + cache=None, + training=None, + ): + hidden_states, encoder_hidden_states = inputs + batch_size, q_seq_len = keras.ops.shape(hidden_states)[:2] + query_states = self.query_dense(hidden_states) + query_states = keras.ops.transpose(query_states, (0, 2, 1, 3)) + if cache is not None: + key_states = cache[:, 0, ...] + value_states = cache[:, 1, ...] + else: + key_states = self.key_dense(encoder_hidden_states) + key_states = keras.ops.transpose(key_states, (0, 2, 1, 3)) + value_states = self.value_dense(encoder_hidden_states) + value_states = keras.ops.transpose(value_states, (0, 2, 1, 3)) + + # Repeat key-value heads for GQA. + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = ( + keras.ops.matmul( + query_states, keras.ops.transpose(key_states, (0, 1, 3, 2)) + ) + * self.scaling + ) + + if self.attn_logit_softcapping is not None: + attn_weights = attn_weights / self.attn_logit_softcapping + attn_weights = keras.ops.tanh(attn_weights) + attn_weights = attn_weights * self.attn_logit_softcapping + if attention_mask is not None: + attn_weights += attention_mask + + attn_weights = keras.ops.cast( + keras.activations.softmax( + keras.ops.cast(attn_weights, "float32"), axis=-1 + ), + query_states.dtype, + ) + attn_weights = self.dropout_layer(attn_weights, training=training) + attn_output = keras.ops.matmul(attn_weights, value_states) + attn_output = keras.ops.transpose(attn_output, (0, 2, 1, 3)) + attn_output = keras.ops.reshape( + attn_output, (batch_size, q_seq_len, -1) + ) + attn_output = self.output_dense(attn_output) + if cache is not None: + updated_cache = keras.ops.stack((key_states, value_states), axis=1) + return (attn_output, attn_weights), updated_cache + else: + return attn_output, attn_weights + + def compute_output_shape(self, input_shape): + hidden_states_shape, encoder_hidden_states_shape = input_shape + attn_output_shape = hidden_states_shape + q_len = hidden_states_shape[1] + kv_len = encoder_hidden_states_shape[1] + attn_weights_shape = ( + hidden_states_shape[0], + self.num_attention_heads, + q_len, + kv_len, + ) + return attn_output_shape, attn_weights_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "cross_attention_hidden_size": self.cross_attention_hidden_size, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "attn_logit_softcapping": self.attn_logit_softcapping, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone.py b/keras_hub/src/models/t5gemma/t5gemma_backbone.py new file mode 100644 index 0000000000..963b3b4e71 --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone.py @@ -0,0 +1,236 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm +from keras_hub.src.models.t5gemma.t5gemma_decoder import T5GemmaDecoderLayer +from keras_hub.src.models.t5gemma.t5gemma_encoder import T5GemmaEncoderLayer + + +@keras_hub_export("keras_hub.models.T5GemmaBackbone") +class T5GemmaBackbone(Backbone): + """T5Gemma backbone model. + + This class implements the encoder-decoder backbone of the T5Gemma model, + consisting of an embedding layer, a stack of encoder layers, and a + stack of decoder layers. + + Args: + vocabulary_size: int, The size of the vocabulary. + hidden_dim: int, The dimensionality of the hidden states throughout the + model. + intermediate_dim: int, The intermediate size of the feed-forward + networks in encoder and decoder layers. + num_layers: int, The number of encoder and decoder layers. + num_attention_heads: int, The number of attention heads in all attention + mechanisms. + num_key_value_heads: int, The number of key-value heads for grouped + query attention in all attention mechanisms. + dropout_rate: float, The dropout rate applied throughout the model. + rms_norm_eps: float, The epsilon value for RMS normalization. + query_pre_attn_scalar: float, Scalar to multiply queries by before + attention. + attention_bias: bool, Whether to include bias in attention computations. + hidden_activation: str, The activation function used in the feed-forward + networks. + layer_types: list of str, A list of strings specifying the type of + attention layer for each encoder/decoder layer. Each element can be + either `"sliding_attention"` or `"full_attention"`. For example, + `["full_attention", "sliding_attention", ...]`. + tie_word_embeddings: bool, Whether to tie input and output word + embeddings. Default is `True`. + initializer_range: float, The range for the random normal initializer. + Default is `0.02`. + attention_dropout: float, The dropout rate applied to attention weights. + Default is `0.0`. + sliding_window: int, optional, The window size for sliding attention. + Required if any `layer_type` is `"sliding_attention"`. + cross_attention_hidden_size: int, optional, The hidden size for + cross-attention in the decoder layers. If None, it defaults to + `hidden_dim`. + attn_logit_softcapping: float, optional, The softcapping value for + attention logits. + final_logit_softcapping: float, optional, The softcapping value for + final logits. + rope_max_wavelength: float, The maximum wavelength for Rotary Positional + Embeddings. Default is `10000.0`. + **kwargs: Additional keyword arguments passed to the parent `Backbone` + class. + """ + + def __init__( + self, + vocabulary_size, + hidden_dim, + intermediate_dim, + num_layers, + num_attention_heads, + num_key_value_heads, + dropout_rate, + rms_norm_eps, + query_pre_attn_scalar, + attention_bias, + hidden_activation, + layer_types, + tie_word_embeddings=True, + initializer_range=0.02, + attention_dropout=0.0, + sliding_window=None, + cross_attention_hidden_size=None, + attn_logit_softcapping=None, + final_logit_softcapping=None, + rope_max_wavelength=10000.0, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=tie_word_embeddings, + ) + self.encoder_layers = [ + T5GemmaEncoderLayer( + hidden_size=hidden_dim, + rms_norm_eps=rms_norm_eps, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + intermediate_size=intermediate_dim, + hidden_activation=hidden_activation, + dropout_rate=dropout_rate, + initializer_range=initializer_range, + attention_dropout=attention_dropout, + layer_type=layer_types[i], + sliding_window=sliding_window, + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=rope_max_wavelength, + name=f"encoder_layer_{i}", + ) + for i in range(num_layers) + ] + self.encoder_norm = T5LayerNorm(epsilon=rms_norm_eps) + self.encoder_dropout = keras.layers.Dropout(dropout_rate) + self.decoder_layers = [ + T5GemmaDecoderLayer( + hidden_size=hidden_dim, + rms_norm_eps=rms_norm_eps, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + intermediate_size=intermediate_dim, + hidden_activation=hidden_activation, + dropout_rate=dropout_rate, + initializer_range=initializer_range, + attention_dropout=attention_dropout, + layer_type=layer_types[i], + sliding_window=sliding_window, + cross_attention_hidden_size=cross_attention_hidden_size, + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=rope_max_wavelength, + name=f"decoder_layer_{i}", + ) + for i in range(num_layers) + ] + self.decoder_norm = T5LayerNorm(epsilon=rms_norm_eps) + self.decoder_dropout = keras.layers.Dropout(dropout_rate) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + + # Encoder. + encoder_embeddings = self.token_embedding(token_id_input) + encoder_embeddings = encoder_embeddings * keras.ops.cast( + keras.ops.sqrt(hidden_dim), encoder_embeddings.dtype + ) + encoder_hidden_states = self.encoder_dropout(encoder_embeddings) + for layer in self.encoder_layers: + encoder_hidden_states = layer( + encoder_hidden_states, + padding_mask=padding_mask_input, + ) + encoder_output = self.encoder_norm(encoder_hidden_states) + encoder_output = self.encoder_dropout(encoder_output) + + # Decoder. + decoder_embeddings = self.token_embedding(token_id_input) + decoder_embeddings = decoder_embeddings * keras.ops.cast( + keras.ops.sqrt(hidden_dim), decoder_embeddings.dtype + ) + decoder_hidden_states = self.decoder_dropout(decoder_embeddings) + for layer in self.decoder_layers: + decoder_hidden_states, _ = layer( + (decoder_hidden_states, encoder_output), + self_attention_padding_mask=padding_mask_input, + cross_attention_padding_mask=padding_mask_input, + ) + decoder_output = self.decoder_norm(decoder_hidden_states) + decoder_output = self.decoder_dropout(decoder_output) + + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=decoder_output, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.dropout_rate = dropout_rate + self.rms_norm_eps = rms_norm_eps + self.tie_word_embeddings = tie_word_embeddings + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.hidden_activation = hidden_activation + self.layer_types = layer_types + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.sliding_window = sliding_window + self.cross_attention_hidden_size = cross_attention_hidden_size + self.attn_logit_softcapping = attn_logit_softcapping + self.final_logit_softcapping = final_logit_softcapping + self.rope_max_wavelength = rope_max_wavelength + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_layers": self.num_layers, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "dropout_rate": self.dropout_rate, + "rms_norm_eps": self.rms_norm_eps, + "tie_word_embeddings": self.tie_word_embeddings, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "hidden_activation": self.hidden_activation, + "layer_types": self.layer_types, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "sliding_window": self.sliding_window, + "cross_attention_hidden_size": self.cross_attention_hidden_size, + "attn_logit_softcapping": self.attn_logit_softcapping, + "final_logit_softcapping": self.final_logit_softcapping, + "rope_max_wavelength": self.rope_max_wavelength, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py b/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py new file mode 100644 index 0000000000..84413cec85 --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py @@ -0,0 +1,52 @@ +import keras +import pytest + +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.tests.test_case import TestCase + + +class T5GemmaBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 100, + "hidden_dim": 32, + "intermediate_dim": 64, + "num_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "dropout_rate": 0.1, + "rms_norm_eps": 1e-6, + "tie_word_embeddings": True, + "query_pre_attn_scalar": 1.0, + "attention_bias": False, + "hidden_activation": "gelu_approximate", + "layer_types": ["sliding_attention", "full_attention"], + "sliding_window": 16, + "cross_attention_hidden_size": 32, + "attn_logit_softcapping": 50.0, + "rope_max_wavelength": 10000.0, + "initializer_range": 0.02, + "attention_dropout": 0.0, + } + self.input_data = { + "token_ids": keras.ops.ones((2, 16), dtype="int32"), + "padding_mask": keras.ops.ones((2, 16), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=T5GemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 16, 32), + run_mixed_precision_check=False, + run_quantization_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=T5GemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py new file mode 100644 index 0000000000..9c53bae637 --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py @@ -0,0 +1,334 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.models.t5gemma.t5gemma_causal_lm_preprocessor import ( + T5GemmaCausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export("keras_hub.models.T5GemmaCausalLM") +class T5GemmaCausalLM(CausalLM): + """An end-to-end T5Gemma model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a T5Gemma model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_hub.samplers` objects to control the generation. By + default, `"greedy"` sampling will be used. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to string inputs during + `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default + when creating the model with `from_preset()`. + + Args: + backbone: A `keras_hub.models.T5GemmaBackbone` instance. + preprocessor: A `keras_hub.models.T5GemmaCausalLMPreprocessor` or + `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + + Examples: + + Use `generate()` to do text generation. + ```python + t5gemma_lm = keras_hub.models.T5GemmaCausalLM.from_preset( + "t5gemma_b_b_prefixlm_it" + ) + t5gemma_lm.generate("I want to say", max_length=30) + + # Generate with batched prompts. + t5gemma_lm.generate(["This is a", "Where are you"], max_length=30) + ``` + + Compile the `generate()` function with a custom sampler. + ```python + t5gemma_lm = keras_hub.models.T5GemmaCausalLM.from_preset( + "t5gemma_b_b_prefixlm_it" + ) + t5gemma_lm.compile(sampler="top_k") + t5gemma_lm.generate("I want to say", max_length=30) + + t5gemma_lm.compile(sampler=keras_hub.samplers.BeamSampler(num_beams=2)) + t5gemma_lm.generate("I want to say", max_length=30) + ``` + + Use `generate()` without preprocessing. + ```python + # The preprocessor is responsible for creating a dictionary of tensors. + # If you are not using a preprocessor, you must format your inputs + # yourself. + prompt = { + # Token ids for " Keras is". + "token_ids": np.array([[2, 214064, 603, 0, 0, 0, 0]] * 2), + # Use `"padding_mask"` to indicate values that should not be overridden. + "padding_mask": np.array([[1, 1, 1, 0, 0, 0, 0]] * 2), + } + + t5gemma_lm = keras_hub.models.T5GemmaCausalLM.from_preset( + "t5gemma_b_b_prefixlm_it", + preprocessor=None, + ) + t5gemma_lm.generate(prompt) + ``` + + Call `fit()` on a single batch. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + t5gemma_lm = keras_hub.models.T5GemmaCausalLM.from_preset( + "t5gemma_b_b_prefixlm_it" + ) + t5gemma_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` without preprocessing. + ```python + x = { + # Token ids for " Keras is deep learning library" + "token_ids": np.array([[2, 214064, 603, 5271, 6044, 9581, 1, 0]] * 2), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 0]] * 2), + } + y = np.array([[214064, 603, 5271, 6044, 9581, 3, 0, 0]] * 2) + sw = np.array([[1, 1, 1, 1, 1, 1, 0, 0]] * 2) + + t5gemma_lm = keras_hub.models.T5GemmaCausalLM.from_preset( + "t5gemma_b_b_prefixlm_it", + preprocessor=None, + ) + t5gemma_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2) + ``` + + Custom backbone and vocabulary. + ```python + tokenizer = keras_hub.models.T5GemmaTokenizer( + proto="proto.spm", + ) + preprocessor = keras_hub.models.T5GemmaCausalLMPreprocessor( + tokenizer=tokenizer, + sequence_length=128, + ) + backbone = keras_hub.models.T5GemmaBackbone( + vocabulary_size=32000, + num_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + hidden_dim=256, + intermediate_dim=512, + dropout_rate=0.1, + rms_norm_eps=1e-6, + query_pre_attn_scalar=1.0, + attention_bias=False, + hidden_activation="gelu_approximate", + layer_types=["full_attention"] * 4 + ) + t5gemma_lm = keras_hub.models.T5GemmaCausalLM( + backbone=backbone, + preprocessor=preprocessor, + ) + t5gemma_lm.fit(x=features, batch_size=2) + ``` + """ + + backbone_cls = T5GemmaBackbone + preprocessor_cls = T5GemmaCausalLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input + sequence_output = backbone(inputs) + outputs = backbone.token_embedding(sequence_output, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def call_with_cache( + self, token_ids, padding_mask, cache, cache_update_index + ): + """Forward pass of `T5GemmaCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in the attention layers, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: A dense int Tensor with shape `(batch_size, max_length)`. + padding_mask: A dense int Tensor with shape `(batch_size, + max_length)`. + cache: A dense float Tensor, the cache of key and value states. + cache_update_index: int, or int Tensor. The index of the current + token being processed in the whole sequence. + + Returns: + A `(logits, hidden_states, cache)` tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the updated decoding cache. + """ + ( + encoder_output, + past_key_values, + encoder_padding_mask, + ) = cache + hidden_states = self.backbone.token_embedding(token_ids) + hidden_states *= keras.ops.cast( + keras.ops.sqrt(self.backbone.hidden_dim), hidden_states.dtype + ) + hidden_states = self.backbone.decoder_dropout(hidden_states) + updated_key_values = [] + for i, layer in enumerate(self.backbone.decoder_layers): + current_cache = past_key_values[:, i, ...] + hidden_states, current_cache = layer( + (hidden_states, encoder_output), + self_attention_padding_mask=padding_mask, + cross_attention_padding_mask=encoder_padding_mask, + self_attention_cache=current_cache, + cache_update_index=cache_update_index, + ) + updated_key_values.append(current_cache) + past_key_values = keras.ops.stack(updated_key_values, axis=1) + hidden_states = self.backbone.decoder_norm(hidden_states) + logits = self.backbone.token_embedding(hidden_states, reverse=True) + cache = ( + encoder_output, + past_key_values, + encoder_padding_mask, + ) + return logits, hidden_states, cache + + def _build_cache(self, token_ids, padding_mask): + """Build an empty cache for use with `call_with_cache()`.""" + # Encoder. + encoder_embeddings = self.backbone.token_embedding(token_ids) + encoder_embeddings *= keras.ops.cast( + keras.ops.sqrt(self.backbone.hidden_dim), encoder_embeddings.dtype + ) + encoder_hidden_states = self.backbone.encoder_dropout( + encoder_embeddings + ) + for layer in self.backbone.encoder_layers: + encoder_hidden_states = layer( + encoder_hidden_states, + padding_mask=padding_mask, + ) + encoder_output = self.backbone.encoder_norm(encoder_hidden_states) + hidden_states = self.backbone.token_embedding(token_ids) + hidden_states *= keras.ops.cast( + keras.ops.sqrt(self.backbone.hidden_dim), hidden_states.dtype + ) + hidden_states = self.backbone.decoder_dropout(hidden_states) + past_key_values = [] + for layer in self.backbone.decoder_layers: + hidden_states, kv_cache_for_layer = layer( + (hidden_states, encoder_output), + self_attention_padding_mask=padding_mask, + cross_attention_padding_mask=padding_mask, + ) + past_key_values.append(kv_cache_for_layer) + past_key_values = keras.ops.stack(past_key_values, axis=1) + hidden_states = self.backbone.decoder_norm(hidden_states) + cache = ( + encoder_output, + past_key_values, + padding_mask, + ) + return hidden_states, cache + + def generate_step(self, inputs, stop_token_ids=None): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + stop_token_ids: Tuple of id's of end token's to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache( + token_ids=inputs["token_ids"], + padding_mask=inputs["padding_mask"], + ) + token_ids = inputs["token_ids"] + padding_mask = inputs["padding_mask"] + # Compute the lengths of all user inputted tokens ids. + row_lengths = keras.ops.sum( + keras.ops.cast(padding_mask, "int32"), axis=-1 + ) + # Start at the first index that has no user inputted id. + index = keras.ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = keras.ops.shape(prompt)[0] + prompt = keras.ops.slice( + prompt, [0, cache_update_index], [batch_size, 1] + ) + prompt_padding_mask = keras.ops.ones_like(prompt, dtype="int32") + logits, _, cache = self.call_with_cache( + prompt, + prompt_padding_mask, + cache, + cache_update_index, + ) + return keras.ops.squeeze(logits, axis=1), None, cache + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of `stop_token_ids` locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, + stop_token_ids, + keras.ops.logical_not(padding_mask), + ) + # Use cumsum to get ones in all locations after end_locations. + end_locations = keras.ops.cast(end_locations, "int32") + cumsum = keras.ops.cast( + keras.ops.cumsum(end_locations, axis=-1), "int32" + ) + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = keras.ops.logical_not( + keras.ops.cast(overflow, "bool") + ) + else: + # Without early stopping, all locations will have been updated. + padding_mask = keras.ops.ones_like(token_ids, dtype="bool") + + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm_preprocessor.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm_preprocessor.py new file mode 100644 index 0000000000..69dc9ec782 --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm_preprocessor.py @@ -0,0 +1,72 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.models.t5gemma.t5gemma_tokenizer import T5GemmaTokenizer + + +@keras_hub_export("keras_hub.models.T5GemmaCausalLMPreprocessor") +class T5GemmaCausalLMPreprocessor(CausalLMPreprocessor): + """T5Gemma Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.T5GemmaCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_hub.models.T5GemmaCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_hub.models.T5GemmaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_hub.models.T5GemmaCausalLMPreprocessor.from_preset( + "t5gemma_b_b_prefixlm_it" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("The quick brown fox jumped.") + preprocessor(sentence) + # Same output. + preprocessor("The quick brown fox jumped.") + + # Tokenize a batch of sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Apply tokenization to a `tf.data.Dataset`. + features = tf.constant(["The quick brown fox.", "Call me Ishmael."]) + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Prepare tokens for generation (no end token). + preprocessor.generate_preprocess(["The quick brown fox jumped."]) + + # Map generation outputs back to strings. + preprocessor.generate_postprocess({ + 'token_ids': np.array([[2, 714, 4320, 8426, 25341, 32292, 235265, 0]]), + 'padding_mask': np.array([[ 1, 1, 1, 1, 1, 1, 1, 0]]), + }) + ``` + """ + + backbone_cls = T5GemmaBackbone + tokenizer_cls = T5GemmaTokenizer diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py new file mode 100644 index 0000000000..7d27681b0a --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py @@ -0,0 +1,141 @@ +import os +from unittest.mock import patch + +import keras +import pytest + +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.models.t5gemma.t5gemma_causal_lm import T5GemmaCausalLM +from keras_hub.src.models.t5gemma.t5gemma_causal_lm_preprocessor import ( + T5GemmaCausalLMPreprocessor, +) +from keras_hub.src.models.t5gemma.t5gemma_tokenizer import T5GemmaTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class T5GemmaCausalLMTest(TestCase): + def setUp(self): + self.tokenizer = T5GemmaTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ), + ) + self.preprocessor = T5GemmaCausalLMPreprocessor( + self.tokenizer, + sequence_length=8, + ) + self.backbone = T5GemmaBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + hidden_dim=16, + intermediate_dim=32, + num_layers=2, + num_attention_heads=2, + num_key_value_heads=1, + dropout_rate=0.0, + rms_norm_eps=1e-6, + tie_word_embeddings=False, + query_pre_attn_scalar=1.0, + attention_bias=False, + hidden_activation="gelu_approximate", + layer_types=["sliding_attention", "full_attention"], + initializer_range=0.02, + attention_dropout=0.0, + sliding_window=4, + final_logit_softcapping=30.0, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = (["the quick brown fox", "the earth is round"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=T5GemmaCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=( + 2, + 8, + self.preprocessor.tokenizer.vocabulary_size(), + ), + ) + + def test_generate(self): + causal_lm = T5GemmaCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :5], + prompt_ids["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + prompt_ids["padding_mask"][:, :5], + ) + + def test_generate_strip_prompt(self): + causal_lm = T5GemmaCausalLM(**self.init_kwargs) + prompt = "the quick brown fox" + output = causal_lm.generate(prompt, strip_prompt=True) + self.assertFalse(output.startswith(prompt)) + + def test_early_stopping(self): + causal_lm = T5GemmaCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ( + keras.ops.ones( + (keras.ops.shape(logits)[0], 1, 1), dtype=logits.dtype + ) + * 1.0e9 + ) + logits = keras.ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the earth is round"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = T5GemmaCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the quick brown fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the quick brown fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=T5GemmaCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in T5GemmaCausalLM.presets: + self.run_preset_test( + cls=T5GemmaCausalLM, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/t5gemma/t5gemma_decoder.py b/keras_hub/src/models/t5gemma/t5gemma_decoder.py new file mode 100644 index 0000000000..66d752162a --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_decoder.py @@ -0,0 +1,272 @@ +import keras + +from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm +from keras_hub.src.models.t5gemma.t5gemma_attention import T5GemmaCrossAttention +from keras_hub.src.models.t5gemma.t5gemma_attention import T5GemmaSelfAttention +from keras_hub.src.models.t5gemma.t5gemma_layers import T5GemmaMLP + + +@keras.saving.register_keras_serializable(package="keras_hub") +class T5GemmaDecoderLayer(keras.layers.Layer): + """Decoder layer for the T5Gemma model. + + This layer implements a single decoder block in the T5Gemma architecture, + comprising self-attention, cross-attention, and a feed-forward network + (MLP). + + Args: + hidden_size: int, The dimensionality of the hidden states. + rms_norm_eps: float, The epsilon value for RMS normalization. + num_attention_heads: int, The number of attention heads in + self-attention and cross-attention. + num_key_value_heads: int, The number of key-value heads for grouped + query attention. + query_pre_attn_scalar: float, Scalar to multiply queries by before + attention. + attention_bias: bool, Whether to include bias in attention computations. + intermediate_size: int, The intermediate size of the feed-forward + network. + hidden_activation: str, The activation function used in the feed-forward + network. + dropout_rate: float, The dropout rate applied after attention and MLP. + initializer_range: float, The range for the random normal initializer. + attention_dropout: float, The dropout rate applied to attention weights. + layer_type: str, Type of attention layer, e.g., `"sliding_attention"`. + cross_attention_hidden_size: int, optional, The hidden size for + cross-attention. If None, it defaults to `hidden_size`. + attn_logit_softcapping: float, optional, The softcapping value for + attention logits. + sliding_window: int, optional, The window size for sliding attention. + Required if `layer_type` is `"sliding_attention"`. + rope_max_wavelength: float, The maximum wavelength for Rotary + Positional Embeddings. Default is `10000.0`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + hidden_size, + rms_norm_eps, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range, + attention_dropout, + layer_type, + cross_attention_hidden_size=None, + attn_logit_softcapping=None, + sliding_window=None, + rope_max_wavelength=10000.0, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.dropout_rate = dropout_rate + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.layer_type = layer_type + self.sliding_window = sliding_window + self.rope_max_wavelength = rope_max_wavelength + if ( + self.layer_type == "sliding_attention" + and self.sliding_window is None + ): + raise ValueError( + "`sliding_window` must be set for `sliding_attention` layer " + "type." + ) + + # Self-attention. + self.self_attn = T5GemmaSelfAttention( + hidden_size, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + initializer_range=initializer_range, + attention_dropout=attention_dropout, + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=self.rope_max_wavelength, + ) + self.pre_self_attn_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + self.post_self_attn_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + + # Cross-attention. + self.cross_attn = T5GemmaCrossAttention( + hidden_size, + cross_attention_hidden_size, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + initializer_range=initializer_range, + attention_dropout=attention_dropout, + attn_logit_softcapping=attn_logit_softcapping, + ) + self.pre_cross_attn_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + self.post_cross_attn_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + + # MLP. + self.mlp = T5GemmaMLP( + hidden_size, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range=initializer_range, + ) + self.pre_feedforward_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + self.post_feedforward_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + + self.dropout = keras.layers.Dropout(dropout_rate) + + def build(self, input_shape): + hidden_states_shape, encoder_hidden_states_shape = input_shape + self.pre_self_attn_layernorm.build(hidden_states_shape) + current_shape = hidden_states_shape + self.self_attn.build(current_shape) + attn_output_shape = self.self_attn.compute_output_shape(current_shape)[ + 0 + ] + self.post_self_attn_layernorm.build(attn_output_shape) + current_shape = attn_output_shape + self.dropout.build(current_shape) + self.pre_cross_attn_layernorm.build(current_shape) + self.cross_attn.build([current_shape, encoder_hidden_states_shape]) + attn_output_shape = self.cross_attn.compute_output_shape( + [current_shape, encoder_hidden_states_shape] + )[0] + self.post_cross_attn_layernorm.build(attn_output_shape) + current_shape = attn_output_shape + self.pre_feedforward_layernorm.build(current_shape) + self.mlp.build(current_shape) + mlp_output_shape = self.mlp.compute_output_shape(current_shape) + self.post_feedforward_layernorm.build(mlp_output_shape) + self.built = True + + def _make_self_attention_mask(self, hidden_states, padding_mask): + seq_len = keras.ops.shape(hidden_states)[1] + q_indices = keras.ops.arange(0, seq_len, dtype="int32")[:, None] + kv_indices = keras.ops.arange(0, seq_len, dtype="int32")[None, :] + causal_mask = kv_indices <= q_indices + if self.layer_type == "sliding_attention": + sliding_mask = (q_indices - self.sliding_window) <= kv_indices + causal_mask = keras.ops.logical_and(causal_mask, sliding_mask) + final_mask = causal_mask[None, None, :, :] + if padding_mask is not None: + padding_mask_4d = padding_mask[:, None, None, :] + final_mask = keras.ops.logical_and(final_mask, padding_mask_4d) + return (1.0 - keras.ops.cast(final_mask, hidden_states.dtype)) * -1e9 + + def _make_cross_attention_mask(self, hidden_states, padding_mask): + bidirectional_mask = padding_mask[:, None, None, :] + additive_bidirectional_mask = ( + 1.0 - keras.ops.cast(bidirectional_mask, hidden_states.dtype) + ) * -1e9 + return additive_bidirectional_mask + + def call( + self, + inputs, + self_attention_padding_mask=None, + cross_attention_padding_mask=None, + self_attention_cache=None, + cross_attention_cache=None, + cache_update_index=None, + training=None, + ): + hidden_states, encoder_hidden_states = inputs + # Self Attention. + residual = hidden_states + self_attention_mask = self._make_self_attention_mask( + hidden_states, self_attention_padding_mask + ) + hidden_states = self.pre_self_attn_layernorm(hidden_states) + (hidden_states, _), updated_self_attention_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=self_attention_mask, + cache=self_attention_cache, + cache_update_index=cache_update_index, + training=training, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + + # Cross Attention. + residual = hidden_states + cross_attention_mask = self._make_cross_attention_mask( + hidden_states, cross_attention_padding_mask + ) + hidden_states = self.pre_cross_attn_layernorm(hidden_states) + cross_attn_output = self.cross_attn( + [hidden_states, encoder_hidden_states], + attention_mask=cross_attention_mask, + cache=cross_attention_cache, + training=training, + ) + if cross_attention_cache is not None: + (hidden_states, _), _ = cross_attn_output + else: + hidden_states, _ = cross_attn_output + + hidden_states = self.post_cross_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + + # MLP. + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, training=training) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + return hidden_states, updated_self_attention_cache + + def compute_output_shape(self, input_shape): + hidden_states_shape, _ = input_shape + batch_size, seq_len, _ = hidden_states_shape + head_dim = self.hidden_size // self.num_attention_heads + cache_shape = ( + batch_size, + 2, + self.num_key_value_heads, + seq_len, + head_dim, + ) + return hidden_states_shape, cache_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "dropout_rate": self.dropout_rate, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "layer_type": self.layer_type, + "sliding_window": self.sliding_window, + "rope_max_wavelength": self.rope_max_wavelength, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma/t5gemma_encoder.py b/keras_hub/src/models/t5gemma/t5gemma_encoder.py new file mode 100644 index 0000000000..99c75d302d --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_encoder.py @@ -0,0 +1,199 @@ +import keras + +from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm +from keras_hub.src.models.t5gemma.t5gemma_attention import T5GemmaSelfAttention +from keras_hub.src.models.t5gemma.t5gemma_layers import T5GemmaMLP + + +@keras.saving.register_keras_serializable(package="keras_hub") +class T5GemmaEncoderLayer(keras.layers.Layer): + """Encoder layer for the T5Gemma model. + + This layer implements a single encoder block in the T5Gemma architecture, + comprising self-attention and a feed-forward network (MLP). + + Args: + hidden_size: int, The dimensionality of the hidden states. + rms_norm_eps: float, The epsilon value for RMS normalization. + num_attention_heads: int, The number of attention heads in + self-attention. + num_key_value_heads: int, The number of key-value heads for grouped + query attention. + query_pre_attn_scalar: float, Scalar to multiply queries by before + attention. + attention_bias: bool, Whether to include bias in attention computations. + intermediate_size: int, The intermediate size of the feed-forward + network. + hidden_activation: str, The activation function used in the feed-forward + network. + dropout_rate: float, The dropout rate applied after attention and MLP. + initializer_range: float, The range for the random normal initializer. + attention_dropout: float, The dropout rate applied to attention weights. + layer_type: str, Type of attention layer, e.g., `"sliding_attention"`. + attn_logit_softcapping: float, optional, The softcapping value for + attention logits. + sliding_window: int, optional, The window size for sliding attention. + Required if `layer_type` is `"sliding_attention"`. + rope_max_wavelength: float, The maximum wavelength for Rotary Positional + Embeddings. Default is `10000.0`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + hidden_size, + rms_norm_eps, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range, + attention_dropout, + layer_type, + attn_logit_softcapping=None, + sliding_window=None, + rope_max_wavelength=10000.0, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.dropout_rate = dropout_rate + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.layer_type = layer_type + self.sliding_window = sliding_window + self.rope_max_wavelength = rope_max_wavelength + if ( + self.layer_type == "sliding_attention" + and self.sliding_window is None + ): + raise ValueError( + "`sliding_window` must be set for `sliding_attention` layer " + "type." + ) + self.self_attn = T5GemmaSelfAttention( + hidden_size, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + initializer_range=initializer_range, + attention_dropout=attention_dropout, + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=self.rope_max_wavelength, + ) + self.pre_self_attn_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + self.post_self_attn_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + + self.mlp = T5GemmaMLP( + hidden_size, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range=initializer_range, + ) + self.pre_feedforward_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + self.post_feedforward_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + self.dropout = keras.layers.Dropout(dropout_rate) + + def build(self, input_shape): + self.pre_self_attn_layernorm.build(input_shape) + current_shape = input_shape + self.self_attn.build(current_shape) + attn_output_shape = self.self_attn.compute_output_shape(current_shape)[ + 0 + ] + self.post_self_attn_layernorm.build(attn_output_shape) + current_shape = attn_output_shape + self.dropout.build(current_shape) + self.pre_feedforward_layernorm.build(current_shape) + self.mlp.build(current_shape) + current_shape = self.mlp.compute_output_shape(current_shape) + self.post_feedforward_layernorm.build(current_shape) + self.built = True + + def _make_attention_mask(self, hidden_states, padding_mask): + seq_len = keras.ops.shape(hidden_states)[1] + attention_mask = padding_mask[:, None, None, :] + additive_mask = ( + 1.0 - keras.ops.cast(attention_mask, hidden_states.dtype) + ) * -1e9 + if self.layer_type == "sliding_attention": + q_indices = keras.ops.arange(0, seq_len, dtype="int32")[:, None] + kv_indices = keras.ops.arange(0, seq_len, dtype="int32")[None, :] + window_mask = (q_indices - self.sliding_window < kv_indices) & ( + kv_indices < q_indices + self.sliding_window + ) + window_mask = window_mask[None, None, :, :] + window_additive_mask = ( + 1.0 - keras.ops.cast(window_mask, hidden_states.dtype) + ) * -1e9 + additive_mask = additive_mask + window_additive_mask + return additive_mask + + def call( + self, + hidden_states, + padding_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + residual = hidden_states + attention_mask = self._make_attention_mask(hidden_states, padding_mask) + hidden_states = self.pre_self_attn_layernorm(hidden_states) + (hidden_states, _), _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + cache=cache, + cache_update_index=cache_update_index, + training=training, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, training=training) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + return hidden_states + + def compute_output_shape(self, input_shape): + # Isometric. + return input_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "dropout_rate": self.dropout_rate, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "layer_type": self.layer_type, + "sliding_window": self.sliding_window, + "rope_max_wavelength": self.rope_max_wavelength, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma/t5gemma_layers.py b/keras_hub/src/models/t5gemma/t5gemma_layers.py new file mode 100644 index 0000000000..6650ef942f --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_layers.py @@ -0,0 +1,106 @@ +import keras + +from keras_hub.src.utils.keras_utils import clone_initializer + + +def t5gemma_kernel_initializer(initializer_range=0.01): + """Creates a RandomNormal initializer for T5Gemma kernels. + + Args: + initializer_range: float, The standard deviation of the normal + distribution. Default is `0.01`. + + Returns: + keras.initializers.RandomNormal: A Keras RandomNormal initializer. + """ + return keras.initializers.RandomNormal(mean=0.0, stddev=initializer_range) + + +@keras.saving.register_keras_serializable(package="keras_hub") +class T5GemmaMLP(keras.layers.Layer): + """Multilayer Perceptron (MLP) block for the T5Gemma model. + + This layer implements the feed-forward part of a transformer block, + consisting of two dense layers with a GELU activation and dropout. + + Args: + hidden_size: int, The dimensionality of the input and output hidden + states. + intermediate_size: int, The dimensionality of the intermediate layer. + hidden_activation: str, The activation function to use, e.g., + "gelu_approximate". + dropout_rate: float, The dropout rate applied to the intermediate + hidden states. + initializer_range: float, The range for the random normal initializer + for kernel weights. Default is `0.02`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + hidden_size, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.dropout_rate = dropout_rate + self.initializer_range = initializer_range + self.kernel_initializer = t5gemma_kernel_initializer(initializer_range) + + self.gate_proj = keras.layers.Dense( + self.intermediate_size, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + ) + self.up_proj = keras.layers.Dense( + self.intermediate_size, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + ) + self.down_proj = keras.layers.Dense( + self.hidden_size, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + ) + if self.hidden_activation == "gelu_approximate": + # NOTE: `gelu_pytorch_tanh` is the same as `gelu(approximate=True)`. + self.act_fn = lambda x: keras.activations.gelu(x, approximate=True) + else: + self.act_fn = keras.activations.get(self.hidden_activation) + self.dropout = keras.layers.Dropout(self.dropout_rate) + + def build(self, input_shape): + self.gate_proj.build(input_shape) + self.up_proj.build(input_shape) + intermediate_shape = self.gate_proj.compute_output_shape(input_shape) + self.dropout.build(intermediate_shape) + self.down_proj.build(intermediate_shape) + self.built = True + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x, training=None): + hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + hidden_states = self.dropout(hidden_states, training=training) + down_proj = self.down_proj(hidden_states) + return down_proj + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "dropout_rate": self.dropout_rate, + "initializer_range": self.initializer_range, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma/t5gemma_presets.py b/keras_hub/src/models/t5gemma/t5gemma_presets.py new file mode 100644 index 0000000000..d976272974 --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_presets.py @@ -0,0 +1,2 @@ +# Metadata for loading pretrained model weights. +backbone_presets = {} diff --git a/keras_hub/src/models/t5gemma/t5gemma_tokenizer.py b/keras_hub/src/models/t5gemma/t5gemma_tokenizer.py new file mode 100644 index 0000000000..02f44c4220 --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_tokenizer.py @@ -0,0 +1,73 @@ +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( + SentencePieceTokenizer, +) + + +class T5GemmaTokenizer(SentencePieceTokenizer): + """T5Gemma tokenizer layer based on SentencePiece. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_hub.tokenizers.SentencePieceTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by + T5Gemma models and provides a `from_preset()` method to automatically + download a matching vocabulary for a T5Gemma preset. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + proto: Either a `string` path to a SentencePiece proto file, or a + `bytes` object with a serialized SentencePiece proto. See the + [SentencePiece repository](https://github.com/google/sentencepiece) + for more details on the format. + + Examples: + + ```python + # Unbatched input. + tokenizer = keras_hub.models.T5GemmaTokenizer.from_preset( + "t5gemma_b_b_prefixlm_it" + ) + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + + # Custom vocabulary. + bytes_io = io.BytesIO() + ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."]) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=ds.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=8, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + tokenizer = keras_hub.models.T5GemmaTokenizer( + proto=bytes_io.getvalue(), + ) + tokenizer("The quick brown fox jumped.") + ``` + """ + + backbone_cls = T5GemmaBackbone + + def __init__(self, proto, **kwargs): + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + self._add_special_token("", "pad_token") + super().__init__(proto=proto, **kwargs) From 1c9ebbc0f0d4bfad9da39cd67b3012b0447dfa5d Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Sat, 19 Jul 2025 18:41:25 +0400 Subject: [PATCH 02/19] nit: Fix code format test; and cool AI-generated reviews --- keras_hub/api/models/__init__.py | 12 ++++++ keras_hub/api/tokenizers/__init__.py | 3 ++ .../src/models/t5gemma/t5gemma_attention.py | 5 +-- .../src/models/t5gemma/t5gemma_causal_lm.py | 43 +++++++++++++++---- .../src/models/t5gemma/t5gemma_tokenizer.py | 7 +++ 5 files changed, 58 insertions(+), 12 deletions(-) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 8b6aa475e7..d4ca90dddb 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -599,6 +599,18 @@ T5Preprocessor as T5Preprocessor, ) from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer +from keras_hub.src.models.t5gemma.t5gemma_backbone import ( + T5GemmaBackbone as T5GemmaBackbone, +) +from keras_hub.src.models.t5gemma.t5gemma_causal_lm import ( + T5GemmaCausalLM as T5GemmaCausalLM, +) +from keras_hub.src.models.t5gemma.t5gemma_causal_lm_preprocessor import ( + T5GemmaCausalLMPreprocessor as T5GemmaCausalLMPreprocessor, +) +from keras_hub.src.models.t5gemma.t5gemma_tokenizer import ( + T5GemmaTokenizer as T5GemmaTokenizer, +) from keras_hub.src.models.task import Task as Task from keras_hub.src.models.text_classifier import TextClassifier as Classifier from keras_hub.src.models.text_classifier import ( diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 082078184f..2677d89ee7 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -87,6 +87,9 @@ SigLIPTokenizer as SigLIPTokenizer, ) from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer +from keras_hub.src.models.t5gemma.t5gemma_tokenizer import ( + T5GemmaTokenizer as T5GemmaTokenizer, +) from keras_hub.src.models.whisper.whisper_tokenizer import ( WhisperTokenizer as WhisperTokenizer, ) diff --git a/keras_hub/src/models/t5gemma/t5gemma_attention.py b/keras_hub/src/models/t5gemma/t5gemma_attention.py index f063644757..39ecfbe98d 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_attention.py +++ b/keras_hub/src/models/t5gemma/t5gemma_attention.py @@ -390,6 +390,7 @@ def build(self, input_shape): kv_len = encoder_hidden_states_shape[1] attn_weights_shape = (None, self.num_attention_heads, q_len, kv_len) self.dropout_layer.build(attn_weights_shape) + self.softmax = keras.layers.Softmax(dtype="float32") self.built = True def call( @@ -431,9 +432,7 @@ def call( attn_weights += attention_mask attn_weights = keras.ops.cast( - keras.activations.softmax( - keras.ops.cast(attn_weights, "float32"), axis=-1 - ), + self.softmax(attn_weights), query_states.dtype, ) attn_weights = self.dropout_layer(attn_weights, training=training) diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py index 9c53bae637..aea82ece75 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py @@ -183,7 +183,8 @@ def call_with_cache( """ ( encoder_output, - past_key_values, + self_attention_past_key_values, + cross_attention_past_key_values, encoder_padding_mask, ) = cache hidden_states = self.backbone.token_embedding(token_ids) @@ -193,21 +194,26 @@ def call_with_cache( hidden_states = self.backbone.decoder_dropout(hidden_states) updated_key_values = [] for i, layer in enumerate(self.backbone.decoder_layers): - current_cache = past_key_values[:, i, ...] + self_attention_cache = self_attention_past_key_values[:, i, ...] + cross_attention_cache = cross_attention_past_key_values[:, i, ...] hidden_states, current_cache = layer( (hidden_states, encoder_output), self_attention_padding_mask=padding_mask, cross_attention_padding_mask=encoder_padding_mask, - self_attention_cache=current_cache, + self_attention_cache=self_attention_cache, + cross_attention_cache=cross_attention_cache, cache_update_index=cache_update_index, ) updated_key_values.append(current_cache) - past_key_values = keras.ops.stack(updated_key_values, axis=1) + self_attention_past_key_values = keras.ops.stack( + updated_key_values, axis=1 + ) hidden_states = self.backbone.decoder_norm(hidden_states) logits = self.backbone.token_embedding(hidden_states, reverse=True) cache = ( encoder_output, - past_key_values, + self_attention_past_key_values, + cross_attention_past_key_values, encoder_padding_mask, ) return logits, hidden_states, cache @@ -233,19 +239,38 @@ def _build_cache(self, token_ids, padding_mask): keras.ops.sqrt(self.backbone.hidden_dim), hidden_states.dtype ) hidden_states = self.backbone.decoder_dropout(hidden_states) - past_key_values = [] + # Cross-attention cache. + cross_attention_past_key_values = [] for layer in self.backbone.decoder_layers: + key_states = layer.cross_attn.key_dense(encoder_output) + key_states = keras.ops.transpose(key_states, (0, 2, 1, 3)) + value_states = layer.cross_attn.value_dense(encoder_output) + value_states = keras.ops.transpose(value_states, (0, 2, 1, 3)) + cross_attention_past_key_values.append( + keras.ops.stack((key_states, value_states), axis=1) + ) + cross_attention_past_key_values = keras.ops.stack( + cross_attention_past_key_values, axis=1 + ) + # Self-attention cache. + self_attention_past_key_values = [] + for i, layer in enumerate(self.backbone.decoder_layers): + cross_attention_cache = cross_attention_past_key_values[:, i, ...] hidden_states, kv_cache_for_layer = layer( (hidden_states, encoder_output), self_attention_padding_mask=padding_mask, cross_attention_padding_mask=padding_mask, + cross_attention_cache=cross_attention_cache, ) - past_key_values.append(kv_cache_for_layer) - past_key_values = keras.ops.stack(past_key_values, axis=1) + self_attention_past_key_values.append(kv_cache_for_layer) + self_attention_past_key_values = keras.ops.stack( + self_attention_past_key_values, axis=1 + ) hidden_states = self.backbone.decoder_norm(hidden_states) cache = ( encoder_output, - past_key_values, + self_attention_past_key_values, + cross_attention_past_key_values, padding_mask, ) return hidden_states, cache diff --git a/keras_hub/src/models/t5gemma/t5gemma_tokenizer.py b/keras_hub/src/models/t5gemma/t5gemma_tokenizer.py index 02f44c4220..f63f617a91 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +++ b/keras_hub/src/models/t5gemma/t5gemma_tokenizer.py @@ -1,9 +1,16 @@ +from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( SentencePieceTokenizer, ) +@keras_hub_export( + [ + "keras_hub.tokenizers.T5GemmaTokenizer", + "keras_hub.models.T5GemmaTokenizer", + ] +) class T5GemmaTokenizer(SentencePieceTokenizer): """T5Gemma tokenizer layer based on SentencePiece. From 1c7dc13821e0d682d31549c33f6f4b06a3a3a9bb Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Mon, 21 Jul 2025 11:13:33 +0400 Subject: [PATCH 03/19] refactor: Cleanup and replace incorrect T5LayerNorm with RMSNormalization (Gemma) --- .../src/models/t5gemma/t5gemma_attention.py | 432 ++++++------------ .../src/models/t5gemma/t5gemma_backbone.py | 6 +- .../src/models/t5gemma/t5gemma_causal_lm.py | 14 +- .../src/models/t5gemma/t5gemma_decoder.py | 49 +- .../src/models/t5gemma/t5gemma_encoder.py | 27 +- 5 files changed, 191 insertions(+), 337 deletions(-) diff --git a/keras_hub/src/models/t5gemma/t5gemma_attention.py b/keras_hub/src/models/t5gemma/t5gemma_attention.py index 39ecfbe98d..084b4b6957 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_attention.py +++ b/keras_hub/src/models/t5gemma/t5gemma_attention.py @@ -12,10 +12,9 @@ def repeat_kv(hidden_states, n_rep): """Repeats the key/value hidden states to match the number of query heads for Grouped Query Attention (GQA). - This function is used in `T5GemmaSelfAttention` and `T5GemmaCrossAttention` - to broadcast key and value states across multiple query heads when Grouped - Query Attention (GQA) is used (i.e., when `num_query_heads` > - `num_key_value_heads`). + This function is used in `T5GemmaAttention` to broadcast key and value + states across multiple query heads when Grouped Query Attention (GQA) is + used (i.e., when `num_query_heads` > `num_key_value_heads`). Args: hidden_states: Tensor, The key or value hidden states with shape @@ -38,11 +37,12 @@ def repeat_kv(hidden_states, n_rep): @keras.saving.register_keras_serializable(package="keras_hub") -class T5GemmaSelfAttention(CachedGemmaAttention): - """Self-attention block for the T5Gemma model. +class T5GemmaAttention(CachedGemmaAttention): + """A unified attention layer for T5Gemma that handles both self-attention + and cross-attention. - This layer performs self-attention with Rotary Positional Embeddings (RoPE) - and supports Grouped Query Attention (GQA). It is used in + This layer performs attention with optional Rotary Positional Embeddings + (RoPE) and supports Grouped Query Attention (GQA). It is used in `T5GemmaEncoderLayer` and `T5GemmaDecoderLayer`. Args: @@ -52,8 +52,11 @@ class T5GemmaSelfAttention(CachedGemmaAttention): can be less than `num_attention_heads`. query_pre_attn_scalar: float, Scalar to multiply queries by before attention. - attention_bias: bool, Whether to include bias in the query, key, value, - and output dense layers. + attention_bias: bool, Whether to include bias in the dense layers. + attention_type: str, The type of attention, either 'self' or 'cross'. + Defaults to 'self'. + cross_attention_hidden_size: int, optional, The dimensionality of + encoder hidden states for cross-attention. initializer_range: float, The range for the random normal initializer for kernel weights. Default is `0.02`. attention_dropout: float, The dropout rate applied to attention weights. @@ -61,7 +64,7 @@ class T5GemmaSelfAttention(CachedGemmaAttention): attn_logit_softcapping: float, optional, The softcapping value for attention logits. rope_max_wavelength: float, The maximum wavelength for Rotary Positional - Embeddings. Default is `10000.0`. + Embeddings. Default is `10000.0`. Only used for self-attention. **kwargs: Additional keyword arguments passed to the parent class. """ @@ -72,6 +75,8 @@ def __init__( num_key_value_heads, query_pre_attn_scalar, attention_bias, + attention_type="self", + cross_attention_hidden_size=None, initializer_range=0.02, attention_dropout=0.0, attn_logit_softcapping=None, @@ -89,135 +94,173 @@ def __init__( use_sliding_window_attention=False, **kwargs, ) - self.attention_dropout = attention_dropout + if attention_type not in ["self", "cross"]: + raise ValueError( + f"attention_type must be 'self' or 'cross', but got " + f"{attention_type}" + ) + self.attention_type = attention_type self.hidden_size = hidden_size + self.cross_attention_hidden_size = ( + cross_attention_hidden_size or hidden_size + ) self.query_pre_attn_scalar = query_pre_attn_scalar - self.initializer_range = initializer_range self.attention_bias = attention_bias + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.rope_max_wavelength = rope_max_wavelength self.num_key_value_groups = ( self.num_query_heads // self.num_key_value_heads ) self.scaling = self.query_pre_attn_scalar**-0.5 - self.rope_max_wavelength = rope_max_wavelength - self.rotary_embedding = RotaryEmbedding( - max_wavelength=self.rope_max_wavelength, - sequence_axis=2, - feature_axis=3, - name="rotary_embedding", - ) + if self.attention_type == "self": + self.rotary_embedding = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + sequence_axis=2, + feature_axis=3, + name="rotary_embedding", + ) def build(self, input_shape): self._kernel_initializer = t5gemma_kernel_initializer( self.initializer_range ) + if self.attention_type == "cross": + hidden_states_shape, kv_states_shape = input_shape + else: + hidden_states_shape = input_shape + kv_states_shape = input_shape # Query projection layer. - self.hidden_dim = input_shape[-1] + self.hidden_dim = hidden_states_shape[-1] self.query_dense = keras.layers.EinsumDense( - equation="...a,abc->...bc", - output_shape=(self.num_query_heads, self.head_dim), + equation="btd,dnh->bnth", + output_shape=(self.num_query_heads, None, self.head_dim), kernel_initializer=clone_initializer(self._kernel_initializer), - bias_axes="bc" if self.attention_bias else None, + bias_axes="nh" if self.attention_bias else None, dtype=self.dtype_policy, name="query", ) - self.query_dense.build(input_shape) + self.query_dense.build(hidden_states_shape) # Key projection layer. self.key_dense = keras.layers.EinsumDense( - equation="...a,abc->...bc", - output_shape=(self.num_key_value_heads, self.head_dim), + equation="bsd,dkh->bksh", + output_shape=(self.num_key_value_heads, None, self.head_dim), kernel_initializer=clone_initializer(self._kernel_initializer), - bias_axes="bc" if self.attention_bias else None, + bias_axes="kh" if self.attention_bias else None, dtype=self.dtype_policy, name="key", ) - self.key_dense.build(input_shape) + self.key_dense.build(kv_states_shape) # Value projection layer. self.value_dense = keras.layers.EinsumDense( - equation="...a,abc->...bc", - output_shape=(self.num_key_value_heads, self.head_dim), + equation="bsd,dkh->bksh", + output_shape=(self.num_key_value_heads, None, self.head_dim), kernel_initializer=clone_initializer(self._kernel_initializer), - bias_axes="bc" if self.attention_bias else None, + bias_axes="kh" if self.attention_bias else None, dtype=self.dtype_policy, name="value", ) - self.value_dense.build(input_shape) + self.value_dense.build(kv_states_shape) # Output projection layer. self.output_dense = keras.layers.EinsumDense( - equation="...a,ab->...b", - output_shape=(self.hidden_dim,), + equation="bnth,nhd->btd", + output_shape=(None, self.hidden_dim), kernel_initializer=clone_initializer(self._kernel_initializer), - bias_axes="b" if self.attention_bias else None, + bias_axes="d" if self.attention_bias else None, dtype=self.dtype_policy, name="attention_output", ) self.output_dense.build( - (*input_shape[:-1], self.num_query_heads * self.head_dim) + ( + hidden_states_shape[0], + self.num_query_heads, + hidden_states_shape[1], + self.head_dim, + ) ) self.dropout_layer = keras.layers.Dropout( rate=self.attention_dropout, dtype=self.dtype_policy, ) - q_len = input_shape[1] - attn_weights_shape = (None, self.num_query_heads, q_len, q_len) - self.dropout_layer.build(attn_weights_shape) self.softmax = keras.layers.Softmax(dtype="float32") self.built = True def call( self, - hidden_states, + inputs, attention_mask=None, cache=None, cache_update_index=None, training=None, ): - query_states = self.query_dense(hidden_states) - query_states = keras.ops.transpose(query_states, (0, 2, 1, 3)) - key_states = self.key_dense(hidden_states) - key_states = keras.ops.transpose(key_states, (0, 2, 1, 3)) - value_states = self.value_dense(hidden_states) - value_states = keras.ops.transpose(value_states, (0, 2, 1, 3)) - start_index = 0 if cache_update_index is None else cache_update_index - query_states = self.rotary_embedding( - query_states, start_index=start_index - ) - key_states = self.rotary_embedding(key_states, start_index=start_index) - current_pass_cache = keras.ops.stack((key_states, value_states), axis=1) - if cache is not None: - if cache_update_index is None: + if self.attention_type == "cross": + if not isinstance(inputs, (list, tuple)) or len(inputs) != 2: raise ValueError( - "Both `cache` and `cache_update_index` must be " - "passed for caching." + "For cross-attention, `inputs` must be a list or tuple of " + "two tensors: `[hidden_states, encoder_hidden_states]`." ) - key_cache = cache[:, 0, ...] - value_cache = cache[:, 1, ...] - start = [0, 0, cache_update_index, 0] - key_states = keras.ops.slice_update(key_cache, start, key_states) - value_states = keras.ops.slice_update( - value_cache, start, value_states + hidden_states, kv_states = inputs + else: + hidden_states = inputs + kv_states = hidden_states + query_states = self.query_dense(hidden_states) + if self.attention_type == "cross": + if cache is not None: + key_states = cache[:, 0, ...] + value_states = cache[:, 1, ...] + else: + key_states = self.key_dense(kv_states) + value_states = self.value_dense(kv_states) + else: # Self-attention + key_states = self.key_dense(kv_states) + value_states = self.value_dense(kv_states) + start_index = ( + 0 if cache_update_index is None else cache_update_index ) - cache = keras.ops.stack((key_states, value_states), axis=1) - elif cache_update_index is not None: - raise ValueError( - "`cache_update_index` should not be set if `cache` is `None`." + query_states = self.rotary_embedding( + query_states, start_index=start_index ) - else: - cache = current_pass_cache + key_states = self.rotary_embedding( + key_states, start_index=start_index + ) + current_pass_cache = keras.ops.stack( + (key_states, value_states), axis=1 + ) + if cache is not None: + if cache_update_index is None: + raise ValueError( + "Both `cache` and `cache_update_index` must be passed " + "for self-attention caching." + ) + key_cache, value_cache = cache[:, 0, ...], cache[:, 1, ...] + start = [0, 0, cache_update_index, 0] + key_states = keras.ops.slice_update( + key_cache, start, key_states + ) + value_states = keras.ops.slice_update( + value_cache, start, value_states + ) + cache = keras.ops.stack((key_states, value_states), axis=1) + elif cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + "`None`." + ) + else: + cache = current_pass_cache # Repeat key-value heads for GQA. key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = ( - keras.ops.matmul( - query_states, keras.ops.transpose(key_states, (0, 1, 3, 2)) - ) - * self.scaling + attn_weights = keras.ops.einsum( + "bnth,bnsh->bnts", query_states, key_states ) + attn_weights *= self.scaling if self.logit_soft_cap is not None: attn_weights = attn_weights / self.logit_soft_cap @@ -231,231 +274,32 @@ def call( query_states.dtype, ) attn_weights = self.dropout_layer(attn_weights, training=training) - attn_output = keras.ops.matmul(attn_weights, value_states) - attn_output = keras.ops.transpose(attn_output, (0, 2, 1, 3)) - attn_output = keras.ops.reshape( - attn_output, - ( - keras.ops.shape(hidden_states)[0], - -1, - self.num_query_heads * self.head_dim, - ), + attn_output = keras.ops.einsum( + "bnts,bnsh->bnth", attn_weights, value_states ) attn_output = self.output_dense(attn_output) - return (attn_output, attn_weights), cache - - def compute_output_shape(self, input_shape): - attn_output_shape = input_shape - q_len = input_shape[1] - attn_weights_shape = ( - input_shape[0], - self.num_query_heads, - q_len, - q_len, - ) - return attn_output_shape, attn_weights_shape - - def get_config(self): - config = super().get_config() - config.update( - { - "hidden_size": self.hidden_size, - "num_attention_heads": self.num_query_heads, - "num_key_value_heads": self.num_key_value_heads, - "query_pre_attn_scalar": self.query_pre_attn_scalar, - "attention_bias": self.attention_bias, - "initializer_range": self.initializer_range, - "attention_dropout": self.attention_dropout, - "attn_logit_softcapping": self.logit_soft_cap, - "rope_max_wavelength": self.rope_max_wavelength, - } - ) - return config - - -@keras.saving.register_keras_serializable(package="keras_hub") -class T5GemmaCrossAttention(keras.layers.Layer): - """Cross-attention block for the T5Gemma model. - - This layer performs cross-attention, where queries are derived from the - decoder hidden states and keys/values are from the encoder hidden states. - It supports Grouped Query Attention (GQA). It is used in - `T5GemmaDecoderLayer`. - - Args: - hidden_size: int, The dimensionality of the hidden states for queries - and output. - cross_attention_hidden_size: int, The dimensionality of the hidden - states from the encoder for keys and values. - num_attention_heads: int, The number of attention heads for queries. - num_key_value_heads: int, The number of key-value heads. For GQA, this - can be less than `num_attention_heads`. - query_pre_attn_scalar: float, Scalar to multiply queries by before - attention. - attention_bias: bool, Whether to include bias in the query, key, value, - and output dense layers. - initializer_range: float, The range for the random normal initializer - for kernel weights. Default is `0.02`. - attention_dropout: float, The dropout rate applied to attention weights. - Default is `0.0`. - attn_logit_softcapping: float, optional, The softcapping value for - attention logits. - **kwargs: Additional keyword arguments passed to the parent class. - """ - - def __init__( - self, - hidden_size, - cross_attention_hidden_size, - num_attention_heads, - num_key_value_heads, - query_pre_attn_scalar, - attention_bias, - initializer_range=0.02, - attention_dropout=0.0, - attn_logit_softcapping=None, - **kwargs, - ): - super().__init__(**kwargs) - self.hidden_size = hidden_size - self.cross_attention_hidden_size = ( - cross_attention_hidden_size or hidden_size - ) - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.query_pre_attn_scalar = query_pre_attn_scalar - self.initializer_range = initializer_range - self.attention_dropout = attention_dropout - self.attention_bias = attention_bias - self.attn_logit_softcapping = attn_logit_softcapping - - self.head_dim = self.hidden_size // self.num_attention_heads - self.num_key_value_groups = ( - self.num_attention_heads // self.num_key_value_heads - ) - self.scaling = self.query_pre_attn_scalar**-0.5 - - def build(self, input_shape): - hidden_states_shape, encoder_hidden_states_shape = input_shape - self.kernel_initializer = t5gemma_kernel_initializer( - self.initializer_range - ) - - # Query projection layer. - self.query_dense = keras.layers.EinsumDense( - equation="...a,abc->...bc", - output_shape=(self.num_attention_heads, self.head_dim), - kernel_initializer=clone_initializer(self.kernel_initializer), - bias_axes="bc" if self.attention_bias else None, - name="query", - ) - self.query_dense.build(hidden_states_shape) - cross_attn_proj_shape = ( - *encoder_hidden_states_shape[:-1], - self.cross_attention_hidden_size, - ) - - # Key projection layer. - self.key_dense = keras.layers.EinsumDense( - equation="...a,abc->...bc", - output_shape=(self.num_key_value_heads, self.head_dim), - kernel_initializer=clone_initializer(self.kernel_initializer), - bias_axes="bc" if self.attention_bias else None, - name="key", - ) - self.key_dense.build(cross_attn_proj_shape) - - # Value projection layer. - self.value_dense = keras.layers.EinsumDense( - equation="...a,abc->...bc", - output_shape=(self.num_key_value_heads, self.head_dim), - kernel_initializer=clone_initializer(self.kernel_initializer), - bias_axes="bc" if self.attention_bias else None, - name="value", - ) - self.value_dense.build(cross_attn_proj_shape) - - # Output projection layer. - self.output_dense = keras.layers.EinsumDense( - equation="...a,ab->...b", - output_shape=(self.hidden_size,), - kernel_initializer=clone_initializer(self.kernel_initializer), - bias_axes="b" if self.attention_bias else None, - name="attention_output", - ) - o_proj_input_shape = (*hidden_states_shape[:-1], self.hidden_size) - self.output_dense.build(o_proj_input_shape) - self.dropout_layer = keras.layers.Dropout(self.attention_dropout) - q_len = hidden_states_shape[1] - kv_len = encoder_hidden_states_shape[1] - attn_weights_shape = (None, self.num_attention_heads, q_len, kv_len) - self.dropout_layer.build(attn_weights_shape) - self.softmax = keras.layers.Softmax(dtype="float32") - self.built = True - - def call( - self, - inputs, - attention_mask=None, - cache=None, - training=None, - ): - hidden_states, encoder_hidden_states = inputs - batch_size, q_seq_len = keras.ops.shape(hidden_states)[:2] - query_states = self.query_dense(hidden_states) - query_states = keras.ops.transpose(query_states, (0, 2, 1, 3)) - if cache is not None: - key_states = cache[:, 0, ...] - value_states = cache[:, 1, ...] - else: - key_states = self.key_dense(encoder_hidden_states) - key_states = keras.ops.transpose(key_states, (0, 2, 1, 3)) - value_states = self.value_dense(encoder_hidden_states) - value_states = keras.ops.transpose(value_states, (0, 2, 1, 3)) - - # Repeat key-value heads for GQA. - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = ( - keras.ops.matmul( - query_states, keras.ops.transpose(key_states, (0, 1, 3, 2)) - ) - * self.scaling - ) - - if self.attn_logit_softcapping is not None: - attn_weights = attn_weights / self.attn_logit_softcapping - attn_weights = keras.ops.tanh(attn_weights) - attn_weights = attn_weights * self.attn_logit_softcapping - if attention_mask is not None: - attn_weights += attention_mask - - attn_weights = keras.ops.cast( - self.softmax(attn_weights), - query_states.dtype, - ) - attn_weights = self.dropout_layer(attn_weights, training=training) - attn_output = keras.ops.matmul(attn_weights, value_states) - attn_output = keras.ops.transpose(attn_output, (0, 2, 1, 3)) - attn_output = keras.ops.reshape( - attn_output, (batch_size, q_seq_len, -1) - ) - attn_output = self.output_dense(attn_output) - if cache is not None: - updated_cache = keras.ops.stack((key_states, value_states), axis=1) - return (attn_output, attn_weights), updated_cache - else: + if self.attention_type == "cross": + if cache is not None: + updated_cache = keras.ops.stack( + (key_states, value_states), axis=1 + ) + return (attn_output, attn_weights), updated_cache return attn_output, attn_weights + else: # Self-attention + return (attn_output, attn_weights), cache def compute_output_shape(self, input_shape): - hidden_states_shape, encoder_hidden_states_shape = input_shape + if self.attention_type == "cross": + hidden_states_shape, kv_states_shape = input_shape + else: + hidden_states_shape = input_shape + kv_states_shape = input_shape attn_output_shape = hidden_states_shape q_len = hidden_states_shape[1] - kv_len = encoder_hidden_states_shape[1] + kv_len = kv_states_shape[1] attn_weights_shape = ( hidden_states_shape[0], - self.num_attention_heads, + self.num_query_heads, q_len, kv_len, ) @@ -466,14 +310,16 @@ def get_config(self): config.update( { "hidden_size": self.hidden_size, - "cross_attention_hidden_size": self.cross_attention_hidden_size, - "num_attention_heads": self.num_attention_heads, + "num_attention_heads": self.num_query_heads, "num_key_value_heads": self.num_key_value_heads, "query_pre_attn_scalar": self.query_pre_attn_scalar, "attention_bias": self.attention_bias, + "attention_type": self.attention_type, + "cross_attention_hidden_size": self.cross_attention_hidden_size, "initializer_range": self.initializer_range, "attention_dropout": self.attention_dropout, - "attn_logit_softcapping": self.attn_logit_softcapping, + "attn_logit_softcapping": self.logit_soft_cap, + "rope_max_wavelength": self.rope_max_wavelength, } ) return config diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone.py b/keras_hub/src/models/t5gemma/t5gemma_backbone.py index 963b3b4e71..80240abaf2 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_backbone.py +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone.py @@ -5,7 +5,7 @@ ReversibleEmbedding, ) from keras_hub.src.models.backbone import Backbone -from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm +from keras_hub.src.models.gemma.rms_normalization import RMSNormalization from keras_hub.src.models.t5gemma.t5gemma_decoder import T5GemmaDecoderLayer from keras_hub.src.models.t5gemma.t5gemma_encoder import T5GemmaEncoderLayer @@ -112,7 +112,7 @@ def __init__( ) for i in range(num_layers) ] - self.encoder_norm = T5LayerNorm(epsilon=rms_norm_eps) + self.encoder_norm = RMSNormalization(epsilon=rms_norm_eps) self.encoder_dropout = keras.layers.Dropout(dropout_rate) self.decoder_layers = [ T5GemmaDecoderLayer( @@ -136,7 +136,7 @@ def __init__( ) for i in range(num_layers) ] - self.decoder_norm = T5LayerNorm(epsilon=rms_norm_eps) + self.decoder_norm = RMSNormalization(epsilon=rms_norm_eps) self.decoder_dropout = keras.layers.Dropout(dropout_rate) # === Functional Model === diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py index aea82ece75..76134fab85 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py @@ -150,10 +150,14 @@ def __init__(self, backbone, preprocessor=None, **kwargs): # rather than "backbone.inputs" which is the flattened list of inputs. inputs = backbone.input sequence_output = backbone(inputs) - outputs = backbone.token_embedding(sequence_output, reverse=True) + logits = backbone.token_embedding(sequence_output, reverse=True) + if self.backbone.final_logit_softcapping is not None: + logits = logits / self.backbone.final_logit_softcapping + logits = keras.ops.tanh(logits) + logits = logits * self.backbone.final_logit_softcapping super().__init__( inputs=inputs, - outputs=outputs, + outputs=logits, **kwargs, ) @@ -210,6 +214,10 @@ def call_with_cache( ) hidden_states = self.backbone.decoder_norm(hidden_states) logits = self.backbone.token_embedding(hidden_states, reverse=True) + if self.backbone.final_logit_softcapping is not None: + logits = logits / self.backbone.final_logit_softcapping + logits = keras.ops.tanh(logits) + logits = logits * self.backbone.final_logit_softcapping cache = ( encoder_output, self_attention_past_key_values, @@ -243,9 +251,7 @@ def _build_cache(self, token_ids, padding_mask): cross_attention_past_key_values = [] for layer in self.backbone.decoder_layers: key_states = layer.cross_attn.key_dense(encoder_output) - key_states = keras.ops.transpose(key_states, (0, 2, 1, 3)) value_states = layer.cross_attn.value_dense(encoder_output) - value_states = keras.ops.transpose(value_states, (0, 2, 1, 3)) cross_attention_past_key_values.append( keras.ops.stack((key_states, value_states), axis=1) ) diff --git a/keras_hub/src/models/t5gemma/t5gemma_decoder.py b/keras_hub/src/models/t5gemma/t5gemma_decoder.py index 66d752162a..738c391872 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_decoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_decoder.py @@ -1,8 +1,7 @@ import keras -from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm -from keras_hub.src.models.t5gemma.t5gemma_attention import T5GemmaCrossAttention -from keras_hub.src.models.t5gemma.t5gemma_attention import T5GemmaSelfAttention +from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.models.t5gemma.t5gemma_attention import T5GemmaAttention from keras_hub.src.models.t5gemma.t5gemma_layers import T5GemmaMLP @@ -88,34 +87,36 @@ def __init__( ) # Self-attention. - self.self_attn = T5GemmaSelfAttention( - hidden_size, - num_attention_heads, - num_key_value_heads, - query_pre_attn_scalar, - attention_bias, + self.self_attn = T5GemmaAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + attention_type="self", initializer_range=initializer_range, attention_dropout=attention_dropout, attn_logit_softcapping=attn_logit_softcapping, rope_max_wavelength=self.rope_max_wavelength, ) - self.pre_self_attn_layernorm = T5LayerNorm(epsilon=rms_norm_eps) - self.post_self_attn_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + self.pre_self_attn_layernorm = RMSNormalization(epsilon=rms_norm_eps) + self.post_self_attn_layernorm = RMSNormalization(epsilon=rms_norm_eps) # Cross-attention. - self.cross_attn = T5GemmaCrossAttention( - hidden_size, - cross_attention_hidden_size, - num_attention_heads, - num_key_value_heads, - query_pre_attn_scalar, - attention_bias, + self.cross_attn = T5GemmaAttention( + hidden_size=hidden_size, + cross_attention_hidden_size=cross_attention_hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + attention_type="cross", initializer_range=initializer_range, attention_dropout=attention_dropout, attn_logit_softcapping=attn_logit_softcapping, ) - self.pre_cross_attn_layernorm = T5LayerNorm(epsilon=rms_norm_eps) - self.post_cross_attn_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + self.pre_cross_attn_layernorm = RMSNormalization(epsilon=rms_norm_eps) + self.post_cross_attn_layernorm = RMSNormalization(epsilon=rms_norm_eps) # MLP. self.mlp = T5GemmaMLP( @@ -125,8 +126,8 @@ def __init__( dropout_rate, initializer_range=initializer_range, ) - self.pre_feedforward_layernorm = T5LayerNorm(epsilon=rms_norm_eps) - self.post_feedforward_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + self.pre_feedforward_layernorm = RMSNormalization(epsilon=rms_norm_eps) + self.post_feedforward_layernorm = RMSNormalization(epsilon=rms_norm_eps) self.dropout = keras.layers.Dropout(dropout_rate) @@ -193,7 +194,7 @@ def call( ) hidden_states = self.pre_self_attn_layernorm(hidden_states) (hidden_states, _), updated_self_attention_cache = self.self_attn( - hidden_states=hidden_states, + inputs=hidden_states, attention_mask=self_attention_mask, cache=self_attention_cache, cache_update_index=cache_update_index, @@ -211,7 +212,7 @@ def call( ) hidden_states = self.pre_cross_attn_layernorm(hidden_states) cross_attn_output = self.cross_attn( - [hidden_states, encoder_hidden_states], + inputs=[hidden_states, encoder_hidden_states], attention_mask=cross_attention_mask, cache=cross_attention_cache, training=training, diff --git a/keras_hub/src/models/t5gemma/t5gemma_encoder.py b/keras_hub/src/models/t5gemma/t5gemma_encoder.py index 99c75d302d..dcb3c620a4 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_encoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_encoder.py @@ -1,7 +1,7 @@ import keras -from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm -from keras_hub.src.models.t5gemma.t5gemma_attention import T5GemmaSelfAttention +from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.models.t5gemma.t5gemma_attention import T5GemmaAttention from keras_hub.src.models.t5gemma.t5gemma_layers import T5GemmaMLP @@ -81,19 +81,20 @@ def __init__( "`sliding_window` must be set for `sliding_attention` layer " "type." ) - self.self_attn = T5GemmaSelfAttention( - hidden_size, - num_attention_heads, - num_key_value_heads, - query_pre_attn_scalar, - attention_bias, + self.self_attn = T5GemmaAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + attention_type="self", initializer_range=initializer_range, attention_dropout=attention_dropout, attn_logit_softcapping=attn_logit_softcapping, rope_max_wavelength=self.rope_max_wavelength, ) - self.pre_self_attn_layernorm = T5LayerNorm(epsilon=rms_norm_eps) - self.post_self_attn_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + self.pre_self_attn_layernorm = RMSNormalization(epsilon=rms_norm_eps) + self.post_self_attn_layernorm = RMSNormalization(epsilon=rms_norm_eps) self.mlp = T5GemmaMLP( hidden_size, @@ -102,8 +103,8 @@ def __init__( dropout_rate, initializer_range=initializer_range, ) - self.pre_feedforward_layernorm = T5LayerNorm(epsilon=rms_norm_eps) - self.post_feedforward_layernorm = T5LayerNorm(epsilon=rms_norm_eps) + self.pre_feedforward_layernorm = RMSNormalization(epsilon=rms_norm_eps) + self.post_feedforward_layernorm = RMSNormalization(epsilon=rms_norm_eps) self.dropout = keras.layers.Dropout(dropout_rate) def build(self, input_shape): @@ -153,7 +154,7 @@ def call( attention_mask = self._make_attention_mask(hidden_states, padding_mask) hidden_states = self.pre_self_attn_layernorm(hidden_states) (hidden_states, _), _ = self.self_attn( - hidden_states=hidden_states, + inputs=hidden_states, attention_mask=attention_mask, cache=cache, cache_update_index=cache_update_index, From 41910d32c6f9f4c15a1ac6eb084763cb12c82a9d Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Tue, 22 Jul 2025 14:49:24 +0400 Subject: [PATCH 04/19] fix: Numerics @ atol=1e-4 --- .../src/models/t5gemma/t5gemma_attention.py | 96 +++++---- .../src/models/t5gemma/t5gemma_backbone.py | 16 +- .../src/models/t5gemma/t5gemma_causal_lm.py | 186 ++++++++++-------- .../models/t5gemma/t5gemma_causal_lm_test.py | 18 +- .../src/models/t5gemma/t5gemma_decoder.py | 49 +++-- .../src/models/t5gemma/t5gemma_encoder.py | 34 +--- 6 files changed, 236 insertions(+), 163 deletions(-) diff --git a/keras_hub/src/models/t5gemma/t5gemma_attention.py b/keras_hub/src/models/t5gemma/t5gemma_attention.py index 084b4b6957..dd18aea21c 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_attention.py +++ b/keras_hub/src/models/t5gemma/t5gemma_attention.py @@ -204,18 +204,48 @@ def call( "two tensors: `[hidden_states, encoder_hidden_states]`." ) hidden_states, kv_states = inputs - else: - hidden_states = inputs - kv_states = hidden_states - query_states = self.query_dense(hidden_states) - if self.attention_type == "cross": + query_states = self.query_dense(hidden_states) if cache is not None: - key_states = cache[:, 0, ...] - value_states = cache[:, 1, ...] + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set for " + "cross-attention caching." + ) + key_states, value_states = cache[:, 0, ...], cache[:, 1, ...] + updated_cache = cache else: key_states = self.key_dense(kv_states) value_states = self.value_dense(kv_states) + updated_cache = keras.ops.stack( + (key_states, value_states), axis=1 + ) + # Repeat key-value heads for GQA. + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = keras.ops.einsum( + "bnth,bnsh->bnts", query_states, key_states + ) + attn_weights *= self.scaling + if self.logit_soft_cap is not None: + attn_weights = attn_weights / self.logit_soft_cap + attn_weights = keras.ops.tanh(attn_weights) + attn_weights = attn_weights * self.logit_soft_cap + if attention_mask is not None: + attn_weights += attention_mask + attn_weights = keras.ops.cast( + self.softmax(attn_weights), + query_states.dtype, + ) + attn_weights = self.dropout_layer(attn_weights, training=training) + attn_output = keras.ops.einsum( + "bnts,bnsh->bnth", attn_weights, value_states + ) + attn_output = self.output_dense(attn_output) + return (attn_output, attn_weights), updated_cache else: # Self-attention + hidden_states = inputs + kv_states = hidden_states + query_states = self.query_dense(hidden_states) key_states = self.key_dense(kv_states) value_states = self.value_dense(kv_states) start_index = ( @@ -253,39 +283,31 @@ def call( else: cache = current_pass_cache - # Repeat key-value heads for GQA. - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + # Repeat key-value heads for GQA. + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = keras.ops.einsum( - "bnth,bnsh->bnts", query_states, key_states - ) - attn_weights *= self.scaling + attn_weights = keras.ops.einsum( + "bnth,bnsh->bnts", query_states, key_states + ) + attn_weights *= self.scaling - if self.logit_soft_cap is not None: - attn_weights = attn_weights / self.logit_soft_cap - attn_weights = keras.ops.tanh(attn_weights) - attn_weights = attn_weights * self.logit_soft_cap - if attention_mask is not None: - attn_weights += attention_mask + if self.logit_soft_cap is not None: + attn_weights = attn_weights / self.logit_soft_cap + attn_weights = keras.ops.tanh(attn_weights) + attn_weights = attn_weights * self.logit_soft_cap + if attention_mask is not None: + attn_weights += attention_mask - attn_weights = keras.ops.cast( - self.softmax(attn_weights), - query_states.dtype, - ) - attn_weights = self.dropout_layer(attn_weights, training=training) - attn_output = keras.ops.einsum( - "bnts,bnsh->bnth", attn_weights, value_states - ) - attn_output = self.output_dense(attn_output) - if self.attention_type == "cross": - if cache is not None: - updated_cache = keras.ops.stack( - (key_states, value_states), axis=1 - ) - return (attn_output, attn_weights), updated_cache - return attn_output, attn_weights - else: # Self-attention + attn_weights = keras.ops.cast( + self.softmax(attn_weights), + query_states.dtype, + ) + attn_weights = self.dropout_layer(attn_weights, training=training) + attn_output = keras.ops.einsum( + "bnts,bnsh->bnth", attn_weights, value_states + ) + attn_output = self.output_dense(attn_output) return (attn_output, attn_weights), cache def compute_output_shape(self, input_shape): diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone.py b/keras_hub/src/models/t5gemma/t5gemma_backbone.py index 80240abaf2..32624dc545 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_backbone.py +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone.py @@ -8,6 +8,10 @@ from keras_hub.src.models.gemma.rms_normalization import RMSNormalization from keras_hub.src.models.t5gemma.t5gemma_decoder import T5GemmaDecoderLayer from keras_hub.src.models.t5gemma.t5gemma_encoder import T5GemmaEncoderLayer +from keras_hub.src.models.t5gemma.t5gemma_layers import ( + t5gemma_kernel_initializer, +) +from keras_hub.src.utils.keras_utils import clone_initializer @keras_hub_export("keras_hub.models.T5GemmaBackbone") @@ -85,11 +89,19 @@ def __init__( rope_max_wavelength=10000.0, **kwargs, ): + self.kernel_initializer = t5gemma_kernel_initializer(initializer_range) + # === Layers === - self.token_embedding = ReversibleEmbedding( + self.token_embedding = keras.layers.Embedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + embeddings_initializer=clone_initializer(self.kernel_initializer), + ) + self.decoder_token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, tie_weights=tie_word_embeddings, + embeddings_initializer=clone_initializer(self.kernel_initializer), ) self.encoder_layers = [ T5GemmaEncoderLayer( @@ -162,7 +174,7 @@ def __init__( encoder_output = self.encoder_dropout(encoder_output) # Decoder. - decoder_embeddings = self.token_embedding(token_id_input) + decoder_embeddings = self.decoder_token_embedding(token_id_input) decoder_embeddings = decoder_embeddings * keras.ops.cast( keras.ops.sqrt(hidden_dim), decoder_embeddings.dtype ) diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py index 76134fab85..45659870f9 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py @@ -150,7 +150,7 @@ def __init__(self, backbone, preprocessor=None, **kwargs): # rather than "backbone.inputs" which is the flattened list of inputs. inputs = backbone.input sequence_output = backbone(inputs) - logits = backbone.token_embedding(sequence_output, reverse=True) + logits = backbone.decoder_token_embedding(sequence_output, reverse=True) if self.backbone.final_logit_softcapping is not None: logits = logits / self.backbone.final_logit_softcapping logits = keras.ops.tanh(logits) @@ -161,8 +161,34 @@ def __init__(self, backbone, preprocessor=None, **kwargs): **kwargs, ) - def call_with_cache( - self, token_ids, padding_mask, cache, cache_update_index + def call_encoder(self, token_ids, padding_mask): + """Process inputs through the encoder stack.""" + encoder_embeddings = self.backbone.token_embedding(token_ids) + encoder_embeddings *= keras.ops.cast( + keras.ops.sqrt(self.backbone.hidden_dim), encoder_embeddings.dtype + ) + encoder_hidden_states = self.backbone.encoder_dropout( + encoder_embeddings, training=False + ) + for layer in self.backbone.encoder_layers: + encoder_hidden_states = layer( + encoder_hidden_states, padding_mask=padding_mask, training=False + ) + encoder_output = self.backbone.encoder_norm(encoder_hidden_states) + encoder_output = self.backbone.encoder_dropout( + encoder_output, training=False + ) + return encoder_output, padding_mask + + def call_decoder_with_cache( + self, + decoder_token_ids, + decoder_padding_mask, + cache_update_index, + self_attention_cache, + cross_attention_cache, + encoder_output, + encoder_padding_mask, ): """Forward pass of `T5GemmaCausalLM` with cache. @@ -185,101 +211,90 @@ def call_with_cache( the final hidden representation of the input tokens, and `cache` is the updated decoding cache. """ - ( - encoder_output, - self_attention_past_key_values, - cross_attention_past_key_values, - encoder_padding_mask, - ) = cache - hidden_states = self.backbone.token_embedding(token_ids) + hidden_states = self.backbone.decoder_token_embedding(decoder_token_ids) hidden_states *= keras.ops.cast( keras.ops.sqrt(self.backbone.hidden_dim), hidden_states.dtype ) - hidden_states = self.backbone.decoder_dropout(hidden_states) - updated_key_values = [] + hidden_states = self.backbone.decoder_dropout( + hidden_states, training=False + ) + updated_self_attention_cache = [] for i, layer in enumerate(self.backbone.decoder_layers): - self_attention_cache = self_attention_past_key_values[:, i, ...] - cross_attention_cache = cross_attention_past_key_values[:, i, ...] - hidden_states, current_cache = layer( + current_self_attention_cache = self_attention_cache[:, i, ...] + current_cross_attention_cache = cross_attention_cache[:, i, ...] + hidden_states, new_self_attention_cache_for_layer = layer( (hidden_states, encoder_output), - self_attention_padding_mask=padding_mask, + self_attention_padding_mask=decoder_padding_mask, cross_attention_padding_mask=encoder_padding_mask, - self_attention_cache=self_attention_cache, - cross_attention_cache=cross_attention_cache, + self_attention_cache=current_self_attention_cache, + cross_attention_cache=current_cross_attention_cache, cache_update_index=cache_update_index, + training=False, + ) + updated_self_attention_cache.append( + new_self_attention_cache_for_layer ) - updated_key_values.append(current_cache) - self_attention_past_key_values = keras.ops.stack( - updated_key_values, axis=1 + updated_self_attention_cache = keras.ops.stack( + updated_self_attention_cache, axis=1 ) hidden_states = self.backbone.decoder_norm(hidden_states) - logits = self.backbone.token_embedding(hidden_states, reverse=True) + logits = self.backbone.decoder_token_embedding( + hidden_states, reverse=True + ) if self.backbone.final_logit_softcapping is not None: logits = logits / self.backbone.final_logit_softcapping logits = keras.ops.tanh(logits) logits = logits * self.backbone.final_logit_softcapping - cache = ( - encoder_output, - self_attention_past_key_values, - cross_attention_past_key_values, - encoder_padding_mask, + return ( + logits, + hidden_states, + updated_self_attention_cache, + cross_attention_cache, ) - return logits, hidden_states, cache def _build_cache(self, token_ids, padding_mask): """Build an empty cache for use with `call_with_cache()`.""" - # Encoder. - encoder_embeddings = self.backbone.token_embedding(token_ids) - encoder_embeddings *= keras.ops.cast( - keras.ops.sqrt(self.backbone.hidden_dim), encoder_embeddings.dtype - ) - encoder_hidden_states = self.backbone.encoder_dropout( - encoder_embeddings + encoder_output, encoder_padding_mask = self.call_encoder( + token_ids, padding_mask ) - for layer in self.backbone.encoder_layers: - encoder_hidden_states = layer( - encoder_hidden_states, - padding_mask=padding_mask, - ) - encoder_output = self.backbone.encoder_norm(encoder_hidden_states) - hidden_states = self.backbone.token_embedding(token_ids) - hidden_states *= keras.ops.cast( - keras.ops.sqrt(self.backbone.hidden_dim), hidden_states.dtype - ) - hidden_states = self.backbone.decoder_dropout(hidden_states) - # Cross-attention cache. - cross_attention_past_key_values = [] + # Pre-compute cross-attention cache. + cross_attention_cache = [] for layer in self.backbone.decoder_layers: key_states = layer.cross_attn.key_dense(encoder_output) value_states = layer.cross_attn.value_dense(encoder_output) - cross_attention_past_key_values.append( + cross_attention_cache.append( keras.ops.stack((key_states, value_states), axis=1) ) - cross_attention_past_key_values = keras.ops.stack( - cross_attention_past_key_values, axis=1 + cross_attention_cache = keras.ops.stack(cross_attention_cache, axis=1) + # Seed the self-attention cache. + hidden_states = self.backbone.decoder_token_embedding(token_ids) + hidden_states *= keras.ops.cast( + keras.ops.sqrt(self.backbone.hidden_dim), hidden_states.dtype ) - # Self-attention cache. - self_attention_past_key_values = [] + hidden_states = self.backbone.decoder_dropout( + hidden_states, training=False + ) + # Seed the cache by running a forward pass on the prompt. + updated_self_attention_cache = [] for i, layer in enumerate(self.backbone.decoder_layers): - cross_attention_cache = cross_attention_past_key_values[:, i, ...] - hidden_states, kv_cache_for_layer = layer( + current_cross_attention_cache = cross_attention_cache[:, i, ...] + hidden_states, new_self_cache = layer( (hidden_states, encoder_output), self_attention_padding_mask=padding_mask, - cross_attention_padding_mask=padding_mask, - cross_attention_cache=cross_attention_cache, + cross_attention_padding_mask=encoder_padding_mask, + self_attention_cache=None, + cross_attention_cache=current_cross_attention_cache, + cache_update_index=None, + training=False, ) - self_attention_past_key_values.append(kv_cache_for_layer) - self_attention_past_key_values = keras.ops.stack( - self_attention_past_key_values, axis=1 + updated_self_attention_cache.append(new_self_cache) + self_attention_cache = keras.ops.stack( + updated_self_attention_cache, axis=1 ) hidden_states = self.backbone.decoder_norm(hidden_states) - cache = ( - encoder_output, - self_attention_past_key_values, - cross_attention_past_key_values, - padding_mask, - ) - return hidden_states, cache + cache = (self_attention_cache, cross_attention_cache) + extra_cache_info = (encoder_output, encoder_padding_mask) + return hidden_states, cache, extra_cache_info def generate_step(self, inputs, stop_token_ids=None): """A compilable generation function for a single batch of inputs. @@ -295,13 +310,14 @@ def generate_step(self, inputs, stop_token_ids=None): sequences have produced a new stop token, generation will stop. """ - # Create and seed cache with a single forward pass. - hidden_states, cache = self._build_cache( - token_ids=inputs["token_ids"], - padding_mask=inputs["padding_mask"], - ) token_ids = inputs["token_ids"] padding_mask = inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache, extra_cache_info = self._build_cache( + token_ids=token_ids, padding_mask=padding_mask + ) + self_attention_cache, cross_attention_cache = cache + encoder_output, encoder_padding_mask = extra_cache_info # Compute the lengths of all user inputted tokens ids. row_lengths = keras.ops.sum( keras.ops.cast(padding_mask, "int32"), axis=-1 @@ -310,25 +326,37 @@ def generate_step(self, inputs, stop_token_ids=None): index = keras.ops.min(row_lengths) def next(prompt, cache, index): + self_attention_cache, cross_attention_cache = cache # The cache index is the index of our previous token. cache_update_index = index - 1 batch_size = keras.ops.shape(prompt)[0] prompt = keras.ops.slice( prompt, [0, cache_update_index], [batch_size, 1] ) - prompt_padding_mask = keras.ops.ones_like(prompt, dtype="int32") - logits, _, cache = self.call_with_cache( - prompt, - prompt_padding_mask, - cache, - cache_update_index, + ( + logits, + _, + updated_self_attention_cache, + updated_cross_attention_cache, + ) = self.call_decoder_with_cache( + decoder_token_ids=prompt, + decoder_padding_mask=None, + cache_update_index=cache_update_index, + self_attention_cache=self_attention_cache, + cross_attention_cache=cross_attention_cache, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask, + ) + cache = ( + updated_self_attention_cache, + updated_cross_attention_cache, ) return keras.ops.squeeze(logits, axis=1), None, cache token_ids = self.sampler( next=next, prompt=token_ids, - cache=cache, + cache=(self_attention_cache, cross_attention_cache), index=index, mask=padding_mask, stop_token_ids=stop_token_ids, diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py index 7d27681b0a..8441bf3cc4 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py @@ -90,11 +90,16 @@ def test_generate_strip_prompt(self): def test_early_stopping(self): causal_lm = T5GemmaCausalLM(**self.init_kwargs) - call_with_cache = causal_lm.call_with_cache + call_decoder_with_cache = causal_lm.call_decoder_with_cache def wrapper(*args, **kwargs): """Modify output logits to always favor end_token_id""" - logits, hidden_states, cache = call_with_cache(*args, **kwargs) + ( + logits, + hidden_states, + self_attention_cache, + cross_attention_cache, + ) = call_decoder_with_cache(*args, **kwargs) index = self.preprocessor.tokenizer.end_token_id update = ( keras.ops.ones( @@ -103,9 +108,14 @@ def wrapper(*args, **kwargs): * 1.0e9 ) logits = keras.ops.slice_update(logits, (0, 0, index), update) - return logits, hidden_states, cache + return ( + logits, + hidden_states, + self_attention_cache, + cross_attention_cache, + ) - with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + with patch.object(causal_lm, "call_decoder_with_cache", wraps=wrapper): prompt = ["the quick brown fox", "the earth is round"] output = causal_lm.generate(prompt) # We should immediately abort and output the prompt. diff --git a/keras_hub/src/models/t5gemma/t5gemma_decoder.py b/keras_hub/src/models/t5gemma/t5gemma_decoder.py index 738c391872..ec47115078 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_decoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_decoder.py @@ -155,21 +155,43 @@ def build(self, input_shape): self.post_feedforward_layernorm.build(mlp_output_shape) self.built = True - def _make_self_attention_mask(self, hidden_states, padding_mask): - seq_len = keras.ops.shape(hidden_states)[1] - q_indices = keras.ops.arange(0, seq_len, dtype="int32")[:, None] - kv_indices = keras.ops.arange(0, seq_len, dtype="int32")[None, :] - causal_mask = kv_indices <= q_indices + def _make_self_attention_mask( + self, + hidden_states, + padding_mask, + cache=None, + cache_update_index=None, + ): + if cache is not None: + q_len = keras.ops.shape(hidden_states)[1] + kv_len = keras.ops.shape(cache)[3] + q_indices = ( + keras.ops.arange(0, q_len, dtype="int32") + cache_update_index + ) + kv_indices = keras.ops.arange(0, kv_len, dtype="int32") + else: + q_len = kv_len = keras.ops.shape(hidden_states)[1] + q_indices = keras.ops.arange(0, q_len, dtype="int32") + kv_indices = keras.ops.arange(0, kv_len, dtype="int32") + # Create the causal mask. + causal_mask = kv_indices[None, :] <= q_indices[:, None] + # Apply sliding window if applicable. if self.layer_type == "sliding_attention": - sliding_mask = (q_indices - self.sliding_window) <= kv_indices + sliding_mask = ( + q_indices[:, None] - self.sliding_window + ) <= kv_indices[None, :] causal_mask = keras.ops.logical_and(causal_mask, sliding_mask) + # Combine with padding mask. final_mask = causal_mask[None, None, :, :] if padding_mask is not None: - padding_mask_4d = padding_mask[:, None, None, :] + padding_mask_slice = padding_mask[:, :kv_len] + padding_mask_4d = padding_mask_slice[:, None, None, :] final_mask = keras.ops.logical_and(final_mask, padding_mask_4d) return (1.0 - keras.ops.cast(final_mask, hidden_states.dtype)) * -1e9 def _make_cross_attention_mask(self, hidden_states, padding_mask): + if padding_mask is None: + return None bidirectional_mask = padding_mask[:, None, None, :] additive_bidirectional_mask = ( 1.0 - keras.ops.cast(bidirectional_mask, hidden_states.dtype) @@ -190,7 +212,10 @@ def call( # Self Attention. residual = hidden_states self_attention_mask = self._make_self_attention_mask( - hidden_states, self_attention_padding_mask + hidden_states, + self_attention_padding_mask, + cache=self_attention_cache, + cache_update_index=cache_update_index, ) hidden_states = self.pre_self_attn_layernorm(hidden_states) (hidden_states, _), updated_self_attention_cache = self.self_attn( @@ -208,19 +233,15 @@ def call( # Cross Attention. residual = hidden_states cross_attention_mask = self._make_cross_attention_mask( - hidden_states, cross_attention_padding_mask + encoder_hidden_states, cross_attention_padding_mask ) hidden_states = self.pre_cross_attn_layernorm(hidden_states) - cross_attn_output = self.cross_attn( + (hidden_states, _), _ = self.cross_attn( inputs=[hidden_states, encoder_hidden_states], attention_mask=cross_attention_mask, cache=cross_attention_cache, training=training, ) - if cross_attention_cache is not None: - (hidden_states, _), _ = cross_attn_output - else: - hidden_states, _ = cross_attn_output hidden_states = self.post_cross_attn_layernorm(hidden_states) hidden_states = residual + self.dropout( diff --git a/keras_hub/src/models/t5gemma/t5gemma_encoder.py b/keras_hub/src/models/t5gemma/t5gemma_encoder.py index dcb3c620a4..266acfac9d 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_encoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_encoder.py @@ -109,45 +109,27 @@ def __init__( def build(self, input_shape): self.pre_self_attn_layernorm.build(input_shape) - current_shape = input_shape - self.self_attn.build(current_shape) - attn_output_shape = self.self_attn.compute_output_shape(current_shape)[ - 0 - ] + self.self_attn.build(input_shape) + attn_output_shape = self.self_attn.compute_output_shape(input_shape)[0] self.post_self_attn_layernorm.build(attn_output_shape) - current_shape = attn_output_shape - self.dropout.build(current_shape) - self.pre_feedforward_layernorm.build(current_shape) - self.mlp.build(current_shape) - current_shape = self.mlp.compute_output_shape(current_shape) - self.post_feedforward_layernorm.build(current_shape) + self.dropout.build(attn_output_shape) + self.pre_feedforward_layernorm.build(attn_output_shape) + self.mlp.build(attn_output_shape) + mlp_output_shape = self.mlp.compute_output_shape(attn_output_shape) + self.post_feedforward_layernorm.build(mlp_output_shape) self.built = True def _make_attention_mask(self, hidden_states, padding_mask): - seq_len = keras.ops.shape(hidden_states)[1] attention_mask = padding_mask[:, None, None, :] additive_mask = ( 1.0 - keras.ops.cast(attention_mask, hidden_states.dtype) ) * -1e9 - if self.layer_type == "sliding_attention": - q_indices = keras.ops.arange(0, seq_len, dtype="int32")[:, None] - kv_indices = keras.ops.arange(0, seq_len, dtype="int32")[None, :] - window_mask = (q_indices - self.sliding_window < kv_indices) & ( - kv_indices < q_indices + self.sliding_window - ) - window_mask = window_mask[None, None, :, :] - window_additive_mask = ( - 1.0 - keras.ops.cast(window_mask, hidden_states.dtype) - ) * -1e9 - additive_mask = additive_mask + window_additive_mask return additive_mask def call( self, hidden_states, padding_mask=None, - cache=None, - cache_update_index=None, training=None, ): residual = hidden_states @@ -156,8 +138,6 @@ def call( (hidden_states, _), _ = self.self_attn( inputs=hidden_states, attention_mask=attention_mask, - cache=cache, - cache_update_index=cache_update_index, training=training, ) hidden_states = self.post_self_attn_layernorm(hidden_states) From a8eb53c6676704f888c753a85ca422eca5ed00e0 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Wed, 23 Jul 2025 14:04:52 +0400 Subject: [PATCH 05/19] refactor: Refactor T5Gemma decoder cache handling --- .../src/models/t5gemma/t5gemma_backbone.py | 31 ++++ .../src/models/t5gemma/t5gemma_causal_lm.py | 133 +++++++++--------- .../models/t5gemma/t5gemma_causal_lm_test.py | 6 +- .../src/models/t5gemma/t5gemma_decoder.py | 32 +++-- 4 files changed, 121 insertions(+), 81 deletions(-) diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone.py b/keras_hub/src/models/t5gemma/t5gemma_backbone.py index 32624dc545..3fb41318f1 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_backbone.py +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone.py @@ -63,6 +63,37 @@ class T5GemmaBackbone(Backbone): Embeddings. Default is `10000.0`. **kwargs: Additional keyword arguments passed to the parent `Backbone` class. + + Examples: + ```python + import numpy as np + import keras + from keras_hub.models import T5GemmaBackbone + + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array( + [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], dtype="int32" + ), + } + + # Randomly initialized T5Gemma backbone with custom config. + model = T5GemmaBackbone( + vocabulary_size=32000, + num_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + hidden_dim=256, + intermediate_dim=512, + dropout_rate=0.1, + rms_norm_eps=1e-6, + query_pre_attn_scalar=1.0, + attention_bias=False, + hidden_activation="gelu_approximate", + layer_types=["full_attention"] * 4, + ) + output = model(input_data) + ``` """ def __init__( diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py index 45659870f9..d6c20fdaee 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py @@ -184,26 +184,29 @@ def call_decoder_with_cache( self, decoder_token_ids, decoder_padding_mask, + cache, cache_update_index, - self_attention_cache, - cross_attention_cache, encoder_output, encoder_padding_mask, ): - """Forward pass of `T5GemmaCausalLM` with cache. + """Forward pass of `T5GemmaCausalLM`'s decoder with cache. - `call_with_cache` adds an additional forward pass for the model for - autoregressive inference. Unlike calling the model directly, this method - allows caching previous key/value Tensors in the attention layers, - and avoids recomputing the outputs of seen tokens. + `call_decoder_with_cache` adds an additional forward pass for the model + for autoregressive inference. Unlike calling the model directly, this + method allows caching previous key/value Tensors in the attention + layers, and avoids recomputing the outputs of seen tokens. Args: - token_ids: A dense int Tensor with shape `(batch_size, max_length)`. - padding_mask: A dense int Tensor with shape `(batch_size, - max_length)`. + decoder_token_ids: A dense int Tensor with shape + `(batch_size, max_length)`. The token ids for the decoder. + decoder_padding_mask: A dense int Tensor with shape `(batch_size, + max_length)`. The padding mask for the decoder. cache: A dense float Tensor, the cache of key and value states. cache_update_index: int, or int Tensor. The index of the current token being processed in the whole sequence. + encoder_output: A dense float Tensor. The output of the encoder. + encoder_padding_mask: A dense int Tensor. The padding mask for + the encoder output. Returns: A `(logits, hidden_states, cache)` tuple. Where `logits` is the @@ -211,6 +214,7 @@ def call_decoder_with_cache( the final hidden representation of the input tokens, and `cache` is the updated decoding cache. """ + self_attention_cache, cross_attention_cache = cache hidden_states = self.backbone.decoder_token_embedding(decoder_token_ids) hidden_states *= keras.ops.cast( keras.ops.sqrt(self.backbone.hidden_dim), hidden_states.dtype @@ -218,24 +222,38 @@ def call_decoder_with_cache( hidden_states = self.backbone.decoder_dropout( hidden_states, training=False ) - updated_self_attention_cache = [] + # Every decoder layer has a separate cache for the self-attention layer + # and the cross-attention layer. We update all of them separately. + updated_self_attention_caches = [] + updated_cross_attention_caches = [] for i, layer in enumerate(self.backbone.decoder_layers): - current_self_attention_cache = self_attention_cache[:, i, ...] - current_cross_attention_cache = cross_attention_cache[:, i, ...] - hidden_states, new_self_attention_cache_for_layer = layer( + layer_self_cache = ( + self_attention_cache[:, i, ...] + if self_attention_cache is not None + else None + ) + layer_cross_cache = ( + cross_attention_cache[:, i, ...] + if cross_attention_cache is not None + else None + ) + layer_cache = (layer_self_cache, layer_cross_cache) + hidden_states, updated_layer_cache = layer( (hidden_states, encoder_output), self_attention_padding_mask=decoder_padding_mask, cross_attention_padding_mask=encoder_padding_mask, - self_attention_cache=current_self_attention_cache, - cross_attention_cache=current_cross_attention_cache, + cache=layer_cache, cache_update_index=cache_update_index, training=False, ) - updated_self_attention_cache.append( - new_self_attention_cache_for_layer - ) - updated_self_attention_cache = keras.ops.stack( - updated_self_attention_cache, axis=1 + new_self_cache, new_cross_cache = updated_layer_cache + updated_self_attention_caches.append(new_self_cache) + updated_cross_attention_caches.append(new_cross_cache) + self_attention_cache = keras.ops.stack( + updated_self_attention_caches, axis=1 + ) + cross_attention_cache = keras.ops.stack( + updated_cross_attention_caches, axis=1 ) hidden_states = self.backbone.decoder_norm(hidden_states) logits = self.backbone.decoder_token_embedding( @@ -248,8 +266,7 @@ def call_decoder_with_cache( return ( logits, hidden_states, - updated_self_attention_cache, - cross_attention_cache, + (self_attention_cache, cross_attention_cache), ) def _build_cache(self, token_ids, padding_mask): @@ -257,42 +274,30 @@ def _build_cache(self, token_ids, padding_mask): encoder_output, encoder_padding_mask = self.call_encoder( token_ids, padding_mask ) - # Pre-compute cross-attention cache. - cross_attention_cache = [] - for layer in self.backbone.decoder_layers: - key_states = layer.cross_attn.key_dense(encoder_output) - value_states = layer.cross_attn.value_dense(encoder_output) - cross_attention_cache.append( - keras.ops.stack((key_states, value_states), axis=1) - ) - cross_attention_cache = keras.ops.stack(cross_attention_cache, axis=1) - # Seed the self-attention cache. - hidden_states = self.backbone.decoder_token_embedding(token_ids) - hidden_states *= keras.ops.cast( - keras.ops.sqrt(self.backbone.hidden_dim), hidden_states.dtype + batch_size = keras.ops.shape(token_ids)[0] + num_layers = self.backbone.num_layers + num_kv_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_attention_heads + self_cache_shape = ( + batch_size, + num_layers, + 2, + num_kv_heads, + keras.ops.shape(token_ids)[1], + head_dim, ) - hidden_states = self.backbone.decoder_dropout( - hidden_states, training=False + self_attention_cache = keras.ops.zeros( + self_cache_shape, dtype=self.compute_dtype ) - # Seed the cache by running a forward pass on the prompt. - updated_self_attention_cache = [] - for i, layer in enumerate(self.backbone.decoder_layers): - current_cross_attention_cache = cross_attention_cache[:, i, ...] - hidden_states, new_self_cache = layer( - (hidden_states, encoder_output), - self_attention_padding_mask=padding_mask, - cross_attention_padding_mask=encoder_padding_mask, - self_attention_cache=None, - cross_attention_cache=current_cross_attention_cache, - cache_update_index=None, - training=False, - ) - updated_self_attention_cache.append(new_self_cache) - self_attention_cache = keras.ops.stack( - updated_self_attention_cache, axis=1 + cross_attention_cache = None + _, hidden_states, cache = self.call_decoder_with_cache( + decoder_token_ids=token_ids, + decoder_padding_mask=padding_mask, + cache=(self_attention_cache, cross_attention_cache), + cache_update_index=0, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask, ) - hidden_states = self.backbone.decoder_norm(hidden_states) - cache = (self_attention_cache, cross_attention_cache) extra_cache_info = (encoder_output, encoder_padding_mask) return hidden_states, cache, extra_cache_info @@ -316,7 +321,6 @@ def generate_step(self, inputs, stop_token_ids=None): hidden_states, cache, extra_cache_info = self._build_cache( token_ids=token_ids, padding_mask=padding_mask ) - self_attention_cache, cross_attention_cache = cache encoder_output, encoder_padding_mask = extra_cache_info # Compute the lengths of all user inputted tokens ids. row_lengths = keras.ops.sum( @@ -326,7 +330,6 @@ def generate_step(self, inputs, stop_token_ids=None): index = keras.ops.min(row_lengths) def next(prompt, cache, index): - self_attention_cache, cross_attention_cache = cache # The cache index is the index of our previous token. cache_update_index = index - 1 batch_size = keras.ops.shape(prompt)[0] @@ -336,27 +339,21 @@ def next(prompt, cache, index): ( logits, _, - updated_self_attention_cache, - updated_cross_attention_cache, + updated_cache, ) = self.call_decoder_with_cache( decoder_token_ids=prompt, decoder_padding_mask=None, cache_update_index=cache_update_index, - self_attention_cache=self_attention_cache, - cross_attention_cache=cross_attention_cache, + cache=cache, encoder_output=encoder_output, encoder_padding_mask=encoder_padding_mask, ) - cache = ( - updated_self_attention_cache, - updated_cross_attention_cache, - ) - return keras.ops.squeeze(logits, axis=1), None, cache + return keras.ops.squeeze(logits, axis=1), None, updated_cache token_ids = self.sampler( next=next, prompt=token_ids, - cache=(self_attention_cache, cross_attention_cache), + cache=cache, index=index, mask=padding_mask, stop_token_ids=stop_token_ids, diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py index 8441bf3cc4..efbc9b6b52 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py @@ -97,8 +97,7 @@ def wrapper(*args, **kwargs): ( logits, hidden_states, - self_attention_cache, - cross_attention_cache, + cache, ) = call_decoder_with_cache(*args, **kwargs) index = self.preprocessor.tokenizer.end_token_id update = ( @@ -111,8 +110,7 @@ def wrapper(*args, **kwargs): return ( logits, hidden_states, - self_attention_cache, - cross_attention_cache, + cache, ) with patch.object(causal_lm, "call_decoder_with_cache", wraps=wrapper): diff --git a/keras_hub/src/models/t5gemma/t5gemma_decoder.py b/keras_hub/src/models/t5gemma/t5gemma_decoder.py index ec47115078..517e2d7622 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_decoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_decoder.py @@ -203,12 +203,14 @@ def call( inputs, self_attention_padding_mask=None, cross_attention_padding_mask=None, - self_attention_cache=None, - cross_attention_cache=None, + cache=None, cache_update_index=None, training=None, ): hidden_states, encoder_hidden_states = inputs + self_attention_cache, cross_attention_cache = ( + cache if cache is not None else (None, None) + ) # Self Attention. residual = hidden_states self_attention_mask = self._make_self_attention_mask( @@ -236,7 +238,7 @@ def call( encoder_hidden_states, cross_attention_padding_mask ) hidden_states = self.pre_cross_attn_layernorm(hidden_states) - (hidden_states, _), _ = self.cross_attn( + (hidden_states, _), updated_cross_attention_cache = self.cross_attn( inputs=[hidden_states, encoder_hidden_states], attention_mask=cross_attention_mask, cache=cross_attention_cache, @@ -256,20 +258,32 @@ def call( hidden_states = residual + self.dropout( hidden_states, training=training ) - return hidden_states, updated_self_attention_cache + updated_cache = ( + updated_self_attention_cache, + updated_cross_attention_cache, + ) + return hidden_states, updated_cache def compute_output_shape(self, input_shape): - hidden_states_shape, _ = input_shape - batch_size, seq_len, _ = hidden_states_shape + hidden_states_shape, encoder_hidden_states_shape = input_shape + batch_size, dec_seq_len, _ = hidden_states_shape + _, enc_seq_len, _ = encoder_hidden_states_shape head_dim = self.hidden_size // self.num_attention_heads - cache_shape = ( + self_cache_shape = ( + batch_size, + 2, + self.num_key_value_heads, + dec_seq_len, + head_dim, + ) + cross_cache_shape = ( batch_size, 2, self.num_key_value_heads, - seq_len, + enc_seq_len, head_dim, ) - return hidden_states_shape, cache_shape + return hidden_states_shape, (self_cache_shape, cross_cache_shape) def get_config(self): config = super().get_config() From 95f563b835b8dceece5521ee6163482c22cea731 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Wed, 23 Jul 2025 19:22:59 +0400 Subject: [PATCH 06/19] feat: Add checkpoint conversion script --- .../src/models/t5gemma/t5gemma_attention.py | 4 +- .../src/models/t5gemma/t5gemma_backbone.py | 5 + .../models/t5gemma/t5gemma_backbone_test.py | 1 + .../models/t5gemma/t5gemma_causal_lm_test.py | 1 + .../src/models/t5gemma/t5gemma_decoder.py | 4 + .../src/models/t5gemma/t5gemma_encoder.py | 4 + .../convert_t5gemma_checkpoints.py | 335 ++++++++++++++++++ 7 files changed, 353 insertions(+), 1 deletion(-) create mode 100644 tools/checkpoint_conversion/convert_t5gemma_checkpoints.py diff --git a/keras_hub/src/models/t5gemma/t5gemma_attention.py b/keras_hub/src/models/t5gemma/t5gemma_attention.py index dd18aea21c..94116ce02a 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_attention.py +++ b/keras_hub/src/models/t5gemma/t5gemma_attention.py @@ -75,6 +75,7 @@ def __init__( num_key_value_heads, query_pre_attn_scalar, attention_bias, + head_dim, attention_type="self", cross_attention_hidden_size=None, initializer_range=0.02, @@ -84,7 +85,7 @@ def __init__( **kwargs, ): super().__init__( - head_dim=hidden_size // num_attention_heads, + head_dim=head_dim, num_query_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, kernel_initializer=t5gemma_kernel_initializer(initializer_range), @@ -332,6 +333,7 @@ def get_config(self): config.update( { "hidden_size": self.hidden_size, + "head_dim": self.head_dim, "num_attention_heads": self.num_query_heads, "num_key_value_heads": self.num_key_value_heads, "query_pre_attn_scalar": self.query_pre_attn_scalar, diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone.py b/keras_hub/src/models/t5gemma/t5gemma_backbone.py index 3fb41318f1..d5a63f3bad 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_backbone.py +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone.py @@ -110,6 +110,7 @@ def __init__( attention_bias, hidden_activation, layer_types, + head_dim, tie_word_embeddings=True, initializer_range=0.02, attention_dropout=0.0, @@ -144,6 +145,7 @@ def __init__( attention_bias=attention_bias, intermediate_size=intermediate_dim, hidden_activation=hidden_activation, + head_dim=head_dim, dropout_rate=dropout_rate, initializer_range=initializer_range, attention_dropout=attention_dropout, @@ -169,6 +171,7 @@ def __init__( hidden_activation=hidden_activation, dropout_rate=dropout_rate, initializer_range=initializer_range, + head_dim=head_dim, attention_dropout=attention_dropout, layer_type=layer_types[i], sliding_window=sliding_window, @@ -249,6 +252,7 @@ def __init__( self.attn_logit_softcapping = attn_logit_softcapping self.final_logit_softcapping = final_logit_softcapping self.rope_max_wavelength = rope_max_wavelength + self.head_dim = head_dim def get_config(self): config = super().get_config() @@ -265,6 +269,7 @@ def get_config(self): "tie_word_embeddings": self.tie_word_embeddings, "query_pre_attn_scalar": self.query_pre_attn_scalar, "attention_bias": self.attention_bias, + "head_dim": self.head_dim, "hidden_activation": self.hidden_activation, "layer_types": self.layer_types, "initializer_range": self.initializer_range, diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py b/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py index 84413cec85..171a233b6f 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py @@ -14,6 +14,7 @@ def setUp(self): "num_layers": 2, "num_attention_heads": 4, "num_key_value_heads": 2, + "head_dim": 8, "dropout_rate": 0.1, "rms_norm_eps": 1e-6, "tie_word_embeddings": True, diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py index efbc9b6b52..003c6211f7 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py @@ -32,6 +32,7 @@ def setUp(self): num_attention_heads=2, num_key_value_heads=1, dropout_rate=0.0, + head_dim=8, rms_norm_eps=1e-6, tie_word_embeddings=False, query_pre_attn_scalar=1.0, diff --git a/keras_hub/src/models/t5gemma/t5gemma_decoder.py b/keras_hub/src/models/t5gemma/t5gemma_decoder.py index 517e2d7622..f14c6dcb11 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_decoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_decoder.py @@ -53,6 +53,7 @@ def __init__( intermediate_size, hidden_activation, dropout_rate, + head_dim, initializer_range, attention_dropout, layer_type, @@ -63,6 +64,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + self.head_dim = head_dim self.hidden_size = hidden_size self.rms_norm_eps = rms_norm_eps self.num_attention_heads = num_attention_heads @@ -93,6 +95,7 @@ def __init__( num_key_value_heads=num_key_value_heads, query_pre_attn_scalar=query_pre_attn_scalar, attention_bias=attention_bias, + head_dim=self.head_dim, attention_type="self", initializer_range=initializer_range, attention_dropout=attention_dropout, @@ -110,6 +113,7 @@ def __init__( num_key_value_heads=num_key_value_heads, query_pre_attn_scalar=query_pre_attn_scalar, attention_bias=attention_bias, + head_dim=self.head_dim, attention_type="cross", initializer_range=initializer_range, attention_dropout=attention_dropout, diff --git a/keras_hub/src/models/t5gemma/t5gemma_encoder.py b/keras_hub/src/models/t5gemma/t5gemma_encoder.py index 266acfac9d..5327fd5ce7 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_encoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_encoder.py @@ -53,6 +53,7 @@ def __init__( initializer_range, attention_dropout, layer_type, + head_dim, attn_logit_softcapping=None, sliding_window=None, rope_max_wavelength=10000.0, @@ -73,6 +74,7 @@ def __init__( self.layer_type = layer_type self.sliding_window = sliding_window self.rope_max_wavelength = rope_max_wavelength + self.head_dim = head_dim if ( self.layer_type == "sliding_attention" and self.sliding_window is None @@ -87,6 +89,7 @@ def __init__( num_key_value_heads=num_key_value_heads, query_pre_attn_scalar=query_pre_attn_scalar, attention_bias=attention_bias, + head_dim=self.head_dim, attention_type="self", initializer_range=initializer_range, attention_dropout=attention_dropout, @@ -163,6 +166,7 @@ def get_config(self): { "hidden_size": self.hidden_size, "rms_norm_eps": self.rms_norm_eps, + "head_dim": self.head_dim, "num_attention_heads": self.num_attention_heads, "num_key_value_heads": self.num_key_value_heads, "query_pre_attn_scalar": self.query_pre_attn_scalar, diff --git a/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py new file mode 100644 index 0000000000..6f52cfe0c7 --- /dev/null +++ b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py @@ -0,0 +1,335 @@ +""" +T5Gemma weight conversion script. + +This script converts checkpoints from a Hugging Face T5Gemma model to a +KerasHub T5Gemma model. + +To run, first install the dependencies: +``` +pip install keras-core keras-nlp tensorflow-text +pip install transformers huggingface-hub sentencepiece absl-py torch +``` + +Then, log in to Hugging Face: +``` +huggingface-cli login +``` + +Finally, run the script to convert the weights: +``` +python convert_t5gemma_checkpoints.py --preset t5gemma_b_b_prefixlm_it +``` +""" + +import os + +import absl +import huggingface_hub +import numpy as np +import torch +import transformers + +from keras_hub.src.models.t5gemma.t5gemma_causal_lm import T5GemmaCausalLM +from keras_hub.src.models.t5gemma.t5gemma_causal_lm_preprocessor import ( + T5GemmaCausalLMPreprocessor, +) +from keras_hub.src.models.t5gemma.t5gemma_tokenizer import T5GemmaTokenizer + +PRESET_MAP = { + "t5gemma_s_s_ul2": "google/t5gemma-s-s-ul2", + "t5gemma_s_s_prefixlm": "google/t5gemma-s-s-prefixlm", + "t5gemma_s_s_ul2_it": "google/t5gemma-s-s-ul2-it", + "t5gemma_s_s_prefixlm_it": "google/t5gemma-s-s-prefixlm-it", + "t5gemma_b_b_ul2": "google/t5gemma-b-b-ul2", + "t5gemma_b_b_prefixlm": "google/t5gemma-b-b-prefixlm", + "t5gemma_b_b_ul2_it": "google/t5gemma-b-b-ul2-it", + "t5gemma_b_b_prefixlm_it": "google/t5gemma-b-b-prefixlm-it", + "t5gemma_l_l_ul2": "google/t5gemma-l-l-ul2", + "t5gemma_l_l_prefixlm": "google/t5gemma-l-l-prefixlm", + "t5gemma_l_l_ul2_it": "google/t5gemma-l-l-ul2-it", + "t5gemma_l_l_prefixlm_it": "google/t5gemma-l-l-prefixlm-it", + "t5gemma_ml_ml_ul2": "google/t5gemma-ml-ml-ul2", + "t5gemma_ml_ml_prefixlm": "google/t5gemma-ml-ml-prefixlm", + "t5gemma_ml_ml_ul2_it": "google/t5gemma-ml-ml-ul2-it", + "t5gemma_ml_ml_prefixlm_it": "google/t5gemma-ml-ml-prefixlm-it", +} +EXTRACT_DIR = "./model_t5gemma" +FLAGS = absl.flags.FLAGS +absl.flags.DEFINE_string( + "preset", + "t5gemma_b_b_prefixlm_it", + f"Must be one of {','.join(PRESET_MAP.keys())}.", +) + + +def download_hf_model(hf_model_name): + print(f"ā¬‡ļø Downloading Hugging Face model '{hf_model_name}'...") + hf_model_dir = huggingface_hub.snapshot_download( + repo_id=hf_model_name, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + local_dir=EXTRACT_DIR, + local_dir_use_symlinks=False, + ) + print(f"āœ… Model downloaded to: {hf_model_dir}") + return hf_model_dir + + +def convert_model(hf_model, preprocessor): + # NOTE: The decoder and encoder are symmetrical in the config, so they are + # interchangeable. + decoder_config = hf_model.config.decoder + if decoder_config.hidden_activation == "gelu_pytorch_tanh": + decoder_config.hidden_activation = "gelu_approximate" + keras_backbone = T5GemmaCausalLM.backbone_cls( + vocabulary_size=decoder_config.vocab_size, + hidden_dim=decoder_config.hidden_size, + intermediate_dim=decoder_config.intermediate_size, + num_layers=decoder_config.num_hidden_layers, + num_attention_heads=decoder_config.num_attention_heads, + num_key_value_heads=decoder_config.num_key_value_heads, + dropout_rate=decoder_config.dropout_rate, + rms_norm_eps=decoder_config.rms_norm_eps, + query_pre_attn_scalar=decoder_config.query_pre_attn_scalar, + tie_word_embeddings=getattr( + hf_model.config, "tie_word_embeddings", True + ), + head_dim=decoder_config.head_dim, + attention_bias=decoder_config.attention_bias, + hidden_activation=decoder_config.hidden_activation, + layer_types=decoder_config.layer_types, + initializer_range=decoder_config.initializer_range, + attention_dropout=decoder_config.attention_dropout, + sliding_window=decoder_config.sliding_window, + cross_attention_hidden_size=decoder_config.cross_attention_hidden_size, + attn_logit_softcapping=decoder_config.attn_logit_softcapping, + final_logit_softcapping=decoder_config.final_logit_softcapping, + rope_max_wavelength=decoder_config.rope_theta, + ) + keras_model = T5GemmaCausalLM( + backbone=keras_backbone, preprocessor=preprocessor + ) + print("āœ… Keras model instantiated.") + return keras_model + + +def convert_tokenizer(hf_model_dir): + print("šŸ—£ļø Converting tokenizer...") + tokenizer_path = os.path.join(hf_model_dir, "tokenizer.model") + keras_tokenizer = T5GemmaTokenizer(proto=tokenizer_path) + print("āœ… Tokenizer converted.") + return keras_tokenizer + + +def convert_weights(keras_model, hf_model): + print("šŸ‹ļø Converting weights...") + hf_wts = hf_model.state_dict() + keras_backbone = keras_model.backbone + hidden_dim = keras_backbone.hidden_dim + num_attention_heads = keras_backbone.num_attention_heads + num_key_value_heads = keras_backbone.num_key_value_heads + head_dim = keras_backbone.head_dim + # Token Embeddings. + keras_backbone.token_embedding.embeddings.assign( + hf_wts["encoder.embed_tokens.weight"] + ) + keras_backbone.decoder_token_embedding.embeddings.assign( + hf_wts["decoder.embed_tokens.weight"] + ) + + # Encoder. + keras_backbone.encoder_norm.scale.assign(hf_wts["encoder.norm.weight"]) + for i in range(keras_backbone.num_layers): + encoder_layer = keras_backbone.get_layer(f"encoder_layer_{i}") + hf_prefix = f"encoder.layers.{i}" + + # Self-attention. + q_w = hf_wts[f"{hf_prefix}.self_attn.q_proj.weight"] + k_w = hf_wts[f"{hf_prefix}.self_attn.k_proj.weight"] + v_w = hf_wts[f"{hf_prefix}.self_attn.v_proj.weight"] + o_w = hf_wts[f"{hf_prefix}.self_attn.o_proj.weight"] + + encoder_layer.self_attn.query_dense.kernel.assign( + q_w.T.reshape(hidden_dim, num_attention_heads, head_dim).numpy() + ) + encoder_layer.self_attn.key_dense.kernel.assign( + k_w.T.reshape(hidden_dim, num_key_value_heads, head_dim).numpy() + ) + encoder_layer.self_attn.value_dense.kernel.assign( + v_w.T.reshape(hidden_dim, num_key_value_heads, head_dim).numpy() + ) + encoder_layer.self_attn.output_dense.kernel.assign( + o_w.T.reshape(num_attention_heads, head_dim, hidden_dim).numpy() + ) + + # MLP. + encoder_layer.mlp.gate_proj.kernel.assign( + hf_wts[f"{hf_prefix}.mlp.gate_proj.weight"].T.numpy() + ) + encoder_layer.mlp.up_proj.kernel.assign( + hf_wts[f"{hf_prefix}.mlp.up_proj.weight"].T.numpy() + ) + encoder_layer.mlp.down_proj.kernel.assign( + hf_wts[f"{hf_prefix}.mlp.down_proj.weight"].T.numpy() + ) + + # Layer norm. + encoder_layer.pre_self_attn_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.pre_self_attn_layernorm.weight"] + ) + encoder_layer.post_self_attn_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.post_self_attn_layernorm.weight"] + ) + encoder_layer.pre_feedforward_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.pre_feedforward_layernorm.weight"] + ) + encoder_layer.post_feedforward_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.post_feedforward_layernorm.weight"] + ) + + # Decoder. + keras_backbone.decoder_norm.scale.assign(hf_wts["decoder.norm.weight"]) + for i in range(keras_backbone.num_layers): + decoder_layer = keras_backbone.get_layer(f"decoder_layer_{i}") + hf_prefix = f"decoder.layers.{i}" + + # Self-attention. + q_w = hf_wts[f"{hf_prefix}.self_attn.q_proj.weight"] + k_w = hf_wts[f"{hf_prefix}.self_attn.k_proj.weight"] + v_w = hf_wts[f"{hf_prefix}.self_attn.v_proj.weight"] + o_w = hf_wts[f"{hf_prefix}.self_attn.o_proj.weight"] + decoder_layer.self_attn.query_dense.kernel.assign( + q_w.T.reshape(hidden_dim, num_attention_heads, head_dim).numpy() + ) + decoder_layer.self_attn.key_dense.kernel.assign( + k_w.T.reshape(hidden_dim, num_key_value_heads, head_dim).numpy() + ) + decoder_layer.self_attn.value_dense.kernel.assign( + v_w.T.reshape(hidden_dim, num_key_value_heads, head_dim).numpy() + ) + decoder_layer.self_attn.output_dense.kernel.assign( + o_w.T.reshape(num_attention_heads, head_dim, hidden_dim).numpy() + ) + + # Cross-attention. + q_w = hf_wts[f"{hf_prefix}.cross_attn.q_proj.weight"] + k_w = hf_wts[f"{hf_prefix}.cross_attn.k_proj.weight"] + v_w = hf_wts[f"{hf_prefix}.cross_attn.v_proj.weight"] + o_w = hf_wts[f"{hf_prefix}.cross_attn.o_proj.weight"] + decoder_layer.cross_attn.query_dense.kernel.assign( + q_w.T.reshape(hidden_dim, num_attention_heads, head_dim).numpy() + ) + decoder_layer.cross_attn.key_dense.kernel.assign( + k_w.T.reshape(hidden_dim, num_key_value_heads, head_dim).numpy() + ) + decoder_layer.cross_attn.value_dense.kernel.assign( + v_w.T.reshape(hidden_dim, num_key_value_heads, head_dim).numpy() + ) + decoder_layer.cross_attn.output_dense.kernel.assign( + o_w.T.reshape(num_attention_heads, head_dim, hidden_dim).numpy() + ) + + # MLP. + decoder_layer.mlp.gate_proj.kernel.assign( + hf_wts[f"{hf_prefix}.mlp.gate_proj.weight"].T.numpy() + ) + decoder_layer.mlp.up_proj.kernel.assign( + hf_wts[f"{hf_prefix}.mlp.up_proj.weight"].T.numpy() + ) + decoder_layer.mlp.down_proj.kernel.assign( + hf_wts[f"{hf_prefix}.mlp.down_proj.weight"].T.numpy() + ) + + # Layer norm. + decoder_layer.pre_self_attn_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.pre_self_attn_layernorm.weight"] + ) + decoder_layer.post_self_attn_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.post_self_attn_layernorm.weight"] + ) + decoder_layer.pre_cross_attn_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.pre_cross_attn_layernorm.weight"] + ) + decoder_layer.post_cross_attn_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.post_cross_attn_layernorm.weight"] + ) + decoder_layer.pre_feedforward_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.pre_feedforward_layernorm.weight"] + ) + decoder_layer.post_feedforward_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.post_feedforward_layernorm.weight"] + ) + print("āœ… Weights converted.") + + +def validate_output(hf_model, keras_model, hf_tokenizer, keras_tokenizer): + hf_model.eval() + print("šŸ”Ž Validating tokenizer outputs...") + # Example sentence. + test_sentence = "What is the fastest land animal?" + hf_tokens = hf_tokenizer(test_sentence, return_tensors="pt")["input_ids"][ + 0 + ].tolist() + keras_tokens = keras_tokenizer.tokenize(test_sentence).numpy().tolist() + print(f"šŸ”¶ Test Sentence: '{test_sentence}'") + print(f"šŸ”¶ Hugging Face Tokens: {hf_tokens}") + print(f"šŸ”¶ Keras Tokens: {keras_tokens}") + assert hf_tokens == keras_tokens, "Tokenizer outputs do not match!" + print("āœ… Tokenizer outputs are consistent.") + print("šŸ”Ž Validating numeric outputs...") + input_ids_np = np.ones((1, 10), dtype="int32") + attention_mask_np = np.ones((1, 10), dtype="int32") + keras_inputs = { + "token_ids": input_ids_np, + "padding_mask": attention_mask_np, + } + hf_input_ids = torch.from_numpy(input_ids_np) + hf_attention_mask = torch.from_numpy(attention_mask_np) + hf_decoder_input_ids = hf_input_ids.clone() + hf_outputs = hf_model( + input_ids=hf_input_ids, + attention_mask=hf_attention_mask, + decoder_input_ids=hf_decoder_input_ids, + ) + hf_final_hidden_states = hf_outputs.last_hidden_state.detach().numpy() + print("\nšŸ”Ž Validating final hidden states...") + keras_final_hidden_states = keras_model.backbone.predict(keras_inputs) + final_difference = np.mean( + np.abs(hf_final_hidden_states - keras_final_hidden_states) + ) + print(f"šŸ”¶ Keras final output shape: {keras_final_hidden_states.shape}") + print(f"šŸ”¶ HF final output shape: {hf_final_hidden_states.shape}") + print(f"šŸ”¶ Mean absolute difference: {final_difference:.6e}") + assert final_difference < 1e-4, "Final output difference is too high!" + print("āœ… Final hidden states are consistent.") + + +def main(_): + preset = FLAGS.preset + print(f"šŸš€ Starting conversion for preset: {preset}") + + hf_model_name = PRESET_MAP[preset] + hf_model_dir = download_hf_model(hf_model_name) + + print("🧩 Loading Hugging Face model and tokenizer...") + hf_model = transformers.T5GemmaModel.from_pretrained(hf_model_dir) + hf_tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model_dir) + print("āœ… Hugging Face model and tokenizer loaded.") + + keras_tokenizer = convert_tokenizer(hf_model_dir) + + keras_preprocessor = T5GemmaCausalLMPreprocessor( + tokenizer=keras_tokenizer, + ) + keras_model = convert_model(hf_model, keras_preprocessor) + convert_weights(keras_model, hf_model) + validate_output(hf_model, keras_model, hf_tokenizer, keras_tokenizer) + + print(f"šŸ’¾ Saving Keras model and tokenizer to preset '{preset}'...") + keras_model.save_to_preset(preset) + keras_tokenizer.save_to_preset(preset) + print("āœ… Preset saved successfully.") + print("šŸŽ‰ Conversion complete!") + + +if __name__ == "__main__": + absl.app.run(main) From afb98451bcb3b8b8d7c8e95089d32374c6890d27 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Thu, 24 Jul 2025 09:37:32 +0400 Subject: [PATCH 07/19] nit: Precise compute_output_shape methods; document head_dim --- keras_hub/src/models/t5gemma/t5gemma_attention.py | 10 +++++++++- keras_hub/src/models/t5gemma/t5gemma_backbone.py | 2 ++ keras_hub/src/models/t5gemma/t5gemma_causal_lm.py | 1 + keras_hub/src/models/t5gemma/t5gemma_decoder.py | 5 +++-- keras_hub/src/models/t5gemma/t5gemma_encoder.py | 5 ++++- 5 files changed, 19 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/models/t5gemma/t5gemma_attention.py b/keras_hub/src/models/t5gemma/t5gemma_attention.py index 94116ce02a..6e0259919a 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_attention.py +++ b/keras_hub/src/models/t5gemma/t5gemma_attention.py @@ -53,6 +53,7 @@ class T5GemmaAttention(CachedGemmaAttention): query_pre_attn_scalar: float, Scalar to multiply queries by before attention. attention_bias: bool, Whether to include bias in the dense layers. + head_dim: int, The dimensionality of each attention head. attention_type: str, The type of attention, either 'self' or 'cross'. Defaults to 'self'. cross_attention_hidden_size: int, optional, The dimensionality of @@ -326,7 +327,14 @@ def compute_output_shape(self, input_shape): q_len, kv_len, ) - return attn_output_shape, attn_weights_shape + cache_shape = ( + hidden_states_shape[0], # batch + 2, # key and value + self.num_key_value_heads, + kv_len, + self.head_dim, + ) + return (attn_output_shape, attn_weights_shape), cache_shape def get_config(self): config = super().get_config() diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone.py b/keras_hub/src/models/t5gemma/t5gemma_backbone.py index d5a63f3bad..13a771f890 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_backbone.py +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone.py @@ -44,6 +44,7 @@ class T5GemmaBackbone(Backbone): attention layer for each encoder/decoder layer. Each element can be either `"sliding_attention"` or `"full_attention"`. For example, `["full_attention", "sliding_attention", ...]`. + head_dim: int, The dimensionality of each attention head. tie_word_embeddings: bool, Whether to tie input and output word embeddings. Default is `True`. initializer_range: float, The range for the random normal initializer. @@ -85,6 +86,7 @@ class T5GemmaBackbone(Backbone): num_key_value_heads=2, hidden_dim=256, intermediate_dim=512, + head_dim=64, dropout_rate=0.1, rms_norm_eps=1e-6, query_pre_attn_scalar=1.0, diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py index d6c20fdaee..a8adcc2d09 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py @@ -122,6 +122,7 @@ class T5GemmaCausalLM(CausalLM): num_key_value_heads=2, hidden_dim=256, intermediate_dim=512, + head_dim=64, dropout_rate=0.1, rms_norm_eps=1e-6, query_pre_attn_scalar=1.0, diff --git a/keras_hub/src/models/t5gemma/t5gemma_decoder.py b/keras_hub/src/models/t5gemma/t5gemma_decoder.py index f14c6dcb11..b64c25ff26 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_decoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_decoder.py @@ -28,6 +28,7 @@ class T5GemmaDecoderLayer(keras.layers.Layer): hidden_activation: str, The activation function used in the feed-forward network. dropout_rate: float, The dropout rate applied after attention and MLP. + head_dim: int, The dimensionality of each attention head. initializer_range: float, The range for the random normal initializer. attention_dropout: float, The dropout rate applied to attention weights. layer_type: str, Type of attention layer, e.g., `"sliding_attention"`. @@ -142,7 +143,7 @@ def build(self, input_shape): self.self_attn.build(current_shape) attn_output_shape = self.self_attn.compute_output_shape(current_shape)[ 0 - ] + ][0] self.post_self_attn_layernorm.build(attn_output_shape) current_shape = attn_output_shape self.dropout.build(current_shape) @@ -150,7 +151,7 @@ def build(self, input_shape): self.cross_attn.build([current_shape, encoder_hidden_states_shape]) attn_output_shape = self.cross_attn.compute_output_shape( [current_shape, encoder_hidden_states_shape] - )[0] + )[0][0] self.post_cross_attn_layernorm.build(attn_output_shape) current_shape = attn_output_shape self.pre_feedforward_layernorm.build(current_shape) diff --git a/keras_hub/src/models/t5gemma/t5gemma_encoder.py b/keras_hub/src/models/t5gemma/t5gemma_encoder.py index 5327fd5ce7..cf28c37140 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_encoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_encoder.py @@ -30,6 +30,7 @@ class T5GemmaEncoderLayer(keras.layers.Layer): initializer_range: float, The range for the random normal initializer. attention_dropout: float, The dropout rate applied to attention weights. layer_type: str, Type of attention layer, e.g., `"sliding_attention"`. + head_dim: int, The dimensionality of each attention head. attn_logit_softcapping: float, optional, The softcapping value for attention logits. sliding_window: int, optional, The window size for sliding attention. @@ -113,7 +114,9 @@ def __init__( def build(self, input_shape): self.pre_self_attn_layernorm.build(input_shape) self.self_attn.build(input_shape) - attn_output_shape = self.self_attn.compute_output_shape(input_shape)[0] + attn_output_shape = self.self_attn.compute_output_shape(input_shape)[0][ + 0 + ] self.post_self_attn_layernorm.build(attn_output_shape) self.dropout.build(attn_output_shape) self.pre_feedforward_layernorm.build(attn_output_shape) From 5be64381d6ce54edef638498b7f19a15df3fd35a Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Thu, 24 Jul 2025 12:03:55 +0400 Subject: [PATCH 08/19] nit: Propagate dtypes --- .../src/models/t5gemma/t5gemma_attention.py | 5 +++ .../src/models/t5gemma/t5gemma_backbone.py | 18 +++++++--- .../src/models/t5gemma/t5gemma_decoder.py | 36 ++++++++++++++----- .../src/models/t5gemma/t5gemma_encoder.py | 27 ++++++++++---- .../src/models/t5gemma/t5gemma_layers.py | 12 +++++-- 5 files changed, 78 insertions(+), 20 deletions(-) diff --git a/keras_hub/src/models/t5gemma/t5gemma_attention.py b/keras_hub/src/models/t5gemma/t5gemma_attention.py index 6e0259919a..20a5321ef6 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_attention.py +++ b/keras_hub/src/models/t5gemma/t5gemma_attention.py @@ -66,6 +66,8 @@ class T5GemmaAttention(CachedGemmaAttention): attention logits. rope_max_wavelength: float, The maximum wavelength for Rotary Positional Embeddings. Default is `10000.0`. Only used for self-attention. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. **kwargs: Additional keyword arguments passed to the parent class. """ @@ -83,6 +85,7 @@ def __init__( attention_dropout=0.0, attn_logit_softcapping=None, rope_max_wavelength=10000.0, + dtype=None, **kwargs, ): super().__init__( @@ -94,6 +97,7 @@ def __init__( dropout=attention_dropout, query_head_dim_normalize=False, use_sliding_window_attention=False, + dtype=dtype, **kwargs, ) if attention_type not in ["self", "cross"]: @@ -121,6 +125,7 @@ def __init__( sequence_axis=2, feature_axis=3, name="rotary_embedding", + dtype=self.dtype_policy, ) def build(self, input_shape): diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone.py b/keras_hub/src/models/t5gemma/t5gemma_backbone.py index 13a771f890..41f2cb0c8e 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_backbone.py +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone.py @@ -62,6 +62,10 @@ class T5GemmaBackbone(Backbone): final logits. rope_max_wavelength: float, The maximum wavelength for Rotary Positional Embeddings. Default is `10000.0`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. **kwargs: Additional keyword arguments passed to the parent `Backbone` class. @@ -121,6 +125,7 @@ def __init__( attn_logit_softcapping=None, final_logit_softcapping=None, rope_max_wavelength=10000.0, + dtype=None, **kwargs, ): self.kernel_initializer = t5gemma_kernel_initializer(initializer_range) @@ -130,12 +135,14 @@ def __init__( input_dim=vocabulary_size, output_dim=hidden_dim, embeddings_initializer=clone_initializer(self.kernel_initializer), + dtype=dtype, ) self.decoder_token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, tie_weights=tie_word_embeddings, embeddings_initializer=clone_initializer(self.kernel_initializer), + dtype=dtype, ) self.encoder_layers = [ T5GemmaEncoderLayer( @@ -156,11 +163,12 @@ def __init__( attn_logit_softcapping=attn_logit_softcapping, rope_max_wavelength=rope_max_wavelength, name=f"encoder_layer_{i}", + dtype=dtype, ) for i in range(num_layers) ] - self.encoder_norm = RMSNormalization(epsilon=rms_norm_eps) - self.encoder_dropout = keras.layers.Dropout(dropout_rate) + self.encoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype) + self.encoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype) self.decoder_layers = [ T5GemmaDecoderLayer( hidden_size=hidden_dim, @@ -181,11 +189,12 @@ def __init__( attn_logit_softcapping=attn_logit_softcapping, rope_max_wavelength=rope_max_wavelength, name=f"decoder_layer_{i}", + dtype=dtype, ) for i in range(num_layers) ] - self.decoder_norm = RMSNormalization(epsilon=rms_norm_eps) - self.decoder_dropout = keras.layers.Dropout(dropout_rate) + self.decoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype) + self.decoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype) # === Functional Model === token_id_input = keras.Input( @@ -230,6 +239,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=decoder_output, + dtype=dtype, **kwargs, ) diff --git a/keras_hub/src/models/t5gemma/t5gemma_decoder.py b/keras_hub/src/models/t5gemma/t5gemma_decoder.py index b64c25ff26..c127ca8c8a 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_decoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_decoder.py @@ -40,6 +40,8 @@ class T5GemmaDecoderLayer(keras.layers.Layer): Required if `layer_type` is `"sliding_attention"`. rope_max_wavelength: float, The maximum wavelength for Rotary Positional Embeddings. Default is `10000.0`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. **kwargs: Additional keyword arguments passed to the parent class. """ @@ -62,9 +64,10 @@ def __init__( attn_logit_softcapping=None, sliding_window=None, rope_max_wavelength=10000.0, + dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.head_dim = head_dim self.hidden_size = hidden_size self.rms_norm_eps = rms_norm_eps @@ -102,9 +105,14 @@ def __init__( attention_dropout=attention_dropout, attn_logit_softcapping=attn_logit_softcapping, rope_max_wavelength=self.rope_max_wavelength, + dtype=self.dtype_policy, + ) + self.pre_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, dtype=self.dtype_policy + ) + self.post_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, dtype=self.dtype_policy ) - self.pre_self_attn_layernorm = RMSNormalization(epsilon=rms_norm_eps) - self.post_self_attn_layernorm = RMSNormalization(epsilon=rms_norm_eps) # Cross-attention. self.cross_attn = T5GemmaAttention( @@ -119,9 +127,14 @@ def __init__( initializer_range=initializer_range, attention_dropout=attention_dropout, attn_logit_softcapping=attn_logit_softcapping, + dtype=self.dtype_policy, + ) + self.pre_cross_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, dtype=self.dtype_policy + ) + self.post_cross_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, dtype=self.dtype_policy ) - self.pre_cross_attn_layernorm = RMSNormalization(epsilon=rms_norm_eps) - self.post_cross_attn_layernorm = RMSNormalization(epsilon=rms_norm_eps) # MLP. self.mlp = T5GemmaMLP( @@ -130,11 +143,18 @@ def __init__( hidden_activation, dropout_rate, initializer_range=initializer_range, + dtype=self.dtype_policy, + ) + self.pre_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, dtype=self.dtype_policy + ) + self.post_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, dtype=self.dtype_policy ) - self.pre_feedforward_layernorm = RMSNormalization(epsilon=rms_norm_eps) - self.post_feedforward_layernorm = RMSNormalization(epsilon=rms_norm_eps) - self.dropout = keras.layers.Dropout(dropout_rate) + self.dropout = keras.layers.Dropout( + dropout_rate, dtype=self.dtype_policy + ) def build(self, input_shape): hidden_states_shape, encoder_hidden_states_shape = input_shape diff --git a/keras_hub/src/models/t5gemma/t5gemma_encoder.py b/keras_hub/src/models/t5gemma/t5gemma_encoder.py index cf28c37140..d119388cd8 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_encoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_encoder.py @@ -37,6 +37,8 @@ class T5GemmaEncoderLayer(keras.layers.Layer): Required if `layer_type` is `"sliding_attention"`. rope_max_wavelength: float, The maximum wavelength for Rotary Positional Embeddings. Default is `10000.0`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. **kwargs: Additional keyword arguments passed to the parent class. """ @@ -58,9 +60,10 @@ def __init__( attn_logit_softcapping=None, sliding_window=None, rope_max_wavelength=10000.0, + dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.hidden_size = hidden_size self.rms_norm_eps = rms_norm_eps self.num_attention_heads = num_attention_heads @@ -96,9 +99,14 @@ def __init__( attention_dropout=attention_dropout, attn_logit_softcapping=attn_logit_softcapping, rope_max_wavelength=self.rope_max_wavelength, + dtype=self.dtype_policy, + ) + self.pre_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, dtype=self.dtype_policy + ) + self.post_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, dtype=self.dtype_policy ) - self.pre_self_attn_layernorm = RMSNormalization(epsilon=rms_norm_eps) - self.post_self_attn_layernorm = RMSNormalization(epsilon=rms_norm_eps) self.mlp = T5GemmaMLP( hidden_size, @@ -106,10 +114,17 @@ def __init__( hidden_activation, dropout_rate, initializer_range=initializer_range, + dtype=self.dtype_policy, + ) + self.pre_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, dtype=self.dtype_policy + ) + self.post_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, dtype=self.dtype_policy + ) + self.dropout = keras.layers.Dropout( + dropout_rate, dtype=self.dtype_policy ) - self.pre_feedforward_layernorm = RMSNormalization(epsilon=rms_norm_eps) - self.post_feedforward_layernorm = RMSNormalization(epsilon=rms_norm_eps) - self.dropout = keras.layers.Dropout(dropout_rate) def build(self, input_shape): self.pre_self_attn_layernorm.build(input_shape) diff --git a/keras_hub/src/models/t5gemma/t5gemma_layers.py b/keras_hub/src/models/t5gemma/t5gemma_layers.py index 6650ef942f..a574b737e1 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_layers.py +++ b/keras_hub/src/models/t5gemma/t5gemma_layers.py @@ -33,6 +33,8 @@ class T5GemmaMLP(keras.layers.Layer): hidden states. initializer_range: float, The range for the random normal initializer for kernel weights. Default is `0.02`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. **kwargs: Additional keyword arguments passed to the parent class. """ @@ -43,9 +45,10 @@ def __init__( hidden_activation, dropout_rate, initializer_range=0.02, + dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.hidden_activation = hidden_activation @@ -57,23 +60,28 @@ def __init__( self.intermediate_size, use_bias=False, kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, ) self.up_proj = keras.layers.Dense( self.intermediate_size, use_bias=False, kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, ) self.down_proj = keras.layers.Dense( self.hidden_size, use_bias=False, kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, ) if self.hidden_activation == "gelu_approximate": # NOTE: `gelu_pytorch_tanh` is the same as `gelu(approximate=True)`. self.act_fn = lambda x: keras.activations.gelu(x, approximate=True) else: self.act_fn = keras.activations.get(self.hidden_activation) - self.dropout = keras.layers.Dropout(self.dropout_rate) + self.dropout = keras.layers.Dropout( + self.dropout_rate, dtype=self.dtype_policy + ) def build(self, input_shape): self.gate_proj.build(input_shape) From 3dbc0b7eca019c8292e87d32b7ea5233931a51c6 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Thu, 24 Jul 2025 21:37:54 +0400 Subject: [PATCH 09/19] =?UTF-8?q?bug=20fix=20+=20minor=20cleanup:=20Fix=20?= =?UTF-8?q?head=5Fdim=20default=20=E2=86=92=20head=5Fdim=20from=20config?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/models/t5gemma/t5gemma_attention.py | 63 ++++++++----------- .../src/models/t5gemma/t5gemma_causal_lm.py | 2 +- .../src/models/t5gemma/t5gemma_decoder.py | 5 +- 3 files changed, 30 insertions(+), 40 deletions(-) diff --git a/keras_hub/src/models/t5gemma/t5gemma_attention.py b/keras_hub/src/models/t5gemma/t5gemma_attention.py index 20a5321ef6..8ead72e182 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_attention.py +++ b/keras_hub/src/models/t5gemma/t5gemma_attention.py @@ -196,6 +196,29 @@ def build(self, input_shape): self.softmax = keras.layers.Softmax(dtype="float32") self.built = True + def _compute_attention( + self, query_states, key_states, value_states, attention_mask, training + ): + attn_weights = keras.ops.einsum( + "bnth,bnsh->bnts", query_states, key_states + ) + attn_weights *= self.scaling + if self.logit_soft_cap is not None: + attn_weights = attn_weights / self.logit_soft_cap + attn_weights = keras.ops.tanh(attn_weights) + attn_weights = attn_weights * self.logit_soft_cap + if attention_mask is not None: + attn_weights += attention_mask + attn_weights = keras.ops.cast( + self.softmax(attn_weights), + query_states.dtype, + ) + attn_weights = self.dropout_layer(attn_weights, training=training) + attn_output = keras.ops.einsum( + "bnts,bnsh->bnth", attn_weights, value_states + ) + return attn_output, attn_weights + def call( self, inputs, @@ -229,23 +252,8 @@ def call( # Repeat key-value heads for GQA. key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = keras.ops.einsum( - "bnth,bnsh->bnts", query_states, key_states - ) - attn_weights *= self.scaling - if self.logit_soft_cap is not None: - attn_weights = attn_weights / self.logit_soft_cap - attn_weights = keras.ops.tanh(attn_weights) - attn_weights = attn_weights * self.logit_soft_cap - if attention_mask is not None: - attn_weights += attention_mask - attn_weights = keras.ops.cast( - self.softmax(attn_weights), - query_states.dtype, - ) - attn_weights = self.dropout_layer(attn_weights, training=training) - attn_output = keras.ops.einsum( - "bnts,bnsh->bnth", attn_weights, value_states + attn_output, attn_weights = self._compute_attention( + query_states, key_states, value_states, attention_mask, training ) attn_output = self.output_dense(attn_output) return (attn_output, attn_weights), updated_cache @@ -294,25 +302,8 @@ def call( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = keras.ops.einsum( - "bnth,bnsh->bnts", query_states, key_states - ) - attn_weights *= self.scaling - - if self.logit_soft_cap is not None: - attn_weights = attn_weights / self.logit_soft_cap - attn_weights = keras.ops.tanh(attn_weights) - attn_weights = attn_weights * self.logit_soft_cap - if attention_mask is not None: - attn_weights += attention_mask - - attn_weights = keras.ops.cast( - self.softmax(attn_weights), - query_states.dtype, - ) - attn_weights = self.dropout_layer(attn_weights, training=training) - attn_output = keras.ops.einsum( - "bnts,bnsh->bnth", attn_weights, value_states + attn_output, attn_weights = self._compute_attention( + query_states, key_states, value_states, attention_mask, training ) attn_output = self.output_dense(attn_output) return (attn_output, attn_weights), cache diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py index a8adcc2d09..ef351b9501 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py @@ -278,7 +278,7 @@ def _build_cache(self, token_ids, padding_mask): batch_size = keras.ops.shape(token_ids)[0] num_layers = self.backbone.num_layers num_kv_heads = self.backbone.num_key_value_heads - head_dim = self.backbone.hidden_dim // self.backbone.num_attention_heads + head_dim = self.backbone.head_dim self_cache_shape = ( batch_size, num_layers, diff --git a/keras_hub/src/models/t5gemma/t5gemma_decoder.py b/keras_hub/src/models/t5gemma/t5gemma_decoder.py index c127ca8c8a..ee6ff962bd 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_decoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_decoder.py @@ -293,20 +293,19 @@ def compute_output_shape(self, input_shape): hidden_states_shape, encoder_hidden_states_shape = input_shape batch_size, dec_seq_len, _ = hidden_states_shape _, enc_seq_len, _ = encoder_hidden_states_shape - head_dim = self.hidden_size // self.num_attention_heads self_cache_shape = ( batch_size, 2, self.num_key_value_heads, dec_seq_len, - head_dim, + self.head_dim, ) cross_cache_shape = ( batch_size, 2, self.num_key_value_heads, enc_seq_len, - head_dim, + self.head_dim, ) return hidden_states_shape, (self_cache_shape, cross_cache_shape) From 291d8f1027028f118a2d8399c838b429dcdec826 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Fri, 25 Jul 2025 11:30:43 +0400 Subject: [PATCH 10/19] perf(jax/tpu): Fused kernel optim for TPU backend + get_config() args --- .../src/models/t5gemma/t5gemma_attention.py | 82 +++++++++++-------- .../src/models/t5gemma/t5gemma_causal_lm.py | 2 +- .../src/models/t5gemma/t5gemma_decoder.py | 26 +++--- .../src/models/t5gemma/t5gemma_encoder.py | 9 +- .../src/models/t5gemma/t5gemma_layers.py | 1 - 5 files changed, 66 insertions(+), 54 deletions(-) diff --git a/keras_hub/src/models/t5gemma/t5gemma_attention.py b/keras_hub/src/models/t5gemma/t5gemma_attention.py index 8ead72e182..1e4d772fc8 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_attention.py +++ b/keras_hub/src/models/t5gemma/t5gemma_attention.py @@ -1,3 +1,5 @@ +import inspect + import keras from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding @@ -18,21 +20,21 @@ def repeat_kv(hidden_states, n_rep): Args: hidden_states: Tensor, The key or value hidden states with shape - `(batch, num_key_value_heads, sequence_length, head_dim)`. + `(batch, sequence_length, num_key_value_heads, head_dim)`. n_rep: int, The number of times to repeat the key/value heads. This is typically `num_query_heads // num_key_value_heads`. Returns: Tensor: The expanded key/value hidden states with shape - `(batch, num_query_heads, sequence_length, head_dim)`. + `(batch, sequence_length, num_query_heads, head_dim)`. """ if n_rep == 1: return hidden_states - batch, num_key_value_heads, slen, head_dim = keras.ops.shape(hidden_states) - hidden_states = keras.ops.expand_dims(hidden_states, 2) - hidden_states = keras.ops.tile(hidden_states, (1, 1, n_rep, 1, 1)) + batch, slen, num_key_value_heads, head_dim = keras.ops.shape(hidden_states) + hidden_states = keras.ops.expand_dims(hidden_states, 3) + hidden_states = keras.ops.tile(hidden_states, (1, 1, 1, n_rep, 1)) return keras.ops.reshape( - hidden_states, (batch, num_key_value_heads * n_rep, slen, head_dim) + hidden_states, (batch, slen, num_key_value_heads * n_rep, head_dim) ) @@ -122,7 +124,7 @@ def __init__( if self.attention_type == "self": self.rotary_embedding = RotaryEmbedding( max_wavelength=self.rope_max_wavelength, - sequence_axis=2, + sequence_axis=1, feature_axis=3, name="rotary_embedding", dtype=self.dtype_policy, @@ -141,8 +143,8 @@ def build(self, input_shape): # Query projection layer. self.hidden_dim = hidden_states_shape[-1] self.query_dense = keras.layers.EinsumDense( - equation="btd,dnh->bnth", - output_shape=(self.num_query_heads, None, self.head_dim), + equation="btd,dnh->btnh", + output_shape=(None, self.num_query_heads, self.head_dim), kernel_initializer=clone_initializer(self._kernel_initializer), bias_axes="nh" if self.attention_bias else None, dtype=self.dtype_policy, @@ -152,8 +154,8 @@ def build(self, input_shape): # Key projection layer. self.key_dense = keras.layers.EinsumDense( - equation="bsd,dkh->bksh", - output_shape=(self.num_key_value_heads, None, self.head_dim), + equation="bsd,dkh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), kernel_initializer=clone_initializer(self._kernel_initializer), bias_axes="kh" if self.attention_bias else None, dtype=self.dtype_policy, @@ -163,8 +165,8 @@ def build(self, input_shape): # Value projection layer. self.value_dense = keras.layers.EinsumDense( - equation="bsd,dkh->bksh", - output_shape=(self.num_key_value_heads, None, self.head_dim), + equation="bsd,dkh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), kernel_initializer=clone_initializer(self._kernel_initializer), bias_axes="kh" if self.attention_bias else None, dtype=self.dtype_policy, @@ -174,7 +176,7 @@ def build(self, input_shape): # Output projection layer. self.output_dense = keras.layers.EinsumDense( - equation="bnth,nhd->btd", + equation="btnh,nhd->btd", output_shape=(None, self.hidden_dim), kernel_initializer=clone_initializer(self._kernel_initializer), bias_axes="d" if self.attention_bias else None, @@ -184,8 +186,8 @@ def build(self, input_shape): self.output_dense.build( ( hidden_states_shape[0], - self.num_query_heads, hidden_states_shape[1], + self.num_query_heads, self.head_dim, ) ) @@ -193,14 +195,32 @@ def build(self, input_shape): rate=self.attention_dropout, dtype=self.dtype_policy, ) - self.softmax = keras.layers.Softmax(dtype="float32") + self.softmax = keras.layers.Softmax(axis=-1, dtype="float32") self.built = True def _compute_attention( self, query_states, key_states, value_states, attention_mask, training ): + if self._use_fused_attention_op(): + kwargs = {"bias": attention_mask} + if self.logit_soft_cap is not None: + sig = inspect.signature(keras.ops.dot_product_attention) + # This is only supported in JAX TPU backend. + # https://keras.io/api/ops/nn/#dot_product_attention-function + if "attn_logits_soft_cap" in sig.parameters: + kwargs["attn_logits_soft_cap"] = self.logit_soft_cap + return ( + keras.ops.dot_product_attention( + query=query_states, + key=key_states, + value=value_states, + scale=self.scaling, + **kwargs, + ), + None, + ) attn_weights = keras.ops.einsum( - "bnth,bnsh->bnts", query_states, key_states + "btnh,bsnh->bnts", query_states, key_states ) attn_weights *= self.scaling if self.logit_soft_cap is not None: @@ -215,7 +235,7 @@ def _compute_attention( ) attn_weights = self.dropout_layer(attn_weights, training=training) attn_output = keras.ops.einsum( - "bnts,bnsh->bnth", attn_weights, value_states + "bnts,bsnh->btnh", attn_weights, value_states ) return attn_output, attn_weights @@ -252,11 +272,11 @@ def call( # Repeat key-value heads for GQA. key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_output, attn_weights = self._compute_attention( + attn_output, _ = self._compute_attention( query_states, key_states, value_states, attention_mask, training ) attn_output = self.output_dense(attn_output) - return (attn_output, attn_weights), updated_cache + return attn_output, updated_cache else: # Self-attention hidden_states = inputs kv_states = hidden_states @@ -272,9 +292,6 @@ def call( key_states = self.rotary_embedding( key_states, start_index=start_index ) - current_pass_cache = keras.ops.stack( - (key_states, value_states), axis=1 - ) if cache is not None: if cache_update_index is None: raise ValueError( @@ -282,7 +299,7 @@ def call( "for self-attention caching." ) key_cache, value_cache = cache[:, 0, ...], cache[:, 1, ...] - start = [0, 0, cache_update_index, 0] + start = [0, cache_update_index, 0, 0] key_states = keras.ops.slice_update( key_cache, start, key_states ) @@ -296,17 +313,17 @@ def call( "`None`." ) else: - cache = current_pass_cache + cache = keras.ops.stack((key_states, value_states), axis=1) # Repeat key-value heads for GQA. key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_output, attn_weights = self._compute_attention( + attn_output, _ = self._compute_attention( query_states, key_states, value_states, attention_mask, training ) attn_output = self.output_dense(attn_output) - return (attn_output, attn_weights), cache + return attn_output, cache def compute_output_shape(self, input_shape): if self.attention_type == "cross": @@ -315,22 +332,15 @@ def compute_output_shape(self, input_shape): hidden_states_shape = input_shape kv_states_shape = input_shape attn_output_shape = hidden_states_shape - q_len = hidden_states_shape[1] kv_len = kv_states_shape[1] - attn_weights_shape = ( - hidden_states_shape[0], - self.num_query_heads, - q_len, - kv_len, - ) cache_shape = ( hidden_states_shape[0], # batch 2, # key and value - self.num_key_value_heads, kv_len, + self.num_key_value_heads, self.head_dim, ) - return (attn_output_shape, attn_weights_shape), cache_shape + return attn_output_shape, cache_shape def get_config(self): config = super().get_config() diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py index ef351b9501..e34a8c1cc5 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py @@ -283,8 +283,8 @@ def _build_cache(self, token_ids, padding_mask): batch_size, num_layers, 2, - num_kv_heads, keras.ops.shape(token_ids)[1], + num_kv_heads, head_dim, ) self_attention_cache = keras.ops.zeros( diff --git a/keras_hub/src/models/t5gemma/t5gemma_decoder.py b/keras_hub/src/models/t5gemma/t5gemma_decoder.py index ee6ff962bd..e0ab336a9e 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_decoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_decoder.py @@ -5,7 +5,6 @@ from keras_hub.src.models.t5gemma.t5gemma_layers import T5GemmaMLP -@keras.saving.register_keras_serializable(package="keras_hub") class T5GemmaDecoderLayer(keras.layers.Layer): """Decoder layer for the T5Gemma model. @@ -83,6 +82,8 @@ def __init__( self.layer_type = layer_type self.sliding_window = sliding_window self.rope_max_wavelength = rope_max_wavelength + self.cross_attention_hidden_size = cross_attention_hidden_size + self.attn_logit_softcapping = attn_logit_softcapping if ( self.layer_type == "sliding_attention" and self.sliding_window is None @@ -161,17 +162,17 @@ def build(self, input_shape): self.pre_self_attn_layernorm.build(hidden_states_shape) current_shape = hidden_states_shape self.self_attn.build(current_shape) - attn_output_shape = self.self_attn.compute_output_shape(current_shape)[ - 0 - ][0] + attn_output_shape, _ = self.self_attn.compute_output_shape( + current_shape + ) self.post_self_attn_layernorm.build(attn_output_shape) current_shape = attn_output_shape self.dropout.build(current_shape) self.pre_cross_attn_layernorm.build(current_shape) self.cross_attn.build([current_shape, encoder_hidden_states_shape]) - attn_output_shape = self.cross_attn.compute_output_shape( + attn_output_shape, _ = self.cross_attn.compute_output_shape( [current_shape, encoder_hidden_states_shape] - )[0][0] + ) self.post_cross_attn_layernorm.build(attn_output_shape) current_shape = attn_output_shape self.pre_feedforward_layernorm.build(current_shape) @@ -189,7 +190,7 @@ def _make_self_attention_mask( ): if cache is not None: q_len = keras.ops.shape(hidden_states)[1] - kv_len = keras.ops.shape(cache)[3] + kv_len = keras.ops.shape(cache)[2] q_indices = ( keras.ops.arange(0, q_len, dtype="int32") + cache_update_index ) @@ -245,7 +246,7 @@ def call( cache_update_index=cache_update_index, ) hidden_states = self.pre_self_attn_layernorm(hidden_states) - (hidden_states, _), updated_self_attention_cache = self.self_attn( + hidden_states, updated_self_attention_cache = self.self_attn( inputs=hidden_states, attention_mask=self_attention_mask, cache=self_attention_cache, @@ -263,7 +264,7 @@ def call( encoder_hidden_states, cross_attention_padding_mask ) hidden_states = self.pre_cross_attn_layernorm(hidden_states) - (hidden_states, _), updated_cross_attention_cache = self.cross_attn( + hidden_states, updated_cross_attention_cache = self.cross_attn( inputs=[hidden_states, encoder_hidden_states], attention_mask=cross_attention_mask, cache=cross_attention_cache, @@ -296,15 +297,15 @@ def compute_output_shape(self, input_shape): self_cache_shape = ( batch_size, 2, - self.num_key_value_heads, dec_seq_len, + self.num_key_value_heads, self.head_dim, ) cross_cache_shape = ( batch_size, 2, - self.num_key_value_heads, enc_seq_len, + self.num_key_value_heads, self.head_dim, ) return hidden_states_shape, (self_cache_shape, cross_cache_shape) @@ -327,6 +328,9 @@ def get_config(self): "layer_type": self.layer_type, "sliding_window": self.sliding_window, "rope_max_wavelength": self.rope_max_wavelength, + "head_dim": self.head_dim, + "cross_attention_hidden_size": self.cross_attention_hidden_size, + "attn_logit_softcapping": self.attn_logit_softcapping, } ) return config diff --git a/keras_hub/src/models/t5gemma/t5gemma_encoder.py b/keras_hub/src/models/t5gemma/t5gemma_encoder.py index d119388cd8..7cdac8ca7e 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_encoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_encoder.py @@ -5,7 +5,6 @@ from keras_hub.src.models.t5gemma.t5gemma_layers import T5GemmaMLP -@keras.saving.register_keras_serializable(package="keras_hub") class T5GemmaEncoderLayer(keras.layers.Layer): """Encoder layer for the T5Gemma model. @@ -79,6 +78,7 @@ def __init__( self.sliding_window = sliding_window self.rope_max_wavelength = rope_max_wavelength self.head_dim = head_dim + self.attn_logit_softcapping = attn_logit_softcapping if ( self.layer_type == "sliding_attention" and self.sliding_window is None @@ -129,9 +129,7 @@ def __init__( def build(self, input_shape): self.pre_self_attn_layernorm.build(input_shape) self.self_attn.build(input_shape) - attn_output_shape = self.self_attn.compute_output_shape(input_shape)[0][ - 0 - ] + attn_output_shape, _ = self.self_attn.compute_output_shape(input_shape) self.post_self_attn_layernorm.build(attn_output_shape) self.dropout.build(attn_output_shape) self.pre_feedforward_layernorm.build(attn_output_shape) @@ -156,7 +154,7 @@ def call( residual = hidden_states attention_mask = self._make_attention_mask(hidden_states, padding_mask) hidden_states = self.pre_self_attn_layernorm(hidden_states) - (hidden_states, _), _ = self.self_attn( + hidden_states, _ = self.self_attn( inputs=hidden_states, attention_mask=attention_mask, training=training, @@ -197,6 +195,7 @@ def get_config(self): "layer_type": self.layer_type, "sliding_window": self.sliding_window, "rope_max_wavelength": self.rope_max_wavelength, + "attn_logit_softcapping": self.attn_logit_softcapping, } ) return config diff --git a/keras_hub/src/models/t5gemma/t5gemma_layers.py b/keras_hub/src/models/t5gemma/t5gemma_layers.py index a574b737e1..1ec2d039af 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_layers.py +++ b/keras_hub/src/models/t5gemma/t5gemma_layers.py @@ -16,7 +16,6 @@ def t5gemma_kernel_initializer(initializer_range=0.01): return keras.initializers.RandomNormal(mean=0.0, stddev=initializer_range) -@keras.saving.register_keras_serializable(package="keras_hub") class T5GemmaMLP(keras.layers.Layer): """Multilayer Perceptron (MLP) block for the T5Gemma model. From 524aa3707ceada068fbc37ae6f5ecfcbb421390e Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Fri, 25 Jul 2025 12:15:48 +0400 Subject: [PATCH 11/19] cleanup: Slight refactor --- .../src/models/t5gemma/t5gemma_attention.py | 52 +++++++++++-------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/keras_hub/src/models/t5gemma/t5gemma_attention.py b/keras_hub/src/models/t5gemma/t5gemma_attention.py index 1e4d772fc8..0040e3801e 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_attention.py +++ b/keras_hub/src/models/t5gemma/t5gemma_attention.py @@ -198,27 +198,9 @@ def build(self, input_shape): self.softmax = keras.layers.Softmax(axis=-1, dtype="float32") self.built = True - def _compute_attention( + def _compute_attention_without_fused_op( self, query_states, key_states, value_states, attention_mask, training ): - if self._use_fused_attention_op(): - kwargs = {"bias": attention_mask} - if self.logit_soft_cap is not None: - sig = inspect.signature(keras.ops.dot_product_attention) - # This is only supported in JAX TPU backend. - # https://keras.io/api/ops/nn/#dot_product_attention-function - if "attn_logits_soft_cap" in sig.parameters: - kwargs["attn_logits_soft_cap"] = self.logit_soft_cap - return ( - keras.ops.dot_product_attention( - query=query_states, - key=key_states, - value=value_states, - scale=self.scaling, - **kwargs, - ), - None, - ) attn_weights = keras.ops.einsum( "btnh,bsnh->bnts", query_states, key_states ) @@ -237,7 +219,33 @@ def _compute_attention( attn_output = keras.ops.einsum( "bnts,bsnh->btnh", attn_weights, value_states ) - return attn_output, attn_weights + return attn_output + + def _compute_attention( + self, query_states, key_states, value_states, attention_mask, training + ): + if self._use_fused_attention_op(): + kwargs = {"bias": attention_mask} + if self.logit_soft_cap is not None: + sig = inspect.signature(keras.ops.dot_product_attention) + # This is only supported in JAX TPU backend. + # https://keras.io/api/ops/nn/#dot_product_attention-function + if "attn_logits_soft_cap" in sig.parameters: + kwargs["attn_logits_soft_cap"] = self.logit_soft_cap + return keras.ops.dot_product_attention( + query=query_states, + key=key_states, + value=value_states, + scale=self.scaling, + **kwargs, + ) + return self._compute_attention_without_fused_op( + query_states, + key_states, + value_states, + attention_mask, + training, + ) def call( self, @@ -272,7 +280,7 @@ def call( # Repeat key-value heads for GQA. key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_output, _ = self._compute_attention( + attn_output = self._compute_attention( query_states, key_states, value_states, attention_mask, training ) attn_output = self.output_dense(attn_output) @@ -319,7 +327,7 @@ def call( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_output, _ = self._compute_attention( + attn_output = self._compute_attention( query_states, key_states, value_states, attention_mask, training ) attn_output = self.output_dense(attn_output) From 889e23bf446d455ce6cc5e327591ea6a1fc94479 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Wed, 30 Jul 2025 13:39:00 +0400 Subject: [PATCH 12/19] fix: Enable mixed precision and quantization tests --- .../src/models/t5gemma/t5gemma_backbone.py | 2 ++ .../models/t5gemma/t5gemma_backbone_test.py | 2 -- .../src/models/t5gemma/t5gemma_decoder.py | 29 ++++++++++++++----- .../src/models/t5gemma/t5gemma_encoder.py | 22 ++++++++++---- .../src/models/t5gemma/t5gemma_layers.py | 7 ++++- 5 files changed, 47 insertions(+), 15 deletions(-) diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone.py b/keras_hub/src/models/t5gemma/t5gemma_backbone.py index 41f2cb0c8e..8b8c6316c8 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_backbone.py +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone.py @@ -136,6 +136,7 @@ def __init__( output_dim=hidden_dim, embeddings_initializer=clone_initializer(self.kernel_initializer), dtype=dtype, + name="encoder_token_embedding", ) self.decoder_token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, @@ -143,6 +144,7 @@ def __init__( tie_weights=tie_word_embeddings, embeddings_initializer=clone_initializer(self.kernel_initializer), dtype=dtype, + name="decoder_token_embedding", ) self.encoder_layers = [ T5GemmaEncoderLayer( diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py b/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py index 171a233b6f..5aa7e7b46d 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py @@ -40,8 +40,6 @@ def test_backbone_basics(self): init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output_shape=(2, 16, 32), - run_mixed_precision_check=False, - run_quantization_check=False, ) @pytest.mark.large diff --git a/keras_hub/src/models/t5gemma/t5gemma_decoder.py b/keras_hub/src/models/t5gemma/t5gemma_decoder.py index e0ab336a9e..555daa8601 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_decoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_decoder.py @@ -107,12 +107,17 @@ def __init__( attn_logit_softcapping=attn_logit_softcapping, rope_max_wavelength=self.rope_max_wavelength, dtype=self.dtype_policy, + name="self_attention", ) self.pre_self_attn_layernorm = RMSNormalization( - epsilon=rms_norm_eps, dtype=self.dtype_policy + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="pre_self_attention_layernorm", ) self.post_self_attn_layernorm = RMSNormalization( - epsilon=rms_norm_eps, dtype=self.dtype_policy + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="post_self_attention_layernorm", ) # Cross-attention. @@ -129,12 +134,17 @@ def __init__( attention_dropout=attention_dropout, attn_logit_softcapping=attn_logit_softcapping, dtype=self.dtype_policy, + name="cross_attention", ) self.pre_cross_attn_layernorm = RMSNormalization( - epsilon=rms_norm_eps, dtype=self.dtype_policy + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="pre_cross_attention_layernorm", ) self.post_cross_attn_layernorm = RMSNormalization( - epsilon=rms_norm_eps, dtype=self.dtype_policy + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="post_cross_attention_layernorm", ) # MLP. @@ -145,16 +155,21 @@ def __init__( dropout_rate, initializer_range=initializer_range, dtype=self.dtype_policy, + name="mlp", ) self.pre_feedforward_layernorm = RMSNormalization( - epsilon=rms_norm_eps, dtype=self.dtype_policy + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="pre_feedforward_layernorm", ) self.post_feedforward_layernorm = RMSNormalization( - epsilon=rms_norm_eps, dtype=self.dtype_policy + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="post_feedforward_layernorm", ) self.dropout = keras.layers.Dropout( - dropout_rate, dtype=self.dtype_policy + dropout_rate, dtype=self.dtype_policy, name="dropout" ) def build(self, input_shape): diff --git a/keras_hub/src/models/t5gemma/t5gemma_encoder.py b/keras_hub/src/models/t5gemma/t5gemma_encoder.py index 7cdac8ca7e..cd42d767cd 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_encoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_encoder.py @@ -100,12 +100,17 @@ def __init__( attn_logit_softcapping=attn_logit_softcapping, rope_max_wavelength=self.rope_max_wavelength, dtype=self.dtype_policy, + name="self_attention", ) self.pre_self_attn_layernorm = RMSNormalization( - epsilon=rms_norm_eps, dtype=self.dtype_policy + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="pre_self_attention_layernorm", ) self.post_self_attn_layernorm = RMSNormalization( - epsilon=rms_norm_eps, dtype=self.dtype_policy + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="post_self_attention_layernorm", ) self.mlp = T5GemmaMLP( @@ -115,15 +120,22 @@ def __init__( dropout_rate, initializer_range=initializer_range, dtype=self.dtype_policy, + name="mlp", ) self.pre_feedforward_layernorm = RMSNormalization( - epsilon=rms_norm_eps, dtype=self.dtype_policy + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="pre_feedforward_layernorm", ) self.post_feedforward_layernorm = RMSNormalization( - epsilon=rms_norm_eps, dtype=self.dtype_policy + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="post_feedforward_layernorm", ) self.dropout = keras.layers.Dropout( - dropout_rate, dtype=self.dtype_policy + dropout_rate, + dtype=self.dtype_policy, + name="dropout", ) def build(self, input_shape): diff --git a/keras_hub/src/models/t5gemma/t5gemma_layers.py b/keras_hub/src/models/t5gemma/t5gemma_layers.py index 1ec2d039af..a282cadc3e 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_layers.py +++ b/keras_hub/src/models/t5gemma/t5gemma_layers.py @@ -60,18 +60,21 @@ def __init__( use_bias=False, kernel_initializer=clone_initializer(self.kernel_initializer), dtype=self.dtype_policy, + name="gate_proj", ) self.up_proj = keras.layers.Dense( self.intermediate_size, use_bias=False, kernel_initializer=clone_initializer(self.kernel_initializer), dtype=self.dtype_policy, + name="up_proj", ) self.down_proj = keras.layers.Dense( self.hidden_size, use_bias=False, kernel_initializer=clone_initializer(self.kernel_initializer), dtype=self.dtype_policy, + name="down_proj", ) if self.hidden_activation == "gelu_approximate": # NOTE: `gelu_pytorch_tanh` is the same as `gelu(approximate=True)`. @@ -79,7 +82,9 @@ def __init__( else: self.act_fn = keras.activations.get(self.hidden_activation) self.dropout = keras.layers.Dropout( - self.dropout_rate, dtype=self.dtype_policy + self.dropout_rate, + dtype=self.dtype_policy, + name="dropout", ) def build(self, input_shape): From 32a6912d0a87340b913e1644060c939b9c07147d Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Wed, 30 Jul 2025 18:31:48 +0400 Subject: [PATCH 13/19] feat: Add support for asymmetrical presets (only invariants included) --- .../src/models/t5gemma/t5gemma_backbone.py | 185 +++++++++++------- .../models/t5gemma/t5gemma_backbone_test.py | 51 ++++- .../src/models/t5gemma/t5gemma_causal_lm.py | 36 ++-- .../models/t5gemma/t5gemma_causal_lm_test.py | 21 +- .../convert_t5gemma_checkpoints.py | 129 +++++++++--- 5 files changed, 300 insertions(+), 122 deletions(-) diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone.py b/keras_hub/src/models/t5gemma/t5gemma_backbone.py index 8b8c6316c8..a166a63c51 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_backbone.py +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone.py @@ -24,15 +24,34 @@ class T5GemmaBackbone(Backbone): Args: vocabulary_size: int, The size of the vocabulary. - hidden_dim: int, The dimensionality of the hidden states throughout the - model. - intermediate_dim: int, The intermediate size of the feed-forward - networks in encoder and decoder layers. - num_layers: int, The number of encoder and decoder layers. - num_attention_heads: int, The number of attention heads in all attention - mechanisms. - num_key_value_heads: int, The number of key-value heads for grouped - query attention in all attention mechanisms. + encoder_hidden_dim: int, The hidden dimensionality of the encoder. + encoder_intermediate_dim: int, The intermediate size of the encoder's + feed-forward networks. + encoder_num_layers: int, The number of encoder layers. + encoder_num_attention_heads: int, The number of attention heads in the + encoder. + encoder_num_key_value_heads: int, The number of key-value heads in the + encoder. + encoder_head_dim: int, The dimensionality of each attention head in the + encoder. + encoder_layer_types: list of str, A list of strings specifying the type + of attention layer for each encoder layer. Each element can be + either `"sliding_attention"` or `"full_attention"`. For example, + `["full_attention", "sliding_attention", ...]`. + decoder_hidden_dim: int, The hidden dimensionality of the decoder. + decoder_intermediate_dim: int, The intermediate size of the decoder's + feed-forward networks. + decoder_num_layers: int, The number of decoder layers. + decoder_num_attention_heads: int, The number of attention heads in the + decoder. + decoder_num_key_value_heads: int, The number of key-value heads in the + decoder. + decoder_head_dim: int, The dimensionality of each attention head in the + decoder. + decoder_layer_types: list of str, A list of strings specifying the type + of attention layer for each decoder layer. Each element can be + either `"sliding_attention"` or `"full_attention"`. For example, + `["full_attention", "sliding_attention", ...]`. dropout_rate: float, The dropout rate applied throughout the model. rms_norm_eps: float, The epsilon value for RMS normalization. query_pre_attn_scalar: float, Scalar to multiply queries by before @@ -40,11 +59,6 @@ class T5GemmaBackbone(Backbone): attention_bias: bool, Whether to include bias in attention computations. hidden_activation: str, The activation function used in the feed-forward networks. - layer_types: list of str, A list of strings specifying the type of - attention layer for each encoder/decoder layer. Each element can be - either `"sliding_attention"` or `"full_attention"`. For example, - `["full_attention", "sliding_attention", ...]`. - head_dim: int, The dimensionality of each attention head. tie_word_embeddings: bool, Whether to tie input and output word embeddings. Default is `True`. initializer_range: float, The range for the random normal initializer. @@ -55,7 +69,7 @@ class T5GemmaBackbone(Backbone): Required if any `layer_type` is `"sliding_attention"`. cross_attention_hidden_size: int, optional, The hidden size for cross-attention in the decoder layers. If None, it defaults to - `hidden_dim`. + `encoder_hidden_dim`. attn_logit_softcapping: float, optional, The softcapping value for attention logits. final_logit_softcapping: float, optional, The softcapping value for @@ -85,18 +99,28 @@ class T5GemmaBackbone(Backbone): # Randomly initialized T5Gemma backbone with custom config. model = T5GemmaBackbone( vocabulary_size=32000, - num_layers=4, - num_attention_heads=4, - num_key_value_heads=2, - hidden_dim=256, - intermediate_dim=512, - head_dim=64, + # Encoder parameters. + encoder_hidden_dim=256, + encoder_intermediate_dim=512, + encoder_num_layers=4, + encoder_num_attention_heads=4, + encoder_num_key_value_heads=2, + encoder_head_dim=64, + encoder_layer_types=["full_attention"] * 4, + # Decoder parameters. + decoder_hidden_dim=256, + decoder_intermediate_dim=512, + decoder_num_layers=4, + decoder_num_attention_heads=4, + decoder_num_key_value_heads=2, + decoder_head_dim=64, + decoder_layer_types=["full_attention"] * 4, + # Common parameters. dropout_rate=0.1, rms_norm_eps=1e-6, query_pre_attn_scalar=1.0, attention_bias=False, hidden_activation="gelu_approximate", - layer_types=["full_attention"] * 4, ) output = model(input_data) ``` @@ -105,18 +129,25 @@ class T5GemmaBackbone(Backbone): def __init__( self, vocabulary_size, - hidden_dim, - intermediate_dim, - num_layers, - num_attention_heads, - num_key_value_heads, - dropout_rate, - rms_norm_eps, - query_pre_attn_scalar, - attention_bias, - hidden_activation, - layer_types, - head_dim, + encoder_hidden_dim, + encoder_intermediate_dim, + encoder_num_layers, + encoder_num_attention_heads, + encoder_num_key_value_heads, + encoder_head_dim, + encoder_layer_types, + decoder_hidden_dim, + decoder_intermediate_dim, + decoder_num_layers, + decoder_num_attention_heads, + decoder_num_key_value_heads, + decoder_head_dim, + decoder_layer_types, + dropout_rate=0.0, + rms_norm_eps=1e-6, + query_pre_attn_scalar=1.0, + attention_bias=False, + hidden_activation="gelu_approximate", tie_word_embeddings=True, initializer_range=0.02, attention_dropout=0.0, @@ -133,14 +164,14 @@ def __init__( # === Layers === self.token_embedding = keras.layers.Embedding( input_dim=vocabulary_size, - output_dim=hidden_dim, + output_dim=encoder_hidden_dim, embeddings_initializer=clone_initializer(self.kernel_initializer), dtype=dtype, name="encoder_token_embedding", ) self.decoder_token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, - output_dim=hidden_dim, + output_dim=decoder_hidden_dim, tie_weights=tie_word_embeddings, embeddings_initializer=clone_initializer(self.kernel_initializer), dtype=dtype, @@ -148,52 +179,54 @@ def __init__( ) self.encoder_layers = [ T5GemmaEncoderLayer( - hidden_size=hidden_dim, + hidden_size=encoder_hidden_dim, rms_norm_eps=rms_norm_eps, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, + num_attention_heads=encoder_num_attention_heads, + num_key_value_heads=encoder_num_key_value_heads, query_pre_attn_scalar=query_pre_attn_scalar, attention_bias=attention_bias, - intermediate_size=intermediate_dim, + intermediate_size=encoder_intermediate_dim, hidden_activation=hidden_activation, - head_dim=head_dim, + head_dim=encoder_head_dim, dropout_rate=dropout_rate, initializer_range=initializer_range, attention_dropout=attention_dropout, - layer_type=layer_types[i], + layer_type=encoder_layer_types[i], sliding_window=sliding_window, attn_logit_softcapping=attn_logit_softcapping, rope_max_wavelength=rope_max_wavelength, name=f"encoder_layer_{i}", dtype=dtype, ) - for i in range(num_layers) + for i in range(encoder_num_layers) ] self.encoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype) self.encoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype) self.decoder_layers = [ T5GemmaDecoderLayer( - hidden_size=hidden_dim, + hidden_size=decoder_hidden_dim, rms_norm_eps=rms_norm_eps, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, + num_attention_heads=decoder_num_attention_heads, + num_key_value_heads=decoder_num_key_value_heads, query_pre_attn_scalar=query_pre_attn_scalar, attention_bias=attention_bias, - intermediate_size=intermediate_dim, + intermediate_size=decoder_intermediate_dim, hidden_activation=hidden_activation, dropout_rate=dropout_rate, initializer_range=initializer_range, - head_dim=head_dim, + head_dim=decoder_head_dim, attention_dropout=attention_dropout, - layer_type=layer_types[i], + layer_type=decoder_layer_types[i], sliding_window=sliding_window, - cross_attention_hidden_size=cross_attention_hidden_size, + cross_attention_hidden_size=( + cross_attention_hidden_size or encoder_hidden_dim + ), attn_logit_softcapping=attn_logit_softcapping, rope_max_wavelength=rope_max_wavelength, name=f"decoder_layer_{i}", dtype=dtype, ) - for i in range(num_layers) + for i in range(decoder_num_layers) ] self.decoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype) self.decoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype) @@ -209,7 +242,7 @@ def __init__( # Encoder. encoder_embeddings = self.token_embedding(token_id_input) encoder_embeddings = encoder_embeddings * keras.ops.cast( - keras.ops.sqrt(hidden_dim), encoder_embeddings.dtype + keras.ops.sqrt(encoder_hidden_dim), encoder_embeddings.dtype ) encoder_hidden_states = self.encoder_dropout(encoder_embeddings) for layer in self.encoder_layers: @@ -223,7 +256,7 @@ def __init__( # Decoder. decoder_embeddings = self.decoder_token_embedding(token_id_input) decoder_embeddings = decoder_embeddings * keras.ops.cast( - keras.ops.sqrt(hidden_dim), decoder_embeddings.dtype + keras.ops.sqrt(decoder_hidden_dim), decoder_embeddings.dtype ) decoder_hidden_states = self.decoder_dropout(decoder_embeddings) for layer in self.decoder_layers: @@ -246,46 +279,62 @@ def __init__( ) # === Config === + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_intermediate_dim = encoder_intermediate_dim + self.encoder_num_layers = encoder_num_layers + self.encoder_num_attention_heads = encoder_num_attention_heads + self.encoder_num_key_value_heads = encoder_num_key_value_heads + self.encoder_head_dim = encoder_head_dim + self.encoder_layer_types = encoder_layer_types + self.decoder_hidden_dim = decoder_hidden_dim + self.decoder_intermediate_dim = decoder_intermediate_dim + self.decoder_num_layers = decoder_num_layers + self.decoder_num_attention_heads = decoder_num_attention_heads + self.decoder_num_key_value_heads = decoder_num_key_value_heads + self.decoder_head_dim = decoder_head_dim + self.decoder_layer_types = decoder_layer_types self.vocabulary_size = vocabulary_size - self.hidden_dim = hidden_dim - self.intermediate_dim = intermediate_dim - self.num_layers = num_layers - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads self.dropout_rate = dropout_rate self.rms_norm_eps = rms_norm_eps self.tie_word_embeddings = tie_word_embeddings self.query_pre_attn_scalar = query_pre_attn_scalar self.attention_bias = attention_bias self.hidden_activation = hidden_activation - self.layer_types = layer_types self.initializer_range = initializer_range self.attention_dropout = attention_dropout self.sliding_window = sliding_window - self.cross_attention_hidden_size = cross_attention_hidden_size + self.cross_attention_hidden_size = ( + cross_attention_hidden_size or encoder_hidden_dim + ) self.attn_logit_softcapping = attn_logit_softcapping self.final_logit_softcapping = final_logit_softcapping self.rope_max_wavelength = rope_max_wavelength - self.head_dim = head_dim def get_config(self): config = super().get_config() config.update( { "vocabulary_size": self.vocabulary_size, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, - "num_layers": self.num_layers, - "num_attention_heads": self.num_attention_heads, - "num_key_value_heads": self.num_key_value_heads, + "encoder_hidden_dim": self.encoder_hidden_dim, + "encoder_intermediate_dim": self.encoder_intermediate_dim, + "encoder_num_layers": self.encoder_num_layers, + "encoder_num_attention_heads": self.encoder_num_attention_heads, + "encoder_num_key_value_heads": self.encoder_num_key_value_heads, + "encoder_layer_types": self.encoder_layer_types, + "encoder_head_dim": self.encoder_head_dim, + "decoder_hidden_dim": self.decoder_hidden_dim, + "decoder_intermediate_dim": self.decoder_intermediate_dim, + "decoder_num_layers": self.decoder_num_layers, + "decoder_num_attention_heads": self.decoder_num_attention_heads, + "decoder_num_key_value_heads": self.decoder_num_key_value_heads, + "decoder_layer_types": self.decoder_layer_types, + "decoder_head_dim": self.decoder_head_dim, "dropout_rate": self.dropout_rate, "rms_norm_eps": self.rms_norm_eps, "tie_word_embeddings": self.tie_word_embeddings, "query_pre_attn_scalar": self.query_pre_attn_scalar, "attention_bias": self.attention_bias, - "head_dim": self.head_dim, "hidden_activation": self.hidden_activation, - "layer_types": self.layer_types, "initializer_range": self.initializer_range, "attention_dropout": self.attention_dropout, "sliding_window": self.sliding_window, diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py b/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py index 5aa7e7b46d..19ef39a214 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py @@ -9,19 +9,26 @@ class T5GemmaBackboneTest(TestCase): def setUp(self): self.init_kwargs = { "vocabulary_size": 100, - "hidden_dim": 32, - "intermediate_dim": 64, - "num_layers": 2, - "num_attention_heads": 4, - "num_key_value_heads": 2, - "head_dim": 8, + "encoder_hidden_dim": 32, + "encoder_intermediate_dim": 64, + "encoder_num_layers": 2, + "encoder_num_attention_heads": 4, + "encoder_num_key_value_heads": 2, + "encoder_head_dim": 8, + "encoder_layer_types": ["sliding_attention", "full_attention"], + "decoder_hidden_dim": 32, + "decoder_intermediate_dim": 64, + "decoder_num_layers": 2, + "decoder_num_attention_heads": 4, + "decoder_num_key_value_heads": 2, + "decoder_head_dim": 8, + "decoder_layer_types": ["sliding_attention", "full_attention"], "dropout_rate": 0.1, "rms_norm_eps": 1e-6, "tie_word_embeddings": True, "query_pre_attn_scalar": 1.0, "attention_bias": False, "hidden_activation": "gelu_approximate", - "layer_types": ["sliding_attention", "full_attention"], "sliding_window": 16, "cross_attention_hidden_size": 32, "attn_logit_softcapping": 50.0, @@ -42,6 +49,36 @@ def test_backbone_basics(self): expected_output_shape=(2, 16, 32), ) + def test_asymmetrical_backbone(self): + asym_kwargs = { + "vocabulary_size": 100, + "encoder_hidden_dim": 48, + "encoder_intermediate_dim": 96, + "encoder_num_layers": 3, + "encoder_num_attention_heads": 6, + "encoder_num_key_value_heads": 3, + "encoder_head_dim": 8, + "encoder_layer_types": ["full_attention"] * 3, + "decoder_hidden_dim": 32, + "decoder_intermediate_dim": 64, + "decoder_num_layers": 2, + "decoder_num_attention_heads": 4, + "decoder_num_key_value_heads": 2, + "decoder_head_dim": 8, + "decoder_layer_types": ["sliding_attention", "full_attention"], + "sliding_window": 16, + "dropout_rate": 0.1, + "rms_norm_eps": 1e-6, + "tie_word_embeddings": True, + "cross_attention_hidden_size": 48, + } + self.run_backbone_test( + cls=T5GemmaBackbone, + init_kwargs=asym_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 16, 32), + ) + @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py index e34a8c1cc5..d04b269437 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py @@ -117,18 +117,28 @@ class T5GemmaCausalLM(CausalLM): ) backbone = keras_hub.models.T5GemmaBackbone( vocabulary_size=32000, - num_layers=4, - num_attention_heads=4, - num_key_value_heads=2, - hidden_dim=256, - intermediate_dim=512, - head_dim=64, + # Encoder parameters. + encoder_hidden_dim=256, + encoder_intermediate_dim=512, + encoder_num_layers=4, + encoder_num_attention_heads=4, + encoder_num_key_value_heads=2, + encoder_head_dim=64, + encoder_layer_types=["full_attention"] * 4, + # Decoder parameters. + decoder_hidden_dim=256, + decoder_intermediate_dim=512, + decoder_num_layers=4, + decoder_num_attention_heads=4, + decoder_num_key_value_heads=2, + decoder_head_dim=64, + decoder_layer_types=["full_attention"] * 4, + # Common parameters. dropout_rate=0.1, rms_norm_eps=1e-6, query_pre_attn_scalar=1.0, attention_bias=False, hidden_activation="gelu_approximate", - layer_types=["full_attention"] * 4 ) t5gemma_lm = keras_hub.models.T5GemmaCausalLM( backbone=backbone, @@ -166,7 +176,8 @@ def call_encoder(self, token_ids, padding_mask): """Process inputs through the encoder stack.""" encoder_embeddings = self.backbone.token_embedding(token_ids) encoder_embeddings *= keras.ops.cast( - keras.ops.sqrt(self.backbone.hidden_dim), encoder_embeddings.dtype + keras.ops.sqrt(self.backbone.encoder_hidden_dim), + encoder_embeddings.dtype, ) encoder_hidden_states = self.backbone.encoder_dropout( encoder_embeddings, training=False @@ -218,7 +229,8 @@ def call_decoder_with_cache( self_attention_cache, cross_attention_cache = cache hidden_states = self.backbone.decoder_token_embedding(decoder_token_ids) hidden_states *= keras.ops.cast( - keras.ops.sqrt(self.backbone.hidden_dim), hidden_states.dtype + keras.ops.sqrt(self.backbone.decoder_hidden_dim), + hidden_states.dtype, ) hidden_states = self.backbone.decoder_dropout( hidden_states, training=False @@ -276,9 +288,9 @@ def _build_cache(self, token_ids, padding_mask): token_ids, padding_mask ) batch_size = keras.ops.shape(token_ids)[0] - num_layers = self.backbone.num_layers - num_kv_heads = self.backbone.num_key_value_heads - head_dim = self.backbone.head_dim + num_layers = self.backbone.decoder_num_layers + num_kv_heads = self.backbone.decoder_num_key_value_heads + head_dim = self.backbone.decoder_head_dim self_cache_shape = ( batch_size, num_layers, diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py index 003c6211f7..86f686a138 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py +++ b/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py @@ -26,19 +26,26 @@ def setUp(self): ) self.backbone = T5GemmaBackbone( vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), - hidden_dim=16, - intermediate_dim=32, - num_layers=2, - num_attention_heads=2, - num_key_value_heads=1, + encoder_hidden_dim=16, + encoder_intermediate_dim=32, + encoder_num_layers=2, + encoder_num_attention_heads=2, + encoder_num_key_value_heads=1, + encoder_head_dim=8, + encoder_layer_types=["sliding_attention", "full_attention"], + decoder_hidden_dim=16, + decoder_intermediate_dim=32, + decoder_num_layers=2, + decoder_num_attention_heads=2, + decoder_num_key_value_heads=1, + decoder_head_dim=8, + decoder_layer_types=["sliding_attention", "full_attention"], dropout_rate=0.0, - head_dim=8, rms_norm_eps=1e-6, tie_word_embeddings=False, query_pre_attn_scalar=1.0, attention_bias=False, hidden_activation="gelu_approximate", - layer_types=["sliding_attention", "full_attention"], initializer_range=0.02, attention_dropout=0.0, sliding_window=4, diff --git a/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py index 6f52cfe0c7..4f9861fc1c 100644 --- a/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py +++ b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py @@ -52,6 +52,18 @@ "t5gemma_ml_ml_prefixlm": "google/t5gemma-ml-ml-prefixlm", "t5gemma_ml_ml_ul2_it": "google/t5gemma-ml-ml-ul2-it", "t5gemma_ml_ml_prefixlm_it": "google/t5gemma-ml-ml-prefixlm-it", + "t5gemma_xl_xl_ul2": "google/t5gemma-xl-xl-ul2", + "t5gemma_xl_xl_prefixlm": "google/t5gemma-xl-xl-prefixlm", + "t5gemma_xl_xl_ul2_it": "google/t5gemma-xl-xl-ul2-it", + "t5gemma_xl_xl_prefixlm_it": "google/t5gemma-xl-xl-prefixlm-it", + "t5gemma_2b_2b_ul2": "google/t5gemma-2b-2b-ul2", + "t5gemma_2b_2b_prefixlm": "google/t5gemma-2b-2b-prefixlm", + "t5gemma_2b_2b_ul2_it": "google/t5gemma-2b-2b-ul2-it", + "t5gemma_2b_2b_prefixlm_it": "google/t5gemma-2b-2b-prefixlm-it", + "t5gemma_9b_9b_ul2": "google/t5gemma-9b-9b-ul2", + "t5gemma_9b_9b_prefixlm": "google/t5gemma-9b-9b-prefixlm", + "t5gemma_9b_9b_ul2_it": "google/t5gemma-9b-9b-ul2-it", + "t5gemma_9b_9b_prefixlm_it": "google/t5gemma-9b-9b-prefixlm-it", } EXTRACT_DIR = "./model_t5gemma" FLAGS = absl.flags.FLAGS @@ -75,32 +87,40 @@ def download_hf_model(hf_model_name): def convert_model(hf_model, preprocessor): - # NOTE: The decoder and encoder are symmetrical in the config, so they are - # interchangeable. + encoder_config = hf_model.config.encoder decoder_config = hf_model.config.decoder if decoder_config.hidden_activation == "gelu_pytorch_tanh": decoder_config.hidden_activation = "gelu_approximate" + if encoder_config.hidden_activation == "gelu_pytorch_tanh": + encoder_config.hidden_activation = "gelu_approximate" keras_backbone = T5GemmaCausalLM.backbone_cls( vocabulary_size=decoder_config.vocab_size, - hidden_dim=decoder_config.hidden_size, - intermediate_dim=decoder_config.intermediate_size, - num_layers=decoder_config.num_hidden_layers, - num_attention_heads=decoder_config.num_attention_heads, - num_key_value_heads=decoder_config.num_key_value_heads, + encoder_hidden_dim=encoder_config.hidden_size, + encoder_intermediate_dim=encoder_config.intermediate_size, + encoder_num_layers=encoder_config.num_hidden_layers, + encoder_num_attention_heads=encoder_config.num_attention_heads, + encoder_num_key_value_heads=encoder_config.num_key_value_heads, + encoder_head_dim=encoder_config.head_dim, + encoder_layer_types=encoder_config.layer_types, + decoder_hidden_dim=decoder_config.hidden_size, + decoder_intermediate_dim=decoder_config.intermediate_size, + decoder_num_layers=decoder_config.num_hidden_layers, + decoder_num_attention_heads=decoder_config.num_attention_heads, + decoder_num_key_value_heads=decoder_config.num_key_value_heads, + decoder_head_dim=decoder_config.head_dim, + decoder_layer_types=decoder_config.layer_types, dropout_rate=decoder_config.dropout_rate, rms_norm_eps=decoder_config.rms_norm_eps, query_pre_attn_scalar=decoder_config.query_pre_attn_scalar, tie_word_embeddings=getattr( hf_model.config, "tie_word_embeddings", True ), - head_dim=decoder_config.head_dim, attention_bias=decoder_config.attention_bias, hidden_activation=decoder_config.hidden_activation, - layer_types=decoder_config.layer_types, initializer_range=decoder_config.initializer_range, attention_dropout=decoder_config.attention_dropout, sliding_window=decoder_config.sliding_window, - cross_attention_hidden_size=decoder_config.cross_attention_hidden_size, + cross_attention_hidden_size=encoder_config.hidden_size, attn_logit_softcapping=decoder_config.attn_logit_softcapping, final_logit_softcapping=decoder_config.final_logit_softcapping, rope_max_wavelength=decoder_config.rope_theta, @@ -124,10 +144,6 @@ def convert_weights(keras_model, hf_model): print("šŸ‹ļø Converting weights...") hf_wts = hf_model.state_dict() keras_backbone = keras_model.backbone - hidden_dim = keras_backbone.hidden_dim - num_attention_heads = keras_backbone.num_attention_heads - num_key_value_heads = keras_backbone.num_key_value_heads - head_dim = keras_backbone.head_dim # Token Embeddings. keras_backbone.token_embedding.embeddings.assign( hf_wts["encoder.embed_tokens.weight"] @@ -137,8 +153,12 @@ def convert_weights(keras_model, hf_model): ) # Encoder. + encoder_hidden_dim = keras_backbone.encoder_hidden_dim + encoder_num_attention_heads = keras_backbone.encoder_num_attention_heads + encoder_num_key_value_heads = keras_backbone.encoder_num_key_value_heads + encoder_head_dim = keras_backbone.encoder_head_dim keras_backbone.encoder_norm.scale.assign(hf_wts["encoder.norm.weight"]) - for i in range(keras_backbone.num_layers): + for i in range(keras_backbone.encoder_num_layers): encoder_layer = keras_backbone.get_layer(f"encoder_layer_{i}") hf_prefix = f"encoder.layers.{i}" @@ -149,16 +169,32 @@ def convert_weights(keras_model, hf_model): o_w = hf_wts[f"{hf_prefix}.self_attn.o_proj.weight"] encoder_layer.self_attn.query_dense.kernel.assign( - q_w.T.reshape(hidden_dim, num_attention_heads, head_dim).numpy() + q_w.T.reshape( + encoder_hidden_dim, + encoder_num_attention_heads, + encoder_head_dim, + ).numpy() ) encoder_layer.self_attn.key_dense.kernel.assign( - k_w.T.reshape(hidden_dim, num_key_value_heads, head_dim).numpy() + k_w.T.reshape( + encoder_hidden_dim, + encoder_num_key_value_heads, + encoder_head_dim, + ).numpy() ) encoder_layer.self_attn.value_dense.kernel.assign( - v_w.T.reshape(hidden_dim, num_key_value_heads, head_dim).numpy() + v_w.T.reshape( + encoder_hidden_dim, + encoder_num_key_value_heads, + encoder_head_dim, + ).numpy() ) encoder_layer.self_attn.output_dense.kernel.assign( - o_w.T.reshape(num_attention_heads, head_dim, hidden_dim).numpy() + o_w.T.reshape( + encoder_num_attention_heads, + encoder_head_dim, + encoder_hidden_dim, + ).numpy() ) # MLP. @@ -187,8 +223,13 @@ def convert_weights(keras_model, hf_model): ) # Decoder. + decoder_hidden_dim = keras_backbone.decoder_hidden_dim + decoder_num_attention_heads = keras_backbone.decoder_num_attention_heads + decoder_num_key_value_heads = keras_backbone.decoder_num_key_value_heads + decoder_head_dim = keras_backbone.decoder_head_dim + cross_attention_hidden_size = keras_backbone.cross_attention_hidden_size keras_backbone.decoder_norm.scale.assign(hf_wts["decoder.norm.weight"]) - for i in range(keras_backbone.num_layers): + for i in range(keras_backbone.decoder_num_layers): decoder_layer = keras_backbone.get_layer(f"decoder_layer_{i}") hf_prefix = f"decoder.layers.{i}" @@ -198,16 +239,32 @@ def convert_weights(keras_model, hf_model): v_w = hf_wts[f"{hf_prefix}.self_attn.v_proj.weight"] o_w = hf_wts[f"{hf_prefix}.self_attn.o_proj.weight"] decoder_layer.self_attn.query_dense.kernel.assign( - q_w.T.reshape(hidden_dim, num_attention_heads, head_dim).numpy() + q_w.T.reshape( + decoder_hidden_dim, + decoder_num_attention_heads, + decoder_head_dim, + ).numpy() ) decoder_layer.self_attn.key_dense.kernel.assign( - k_w.T.reshape(hidden_dim, num_key_value_heads, head_dim).numpy() + k_w.T.reshape( + decoder_hidden_dim, + decoder_num_key_value_heads, + decoder_head_dim, + ).numpy() ) decoder_layer.self_attn.value_dense.kernel.assign( - v_w.T.reshape(hidden_dim, num_key_value_heads, head_dim).numpy() + v_w.T.reshape( + decoder_hidden_dim, + decoder_num_key_value_heads, + decoder_head_dim, + ).numpy() ) decoder_layer.self_attn.output_dense.kernel.assign( - o_w.T.reshape(num_attention_heads, head_dim, hidden_dim).numpy() + o_w.T.reshape( + decoder_num_attention_heads, + decoder_head_dim, + decoder_hidden_dim, + ).numpy() ) # Cross-attention. @@ -216,16 +273,32 @@ def convert_weights(keras_model, hf_model): v_w = hf_wts[f"{hf_prefix}.cross_attn.v_proj.weight"] o_w = hf_wts[f"{hf_prefix}.cross_attn.o_proj.weight"] decoder_layer.cross_attn.query_dense.kernel.assign( - q_w.T.reshape(hidden_dim, num_attention_heads, head_dim).numpy() + q_w.T.reshape( + decoder_hidden_dim, + decoder_num_attention_heads, + decoder_head_dim, + ).numpy() ) decoder_layer.cross_attn.key_dense.kernel.assign( - k_w.T.reshape(hidden_dim, num_key_value_heads, head_dim).numpy() + k_w.T.reshape( + cross_attention_hidden_size, + decoder_num_key_value_heads, + decoder_head_dim, + ).numpy() ) decoder_layer.cross_attn.value_dense.kernel.assign( - v_w.T.reshape(hidden_dim, num_key_value_heads, head_dim).numpy() + v_w.T.reshape( + cross_attention_hidden_size, + decoder_num_key_value_heads, + decoder_head_dim, + ).numpy() ) decoder_layer.cross_attn.output_dense.kernel.assign( - o_w.T.reshape(num_attention_heads, head_dim, hidden_dim).numpy() + o_w.T.reshape( + decoder_num_attention_heads, + decoder_head_dim, + decoder_hidden_dim, + ).numpy() ) # MLP. From 050910bc5b4d06d4b7b195f8a32b54c8bcf6d6eb Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Wed, 6 Aug 2025 12:04:09 +0400 Subject: [PATCH 14/19] refactor: Address reviews - presets will be handled post D-FINE --- keras_hub/api/models/__init__.py | 8 +- .../src/models/t5gemma/t5gemma_attention.py | 13 +- .../src/models/t5gemma/t5gemma_backbone.py | 75 +++++--- .../models/t5gemma/t5gemma_backbone_test.py | 29 +++- .../t5gemma/t5gemma_causal_lm_preprocessor.py | 72 -------- .../src/models/t5gemma/t5gemma_decoder.py | 12 +- .../src/models/t5gemma/t5gemma_encoder.py | 9 +- .../src/models/t5gemma/t5gemma_layers.py | 6 +- ...a_causal_lm.py => t5gemma_seq_2_seq_lm.py} | 163 +++++++++++------- .../t5gemma_seq_2_seq_lm_preprocessor.py | 88 ++++++++++ ...m_test.py => t5gemma_seq_2_seq_lm_test.py} | 77 +++++---- .../src/models/t5gemma/t5gemma_tokenizer.py | 4 + .../convert_t5gemma_checkpoints.py | 21 ++- 13 files changed, 343 insertions(+), 234 deletions(-) delete mode 100644 keras_hub/src/models/t5gemma/t5gemma_causal_lm_preprocessor.py rename keras_hub/src/models/t5gemma/{t5gemma_causal_lm.py => t5gemma_seq_2_seq_lm.py} (71%) create mode 100644 keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py rename keras_hub/src/models/t5gemma/{t5gemma_causal_lm_test.py => t5gemma_seq_2_seq_lm_test.py} (69%) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index d4ca90dddb..0941e63402 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -602,11 +602,11 @@ from keras_hub.src.models.t5gemma.t5gemma_backbone import ( T5GemmaBackbone as T5GemmaBackbone, ) -from keras_hub.src.models.t5gemma.t5gemma_causal_lm import ( - T5GemmaCausalLM as T5GemmaCausalLM, +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm import ( + T5GemmaSeq2SeqLM as T5GemmaSeq2SeqLM, ) -from keras_hub.src.models.t5gemma.t5gemma_causal_lm_preprocessor import ( - T5GemmaCausalLMPreprocessor as T5GemmaCausalLMPreprocessor, +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import ( + T5GemmaSeq2SeqLMPreprocessor as T5GemmaSeq2SeqLMPreprocessor, ) from keras_hub.src.models.t5gemma.t5gemma_tokenizer import ( T5GemmaTokenizer as T5GemmaTokenizer, diff --git a/keras_hub/src/models/t5gemma/t5gemma_attention.py b/keras_hub/src/models/t5gemma/t5gemma_attention.py index 0040e3801e..51c5490496 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_attention.py +++ b/keras_hub/src/models/t5gemma/t5gemma_attention.py @@ -38,7 +38,6 @@ def repeat_kv(hidden_states, n_rep): ) -@keras.saving.register_keras_serializable(package="keras_hub") class T5GemmaAttention(CachedGemmaAttention): """A unified attention layer for T5Gemma that handles both self-attention and cross-attention. @@ -59,17 +58,17 @@ class T5GemmaAttention(CachedGemmaAttention): attention_type: str, The type of attention, either 'self' or 'cross'. Defaults to 'self'. cross_attention_hidden_size: int, optional, The dimensionality of - encoder hidden states for cross-attention. + encoder hidden states for cross-attention. Defaults to `None`. initializer_range: float, The range for the random normal initializer - for kernel weights. Default is `0.02`. + for kernel weights. Defaults to `0.02`. attention_dropout: float, The dropout rate applied to attention weights. - Default is `0.0`. + Defaults to `0.0`. attn_logit_softcapping: float, optional, The softcapping value for - attention logits. + attention logits. Defaults to `None`. rope_max_wavelength: float, The maximum wavelength for Rotary Positional - Embeddings. Default is `10000.0`. Only used for self-attention. + Embeddings. Defaults to `10000.0`. Only used for self-attention. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use - for model computations and weights. + for model computations and weights. Defaults to `None`. **kwargs: Additional keyword arguments passed to the parent class. """ diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone.py b/keras_hub/src/models/t5gemma/t5gemma_backbone.py index a166a63c51..61d6665746 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_backbone.py +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone.py @@ -53,47 +53,54 @@ class T5GemmaBackbone(Backbone): either `"sliding_attention"` or `"full_attention"`. For example, `["full_attention", "sliding_attention", ...]`. dropout_rate: float, The dropout rate applied throughout the model. - rms_norm_eps: float, The epsilon value for RMS normalization. + Defaults to `0.0`. + rms_norm_eps: float, The epsilon value for RMS normalization. Defaults + to `1e-6`. query_pre_attn_scalar: float, Scalar to multiply queries by before - attention. + attention. Defaults to `1.0`. attention_bias: bool, Whether to include bias in attention computations. + Defaults to `False`. hidden_activation: str, The activation function used in the feed-forward - networks. + networks. Defaults to `"gelu_approximate"`. tie_word_embeddings: bool, Whether to tie input and output word - embeddings. Default is `True`. + embeddings. Defaults to `True`. initializer_range: float, The range for the random normal initializer. - Default is `0.02`. + Defaults to `0.02`. attention_dropout: float, The dropout rate applied to attention weights. - Default is `0.0`. + Defaults to `0.0`. sliding_window: int, optional, The window size for sliding attention. - Required if any `layer_type` is `"sliding_attention"`. + Required if any `layer_type` is `"sliding_attention"`. Defaults to + `None`. cross_attention_hidden_size: int, optional, The hidden size for cross-attention in the decoder layers. If None, it defaults to - `encoder_hidden_dim`. + `encoder_hidden_dim`. Defaults to `None`. attn_logit_softcapping: float, optional, The softcapping value for - attention logits. + attention logits. Defaults to `None`. final_logit_softcapping: float, optional, The softcapping value for - final logits. + final logits. Defaults to `None`. rope_max_wavelength: float, The maximum wavelength for Rotary Positional - Embeddings. Default is `10000.0`. + Embeddings. Defaults to `10000.0`. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use for model computations and weights. Note that some computations, such as softmax and layer normalization, will always be done at - float32 precision regardless of dtype. + float32 precision regardless of dtype. Defaults to `None`. **kwargs: Additional keyword arguments passed to the parent `Backbone` class. Examples: ```python import numpy as np - import keras from keras_hub.models import T5GemmaBackbone input_data = { - "token_ids": np.ones(shape=(1, 12), dtype="int32"), - "padding_mask": np.array( + "encoder_token_ids": np.ones(shape=(1, 12), dtype="int32"), + "encoder_padding_mask": np.array( [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], dtype="int32" ), + "decoder_token_ids": np.ones(shape=(1, 8), dtype="int32"), + "decoder_padding_mask": np.array( + [[1, 1, 1, 1, 1, 1, 1, 1]], dtype="int32" + ), } # Randomly initialized T5Gemma backbone with custom config. @@ -232,29 +239,36 @@ def __init__( self.decoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype) # === Functional Model === - token_id_input = keras.Input( - shape=(None,), dtype="int32", name="token_ids" + encoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_token_ids" + ) + encoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_padding_mask" + ) + decoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_token_ids" ) - padding_mask_input = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" + decoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_padding_mask" ) # Encoder. - encoder_embeddings = self.token_embedding(token_id_input) + encoder_embeddings = self.token_embedding(encoder_token_id_input) encoder_embeddings = encoder_embeddings * keras.ops.cast( keras.ops.sqrt(encoder_hidden_dim), encoder_embeddings.dtype ) encoder_hidden_states = self.encoder_dropout(encoder_embeddings) for layer in self.encoder_layers: encoder_hidden_states = layer( - encoder_hidden_states, - padding_mask=padding_mask_input, + encoder_hidden_states, padding_mask=encoder_padding_mask_input ) encoder_output = self.encoder_norm(encoder_hidden_states) encoder_output = self.encoder_dropout(encoder_output) # Decoder. - decoder_embeddings = self.decoder_token_embedding(token_id_input) + decoder_embeddings = self.decoder_token_embedding( + decoder_token_id_input + ) decoder_embeddings = decoder_embeddings * keras.ops.cast( keras.ops.sqrt(decoder_hidden_dim), decoder_embeddings.dtype ) @@ -262,18 +276,23 @@ def __init__( for layer in self.decoder_layers: decoder_hidden_states, _ = layer( (decoder_hidden_states, encoder_output), - self_attention_padding_mask=padding_mask_input, - cross_attention_padding_mask=padding_mask_input, + self_attention_padding_mask=decoder_padding_mask_input, + cross_attention_padding_mask=encoder_padding_mask_input, ) decoder_output = self.decoder_norm(decoder_hidden_states) decoder_output = self.decoder_dropout(decoder_output) super().__init__( inputs={ - "token_ids": token_id_input, - "padding_mask": padding_mask_input, + "encoder_token_ids": encoder_token_id_input, + "encoder_padding_mask": encoder_padding_mask_input, + "decoder_token_ids": decoder_token_id_input, + "decoder_padding_mask": decoder_padding_mask_input, + }, + outputs={ + "encoder_sequence_output": encoder_output, + "decoder_sequence_output": decoder_output, }, - outputs=decoder_output, dtype=dtype, **kwargs, ) diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py b/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py index 19ef39a214..5c5bfe8229 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py @@ -33,12 +33,14 @@ def setUp(self): "cross_attention_hidden_size": 32, "attn_logit_softcapping": 50.0, "rope_max_wavelength": 10000.0, - "initializer_range": 0.02, - "attention_dropout": 0.0, + "initializer_range": 0.04, + "attention_dropout": 0.1, } self.input_data = { - "token_ids": keras.ops.ones((2, 16), dtype="int32"), - "padding_mask": keras.ops.ones((2, 16), dtype="int32"), + "encoder_token_ids": keras.ops.ones((2, 16), dtype="int32"), + "encoder_padding_mask": keras.ops.ones((2, 16), dtype="int32"), + "decoder_token_ids": keras.ops.ones((2, 16), dtype="int32"), + "decoder_padding_mask": keras.ops.ones((2, 16), dtype="int32"), } def test_backbone_basics(self): @@ -46,7 +48,10 @@ def test_backbone_basics(self): cls=T5GemmaBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 16, 32), + expected_output_shape={ + "encoder_sequence_output": (2, 16, 32), + "decoder_sequence_output": (2, 16, 32), + }, ) def test_asymmetrical_backbone(self): @@ -76,7 +81,10 @@ def test_asymmetrical_backbone(self): cls=T5GemmaBackbone, init_kwargs=asym_kwargs, input_data=self.input_data, - expected_output_shape=(2, 16, 32), + expected_output_shape={ + "encoder_sequence_output": (2, 16, 48), + "decoder_sequence_output": (2, 16, 32), + }, ) @pytest.mark.large @@ -86,3 +94,12 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.input_data, ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in T5GemmaBackbone.presets: + self.run_preset_test( + cls=T5GemmaBackbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm_preprocessor.py b/keras_hub/src/models/t5gemma/t5gemma_causal_lm_preprocessor.py deleted file mode 100644 index 69dc9ec782..0000000000 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm_preprocessor.py +++ /dev/null @@ -1,72 +0,0 @@ -from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone -from keras_hub.src.models.t5gemma.t5gemma_tokenizer import T5GemmaTokenizer - - -@keras_hub_export("keras_hub.models.T5GemmaCausalLMPreprocessor") -class T5GemmaCausalLMPreprocessor(CausalLMPreprocessor): - """T5Gemma Causal LM preprocessor. - - This preprocessing layer is meant for use with - `keras_hub.models.T5GemmaCausalLM`. By default, it will take in batches of - strings, and return outputs in a `(x, y, sample_weight)` format, where the - `y` label is the next token id in the `x` sequence. - - For use with generation, the layer also exposes two methods - `generate_preprocess()` and `generate_postprocess()`. When this preprocessor - is attached to a `keras_hub.models.T5GemmaCausalLM` instance, these methods - will be called implicitly in `generate()`. They can also be called - standalone (e.g. to precompute preprocessing inputs for generation in a - separate process). - - Args: - tokenizer: A `keras_hub.models.T5GemmaTokenizer` instance. - sequence_length: The length of the packed inputs. - add_start_token: If `True`, the preprocessor will prepend the tokenizer - start token to each input sequence. Default is `True`. - add_end_token: If `True`, the preprocessor will append the tokenizer - end token to each input sequence. Default is `False`. - - Call arguments: - x: A string, `tf.Tensor` or list of python strings. - y: Label data. Should always be `None` as the layer generates labels. - sample_weight: Label weights. Should always be `None` as the layer - generates label weights. - sequence_length: Pass to override the configured `sequence_length` of - the layer. - - Examples: - ```python - # Load the preprocessor from a preset. - preprocessor = keras_hub.models.T5GemmaCausalLMPreprocessor.from_preset( - "t5gemma_b_b_prefixlm_it" - ) - - # Tokenize and pack a single sentence. - sentence = tf.constant("The quick brown fox jumped.") - preprocessor(sentence) - # Same output. - preprocessor("The quick brown fox jumped.") - - # Tokenize a batch of sentences. - preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) - - # Apply tokenization to a `tf.data.Dataset`. - features = tf.constant(["The quick brown fox.", "Call me Ishmael."]) - ds = tf.data.Dataset.from_tensor_slices(features) - ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) - - # Prepare tokens for generation (no end token). - preprocessor.generate_preprocess(["The quick brown fox jumped."]) - - # Map generation outputs back to strings. - preprocessor.generate_postprocess({ - 'token_ids': np.array([[2, 714, 4320, 8426, 25341, 32292, 235265, 0]]), - 'padding_mask': np.array([[ 1, 1, 1, 1, 1, 1, 1, 0]]), - }) - ``` - """ - - backbone_cls = T5GemmaBackbone - tokenizer_cls = T5GemmaTokenizer diff --git a/keras_hub/src/models/t5gemma/t5gemma_decoder.py b/keras_hub/src/models/t5gemma/t5gemma_decoder.py index 555daa8601..905d291abb 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_decoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_decoder.py @@ -32,15 +32,17 @@ class T5GemmaDecoderLayer(keras.layers.Layer): attention_dropout: float, The dropout rate applied to attention weights. layer_type: str, Type of attention layer, e.g., `"sliding_attention"`. cross_attention_hidden_size: int, optional, The hidden size for - cross-attention. If None, it defaults to `hidden_size`. + cross-attention. If None, it defaults to `hidden_size`. Defaults to + `None`. attn_logit_softcapping: float, optional, The softcapping value for - attention logits. + attention logits. Defaults to `None`. sliding_window: int, optional, The window size for sliding attention. - Required if `layer_type` is `"sliding_attention"`. + Required if `layer_type` is `"sliding_attention"`. Defaults to + `None`. rope_max_wavelength: float, The maximum wavelength for Rotary - Positional Embeddings. Default is `10000.0`. + Positional Embeddings. Defaults to `10000.0`. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use - for model computations and weights. + for model computations and weights. Defaults to `None`. **kwargs: Additional keyword arguments passed to the parent class. """ diff --git a/keras_hub/src/models/t5gemma/t5gemma_encoder.py b/keras_hub/src/models/t5gemma/t5gemma_encoder.py index cd42d767cd..f6d1f095e2 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_encoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_encoder.py @@ -31,13 +31,14 @@ class T5GemmaEncoderLayer(keras.layers.Layer): layer_type: str, Type of attention layer, e.g., `"sliding_attention"`. head_dim: int, The dimensionality of each attention head. attn_logit_softcapping: float, optional, The softcapping value for - attention logits. + attention logits. Defaults to `None`. sliding_window: int, optional, The window size for sliding attention. - Required if `layer_type` is `"sliding_attention"`. + Required if `layer_type` is `"sliding_attention"`. Defaults to + `None`. rope_max_wavelength: float, The maximum wavelength for Rotary Positional - Embeddings. Default is `10000.0`. + Embeddings. Defaults to `10000.0`. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use - for model computations and weights. + for model computations and weights. Defaults to `None`. **kwargs: Additional keyword arguments passed to the parent class. """ diff --git a/keras_hub/src/models/t5gemma/t5gemma_layers.py b/keras_hub/src/models/t5gemma/t5gemma_layers.py index a282cadc3e..ca4cd0c058 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_layers.py +++ b/keras_hub/src/models/t5gemma/t5gemma_layers.py @@ -8,7 +8,7 @@ def t5gemma_kernel_initializer(initializer_range=0.01): Args: initializer_range: float, The standard deviation of the normal - distribution. Default is `0.01`. + distribution. Defaults to `0.01`. Returns: keras.initializers.RandomNormal: A Keras RandomNormal initializer. @@ -31,9 +31,9 @@ class T5GemmaMLP(keras.layers.Layer): dropout_rate: float, The dropout rate applied to the intermediate hidden states. initializer_range: float, The range for the random normal initializer - for kernel weights. Default is `0.02`. + for kernel weights. Defaults to `0.02`. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use - for model computations and weights. + for model computations and weights. Defaults to `None`. **kwargs: Additional keyword arguments passed to the parent class. """ diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py similarity index 71% rename from keras_hub/src/models/t5gemma/t5gemma_causal_lm.py rename to keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py index d04b269437..799080c9dc 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm.py +++ b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py @@ -1,23 +1,23 @@ import keras from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone -from keras_hub.src.models.t5gemma.t5gemma_causal_lm_preprocessor import ( - T5GemmaCausalLMPreprocessor, +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import ( + T5GemmaSeq2SeqLMPreprocessor, ) from keras_hub.src.utils.tensor_utils import any_equal -@keras_hub_export("keras_hub.models.T5GemmaCausalLM") -class T5GemmaCausalLM(CausalLM): - """An end-to-end T5Gemma model for causal language modeling. +@keras_hub_export("keras_hub.models.T5GemmaSeq2SeqLM") +class T5GemmaSeq2SeqLM(Seq2SeqLM): + """An end-to-end T5Gemma model for seq2seq language modeling. - A causal language model (LM) predicts the next token based on previous - tokens. This task setup can be used to train the model unsupervised on - plain text input, or to autoregressively generate plain text similar to - the data used for training. This task can be used for pre-training or - fine-tuning a T5Gemma model, simply by calling `fit()`. + A seq2seq language model (LM) is an encoder-decoder model which is used for + conditional text generation. The encoder is given a "context" text (fed to + the encoder), and the decoder predicts the next token based on both the + encoder inputs and the previous tokens. You can finetune `T5GemmaSeq2SeqLM` + to generate text for any seq2seq task (e.g., translation or summarization). This model has a `generate()` method, which generates text based on a prompt. The generation strategy used is controlled by an additional @@ -32,26 +32,40 @@ class T5GemmaCausalLM(CausalLM): Args: backbone: A `keras_hub.models.T5GemmaBackbone` instance. - preprocessor: A `keras_hub.models.T5GemmaCausalLMPreprocessor` or + preprocessor: A `keras_hub.models.T5GemmaSeq2SeqLMPreprocessor` or `None`. If `None`, this model will not apply preprocessing, and - inputs should be preprocessed before calling the model. + inputs should be preprocessed before calling the model. Defaults + to `None`. Examples: Use `generate()` to do text generation. ```python - t5gemma_lm = keras_hub.models.T5GemmaCausalLM.from_preset( + import numpy as np + t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset( "t5gemma_b_b_prefixlm_it" ) - t5gemma_lm.generate("I want to say", max_length=30) + # Generate with encoder-only input. + t5gemma_lm.generate("The quick brown fox jumped.", max_length=30) - # Generate with batched prompts. - t5gemma_lm.generate(["This is a", "Where are you"], max_length=30) + # Generate with batched encoder-only inputs. + t5gemma_lm.generate( + ["The quick brown fox jumped.", "The whale."], + max_length=30 + ) + # Generate with encoder and decoder inputs. + t5gemma_lm.generate( + { + "encoder_text": "The quick brown fox jumped.", + "decoder_text": "A fast fox" + }, + max_length=30 + ) ``` Compile the `generate()` function with a custom sampler. ```python - t5gemma_lm = keras_hub.models.T5GemmaCausalLM.from_preset( + t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset( "t5gemma_b_b_prefixlm_it" ) t5gemma_lm.compile(sampler="top_k") @@ -63,17 +77,17 @@ class T5GemmaCausalLM(CausalLM): Use `generate()` without preprocessing. ```python - # The preprocessor is responsible for creating a dictionary of tensors. - # If you are not using a preprocessor, you must format your inputs - # yourself. + # Preprocessed inputs, with encoder inputs corresponding to + # "The quick brown fox", and the decoder inputs to "A fast fox". + # Use `"padding_mask"` to indicate values that should not be overridden. prompt = { - # Token ids for " Keras is". - "token_ids": np.array([[2, 214064, 603, 0, 0, 0, 0]] * 2), - # Use `"padding_mask"` to indicate values that should not be overridden. - "padding_mask": np.array([[1, 1, 1, 0, 0, 0, 0]] * 2), + "encoder_token_ids": np.array([[2, 10, 133, 2119, 6219, 23602, 1, 0]]), + "encoder_padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 0]]), + "decoder_token_ids": np.array([[2, 133, 1769, 1, 0, 0, 0]]), + "decoder_padding_mask": np.array([[1, 1, 1, 1, 0, 0, 0]]) } - t5gemma_lm = keras_hub.models.T5GemmaCausalLM.from_preset( + t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset( "t5gemma_b_b_prefixlm_it", preprocessor=None, ) @@ -82,8 +96,11 @@ class T5GemmaCausalLM(CausalLM): Call `fit()` on a single batch. ```python - features = ["The quick brown fox jumped.", "I forgot my homework."] - t5gemma_lm = keras_hub.models.T5GemmaCausalLM.from_preset( + features = { + "encoder_text": ["The quick fox jumped.", "I forgot my homework."], + "decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."] + } + t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset( "t5gemma_b_b_prefixlm_it" ) t5gemma_lm.fit(x=features, batch_size=2) @@ -92,14 +109,15 @@ class T5GemmaCausalLM(CausalLM): Call `fit()` without preprocessing. ```python x = { - # Token ids for " Keras is deep learning library" - "token_ids": np.array([[2, 214064, 603, 5271, 6044, 9581, 1, 0]] * 2), - "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 0]] * 2), + "encoder_token_ids": np.array([[2, 133, 2119, 1, 0]] * 2), + "encoder_padding_mask": np.array([[1, 1, 1, 1, 0]] * 2), + "decoder_token_ids": np.array([[2, 133, 1769, 1, 0]] * 2), + "decoder_padding_mask": np.array([[1, 1, 1, 1, 1]] * 2), } - y = np.array([[214064, 603, 5271, 6044, 9581, 3, 0, 0]] * 2) - sw = np.array([[1, 1, 1, 1, 1, 1, 0, 0]] * 2) + y = np.array([[133, 1769, 1, 0, 0]] * 2) + sw = np.array([[1, 1, 1, 0, 0]] * 2) - t5gemma_lm = keras_hub.models.T5GemmaCausalLM.from_preset( + t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset( "t5gemma_b_b_prefixlm_it", preprocessor=None, ) @@ -108,12 +126,17 @@ class T5GemmaCausalLM(CausalLM): Custom backbone and vocabulary. ```python + features = { + "encoder_text": ["The quick fox jumped.", "I forgot my homework."], + "decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."] + } tokenizer = keras_hub.models.T5GemmaTokenizer( proto="proto.spm", ) - preprocessor = keras_hub.models.T5GemmaCausalLMPreprocessor( + preprocessor = keras_hub.models.T5GemmaSeq2SeqLMPreprocessor( tokenizer=tokenizer, - sequence_length=128, + encoder_sequence_length=128, + decoder_sequence_length=128, ) backbone = keras_hub.models.T5GemmaBackbone( vocabulary_size=32000, @@ -140,7 +163,7 @@ class T5GemmaCausalLM(CausalLM): attention_bias=False, hidden_activation="gelu_approximate", ) - t5gemma_lm = keras_hub.models.T5GemmaCausalLM( + t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM( backbone=backbone, preprocessor=preprocessor, ) @@ -149,7 +172,7 @@ class T5GemmaCausalLM(CausalLM): """ backbone_cls = T5GemmaBackbone - preprocessor_cls = T5GemmaCausalLMPreprocessor + preprocessor_cls = T5GemmaSeq2SeqLMPreprocessor def __init__(self, backbone, preprocessor=None, **kwargs): # === Layers === @@ -160,7 +183,7 @@ def __init__(self, backbone, preprocessor=None, **kwargs): # This must be "backbone.input" i.e. the full input structure, # rather than "backbone.inputs" which is the flattened list of inputs. inputs = backbone.input - sequence_output = backbone(inputs) + sequence_output = backbone(inputs)["decoder_sequence_output"] logits = backbone.decoder_token_embedding(sequence_output, reverse=True) if self.backbone.final_logit_softcapping is not None: logits = logits / self.backbone.final_logit_softcapping @@ -201,7 +224,7 @@ def call_decoder_with_cache( encoder_output, encoder_padding_mask, ): - """Forward pass of `T5GemmaCausalLM`'s decoder with cache. + """Forward pass of `T5GemmaSeq2SeqLM`'s decoder with cache. `call_decoder_with_cache` adds an additional forward pass for the model for autoregressive inference. Unlike calling the model directly, this @@ -282,12 +305,18 @@ def call_decoder_with_cache( (self_attention_cache, cross_attention_cache), ) - def _build_cache(self, token_ids, padding_mask): + def _build_cache( + self, + encoder_token_ids, + encoder_padding_mask, + decoder_token_ids, + decoder_padding_mask, + ): """Build an empty cache for use with `call_with_cache()`.""" encoder_output, encoder_padding_mask = self.call_encoder( - token_ids, padding_mask + encoder_token_ids, encoder_padding_mask ) - batch_size = keras.ops.shape(token_ids)[0] + batch_size = keras.ops.shape(decoder_token_ids)[0] num_layers = self.backbone.decoder_num_layers num_kv_heads = self.backbone.decoder_num_key_value_heads head_dim = self.backbone.decoder_head_dim @@ -295,7 +324,7 @@ def _build_cache(self, token_ids, padding_mask): batch_size, num_layers, 2, - keras.ops.shape(token_ids)[1], + keras.ops.shape(decoder_token_ids)[1], num_kv_heads, head_dim, ) @@ -304,8 +333,8 @@ def _build_cache(self, token_ids, padding_mask): ) cross_attention_cache = None _, hidden_states, cache = self.call_decoder_with_cache( - decoder_token_ids=token_ids, - decoder_padding_mask=padding_mask, + decoder_token_ids=decoder_token_ids, + decoder_padding_mask=decoder_padding_mask, cache=(self_attention_cache, cross_attention_cache), cache_update_index=0, encoder_output=encoder_output, @@ -320,24 +349,32 @@ def generate_step(self, inputs, stop_token_ids=None): This function represents the inner, XLA-compilable, generation function for a single batch of inputs. Inputs should have the same structure as model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + `"encoder_token_ids"`, `"encoder_padding_mask"`, `"decoder_token_ids"` + and `"decoder_padding_mask"`. Args: - inputs: A dictionary with two keys `"token_ids"` and - `"padding_mask"` and batched tensor values. + inputs: A dictionary with four keys - `"encoder_token_ids"`, + `"encoder_padding_mask"`, `"decoder_token_ids"` and + `"decoder_padding_mask"`, with batched tensor values. stop_token_ids: Tuple of id's of end token's to stop on. If all sequences have produced a new stop token, generation will stop. """ - token_ids = inputs["token_ids"] - padding_mask = inputs["padding_mask"] + encoder_token_ids = inputs["encoder_token_ids"] + encoder_padding_mask = inputs["encoder_padding_mask"] + decoder_token_ids = inputs["decoder_token_ids"] + decoder_padding_mask = inputs["decoder_padding_mask"] # Create and seed cache with a single forward pass. hidden_states, cache, extra_cache_info = self._build_cache( - token_ids=token_ids, padding_mask=padding_mask + encoder_token_ids=encoder_token_ids, + encoder_padding_mask=encoder_padding_mask, + decoder_token_ids=decoder_token_ids, + decoder_padding_mask=decoder_padding_mask, ) encoder_output, encoder_padding_mask = extra_cache_info # Compute the lengths of all user inputted tokens ids. row_lengths = keras.ops.sum( - keras.ops.cast(padding_mask, "int32"), axis=-1 + keras.ops.cast(decoder_padding_mask, "int32"), axis=-1 ) # Start at the first index that has no user inputted id. index = keras.ops.min(row_lengths) @@ -363,12 +400,12 @@ def next(prompt, cache, index): ) return keras.ops.squeeze(logits, axis=1), None, updated_cache - token_ids = self.sampler( + decoder_token_ids = self.sampler( next=next, - prompt=token_ids, + prompt=decoder_token_ids, cache=cache, index=index, - mask=padding_mask, + mask=decoder_padding_mask, stop_token_ids=stop_token_ids, hidden_states=hidden_states, model=self, @@ -377,11 +414,11 @@ def next(prompt, cache, index): # Compute an output padding mask with the token ids we updated. if stop_token_ids is not None: # Build a mask of `stop_token_ids` locations not in the original - # prompt (not in locations where `padding_mask` is True). + # prompt (not in locations where `decoder_padding_mask` is True). end_locations = any_equal( - token_ids, + decoder_token_ids, stop_token_ids, - keras.ops.logical_not(padding_mask), + keras.ops.logical_not(decoder_padding_mask), ) # Use cumsum to get ones in all locations after end_locations. end_locations = keras.ops.cast(end_locations, "int32") @@ -390,14 +427,16 @@ def next(prompt, cache, index): ) overflow = cumsum - end_locations # Our padding mask is the inverse of these overflow locations. - padding_mask = keras.ops.logical_not( + decoder_padding_mask = keras.ops.logical_not( keras.ops.cast(overflow, "bool") ) else: # Without early stopping, all locations will have been updated. - padding_mask = keras.ops.ones_like(token_ids, dtype="bool") + decoder_padding_mask = keras.ops.ones_like( + decoder_token_ids, dtype="bool" + ) return { - "token_ids": token_ids, - "padding_mask": padding_mask, + "decoder_token_ids": decoder_token_ids, + "decoder_padding_mask": decoder_padding_mask, } diff --git a/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py new file mode 100644 index 0000000000..f50fe96f61 --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py @@ -0,0 +1,88 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.models.t5gemma.t5gemma_tokenizer import T5GemmaTokenizer + + +@keras_hub_export("keras_hub.models.T5GemmaSeq2SeqLMPreprocessor") +class T5GemmaSeq2SeqLMPreprocessor(Seq2SeqLMPreprocessor): + """T5Gemma Seq2Seq LM preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.T5GemmaSeq2SeqLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_hub.models.T5GemmaSeq2SeqLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_hub.models.T5GemmaTokenizer` instance. + encoder_sequence_length: The length of the packed encoder inputs. + decoder_sequence_length: The length of the packed decoder inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Defaults to `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Defaults to `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. Can also be a + dictionary with `encoder_text` and `decoder_text` keys. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + encoder_sequence_length: Pass to override the configured + `encoder_sequence_length` of the layer. + decoder_sequence_length: Pass to override the configured + `decoder_sequence_length` of the layer. + + Examples: + ```python + import tensorflow as tf + import numpy as np + + # Load the preprocessor from a preset. + preprocessor = keras_hub.models.T5GemmaSeq2SeqLMPreprocessor.from_preset( + "t5gemma_b_b_prefixlm_it" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("The quick brown fox jumped.") + preprocessor(sentence) + + # Tokenize a batch of sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + # Tokenize a dictionary with separate encoder and decoder inputs. + preprocessor({ + "encoder_text": "The quick brown fox jumped.", + "decoder_text": "The fast fox." + }) + + # Apply tokenization to a `tf.data.Dataset`. + encoder_features = tf.constant(["The quick brown fox.", "Call me Ishmael."]) + decoder_features = tf.constant(["The fast fox.", "I am Ishmael."]) + ds = tf.data.Dataset.from_tensor_slices( + {"encoder_text": encoder_features, "decoder_text": decoder_features} + ) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Prepare tokens for generation. + preprocessor.generate_preprocess({ + "encoder_text": "The quick brown fox jumped.", + "decoder_text": "The fast fox." + }) + + # Map generation outputs back to strings. + preprocessor.generate_postprocess({ + 'decoder_token_ids': np.array([[2, 714, 4320, 8426, 25341, 1, 0, 0]]), + 'decoder_padding_mask': np.array([[1, 1, 1, 1, 1, 1, 0, 0]]), + }) + ``` + """ + + backbone_cls = T5GemmaBackbone + tokenizer_cls = T5GemmaTokenizer diff --git a/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_test.py similarity index 69% rename from keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py rename to keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_test.py index 86f686a138..0a4cb0ef4e 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_causal_lm_test.py +++ b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_test.py @@ -5,24 +5,25 @@ import pytest from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone -from keras_hub.src.models.t5gemma.t5gemma_causal_lm import T5GemmaCausalLM -from keras_hub.src.models.t5gemma.t5gemma_causal_lm_preprocessor import ( - T5GemmaCausalLMPreprocessor, +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm import T5GemmaSeq2SeqLM +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import ( + T5GemmaSeq2SeqLMPreprocessor, ) from keras_hub.src.models.t5gemma.t5gemma_tokenizer import T5GemmaTokenizer from keras_hub.src.tests.test_case import TestCase -class T5GemmaCausalLMTest(TestCase): +class T5GemmaSeq2SeqLMTest(TestCase): def setUp(self): self.tokenizer = T5GemmaTokenizer( proto=os.path.join( self.get_test_data_dir(), "gemma_test_vocab.spm" ), ) - self.preprocessor = T5GemmaCausalLMPreprocessor( - self.tokenizer, - sequence_length=8, + self.preprocessor = T5GemmaSeq2SeqLMPreprocessor( + tokenizer=self.tokenizer, + encoder_sequence_length=8, + decoder_sequence_length=10, ) self.backbone = T5GemmaBackbone( vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), @@ -55,49 +56,51 @@ def setUp(self): "preprocessor": self.preprocessor, "backbone": self.backbone, } - self.train_data = (["the quick brown fox", "the earth is round"],) + self.train_data = ( + { + "encoder_text": ["the quick brown fox", "the earth is round"], + "decoder_text": ["the quick brown fox", "the earth is round"], + }, + ) self.input_data = self.preprocessor(*self.train_data)[0] def test_causal_lm_basics(self): self.run_task_test( - cls=T5GemmaCausalLM, + cls=T5GemmaSeq2SeqLM, init_kwargs=self.init_kwargs, train_data=self.train_data, expected_output_shape=( 2, - 8, + 10, self.preprocessor.tokenizer.vocabulary_size(), ), ) def test_generate(self): - causal_lm = T5GemmaCausalLM(**self.init_kwargs) - # String input. - prompt = "the quick brown fox" - output = causal_lm.generate(prompt) - self.assertTrue(prompt in output) + causal_lm = T5GemmaSeq2SeqLM(**self.init_kwargs) + # String inputs. + inputs = { + "encoder_text": "the quick brown fox", + "decoder_text": "the quick", + } + output = causal_lm.generate(inputs) + self.assertTrue("the quick" in output) # Int tensor input. - prompt_ids = self.preprocessor.generate_preprocess([prompt]) + prompt_ids = self.preprocessor.generate_preprocess(inputs) causal_lm.preprocessor = None outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) # Assert prompt is in output in token id space. self.assertAllEqual( - outputs["token_ids"][:, :5], - prompt_ids["token_ids"][:, :5], + outputs["decoder_token_ids"][:, :3], + prompt_ids["decoder_token_ids"][:, :3], ) self.assertAllEqual( - outputs["padding_mask"][:, :5], - prompt_ids["padding_mask"][:, :5], + outputs["decoder_padding_mask"][:, :3], + prompt_ids["decoder_padding_mask"][:, :3], ) - def test_generate_strip_prompt(self): - causal_lm = T5GemmaCausalLM(**self.init_kwargs) - prompt = "the quick brown fox" - output = causal_lm.generate(prompt, strip_prompt=True) - self.assertFalse(output.startswith(prompt)) - def test_early_stopping(self): - causal_lm = T5GemmaCausalLM(**self.init_kwargs) + causal_lm = T5GemmaSeq2SeqLM(**self.init_kwargs) call_decoder_with_cache = causal_lm.call_decoder_with_cache def wrapper(*args, **kwargs): @@ -122,13 +125,19 @@ def wrapper(*args, **kwargs): ) with patch.object(causal_lm, "call_decoder_with_cache", wraps=wrapper): - prompt = ["the quick brown fox", "the earth is round"] - output = causal_lm.generate(prompt) + inputs = { + "encoder_text": [ + "the quick brown fox", + "the earth is round", + ], + "decoder_text": ["the quick", "the earth"], + } + output = causal_lm.generate(inputs) # We should immediately abort and output the prompt. - self.assertEqual(prompt, output) + self.assertEqual(inputs["decoder_text"], output) def test_generate_compilation(self): - causal_lm = T5GemmaCausalLM(**self.init_kwargs) + causal_lm = T5GemmaSeq2SeqLM(**self.init_kwargs) # Assert we do not recompile with successive calls. causal_lm.generate("the quick brown fox") first_fn = causal_lm.generate_function @@ -142,16 +151,16 @@ def test_generate_compilation(self): @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( - cls=T5GemmaCausalLM, + cls=T5GemmaSeq2SeqLM, init_kwargs=self.init_kwargs, input_data=self.input_data, ) @pytest.mark.extra_large def test_all_presets(self): - for preset in T5GemmaCausalLM.presets: + for preset in T5GemmaSeq2SeqLM.presets: self.run_preset_test( - cls=T5GemmaCausalLM, + cls=T5GemmaSeq2SeqLM, preset=preset, input_data=self.input_data, ) diff --git a/keras_hub/src/models/t5gemma/t5gemma_tokenizer.py b/keras_hub/src/models/t5gemma/t5gemma_tokenizer.py index f63f617a91..a3a6d27365 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +++ b/keras_hub/src/models/t5gemma/t5gemma_tokenizer.py @@ -35,6 +35,10 @@ class T5GemmaTokenizer(SentencePieceTokenizer): Examples: ```python + import io + import tensorflow as tf + import sentencepiece + # Unbatched input. tokenizer = keras_hub.models.T5GemmaTokenizer.from_preset( "t5gemma_b_b_prefixlm_it" diff --git a/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py index 4f9861fc1c..047814bd20 100644 --- a/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py +++ b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py @@ -29,9 +29,9 @@ import torch import transformers -from keras_hub.src.models.t5gemma.t5gemma_causal_lm import T5GemmaCausalLM -from keras_hub.src.models.t5gemma.t5gemma_causal_lm_preprocessor import ( - T5GemmaCausalLMPreprocessor, +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm import T5GemmaSeq2SeqLM +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import ( + T5GemmaSeq2SeqLMPreprocessor, ) from keras_hub.src.models.t5gemma.t5gemma_tokenizer import T5GemmaTokenizer @@ -93,7 +93,7 @@ def convert_model(hf_model, preprocessor): decoder_config.hidden_activation = "gelu_approximate" if encoder_config.hidden_activation == "gelu_pytorch_tanh": encoder_config.hidden_activation = "gelu_approximate" - keras_backbone = T5GemmaCausalLM.backbone_cls( + keras_backbone = T5GemmaSeq2SeqLM.backbone_cls( vocabulary_size=decoder_config.vocab_size, encoder_hidden_dim=encoder_config.hidden_size, encoder_intermediate_dim=encoder_config.intermediate_size, @@ -125,7 +125,7 @@ def convert_model(hf_model, preprocessor): final_logit_softcapping=decoder_config.final_logit_softcapping, rope_max_wavelength=decoder_config.rope_theta, ) - keras_model = T5GemmaCausalLM( + keras_model = T5GemmaSeq2SeqLM( backbone=keras_backbone, preprocessor=preprocessor ) print("āœ… Keras model instantiated.") @@ -352,8 +352,10 @@ def validate_output(hf_model, keras_model, hf_tokenizer, keras_tokenizer): input_ids_np = np.ones((1, 10), dtype="int32") attention_mask_np = np.ones((1, 10), dtype="int32") keras_inputs = { - "token_ids": input_ids_np, - "padding_mask": attention_mask_np, + "encoder_token_ids": input_ids_np, + "encoder_padding_mask": attention_mask_np, + "decoder_token_ids": input_ids_np, + "decoder_padding_mask": attention_mask_np, } hf_input_ids = torch.from_numpy(input_ids_np) hf_attention_mask = torch.from_numpy(attention_mask_np) @@ -365,7 +367,8 @@ def validate_output(hf_model, keras_model, hf_tokenizer, keras_tokenizer): ) hf_final_hidden_states = hf_outputs.last_hidden_state.detach().numpy() print("\nšŸ”Ž Validating final hidden states...") - keras_final_hidden_states = keras_model.backbone.predict(keras_inputs) + keras_output = keras_model.backbone.predict(keras_inputs) + keras_final_hidden_states = keras_output["decoder_sequence_output"] final_difference = np.mean( np.abs(hf_final_hidden_states - keras_final_hidden_states) ) @@ -390,7 +393,7 @@ def main(_): keras_tokenizer = convert_tokenizer(hf_model_dir) - keras_preprocessor = T5GemmaCausalLMPreprocessor( + keras_preprocessor = T5GemmaSeq2SeqLMPreprocessor( tokenizer=keras_tokenizer, ) keras_model = convert_model(hf_model, keras_preprocessor) From 6b320fa17a1760168e6a4e62b35ab60a984ddf74 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Sun, 17 Aug 2025 22:46:50 +0400 Subject: [PATCH 15/19] feat: Support direct loading of Hugging Face checkpoints --- .../src/models/t5gemma/t5gemma_decoder.py | 16 +- .../src/models/t5gemma/t5gemma_encoder.py | 2 +- .../src/models/t5gemma/t5gemma_layers.py | 2 +- .../src/utils/transformers/convert_t5gemma.py | 229 ++++++++++++++++++ .../transformers/convert_t5gemma_test.py | 31 +++ .../src/utils/transformers/preset_loader.py | 3 + 6 files changed, 274 insertions(+), 9 deletions(-) create mode 100644 keras_hub/src/utils/transformers/convert_t5gemma.py create mode 100644 keras_hub/src/utils/transformers/convert_t5gemma_test.py diff --git a/keras_hub/src/models/t5gemma/t5gemma_decoder.py b/keras_hub/src/models/t5gemma/t5gemma_decoder.py index 905d291abb..fb9fb6950d 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_decoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_decoder.py @@ -114,12 +114,12 @@ def __init__( self.pre_self_attn_layernorm = RMSNormalization( epsilon=rms_norm_eps, dtype=self.dtype_policy, - name="pre_self_attention_layernorm", + name="decoder_pre_self_attention_layernorm", ) self.post_self_attn_layernorm = RMSNormalization( epsilon=rms_norm_eps, dtype=self.dtype_policy, - name="post_self_attention_layernorm", + name="decoder_post_self_attention_layernorm", ) # Cross-attention. @@ -141,12 +141,12 @@ def __init__( self.pre_cross_attn_layernorm = RMSNormalization( epsilon=rms_norm_eps, dtype=self.dtype_policy, - name="pre_cross_attention_layernorm", + name="decoder_pre_cross_attention_layernorm", ) self.post_cross_attn_layernorm = RMSNormalization( epsilon=rms_norm_eps, dtype=self.dtype_policy, - name="post_cross_attention_layernorm", + name="decoder_post_cross_attention_layernorm", ) # MLP. @@ -162,16 +162,18 @@ def __init__( self.pre_feedforward_layernorm = RMSNormalization( epsilon=rms_norm_eps, dtype=self.dtype_policy, - name="pre_feedforward_layernorm", + name="decoder_pre_feedforward_layernorm", ) self.post_feedforward_layernorm = RMSNormalization( epsilon=rms_norm_eps, dtype=self.dtype_policy, - name="post_feedforward_layernorm", + name="decoder_post_feedforward_layernorm", ) self.dropout = keras.layers.Dropout( - dropout_rate, dtype=self.dtype_policy, name="dropout" + dropout_rate, + dtype=self.dtype_policy, + name="decoder_residual_dropout", ) def build(self, input_shape): diff --git a/keras_hub/src/models/t5gemma/t5gemma_encoder.py b/keras_hub/src/models/t5gemma/t5gemma_encoder.py index f6d1f095e2..d17a3e880c 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_encoder.py +++ b/keras_hub/src/models/t5gemma/t5gemma_encoder.py @@ -136,7 +136,7 @@ def __init__( self.dropout = keras.layers.Dropout( dropout_rate, dtype=self.dtype_policy, - name="dropout", + name="residual_dropout", ) def build(self, input_shape): diff --git a/keras_hub/src/models/t5gemma/t5gemma_layers.py b/keras_hub/src/models/t5gemma/t5gemma_layers.py index ca4cd0c058..1a9d18f186 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_layers.py +++ b/keras_hub/src/models/t5gemma/t5gemma_layers.py @@ -84,7 +84,7 @@ def __init__( self.dropout = keras.layers.Dropout( self.dropout_rate, dtype=self.dtype_policy, - name="dropout", + name="mlp_dropout", ) def build(self, input_shape): diff --git a/keras_hub/src/utils/transformers/convert_t5gemma.py b/keras_hub/src/utils/transformers/convert_t5gemma.py new file mode 100644 index 0000000000..f89242d944 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_t5gemma.py @@ -0,0 +1,229 @@ +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.utils.preset_utils import get_file + +backbone_cls = T5GemmaBackbone + + +def convert_backbone_config(transformers_config): + """Convert a Hugging Face T5Gemma config to a KerasHub backbone config.""" + encoder_config = transformers_config["encoder"] + decoder_config = transformers_config["decoder"] + + if decoder_config.get("hidden_activation") == "gelu_pytorch_tanh": + decoder_config["hidden_activation"] = "gelu_approximate" + if encoder_config.get("hidden_activation") == "gelu_pytorch_tanh": + encoder_config["hidden_activation"] = "gelu_approximate" + + backbone_config = { + "vocabulary_size": decoder_config["vocab_size"], + "encoder_hidden_dim": encoder_config["hidden_size"], + "encoder_intermediate_dim": encoder_config["intermediate_size"], + "encoder_num_layers": encoder_config["num_hidden_layers"], + "encoder_num_attention_heads": encoder_config["num_attention_heads"], + "encoder_num_key_value_heads": encoder_config["num_key_value_heads"], + "encoder_head_dim": encoder_config["head_dim"], + "encoder_layer_types": encoder_config["layer_types"], + "decoder_hidden_dim": decoder_config["hidden_size"], + "decoder_intermediate_dim": decoder_config["intermediate_size"], + "decoder_num_layers": decoder_config["num_hidden_layers"], + "decoder_num_attention_heads": decoder_config["num_attention_heads"], + "decoder_num_key_value_heads": decoder_config["num_key_value_heads"], + "decoder_head_dim": decoder_config["head_dim"], + "decoder_layer_types": decoder_config["layer_types"], + "dropout_rate": decoder_config["dropout_rate"], + "rms_norm_eps": decoder_config["rms_norm_eps"], + "query_pre_attn_scalar": decoder_config["query_pre_attn_scalar"], + "tie_word_embeddings": transformers_config.get( + "tie_word_embeddings", True + ), + "attention_bias": decoder_config["attention_bias"], + "hidden_activation": decoder_config["hidden_activation"], + "initializer_range": decoder_config["initializer_range"], + "attention_dropout": decoder_config["attention_dropout"], + "sliding_window": decoder_config["sliding_window"], + "cross_attention_hidden_size": encoder_config["hidden_size"], + "attn_logit_softcapping": decoder_config["attn_logit_softcapping"], + "final_logit_softcapping": decoder_config["final_logit_softcapping"], + "rope_max_wavelength": decoder_config["rope_theta"], + } + return backbone_config + + +def convert_weights(backbone, loader, transformers_config): + """Convert T5Gemma from Hugging Face to KerasHub.""" + # Token embeddings. + loader.port_weight( + keras_variable=backbone.token_embedding.embeddings, + hf_weight_key="encoder.embed_tokens.weight", + ) + loader.port_weight( + keras_variable=backbone.decoder_token_embedding.embeddings, + hf_weight_key="decoder.embed_tokens.weight", + ) + + # Encoder. + loader.port_weight( + keras_variable=backbone.encoder_norm.scale, + hf_weight_key="encoder.norm.weight", + ) + for i in range(backbone.encoder_num_layers): + layer = backbone.get_layer(f"encoder_layer_{i}") + hf_prefix = f"encoder.layers.{i}" + + # Self-attention. + loader.port_weight( + keras_variable=layer.self_attn.query_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.q_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.key_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.k_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.value_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.v_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.output_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.o_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + + # MLP. + loader.port_weight( + keras_variable=layer.mlp.gate_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.gate_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.up_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.up_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.down_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.down_proj.weight", + hook_fn=lambda w, s: w.T, + ) + + # Layer norm. + loader.port_weight( + keras_variable=layer.pre_self_attn_layernorm.scale, + hf_weight_key=f"{hf_prefix}.pre_self_attn_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.post_self_attn_layernorm.scale, + hf_weight_key=f"{hf_prefix}.post_self_attn_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.pre_feedforward_layernorm.scale, + hf_weight_key=f"{hf_prefix}.pre_feedforward_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.post_feedforward_layernorm.scale, + hf_weight_key=f"{hf_prefix}.post_feedforward_layernorm.weight", + ) + + # Decoder. + loader.port_weight( + keras_variable=backbone.decoder_norm.scale, + hf_weight_key="decoder.norm.weight", + ) + for i in range(backbone.decoder_num_layers): + layer = backbone.get_layer(f"decoder_layer_{i}") + hf_prefix = f"decoder.layers.{i}" + + # Self-attention. + loader.port_weight( + keras_variable=layer.self_attn.query_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.q_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.key_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.k_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.value_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.v_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.output_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.o_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + + # Cross-attention. + loader.port_weight( + keras_variable=layer.cross_attn.query_dense.kernel, + hf_weight_key=f"{hf_prefix}.cross_attn.q_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.cross_attn.key_dense.kernel, + hf_weight_key=f"{hf_prefix}.cross_attn.k_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.cross_attn.value_dense.kernel, + hf_weight_key=f"{hf_prefix}.cross_attn.v_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.cross_attn.output_dense.kernel, + hf_weight_key=f"{hf_prefix}.cross_attn.o_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + + # MLP. + loader.port_weight( + keras_variable=layer.mlp.gate_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.gate_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.up_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.up_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.down_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.down_proj.weight", + hook_fn=lambda w, s: w.T, + ) + + # Layer norm. + loader.port_weight( + keras_variable=layer.pre_self_attn_layernorm.scale, + hf_weight_key=f"{hf_prefix}.pre_self_attn_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.post_self_attn_layernorm.scale, + hf_weight_key=f"{hf_prefix}.post_self_attn_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.pre_cross_attn_layernorm.scale, + hf_weight_key=f"{hf_prefix}.pre_cross_attn_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.post_cross_attn_layernorm.scale, + hf_weight_key=f"{hf_prefix}.post_cross_attn_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.pre_feedforward_layernorm.scale, + hf_weight_key=f"{hf_prefix}.pre_feedforward_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.post_feedforward_layernorm.scale, + hf_weight_key=f"{hf_prefix}.post_feedforward_layernorm.weight", + ) + + +def convert_tokenizer(cls, preset, **kwargs): + """Convert a T5Gemma tokenizer.""" + return cls(get_file(preset, "tokenizer.model"), **kwargs) diff --git a/keras_hub/src/utils/transformers/convert_t5gemma_test.py b/keras_hub/src/utils/transformers/convert_t5gemma_test.py new file mode 100644 index 0000000000..dfa8f6f0ad --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_t5gemma_test.py @@ -0,0 +1,31 @@ +import pytest + +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm import T5GemmaSeq2SeqLM +from keras_hub.src.tests.test_case import TestCase + + +class TestTask(TestCase): + @pytest.mark.large + def test_convert_tiny_preset(self): + model = T5GemmaSeq2SeqLM.from_preset( + "hf://google/t5gemma-b-b-prefixlm-it" + ) + prompt = "What is the capital of France?" + model.generate([prompt], max_length=15) + + @pytest.mark.large + def test_class_detection(self): + preset_name = "hf://google/t5gemma-b-b-prefixlm-it" + model = Seq2SeqLM.from_preset( + preset_name, + load_weights=False, + ) + self.assertIsInstance(model, T5GemmaSeq2SeqLM) + model = Backbone.from_preset( + preset_name, + load_weights=False, + ) + self.assertIsInstance(model, T5GemmaBackbone) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 4accea67a1..bee392289f 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -18,6 +18,7 @@ from keras_hub.src.utils.transformers import convert_qwen from keras_hub.src.utils.transformers import convert_qwen3 from keras_hub.src.utils.transformers import convert_qwen_moe +from keras_hub.src.utils.transformers import convert_t5gemma from keras_hub.src.utils.transformers import convert_vit from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -59,6 +60,8 @@ def __init__(self, preset, config): self.converter = convert_qwen_moe elif model_type == "qwen3": self.converter = convert_qwen3 + elif model_type == "t5gemma": + self.converter = convert_t5gemma else: raise ValueError( "KerasHub has no converter for huggingface/transformers models " From 26db4d1748f781d4684933dc7ef97f3d27de7b53 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Thu, 21 Aug 2025 14:41:05 +0400 Subject: [PATCH 16/19] =?UTF-8?q?=E2=9C=85=20Yayy:=20Generate=20outputs=20?= =?UTF-8?q?identical,=20hidden=20states=20match=20within=201e-3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../t5gemma_seq_2_seq_lm_preprocessor.py | 137 ++++++- .../convert_t5gemma_checkpoints.py | 334 ++++++++++-------- 2 files changed, 329 insertions(+), 142 deletions(-) diff --git a/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py index f50fe96f61..9ce57a7d60 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +++ b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py @@ -1,7 +1,15 @@ +import keras + from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone from keras_hub.src.models.t5gemma.t5gemma_tokenizer import T5GemmaTokenizer +from keras_hub.src.utils.tensor_utils import preprocessing_function + +try: + import tensorflow as tf +except ImportError: + tf = None @keras_hub_export("keras_hub.models.T5GemmaSeq2SeqLMPreprocessor") @@ -24,10 +32,12 @@ class T5GemmaSeq2SeqLMPreprocessor(Seq2SeqLMPreprocessor): tokenizer: A `keras_hub.models.T5GemmaTokenizer` instance. encoder_sequence_length: The length of the packed encoder inputs. decoder_sequence_length: The length of the packed decoder inputs. - add_start_token: If `True`, the preprocessor will prepend the tokenizer - start token to each input sequence. Defaults to `True`. - add_end_token: If `True`, the preprocessor will append the tokenizer - end token to each input sequence. Defaults to `False`. + add_start_token: If `True`, the preprocessor will prepend the + tokenizer start token to each input sequence. For T5Gemma models, + this should be `False`. Defaults to `False`. + add_end_token: If `True`, the preprocessor will append the tokenizer end + token to each input sequence. For T5Gemma models, this should be + `True`. Defaults to `True`. Call arguments: x: A string, `tf.Tensor` or list of python strings. Can also be a @@ -86,3 +96,122 @@ class T5GemmaSeq2SeqLMPreprocessor(Seq2SeqLMPreprocessor): backbone_cls = T5GemmaBackbone tokenizer_cls = T5GemmaTokenizer + + def __init__( + self, + tokenizer, + encoder_sequence_length=512, + decoder_sequence_length=512, + add_start_token=False, + add_end_token=True, + **kwargs, + ): + # Do not pass `add_start_token` and `add_end_token` to the base class. + super().__init__( + tokenizer=tokenizer, + encoder_sequence_length=encoder_sequence_length, + decoder_sequence_length=decoder_sequence_length, + **kwargs, + ) + # Store them directly on the subclass instance. + self.add_start_token = add_start_token + self.add_end_token = add_end_token + + @preprocessing_function + def call( + self, + x, + y=None, + sample_weight=None, + *, + encoder_sequence_length=None, + decoder_sequence_length=None, + sequence_length=None, + ): + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + + encoder_inputs = self.tokenizer(x["encoder_text"]) + encoder_token_ids, encoder_padding_mask = self.encoder_packer( + encoder_inputs, + sequence_length=encoder_sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + decoder_inputs = self.tokenizer(x["decoder_text"]) + decoder_token_ids, decoder_padding_mask = self.decoder_packer( + decoder_inputs, + sequence_length=decoder_sequence_length + 1, + add_start_value=True, + add_end_value=self.add_end_token, + ) + x = { + "encoder_token_ids": encoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_token_ids": decoder_token_ids[..., :-1], + "decoder_padding_mask": decoder_padding_mask[..., :-1], + } + y = decoder_token_ids[..., 1:] + sample_weight = decoder_padding_mask[..., 1:] + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + @preprocessing_function + def generate_preprocess( + self, + x, + *, + encoder_sequence_length=None, + decoder_sequence_length=None, + sequence_length=None, + ): + if not self.built: + self.build(None) + + if isinstance(x, dict): + encoder_text = x["encoder_text"] + decoder_text = x["decoder_text"] + else: + encoder_text = x + decoder_text = tf.fill((tf.shape(encoder_text)[0],), "") + + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + + encoder_token_ids = self.tokenizer(encoder_text) + encoder_token_ids, encoder_padding_mask = self.encoder_packer( + encoder_token_ids, + sequence_length=None, + add_start_value=self.add_start_token, + add_end_value=False, + ) + + decoder_token_ids = self.tokenizer(decoder_text) + decoder_token_ids, decoder_padding_mask = self.decoder_packer( + decoder_token_ids, + sequence_length=decoder_sequence_length, + add_start_value=True, + add_end_value=False, + ) + + return { + "encoder_token_ids": encoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_token_ids": decoder_token_ids, + "decoder_padding_mask": decoder_padding_mask, + } + + def get_config(self): + config = super().get_config() + config.update( + { + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config diff --git a/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py index 047814bd20..8213bd3dba 100644 --- a/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py +++ b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py @@ -1,39 +1,30 @@ -""" -T5Gemma weight conversion script. - -This script converts checkpoints from a Hugging Face T5Gemma model to a -KerasHub T5Gemma model. - -To run, first install the dependencies: -``` -pip install keras-core keras-nlp tensorflow-text -pip install transformers huggingface-hub sentencepiece absl-py torch -``` - -Then, log in to Hugging Face: -``` -huggingface-cli login -``` - -Finally, run the script to convert the weights: -``` -python convert_t5gemma_checkpoints.py --preset t5gemma_b_b_prefixlm_it -``` -""" - +import gc import os +import random +import shutil -import absl import huggingface_hub +import keras import numpy as np +import tensorflow as tf import torch import transformers +from absl import app +from absl import flags +from checkpoint_conversion_utils import get_md5_checksum +import keras_hub +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm import T5GemmaSeq2SeqLM from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import ( T5GemmaSeq2SeqLMPreprocessor, ) -from keras_hub.src.models.t5gemma.t5gemma_tokenizer import T5GemmaTokenizer + +random.seed(123) +torch.manual_seed(123) +device = torch.device("cpu") +# Force PyTorch to use CPU +torch.set_default_device(device) PRESET_MAP = { "t5gemma_s_s_ul2": "google/t5gemma-s-s-ul2", @@ -65,35 +56,27 @@ "t5gemma_9b_9b_ul2_it": "google/t5gemma-9b-9b-ul2-it", "t5gemma_9b_9b_prefixlm_it": "google/t5gemma-9b-9b-prefixlm-it", } -EXTRACT_DIR = "./model_t5gemma" -FLAGS = absl.flags.FLAGS -absl.flags.DEFINE_string( - "preset", - "t5gemma_b_b_prefixlm_it", - f"Must be one of {','.join(PRESET_MAP.keys())}.", -) -def download_hf_model(hf_model_name): - print(f"ā¬‡ļø Downloading Hugging Face model '{hf_model_name}'...") - hf_model_dir = huggingface_hub.snapshot_download( - repo_id=hf_model_name, - allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], - local_dir=EXTRACT_DIR, - local_dir_use_symlinks=False, - ) - print(f"āœ… Model downloaded to: {hf_model_dir}") - return hf_model_dir +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +) + +def convert_checkpoints(hf_model): + """Convert Hugging Face weights to Keras Hub format.""" + print("\n-> Convert original weights to KerasHub format.") -def convert_model(hf_model, preprocessor): + print("\n-> Load KerasHub model.") encoder_config = hf_model.config.encoder decoder_config = hf_model.config.decoder if decoder_config.hidden_activation == "gelu_pytorch_tanh": decoder_config.hidden_activation = "gelu_approximate" if encoder_config.hidden_activation == "gelu_pytorch_tanh": encoder_config.hidden_activation = "gelu_approximate" - keras_backbone = T5GemmaSeq2SeqLM.backbone_cls( + keras.config.set_floatx("float32") + keras_hub_model = keras_hub.models.T5GemmaBackbone( vocabulary_size=decoder_config.vocab_size, encoder_hidden_dim=encoder_config.hidden_size, encoder_intermediate_dim=encoder_config.intermediate_size, @@ -124,42 +107,27 @@ def convert_model(hf_model, preprocessor): attn_logit_softcapping=decoder_config.attn_logit_softcapping, final_logit_softcapping=decoder_config.final_logit_softcapping, rope_max_wavelength=decoder_config.rope_theta, + dtype="float32", ) - keras_model = T5GemmaSeq2SeqLM( - backbone=keras_backbone, preprocessor=preprocessor - ) - print("āœ… Keras model instantiated.") - return keras_model - - -def convert_tokenizer(hf_model_dir): - print("šŸ—£ļø Converting tokenizer...") - tokenizer_path = os.path.join(hf_model_dir, "tokenizer.model") - keras_tokenizer = T5GemmaTokenizer(proto=tokenizer_path) - print("āœ… Tokenizer converted.") - return keras_tokenizer - -def convert_weights(keras_model, hf_model): - print("šŸ‹ļø Converting weights...") hf_wts = hf_model.state_dict() - keras_backbone = keras_model.backbone - # Token Embeddings. - keras_backbone.token_embedding.embeddings.assign( + # Token embedding. + keras_hub_model.get_layer("encoder_token_embedding").embeddings.assign( hf_wts["encoder.embed_tokens.weight"] ) - keras_backbone.decoder_token_embedding.embeddings.assign( + keras_hub_model.get_layer("decoder_token_embedding").embeddings.assign( hf_wts["decoder.embed_tokens.weight"] ) # Encoder. - encoder_hidden_dim = keras_backbone.encoder_hidden_dim - encoder_num_attention_heads = keras_backbone.encoder_num_attention_heads - encoder_num_key_value_heads = keras_backbone.encoder_num_key_value_heads - encoder_head_dim = keras_backbone.encoder_head_dim - keras_backbone.encoder_norm.scale.assign(hf_wts["encoder.norm.weight"]) - for i in range(keras_backbone.encoder_num_layers): - encoder_layer = keras_backbone.get_layer(f"encoder_layer_{i}") + encoder_hidden_dim = keras_hub_model.encoder_hidden_dim + encoder_num_attention_heads = keras_hub_model.encoder_num_attention_heads + encoder_num_key_value_heads = keras_hub_model.encoder_num_key_value_heads + encoder_head_dim = keras_hub_model.encoder_head_dim + keras_hub_model.encoder_norm.scale.assign(hf_wts["encoder.norm.weight"]) + + for i in range(keras_hub_model.encoder_num_layers): + encoder_layer = keras_hub_model.get_layer(f"encoder_layer_{i}") hf_prefix = f"encoder.layers.{i}" # Self-attention. @@ -223,14 +191,15 @@ def convert_weights(keras_model, hf_model): ) # Decoder. - decoder_hidden_dim = keras_backbone.decoder_hidden_dim - decoder_num_attention_heads = keras_backbone.decoder_num_attention_heads - decoder_num_key_value_heads = keras_backbone.decoder_num_key_value_heads - decoder_head_dim = keras_backbone.decoder_head_dim - cross_attention_hidden_size = keras_backbone.cross_attention_hidden_size - keras_backbone.decoder_norm.scale.assign(hf_wts["decoder.norm.weight"]) - for i in range(keras_backbone.decoder_num_layers): - decoder_layer = keras_backbone.get_layer(f"decoder_layer_{i}") + decoder_hidden_dim = keras_hub_model.decoder_hidden_dim + decoder_num_attention_heads = keras_hub_model.decoder_num_attention_heads + decoder_num_key_value_heads = keras_hub_model.decoder_num_key_value_heads + decoder_head_dim = keras_hub_model.decoder_head_dim + cross_attention_hidden_size = keras_hub_model.cross_attention_hidden_size + keras_hub_model.decoder_norm.scale.assign(hf_wts["decoder.norm.weight"]) + + for i in range(keras_hub_model.decoder_num_layers): + decoder_layer = keras_hub_model.get_layer(f"decoder_layer_{i}") hf_prefix = f"decoder.layers.{i}" # Self-attention. @@ -331,81 +300,170 @@ def convert_weights(keras_model, hf_model): decoder_layer.post_feedforward_layernorm.scale.assign( hf_wts[f"{hf_prefix}.post_feedforward_layernorm.weight"] ) - print("āœ… Weights converted.") + return keras_hub_model -def validate_output(hf_model, keras_model, hf_tokenizer, keras_tokenizer): - hf_model.eval() - print("šŸ”Ž Validating tokenizer outputs...") - # Example sentence. - test_sentence = "What is the fastest land animal?" - hf_tokens = hf_tokenizer(test_sentence, return_tensors="pt")["input_ids"][ - 0 - ].tolist() - keras_tokens = keras_tokenizer.tokenize(test_sentence).numpy().tolist() - print(f"šŸ”¶ Test Sentence: '{test_sentence}'") - print(f"šŸ”¶ Hugging Face Tokens: {hf_tokens}") - print(f"šŸ”¶ Keras Tokens: {keras_tokens}") - assert hf_tokens == keras_tokens, "Tokenizer outputs do not match!" - print("āœ… Tokenizer outputs are consistent.") - print("šŸ”Ž Validating numeric outputs...") - input_ids_np = np.ones((1, 10), dtype="int32") - attention_mask_np = np.ones((1, 10), dtype="int32") - keras_inputs = { - "encoder_token_ids": input_ids_np, - "encoder_padding_mask": attention_mask_np, - "decoder_token_ids": input_ids_np, - "decoder_padding_mask": attention_mask_np, + +def extract_vocab(hf_model_dir): + """Extract vocabulary from the downloaded Hugging Face model directory.""" + source_path = os.path.join(hf_model_dir, "tokenizer.model") + vocabulary_path = os.path.join(FLAGS.preset, "tokenizer.model") + print(f"\n-> Save KerasHub vocab to `{vocabulary_path}`.") + + shutil.copyfile(source_path, vocabulary_path) + + keras_hub_tokenizer = keras_hub.models.T5GemmaTokenizer( + proto=vocabulary_path + ) + + print("-> Print MD5 checksum of the vocab file.") + print(f"`{vocabulary_path}` md5sum: ", get_md5_checksum(vocabulary_path)) + + return keras_hub_tokenizer + + +def check_output( + keras_hub_tokenizer, + keras_hub_model, + hf_tokenizer, + hf_model, +): + """Check the outputs of the Keras Hub and Hugging Face models.""" + print("\n-> Check the outputs.") + enc_sample_text = [ + "cricket is awesome, easily the best sport in the world!" + ] + dec_sample_text = [ + "football is good too, but nowhere near as good as cricket." + ] + + # KerasHub. + keras_hub_enc_token_ids = hf_tokenizer( + enc_sample_text, return_tensors="tf" + )["input_ids"] + keras_hub_dec_token_ids = hf_tokenizer( + dec_sample_text, return_tensors="tf" + )["input_ids"] + keras_hub_dec_token_ids = tf.concat( + [ + tf.constant([[keras_hub_tokenizer.start_token_id]]), + keras_hub_dec_token_ids, + ], + axis=-1, + ) + keras_hub_inputs = { + "encoder_token_ids": keras_hub_enc_token_ids, + "encoder_padding_mask": tf.ones_like(keras_hub_enc_token_ids), + "decoder_token_ids": keras_hub_dec_token_ids, + "decoder_padding_mask": tf.ones_like(keras_hub_dec_token_ids), } - hf_input_ids = torch.from_numpy(input_ids_np) - hf_attention_mask = torch.from_numpy(attention_mask_np) - hf_decoder_input_ids = hf_input_ids.clone() - hf_outputs = hf_model( - input_ids=hf_input_ids, - attention_mask=hf_attention_mask, + keras_hub_output = keras_hub_model.predict(keras_hub_inputs) + + # HF. + hf_enc_inputs = hf_tokenizer(enc_sample_text, return_tensors="pt") + hf_dec_inputs = hf_tokenizer(dec_sample_text, return_tensors="pt") + hf_decoder_input_ids = torch.cat( + [ + torch.tensor([[hf_tokenizer.bos_token_id]]), + hf_dec_inputs["input_ids"], + ], + dim=-1, + ) + hf_decoder_attention_mask = torch.cat( + [torch.ones(1, 1, dtype=torch.long), hf_dec_inputs["attention_mask"]], + dim=-1, + ) + + hf_output = hf_model( + **hf_enc_inputs, decoder_input_ids=hf_decoder_input_ids, + decoder_attention_mask=hf_decoder_attention_mask, + ) + + print("Encoder Outputs:") + print( + "KerasHub output:", + keras_hub_output["encoder_sequence_output"][0, 0, :10], + ) + print("HF output:", hf_output.encoder_last_hidden_state[0, 0, :10]) + print( + "Difference:", + np.mean( + keras_hub_output["encoder_sequence_output"] + - hf_output.encoder_last_hidden_state.detach().numpy() + ), + ) + + print("Decoder Outputs:") + print( + "KerasHub output:", + keras_hub_output["decoder_sequence_output"][0, 0, :10], ) - hf_final_hidden_states = hf_outputs.last_hidden_state.detach().numpy() - print("\nšŸ”Ž Validating final hidden states...") - keras_output = keras_model.backbone.predict(keras_inputs) - keras_final_hidden_states = keras_output["decoder_sequence_output"] - final_difference = np.mean( - np.abs(hf_final_hidden_states - keras_final_hidden_states) + print("HF output:", hf_output.last_hidden_state[0, 0, :10]) + print( + "Difference:", + np.mean( + keras_hub_output["decoder_sequence_output"] + - hf_output.last_hidden_state.detach().numpy() + ), ) - print(f"šŸ”¶ Keras final output shape: {keras_final_hidden_states.shape}") - print(f"šŸ”¶ HF final output shape: {hf_final_hidden_states.shape}") - print(f"šŸ”¶ Mean absolute difference: {final_difference:.6e}") - assert final_difference < 1e-4, "Final output difference is too high!" - print("āœ… Final hidden states are consistent.") def main(_): - preset = FLAGS.preset - print(f"šŸš€ Starting conversion for preset: {preset}") + os.makedirs(FLAGS.preset, exist_ok=True) + + hf_model_name = PRESET_MAP[FLAGS.preset] - hf_model_name = PRESET_MAP[preset] - hf_model_dir = download_hf_model(hf_model_name) + print("\n-> Download HF model files.") + hf_model_dir = huggingface_hub.snapshot_download( + repo_id=hf_model_name, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + ) - print("🧩 Loading Hugging Face model and tokenizer...") - hf_model = transformers.T5GemmaModel.from_pretrained(hf_model_dir) + print("\n-> Load HF model and HF tokenizer.") + hf_model = transformers.AutoModel.from_pretrained(hf_model_dir) + hf_model.eval() hf_tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model_dir) - print("āœ… Hugging Face model and tokenizer loaded.") - keras_tokenizer = convert_tokenizer(hf_model_dir) + keras_hub_model = convert_checkpoints(hf_model) + print("\n-> Load KerasHub tokenizer.") + keras_hub_tokenizer = extract_vocab(hf_model_dir) - keras_preprocessor = T5GemmaSeq2SeqLMPreprocessor( - tokenizer=keras_tokenizer, + check_output( + keras_hub_tokenizer, + keras_hub_model, + hf_tokenizer, + hf_model, + ) + print("\n-> Releasing HF backbone from memory.") + del hf_model + gc.collect() + preprocessor = T5GemmaSeq2SeqLMPreprocessor( + tokenizer=keras_hub_tokenizer, + encoder_sequence_length=512, + decoder_sequence_length=512, + ) + keras_lm = T5GemmaSeq2SeqLM( + backbone=keras_hub_model, + preprocessor=preprocessor, + dtype=keras_hub_model.dtype, ) - keras_model = convert_model(hf_model, keras_preprocessor) - convert_weights(keras_model, hf_model) - validate_output(hf_model, keras_model, hf_tokenizer, keras_tokenizer) + keras_lm.compile(sampler="greedy") + + print(f"\n-> Saving T5GemmaSeq2SeqLM preset to `{FLAGS.preset}`.") + keras_lm.save_to_preset(FLAGS.preset) + print("-> Preset saved successfully.") + + print("\n-> Testing preset loading.") + keras_lm = Seq2SeqLM.from_preset("t5gemma_b_b_prefixlm_it") + print("-> Preset loading verified successfully.") - print(f"šŸ’¾ Saving Keras model and tokenizer to preset '{preset}'...") - keras_model.save_to_preset(preset) - keras_tokenizer.save_to_preset(preset) - print("āœ… Preset saved successfully.") - print("šŸŽ‰ Conversion complete!") + # Show the MD5 checksum of the model weights after saving. + print("\n-> Print MD5 checksum of the model weights.") + weights_path = os.path.join(FLAGS.preset, "model.weights.h5") + print(f"`{weights_path}` md5sum: ", get_md5_checksum(weights_path)) if __name__ == "__main__": - absl.app.run(main) + flags.mark_flag_as_required("preset") + app.run(main) From 87a221d07b3ba458ddc517a65d12a09b689c65e3 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Fri, 22 Aug 2025 14:15:37 +0400 Subject: [PATCH 17/19] preset test: Register and test a preset (to be replaced later by the team with the full set) --- keras_hub/src/models/t5gemma/__init__.py | 5 +++++ keras_hub/src/models/t5gemma/t5gemma_presets.py | 15 ++++++++++++++- .../convert_t5gemma_checkpoints.py | 4 ++++ 3 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 keras_hub/src/models/t5gemma/__init__.py diff --git a/keras_hub/src/models/t5gemma/__init__.py b/keras_hub/src/models/t5gemma/__init__.py new file mode 100644 index 0000000000..e95c262c19 --- /dev/null +++ b/keras_hub/src/models/t5gemma/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.models.t5gemma.t5gemma_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, T5GemmaBackbone) diff --git a/keras_hub/src/models/t5gemma/t5gemma_presets.py b/keras_hub/src/models/t5gemma/t5gemma_presets.py index d976272974..25bfb8465b 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_presets.py +++ b/keras_hub/src/models/t5gemma/t5gemma_presets.py @@ -1,2 +1,15 @@ # Metadata for loading pretrained model weights. -backbone_presets = {} +backbone_presets = { + "t5gemma_b_b_prefixlm_it": { + "metadata": { + "description": ( + "T5Gemma B/B model with a base encoder and base decoder, " + "adapted as a prefix language model and fine-tuned for " + "instruction following." + ), + "params": 591490560, + "path": "t5gemma", + }, + "kaggle_handle": "kaggle://harshaljanjani/t5gemma/keras/t5gemma_b_b_prefixlm_it", + }, +} diff --git a/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py index 8213bd3dba..90ed208de0 100644 --- a/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py +++ b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py @@ -51,6 +51,10 @@ "t5gemma_2b_2b_prefixlm": "google/t5gemma-2b-2b-prefixlm", "t5gemma_2b_2b_ul2_it": "google/t5gemma-2b-2b-ul2-it", "t5gemma_2b_2b_prefixlm_it": "google/t5gemma-2b-2b-prefixlm-it", + "t5gemma_9b_2b_ul2": "google/t5gemma-9b-2b-ul2", + "t5gemma_9b_2b_prefixlm": "google/t5gemma-9b-2b-prefixlm", + "t5gemma_9b_2b_ul2_it": "google/t5gemma-9b-2b-ul2-it", + "t5gemma_9b_2b_prefixlm_it": "google/t5gemma-9b-2b-prefixlm-it", "t5gemma_9b_9b_ul2": "google/t5gemma-9b-9b-ul2", "t5gemma_9b_9b_prefixlm": "google/t5gemma-9b-9b-prefixlm", "t5gemma_9b_9b_ul2_it": "google/t5gemma-9b-9b-ul2-it", From 9c7905852094a60b63e5c22a5e5535def9a18a41 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Sun, 24 Aug 2025 14:41:19 +0400 Subject: [PATCH 18/19] =?UTF-8?q?nit:=20Sharded=20weights=20don=E2=80=99t?= =?UTF-8?q?=20include=20`model.weights.h5`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/checkpoint_conversion/convert_t5gemma_checkpoints.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py index 90ed208de0..f0c70f8265 100644 --- a/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py +++ b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py @@ -462,11 +462,6 @@ def main(_): keras_lm = Seq2SeqLM.from_preset("t5gemma_b_b_prefixlm_it") print("-> Preset loading verified successfully.") - # Show the MD5 checksum of the model weights after saving. - print("\n-> Print MD5 checksum of the model weights.") - weights_path = os.path.join(FLAGS.preset, "model.weights.h5") - print(f"`{weights_path}` md5sum: ", get_md5_checksum(weights_path)) - if __name__ == "__main__": flags.mark_flag_as_required("preset") From f7e356f0feeca21b26862aebd4d3ef3c5a435ae1 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Mon, 25 Aug 2025 23:30:00 +0400 Subject: [PATCH 19/19] nits: Address reviews + replace gated model --- .../t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py | 11 +++++------ .../src/utils/transformers/convert_t5gemma_test.py | 4 ++-- .../convert_t5gemma_checkpoints.py | 2 +- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py index 9ce57a7d60..1570d4796d 100644 --- a/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +++ b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py @@ -40,8 +40,9 @@ class T5GemmaSeq2SeqLMPreprocessor(Seq2SeqLMPreprocessor): `True`. Defaults to `True`. Call arguments: - x: A string, `tf.Tensor` or list of python strings. Can also be a - dictionary with `encoder_text` and `decoder_text` keys. + x: A dictionary with two keys, `"encoder_text"` and `"decoder_text"`. + The values can be a string, a `tf.Tensor` or a list of python + strings. y: Label data. Should always be `None` as the layer generates labels. sample_weight: Label weights. Should always be `None` as the layer generates label weights. @@ -60,10 +61,8 @@ class T5GemmaSeq2SeqLMPreprocessor(Seq2SeqLMPreprocessor): "t5gemma_b_b_prefixlm_it" ) - # Tokenize and pack a single sentence. - sentence = tf.constant("The quick brown fox jumped.") - preprocessor(sentence) - + # For example usage, see the dictionary example below which provides + # both encoder and decoder text. # Tokenize a batch of sentences. preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) # Tokenize a dictionary with separate encoder and decoder inputs. diff --git a/keras_hub/src/utils/transformers/convert_t5gemma_test.py b/keras_hub/src/utils/transformers/convert_t5gemma_test.py index dfa8f6f0ad..939984eba5 100644 --- a/keras_hub/src/utils/transformers/convert_t5gemma_test.py +++ b/keras_hub/src/utils/transformers/convert_t5gemma_test.py @@ -11,14 +11,14 @@ class TestTask(TestCase): @pytest.mark.large def test_convert_tiny_preset(self): model = T5GemmaSeq2SeqLM.from_preset( - "hf://google/t5gemma-b-b-prefixlm-it" + "hf://harshaljanjani/tiny-t5gemma-test" ) prompt = "What is the capital of France?" model.generate([prompt], max_length=15) @pytest.mark.large def test_class_detection(self): - preset_name = "hf://google/t5gemma-b-b-prefixlm-it" + preset_name = "hf://harshaljanjani/tiny-t5gemma-test" model = Seq2SeqLM.from_preset( preset_name, load_weights=False, diff --git a/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py index f0c70f8265..e4cd550d9d 100644 --- a/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py +++ b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py @@ -459,7 +459,7 @@ def main(_): print("-> Preset saved successfully.") print("\n-> Testing preset loading.") - keras_lm = Seq2SeqLM.from_preset("t5gemma_b_b_prefixlm_it") + keras_lm = Seq2SeqLM.from_preset(FLAGS.preset) print("-> Preset loading verified successfully.")