Skip to content

Commit e59d313

Browse files
Merge branch 'master' into supporting_gemma_inference_with_ov_backend
1 parent fdcc24c commit e59d313

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3530
-148
lines changed

.kokoro/github/ubuntu/gpu/build.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ pip install huggingface_hub
5757
if [ "${RUN_XLARGE:-0}" == "1" ]
5858
then
5959
pytest keras_hub --check_gpu --run_large --run_extra_large \
60-
--cov=keras-hub
60+
--cov=keras_hub
6161
else
6262
pytest keras_hub --check_gpu --run_large \
63-
--cov=keras-hub
63+
--cov=keras_hub
6464
fi

keras_hub/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@
9090
from keras_hub.src.models.gemma3.gemma3_image_converter import (
9191
Gemma3ImageConverter as Gemma3ImageConverter,
9292
)
93+
from keras_hub.src.models.hgnetv2.hgnetv2_image_converter import (
94+
HGNetV2ImageConverter as HGNetV2ImageConverter,
95+
)
9396
from keras_hub.src.models.mit.mit_image_converter import (
9497
MiTImageConverter as MiTImageConverter,
9598
)

keras_hub/api/models/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,15 @@
294294
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
295295
GPTNeoXTokenizer as GPTNeoXTokenizer,
296296
)
297+
from keras_hub.src.models.hgnetv2.hgnetv2_backbone import (
298+
HGNetV2Backbone as HGNetV2Backbone,
299+
)
300+
from keras_hub.src.models.hgnetv2.hgnetv2_image_classifier import (
301+
HGNetV2ImageClassifier as HGNetV2ImageClassifier,
302+
)
303+
from keras_hub.src.models.hgnetv2.hgnetv2_image_classifier_preprocessor import (
304+
HGNetV2ImageClassifierPreprocessor as HGNetV2ImageClassifierPreprocessor,
305+
)
297306
from keras_hub.src.models.image_classifier import (
298307
ImageClassifier as ImageClassifier,
299308
)
@@ -454,6 +463,9 @@
454463
from keras_hub.src.models.qwen3.qwen3_backbone import (
455464
Qwen3Backbone as Qwen3Backbone,
456465
)
466+
from keras_hub.src.models.qwen3.qwen3_causal_lm import (
467+
Qwen3CausalLM as Qwen3CausalLM,
468+
)
457469
from keras_hub.src.models.qwen3.qwen3_causal_lm_preprocessor import (
458470
Qwen3CausalLMPreprocessor as Qwen3CausalLMPreprocessor,
459471
)

keras_hub/src/layers/modeling/transformer_encoder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@ class TransformerEncoder(keras.layers.Layer):
1616
paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users
1717
can instantiate multiple instances of this class to stack up an encoder.
1818
19-
This layer will correctly compute an attention mask from an implicit
20-
Keras padding mask (for example, by passing `mask_zero=True` to a
21-
`keras.layers.Embedding` layer). See the Masking and Padding
19+
This layer will compute an attention mask, prioritizing explicitly provided
20+
masks (a `padding_mask` or a custom `attention_mask`) over an implicit Keras
21+
padding mask (for example, by passing `mask_zero=True` to a
22+
`keras.layers.Embedding` layer). If both a `padding_mask` and a
23+
`attention_mask` are provided, they will be combined to determine the final
24+
mask. See the Masking and Padding
2225
[guide](https://keras.io/guides/understanding_masking_and_padding/)
2326
for more details.
2427

keras_hub/src/models/gemma/gemma_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _compute_attention(
152152
attention_mask = ops.expand_dims(attention_mask, axis=1)
153153
attention_mask = ops.cast(attention_mask, dtype="bool")
154154
# Only pass soft cap if needed as not all keras versions support.
155-
if self.logit_soft_cap:
155+
if self.logit_soft_cap is not None:
156156
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
157157
else:
158158
kwargs = {}

keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def call(self, image_embeddings, text_embeddings, vision_indices):
6565
to_add = ops.multiply(
6666
keras.ops.arange(batch_size, dtype="int32"), seq_length
6767
)
68-
to_add = ops.expand_dims(to_add, axis=-1)
68+
to_add = ops.cast(ops.expand_dims(to_add, axis=-1), "int32")
6969
vision_indices = ops.add(vision_indices, to_add)
7070

7171
# indices should be of shape `(num_updates, 1)`. `num_updates` is
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone
2+
from keras_hub.src.models.hgnetv2.hgnetv2_presets import backbone_presets
3+
from keras_hub.src.utils.preset_utils import register_presets
4+
5+
register_presets(backbone_presets, HGNetV2Backbone)
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import keras
2+
3+
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.models.backbone import Backbone
5+
from keras_hub.src.models.hgnetv2.hgnetv2_encoder import HGNetV2Encoder
6+
from keras_hub.src.models.hgnetv2.hgnetv2_layers import HGNetV2Embeddings
7+
from keras_hub.src.utils.keras_utils import standardize_data_format
8+
9+
10+
@keras_hub_export("keras_hub.models.HGNetV2Backbone")
11+
class HGNetV2Backbone(Backbone):
12+
"""This class represents a Keras Backbone of the HGNetV2 model.
13+
14+
This class implements an HGNetV2 backbone architecture, a convolutional
15+
neural network (CNN) optimized for GPU efficiency. HGNetV2 is frequently
16+
used as a lightweight CNN backbone in object detection pipelines like
17+
RT-DETR and YOLO variants, delivering strong performance on classification
18+
and detection tasks, with speed-ups and accuracy gains compared to larger
19+
CNN backbones.
20+
21+
Args:
22+
depths: list of ints, the number of blocks in each stage.
23+
embedding_size: int, the size of the embedding layer.
24+
hidden_sizes: list of ints, the sizes of the hidden layers.
25+
stem_channels: list of ints, the channels for the stem part.
26+
hidden_act: str, the activation function for hidden layers.
27+
use_learnable_affine_block: bool, whether to use learnable affine
28+
transformations.
29+
stackwise_stage_filters: list of tuples, where each tuple contains
30+
configuration for a stage: (stage_in_channels, stage_mid_channels,
31+
stage_out_channels, stage_num_blocks, stage_num_of_layers,
32+
stage_kernel_size).
33+
- stage_in_channels: int, input channels for the stage
34+
- stage_mid_channels: int, middle channels for the stage
35+
- stage_out_channels: int, output channels for the stage
36+
- stage_num_blocks: int, number of blocks in the stage
37+
- stage_num_of_layers: int, number of layers in each block
38+
- stage_kernel_size: int, kernel size for the stage
39+
apply_downsample: list of bools, whether to downsample in each stage.
40+
use_lightweight_conv_block: list of bools, whether to use HGNetV2
41+
lightweight convolutional blocks in each stage.
42+
image_shape: tuple, the shape of the input image without the batch size.
43+
Defaults to `(None, None, 3)`.
44+
data_format: `None` or str, the data format ('channels_last' or
45+
'channels_first'). If not specified, defaults to the
46+
`image_data_format` value in your Keras config.
47+
out_features: list of str or `None`, the names of the output features to
48+
return. If `None`, returns all available features from all stages.
49+
Defaults to `None`.
50+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`, the data
51+
type for computations and weights.
52+
53+
Examples:
54+
```python
55+
import numpy as np
56+
from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone
57+
input_data = np.ones(shape=(8, 224, 224, 3))
58+
59+
# Pretrained backbone.
60+
model = keras_hub.models.HGNetV2Backbone.from_preset(
61+
"hgnetv2_b5_ssld_stage2_ft_in1k"
62+
)
63+
model(input_data)
64+
65+
# Randomly initialized backbone with a custom config.
66+
model = HGNetV2Backbone(
67+
depths=[1, 2, 4],
68+
embedding_size=32,
69+
hidden_sizes=[64, 128, 256],
70+
stem_channels=[3, 16, 32],
71+
hidden_act="relu",
72+
use_learnable_affine_block=False,
73+
stackwise_stage_filters=[
74+
(32, 16, 64, 1, 1, 3), # Stage 0
75+
(64, 32, 128, 2, 1, 3), # Stage 1
76+
(128, 64, 256, 4, 1, 3), # Stage 2
77+
],
78+
apply_downsample=[False, True, True],
79+
use_lightweight_conv_block=[False, False, False],
80+
image_shape=(224, 224, 3),
81+
)
82+
model(input_data)
83+
```
84+
"""
85+
86+
def __init__(
87+
self,
88+
depths,
89+
embedding_size,
90+
hidden_sizes,
91+
stem_channels,
92+
hidden_act,
93+
use_learnable_affine_block,
94+
stackwise_stage_filters,
95+
apply_downsample,
96+
use_lightweight_conv_block,
97+
image_shape=(None, None, 3),
98+
data_format=None,
99+
out_features=None,
100+
dtype=None,
101+
**kwargs,
102+
):
103+
name = kwargs.get("name", None)
104+
data_format = standardize_data_format(data_format)
105+
channel_axis = -1 if data_format == "channels_last" else 1
106+
self.image_shape = image_shape
107+
(
108+
stage_in_channels,
109+
stage_mid_channels,
110+
stage_out_filters,
111+
stage_num_blocks,
112+
stage_num_of_layers,
113+
stage_kernel_size,
114+
) = zip(*stackwise_stage_filters)
115+
116+
# === Layers ===
117+
self.embedder_layer = HGNetV2Embeddings(
118+
stem_channels=stem_channels,
119+
hidden_act=hidden_act,
120+
use_learnable_affine_block=use_learnable_affine_block,
121+
data_format=data_format,
122+
channel_axis=channel_axis,
123+
name=f"{name}_embedder" if name else "embedder",
124+
dtype=dtype,
125+
)
126+
self.encoder_layer = HGNetV2Encoder(
127+
stage_in_channels=stage_in_channels,
128+
stage_mid_channels=stage_mid_channels,
129+
stage_out_channels=stage_out_filters,
130+
stage_num_blocks=stage_num_blocks,
131+
stage_num_of_layers=stage_num_of_layers,
132+
apply_downsample=apply_downsample,
133+
use_lightweight_conv_block=use_lightweight_conv_block,
134+
stage_kernel_size=stage_kernel_size,
135+
use_learnable_affine_block=use_learnable_affine_block,
136+
data_format=data_format,
137+
channel_axis=channel_axis,
138+
name=f"{name}_encoder" if name else "encoder",
139+
dtype=dtype,
140+
)
141+
self.stage_names = ["stem"] + [
142+
f"stage{i + 1}" for i in range(len(stackwise_stage_filters))
143+
]
144+
self.out_features = (
145+
out_features if out_features is not None else self.stage_names
146+
)
147+
148+
# === Functional Model ===
149+
pixel_values = keras.layers.Input(
150+
shape=image_shape, name="pixel_values_input"
151+
)
152+
embedding_output = self.embedder_layer(pixel_values)
153+
all_encoder_hidden_states_tuple = self.encoder_layer(embedding_output)
154+
feature_maps_output = {
155+
stage_name: all_encoder_hidden_states_tuple[idx]
156+
for idx, stage_name in enumerate(self.stage_names)
157+
if stage_name in self.out_features
158+
}
159+
super().__init__(
160+
inputs=pixel_values, outputs=feature_maps_output, **kwargs
161+
)
162+
163+
# === Config ===
164+
self.depths = depths
165+
self.embedding_size = embedding_size
166+
self.hidden_sizes = hidden_sizes
167+
self.stem_channels = stem_channels
168+
self.hidden_act = hidden_act
169+
self.use_learnable_affine_block = use_learnable_affine_block
170+
self.stackwise_stage_filters = stackwise_stage_filters
171+
self.apply_downsample = apply_downsample
172+
self.use_lightweight_conv_block = use_lightweight_conv_block
173+
self.data_format = data_format
174+
175+
def get_config(self):
176+
config = super().get_config()
177+
config.update(
178+
{
179+
"depths": self.depths,
180+
"embedding_size": self.embedding_size,
181+
"hidden_sizes": self.hidden_sizes,
182+
"stem_channels": self.stem_channels,
183+
"hidden_act": self.hidden_act,
184+
"use_learnable_affine_block": self.use_learnable_affine_block,
185+
"stackwise_stage_filters": self.stackwise_stage_filters,
186+
"apply_downsample": self.apply_downsample,
187+
"use_lightweight_conv_block": self.use_lightweight_conv_block,
188+
"image_shape": self.image_shape,
189+
"out_features": self.out_features,
190+
"data_format": self.data_format,
191+
}
192+
)
193+
return config

0 commit comments

Comments
 (0)