Skip to content

Commit a9ce042

Browse files
Add HGNetV2 to KerasHub (#2293)
* init: Add initial project structure and files * bug: Small bug related to weight loading in the conversion script * finalizing: Add TIMM preprocessing layer * incorporate reviews: Consolidate stage configurations and improve API consistency * bug: Unexpected argument error in JAX with Keras 3.5 * small addition for the D-FINE to come: No changes to the existing HGNetV2 * D-FINE JIT compile: Remove non-essential conditional statement * refactor: Address reviews and fix some nits
1 parent 2ad1406 commit a9ce042

13 files changed

+2284
-0
lines changed

keras_hub/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@
9090
from keras_hub.src.models.gemma3.gemma3_image_converter import (
9191
Gemma3ImageConverter as Gemma3ImageConverter,
9292
)
93+
from keras_hub.src.models.hgnetv2.hgnetv2_image_converter import (
94+
HGNetV2ImageConverter as HGNetV2ImageConverter,
95+
)
9396
from keras_hub.src.models.mit.mit_image_converter import (
9497
MiTImageConverter as MiTImageConverter,
9598
)

keras_hub/api/models/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,15 @@
294294
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
295295
GPTNeoXTokenizer as GPTNeoXTokenizer,
296296
)
297+
from keras_hub.src.models.hgnetv2.hgnetv2_backbone import (
298+
HGNetV2Backbone as HGNetV2Backbone,
299+
)
300+
from keras_hub.src.models.hgnetv2.hgnetv2_image_classifier import (
301+
HGNetV2ImageClassifier as HGNetV2ImageClassifier,
302+
)
303+
from keras_hub.src.models.hgnetv2.hgnetv2_image_classifier_preprocessor import (
304+
HGNetV2ImageClassifierPreprocessor as HGNetV2ImageClassifierPreprocessor,
305+
)
297306
from keras_hub.src.models.image_classifier import (
298307
ImageClassifier as ImageClassifier,
299308
)

keras_hub/src/models/hgnetv2/__init__.py

Whitespace-only changes.
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import keras
2+
3+
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.models.backbone import Backbone
5+
from keras_hub.src.models.hgnetv2.hgnetv2_encoder import HGNetV2Encoder
6+
from keras_hub.src.models.hgnetv2.hgnetv2_layers import HGNetV2Embeddings
7+
from keras_hub.src.utils.keras_utils import standardize_data_format
8+
9+
10+
@keras_hub_export("keras_hub.models.HGNetV2Backbone")
11+
class HGNetV2Backbone(Backbone):
12+
"""This class represents a Keras Backbone of the HGNetV2 model.
13+
14+
This class implements an HGNetV2 backbone architecture, a convolutional
15+
neural network (CNN) optimized for GPU efficiency. HGNetV2 is frequently
16+
used as a lightweight CNN backbone in object detection pipelines like
17+
RT-DETR and YOLO variants, delivering strong performance on classification
18+
and detection tasks, with speed-ups and accuracy gains compared to larger
19+
CNN backbones.
20+
21+
Args:
22+
depths: list of ints, the number of blocks in each stage.
23+
embedding_size: int, the size of the embedding layer.
24+
hidden_sizes: list of ints, the sizes of the hidden layers.
25+
stem_channels: list of ints, the channels for the stem part.
26+
hidden_act: str, the activation function for hidden layers.
27+
use_learnable_affine_block: bool, whether to use learnable affine
28+
transformations.
29+
stackwise_stage_filters: list of tuples, where each tuple contains
30+
configuration for a stage: (stage_in_channels, stage_mid_channels,
31+
stage_out_channels, stage_num_blocks, stage_num_of_layers,
32+
stage_kernel_size).
33+
- stage_in_channels: int, input channels for the stage
34+
- stage_mid_channels: int, middle channels for the stage
35+
- stage_out_channels: int, output channels for the stage
36+
- stage_num_blocks: int, number of blocks in the stage
37+
- stage_num_of_layers: int, number of layers in each block
38+
- stage_kernel_size: int, kernel size for the stage
39+
apply_downsample: list of bools, whether to downsample in each stage.
40+
use_lightweight_conv_block: list of bools, whether to use HGNetV2
41+
lightweight convolutional blocks in each stage.
42+
image_shape: tuple, the shape of the input image without the batch size.
43+
Defaults to `(None, None, 3)`.
44+
data_format: `None` or str, the data format ('channels_last' or
45+
'channels_first'). If not specified, defaults to the
46+
`image_data_format` value in your Keras config.
47+
out_features: list of str or `None`, the names of the output features to
48+
return. If `None`, returns all available features from all stages.
49+
Defaults to `None`.
50+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`, the data
51+
type for computations and weights.
52+
53+
Examples:
54+
```python
55+
import numpy as np
56+
from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone
57+
input_data = np.ones(shape=(8, 224, 224, 3))
58+
59+
# Pretrained backbone.
60+
model = keras_hub.models.HGNetV2Backbone.from_preset(
61+
"hgnetv2_b5_ssld_stage2_ft_in1k"
62+
)
63+
model(input_data)
64+
65+
# Randomly initialized backbone with a custom config.
66+
model = HGNetV2Backbone(
67+
depths=[1, 2, 4],
68+
embedding_size=32,
69+
hidden_sizes=[64, 128, 256],
70+
stem_channels=[3, 16, 32],
71+
hidden_act="relu",
72+
use_learnable_affine_block=False,
73+
stackwise_stage_filters=[
74+
(32, 16, 64, 1, 1, 3), # Stage 0
75+
(64, 32, 128, 2, 1, 3), # Stage 1
76+
(128, 64, 256, 4, 1, 3), # Stage 2
77+
],
78+
apply_downsample=[False, True, True],
79+
use_lightweight_conv_block=[False, False, False],
80+
image_shape=(224, 224, 3),
81+
)
82+
model(input_data)
83+
```
84+
"""
85+
86+
def __init__(
87+
self,
88+
depths,
89+
embedding_size,
90+
hidden_sizes,
91+
stem_channels,
92+
hidden_act,
93+
use_learnable_affine_block,
94+
stackwise_stage_filters,
95+
apply_downsample,
96+
use_lightweight_conv_block,
97+
image_shape=(None, None, 3),
98+
data_format=None,
99+
out_features=None,
100+
dtype=None,
101+
**kwargs,
102+
):
103+
name = kwargs.get("name", None)
104+
data_format = standardize_data_format(data_format)
105+
channel_axis = -1 if data_format == "channels_last" else 1
106+
self.image_shape = image_shape
107+
(
108+
stage_in_channels,
109+
stage_mid_channels,
110+
stage_out_filters,
111+
stage_num_blocks,
112+
stage_num_of_layers,
113+
stage_kernel_size,
114+
) = zip(*stackwise_stage_filters)
115+
116+
# === Layers ===
117+
self.embedder_layer = HGNetV2Embeddings(
118+
stem_channels=stem_channels,
119+
hidden_act=hidden_act,
120+
use_learnable_affine_block=use_learnable_affine_block,
121+
data_format=data_format,
122+
channel_axis=channel_axis,
123+
name=f"{name}_embedder" if name else "embedder",
124+
dtype=dtype,
125+
)
126+
self.encoder_layer = HGNetV2Encoder(
127+
stage_in_channels=stage_in_channels,
128+
stage_mid_channels=stage_mid_channels,
129+
stage_out_channels=stage_out_filters,
130+
stage_num_blocks=stage_num_blocks,
131+
stage_num_of_layers=stage_num_of_layers,
132+
apply_downsample=apply_downsample,
133+
use_lightweight_conv_block=use_lightweight_conv_block,
134+
stage_kernel_size=stage_kernel_size,
135+
use_learnable_affine_block=use_learnable_affine_block,
136+
data_format=data_format,
137+
channel_axis=channel_axis,
138+
name=f"{name}_encoder" if name else "encoder",
139+
dtype=dtype,
140+
)
141+
self.stage_names = ["stem"] + [
142+
f"stage{i + 1}" for i in range(len(stackwise_stage_filters))
143+
]
144+
self.out_features = (
145+
out_features if out_features is not None else self.stage_names
146+
)
147+
148+
# === Functional Model ===
149+
pixel_values = keras.layers.Input(
150+
shape=image_shape, name="pixel_values_input"
151+
)
152+
embedding_output = self.embedder_layer(pixel_values)
153+
all_encoder_hidden_states_tuple = self.encoder_layer(embedding_output)
154+
feature_maps_output = {
155+
stage_name: all_encoder_hidden_states_tuple[idx]
156+
for idx, stage_name in enumerate(self.stage_names)
157+
if stage_name in self.out_features
158+
}
159+
super().__init__(
160+
inputs=pixel_values, outputs=feature_maps_output, **kwargs
161+
)
162+
163+
# === Config ===
164+
self.depths = depths
165+
self.embedding_size = embedding_size
166+
self.hidden_sizes = hidden_sizes
167+
self.stem_channels = stem_channels
168+
self.hidden_act = hidden_act
169+
self.use_learnable_affine_block = use_learnable_affine_block
170+
self.stackwise_stage_filters = stackwise_stage_filters
171+
self.apply_downsample = apply_downsample
172+
self.use_lightweight_conv_block = use_lightweight_conv_block
173+
self.data_format = data_format
174+
175+
def get_config(self):
176+
config = super().get_config()
177+
config.update(
178+
{
179+
"depths": self.depths,
180+
"embedding_size": self.embedding_size,
181+
"hidden_sizes": self.hidden_sizes,
182+
"stem_channels": self.stem_channels,
183+
"hidden_act": self.hidden_act,
184+
"use_learnable_affine_block": self.use_learnable_affine_block,
185+
"stackwise_stage_filters": self.stackwise_stage_filters,
186+
"apply_downsample": self.apply_downsample,
187+
"use_lightweight_conv_block": self.use_lightweight_conv_block,
188+
"image_shape": self.image_shape,
189+
"out_features": self.out_features,
190+
"data_format": self.data_format,
191+
}
192+
)
193+
return config
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import keras
2+
import numpy as np
3+
import pytest
4+
from absl.testing import parameterized
5+
6+
from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone
7+
from keras_hub.src.tests.test_case import TestCase
8+
9+
10+
class HGNetV2BackboneTest(TestCase):
11+
def setUp(self):
12+
self.default_input_shape = (64, 64, 3)
13+
self.num_channels = self.default_input_shape[-1]
14+
self.batch_size = 2
15+
self.stem_channels = [self.num_channels, 16, 32]
16+
self.default_stage_out_filters = [64, 128]
17+
self.default_num_stages = 2
18+
self.stackwise_stage_filters = [
19+
[32, 16, 64, 1, 1, 3],
20+
[64, 32, 128, 1, 1, 3],
21+
]
22+
self.init_kwargs = {
23+
"embedding_size": self.stem_channels[-1],
24+
"stem_channels": self.stem_channels,
25+
"hidden_act": "relu",
26+
"use_learnable_affine_block": False,
27+
"image_shape": self.default_input_shape,
28+
"depths": [1] * self.default_num_stages,
29+
"hidden_sizes": [
30+
stage[2] for stage in self.stackwise_stage_filters
31+
],
32+
"stackwise_stage_filters": self.stackwise_stage_filters,
33+
"apply_downsample": [False, True],
34+
"use_lightweight_conv_block": [False, False],
35+
# Explicitly pass the out_features arg to ensure comprehensive
36+
# test coverage for D-FINE.
37+
"out_features": ["stem", "stage1", "stage2"],
38+
}
39+
self.input_data = keras.ops.convert_to_tensor(
40+
np.random.rand(self.batch_size, *self.default_input_shape).astype(
41+
np.float32
42+
)
43+
)
44+
45+
@parameterized.named_parameters(
46+
(
47+
"default",
48+
[False, True],
49+
[False, False],
50+
2,
51+
{
52+
"stem": (2, 16, 16, 32),
53+
"stage1": (2, 16, 16, 64),
54+
"stage2": (2, 8, 8, 128),
55+
},
56+
),
57+
(
58+
"early_downsample_light_blocks",
59+
[True, True],
60+
[True, True],
61+
2,
62+
{
63+
"stem": (2, 16, 16, 32),
64+
"stage1": (2, 8, 8, 64),
65+
"stage2": (2, 4, 4, 128),
66+
},
67+
),
68+
(
69+
"single_stage_no_downsample",
70+
[False],
71+
[False],
72+
1,
73+
{
74+
"stem": (2, 16, 16, 32),
75+
"stage1": (2, 16, 16, 64),
76+
},
77+
),
78+
(
79+
"all_no_downsample",
80+
[False, False],
81+
[False, False],
82+
2,
83+
{
84+
"stem": (2, 16, 16, 32),
85+
"stage1": (2, 16, 16, 64),
86+
"stage2": (2, 16, 16, 128),
87+
},
88+
),
89+
)
90+
def test_backbone_basics(
91+
self,
92+
apply_downsample,
93+
use_lightweight_conv_block,
94+
num_stages,
95+
expected_shapes,
96+
):
97+
test_filters = self.stackwise_stage_filters[:num_stages]
98+
hidden_sizes = [stage[2] for stage in test_filters]
99+
test_kwargs = {
100+
**self.init_kwargs,
101+
"depths": [1] * num_stages,
102+
"hidden_sizes": hidden_sizes,
103+
"stackwise_stage_filters": test_filters,
104+
"apply_downsample": apply_downsample,
105+
"use_lightweight_conv_block": use_lightweight_conv_block,
106+
"out_features": ["stem"]
107+
+ [f"stage{i + 1}" for i in range(num_stages)],
108+
}
109+
self.run_vision_backbone_test(
110+
cls=HGNetV2Backbone,
111+
init_kwargs=test_kwargs,
112+
input_data=self.input_data,
113+
expected_output_shape=expected_shapes,
114+
run_mixed_precision_check=False,
115+
run_data_format_check=False,
116+
)
117+
118+
@pytest.mark.large
119+
def test_saved_model(self):
120+
self.run_model_saving_test(
121+
cls=HGNetV2Backbone,
122+
init_kwargs=self.init_kwargs,
123+
input_data=self.input_data,
124+
)
125+
126+
@pytest.mark.extra_large
127+
def test_all_presets(self):
128+
for preset in HGNetV2Backbone.presets:
129+
self.run_preset_test(
130+
cls=HGNetV2Backbone,
131+
preset=preset,
132+
input_data=self.input_data,
133+
)

0 commit comments

Comments
 (0)