diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 8b6aa475e7..0941e63402 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -599,6 +599,18 @@ T5Preprocessor as T5Preprocessor, ) from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer +from keras_hub.src.models.t5gemma.t5gemma_backbone import ( + T5GemmaBackbone as T5GemmaBackbone, +) +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm import ( + T5GemmaSeq2SeqLM as T5GemmaSeq2SeqLM, +) +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import ( + T5GemmaSeq2SeqLMPreprocessor as T5GemmaSeq2SeqLMPreprocessor, +) +from keras_hub.src.models.t5gemma.t5gemma_tokenizer import ( + T5GemmaTokenizer as T5GemmaTokenizer, +) from keras_hub.src.models.task import Task as Task from keras_hub.src.models.text_classifier import TextClassifier as Classifier from keras_hub.src.models.text_classifier import ( diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 082078184f..2677d89ee7 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -87,6 +87,9 @@ SigLIPTokenizer as SigLIPTokenizer, ) from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer +from keras_hub.src.models.t5gemma.t5gemma_tokenizer import ( + T5GemmaTokenizer as T5GemmaTokenizer, +) from keras_hub.src.models.whisper.whisper_tokenizer import ( WhisperTokenizer as WhisperTokenizer, ) diff --git a/keras_hub/src/models/t5gemma/__init__.py b/keras_hub/src/models/t5gemma/__init__.py new file mode 100644 index 0000000000..e95c262c19 --- /dev/null +++ b/keras_hub/src/models/t5gemma/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.models.t5gemma.t5gemma_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, T5GemmaBackbone) diff --git a/keras_hub/src/models/t5gemma/t5gemma_attention.py b/keras_hub/src/models/t5gemma/t5gemma_attention.py new file mode 100644 index 0000000000..51c5490496 --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_attention.py @@ -0,0 +1,370 @@ +import inspect + +import keras + +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.models.gemma.gemma_attention import CachedGemmaAttention +from keras_hub.src.models.t5gemma.t5gemma_layers import ( + t5gemma_kernel_initializer, +) +from keras_hub.src.utils.keras_utils import clone_initializer + + +def repeat_kv(hidden_states, n_rep): + """Repeats the key/value hidden states to match the number of query heads + for Grouped Query Attention (GQA). + + This function is used in `T5GemmaAttention` to broadcast key and value + states across multiple query heads when Grouped Query Attention (GQA) is + used (i.e., when `num_query_heads` > `num_key_value_heads`). + + Args: + hidden_states: Tensor, The key or value hidden states with shape + `(batch, sequence_length, num_key_value_heads, head_dim)`. + n_rep: int, The number of times to repeat the key/value heads. This is + typically `num_query_heads // num_key_value_heads`. + + Returns: + Tensor: The expanded key/value hidden states with shape + `(batch, sequence_length, num_query_heads, head_dim)`. + """ + if n_rep == 1: + return hidden_states + batch, slen, num_key_value_heads, head_dim = keras.ops.shape(hidden_states) + hidden_states = keras.ops.expand_dims(hidden_states, 3) + hidden_states = keras.ops.tile(hidden_states, (1, 1, 1, n_rep, 1)) + return keras.ops.reshape( + hidden_states, (batch, slen, num_key_value_heads * n_rep, head_dim) + ) + + +class T5GemmaAttention(CachedGemmaAttention): + """A unified attention layer for T5Gemma that handles both self-attention + and cross-attention. + + This layer performs attention with optional Rotary Positional Embeddings + (RoPE) and supports Grouped Query Attention (GQA). It is used in + `T5GemmaEncoderLayer` and `T5GemmaDecoderLayer`. + + Args: + hidden_size: int, The dimensionality of the hidden states. + num_attention_heads: int, The number of attention heads. + num_key_value_heads: int, The number of key-value heads. For GQA, this + can be less than `num_attention_heads`. + query_pre_attn_scalar: float, Scalar to multiply queries by before + attention. + attention_bias: bool, Whether to include bias in the dense layers. + head_dim: int, The dimensionality of each attention head. + attention_type: str, The type of attention, either 'self' or 'cross'. + Defaults to 'self'. + cross_attention_hidden_size: int, optional, The dimensionality of + encoder hidden states for cross-attention. Defaults to `None`. + initializer_range: float, The range for the random normal initializer + for kernel weights. Defaults to `0.02`. + attention_dropout: float, The dropout rate applied to attention weights. + Defaults to `0.0`. + attn_logit_softcapping: float, optional, The softcapping value for + attention logits. Defaults to `None`. + rope_max_wavelength: float, The maximum wavelength for Rotary Positional + Embeddings. Defaults to `10000.0`. Only used for self-attention. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + head_dim, + attention_type="self", + cross_attention_hidden_size=None, + initializer_range=0.02, + attention_dropout=0.0, + attn_logit_softcapping=None, + rope_max_wavelength=10000.0, + dtype=None, + **kwargs, + ): + super().__init__( + head_dim=head_dim, + num_query_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + kernel_initializer=t5gemma_kernel_initializer(initializer_range), + logit_soft_cap=attn_logit_softcapping, + dropout=attention_dropout, + query_head_dim_normalize=False, + use_sliding_window_attention=False, + dtype=dtype, + **kwargs, + ) + if attention_type not in ["self", "cross"]: + raise ValueError( + f"attention_type must be 'self' or 'cross', but got " + f"{attention_type}" + ) + self.attention_type = attention_type + self.hidden_size = hidden_size + self.cross_attention_hidden_size = ( + cross_attention_hidden_size or hidden_size + ) + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.rope_max_wavelength = rope_max_wavelength + self.num_key_value_groups = ( + self.num_query_heads // self.num_key_value_heads + ) + self.scaling = self.query_pre_attn_scalar**-0.5 + if self.attention_type == "self": + self.rotary_embedding = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + sequence_axis=1, + feature_axis=3, + name="rotary_embedding", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + self._kernel_initializer = t5gemma_kernel_initializer( + self.initializer_range + ) + + if self.attention_type == "cross": + hidden_states_shape, kv_states_shape = input_shape + else: + hidden_states_shape = input_shape + kv_states_shape = input_shape + # Query projection layer. + self.hidden_dim = hidden_states_shape[-1] + self.query_dense = keras.layers.EinsumDense( + equation="btd,dnh->btnh", + output_shape=(None, self.num_query_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="nh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="query", + ) + self.query_dense.build(hidden_states_shape) + + # Key projection layer. + self.key_dense = keras.layers.EinsumDense( + equation="bsd,dkh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="kh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="key", + ) + self.key_dense.build(kv_states_shape) + + # Value projection layer. + self.value_dense = keras.layers.EinsumDense( + equation="bsd,dkh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="kh" if self.attention_bias else None, + dtype=self.dtype_policy, + name="value", + ) + self.value_dense.build(kv_states_shape) + + # Output projection layer. + self.output_dense = keras.layers.EinsumDense( + equation="btnh,nhd->btd", + output_shape=(None, self.hidden_dim), + kernel_initializer=clone_initializer(self._kernel_initializer), + bias_axes="d" if self.attention_bias else None, + dtype=self.dtype_policy, + name="attention_output", + ) + self.output_dense.build( + ( + hidden_states_shape[0], + hidden_states_shape[1], + self.num_query_heads, + self.head_dim, + ) + ) + self.dropout_layer = keras.layers.Dropout( + rate=self.attention_dropout, + dtype=self.dtype_policy, + ) + self.softmax = keras.layers.Softmax(axis=-1, dtype="float32") + self.built = True + + def _compute_attention_without_fused_op( + self, query_states, key_states, value_states, attention_mask, training + ): + attn_weights = keras.ops.einsum( + "btnh,bsnh->bnts", query_states, key_states + ) + attn_weights *= self.scaling + if self.logit_soft_cap is not None: + attn_weights = attn_weights / self.logit_soft_cap + attn_weights = keras.ops.tanh(attn_weights) + attn_weights = attn_weights * self.logit_soft_cap + if attention_mask is not None: + attn_weights += attention_mask + attn_weights = keras.ops.cast( + self.softmax(attn_weights), + query_states.dtype, + ) + attn_weights = self.dropout_layer(attn_weights, training=training) + attn_output = keras.ops.einsum( + "bnts,bsnh->btnh", attn_weights, value_states + ) + return attn_output + + def _compute_attention( + self, query_states, key_states, value_states, attention_mask, training + ): + if self._use_fused_attention_op(): + kwargs = {"bias": attention_mask} + if self.logit_soft_cap is not None: + sig = inspect.signature(keras.ops.dot_product_attention) + # This is only supported in JAX TPU backend. + # https://keras.io/api/ops/nn/#dot_product_attention-function + if "attn_logits_soft_cap" in sig.parameters: + kwargs["attn_logits_soft_cap"] = self.logit_soft_cap + return keras.ops.dot_product_attention( + query=query_states, + key=key_states, + value=value_states, + scale=self.scaling, + **kwargs, + ) + return self._compute_attention_without_fused_op( + query_states, + key_states, + value_states, + attention_mask, + training, + ) + + def call( + self, + inputs, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + if self.attention_type == "cross": + if not isinstance(inputs, (list, tuple)) or len(inputs) != 2: + raise ValueError( + "For cross-attention, `inputs` must be a list or tuple of " + "two tensors: `[hidden_states, encoder_hidden_states]`." + ) + hidden_states, kv_states = inputs + query_states = self.query_dense(hidden_states) + if cache is not None: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set for " + "cross-attention caching." + ) + key_states, value_states = cache[:, 0, ...], cache[:, 1, ...] + updated_cache = cache + else: + key_states = self.key_dense(kv_states) + value_states = self.value_dense(kv_states) + updated_cache = keras.ops.stack( + (key_states, value_states), axis=1 + ) + # Repeat key-value heads for GQA. + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_output = self._compute_attention( + query_states, key_states, value_states, attention_mask, training + ) + attn_output = self.output_dense(attn_output) + return attn_output, updated_cache + else: # Self-attention + hidden_states = inputs + kv_states = hidden_states + query_states = self.query_dense(hidden_states) + key_states = self.key_dense(kv_states) + value_states = self.value_dense(kv_states) + start_index = ( + 0 if cache_update_index is None else cache_update_index + ) + query_states = self.rotary_embedding( + query_states, start_index=start_index + ) + key_states = self.rotary_embedding( + key_states, start_index=start_index + ) + if cache is not None: + if cache_update_index is None: + raise ValueError( + "Both `cache` and `cache_update_index` must be passed " + "for self-attention caching." + ) + key_cache, value_cache = cache[:, 0, ...], cache[:, 1, ...] + start = [0, cache_update_index, 0, 0] + key_states = keras.ops.slice_update( + key_cache, start, key_states + ) + value_states = keras.ops.slice_update( + value_cache, start, value_states + ) + cache = keras.ops.stack((key_states, value_states), axis=1) + elif cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + "`None`." + ) + else: + cache = keras.ops.stack((key_states, value_states), axis=1) + + # Repeat key-value heads for GQA. + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = self._compute_attention( + query_states, key_states, value_states, attention_mask, training + ) + attn_output = self.output_dense(attn_output) + return attn_output, cache + + def compute_output_shape(self, input_shape): + if self.attention_type == "cross": + hidden_states_shape, kv_states_shape = input_shape + else: + hidden_states_shape = input_shape + kv_states_shape = input_shape + attn_output_shape = hidden_states_shape + kv_len = kv_states_shape[1] + cache_shape = ( + hidden_states_shape[0], # batch + 2, # key and value + kv_len, + self.num_key_value_heads, + self.head_dim, + ) + return attn_output_shape, cache_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "head_dim": self.head_dim, + "num_attention_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "attention_type": self.attention_type, + "cross_attention_hidden_size": self.cross_attention_hidden_size, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "attn_logit_softcapping": self.logit_soft_cap, + "rope_max_wavelength": self.rope_max_wavelength, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone.py b/keras_hub/src/models/t5gemma/t5gemma_backbone.py new file mode 100644 index 0000000000..61d6665746 --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone.py @@ -0,0 +1,366 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.models.t5gemma.t5gemma_decoder import T5GemmaDecoderLayer +from keras_hub.src.models.t5gemma.t5gemma_encoder import T5GemmaEncoderLayer +from keras_hub.src.models.t5gemma.t5gemma_layers import ( + t5gemma_kernel_initializer, +) +from keras_hub.src.utils.keras_utils import clone_initializer + + +@keras_hub_export("keras_hub.models.T5GemmaBackbone") +class T5GemmaBackbone(Backbone): + """T5Gemma backbone model. + + This class implements the encoder-decoder backbone of the T5Gemma model, + consisting of an embedding layer, a stack of encoder layers, and a + stack of decoder layers. + + Args: + vocabulary_size: int, The size of the vocabulary. + encoder_hidden_dim: int, The hidden dimensionality of the encoder. + encoder_intermediate_dim: int, The intermediate size of the encoder's + feed-forward networks. + encoder_num_layers: int, The number of encoder layers. + encoder_num_attention_heads: int, The number of attention heads in the + encoder. + encoder_num_key_value_heads: int, The number of key-value heads in the + encoder. + encoder_head_dim: int, The dimensionality of each attention head in the + encoder. + encoder_layer_types: list of str, A list of strings specifying the type + of attention layer for each encoder layer. Each element can be + either `"sliding_attention"` or `"full_attention"`. For example, + `["full_attention", "sliding_attention", ...]`. + decoder_hidden_dim: int, The hidden dimensionality of the decoder. + decoder_intermediate_dim: int, The intermediate size of the decoder's + feed-forward networks. + decoder_num_layers: int, The number of decoder layers. + decoder_num_attention_heads: int, The number of attention heads in the + decoder. + decoder_num_key_value_heads: int, The number of key-value heads in the + decoder. + decoder_head_dim: int, The dimensionality of each attention head in the + decoder. + decoder_layer_types: list of str, A list of strings specifying the type + of attention layer for each decoder layer. Each element can be + either `"sliding_attention"` or `"full_attention"`. For example, + `["full_attention", "sliding_attention", ...]`. + dropout_rate: float, The dropout rate applied throughout the model. + Defaults to `0.0`. + rms_norm_eps: float, The epsilon value for RMS normalization. Defaults + to `1e-6`. + query_pre_attn_scalar: float, Scalar to multiply queries by before + attention. Defaults to `1.0`. + attention_bias: bool, Whether to include bias in attention computations. + Defaults to `False`. + hidden_activation: str, The activation function used in the feed-forward + networks. Defaults to `"gelu_approximate"`. + tie_word_embeddings: bool, Whether to tie input and output word + embeddings. Defaults to `True`. + initializer_range: float, The range for the random normal initializer. + Defaults to `0.02`. + attention_dropout: float, The dropout rate applied to attention weights. + Defaults to `0.0`. + sliding_window: int, optional, The window size for sliding attention. + Required if any `layer_type` is `"sliding_attention"`. Defaults to + `None`. + cross_attention_hidden_size: int, optional, The hidden size for + cross-attention in the decoder layers. If None, it defaults to + `encoder_hidden_dim`. Defaults to `None`. + attn_logit_softcapping: float, optional, The softcapping value for + attention logits. Defaults to `None`. + final_logit_softcapping: float, optional, The softcapping value for + final logits. Defaults to `None`. + rope_max_wavelength: float, The maximum wavelength for Rotary Positional + Embeddings. Defaults to `10000.0`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent `Backbone` + class. + + Examples: + ```python + import numpy as np + from keras_hub.models import T5GemmaBackbone + + input_data = { + "encoder_token_ids": np.ones(shape=(1, 12), dtype="int32"), + "encoder_padding_mask": np.array( + [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], dtype="int32" + ), + "decoder_token_ids": np.ones(shape=(1, 8), dtype="int32"), + "decoder_padding_mask": np.array( + [[1, 1, 1, 1, 1, 1, 1, 1]], dtype="int32" + ), + } + + # Randomly initialized T5Gemma backbone with custom config. + model = T5GemmaBackbone( + vocabulary_size=32000, + # Encoder parameters. + encoder_hidden_dim=256, + encoder_intermediate_dim=512, + encoder_num_layers=4, + encoder_num_attention_heads=4, + encoder_num_key_value_heads=2, + encoder_head_dim=64, + encoder_layer_types=["full_attention"] * 4, + # Decoder parameters. + decoder_hidden_dim=256, + decoder_intermediate_dim=512, + decoder_num_layers=4, + decoder_num_attention_heads=4, + decoder_num_key_value_heads=2, + decoder_head_dim=64, + decoder_layer_types=["full_attention"] * 4, + # Common parameters. + dropout_rate=0.1, + rms_norm_eps=1e-6, + query_pre_attn_scalar=1.0, + attention_bias=False, + hidden_activation="gelu_approximate", + ) + output = model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + encoder_hidden_dim, + encoder_intermediate_dim, + encoder_num_layers, + encoder_num_attention_heads, + encoder_num_key_value_heads, + encoder_head_dim, + encoder_layer_types, + decoder_hidden_dim, + decoder_intermediate_dim, + decoder_num_layers, + decoder_num_attention_heads, + decoder_num_key_value_heads, + decoder_head_dim, + decoder_layer_types, + dropout_rate=0.0, + rms_norm_eps=1e-6, + query_pre_attn_scalar=1.0, + attention_bias=False, + hidden_activation="gelu_approximate", + tie_word_embeddings=True, + initializer_range=0.02, + attention_dropout=0.0, + sliding_window=None, + cross_attention_hidden_size=None, + attn_logit_softcapping=None, + final_logit_softcapping=None, + rope_max_wavelength=10000.0, + dtype=None, + **kwargs, + ): + self.kernel_initializer = t5gemma_kernel_initializer(initializer_range) + + # === Layers === + self.token_embedding = keras.layers.Embedding( + input_dim=vocabulary_size, + output_dim=encoder_hidden_dim, + embeddings_initializer=clone_initializer(self.kernel_initializer), + dtype=dtype, + name="encoder_token_embedding", + ) + self.decoder_token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=decoder_hidden_dim, + tie_weights=tie_word_embeddings, + embeddings_initializer=clone_initializer(self.kernel_initializer), + dtype=dtype, + name="decoder_token_embedding", + ) + self.encoder_layers = [ + T5GemmaEncoderLayer( + hidden_size=encoder_hidden_dim, + rms_norm_eps=rms_norm_eps, + num_attention_heads=encoder_num_attention_heads, + num_key_value_heads=encoder_num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + intermediate_size=encoder_intermediate_dim, + hidden_activation=hidden_activation, + head_dim=encoder_head_dim, + dropout_rate=dropout_rate, + initializer_range=initializer_range, + attention_dropout=attention_dropout, + layer_type=encoder_layer_types[i], + sliding_window=sliding_window, + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=rope_max_wavelength, + name=f"encoder_layer_{i}", + dtype=dtype, + ) + for i in range(encoder_num_layers) + ] + self.encoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype) + self.encoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype) + self.decoder_layers = [ + T5GemmaDecoderLayer( + hidden_size=decoder_hidden_dim, + rms_norm_eps=rms_norm_eps, + num_attention_heads=decoder_num_attention_heads, + num_key_value_heads=decoder_num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + intermediate_size=decoder_intermediate_dim, + hidden_activation=hidden_activation, + dropout_rate=dropout_rate, + initializer_range=initializer_range, + head_dim=decoder_head_dim, + attention_dropout=attention_dropout, + layer_type=decoder_layer_types[i], + sliding_window=sliding_window, + cross_attention_hidden_size=( + cross_attention_hidden_size or encoder_hidden_dim + ), + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=rope_max_wavelength, + name=f"decoder_layer_{i}", + dtype=dtype, + ) + for i in range(decoder_num_layers) + ] + self.decoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype) + self.decoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype) + + # === Functional Model === + encoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_token_ids" + ) + encoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_padding_mask" + ) + decoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_token_ids" + ) + decoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_padding_mask" + ) + + # Encoder. + encoder_embeddings = self.token_embedding(encoder_token_id_input) + encoder_embeddings = encoder_embeddings * keras.ops.cast( + keras.ops.sqrt(encoder_hidden_dim), encoder_embeddings.dtype + ) + encoder_hidden_states = self.encoder_dropout(encoder_embeddings) + for layer in self.encoder_layers: + encoder_hidden_states = layer( + encoder_hidden_states, padding_mask=encoder_padding_mask_input + ) + encoder_output = self.encoder_norm(encoder_hidden_states) + encoder_output = self.encoder_dropout(encoder_output) + + # Decoder. + decoder_embeddings = self.decoder_token_embedding( + decoder_token_id_input + ) + decoder_embeddings = decoder_embeddings * keras.ops.cast( + keras.ops.sqrt(decoder_hidden_dim), decoder_embeddings.dtype + ) + decoder_hidden_states = self.decoder_dropout(decoder_embeddings) + for layer in self.decoder_layers: + decoder_hidden_states, _ = layer( + (decoder_hidden_states, encoder_output), + self_attention_padding_mask=decoder_padding_mask_input, + cross_attention_padding_mask=encoder_padding_mask_input, + ) + decoder_output = self.decoder_norm(decoder_hidden_states) + decoder_output = self.decoder_dropout(decoder_output) + + super().__init__( + inputs={ + "encoder_token_ids": encoder_token_id_input, + "encoder_padding_mask": encoder_padding_mask_input, + "decoder_token_ids": decoder_token_id_input, + "decoder_padding_mask": decoder_padding_mask_input, + }, + outputs={ + "encoder_sequence_output": encoder_output, + "decoder_sequence_output": decoder_output, + }, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_intermediate_dim = encoder_intermediate_dim + self.encoder_num_layers = encoder_num_layers + self.encoder_num_attention_heads = encoder_num_attention_heads + self.encoder_num_key_value_heads = encoder_num_key_value_heads + self.encoder_head_dim = encoder_head_dim + self.encoder_layer_types = encoder_layer_types + self.decoder_hidden_dim = decoder_hidden_dim + self.decoder_intermediate_dim = decoder_intermediate_dim + self.decoder_num_layers = decoder_num_layers + self.decoder_num_attention_heads = decoder_num_attention_heads + self.decoder_num_key_value_heads = decoder_num_key_value_heads + self.decoder_head_dim = decoder_head_dim + self.decoder_layer_types = decoder_layer_types + self.vocabulary_size = vocabulary_size + self.dropout_rate = dropout_rate + self.rms_norm_eps = rms_norm_eps + self.tie_word_embeddings = tie_word_embeddings + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.hidden_activation = hidden_activation + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.sliding_window = sliding_window + self.cross_attention_hidden_size = ( + cross_attention_hidden_size or encoder_hidden_dim + ) + self.attn_logit_softcapping = attn_logit_softcapping + self.final_logit_softcapping = final_logit_softcapping + self.rope_max_wavelength = rope_max_wavelength + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "encoder_hidden_dim": self.encoder_hidden_dim, + "encoder_intermediate_dim": self.encoder_intermediate_dim, + "encoder_num_layers": self.encoder_num_layers, + "encoder_num_attention_heads": self.encoder_num_attention_heads, + "encoder_num_key_value_heads": self.encoder_num_key_value_heads, + "encoder_layer_types": self.encoder_layer_types, + "encoder_head_dim": self.encoder_head_dim, + "decoder_hidden_dim": self.decoder_hidden_dim, + "decoder_intermediate_dim": self.decoder_intermediate_dim, + "decoder_num_layers": self.decoder_num_layers, + "decoder_num_attention_heads": self.decoder_num_attention_heads, + "decoder_num_key_value_heads": self.decoder_num_key_value_heads, + "decoder_layer_types": self.decoder_layer_types, + "decoder_head_dim": self.decoder_head_dim, + "dropout_rate": self.dropout_rate, + "rms_norm_eps": self.rms_norm_eps, + "tie_word_embeddings": self.tie_word_embeddings, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "hidden_activation": self.hidden_activation, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "sliding_window": self.sliding_window, + "cross_attention_hidden_size": self.cross_attention_hidden_size, + "attn_logit_softcapping": self.attn_logit_softcapping, + "final_logit_softcapping": self.final_logit_softcapping, + "rope_max_wavelength": self.rope_max_wavelength, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py b/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py new file mode 100644 index 0000000000..5c5bfe8229 --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_backbone_test.py @@ -0,0 +1,105 @@ +import keras +import pytest + +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.tests.test_case import TestCase + + +class T5GemmaBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 100, + "encoder_hidden_dim": 32, + "encoder_intermediate_dim": 64, + "encoder_num_layers": 2, + "encoder_num_attention_heads": 4, + "encoder_num_key_value_heads": 2, + "encoder_head_dim": 8, + "encoder_layer_types": ["sliding_attention", "full_attention"], + "decoder_hidden_dim": 32, + "decoder_intermediate_dim": 64, + "decoder_num_layers": 2, + "decoder_num_attention_heads": 4, + "decoder_num_key_value_heads": 2, + "decoder_head_dim": 8, + "decoder_layer_types": ["sliding_attention", "full_attention"], + "dropout_rate": 0.1, + "rms_norm_eps": 1e-6, + "tie_word_embeddings": True, + "query_pre_attn_scalar": 1.0, + "attention_bias": False, + "hidden_activation": "gelu_approximate", + "sliding_window": 16, + "cross_attention_hidden_size": 32, + "attn_logit_softcapping": 50.0, + "rope_max_wavelength": 10000.0, + "initializer_range": 0.04, + "attention_dropout": 0.1, + } + self.input_data = { + "encoder_token_ids": keras.ops.ones((2, 16), dtype="int32"), + "encoder_padding_mask": keras.ops.ones((2, 16), dtype="int32"), + "decoder_token_ids": keras.ops.ones((2, 16), dtype="int32"), + "decoder_padding_mask": keras.ops.ones((2, 16), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=T5GemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape={ + "encoder_sequence_output": (2, 16, 32), + "decoder_sequence_output": (2, 16, 32), + }, + ) + + def test_asymmetrical_backbone(self): + asym_kwargs = { + "vocabulary_size": 100, + "encoder_hidden_dim": 48, + "encoder_intermediate_dim": 96, + "encoder_num_layers": 3, + "encoder_num_attention_heads": 6, + "encoder_num_key_value_heads": 3, + "encoder_head_dim": 8, + "encoder_layer_types": ["full_attention"] * 3, + "decoder_hidden_dim": 32, + "decoder_intermediate_dim": 64, + "decoder_num_layers": 2, + "decoder_num_attention_heads": 4, + "decoder_num_key_value_heads": 2, + "decoder_head_dim": 8, + "decoder_layer_types": ["sliding_attention", "full_attention"], + "sliding_window": 16, + "dropout_rate": 0.1, + "rms_norm_eps": 1e-6, + "tie_word_embeddings": True, + "cross_attention_hidden_size": 48, + } + self.run_backbone_test( + cls=T5GemmaBackbone, + init_kwargs=asym_kwargs, + input_data=self.input_data, + expected_output_shape={ + "encoder_sequence_output": (2, 16, 48), + "decoder_sequence_output": (2, 16, 32), + }, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=T5GemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in T5GemmaBackbone.presets: + self.run_preset_test( + cls=T5GemmaBackbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/t5gemma/t5gemma_decoder.py b/keras_hub/src/models/t5gemma/t5gemma_decoder.py new file mode 100644 index 0000000000..fb9fb6950d --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_decoder.py @@ -0,0 +1,355 @@ +import keras + +from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.models.t5gemma.t5gemma_attention import T5GemmaAttention +from keras_hub.src.models.t5gemma.t5gemma_layers import T5GemmaMLP + + +class T5GemmaDecoderLayer(keras.layers.Layer): + """Decoder layer for the T5Gemma model. + + This layer implements a single decoder block in the T5Gemma architecture, + comprising self-attention, cross-attention, and a feed-forward network + (MLP). + + Args: + hidden_size: int, The dimensionality of the hidden states. + rms_norm_eps: float, The epsilon value for RMS normalization. + num_attention_heads: int, The number of attention heads in + self-attention and cross-attention. + num_key_value_heads: int, The number of key-value heads for grouped + query attention. + query_pre_attn_scalar: float, Scalar to multiply queries by before + attention. + attention_bias: bool, Whether to include bias in attention computations. + intermediate_size: int, The intermediate size of the feed-forward + network. + hidden_activation: str, The activation function used in the feed-forward + network. + dropout_rate: float, The dropout rate applied after attention and MLP. + head_dim: int, The dimensionality of each attention head. + initializer_range: float, The range for the random normal initializer. + attention_dropout: float, The dropout rate applied to attention weights. + layer_type: str, Type of attention layer, e.g., `"sliding_attention"`. + cross_attention_hidden_size: int, optional, The hidden size for + cross-attention. If None, it defaults to `hidden_size`. Defaults to + `None`. + attn_logit_softcapping: float, optional, The softcapping value for + attention logits. Defaults to `None`. + sliding_window: int, optional, The window size for sliding attention. + Required if `layer_type` is `"sliding_attention"`. Defaults to + `None`. + rope_max_wavelength: float, The maximum wavelength for Rotary + Positional Embeddings. Defaults to `10000.0`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + hidden_size, + rms_norm_eps, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + intermediate_size, + hidden_activation, + dropout_rate, + head_dim, + initializer_range, + attention_dropout, + layer_type, + cross_attention_hidden_size=None, + attn_logit_softcapping=None, + sliding_window=None, + rope_max_wavelength=10000.0, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.head_dim = head_dim + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.dropout_rate = dropout_rate + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.layer_type = layer_type + self.sliding_window = sliding_window + self.rope_max_wavelength = rope_max_wavelength + self.cross_attention_hidden_size = cross_attention_hidden_size + self.attn_logit_softcapping = attn_logit_softcapping + if ( + self.layer_type == "sliding_attention" + and self.sliding_window is None + ): + raise ValueError( + "`sliding_window` must be set for `sliding_attention` layer " + "type." + ) + + # Self-attention. + self.self_attn = T5GemmaAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + head_dim=self.head_dim, + attention_type="self", + initializer_range=initializer_range, + attention_dropout=attention_dropout, + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=self.rope_max_wavelength, + dtype=self.dtype_policy, + name="self_attention", + ) + self.pre_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="decoder_pre_self_attention_layernorm", + ) + self.post_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="decoder_post_self_attention_layernorm", + ) + + # Cross-attention. + self.cross_attn = T5GemmaAttention( + hidden_size=hidden_size, + cross_attention_hidden_size=cross_attention_hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + head_dim=self.head_dim, + attention_type="cross", + initializer_range=initializer_range, + attention_dropout=attention_dropout, + attn_logit_softcapping=attn_logit_softcapping, + dtype=self.dtype_policy, + name="cross_attention", + ) + self.pre_cross_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="decoder_pre_cross_attention_layernorm", + ) + self.post_cross_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="decoder_post_cross_attention_layernorm", + ) + + # MLP. + self.mlp = T5GemmaMLP( + hidden_size, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range=initializer_range, + dtype=self.dtype_policy, + name="mlp", + ) + self.pre_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="decoder_pre_feedforward_layernorm", + ) + self.post_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="decoder_post_feedforward_layernorm", + ) + + self.dropout = keras.layers.Dropout( + dropout_rate, + dtype=self.dtype_policy, + name="decoder_residual_dropout", + ) + + def build(self, input_shape): + hidden_states_shape, encoder_hidden_states_shape = input_shape + self.pre_self_attn_layernorm.build(hidden_states_shape) + current_shape = hidden_states_shape + self.self_attn.build(current_shape) + attn_output_shape, _ = self.self_attn.compute_output_shape( + current_shape + ) + self.post_self_attn_layernorm.build(attn_output_shape) + current_shape = attn_output_shape + self.dropout.build(current_shape) + self.pre_cross_attn_layernorm.build(current_shape) + self.cross_attn.build([current_shape, encoder_hidden_states_shape]) + attn_output_shape, _ = self.cross_attn.compute_output_shape( + [current_shape, encoder_hidden_states_shape] + ) + self.post_cross_attn_layernorm.build(attn_output_shape) + current_shape = attn_output_shape + self.pre_feedforward_layernorm.build(current_shape) + self.mlp.build(current_shape) + mlp_output_shape = self.mlp.compute_output_shape(current_shape) + self.post_feedforward_layernorm.build(mlp_output_shape) + self.built = True + + def _make_self_attention_mask( + self, + hidden_states, + padding_mask, + cache=None, + cache_update_index=None, + ): + if cache is not None: + q_len = keras.ops.shape(hidden_states)[1] + kv_len = keras.ops.shape(cache)[2] + q_indices = ( + keras.ops.arange(0, q_len, dtype="int32") + cache_update_index + ) + kv_indices = keras.ops.arange(0, kv_len, dtype="int32") + else: + q_len = kv_len = keras.ops.shape(hidden_states)[1] + q_indices = keras.ops.arange(0, q_len, dtype="int32") + kv_indices = keras.ops.arange(0, kv_len, dtype="int32") + # Create the causal mask. + causal_mask = kv_indices[None, :] <= q_indices[:, None] + # Apply sliding window if applicable. + if self.layer_type == "sliding_attention": + sliding_mask = ( + q_indices[:, None] - self.sliding_window + ) <= kv_indices[None, :] + causal_mask = keras.ops.logical_and(causal_mask, sliding_mask) + # Combine with padding mask. + final_mask = causal_mask[None, None, :, :] + if padding_mask is not None: + padding_mask_slice = padding_mask[:, :kv_len] + padding_mask_4d = padding_mask_slice[:, None, None, :] + final_mask = keras.ops.logical_and(final_mask, padding_mask_4d) + return (1.0 - keras.ops.cast(final_mask, hidden_states.dtype)) * -1e9 + + def _make_cross_attention_mask(self, hidden_states, padding_mask): + if padding_mask is None: + return None + bidirectional_mask = padding_mask[:, None, None, :] + additive_bidirectional_mask = ( + 1.0 - keras.ops.cast(bidirectional_mask, hidden_states.dtype) + ) * -1e9 + return additive_bidirectional_mask + + def call( + self, + inputs, + self_attention_padding_mask=None, + cross_attention_padding_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + hidden_states, encoder_hidden_states = inputs + self_attention_cache, cross_attention_cache = ( + cache if cache is not None else (None, None) + ) + # Self Attention. + residual = hidden_states + self_attention_mask = self._make_self_attention_mask( + hidden_states, + self_attention_padding_mask, + cache=self_attention_cache, + cache_update_index=cache_update_index, + ) + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, updated_self_attention_cache = self.self_attn( + inputs=hidden_states, + attention_mask=self_attention_mask, + cache=self_attention_cache, + cache_update_index=cache_update_index, + training=training, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + + # Cross Attention. + residual = hidden_states + cross_attention_mask = self._make_cross_attention_mask( + encoder_hidden_states, cross_attention_padding_mask + ) + hidden_states = self.pre_cross_attn_layernorm(hidden_states) + hidden_states, updated_cross_attention_cache = self.cross_attn( + inputs=[hidden_states, encoder_hidden_states], + attention_mask=cross_attention_mask, + cache=cross_attention_cache, + training=training, + ) + + hidden_states = self.post_cross_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + + # MLP. + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, training=training) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + updated_cache = ( + updated_self_attention_cache, + updated_cross_attention_cache, + ) + return hidden_states, updated_cache + + def compute_output_shape(self, input_shape): + hidden_states_shape, encoder_hidden_states_shape = input_shape + batch_size, dec_seq_len, _ = hidden_states_shape + _, enc_seq_len, _ = encoder_hidden_states_shape + self_cache_shape = ( + batch_size, + 2, + dec_seq_len, + self.num_key_value_heads, + self.head_dim, + ) + cross_cache_shape = ( + batch_size, + 2, + enc_seq_len, + self.num_key_value_heads, + self.head_dim, + ) + return hidden_states_shape, (self_cache_shape, cross_cache_shape) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "dropout_rate": self.dropout_rate, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "layer_type": self.layer_type, + "sliding_window": self.sliding_window, + "rope_max_wavelength": self.rope_max_wavelength, + "head_dim": self.head_dim, + "cross_attention_hidden_size": self.cross_attention_hidden_size, + "attn_logit_softcapping": self.attn_logit_softcapping, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma/t5gemma_encoder.py b/keras_hub/src/models/t5gemma/t5gemma_encoder.py new file mode 100644 index 0000000000..d17a3e880c --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_encoder.py @@ -0,0 +1,214 @@ +import keras + +from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.models.t5gemma.t5gemma_attention import T5GemmaAttention +from keras_hub.src.models.t5gemma.t5gemma_layers import T5GemmaMLP + + +class T5GemmaEncoderLayer(keras.layers.Layer): + """Encoder layer for the T5Gemma model. + + This layer implements a single encoder block in the T5Gemma architecture, + comprising self-attention and a feed-forward network (MLP). + + Args: + hidden_size: int, The dimensionality of the hidden states. + rms_norm_eps: float, The epsilon value for RMS normalization. + num_attention_heads: int, The number of attention heads in + self-attention. + num_key_value_heads: int, The number of key-value heads for grouped + query attention. + query_pre_attn_scalar: float, Scalar to multiply queries by before + attention. + attention_bias: bool, Whether to include bias in attention computations. + intermediate_size: int, The intermediate size of the feed-forward + network. + hidden_activation: str, The activation function used in the feed-forward + network. + dropout_rate: float, The dropout rate applied after attention and MLP. + initializer_range: float, The range for the random normal initializer. + attention_dropout: float, The dropout rate applied to attention weights. + layer_type: str, Type of attention layer, e.g., `"sliding_attention"`. + head_dim: int, The dimensionality of each attention head. + attn_logit_softcapping: float, optional, The softcapping value for + attention logits. Defaults to `None`. + sliding_window: int, optional, The window size for sliding attention. + Required if `layer_type` is `"sliding_attention"`. Defaults to + `None`. + rope_max_wavelength: float, The maximum wavelength for Rotary Positional + Embeddings. Defaults to `10000.0`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + hidden_size, + rms_norm_eps, + num_attention_heads, + num_key_value_heads, + query_pre_attn_scalar, + attention_bias, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range, + attention_dropout, + layer_type, + head_dim, + attn_logit_softcapping=None, + sliding_window=None, + rope_max_wavelength=10000.0, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.query_pre_attn_scalar = query_pre_attn_scalar + self.attention_bias = attention_bias + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.dropout_rate = dropout_rate + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + self.layer_type = layer_type + self.sliding_window = sliding_window + self.rope_max_wavelength = rope_max_wavelength + self.head_dim = head_dim + self.attn_logit_softcapping = attn_logit_softcapping + if ( + self.layer_type == "sliding_attention" + and self.sliding_window is None + ): + raise ValueError( + "`sliding_window` must be set for `sliding_attention` layer " + "type." + ) + self.self_attn = T5GemmaAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + query_pre_attn_scalar=query_pre_attn_scalar, + attention_bias=attention_bias, + head_dim=self.head_dim, + attention_type="self", + initializer_range=initializer_range, + attention_dropout=attention_dropout, + attn_logit_softcapping=attn_logit_softcapping, + rope_max_wavelength=self.rope_max_wavelength, + dtype=self.dtype_policy, + name="self_attention", + ) + self.pre_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="pre_self_attention_layernorm", + ) + self.post_self_attn_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="post_self_attention_layernorm", + ) + + self.mlp = T5GemmaMLP( + hidden_size, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range=initializer_range, + dtype=self.dtype_policy, + name="mlp", + ) + self.pre_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="pre_feedforward_layernorm", + ) + self.post_feedforward_layernorm = RMSNormalization( + epsilon=rms_norm_eps, + dtype=self.dtype_policy, + name="post_feedforward_layernorm", + ) + self.dropout = keras.layers.Dropout( + dropout_rate, + dtype=self.dtype_policy, + name="residual_dropout", + ) + + def build(self, input_shape): + self.pre_self_attn_layernorm.build(input_shape) + self.self_attn.build(input_shape) + attn_output_shape, _ = self.self_attn.compute_output_shape(input_shape) + self.post_self_attn_layernorm.build(attn_output_shape) + self.dropout.build(attn_output_shape) + self.pre_feedforward_layernorm.build(attn_output_shape) + self.mlp.build(attn_output_shape) + mlp_output_shape = self.mlp.compute_output_shape(attn_output_shape) + self.post_feedforward_layernorm.build(mlp_output_shape) + self.built = True + + def _make_attention_mask(self, hidden_states, padding_mask): + attention_mask = padding_mask[:, None, None, :] + additive_mask = ( + 1.0 - keras.ops.cast(attention_mask, hidden_states.dtype) + ) * -1e9 + return additive_mask + + def call( + self, + hidden_states, + padding_mask=None, + training=None, + ): + residual = hidden_states + attention_mask = self._make_attention_mask(hidden_states, padding_mask) + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + inputs=hidden_states, + attention_mask=attention_mask, + training=training, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, training=training) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout( + hidden_states, training=training + ) + return hidden_states + + def compute_output_shape(self, input_shape): + # Isometric. + return input_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "rms_norm_eps": self.rms_norm_eps, + "head_dim": self.head_dim, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "query_pre_attn_scalar": self.query_pre_attn_scalar, + "attention_bias": self.attention_bias, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "dropout_rate": self.dropout_rate, + "initializer_range": self.initializer_range, + "attention_dropout": self.attention_dropout, + "layer_type": self.layer_type, + "sliding_window": self.sliding_window, + "rope_max_wavelength": self.rope_max_wavelength, + "attn_logit_softcapping": self.attn_logit_softcapping, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma/t5gemma_layers.py b/keras_hub/src/models/t5gemma/t5gemma_layers.py new file mode 100644 index 0000000000..1a9d18f186 --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_layers.py @@ -0,0 +1,118 @@ +import keras + +from keras_hub.src.utils.keras_utils import clone_initializer + + +def t5gemma_kernel_initializer(initializer_range=0.01): + """Creates a RandomNormal initializer for T5Gemma kernels. + + Args: + initializer_range: float, The standard deviation of the normal + distribution. Defaults to `0.01`. + + Returns: + keras.initializers.RandomNormal: A Keras RandomNormal initializer. + """ + return keras.initializers.RandomNormal(mean=0.0, stddev=initializer_range) + + +class T5GemmaMLP(keras.layers.Layer): + """Multilayer Perceptron (MLP) block for the T5Gemma model. + + This layer implements the feed-forward part of a transformer block, + consisting of two dense layers with a GELU activation and dropout. + + Args: + hidden_size: int, The dimensionality of the input and output hidden + states. + intermediate_size: int, The dimensionality of the intermediate layer. + hidden_activation: str, The activation function to use, e.g., + "gelu_approximate". + dropout_rate: float, The dropout rate applied to the intermediate + hidden states. + initializer_range: float, The range for the random normal initializer + for kernel weights. Defaults to `0.02`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + hidden_size, + intermediate_size, + hidden_activation, + dropout_rate, + initializer_range=0.02, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.dropout_rate = dropout_rate + self.initializer_range = initializer_range + self.kernel_initializer = t5gemma_kernel_initializer(initializer_range) + + self.gate_proj = keras.layers.Dense( + self.intermediate_size, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, + name="gate_proj", + ) + self.up_proj = keras.layers.Dense( + self.intermediate_size, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, + name="up_proj", + ) + self.down_proj = keras.layers.Dense( + self.hidden_size, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, + name="down_proj", + ) + if self.hidden_activation == "gelu_approximate": + # NOTE: `gelu_pytorch_tanh` is the same as `gelu(approximate=True)`. + self.act_fn = lambda x: keras.activations.gelu(x, approximate=True) + else: + self.act_fn = keras.activations.get(self.hidden_activation) + self.dropout = keras.layers.Dropout( + self.dropout_rate, + dtype=self.dtype_policy, + name="mlp_dropout", + ) + + def build(self, input_shape): + self.gate_proj.build(input_shape) + self.up_proj.build(input_shape) + intermediate_shape = self.gate_proj.compute_output_shape(input_shape) + self.dropout.build(intermediate_shape) + self.down_proj.build(intermediate_shape) + self.built = True + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x, training=None): + hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + hidden_states = self.dropout(hidden_states, training=training) + down_proj = self.down_proj(hidden_states) + return down_proj + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_size": self.hidden_size, + "intermediate_size": self.intermediate_size, + "hidden_activation": self.hidden_activation, + "dropout_rate": self.dropout_rate, + "initializer_range": self.initializer_range, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma/t5gemma_presets.py b/keras_hub/src/models/t5gemma/t5gemma_presets.py new file mode 100644 index 0000000000..25bfb8465b --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_presets.py @@ -0,0 +1,15 @@ +# Metadata for loading pretrained model weights. +backbone_presets = { + "t5gemma_b_b_prefixlm_it": { + "metadata": { + "description": ( + "T5Gemma B/B model with a base encoder and base decoder, " + "adapted as a prefix language model and fine-tuned for " + "instruction following." + ), + "params": 591490560, + "path": "t5gemma", + }, + "kaggle_handle": "kaggle://harshaljanjani/t5gemma/keras/t5gemma_b_b_prefixlm_it", + }, +} diff --git a/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py new file mode 100644 index 0000000000..799080c9dc --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py @@ -0,0 +1,442 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import ( + T5GemmaSeq2SeqLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export("keras_hub.models.T5GemmaSeq2SeqLM") +class T5GemmaSeq2SeqLM(Seq2SeqLM): + """An end-to-end T5Gemma model for seq2seq language modeling. + + A seq2seq language model (LM) is an encoder-decoder model which is used for + conditional text generation. The encoder is given a "context" text (fed to + the encoder), and the decoder predicts the next token based on both the + encoder inputs and the previous tokens. You can finetune `T5GemmaSeq2SeqLM` + to generate text for any seq2seq task (e.g., translation or summarization). + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_hub.samplers` objects to control the generation. By + default, `"greedy"` sampling will be used. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to string inputs during + `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default + when creating the model with `from_preset()`. + + Args: + backbone: A `keras_hub.models.T5GemmaBackbone` instance. + preprocessor: A `keras_hub.models.T5GemmaSeq2SeqLMPreprocessor` or + `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. Defaults + to `None`. + + Examples: + + Use `generate()` to do text generation. + ```python + import numpy as np + t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset( + "t5gemma_b_b_prefixlm_it" + ) + # Generate with encoder-only input. + t5gemma_lm.generate("The quick brown fox jumped.", max_length=30) + + # Generate with batched encoder-only inputs. + t5gemma_lm.generate( + ["The quick brown fox jumped.", "The whale."], + max_length=30 + ) + # Generate with encoder and decoder inputs. + t5gemma_lm.generate( + { + "encoder_text": "The quick brown fox jumped.", + "decoder_text": "A fast fox" + }, + max_length=30 + ) + ``` + + Compile the `generate()` function with a custom sampler. + ```python + t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset( + "t5gemma_b_b_prefixlm_it" + ) + t5gemma_lm.compile(sampler="top_k") + t5gemma_lm.generate("I want to say", max_length=30) + + t5gemma_lm.compile(sampler=keras_hub.samplers.BeamSampler(num_beams=2)) + t5gemma_lm.generate("I want to say", max_length=30) + ``` + + Use `generate()` without preprocessing. + ```python + # Preprocessed inputs, with encoder inputs corresponding to + # "The quick brown fox", and the decoder inputs to "A fast fox". + # Use `"padding_mask"` to indicate values that should not be overridden. + prompt = { + "encoder_token_ids": np.array([[2, 10, 133, 2119, 6219, 23602, 1, 0]]), + "encoder_padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 0]]), + "decoder_token_ids": np.array([[2, 133, 1769, 1, 0, 0, 0]]), + "decoder_padding_mask": np.array([[1, 1, 1, 1, 0, 0, 0]]) + } + + t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset( + "t5gemma_b_b_prefixlm_it", + preprocessor=None, + ) + t5gemma_lm.generate(prompt) + ``` + + Call `fit()` on a single batch. + ```python + features = { + "encoder_text": ["The quick fox jumped.", "I forgot my homework."], + "decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."] + } + t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset( + "t5gemma_b_b_prefixlm_it" + ) + t5gemma_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` without preprocessing. + ```python + x = { + "encoder_token_ids": np.array([[2, 133, 2119, 1, 0]] * 2), + "encoder_padding_mask": np.array([[1, 1, 1, 1, 0]] * 2), + "decoder_token_ids": np.array([[2, 133, 1769, 1, 0]] * 2), + "decoder_padding_mask": np.array([[1, 1, 1, 1, 1]] * 2), + } + y = np.array([[133, 1769, 1, 0, 0]] * 2) + sw = np.array([[1, 1, 1, 0, 0]] * 2) + + t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset( + "t5gemma_b_b_prefixlm_it", + preprocessor=None, + ) + t5gemma_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2) + ``` + + Custom backbone and vocabulary. + ```python + features = { + "encoder_text": ["The quick fox jumped.", "I forgot my homework."], + "decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."] + } + tokenizer = keras_hub.models.T5GemmaTokenizer( + proto="proto.spm", + ) + preprocessor = keras_hub.models.T5GemmaSeq2SeqLMPreprocessor( + tokenizer=tokenizer, + encoder_sequence_length=128, + decoder_sequence_length=128, + ) + backbone = keras_hub.models.T5GemmaBackbone( + vocabulary_size=32000, + # Encoder parameters. + encoder_hidden_dim=256, + encoder_intermediate_dim=512, + encoder_num_layers=4, + encoder_num_attention_heads=4, + encoder_num_key_value_heads=2, + encoder_head_dim=64, + encoder_layer_types=["full_attention"] * 4, + # Decoder parameters. + decoder_hidden_dim=256, + decoder_intermediate_dim=512, + decoder_num_layers=4, + decoder_num_attention_heads=4, + decoder_num_key_value_heads=2, + decoder_head_dim=64, + decoder_layer_types=["full_attention"] * 4, + # Common parameters. + dropout_rate=0.1, + rms_norm_eps=1e-6, + query_pre_attn_scalar=1.0, + attention_bias=False, + hidden_activation="gelu_approximate", + ) + t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM( + backbone=backbone, + preprocessor=preprocessor, + ) + t5gemma_lm.fit(x=features, batch_size=2) + ``` + """ + + backbone_cls = T5GemmaBackbone + preprocessor_cls = T5GemmaSeq2SeqLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input + sequence_output = backbone(inputs)["decoder_sequence_output"] + logits = backbone.decoder_token_embedding(sequence_output, reverse=True) + if self.backbone.final_logit_softcapping is not None: + logits = logits / self.backbone.final_logit_softcapping + logits = keras.ops.tanh(logits) + logits = logits * self.backbone.final_logit_softcapping + super().__init__( + inputs=inputs, + outputs=logits, + **kwargs, + ) + + def call_encoder(self, token_ids, padding_mask): + """Process inputs through the encoder stack.""" + encoder_embeddings = self.backbone.token_embedding(token_ids) + encoder_embeddings *= keras.ops.cast( + keras.ops.sqrt(self.backbone.encoder_hidden_dim), + encoder_embeddings.dtype, + ) + encoder_hidden_states = self.backbone.encoder_dropout( + encoder_embeddings, training=False + ) + for layer in self.backbone.encoder_layers: + encoder_hidden_states = layer( + encoder_hidden_states, padding_mask=padding_mask, training=False + ) + encoder_output = self.backbone.encoder_norm(encoder_hidden_states) + encoder_output = self.backbone.encoder_dropout( + encoder_output, training=False + ) + return encoder_output, padding_mask + + def call_decoder_with_cache( + self, + decoder_token_ids, + decoder_padding_mask, + cache, + cache_update_index, + encoder_output, + encoder_padding_mask, + ): + """Forward pass of `T5GemmaSeq2SeqLM`'s decoder with cache. + + `call_decoder_with_cache` adds an additional forward pass for the model + for autoregressive inference. Unlike calling the model directly, this + method allows caching previous key/value Tensors in the attention + layers, and avoids recomputing the outputs of seen tokens. + + Args: + decoder_token_ids: A dense int Tensor with shape + `(batch_size, max_length)`. The token ids for the decoder. + decoder_padding_mask: A dense int Tensor with shape `(batch_size, + max_length)`. The padding mask for the decoder. + cache: A dense float Tensor, the cache of key and value states. + cache_update_index: int, or int Tensor. The index of the current + token being processed in the whole sequence. + encoder_output: A dense float Tensor. The output of the encoder. + encoder_padding_mask: A dense int Tensor. The padding mask for + the encoder output. + + Returns: + A `(logits, hidden_states, cache)` tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the updated decoding cache. + """ + self_attention_cache, cross_attention_cache = cache + hidden_states = self.backbone.decoder_token_embedding(decoder_token_ids) + hidden_states *= keras.ops.cast( + keras.ops.sqrt(self.backbone.decoder_hidden_dim), + hidden_states.dtype, + ) + hidden_states = self.backbone.decoder_dropout( + hidden_states, training=False + ) + # Every decoder layer has a separate cache for the self-attention layer + # and the cross-attention layer. We update all of them separately. + updated_self_attention_caches = [] + updated_cross_attention_caches = [] + for i, layer in enumerate(self.backbone.decoder_layers): + layer_self_cache = ( + self_attention_cache[:, i, ...] + if self_attention_cache is not None + else None + ) + layer_cross_cache = ( + cross_attention_cache[:, i, ...] + if cross_attention_cache is not None + else None + ) + layer_cache = (layer_self_cache, layer_cross_cache) + hidden_states, updated_layer_cache = layer( + (hidden_states, encoder_output), + self_attention_padding_mask=decoder_padding_mask, + cross_attention_padding_mask=encoder_padding_mask, + cache=layer_cache, + cache_update_index=cache_update_index, + training=False, + ) + new_self_cache, new_cross_cache = updated_layer_cache + updated_self_attention_caches.append(new_self_cache) + updated_cross_attention_caches.append(new_cross_cache) + self_attention_cache = keras.ops.stack( + updated_self_attention_caches, axis=1 + ) + cross_attention_cache = keras.ops.stack( + updated_cross_attention_caches, axis=1 + ) + hidden_states = self.backbone.decoder_norm(hidden_states) + logits = self.backbone.decoder_token_embedding( + hidden_states, reverse=True + ) + if self.backbone.final_logit_softcapping is not None: + logits = logits / self.backbone.final_logit_softcapping + logits = keras.ops.tanh(logits) + logits = logits * self.backbone.final_logit_softcapping + return ( + logits, + hidden_states, + (self_attention_cache, cross_attention_cache), + ) + + def _build_cache( + self, + encoder_token_ids, + encoder_padding_mask, + decoder_token_ids, + decoder_padding_mask, + ): + """Build an empty cache for use with `call_with_cache()`.""" + encoder_output, encoder_padding_mask = self.call_encoder( + encoder_token_ids, encoder_padding_mask + ) + batch_size = keras.ops.shape(decoder_token_ids)[0] + num_layers = self.backbone.decoder_num_layers + num_kv_heads = self.backbone.decoder_num_key_value_heads + head_dim = self.backbone.decoder_head_dim + self_cache_shape = ( + batch_size, + num_layers, + 2, + keras.ops.shape(decoder_token_ids)[1], + num_kv_heads, + head_dim, + ) + self_attention_cache = keras.ops.zeros( + self_cache_shape, dtype=self.compute_dtype + ) + cross_attention_cache = None + _, hidden_states, cache = self.call_decoder_with_cache( + decoder_token_ids=decoder_token_ids, + decoder_padding_mask=decoder_padding_mask, + cache=(self_attention_cache, cross_attention_cache), + cache_update_index=0, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask, + ) + extra_cache_info = (encoder_output, encoder_padding_mask) + return hidden_states, cache, extra_cache_info + + def generate_step(self, inputs, stop_token_ids=None): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + `"encoder_token_ids"`, `"encoder_padding_mask"`, `"decoder_token_ids"` + and `"decoder_padding_mask"`. + + Args: + inputs: A dictionary with four keys - `"encoder_token_ids"`, + `"encoder_padding_mask"`, `"decoder_token_ids"` and + `"decoder_padding_mask"`, with batched tensor values. + stop_token_ids: Tuple of id's of end token's to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + encoder_token_ids = inputs["encoder_token_ids"] + encoder_padding_mask = inputs["encoder_padding_mask"] + decoder_token_ids = inputs["decoder_token_ids"] + decoder_padding_mask = inputs["decoder_padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache, extra_cache_info = self._build_cache( + encoder_token_ids=encoder_token_ids, + encoder_padding_mask=encoder_padding_mask, + decoder_token_ids=decoder_token_ids, + decoder_padding_mask=decoder_padding_mask, + ) + encoder_output, encoder_padding_mask = extra_cache_info + # Compute the lengths of all user inputted tokens ids. + row_lengths = keras.ops.sum( + keras.ops.cast(decoder_padding_mask, "int32"), axis=-1 + ) + # Start at the first index that has no user inputted id. + index = keras.ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = keras.ops.shape(prompt)[0] + prompt = keras.ops.slice( + prompt, [0, cache_update_index], [batch_size, 1] + ) + ( + logits, + _, + updated_cache, + ) = self.call_decoder_with_cache( + decoder_token_ids=prompt, + decoder_padding_mask=None, + cache_update_index=cache_update_index, + cache=cache, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask, + ) + return keras.ops.squeeze(logits, axis=1), None, updated_cache + + decoder_token_ids = self.sampler( + next=next, + prompt=decoder_token_ids, + cache=cache, + index=index, + mask=decoder_padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of `stop_token_ids` locations not in the original + # prompt (not in locations where `decoder_padding_mask` is True). + end_locations = any_equal( + decoder_token_ids, + stop_token_ids, + keras.ops.logical_not(decoder_padding_mask), + ) + # Use cumsum to get ones in all locations after end_locations. + end_locations = keras.ops.cast(end_locations, "int32") + cumsum = keras.ops.cast( + keras.ops.cumsum(end_locations, axis=-1), "int32" + ) + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + decoder_padding_mask = keras.ops.logical_not( + keras.ops.cast(overflow, "bool") + ) + else: + # Without early stopping, all locations will have been updated. + decoder_padding_mask = keras.ops.ones_like( + decoder_token_ids, dtype="bool" + ) + + return { + "decoder_token_ids": decoder_token_ids, + "decoder_padding_mask": decoder_padding_mask, + } diff --git a/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py new file mode 100644 index 0000000000..1570d4796d --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py @@ -0,0 +1,216 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.models.t5gemma.t5gemma_tokenizer import T5GemmaTokenizer +from keras_hub.src.utils.tensor_utils import preprocessing_function + +try: + import tensorflow as tf +except ImportError: + tf = None + + +@keras_hub_export("keras_hub.models.T5GemmaSeq2SeqLMPreprocessor") +class T5GemmaSeq2SeqLMPreprocessor(Seq2SeqLMPreprocessor): + """T5Gemma Seq2Seq LM preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.T5GemmaSeq2SeqLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_hub.models.T5GemmaSeq2SeqLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_hub.models.T5GemmaTokenizer` instance. + encoder_sequence_length: The length of the packed encoder inputs. + decoder_sequence_length: The length of the packed decoder inputs. + add_start_token: If `True`, the preprocessor will prepend the + tokenizer start token to each input sequence. For T5Gemma models, + this should be `False`. Defaults to `False`. + add_end_token: If `True`, the preprocessor will append the tokenizer end + token to each input sequence. For T5Gemma models, this should be + `True`. Defaults to `True`. + + Call arguments: + x: A dictionary with two keys, `"encoder_text"` and `"decoder_text"`. + The values can be a string, a `tf.Tensor` or a list of python + strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + encoder_sequence_length: Pass to override the configured + `encoder_sequence_length` of the layer. + decoder_sequence_length: Pass to override the configured + `decoder_sequence_length` of the layer. + + Examples: + ```python + import tensorflow as tf + import numpy as np + + # Load the preprocessor from a preset. + preprocessor = keras_hub.models.T5GemmaSeq2SeqLMPreprocessor.from_preset( + "t5gemma_b_b_prefixlm_it" + ) + + # For example usage, see the dictionary example below which provides + # both encoder and decoder text. + # Tokenize a batch of sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + # Tokenize a dictionary with separate encoder and decoder inputs. + preprocessor({ + "encoder_text": "The quick brown fox jumped.", + "decoder_text": "The fast fox." + }) + + # Apply tokenization to a `tf.data.Dataset`. + encoder_features = tf.constant(["The quick brown fox.", "Call me Ishmael."]) + decoder_features = tf.constant(["The fast fox.", "I am Ishmael."]) + ds = tf.data.Dataset.from_tensor_slices( + {"encoder_text": encoder_features, "decoder_text": decoder_features} + ) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Prepare tokens for generation. + preprocessor.generate_preprocess({ + "encoder_text": "The quick brown fox jumped.", + "decoder_text": "The fast fox." + }) + + # Map generation outputs back to strings. + preprocessor.generate_postprocess({ + 'decoder_token_ids': np.array([[2, 714, 4320, 8426, 25341, 1, 0, 0]]), + 'decoder_padding_mask': np.array([[1, 1, 1, 1, 1, 1, 0, 0]]), + }) + ``` + """ + + backbone_cls = T5GemmaBackbone + tokenizer_cls = T5GemmaTokenizer + + def __init__( + self, + tokenizer, + encoder_sequence_length=512, + decoder_sequence_length=512, + add_start_token=False, + add_end_token=True, + **kwargs, + ): + # Do not pass `add_start_token` and `add_end_token` to the base class. + super().__init__( + tokenizer=tokenizer, + encoder_sequence_length=encoder_sequence_length, + decoder_sequence_length=decoder_sequence_length, + **kwargs, + ) + # Store them directly on the subclass instance. + self.add_start_token = add_start_token + self.add_end_token = add_end_token + + @preprocessing_function + def call( + self, + x, + y=None, + sample_weight=None, + *, + encoder_sequence_length=None, + decoder_sequence_length=None, + sequence_length=None, + ): + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + + encoder_inputs = self.tokenizer(x["encoder_text"]) + encoder_token_ids, encoder_padding_mask = self.encoder_packer( + encoder_inputs, + sequence_length=encoder_sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + decoder_inputs = self.tokenizer(x["decoder_text"]) + decoder_token_ids, decoder_padding_mask = self.decoder_packer( + decoder_inputs, + sequence_length=decoder_sequence_length + 1, + add_start_value=True, + add_end_value=self.add_end_token, + ) + x = { + "encoder_token_ids": encoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_token_ids": decoder_token_ids[..., :-1], + "decoder_padding_mask": decoder_padding_mask[..., :-1], + } + y = decoder_token_ids[..., 1:] + sample_weight = decoder_padding_mask[..., 1:] + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + @preprocessing_function + def generate_preprocess( + self, + x, + *, + encoder_sequence_length=None, + decoder_sequence_length=None, + sequence_length=None, + ): + if not self.built: + self.build(None) + + if isinstance(x, dict): + encoder_text = x["encoder_text"] + decoder_text = x["decoder_text"] + else: + encoder_text = x + decoder_text = tf.fill((tf.shape(encoder_text)[0],), "") + + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + + encoder_token_ids = self.tokenizer(encoder_text) + encoder_token_ids, encoder_padding_mask = self.encoder_packer( + encoder_token_ids, + sequence_length=None, + add_start_value=self.add_start_token, + add_end_value=False, + ) + + decoder_token_ids = self.tokenizer(decoder_text) + decoder_token_ids, decoder_padding_mask = self.decoder_packer( + decoder_token_ids, + sequence_length=decoder_sequence_length, + add_start_value=True, + add_end_value=False, + ) + + return { + "encoder_token_ids": encoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_token_ids": decoder_token_ids, + "decoder_padding_mask": decoder_padding_mask, + } + + def get_config(self): + config = super().get_config() + config.update( + { + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config diff --git a/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_test.py b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_test.py new file mode 100644 index 0000000000..0a4cb0ef4e --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_test.py @@ -0,0 +1,166 @@ +import os +from unittest.mock import patch + +import keras +import pytest + +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm import T5GemmaSeq2SeqLM +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import ( + T5GemmaSeq2SeqLMPreprocessor, +) +from keras_hub.src.models.t5gemma.t5gemma_tokenizer import T5GemmaTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class T5GemmaSeq2SeqLMTest(TestCase): + def setUp(self): + self.tokenizer = T5GemmaTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ), + ) + self.preprocessor = T5GemmaSeq2SeqLMPreprocessor( + tokenizer=self.tokenizer, + encoder_sequence_length=8, + decoder_sequence_length=10, + ) + self.backbone = T5GemmaBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + encoder_hidden_dim=16, + encoder_intermediate_dim=32, + encoder_num_layers=2, + encoder_num_attention_heads=2, + encoder_num_key_value_heads=1, + encoder_head_dim=8, + encoder_layer_types=["sliding_attention", "full_attention"], + decoder_hidden_dim=16, + decoder_intermediate_dim=32, + decoder_num_layers=2, + decoder_num_attention_heads=2, + decoder_num_key_value_heads=1, + decoder_head_dim=8, + decoder_layer_types=["sliding_attention", "full_attention"], + dropout_rate=0.0, + rms_norm_eps=1e-6, + tie_word_embeddings=False, + query_pre_attn_scalar=1.0, + attention_bias=False, + hidden_activation="gelu_approximate", + initializer_range=0.02, + attention_dropout=0.0, + sliding_window=4, + final_logit_softcapping=30.0, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = ( + { + "encoder_text": ["the quick brown fox", "the earth is round"], + "decoder_text": ["the quick brown fox", "the earth is round"], + }, + ) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=T5GemmaSeq2SeqLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=( + 2, + 10, + self.preprocessor.tokenizer.vocabulary_size(), + ), + ) + + def test_generate(self): + causal_lm = T5GemmaSeq2SeqLM(**self.init_kwargs) + # String inputs. + inputs = { + "encoder_text": "the quick brown fox", + "decoder_text": "the quick", + } + output = causal_lm.generate(inputs) + self.assertTrue("the quick" in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess(inputs) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["decoder_token_ids"][:, :3], + prompt_ids["decoder_token_ids"][:, :3], + ) + self.assertAllEqual( + outputs["decoder_padding_mask"][:, :3], + prompt_ids["decoder_padding_mask"][:, :3], + ) + + def test_early_stopping(self): + causal_lm = T5GemmaSeq2SeqLM(**self.init_kwargs) + call_decoder_with_cache = causal_lm.call_decoder_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + ( + logits, + hidden_states, + cache, + ) = call_decoder_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ( + keras.ops.ones( + (keras.ops.shape(logits)[0], 1, 1), dtype=logits.dtype + ) + * 1.0e9 + ) + logits = keras.ops.slice_update(logits, (0, 0, index), update) + return ( + logits, + hidden_states, + cache, + ) + + with patch.object(causal_lm, "call_decoder_with_cache", wraps=wrapper): + inputs = { + "encoder_text": [ + "the quick brown fox", + "the earth is round", + ], + "decoder_text": ["the quick", "the earth"], + } + output = causal_lm.generate(inputs) + # We should immediately abort and output the prompt. + self.assertEqual(inputs["decoder_text"], output) + + def test_generate_compilation(self): + causal_lm = T5GemmaSeq2SeqLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the quick brown fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the quick brown fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=T5GemmaSeq2SeqLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in T5GemmaSeq2SeqLM.presets: + self.run_preset_test( + cls=T5GemmaSeq2SeqLM, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/t5gemma/t5gemma_tokenizer.py b/keras_hub/src/models/t5gemma/t5gemma_tokenizer.py new file mode 100644 index 0000000000..a3a6d27365 --- /dev/null +++ b/keras_hub/src/models/t5gemma/t5gemma_tokenizer.py @@ -0,0 +1,84 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( + SentencePieceTokenizer, +) + + +@keras_hub_export( + [ + "keras_hub.tokenizers.T5GemmaTokenizer", + "keras_hub.models.T5GemmaTokenizer", + ] +) +class T5GemmaTokenizer(SentencePieceTokenizer): + """T5Gemma tokenizer layer based on SentencePiece. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_hub.tokenizers.SentencePieceTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by + T5Gemma models and provides a `from_preset()` method to automatically + download a matching vocabulary for a T5Gemma preset. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + proto: Either a `string` path to a SentencePiece proto file, or a + `bytes` object with a serialized SentencePiece proto. See the + [SentencePiece repository](https://github.com/google/sentencepiece) + for more details on the format. + + Examples: + + ```python + import io + import tensorflow as tf + import sentencepiece + + # Unbatched input. + tokenizer = keras_hub.models.T5GemmaTokenizer.from_preset( + "t5gemma_b_b_prefixlm_it" + ) + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + + # Custom vocabulary. + bytes_io = io.BytesIO() + ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."]) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=ds.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=8, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + tokenizer = keras_hub.models.T5GemmaTokenizer( + proto=bytes_io.getvalue(), + ) + tokenizer("The quick brown fox jumped.") + ``` + """ + + backbone_cls = T5GemmaBackbone + + def __init__(self, proto, **kwargs): + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + self._add_special_token("", "pad_token") + super().__init__(proto=proto, **kwargs) diff --git a/keras_hub/src/utils/transformers/convert_t5gemma.py b/keras_hub/src/utils/transformers/convert_t5gemma.py new file mode 100644 index 0000000000..f89242d944 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_t5gemma.py @@ -0,0 +1,229 @@ +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.utils.preset_utils import get_file + +backbone_cls = T5GemmaBackbone + + +def convert_backbone_config(transformers_config): + """Convert a Hugging Face T5Gemma config to a KerasHub backbone config.""" + encoder_config = transformers_config["encoder"] + decoder_config = transformers_config["decoder"] + + if decoder_config.get("hidden_activation") == "gelu_pytorch_tanh": + decoder_config["hidden_activation"] = "gelu_approximate" + if encoder_config.get("hidden_activation") == "gelu_pytorch_tanh": + encoder_config["hidden_activation"] = "gelu_approximate" + + backbone_config = { + "vocabulary_size": decoder_config["vocab_size"], + "encoder_hidden_dim": encoder_config["hidden_size"], + "encoder_intermediate_dim": encoder_config["intermediate_size"], + "encoder_num_layers": encoder_config["num_hidden_layers"], + "encoder_num_attention_heads": encoder_config["num_attention_heads"], + "encoder_num_key_value_heads": encoder_config["num_key_value_heads"], + "encoder_head_dim": encoder_config["head_dim"], + "encoder_layer_types": encoder_config["layer_types"], + "decoder_hidden_dim": decoder_config["hidden_size"], + "decoder_intermediate_dim": decoder_config["intermediate_size"], + "decoder_num_layers": decoder_config["num_hidden_layers"], + "decoder_num_attention_heads": decoder_config["num_attention_heads"], + "decoder_num_key_value_heads": decoder_config["num_key_value_heads"], + "decoder_head_dim": decoder_config["head_dim"], + "decoder_layer_types": decoder_config["layer_types"], + "dropout_rate": decoder_config["dropout_rate"], + "rms_norm_eps": decoder_config["rms_norm_eps"], + "query_pre_attn_scalar": decoder_config["query_pre_attn_scalar"], + "tie_word_embeddings": transformers_config.get( + "tie_word_embeddings", True + ), + "attention_bias": decoder_config["attention_bias"], + "hidden_activation": decoder_config["hidden_activation"], + "initializer_range": decoder_config["initializer_range"], + "attention_dropout": decoder_config["attention_dropout"], + "sliding_window": decoder_config["sliding_window"], + "cross_attention_hidden_size": encoder_config["hidden_size"], + "attn_logit_softcapping": decoder_config["attn_logit_softcapping"], + "final_logit_softcapping": decoder_config["final_logit_softcapping"], + "rope_max_wavelength": decoder_config["rope_theta"], + } + return backbone_config + + +def convert_weights(backbone, loader, transformers_config): + """Convert T5Gemma from Hugging Face to KerasHub.""" + # Token embeddings. + loader.port_weight( + keras_variable=backbone.token_embedding.embeddings, + hf_weight_key="encoder.embed_tokens.weight", + ) + loader.port_weight( + keras_variable=backbone.decoder_token_embedding.embeddings, + hf_weight_key="decoder.embed_tokens.weight", + ) + + # Encoder. + loader.port_weight( + keras_variable=backbone.encoder_norm.scale, + hf_weight_key="encoder.norm.weight", + ) + for i in range(backbone.encoder_num_layers): + layer = backbone.get_layer(f"encoder_layer_{i}") + hf_prefix = f"encoder.layers.{i}" + + # Self-attention. + loader.port_weight( + keras_variable=layer.self_attn.query_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.q_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.key_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.k_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.value_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.v_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.output_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.o_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + + # MLP. + loader.port_weight( + keras_variable=layer.mlp.gate_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.gate_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.up_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.up_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.down_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.down_proj.weight", + hook_fn=lambda w, s: w.T, + ) + + # Layer norm. + loader.port_weight( + keras_variable=layer.pre_self_attn_layernorm.scale, + hf_weight_key=f"{hf_prefix}.pre_self_attn_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.post_self_attn_layernorm.scale, + hf_weight_key=f"{hf_prefix}.post_self_attn_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.pre_feedforward_layernorm.scale, + hf_weight_key=f"{hf_prefix}.pre_feedforward_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.post_feedforward_layernorm.scale, + hf_weight_key=f"{hf_prefix}.post_feedforward_layernorm.weight", + ) + + # Decoder. + loader.port_weight( + keras_variable=backbone.decoder_norm.scale, + hf_weight_key="decoder.norm.weight", + ) + for i in range(backbone.decoder_num_layers): + layer = backbone.get_layer(f"decoder_layer_{i}") + hf_prefix = f"decoder.layers.{i}" + + # Self-attention. + loader.port_weight( + keras_variable=layer.self_attn.query_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.q_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.key_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.k_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.value_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.v_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.self_attn.output_dense.kernel, + hf_weight_key=f"{hf_prefix}.self_attn.o_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + + # Cross-attention. + loader.port_weight( + keras_variable=layer.cross_attn.query_dense.kernel, + hf_weight_key=f"{hf_prefix}.cross_attn.q_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.cross_attn.key_dense.kernel, + hf_weight_key=f"{hf_prefix}.cross_attn.k_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.cross_attn.value_dense.kernel, + hf_weight_key=f"{hf_prefix}.cross_attn.v_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + loader.port_weight( + keras_variable=layer.cross_attn.output_dense.kernel, + hf_weight_key=f"{hf_prefix}.cross_attn.o_proj.weight", + hook_fn=lambda w, s: w.T.reshape(s), + ) + + # MLP. + loader.port_weight( + keras_variable=layer.mlp.gate_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.gate_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.up_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.up_proj.weight", + hook_fn=lambda w, s: w.T, + ) + loader.port_weight( + keras_variable=layer.mlp.down_proj.kernel, + hf_weight_key=f"{hf_prefix}.mlp.down_proj.weight", + hook_fn=lambda w, s: w.T, + ) + + # Layer norm. + loader.port_weight( + keras_variable=layer.pre_self_attn_layernorm.scale, + hf_weight_key=f"{hf_prefix}.pre_self_attn_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.post_self_attn_layernorm.scale, + hf_weight_key=f"{hf_prefix}.post_self_attn_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.pre_cross_attn_layernorm.scale, + hf_weight_key=f"{hf_prefix}.pre_cross_attn_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.post_cross_attn_layernorm.scale, + hf_weight_key=f"{hf_prefix}.post_cross_attn_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.pre_feedforward_layernorm.scale, + hf_weight_key=f"{hf_prefix}.pre_feedforward_layernorm.weight", + ) + loader.port_weight( + keras_variable=layer.post_feedforward_layernorm.scale, + hf_weight_key=f"{hf_prefix}.post_feedforward_layernorm.weight", + ) + + +def convert_tokenizer(cls, preset, **kwargs): + """Convert a T5Gemma tokenizer.""" + return cls(get_file(preset, "tokenizer.model"), **kwargs) diff --git a/keras_hub/src/utils/transformers/convert_t5gemma_test.py b/keras_hub/src/utils/transformers/convert_t5gemma_test.py new file mode 100644 index 0000000000..939984eba5 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_t5gemma_test.py @@ -0,0 +1,31 @@ +import pytest + +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm import T5GemmaSeq2SeqLM +from keras_hub.src.tests.test_case import TestCase + + +class TestTask(TestCase): + @pytest.mark.large + def test_convert_tiny_preset(self): + model = T5GemmaSeq2SeqLM.from_preset( + "hf://harshaljanjani/tiny-t5gemma-test" + ) + prompt = "What is the capital of France?" + model.generate([prompt], max_length=15) + + @pytest.mark.large + def test_class_detection(self): + preset_name = "hf://harshaljanjani/tiny-t5gemma-test" + model = Seq2SeqLM.from_preset( + preset_name, + load_weights=False, + ) + self.assertIsInstance(model, T5GemmaSeq2SeqLM) + model = Backbone.from_preset( + preset_name, + load_weights=False, + ) + self.assertIsInstance(model, T5GemmaBackbone) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 4accea67a1..bee392289f 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -18,6 +18,7 @@ from keras_hub.src.utils.transformers import convert_qwen from keras_hub.src.utils.transformers import convert_qwen3 from keras_hub.src.utils.transformers import convert_qwen_moe +from keras_hub.src.utils.transformers import convert_t5gemma from keras_hub.src.utils.transformers import convert_vit from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -59,6 +60,8 @@ def __init__(self, preset, config): self.converter = convert_qwen_moe elif model_type == "qwen3": self.converter = convert_qwen3 + elif model_type == "t5gemma": + self.converter = convert_t5gemma else: raise ValueError( "KerasHub has no converter for huggingface/transformers models " diff --git a/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py new file mode 100644 index 0000000000..e4cd550d9d --- /dev/null +++ b/tools/checkpoint_conversion/convert_t5gemma_checkpoints.py @@ -0,0 +1,468 @@ +import gc +import os +import random +import shutil + +import huggingface_hub +import keras +import numpy as np +import tensorflow as tf +import torch +import transformers +from absl import app +from absl import flags +from checkpoint_conversion_utils import get_md5_checksum + +import keras_hub +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm import T5GemmaSeq2SeqLM +from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import ( + T5GemmaSeq2SeqLMPreprocessor, +) + +random.seed(123) +torch.manual_seed(123) +device = torch.device("cpu") +# Force PyTorch to use CPU +torch.set_default_device(device) + +PRESET_MAP = { + "t5gemma_s_s_ul2": "google/t5gemma-s-s-ul2", + "t5gemma_s_s_prefixlm": "google/t5gemma-s-s-prefixlm", + "t5gemma_s_s_ul2_it": "google/t5gemma-s-s-ul2-it", + "t5gemma_s_s_prefixlm_it": "google/t5gemma-s-s-prefixlm-it", + "t5gemma_b_b_ul2": "google/t5gemma-b-b-ul2", + "t5gemma_b_b_prefixlm": "google/t5gemma-b-b-prefixlm", + "t5gemma_b_b_ul2_it": "google/t5gemma-b-b-ul2-it", + "t5gemma_b_b_prefixlm_it": "google/t5gemma-b-b-prefixlm-it", + "t5gemma_l_l_ul2": "google/t5gemma-l-l-ul2", + "t5gemma_l_l_prefixlm": "google/t5gemma-l-l-prefixlm", + "t5gemma_l_l_ul2_it": "google/t5gemma-l-l-ul2-it", + "t5gemma_l_l_prefixlm_it": "google/t5gemma-l-l-prefixlm-it", + "t5gemma_ml_ml_ul2": "google/t5gemma-ml-ml-ul2", + "t5gemma_ml_ml_prefixlm": "google/t5gemma-ml-ml-prefixlm", + "t5gemma_ml_ml_ul2_it": "google/t5gemma-ml-ml-ul2-it", + "t5gemma_ml_ml_prefixlm_it": "google/t5gemma-ml-ml-prefixlm-it", + "t5gemma_xl_xl_ul2": "google/t5gemma-xl-xl-ul2", + "t5gemma_xl_xl_prefixlm": "google/t5gemma-xl-xl-prefixlm", + "t5gemma_xl_xl_ul2_it": "google/t5gemma-xl-xl-ul2-it", + "t5gemma_xl_xl_prefixlm_it": "google/t5gemma-xl-xl-prefixlm-it", + "t5gemma_2b_2b_ul2": "google/t5gemma-2b-2b-ul2", + "t5gemma_2b_2b_prefixlm": "google/t5gemma-2b-2b-prefixlm", + "t5gemma_2b_2b_ul2_it": "google/t5gemma-2b-2b-ul2-it", + "t5gemma_2b_2b_prefixlm_it": "google/t5gemma-2b-2b-prefixlm-it", + "t5gemma_9b_2b_ul2": "google/t5gemma-9b-2b-ul2", + "t5gemma_9b_2b_prefixlm": "google/t5gemma-9b-2b-prefixlm", + "t5gemma_9b_2b_ul2_it": "google/t5gemma-9b-2b-ul2-it", + "t5gemma_9b_2b_prefixlm_it": "google/t5gemma-9b-2b-prefixlm-it", + "t5gemma_9b_9b_ul2": "google/t5gemma-9b-9b-ul2", + "t5gemma_9b_9b_prefixlm": "google/t5gemma-9b-9b-prefixlm", + "t5gemma_9b_9b_ul2_it": "google/t5gemma-9b-9b-ul2-it", + "t5gemma_9b_9b_prefixlm_it": "google/t5gemma-9b-9b-prefixlm-it", +} + + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +) + + +def convert_checkpoints(hf_model): + """Convert Hugging Face weights to Keras Hub format.""" + print("\n-> Convert original weights to KerasHub format.") + + print("\n-> Load KerasHub model.") + encoder_config = hf_model.config.encoder + decoder_config = hf_model.config.decoder + if decoder_config.hidden_activation == "gelu_pytorch_tanh": + decoder_config.hidden_activation = "gelu_approximate" + if encoder_config.hidden_activation == "gelu_pytorch_tanh": + encoder_config.hidden_activation = "gelu_approximate" + keras.config.set_floatx("float32") + keras_hub_model = keras_hub.models.T5GemmaBackbone( + vocabulary_size=decoder_config.vocab_size, + encoder_hidden_dim=encoder_config.hidden_size, + encoder_intermediate_dim=encoder_config.intermediate_size, + encoder_num_layers=encoder_config.num_hidden_layers, + encoder_num_attention_heads=encoder_config.num_attention_heads, + encoder_num_key_value_heads=encoder_config.num_key_value_heads, + encoder_head_dim=encoder_config.head_dim, + encoder_layer_types=encoder_config.layer_types, + decoder_hidden_dim=decoder_config.hidden_size, + decoder_intermediate_dim=decoder_config.intermediate_size, + decoder_num_layers=decoder_config.num_hidden_layers, + decoder_num_attention_heads=decoder_config.num_attention_heads, + decoder_num_key_value_heads=decoder_config.num_key_value_heads, + decoder_head_dim=decoder_config.head_dim, + decoder_layer_types=decoder_config.layer_types, + dropout_rate=decoder_config.dropout_rate, + rms_norm_eps=decoder_config.rms_norm_eps, + query_pre_attn_scalar=decoder_config.query_pre_attn_scalar, + tie_word_embeddings=getattr( + hf_model.config, "tie_word_embeddings", True + ), + attention_bias=decoder_config.attention_bias, + hidden_activation=decoder_config.hidden_activation, + initializer_range=decoder_config.initializer_range, + attention_dropout=decoder_config.attention_dropout, + sliding_window=decoder_config.sliding_window, + cross_attention_hidden_size=encoder_config.hidden_size, + attn_logit_softcapping=decoder_config.attn_logit_softcapping, + final_logit_softcapping=decoder_config.final_logit_softcapping, + rope_max_wavelength=decoder_config.rope_theta, + dtype="float32", + ) + + hf_wts = hf_model.state_dict() + # Token embedding. + keras_hub_model.get_layer("encoder_token_embedding").embeddings.assign( + hf_wts["encoder.embed_tokens.weight"] + ) + keras_hub_model.get_layer("decoder_token_embedding").embeddings.assign( + hf_wts["decoder.embed_tokens.weight"] + ) + + # Encoder. + encoder_hidden_dim = keras_hub_model.encoder_hidden_dim + encoder_num_attention_heads = keras_hub_model.encoder_num_attention_heads + encoder_num_key_value_heads = keras_hub_model.encoder_num_key_value_heads + encoder_head_dim = keras_hub_model.encoder_head_dim + keras_hub_model.encoder_norm.scale.assign(hf_wts["encoder.norm.weight"]) + + for i in range(keras_hub_model.encoder_num_layers): + encoder_layer = keras_hub_model.get_layer(f"encoder_layer_{i}") + hf_prefix = f"encoder.layers.{i}" + + # Self-attention. + q_w = hf_wts[f"{hf_prefix}.self_attn.q_proj.weight"] + k_w = hf_wts[f"{hf_prefix}.self_attn.k_proj.weight"] + v_w = hf_wts[f"{hf_prefix}.self_attn.v_proj.weight"] + o_w = hf_wts[f"{hf_prefix}.self_attn.o_proj.weight"] + + encoder_layer.self_attn.query_dense.kernel.assign( + q_w.T.reshape( + encoder_hidden_dim, + encoder_num_attention_heads, + encoder_head_dim, + ).numpy() + ) + encoder_layer.self_attn.key_dense.kernel.assign( + k_w.T.reshape( + encoder_hidden_dim, + encoder_num_key_value_heads, + encoder_head_dim, + ).numpy() + ) + encoder_layer.self_attn.value_dense.kernel.assign( + v_w.T.reshape( + encoder_hidden_dim, + encoder_num_key_value_heads, + encoder_head_dim, + ).numpy() + ) + encoder_layer.self_attn.output_dense.kernel.assign( + o_w.T.reshape( + encoder_num_attention_heads, + encoder_head_dim, + encoder_hidden_dim, + ).numpy() + ) + + # MLP. + encoder_layer.mlp.gate_proj.kernel.assign( + hf_wts[f"{hf_prefix}.mlp.gate_proj.weight"].T.numpy() + ) + encoder_layer.mlp.up_proj.kernel.assign( + hf_wts[f"{hf_prefix}.mlp.up_proj.weight"].T.numpy() + ) + encoder_layer.mlp.down_proj.kernel.assign( + hf_wts[f"{hf_prefix}.mlp.down_proj.weight"].T.numpy() + ) + + # Layer norm. + encoder_layer.pre_self_attn_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.pre_self_attn_layernorm.weight"] + ) + encoder_layer.post_self_attn_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.post_self_attn_layernorm.weight"] + ) + encoder_layer.pre_feedforward_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.pre_feedforward_layernorm.weight"] + ) + encoder_layer.post_feedforward_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.post_feedforward_layernorm.weight"] + ) + + # Decoder. + decoder_hidden_dim = keras_hub_model.decoder_hidden_dim + decoder_num_attention_heads = keras_hub_model.decoder_num_attention_heads + decoder_num_key_value_heads = keras_hub_model.decoder_num_key_value_heads + decoder_head_dim = keras_hub_model.decoder_head_dim + cross_attention_hidden_size = keras_hub_model.cross_attention_hidden_size + keras_hub_model.decoder_norm.scale.assign(hf_wts["decoder.norm.weight"]) + + for i in range(keras_hub_model.decoder_num_layers): + decoder_layer = keras_hub_model.get_layer(f"decoder_layer_{i}") + hf_prefix = f"decoder.layers.{i}" + + # Self-attention. + q_w = hf_wts[f"{hf_prefix}.self_attn.q_proj.weight"] + k_w = hf_wts[f"{hf_prefix}.self_attn.k_proj.weight"] + v_w = hf_wts[f"{hf_prefix}.self_attn.v_proj.weight"] + o_w = hf_wts[f"{hf_prefix}.self_attn.o_proj.weight"] + decoder_layer.self_attn.query_dense.kernel.assign( + q_w.T.reshape( + decoder_hidden_dim, + decoder_num_attention_heads, + decoder_head_dim, + ).numpy() + ) + decoder_layer.self_attn.key_dense.kernel.assign( + k_w.T.reshape( + decoder_hidden_dim, + decoder_num_key_value_heads, + decoder_head_dim, + ).numpy() + ) + decoder_layer.self_attn.value_dense.kernel.assign( + v_w.T.reshape( + decoder_hidden_dim, + decoder_num_key_value_heads, + decoder_head_dim, + ).numpy() + ) + decoder_layer.self_attn.output_dense.kernel.assign( + o_w.T.reshape( + decoder_num_attention_heads, + decoder_head_dim, + decoder_hidden_dim, + ).numpy() + ) + + # Cross-attention. + q_w = hf_wts[f"{hf_prefix}.cross_attn.q_proj.weight"] + k_w = hf_wts[f"{hf_prefix}.cross_attn.k_proj.weight"] + v_w = hf_wts[f"{hf_prefix}.cross_attn.v_proj.weight"] + o_w = hf_wts[f"{hf_prefix}.cross_attn.o_proj.weight"] + decoder_layer.cross_attn.query_dense.kernel.assign( + q_w.T.reshape( + decoder_hidden_dim, + decoder_num_attention_heads, + decoder_head_dim, + ).numpy() + ) + decoder_layer.cross_attn.key_dense.kernel.assign( + k_w.T.reshape( + cross_attention_hidden_size, + decoder_num_key_value_heads, + decoder_head_dim, + ).numpy() + ) + decoder_layer.cross_attn.value_dense.kernel.assign( + v_w.T.reshape( + cross_attention_hidden_size, + decoder_num_key_value_heads, + decoder_head_dim, + ).numpy() + ) + decoder_layer.cross_attn.output_dense.kernel.assign( + o_w.T.reshape( + decoder_num_attention_heads, + decoder_head_dim, + decoder_hidden_dim, + ).numpy() + ) + + # MLP. + decoder_layer.mlp.gate_proj.kernel.assign( + hf_wts[f"{hf_prefix}.mlp.gate_proj.weight"].T.numpy() + ) + decoder_layer.mlp.up_proj.kernel.assign( + hf_wts[f"{hf_prefix}.mlp.up_proj.weight"].T.numpy() + ) + decoder_layer.mlp.down_proj.kernel.assign( + hf_wts[f"{hf_prefix}.mlp.down_proj.weight"].T.numpy() + ) + + # Layer norm. + decoder_layer.pre_self_attn_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.pre_self_attn_layernorm.weight"] + ) + decoder_layer.post_self_attn_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.post_self_attn_layernorm.weight"] + ) + decoder_layer.pre_cross_attn_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.pre_cross_attn_layernorm.weight"] + ) + decoder_layer.post_cross_attn_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.post_cross_attn_layernorm.weight"] + ) + decoder_layer.pre_feedforward_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.pre_feedforward_layernorm.weight"] + ) + decoder_layer.post_feedforward_layernorm.scale.assign( + hf_wts[f"{hf_prefix}.post_feedforward_layernorm.weight"] + ) + + return keras_hub_model + + +def extract_vocab(hf_model_dir): + """Extract vocabulary from the downloaded Hugging Face model directory.""" + source_path = os.path.join(hf_model_dir, "tokenizer.model") + vocabulary_path = os.path.join(FLAGS.preset, "tokenizer.model") + print(f"\n-> Save KerasHub vocab to `{vocabulary_path}`.") + + shutil.copyfile(source_path, vocabulary_path) + + keras_hub_tokenizer = keras_hub.models.T5GemmaTokenizer( + proto=vocabulary_path + ) + + print("-> Print MD5 checksum of the vocab file.") + print(f"`{vocabulary_path}` md5sum: ", get_md5_checksum(vocabulary_path)) + + return keras_hub_tokenizer + + +def check_output( + keras_hub_tokenizer, + keras_hub_model, + hf_tokenizer, + hf_model, +): + """Check the outputs of the Keras Hub and Hugging Face models.""" + print("\n-> Check the outputs.") + enc_sample_text = [ + "cricket is awesome, easily the best sport in the world!" + ] + dec_sample_text = [ + "football is good too, but nowhere near as good as cricket." + ] + + # KerasHub. + keras_hub_enc_token_ids = hf_tokenizer( + enc_sample_text, return_tensors="tf" + )["input_ids"] + keras_hub_dec_token_ids = hf_tokenizer( + dec_sample_text, return_tensors="tf" + )["input_ids"] + keras_hub_dec_token_ids = tf.concat( + [ + tf.constant([[keras_hub_tokenizer.start_token_id]]), + keras_hub_dec_token_ids, + ], + axis=-1, + ) + keras_hub_inputs = { + "encoder_token_ids": keras_hub_enc_token_ids, + "encoder_padding_mask": tf.ones_like(keras_hub_enc_token_ids), + "decoder_token_ids": keras_hub_dec_token_ids, + "decoder_padding_mask": tf.ones_like(keras_hub_dec_token_ids), + } + keras_hub_output = keras_hub_model.predict(keras_hub_inputs) + + # HF. + hf_enc_inputs = hf_tokenizer(enc_sample_text, return_tensors="pt") + hf_dec_inputs = hf_tokenizer(dec_sample_text, return_tensors="pt") + hf_decoder_input_ids = torch.cat( + [ + torch.tensor([[hf_tokenizer.bos_token_id]]), + hf_dec_inputs["input_ids"], + ], + dim=-1, + ) + hf_decoder_attention_mask = torch.cat( + [torch.ones(1, 1, dtype=torch.long), hf_dec_inputs["attention_mask"]], + dim=-1, + ) + + hf_output = hf_model( + **hf_enc_inputs, + decoder_input_ids=hf_decoder_input_ids, + decoder_attention_mask=hf_decoder_attention_mask, + ) + + print("Encoder Outputs:") + print( + "KerasHub output:", + keras_hub_output["encoder_sequence_output"][0, 0, :10], + ) + print("HF output:", hf_output.encoder_last_hidden_state[0, 0, :10]) + print( + "Difference:", + np.mean( + keras_hub_output["encoder_sequence_output"] + - hf_output.encoder_last_hidden_state.detach().numpy() + ), + ) + + print("Decoder Outputs:") + print( + "KerasHub output:", + keras_hub_output["decoder_sequence_output"][0, 0, :10], + ) + print("HF output:", hf_output.last_hidden_state[0, 0, :10]) + print( + "Difference:", + np.mean( + keras_hub_output["decoder_sequence_output"] + - hf_output.last_hidden_state.detach().numpy() + ), + ) + + +def main(_): + os.makedirs(FLAGS.preset, exist_ok=True) + + hf_model_name = PRESET_MAP[FLAGS.preset] + + print("\n-> Download HF model files.") + hf_model_dir = huggingface_hub.snapshot_download( + repo_id=hf_model_name, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + ) + + print("\n-> Load HF model and HF tokenizer.") + hf_model = transformers.AutoModel.from_pretrained(hf_model_dir) + hf_model.eval() + hf_tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model_dir) + + keras_hub_model = convert_checkpoints(hf_model) + print("\n-> Load KerasHub tokenizer.") + keras_hub_tokenizer = extract_vocab(hf_model_dir) + + check_output( + keras_hub_tokenizer, + keras_hub_model, + hf_tokenizer, + hf_model, + ) + print("\n-> Releasing HF backbone from memory.") + del hf_model + gc.collect() + preprocessor = T5GemmaSeq2SeqLMPreprocessor( + tokenizer=keras_hub_tokenizer, + encoder_sequence_length=512, + decoder_sequence_length=512, + ) + keras_lm = T5GemmaSeq2SeqLM( + backbone=keras_hub_model, + preprocessor=preprocessor, + dtype=keras_hub_model.dtype, + ) + keras_lm.compile(sampler="greedy") + + print(f"\n-> Saving T5GemmaSeq2SeqLM preset to `{FLAGS.preset}`.") + keras_lm.save_to_preset(FLAGS.preset) + print("-> Preset saved successfully.") + + print("\n-> Testing preset loading.") + keras_lm = Seq2SeqLM.from_preset(FLAGS.preset) + print("-> Preset loading verified successfully.") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)