Skip to content

Commit d625954

Browse files
authored
PARSeq Model (#2089)
* Base for parseq model * make it vit compatiable with diff height and width sizes * correct vit conv scripts * make class token optional in backbone by default its included * add flags to adjust vit network * add test case for without class_token * decoder file * parseq tokenizer base * add api for parseq tokenizer * Add missing arg max_label_length. * nit * add missing normalization step using tf_text * add missing config for preprocessor * add default start, pad and end tokens * nit * correct special token order * return padding mask as well * use proper keras ops * nit * add decoder for parseq * Build unbuilt layers for model validation * fix forward pass and decoder * add missing mlp forward pass * add generate prprocess and generate step * nit * add generate_step to parseq causal lm * minor fixes for jax backend and config fix * update decoder layer with caching mechanism which is used for generate step * modify generate step including cache * re structure code to make jax backend compatiable * add postprocess step into preprocessor * test only forward pass * nit * test build cache * test generate step only build cache * correct class name * correct dropout * remove slicing in forward pass * nit * use python style slicing * support jax for generate step * compute attention mask for permutation at decoder block level * correct syntax error * nit * Add method for geenrating attention masks during train & permutations method * update end token after 2 perms * minor bug fix * add save assets and load assets methods * fix conflict issue * nit * fix minor issues while loading preset * fix jax dynamic shape issues * try to fix jax backend concretization error * fix mask broadcast error * fix repeat for mismatch output length * ignore permutation based training * fix dtype and add test case for parseq * fix input format and add causal lm testing * use numpy random images * fix jax backend issue when reduction set to "mean_with_sample_weight" * remove redudant classes and use causal lm base calsses itself. * nit * fix decoder_head_dim usage * fix preprocessing issues * add checkpoint convertion script * add missing flag * validate convertion outputs * nit * fix training for permutation logic * add example usage for backbone and causal lm * nit * fix minor issues * use default params from args as we preprocessor can be None * during pre compile self variables not available * nit * nit
1 parent 3bfa89f commit d625954

15 files changed

+2106
-1
lines changed

keras_hub/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@
108108
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
109109
PaliGemmaImageConverter as PaliGemmaImageConverter,
110110
)
111+
from keras_hub.src.models.parseq.parseq_image_converter import (
112+
PARSeqImageConverter as PARSeqImageConverter,
113+
)
111114
from keras_hub.src.models.resnet.resnet_image_converter import (
112115
ResNetImageConverter as ResNetImageConverter,
113116
)

keras_hub/api/models/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,18 @@
446446
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
447447
PaliGemmaTokenizer as PaliGemmaTokenizer,
448448
)
449+
from keras_hub.src.models.parseq.parseq_backbone import (
450+
PARSeqBackbone as PARSeqBackbone,
451+
)
452+
from keras_hub.src.models.parseq.parseq_causal_lm import (
453+
PARSeqCausalLM as PARSeqCausalLM,
454+
)
455+
from keras_hub.src.models.parseq.parseq_causal_lm_preprocessor import (
456+
PARSeqCausalLMPreprocessor as PARSeqCausalLMPreprocessor,
457+
)
458+
from keras_hub.src.models.parseq.parseq_tokenizer import (
459+
PARSeqTokenizer as PARSeqTokenizer,
460+
)
449461
from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone as Phi3Backbone
450462
from keras_hub.src.models.phi3.phi3_causal_lm import (
451463
Phi3CausalLM as Phi3CausalLM,

keras_hub/api/tokenizers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@
6666
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
6767
PaliGemmaTokenizer as PaliGemmaTokenizer,
6868
)
69+
from keras_hub.src.models.parseq.parseq_tokenizer import (
70+
PARSeqTokenizer as PARSeqTokenizer,
71+
)
6972
from keras_hub.src.models.phi3.phi3_tokenizer import (
7073
Phi3Tokenizer as Phi3Tokenizer,
7174
)

keras_hub/src/models/parseq/__init__.py

Whitespace-only changes.
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import keras
2+
3+
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.models.backbone import Backbone
5+
from keras_hub.src.models.parseq.parseq_decoder import PARSeqDecoder
6+
7+
8+
@keras_hub_export("keras_hub.models.PARSeqBackbone")
9+
class PARSeqBackbone(Backbone):
10+
"""Scene Text Detection with PARSeq.
11+
12+
Performs OCR in natural scenes using the PARSeq model described in [Scene
13+
Text Recognition with Permuted Autoregressive Sequence Models](
14+
https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows
15+
iterative decoding by performing an autoregressive decoding phase, followed
16+
by a refinement phase.
17+
18+
Args:
19+
image_encoder: keras.Model. The image encoder model.
20+
vocabulary_size: int. The size of the vocabulary.
21+
max_label_length: int. The maximum length of the label sequence.
22+
decoder_hidden_dim: int. The dimension of the decoder hidden layers.
23+
num_decoder_layers: int. The number of decoder layers.
24+
num_decoder_heads: int. The number of attention heads in the decoder.
25+
decoder_mlp_dim: int. The dimension of the decoder MLP hidden layer.
26+
dropout_rate: float. The dropout rate for the decoder network.
27+
Defaults to `0.1`.
28+
attention_dropout: float. The dropout rate for the attention weights.
29+
Defaults to `0.1`.
30+
dtype: str. `None`, str, or `keras.mixed_precision.DTypePolicy`. The
31+
dtype to use for the computations and weights.
32+
**kwargs: Additional keyword arguments passed to the base
33+
`keras.Model` constructor.
34+
"""
35+
36+
def __init__(
37+
self,
38+
image_encoder,
39+
vocabulary_size,
40+
max_label_length,
41+
decoder_hidden_dim,
42+
num_decoder_layers,
43+
num_decoder_heads,
44+
decoder_mlp_dim,
45+
dropout_rate=0.1,
46+
attention_dropout=0.1,
47+
dtype=None,
48+
**kwargs,
49+
):
50+
# === Layers ===
51+
self.image_encoder = image_encoder
52+
self.decoder = PARSeqDecoder(
53+
vocabulary_size=vocabulary_size,
54+
max_label_length=max_label_length,
55+
num_layers=num_decoder_layers,
56+
num_heads=num_decoder_heads,
57+
hidden_dim=decoder_hidden_dim,
58+
mlp_dim=decoder_mlp_dim,
59+
dropout_rate=dropout_rate,
60+
attention_dropout=attention_dropout,
61+
name="decoder",
62+
dtype=dtype,
63+
)
64+
self.head = keras.layers.Dense(
65+
vocabulary_size - 2, # We don't predict <bos> nor <pad>
66+
dtype=dtype,
67+
)
68+
69+
# === Functional Model ===
70+
image_input = self.image_encoder.input
71+
72+
token_id_input = keras.Input(
73+
shape=(None,), dtype="int32", name="token_ids"
74+
)
75+
padding_mask_input = keras.Input(
76+
shape=(None,), dtype="int32", name="padding_mask"
77+
)
78+
79+
memory = self.image_encoder(image_input)
80+
target_out = self.decoder(
81+
token_id_input, memory, padding_mask=padding_mask_input
82+
)
83+
logits = self.head(target_out)
84+
85+
# === Config ===
86+
self.vocabulary_size = vocabulary_size
87+
self.max_label_length = max_label_length
88+
self.decoder_hidden_dim = decoder_hidden_dim
89+
self.num_decoder_layers = num_decoder_layers
90+
self.num_decoder_heads = num_decoder_heads
91+
self.decoder_mlp_dim = decoder_mlp_dim
92+
self.dropout_rate = dropout_rate
93+
self.attention_dropout = attention_dropout
94+
95+
super().__init__(
96+
inputs={
97+
"images": image_input,
98+
"token_ids": token_id_input,
99+
"padding_mask": padding_mask_input,
100+
},
101+
outputs=logits,
102+
dtype=dtype,
103+
**kwargs,
104+
)
105+
106+
def get_config(self):
107+
config = super().get_config()
108+
config.update(
109+
{
110+
"image_encoder": keras.layers.serialize(self.image_encoder),
111+
"vocabulary_size": self.vocabulary_size,
112+
"max_label_length": self.max_label_length,
113+
"decoder_hidden_dim": self.decoder_hidden_dim,
114+
"num_decoder_layers": self.num_decoder_layers,
115+
"num_decoder_heads": self.num_decoder_heads,
116+
"decoder_mlp_dim": self.decoder_mlp_dim,
117+
"dropout_rate": self.dropout_rate,
118+
"attention_dropout": self.attention_dropout,
119+
}
120+
)
121+
122+
return config
123+
124+
@classmethod
125+
def from_config(cls, config):
126+
config.update(
127+
{
128+
"image_encoder": keras.layers.deserialize(
129+
config["image_encoder"]
130+
),
131+
}
132+
)
133+
134+
return super().from_config(config)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import keras
2+
import pytest
3+
from keras import ops
4+
5+
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
6+
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
7+
from keras_hub.src.tests.test_case import TestCase
8+
9+
10+
class PARSeqBackboneTest(TestCase):
11+
def setUp(self):
12+
self.batch_size = 2
13+
self.image_height = 32
14+
self.image_width = 128
15+
self.num_channels = 3
16+
17+
# Image Encoder parameters (as per your example)
18+
self.vit_patch_size = (4, 8)
19+
self.vit_num_layers = 2
20+
self.vit_num_heads = 2
21+
self.vit_hidden_dim = 64
22+
self.vit_mlp_dim = self.vit_hidden_dim * 4
23+
24+
# PARSeq Backbone parameters
25+
self.vocabulary_size = 97
26+
self.max_label_length = 25
27+
self.decoder_hidden_dim = self.vit_hidden_dim
28+
self.num_decoder_layers = 1
29+
self.num_decoder_heads = 2
30+
self.decoder_mlp_dim = self.decoder_hidden_dim * 4
31+
32+
# Instantiate the actual ViTBackbone to be used as the image_encoder
33+
self.image_encoder = ViTBackbone(
34+
image_shape=(
35+
self.image_height,
36+
self.image_width,
37+
self.num_channels,
38+
),
39+
patch_size=self.vit_patch_size,
40+
num_layers=self.vit_num_layers,
41+
num_heads=self.vit_num_heads,
42+
hidden_dim=self.vit_hidden_dim,
43+
mlp_dim=self.vit_mlp_dim,
44+
use_class_token=False,
45+
name="image_encoder",
46+
)
47+
48+
self.init_kwargs = {
49+
"image_encoder": self.image_encoder,
50+
"vocabulary_size": self.vocabulary_size,
51+
"max_label_length": self.max_label_length,
52+
"decoder_hidden_dim": self.decoder_hidden_dim,
53+
"num_decoder_layers": self.num_decoder_layers,
54+
"num_decoder_heads": self.num_decoder_heads,
55+
"decoder_mlp_dim": self.decoder_mlp_dim,
56+
"dropout_rate": 0.0,
57+
"attention_dropout": 0.0,
58+
}
59+
60+
# Dummy input data
61+
dummy_images = keras.random.normal(
62+
shape=(
63+
self.batch_size,
64+
self.image_height,
65+
self.image_width,
66+
self.num_channels,
67+
),
68+
)
69+
70+
dummy_token_ids = keras.random.randint(
71+
minval=0,
72+
maxval=self.vocabulary_size,
73+
shape=(self.batch_size, self.max_label_length),
74+
)
75+
dummy_padding_mask = ops.ones(
76+
shape=(self.batch_size, self.max_label_length), dtype="int32"
77+
)
78+
79+
self.input_data = {
80+
"images": dummy_images,
81+
"token_ids": dummy_token_ids,
82+
"padding_mask": dummy_padding_mask,
83+
}
84+
85+
def test_backbone_basics(self):
86+
expected_shape_full = (
87+
self.batch_size,
88+
self.max_label_length,
89+
self.vocabulary_size - 2,
90+
)
91+
92+
self.run_backbone_test(
93+
cls=PARSeqBackbone,
94+
init_kwargs=self.init_kwargs,
95+
input_data=self.input_data,
96+
expected_output_shape=expected_shape_full,
97+
# we have image_encoder as init_kwargs which is also a backbone
98+
run_quantization_check=False,
99+
)
100+
101+
@pytest.mark.large
102+
def test_saved_model(self):
103+
self.run_model_saving_test(
104+
cls=PARSeqBackbone,
105+
init_kwargs=self.init_kwargs,
106+
input_data=self.input_data,
107+
)

0 commit comments

Comments
 (0)