diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index ecdebaa7f3..63ffedbcbd 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -108,6 +108,9 @@ from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( PaliGemmaImageConverter as PaliGemmaImageConverter, ) +from keras_hub.src.models.parseq.parseq_image_converter import ( + PARSeqImageConverter as PARSeqImageConverter, +) from keras_hub.src.models.resnet.resnet_image_converter import ( ResNetImageConverter as ResNetImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 8b6aa475e7..abe76ae9f3 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -430,6 +430,18 @@ from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( PaliGemmaTokenizer as PaliGemmaTokenizer, ) +from keras_hub.src.models.parseq.parseq_backbone import ( + PARSeqBackbone as PARSeqBackbone, +) +from keras_hub.src.models.parseq.parseq_causal_lm import ( + PARSeqCausalLM as PARSeqCausalLM, +) +from keras_hub.src.models.parseq.parseq_causal_lm_preprocessor import ( + PARSeqCausalLMPreprocessor as PARSeqCausalLMPreprocessor, +) +from keras_hub.src.models.parseq.parseq_tokenizer import ( + PARSeqTokenizer as PARSeqTokenizer, +) from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone as Phi3Backbone from keras_hub.src.models.phi3.phi3_causal_lm import ( Phi3CausalLM as Phi3CausalLM, diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 082078184f..fa7b3f6d96 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -65,6 +65,9 @@ from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( PaliGemmaTokenizer as PaliGemmaTokenizer, ) +from keras_hub.src.models.parseq.parseq_tokenizer import ( + PARSeqTokenizer as PARSeqTokenizer, +) from keras_hub.src.models.phi3.phi3_tokenizer import ( Phi3Tokenizer as Phi3Tokenizer, ) diff --git a/keras_hub/src/models/parseq/__init__.py b/keras_hub/src/models/parseq/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/parseq/parseq_backbone.py b/keras_hub/src/models/parseq/parseq_backbone.py new file mode 100644 index 0000000000..4c42552bed --- /dev/null +++ b/keras_hub/src/models/parseq/parseq_backbone.py @@ -0,0 +1,134 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.parseq.parseq_decoder import PARSeqDecoder + + +@keras_hub_export("keras_hub.models.PARSeqBackbone") +class PARSeqBackbone(Backbone): + """Scene Text Detection with PARSeq. + + Performs OCR in natural scenes using the PARSeq model described in [Scene + Text Recognition with Permuted Autoregressive Sequence Models]( + https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows + iterative decoding by performing an autoregressive decoding phase, followed + by a refinement phase. + + Args: + image_encoder: keras.Model. The image encoder model. + vocabulary_size: int. The size of the vocabulary. + max_label_length: int. The maximum length of the label sequence. + decoder_hidden_dim: int. The dimension of the decoder hidden layers. + num_decoder_layers: int. The number of decoder layers. + num_decoder_heads: int. The number of attention heads in the decoder. + decoder_mlp_dim: int. The dimension of the decoder MLP hidden layer. + dropout_rate: float. The dropout rate for the decoder network. + Defaults to `0.1`. + attention_dropout: float. The dropout rate for the attention weights. + Defaults to `0.1`. + dtype: str. `None`, str, or `keras.mixed_precision.DTypePolicy`. The + dtype to use for the computations and weights. + **kwargs: Additional keyword arguments passed to the base + `keras.Model` constructor. + """ + + def __init__( + self, + image_encoder, + vocabulary_size, + max_label_length, + decoder_hidden_dim, + num_decoder_layers, + num_decoder_heads, + decoder_mlp_dim, + dropout_rate=0.1, + attention_dropout=0.1, + dtype=None, + **kwargs, + ): + # === Layers === + self.image_encoder = image_encoder + self.decoder = PARSeqDecoder( + vocabulary_size=vocabulary_size, + max_label_length=max_label_length, + num_layers=num_decoder_layers, + num_heads=num_decoder_heads, + hidden_dim=decoder_hidden_dim, + mlp_dim=decoder_mlp_dim, + dropout_rate=dropout_rate, + attention_dropout=attention_dropout, + name="decoder", + dtype=dtype, + ) + self.head = keras.layers.Dense( + vocabulary_size - 2, # We don't predict nor + dtype=dtype, + ) + + # === Functional Model === + image_input = self.image_encoder.input + + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + + memory = self.image_encoder(image_input) + target_out = self.decoder( + token_id_input, memory, padding_mask=padding_mask_input + ) + logits = self.head(target_out) + + # === Config === + self.vocabulary_size = vocabulary_size + self.max_label_length = max_label_length + self.decoder_hidden_dim = decoder_hidden_dim + self.num_decoder_layers = num_decoder_layers + self.num_decoder_heads = num_decoder_heads + self.decoder_mlp_dim = decoder_mlp_dim + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + + super().__init__( + inputs={ + "images": image_input, + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=logits, + dtype=dtype, + **kwargs, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "image_encoder": keras.layers.serialize(self.image_encoder), + "vocabulary_size": self.vocabulary_size, + "max_label_length": self.max_label_length, + "decoder_hidden_dim": self.decoder_hidden_dim, + "num_decoder_layers": self.num_decoder_layers, + "num_decoder_heads": self.num_decoder_heads, + "decoder_mlp_dim": self.decoder_mlp_dim, + "dropout_rate": self.dropout_rate, + "attention_dropout": self.attention_dropout, + } + ) + + return config + + @classmethod + def from_config(cls, config): + config.update( + { + "image_encoder": keras.layers.deserialize( + config["image_encoder"] + ), + } + ) + + return super().from_config(config) diff --git a/keras_hub/src/models/parseq/parseq_backbone_test.py b/keras_hub/src/models/parseq/parseq_backbone_test.py new file mode 100644 index 0000000000..4fbdaa8e6d --- /dev/null +++ b/keras_hub/src/models/parseq/parseq_backbone_test.py @@ -0,0 +1,107 @@ +import keras +import pytest +from keras import ops + +from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.tests.test_case import TestCase + + +class PARSeqBackboneTest(TestCase): + def setUp(self): + self.batch_size = 2 + self.image_height = 32 + self.image_width = 128 + self.num_channels = 3 + + # Image Encoder parameters (as per your example) + self.vit_patch_size = (4, 8) + self.vit_num_layers = 2 + self.vit_num_heads = 2 + self.vit_hidden_dim = 64 + self.vit_mlp_dim = self.vit_hidden_dim * 4 + + # PARSeq Backbone parameters + self.vocabulary_size = 97 + self.max_label_length = 25 + self.decoder_hidden_dim = self.vit_hidden_dim + self.num_decoder_layers = 1 + self.num_decoder_heads = 2 + self.decoder_mlp_dim = self.decoder_hidden_dim * 4 + + # Instantiate the actual ViTBackbone to be used as the image_encoder + self.image_encoder = ViTBackbone( + image_shape=( + self.image_height, + self.image_width, + self.num_channels, + ), + patch_size=self.vit_patch_size, + num_layers=self.vit_num_layers, + num_heads=self.vit_num_heads, + hidden_dim=self.vit_hidden_dim, + mlp_dim=self.vit_mlp_dim, + use_class_token=False, + name="image_encoder", + ) + + self.init_kwargs = { + "image_encoder": self.image_encoder, + "vocabulary_size": self.vocabulary_size, + "max_label_length": self.max_label_length, + "decoder_hidden_dim": self.decoder_hidden_dim, + "num_decoder_layers": self.num_decoder_layers, + "num_decoder_heads": self.num_decoder_heads, + "decoder_mlp_dim": self.decoder_mlp_dim, + "dropout_rate": 0.0, + "attention_dropout": 0.0, + } + + # Dummy input data + dummy_images = keras.random.normal( + shape=( + self.batch_size, + self.image_height, + self.image_width, + self.num_channels, + ), + ) + + dummy_token_ids = keras.random.randint( + minval=0, + maxval=self.vocabulary_size, + shape=(self.batch_size, self.max_label_length), + ) + dummy_padding_mask = ops.ones( + shape=(self.batch_size, self.max_label_length), dtype="int32" + ) + + self.input_data = { + "images": dummy_images, + "token_ids": dummy_token_ids, + "padding_mask": dummy_padding_mask, + } + + def test_backbone_basics(self): + expected_shape_full = ( + self.batch_size, + self.max_label_length, + self.vocabulary_size - 2, + ) + + self.run_backbone_test( + cls=PARSeqBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=expected_shape_full, + # we have image_encoder as init_kwargs which is also a backbone + run_quantization_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=PARSeqBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/parseq/parseq_causal_lm.py b/keras_hub/src/models/parseq/parseq_causal_lm.py new file mode 100644 index 0000000000..c518a4bca7 --- /dev/null +++ b/keras_hub/src/models/parseq/parseq_causal_lm.py @@ -0,0 +1,466 @@ +import math + +import keras +from keras import ops +from keras import random + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone +from keras_hub.src.models.parseq.parseq_causal_lm_preprocessor import ( + PARSeqCausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export("keras_hub.models.PARSeqCausalLM") +class PARSeqCausalLM(CausalLM): + """Scene Text Recognition with PARSeq. + Performs OCR in natural scenes using the PARSeq model described in + [Scene Text Recognition with Permuted Autoregressive Sequence Models]( + https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows + iterative decoding by performing an autoregressive decoding phase, followed + by a refinement phase. + Args: + preprocessor: A `keras_hub.models.Preprocessor` instance or a + `keras.Layer` instance. The preprocessor to use for the model. + backbone: A `keras_hub.models.PARSeqBackbone` instance or a + `keras.Model`. The backbone model to use for the model. + num_perms: int. The number of permutations to generate for training. + Defaults to 6. + add_forward_perms: bool. Whether to add forward permutations to the + generated permutations. Defaults to `True`. + add_mirrored_perms: bool. Whether to add mirrored permutations to the + generated permutations. Defaults to `True`. + seed: int. The random seed to use for generating permutations. + Defaults to `None`, which means no seed is set. + **kwargs: Additional keyword arguments passed to the base + `keras_hub.models.CausalLM` constructor. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and run inference + images = np.random.randint(0, 256, size=(2, 32, 128, 3)) + parseq = keras_hub.models.PARSeqCausalLM.from_preset( + "parseq_vit" + ) + parseq.generate(images) + + # Call `fit()` on a single batch. + images = np.random.randint(0, 256, size=(2, 32, 128, 3)) + token_ids = np.array([[1, 2, 3, 4], [1, 2, 3, 0]]) + padding_mask = np.array([[1, 1, 1, 1], [1, 1, 1, 0]]) + parseq = keras_hub.models.PARSeqCausalLM.from_preset( + "parseq_vit" + ) + parseq.fit( + x={ + "images": images, + "token_ids": token_ids, + "padding_mask": padding_mask + }, + batch_size=2, + ) + ``` + # Call `fit()` with custom loss, optimizer and image encoder. + ```python + # Initialize the image encoder, preprocessor and tokenizer + mean, std = 0.5, 0.5 + image_converter = PARSeqImageConverter( + image_size=(32, 128), + offset=-mean / std, + scale=1.0 / 255.0 / std, + interpolation="bicubic", + ) + tokenizer = PARSeqTokenizer(max_label_length=25) + preprocessor = keras_hub.models.PARSeqCausalLMPreprocessor( + image_converter=image_converter, + tokenizer=tokenizer, + ) + + # Create the backbone + image_encoder = ViTBackbone( + image_shape=(32, 128, 3), + patch_size=(4, 8), + num_layers=12, + num_heads=6, + hidden_dim=384, + mlp_dim=384 * 4, + use_class_token=False, + name="encoder", + ) + backbone = PARSeqBackbone( + vocabulary_size=97, + max_label_length=25, + image_encoder=image_encoder, + num_decoder_heads=12, + num_decoder_layers=1, + decoder_hidden_dim=384, + decoder_mlp_dim=4 * 384, + ) + # Create the PARSeq model + parseq = keras_hub.models.PARSeqCausalLM( + backbone=backbone, + preprocessor=preprocessor, + ) + parseq.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + parseq.fit( + x={ + "images": images, + "token_ids": token_ids, + "padding_mask": padding_mask + }, + batch_size=2, + ) + ``` + """ + + backbone_cls = PARSeqBackbone + preprocessor_cls = PARSeqCausalLMPreprocessor + + def __init__( + self, + preprocessor, + backbone, + num_perms=6, + add_forward_perms=True, + add_mirrored_perms=True, + seed=None, + end_token_id=0, # default tokenizer.end_token_id + **kwargs, + ): + # === Layers === + self.preprocessor = preprocessor + self.backbone = backbone + + # === Functional Model === + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input + outputs = backbone(inputs=inputs) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_perms = num_perms + self.add_forward_perms = add_forward_perms + self.add_mirrored_perms = add_mirrored_perms + self.end_token_id = end_token_id + self.seed = seed + self.seed_generator = keras.random.SeedGenerator(seed) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_perms": self.num_perms, + "add_forward_perms": self.add_forward_perms, + "add_mirrored_perms": self.add_mirrored_perms, + "seed": self.seed, + "end_token_id": self.end_token_id, + } + ) + + return config + + def compile( + self, + optimizer="auto", + loss="auto", + *, + weighted_metrics="auto", + sampler="greedy", + **kwargs, + ): + if loss == "auto": + loss = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, + ignore_class=self.preprocessor.tokenizer.pad_token_id, + ) + super().compile( + optimizer=optimizer, + loss=loss, + weighted_metrics=weighted_metrics, + sampler=sampler, + **kwargs, + ) + + def compute_loss( + self, x, y, y_pred, sample_weight, training=True, *args, **kwargs + ): + # For keras we have fixed input for all batches, so in this case + # we permute 23 tokens excluding BOS and EOS tokens instead of max + # characters for current batch used in torch implementation + # -1 because we will be generating permutation mask for considering + # tokens before creating target label. + max_num_chars = self.backbone.max_label_length - 1 + perms = self.generate_training_permutations(max_num_chars) + max_label_length = self.backbone.max_label_length + memory = self.backbone.image_encoder(x["images"]) + batch_size = ops.shape(x["images"])[0] + losses = [] + for i in range(ops.shape(perms)[0]): + query_mask, content_mask = self.generate_attention_masks(perms[i]) + query_mask = ops.broadcast_to( + query_mask, (batch_size, max_label_length, max_label_length) + ) + content_mask = ops.broadcast_to( + content_mask, (batch_size, max_label_length, max_label_length) + ) + out = self.backbone.decoder( + x["token_ids"], + memory, + padding_mask=x["padding_mask"], + query_mask=query_mask, + content_mask=content_mask, + ) + y_pred = self.backbone.head(out) + loss = super().compute_loss( + x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, **kwargs + ) + losses.append(loss) + if i == 1: + # Sample weights are set to zero for end-of-sequence (EOS) + # tokens to prevent them from affecting loss calculations. + # reference: https://github.com/baudm/parseq/blob/1902db043c029a7e03a3818c616c06600af574be/strhub/models/parseq/system.py#L194 # noqa: E501 + sample_weight = ops.logical_and( + y != self.end_token_id, sample_weight + ) + + return ops.sum(losses) / ops.shape(perms)[0] + + def generate_training_permutations(self, max_num_chars): + max_gen_perms = ( + self.num_perms // 2 if self.add_mirrored_perms else self.num_perms + ) + + if max_num_chars == 1: + return ops.expand_dims(ops.arange(3), axis=0) + + perms = [ops.arange(max_num_chars)] if self.add_forward_perms else [] + max_num_perms = math.factorial(max_num_chars) + max_gen_perms = min(max_gen_perms, max_num_perms) + + for _ in range(max_gen_perms - len(perms)): + perm = random.shuffle( + ops.arange(max_num_chars), seed=self.seed_generator + ) + perms.append(perm) + + perms = ops.stack(perms) + comp = ops.flip(perms, axis=-1) + perms = ops.stack([perms, comp]) + perms = ops.reshape( + ops.transpose(perms, (1, 0, 2)), (-1, max_num_chars) + ) + + bos_idx = ops.zeros((ops.shape(perms)[0], 1), dtype="int32") + eos_idx = ops.full( + (ops.shape(perms)[0], 1), max_num_chars + 1, dtype="int32" + ) + perms = ops.concatenate([bos_idx, perms + 1, eos_idx], axis=1) + + if perms.shape[0] > 1: + perms = ops.scatter_update( + perms, + ops.concatenate( + [ + ops.ones((max_num_chars + 1, 1), dtype="int32"), + ops.expand_dims( + ops.arange(1, max_num_chars + 2, dtype="int32"), + axis=1, + ), + ], + axis=1, + ), + max_num_chars + 1 - ops.arange(max_num_chars + 1), + ) + + return perms + + def generate_attention_masks(self, perm): + """Generate attention masks given a sequence permutation + (includes pos. for BOS and EOS tokens)""" + input_length = ops.shape(perm)[0] + mask = ops.ones((input_length, input_length)) + for i in range(input_length - 1): + masked_keys = perm[i + 1 : input_length] + query_idx = ops.broadcast_to(perm[i], ops.shape(masked_keys)) + indices = ops.stack((query_idx, masked_keys), axis=1) + mask = keras.ops.scatter_update( + mask, indices, keras.ops.zeros(ops.shape(masked_keys)[0]) + ) + content_mask = mask[:-1, :-1] + mask = mask * (1 - ops.eye(input_length)) + query_mask = mask[1:, :-1] + return query_mask, content_mask + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + img_embeddings, + padding_mask=None, + ): + bs = ops.shape(token_ids)[0] + # stands for the null context. We only supply position information + # for characters after . + content = ops.where( + cache_update_index == 0, + self.backbone.decoder_hidden_dim**0.5 + * self.backbone.decoder.token_embedding(token_ids), + ops.expand_dims( + self.backbone.decoder.pos_query_embeddings[ + :, cache_update_index - 1, : + ], + axis=0, + ) + + self.backbone.decoder_hidden_dim**0.5 + * self.backbone.decoder.token_embedding(token_ids), + ) + content = self.backbone.decoder.dropout(content) + + query = ops.ones((bs, 1, 1)) * ops.expand_dims( + self.backbone.decoder.pos_query_embeddings[ + :, cache_update_index, : + ], + axis=0, + ) + query = self.backbone.decoder.dropout(query) + + query_cache = [] + content_cache = [] + for i, decoder_layer in enumerate(self.backbone.decoder.decoder_layers): + last = i == self.backbone.num_decoder_layers - 1 + current_query_cache = cache[:, i, 0, ...] + current_content_cache = cache[:, i, 1, ...] + ( + query, + content, + query_self_attention_new_cache, + content_self_attention_cache, + ) = decoder_layer( + query=query, + content=content, + memory=img_embeddings, + padding_mask=padding_mask, + update_content=not last, + query_self_attention_cache=current_query_cache, + query_self_attention_cache_update_index=cache_update_index, + content_self_attention_cache=current_content_cache, + content_self_attention_cache_update_index=cache_update_index, + ) + query_cache.append(query_self_attention_new_cache) + content_cache.append(content_self_attention_cache) + + query_cache = ops.stack(query_cache, axis=1) + content_cache = ops.stack(content_cache, axis=1) + cache = ops.stack([query_cache, content_cache], axis=2) + hidden_states = self.backbone.decoder.layer_norm(query) + logits = self.backbone.head(hidden_states) + return logits, hidden_states, cache + + def _build_cache(self, token_ids, img_embeddings, padding_mask): + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_decoder_layers + head_dim = ( + self.backbone.decoder_hidden_dim // self.backbone.num_decoder_heads + ) + num_heads = self.backbone.num_decoder_heads + shape = [batch_size, num_layers, 2, 2, max_length, num_heads, head_dim] + cache = ops.zeros(shape) + + # Seed the cache. + logits, hidden_states, cache = self.call_with_cache( + token_ids=token_ids, + img_embeddings=img_embeddings, + cache=cache, + cache_update_index=0, + padding_mask=padding_mask, + ) + return hidden_states, cache + + def generate_step(self, inputs, stop_token_ids=None): + token_ids, padding_mask, images = ( + inputs["token_ids"], + inputs["padding_mask"], + inputs["images"], + ) + images_shape = ops.shape(images) + if len(images_shape) == 3: + # Handle an unbatched image. Unlike `token_ids` and `padding_mask` + # this will not automatically be upranked. + images = ops.expand_dims(images, axis=0) + + img_embeddings = self.backbone.image_encoder(images) + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache( + token_ids=token_ids, + img_embeddings=img_embeddings, + padding_mask=padding_mask, + ) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, index - 1], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + token_ids=prompt, + cache=cache, + cache_update_index=cache_update_index, + img_embeddings=img_embeddings, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of `stop_token_ids` locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + "images": images, + } diff --git a/keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py b/keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py new file mode 100644 index 0000000000..61df7b9e38 --- /dev/null +++ b/keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py @@ -0,0 +1,168 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone +from keras_hub.src.models.parseq.parseq_image_converter import ( + PARSeqImageConverter, +) +from keras_hub.src.models.parseq.parseq_tokenizer import PARSeqTokenizer +from keras_hub.src.utils.tensor_utils import preprocessing_function +from keras_hub.src.utils.tensor_utils import strip_to_ragged + + +@keras_hub_export("keras_hub.models.PARSeqCausalLMPreprocessor") +class PARSeqCausalLMPreprocessor(CausalLMPreprocessor): + backbone_cls = PARSeqBackbone + tokenizer_cls = PARSeqTokenizer + image_converter_cls = PARSeqImageConverter + + def __init__( + self, + image_converter=None, + tokenizer=None, + sequence_length=25, + add_start_token=True, + add_end_token=True, + **kwargs, + ): + super().__init__( + tokenizer=tokenizer, + sequence_length=sequence_length, + add_start_token=add_start_token, + add_end_token=add_end_token, + **kwargs, + ) + self.image_converter = image_converter + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + @preprocessing_function + def call(self, x, y=None, sample_weight=None, sequence_length=None): + """Preprocesses the input data for training. + + This method takes a dictionary containing images and text responses, + and converts them into a format suitable for training a PARSeq model. + + Args: + x: dict. A dictionary containing the input data. Must have keys + "images" and "responses". + y: The target data. Defaults to None. + sample_weight: The sample weights. Defaults to None. + sequence_length: int. The maximum length of the input sequence. + Defaults to None, which uses the pre-defined sequence length. + + Returns: + A tuple containing the preprocessed input data, target data, and + sample weights. + """ + sequence_length = sequence_length or self.sequence_length + images, responses = x["images"], x["responses"] + if self.image_converter: + images = self.image_converter(images) + token_ids = self.tokenizer(responses) + token_ids, padding_mask = self.packer( + token_ids, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "images": images, + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + @preprocessing_function + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Convert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + sequence_length = sequence_length or self.sequence_length + images = x + if self.image_converter: + images = self.image_converter(images) + + images_shape = keras.ops.shape(images) + if len(images_shape) == 3: + batch_size = 1 + else: + batch_size = images_shape[0] + + token_ids = ops.concatenate( + ( + ops.full([batch_size, 1], self.tokenizer.start_token_id), + ops.full( + [batch_size, sequence_length - 1], + self.tokenizer.pad_token_id, + ), + ), + axis=1, + ) + + padding_mask = ops.equal(token_ids, self.tokenizer.start_token_id) + + return { + "images": images, + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + @preprocessing_function + def generate_postprocess( + self, + x, + ): + """Convert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + if not self.built: + self.build(None) + + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + ids_to_strip = self.tokenizer.special_token_ids + token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) + return self.tokenizer.detokenize(token_ids) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config diff --git a/keras_hub/src/models/parseq/parseq_causal_lm_preprocessor_test.py b/keras_hub/src/models/parseq/parseq_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..ffab17c228 --- /dev/null +++ b/keras_hub/src/models/parseq/parseq_causal_lm_preprocessor_test.py @@ -0,0 +1,110 @@ +import numpy as np +import pytest + +from keras_hub.src.models.parseq.parseq_causal_lm_preprocessor import ( + PARSeqCausalLMPreprocessor, +) +from keras_hub.src.models.parseq.parseq_image_converter import ( + PARSeqImageConverter, +) +from keras_hub.src.models.parseq.parseq_tokenizer import PARSeqTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class PARSeqCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = PARSeqTokenizer() + self.image_converter = PARSeqImageConverter(image_size=(32, 128)) + + self.init_kwargs = { + "tokenizer": self.tokenizer, + "image_converter": self.image_converter, + "sequence_length": 9, + } + self.input_data = { + "images": [np.zeros([32, 128, 3])], + "responses": ["Google"], + } + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=PARSeqCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[95, 43, 25, 25, 17, 22, 15, 0, 96]], + "padding_mask": [ + [True, True, True, True, True, True, True, True, False] + ], + "images": np.zeros([1, 32, 128, 3]), + }, + [[43, 25, 25, 17, 22, 15, 0, 96, 96]], # Labels shifted. + [[True, True, True, True, True, True, True, False, False]], + ), + ) + + def test_no_start_end_token(self): + input_data = { + "responses": ["Google"] * 4, + "images": [np.zeros([512, 512, 3])] * 4, + } + preprocessor = PARSeqCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual( + x["token_ids"], [[43, 25, 25, 17, 22, 15, 96, 96, 96]] * 4 + ) + self.assertAllEqual( + x["padding_mask"], + [[True, True, True, True, True, True, False, False, False]] * 4, + ) + self.assertAllEqual(x["images"], np.zeros([4, 32, 128, 3])) + self.assertAllEqual(y, [[25, 25, 17, 22, 15, 96, 96, 96, 96]] * 4) + self.assertAllEqual( + sw, [[True, True, True, True, True, False, False, False, False]] * 4 + ) + + def test_generate_preprocess(self): + input_data = np.zeros([1, 32, 128, 3]) + preprocessor = PARSeqCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual( + x["token_ids"], [[95, 96, 96, 96, 96, 96, 96, 96, 96]] + ) + self.assertAllEqual( + x["padding_mask"], + [[True, False, False, False, False, False, False, False, False]], + ) + self.assertAllEqual(x["images"], np.zeros([1, 32, 128, 3])) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [43, 25, 25, 17, 22, 15, 0, 96, 96], + "padding_mask": [ + True, + True, + True, + True, + True, + True, + True, + False, + False, + ], + } + preprocessor = PARSeqCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, ["G", "o", "o", "g", "l", "e"]) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in PARSeqCausalLMPreprocessor.presets: + self.run_preset_test( + cls=PARSeqCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/parseq/parseq_causal_lm_test.py b/keras_hub/src/models/parseq/parseq_causal_lm_test.py new file mode 100644 index 0000000000..177c596521 --- /dev/null +++ b/keras_hub/src/models/parseq/parseq_causal_lm_test.py @@ -0,0 +1,103 @@ +import numpy as np +import pytest + +from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone +from keras_hub.src.models.parseq.parseq_causal_lm import PARSeqCausalLM +from keras_hub.src.models.parseq.parseq_causal_lm_preprocessor import ( + PARSeqCausalLMPreprocessor, +) +from keras_hub.src.models.parseq.parseq_image_converter import ( + PARSeqImageConverter, +) +from keras_hub.src.models.parseq.parseq_tokenizer import PARSeqTokenizer +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.tests.test_case import TestCase + + +class PARSeqCausalLMTest(TestCase): + def setUp(self): + self.batch_size = 2 + self.image_height = 32 + self.image_width = 128 + self.num_channels = 3 + + # Image Encoder parameters (as per your example) + self.vit_patch_size = (4, 8) + self.vit_num_layers = 2 + self.vit_num_heads = 2 + self.vit_hidden_dim = 64 + self.vit_mlp_dim = self.vit_hidden_dim * 4 + + # PARSeq Backbone parameters + self.vocabulary_size = 97 + self.max_label_length = 25 + self.decoder_hidden_dim = self.vit_hidden_dim + self.num_decoder_layers = 1 + self.num_decoder_heads = 2 + self.decoder_mlp_dim = self.decoder_hidden_dim * 4 + + image_converter = PARSeqImageConverter( + image_size=[32, 128], + offset=-1, + scale=1.0 / 255.0 / 0.5, + interpolation="bicubic", + ) + tokenizer = PARSeqTokenizer() + + preprocessor = PARSeqCausalLMPreprocessor( + image_converter=image_converter, tokenizer=tokenizer + ) + + image_encoder = ViTBackbone( + image_shape=( + self.image_height, + self.image_width, + self.num_channels, + ), + patch_size=self.vit_patch_size, + num_layers=self.vit_num_layers, + num_heads=self.vit_num_heads, + hidden_dim=self.vit_hidden_dim, + mlp_dim=self.vit_mlp_dim, + use_class_token=False, + name="image_encoder", + ) + + backbone = PARSeqBackbone( + image_encoder=image_encoder, + vocabulary_size=self.vocabulary_size, + max_label_length=self.max_label_length, + num_decoder_heads=self.num_decoder_heads, + num_decoder_layers=self.num_decoder_layers, + decoder_hidden_dim=self.decoder_hidden_dim, + decoder_mlp_dim=self.decoder_mlp_dim, + ) + + self.init_kwargs = {"preprocessor": preprocessor, "backbone": backbone} + + # Dummy input data + dummy_images = np.random.randn( + self.batch_size, + self.image_height, + self.image_width, + self.num_channels, + ) + + self.train_data = ( + {"images": dummy_images, "responses": ["abc", "xyz"]}, + ) + + @pytest.mark.large + def test_causal_lm_basics(self): + expected_shape_full = ( + self.batch_size, + self.max_label_length, + self.vocabulary_size - 2, + ) + + self.run_task_test( + cls=PARSeqCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=expected_shape_full, + ) diff --git a/keras_hub/src/models/parseq/parseq_decoder.py b/keras_hub/src/models/parseq/parseq_decoder.py new file mode 100644 index 0000000000..69303c7e67 --- /dev/null +++ b/keras_hub/src/models/parseq/parseq_decoder.py @@ -0,0 +1,418 @@ +import keras +from keras import ops + +from keras_hub.src.layers.modeling.cached_multi_head_attention import ( + CachedMultiHeadAttention, +) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_hub.src.models.vit.vit_layers import MLP + + +class PARSeqDecoderBlock(keras.layers.Layer): + """A decoder block for the PARSeq model. + + This block consists of self-attention, cross-attention, and a multilayer + perceptron (MLP). It also includes layer normalization and dropout layers. + + Args: + hidden_dim: int. The dimension of the hidden layers. + num_heads: int. The number of attention heads. + mlp_dim: int. The dimension of the MLP hidden layer. + dropout_rate: float. The dropout rate used in the feedforward layers. + attention_dropout: float. The dropout rate for the attention weights. + layer_norm_epsilon: float. A small float added to the denominator for + numerical stability in layer normalization. + **kwargs: Additional keyword arguments passed to the base + `keras.layers.Layer` constructor. + """ + + def __init__( + self, + hidden_dim, + num_heads, + mlp_dim, + dropout_rate=0.1, + attention_dropout=0.1, + layer_norm_epsilon=1e-5, + **kwargs, + ): + super().__init__(**kwargs) + + key_dim = hidden_dim // num_heads + + # === Config === + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.key_dim = key_dim + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + + def build(self, input_shape): + self.query_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + name="query_layer_norm", + dtype=self.dtype_policy, + ) + self.query_layer_norm.build(input_shape) + self.content_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + name="content_layer_norm", + dtype=self.dtype_policy, + ) + self.content_layer_norm.build(input_shape) + self.self_attention = CachedMultiHeadAttention( + num_heads=self.num_heads, + key_dim=self.key_dim, + dropout=self.attention_dropout, + name="self_attention", + dtype=self.dtype_policy, + ) + self.self_attention.build(input_shape, input_shape) + self.cross_attention = CachedMultiHeadAttention( + num_heads=self.num_heads, + key_dim=self.key_dim, + dropout=self.attention_dropout, + name="cross_attention", + dtype=self.dtype_policy, + ) + self.cross_attention.build(input_shape, input_shape) + + self.layer_norm_1 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + name="ln_1", + dtype=self.dtype_policy, + ) + self.layer_norm_1.build((None, None, self.hidden_dim)) + self.layer_norm_2 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + name="ln_2", + dtype=self.dtype_policy, + ) + self.layer_norm_2.build((None, None, self.hidden_dim)) + self.mlp = MLP( + hidden_dim=self.hidden_dim, + mlp_dim=self.mlp_dim, + dropout_rate=self.dropout_rate, + name="mlp", + dtype=self.dtype_policy, + ) + self.mlp.build((None, None, self.hidden_dim)) + self.dropout = keras.layers.Dropout( + rate=self.dropout_rate, + dtype=self.dtype_policy, + name="decoder_block_dropout", + ) + + self.built = True + + def forward_stream( + self, + target, + target_norm, + target_kv, + memory, + padding_mask=None, + self_attention_cache=None, + self_attention_cache_update_index=0, + train_attention_mask=None, + ): + self_attention_new_cache = None + if train_attention_mask is None: + target_attention_mask = self._compute_attention_mask( + target_norm, + padding_mask, + self_attention_cache, + self_attention_cache_update_index, + ) + else: + target_attention_mask = merge_padding_and_attention_mask( + target_norm, padding_mask, attention_mask=train_attention_mask + ) + + if self_attention_cache is not None: + target2, self_attention_new_cache = self.self_attention( + target_norm, + target_kv, + target_kv, + attention_mask=target_attention_mask, + cache=self_attention_cache, + cache_update_index=self_attention_cache_update_index, + ) + else: + target2 = self.self_attention( + target_norm, + target_kv, + target_kv, + attention_mask=target_attention_mask, + ) + target = ops.add(target, self.dropout(target2)) + target2 = self.cross_attention( + self.layer_norm_1(target), + memory, + memory, + ) + target = ops.add(target, self.dropout(target2)) + + target2 = self.mlp(self.layer_norm_2(target)) + target = ops.add(target, target2) + + return target, self_attention_new_cache + + def call( + self, + query, + content, + memory, + padding_mask=None, + update_content=True, + query_self_attention_cache=None, + query_self_attention_cache_update_index=0, + content_self_attention_cache=None, + content_self_attention_cache_update_index=0, + query_mask=None, + content_mask=None, + ): + # position + token embeddings + query_norm = self.query_layer_norm(query) + # position embeddings + content_norm = self.content_layer_norm(content) + ( + query, + query_self_attention_new_cache, + ) = self.forward_stream( + query, + query_norm, + content_norm, + memory, + padding_mask=padding_mask, + train_attention_mask=query_mask, + self_attention_cache=query_self_attention_cache, + self_attention_cache_update_index=query_self_attention_cache_update_index, + ) + + if update_content: + ( + content, + content_self_attention_new_cache, + ) = self.forward_stream( + content, + content_norm, + content_norm, + memory, # image embeddings (encoder embeddings) + padding_mask=padding_mask, + train_attention_mask=content_mask, + self_attention_cache=content_self_attention_cache, + self_attention_cache_update_index=content_self_attention_cache_update_index, + ) + + return_values = [query, content] + + if query_self_attention_cache is not None: + return_values.append(query_self_attention_new_cache) + if update_content and content_self_attention_cache is not None: + return_values.append(content_self_attention_new_cache) + elif not update_content and content_self_attention_cache is not None: + return_values.append(content_self_attention_cache) + + return tuple(return_values) + + def _compute_attention_mask( + self, x, padding_mask, cache, cache_update_index + ): + decoder_mask = merge_padding_and_attention_mask( + inputs=x, padding_mask=padding_mask, attention_mask=None + ) + batch_size = ops.shape(x)[0] + input_length = output_length = ops.shape(x)[1] + if cache is not None: + input_length = ops.shape(cache)[2] + + causal_mask = compute_causal_mask( + batch_size=batch_size, + input_length=input_length, + output_length=output_length, + cache_index=cache_update_index, + ) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "key_dim": self.key_dim, + "mlp_dim": self.mlp_dim, + "dropout_rate": self.dropout_rate, + "attention_dropout": self.attention_dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config + + +class PARSeqDecoder(keras.layers.Layer): + """The PARSeq decoder. + + This decoder consists of multiple decoder blocks and a token embedding + layer. It takes token IDs and memory from the encoder as input and outputs a + sequence of hidden states. + + Args: + vocabulary_size: int. The size of the vocabulary. + max_label_length: int. The maximum length of the label sequence. + num_layers: int. The number of decoder layers. + hidden_dim: int. The dimension of the hidden layers. + mlp_dim: int. The dimension of the MLP hidden layer. + num_heads: int. The number of attention heads. + dropout_rate: float. The dropout rate. + attention_dropout: float. The dropout rate for the attention weights. + layer_norm_epsilon: float. A small float added to the denominator for + numerical stability in layer normalization. + **kwargs: Additional keyword arguments passed to the base + `keras.layers.Layer` constructor. + """ + + def __init__( + self, + vocabulary_size, + max_label_length, + num_layers, + hidden_dim, + mlp_dim, + num_heads, + dropout_rate=0.1, + attention_dropout=0.1, + layer_norm_epsilon=1e-5, + **kwargs, + ): + super().__init__(**kwargs) + + # === Config === + self.vocabulary_size = vocabulary_size + self.max_label_length = max_label_length + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.num_layers = num_layers + + def build(self, input_shape): + self.token_embedding = keras.layers.Embedding( + input_dim=self.vocabulary_size, + output_dim=self.hidden_dim, + dtype=self.dtype_policy, + name="token_embedding", + ) + self.token_embedding.build((1, self.vocabulary_size)) + self.pos_query_embeddings = self.add_weight( + shape=(1, self.max_label_length + 1, self.hidden_dim), + name="pos_query_embeddings", + dtype=self.dtype, + ) + self.dropout = keras.layers.Dropout( + self.dropout_rate, dtype=self.dtype_policy, name="decoder_dropout" + ) + self.decoder_layers = [] + for i in range(self.num_layers): + decoder_layer = PARSeqDecoderBlock( + hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + mlp_dim=self.mlp_dim, + dropout_rate=self.dropout_rate, + attention_dropout=self.attention_dropout, + layer_norm_epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name=f"decoder_layer_{i}", + ) + decoder_layer.build((None, None, self.hidden_dim)) + self.decoder_layers.append(decoder_layer) + + self.layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="layer_norm", + ) + self.layer_norm.build((None, None, self.hidden_dim)) + self.built = True + + def call( + self, + token_ids, + memory, + padding_mask=None, + query_mask=None, + content_mask=None, + ): + bs, tokens_length = ops.shape(token_ids) + # stands for the null context. We only supply position information + # for characters after . + null_context = self.hidden_dim**0.5 * self.token_embedding( + token_ids[:, :1] + ) + if tokens_length > 1: + content = self.pos_query_embeddings[:, : tokens_length - 1, :] + content = content + self.hidden_dim**0.5 * self.token_embedding( + token_ids[:, 1:] + ) + content = ops.concatenate([null_context, content], axis=1) + else: + content = null_context + + content = self.dropout(content) + + query = ops.multiply( + ops.ones((bs, 1, 1), dtype=self.dtype), + self.pos_query_embeddings[:, :tokens_length, :], + ) + query = self.dropout(query) + + for i, decoder_layer in enumerate(self.decoder_layers): + last = i == self.num_layers - 1 + query, content = decoder_layer( + query=query, + content=content, + memory=memory, + padding_mask=padding_mask, + update_content=not last, + query_mask=query_mask, + content_mask=content_mask, + ) + + query = self.layer_norm(query) + + return query + + def compute_output_shape(self, input_shape): + return (None, None, self.hidden_dim) + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "max_label_length": self.max_label_length, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "mlp_dim": self.mlp_dim, + "dropout_rate": self.dropout_rate, + "attention_dropout": self.attention_dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config diff --git a/keras_hub/src/models/parseq/parseq_image_converter.py b/keras_hub/src/models/parseq/parseq_image_converter.py new file mode 100644 index 0000000000..86f5f85435 --- /dev/null +++ b/keras_hub/src/models/parseq/parseq_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone + + +@keras_hub_export("keras_hub.layers.PARSeqImageConverter") +class PARSeqImageConverter(ImageConverter): + backbone_cls = PARSeqBackbone diff --git a/keras_hub/src/models/parseq/parseq_tokenizer.py b/keras_hub/src/models/parseq/parseq_tokenizer.py new file mode 100644 index 0000000000..6c12d61542 --- /dev/null +++ b/keras_hub/src/models/parseq/parseq_tokenizer.py @@ -0,0 +1,221 @@ +import os +import re +from typing import Iterable + +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.tokenizers import tokenizer +from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch +from keras_hub.src.utils.tensor_utils import is_int_dtype +from keras_hub.src.utils.tensor_utils import is_string_dtype +from keras_hub.src.utils.tensor_utils import preprocessing_function + +try: + import tensorflow as tf + import tensorflow_text as tf_text +except ImportError: + tf = None + tf_text = None + +PARSEQ_VOCAB = list( + "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!" + "\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" +) + +VOCAB_FILENAME = "vocabulary.txt" + + +@keras_hub_export( + [ + "keras_hub.tokenizers.PARSeqTokenizer", + "keras_hub.models.PARSeqTokenizer", + ] +) +class PARSeqTokenizer(tokenizer.Tokenizer): + """A Tokenizer for PARSeq models, designed for OCR tasks. + + This tokenizer converts strings into sequences of integer IDs or string + tokens, and vice-versa. It supports various preprocessing steps such as + whitespace removal, Unicode normalization, and limiting the maximum label + length. It also provides functionality to save and load the vocabulary + from a file. + + Args: + vocabulary: str. A string or iterable representing the vocabulary to + use. If a string, it's treated as the path to a vocabulary file. + If an iterable, it's treated as a list of characters forming + the vocabulary. Defaults to `PARSEQ_VOCAB`. + remove_whitespace: bool. Whether to remove whitespace characters from + the input. Defaults to `True`. + normalize_unicode: bool. Whether to normalize Unicode characters in the + input using NFKD normalization and remove non-ASCII characters. + Defaults to `True`. + max_label_length: int. The maximum length of the tokenized output. + Longer labels will be truncated. Defaults to `25`. + dtype: str. The data type of the tokenized output. Must be an integer + type (e.g., "int32") or a string type ("string"). + Defaults to `"int32"`. + **kwargs: Additional keyword arguments passed to the base + `keras.layers.Layer` constructor. + """ + + def __init__( + self, + vocabulary=PARSEQ_VOCAB, + remove_whitespace=True, + normalize_unicode=True, + max_label_length=25, + dtype="int32", + **kwargs, + ): + if not is_int_dtype(dtype) and not is_string_dtype(dtype): + raise ValueError( + "Output dtype must be an integer type or a string. " + f"Received: dtype={dtype}" + ) + super().__init__(dtype=dtype, **kwargs) + self.remove_whitespace = remove_whitespace + self.normalize_unicode = normalize_unicode + self.max_label_length = max_label_length + self.file_assets = [VOCAB_FILENAME] + + self.set_vocabulary(vocabulary) + + def save_assets(self, dir_path): + path = os.path.join(dir_path, VOCAB_FILENAME) + with open(path, "w", encoding="utf-8") as file: + for token in self.vocabulary: + file.write(f"{token}\n") + + def load_assets(self, dir_path): + path = os.path.join(dir_path, VOCAB_FILENAME) + self.set_vocabulary(path) + + def set_vocabulary(self, vocabulary): + """Set the tokenizer vocabulary to a file or list of strings.""" + if vocabulary is None: + self.vocabulary = None + return + + if isinstance(vocabulary, str): + with open(vocabulary, "r", encoding="utf-8") as file: + self.vocabulary = [line.rstrip() for line in file] + self.vocabulary = "".join(self.vocabulary) + elif isinstance(vocabulary, Iterable): + self.vocabulary = "".join(vocabulary) + else: + raise ValueError( + "Vocabulary must be an file path or list of terms. " + f"Received: vocabulary={vocabulary}" + ) + + self.lowercase_only = self.vocabulary == self.vocabulary.lower() + self.uppercase_only = self.vocabulary == self.vocabulary.upper() + escaped_charset = re.escape(self.vocabulary) # Escape for safe regex + self.unsupported_regex = f"[^{escaped_charset}]" + self._itos = ("[E]",) + tuple(self.vocabulary) + ("[B]", "[P]") + self._stoi = {s: i for i, s in enumerate(self._itos)} + + self._add_special_token("[B]", "start_token") + self._add_special_token("[E]", "end_token") + self._add_special_token("[P]", "pad_token") + # Create lookup tables. + self.char_to_id = tf.lookup.StaticHashTable( + initializer=tf.lookup.KeyValueTensorInitializer( + keys=list(self._stoi.keys()), + values=list(self._stoi.values()), + key_dtype=tf.string, + value_dtype=tf.int32, + ), + default_value=self._stoi["[E]"], + ) + self.id_to_char = tf.lookup.StaticHashTable( + initializer=tf.lookup.KeyValueTensorInitializer( + keys=list(self._stoi.values()), + values=list(self._stoi.keys()), + key_dtype=tf.int32, + value_dtype=tf.string, + ), + default_value=self.pad_token, + ) + + def get_vocabulary(self): + """Get the tokenizer vocabulary as a list of strings tokens.""" + return list(self.vocabulary) + + def id_to_token(self, id): + if id >= self.vocabulary_size() or id < 0: + raise ValueError( + f"`id` must be in range [0, {self.vocabulary_size() - 1}]. " + f"Received: {id}" + ) + return self._itos[id] + + def token_to_id(self, token): + return self._stoi[token] + + def _preprocess(self, inputs): + """Performs preprocessing include only characters from ASCII.""" + if self.remove_whitespace: + inputs = tf.strings.regex_replace(inputs, r"\s+", "") + + if self.normalize_unicode: + inputs = tf_text.normalize_utf8(inputs, normalization_form="NFKD") + inputs = tf.strings.regex_replace(inputs, r"[^!-~]", "") + + if self.lowercase_only: + inputs = tf.strings.lower(inputs) + elif self.uppercase_only: + inputs = tf.strings.upper(inputs) + + inputs = tf.strings.regex_replace(inputs, self.unsupported_regex, "") + inputs = tf.strings.substr(inputs, 0, self.max_label_length) + + return inputs + + @preprocessing_function + def tokenize(self, inputs): + inputs = tf.convert_to_tensor(inputs) + unbatched = inputs.shape.rank == 0 + if unbatched: + inputs = tf.expand_dims(inputs, 0) + + inputs = tf.map_fn( + self._preprocess, inputs, fn_output_signature=tf.string + ) + + token_ids = tf.cond( + tf.size(inputs) > 0, + lambda: self.char_to_id.lookup( + tf.strings.unicode_split(inputs, "UTF-8") + ), + lambda: tf.RaggedTensor.from_row_splits( + values=tf.constant([], dtype=tf.int32), + row_splits=tf.constant([0], dtype=tf.int64), + ), + ) + if unbatched: + token_ids = tf.squeeze(token_ids, 0) + tf.ensure_shape(token_ids, shape=[self.max_label_length]) + return token_ids + + @preprocessing_function + def detokenize(self, inputs): + inputs, unbatched, rectangular = convert_to_ragged_batch(inputs) + # tf-text sentencepiece does not handle int64. + inputs = tf.cast(inputs, "int32") + outputs = self.id_to_char.lookup(inputs) + if unbatched: + outputs = tf.squeeze(outputs, 0) + return outputs + + def vocabulary_size(self): + """Get the integer size of the tokenizer vocabulary.""" + return len(self.vocabulary) + 3 + + def compute_output_spec(self, input_spec): + return keras.KerasTensor( + input_spec.shape + (self.max_label_length,), + dtype=self.compute_dtype, + ) diff --git a/keras_hub/src/utils/transformers/convert_vit.py b/keras_hub/src/utils/transformers/convert_vit.py index ae9a7aa19b..77bc241b94 100644 --- a/keras_hub/src/utils/transformers/convert_vit.py +++ b/keras_hub/src/utils/transformers/convert_vit.py @@ -9,7 +9,10 @@ def convert_backbone_config(transformers_config): image_size = transformers_config["image_size"] return { "image_shape": (image_size, image_size, 3), - "patch_size": transformers_config["patch_size"], + "patch_size": ( + transformers_config["patch_size"], + transformers_config["patch_size"], + ), "num_layers": transformers_config["num_hidden_layers"], "num_heads": transformers_config["num_attention_heads"], "hidden_dim": transformers_config["hidden_size"], diff --git a/tools/checkpoint_conversion/convert_parseq_checkpoints.py b/tools/checkpoint_conversion/convert_parseq_checkpoints.py new file mode 100644 index 0000000000..9be475fcbf --- /dev/null +++ b/tools/checkpoint_conversion/convert_parseq_checkpoints.py @@ -0,0 +1,349 @@ +"""Convert PARSeq checkpoints from https://github.com/baudm/parseq. + +Make sure to install `pip install pytorch_lighning` for checkpoint convertion. + +export KAGGLE_USERNAME=XXX +export KAGGLE_KEY=XXX + +python tools/checkpoint_conversion/convert_parseq_checkpoints.py \ + --preset parseq +""" + +import os +import shutil + +import keras +import numpy as np +import torch +from absl import app +from absl import flags +from PIL import Image + +import keras_hub +from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone +from keras_hub.src.models.parseq.parseq_causal_lm import PARSeqCausalLM +from keras_hub.src.models.parseq.parseq_causal_lm_preprocessor import ( + PARSeqCausalLMPreprocessor, +) +from keras_hub.src.models.parseq.parseq_image_converter import ( + PARSeqImageConverter, +) +from keras_hub.src.models.parseq.parseq_tokenizer import PARSeqTokenizer +from keras_hub.src.models.vit.vit_backbone import ViTBackbone + +FLAGS = flags.FLAGS + +PRESET_MAP = {"parseq_vit": "baudm/parseq"} + +flags.DEFINE_string( + "preset", + None, + f"Must be one of {','.join(PRESET_MAP.keys())}", + required=True, +) + +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}"', + required=False, +) + + +def get_keras_backbone(): + # Config ref: https://github.com/baudm/parseq/blob/main/configs/model/parseq.yaml # noqa: E501 + image_encoder = ViTBackbone( + image_shape=(32, 128, 3), + patch_size=(4, 8), + num_layers=12, + num_heads=6, + hidden_dim=384, + mlp_dim=384 * 4, + use_class_token=False, + name="encoder", + ) + backbone = PARSeqBackbone( + vocabulary_size=97, + max_label_length=25, + image_encoder=image_encoder, + num_decoder_heads=12, + num_decoder_layers=1, + decoder_hidden_dim=384, + decoder_mlp_dim=4 * 384, + ) + + return backbone + + +def convert_backbone_weights(backbone, torch_model): + state_dict = torch_model.state_dict() + state_dict.update(torch_model.named_buffers()) + + # Helper functions. + def port_weights(keras_variable, weight_key, hook_fn=None): + torch_tensor = state_dict[weight_key].cpu().numpy() + if hook_fn: + torch_tensor = hook_fn(torch_tensor, list(keras_variable.shape)) + keras_variable.assign(torch_tensor) + + def port_ln(keras_variable, weight_key): + port_weights(keras_variable.gamma, f"{weight_key}.weight") + port_weights(keras_variable.beta, f"{weight_key}.bias") + + def port_dense(keras_variable, weight_key): + port_weights( + keras_variable.kernel, + f"{weight_key}.weight", + hook_fn=lambda x, _: x.T, + ) + if keras_variable.bias is not None: + port_weights(keras_variable.bias, f"{weight_key}.bias") + + def port_mha( + keras_variable, weight_key, num_heads, hidden_dim, encoder=True + ): + # Attention layer. + if encoder: + fused_qkv_kernel = state_dict[f"{weight_key}.attn.qkv.weight"].t() + fused_qkv_bias = ( + state_dict[f"{weight_key}.attn.qkv.bias"].cpu().numpy() + ) + else: + fused_qkv_kernel = state_dict[f"{weight_key}.in_proj_weight"].t() + fused_qkv_bias = ( + state_dict[f"{weight_key}.in_proj_bias"].cpu().numpy() + ) + + head_dim = hidden_dim // num_heads + + # Kernel + query_kernel = fused_qkv_kernel[:, :hidden_dim] + query_kernel = query_kernel.reshape(hidden_dim, num_heads, head_dim) + + key_kernel = fused_qkv_kernel[ + :, hidden_dim : hidden_dim + num_heads * head_dim + ] + key_kernel = key_kernel.reshape(hidden_dim, num_heads, head_dim) + + value_kernel = fused_qkv_kernel[:, hidden_dim + num_heads * head_dim :] + value_kernel = value_kernel.reshape(hidden_dim, num_heads, head_dim) + + # Bias + query_bias = fused_qkv_bias[:hidden_dim] + query_bias = query_bias.reshape(num_heads, head_dim) + + key_bias = fused_qkv_bias[ + hidden_dim : hidden_dim + num_heads * head_dim + ] + key_bias = key_bias.reshape(num_heads, head_dim) + + value_bias = fused_qkv_bias[hidden_dim + num_heads * head_dim :] + value_bias = value_bias.reshape(num_heads, head_dim) + + keras_variable.query_dense.kernel.assign(query_kernel) + keras_variable.key_dense.kernel.assign(key_kernel) + keras_variable.value_dense.kernel.assign(value_kernel) + + keras_variable.query_dense.bias.assign(query_bias) + keras_variable.key_dense.bias.assign(key_bias) + keras_variable.value_dense.bias.assign(value_bias) + + if encoder: + keras_variable.output_dense.kernel.assign( + state_dict[f"{weight_key}.attn.proj.weight"] + .t() + .reshape(num_heads, head_dim, hidden_dim) + ) + keras_variable.output_dense.bias.assign( + state_dict[f"{weight_key}.attn.proj.bias"].cpu().numpy() + ) + else: + keras_variable.output_dense.kernel.assign( + state_dict[f"{weight_key}.out_proj.weight"] + .t() + .reshape(num_heads, head_dim, hidden_dim) + ) + keras_variable.output_dense.bias.assign( + state_dict[f"{weight_key}.out_proj.bias"].cpu().numpy() + ) + + # Encoder weight transfer + port_weights( + backbone.image_encoder.layers[1].patch_embedding.kernel, + "model.encoder.patch_embed.proj.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + port_weights( + backbone.image_encoder.layers[1].patch_embedding.bias, + "model.encoder.patch_embed.proj.bias", + ) + + port_weights( + backbone.image_encoder.layers[1].position_embedding.embeddings, + "model.encoder.pos_embed", + hook_fn=lambda x, _: x[0], + ) + encoder_layers = backbone.image_encoder.layers[2].encoder_layers + for i, encoder_block in enumerate(encoder_layers): + prefix = "model.encoder.blocks" + num_heads = encoder_block.num_heads + hidden_dim = encoder_block.hidden_dim + + # Decompose fused multihead attention layer from torch + port_mha( + encoder_block.mha, + f"{prefix}.{i}", + num_heads, + hidden_dim, + ) + + port_ln(encoder_block.layer_norm_1, f"{prefix}.{i}.norm1") + port_ln(encoder_block.layer_norm_2, f"{prefix}.{i}.norm2") + + port_dense(encoder_block.mlp.dense_1, f"{prefix}.{i}.mlp.fc1") + port_dense(encoder_block.mlp.dense_2, f"{prefix}.{i}.mlp.fc2") + port_ln(backbone.image_encoder.layers[2].layer_norm, "model.encoder.norm") + + # Decoder weights transfer + port_weights( + backbone.layers[4].pos_query_embeddings, + "model.pos_queries", + ) + port_weights( + backbone.layers[4].token_embedding.embeddings, + "model.text_embed.embedding.weight", + ) + + decoder_layers = backbone.layers[4].decoder_layers + for i, decoder_block in enumerate(decoder_layers): + prefix = "model.decoder.layers" + num_heads = decoder_block.num_heads + hidden_dim = decoder_block.hidden_dim + + port_mha( + decoder_block.self_attention, + f"{prefix}.{i}.self_attn", + num_heads, + hidden_dim, + encoder=False, + ) + port_mha( + decoder_block.cross_attention, + f"{prefix}.{i}.cross_attn", + num_heads, + hidden_dim, + encoder=False, + ) + + port_ln(decoder_block.layer_norm_1, f"{prefix}.{i}.norm1") + port_ln(decoder_block.layer_norm_2, f"{prefix}.{i}.norm2") + port_ln(decoder_block.query_layer_norm, f"{prefix}.{i}.norm_q") + port_ln(decoder_block.content_layer_norm, f"{prefix}.{i}.norm_c") + port_dense(decoder_block.mlp.dense_1, f"{prefix}.{i}.linear1") + port_dense(decoder_block.mlp.dense_2, f"{prefix}.{i}.linear2") + port_ln(backbone.layers[4].layer_norm, "model.decoder.norm") + port_dense(backbone.layers[5], "model.head") + + +def convert_image_converter(): + # Basic image transformations done: + # Ref: https://github.com/baudm/parseq/blob/1902db043c029a7e03a3818c616c06600af574be/strhub/data/module.py#L77 # noqa: E501 + mean, std = 0.5, 0.5 + return PARSeqImageConverter( + image_size=(32, 128), + offset=-mean / std, + scale=1.0 / 255.0 / std, + interpolation="bicubic", + ) + + +def validate_output(preprocessor, keras_model, torch_model): + file = keras.utils.get_file( + origin="https://upload.wikimedia.org/wikipedia/commons/thumb/2/2f/Google_2015_logo.svg/480px-Google_2015_logo.svg.png", # noqa : E501 + fname="google.png", + ) + image = Image.open(file).convert("RGB") + images = np.expand_dims(np.array(image).astype("float32"), axis=0) + + x, _, _ = preprocessor({"images": images, "responses": ["Google"]}) + + keras_output = keras_model(preprocessor.generate_preprocess(images)) + tgt_in = torch.full((1, 1), torch_model.tokenizer.bos_id, dtype=torch.long) + torch_output = torch_model.model.head( + torch_model.model.decode( + tgt_in, + torch_model.model.encoder( + torch.from_numpy( + keras.ops.convert_to_numpy(x["images"]).transpose( + 0, 3, 1, 2 + ) + ) + ), + ) + ) + + keras_causal_output = [ + "".join(output) for output in keras_model.generate(x["images"]) + ] + torch_image_input = torch.from_numpy( + keras.ops.convert_to_numpy(x["images"]) + ) + torch_logits = torch_model(torch_image_input.permute(0, 3, 1, 2)) + torch_causal_output, _ = torch_model.tokenizer.decode(torch_logits) + + print("🔶 Keras Logits Output:", keras_output[0, 0, :10]) + print("🔶 Torch Logits Output:", torch_output[0, 0, :10]) + print("🔶 Keras Causal Output:", keras_causal_output) + print("🔶 Torch Causal Output:", torch_causal_output) + assert torch_causal_output[0] == keras_causal_output[0] + + +def main(_): + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + print(f"🏃 Coverting {preset}") + + # Load model and image transforms + torch_model = torch.hub.load(hf_preset, preset, pretrained=True).eval() + keras_backbone = get_keras_backbone() + print("✅ KerasHub backbone loaded.") + + convert_backbone_weights(keras_backbone, torch_model) + print("✅ Backbone weights converted.") + + keras_image_converter = convert_image_converter() + keras_tokenizer = PARSeqTokenizer(max_label_length=25) + + parseq_preprocessor = PARSeqCausalLMPreprocessor( + image_converter=keras_image_converter, tokenizer=keras_tokenizer + ) + + print("✅ Loaded preprocessor configuration.") + + keras_model = PARSeqCausalLM( + preprocessor=parseq_preprocessor, backbone=keras_backbone + ) + + validate_output(parseq_preprocessor, keras_model, torch_model) + print("✅ Outputs Validated.") + + print(f"🏁 Preset saved to ./{preset}.") + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + app.run(main)