From e100813bda7b98cd2d16983246a591388c36034d Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Mon, 9 Jun 2025 17:07:36 +0400 Subject: [PATCH 1/8] init: Add initial project structure and files --- keras_hub/api/layers/__init__.py | 3 + keras_hub/api/models/__init__.py | 9 + keras_hub/src/models/hgnetv2/__init__.py | 0 .../src/models/hgnetv2/hgnetv2_backbone.py | 194 ++++ .../models/hgnetv2/hgnetv2_backbone_test.py | 133 +++ .../src/models/hgnetv2/hgnetv2_encoder.py | 147 +++ .../hgnetv2/hgnetv2_image_classifier.py | 127 +++ .../hgnetv2_image_classifier_preprocessor.py | 14 + .../hgnetv2/hgnetv2_image_classifier_test.py | 115 +++ .../models/hgnetv2/hgnetv2_image_converter.py | 8 + .../src/models/hgnetv2/hgnetv2_layers.py | 926 ++++++++++++++++++ .../src/models/hgnetv2/hgnetv2_presets.py | 133 +++ .../convert_hgnetv2_checkpoints.py | 399 ++++++++ 13 files changed, 2208 insertions(+) create mode 100644 keras_hub/src/models/hgnetv2/__init__.py create mode 100644 keras_hub/src/models/hgnetv2/hgnetv2_backbone.py create mode 100644 keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py create mode 100644 keras_hub/src/models/hgnetv2/hgnetv2_encoder.py create mode 100644 keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py create mode 100644 keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_preprocessor.py create mode 100644 keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py create mode 100644 keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py create mode 100644 keras_hub/src/models/hgnetv2/hgnetv2_layers.py create mode 100644 keras_hub/src/models/hgnetv2/hgnetv2_presets.py create mode 100644 tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 4536cd7f66..51a1341477 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -90,6 +90,9 @@ from keras_hub.src.models.gemma3.gemma3_image_converter import ( Gemma3ImageConverter as Gemma3ImageConverter, ) +from keras_hub.src.models.hgnetv2.hgnetv2_image_converter import ( + HGNetV2ImageConverter as HGNetV2ImageConverter, +) from keras_hub.src.models.mit.mit_image_converter import ( MiTImageConverter as MiTImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 4abcaf0fbc..378e121569 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -294,6 +294,15 @@ from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import ( GPTNeoXTokenizer as GPTNeoXTokenizer, ) +from keras_hub.src.models.hgnetv2.hgnetv2_backbone import ( + HGNetV2Backbone as HGNetV2Backbone, +) +from keras_hub.src.models.hgnetv2.hgnetv2_image_classifier import ( + HGNetV2ImageClassifier as HGNetV2ImageClassifier, +) +from keras_hub.src.models.hgnetv2.hgnetv2_image_classifier_preprocessor import ( + HGNetV2ImageClassifierPreprocessor as HGNetV2ImageClassifierPreprocessor, +) from keras_hub.src.models.image_classifier import ( ImageClassifier as ImageClassifier, ) diff --git a/keras_hub/src/models/hgnetv2/__init__.py b/keras_hub/src/models/hgnetv2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py b/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py new file mode 100644 index 0000000000..4bf91cda9a --- /dev/null +++ b/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py @@ -0,0 +1,194 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.hgnetv2.hgnetv2_encoder import HGNetV2Encoder +from keras_hub.src.models.hgnetv2.hgnetv2_layers import HGNetV2Embeddings +from keras_hub.src.utils.keras_utils import standardize_data_format + + +@keras_hub_export("keras_hub.models.HGNetV2Backbone") +class HGNetV2Backbone(Backbone): + """This class represents a Keras Backbone of the HGNetV2 model. + + This class implements an HGNetV2 backbone architecture. + + Args: + initializer_range: float, the range for initializing weights. + depths: list of ints, the number of blocks in each stage. + embedding_size: int, the size of the embedding layer. + hidden_sizes: list of ints, the sizes of the hidden layers. + stem_channels: list of ints, the channels for the stem part. + hidden_act: str, the activation function for hidden layers. + use_learnable_affine_block: bool, whether to use learnable affine + transformations. + num_channels: int, the number of channels in the input image. + stage_in_channels: list of ints, the input channels for each stage. + stage_mid_channels: list of ints, the middle channels for each stage. + stage_out_channels: list of ints, the output channels for each stage. + stage_num_blocks: list of ints, the number of blocks in each stage. + stage_numb_of_layers: list of ints, the number of layers in each block. + stage_downsample: list of bools, whether to downsample in each stage. + stage_light_block: list of bools, whether to use light blocks in each + stage. + stage_kernel_size: list of ints, the kernel sizes for each stage. + image_shape: tuple, the shape of the input image without the batch size. + Defaults to `(None, None, 3)`. + data_format: `None` or str, the data format ('channels_last' or + 'channels_first'). If not specified, defaults to the + `image_data_format` value in your Keras config. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`, the data + type for computations and weights. + + Examples: + ```python + import numpy as np + from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone + input_data = np.ones(shape=(8, 224, 224, 3)) + + # Pretrained backbone. + model = keras_hub.models.HGNetV2Backbone.from_preset( + "hgnetv2_b5.ssld_stage2_ft_in1k" + ) + model(input_data) + + # Randomly initialized backbone with a custom config. + model = HGNetV2Backbone( + initializer_range=0.02, + depths=[1, 2, 4], + embedding_size=32, + hidden_sizes=[64, 128, 256], + stem_channels=[3, 16, 32], + hidden_act="relu", + use_learnable_affine_block=False, + num_channels=3, + stage_in_channels=[32, 64, 128], + stage_mid_channels=[16, 32, 64], + stage_out_channels=[64, 128, 256], + stage_num_blocks=[1, 2, 4], + stage_numb_of_layers=[1, 1, 1], + stage_downsample=[False, True, True], + stage_light_block=[False, False, False], + stage_kernel_size=[3, 3, 3], + image_shape=(224, 224, 3), + ) + model(input_data) + ``` + """ + + def __init__( + self, + initializer_range, + depths, + embedding_size, + hidden_sizes, + stem_channels, + hidden_act, + use_learnable_affine_block, + num_channels, + stage_in_channels, + stage_mid_channels, + stage_out_channels, + stage_num_blocks, + stage_numb_of_layers, + stage_downsample, + stage_light_block, + stage_kernel_size, + image_shape=(None, None, 3), + data_format=None, + dtype=None, + **kwargs, + ): + name = kwargs.get("name", None) + data_format = standardize_data_format(data_format) + channel_axis = -1 if data_format == "channels_last" else 1 + self.image_shape = image_shape + + # === Layers === + self.embedder_layer = HGNetV2Embeddings( + stem_channels=stem_channels, + hidden_act=hidden_act, + use_learnable_affine_block=use_learnable_affine_block, + num_channels=num_channels, + data_format=data_format, + channel_axis=channel_axis, + name=f"{name}_embedder" if name else "embedder", + dtype=dtype, + ) + self.encoder_layer = HGNetV2Encoder( + stage_in_channels=stage_in_channels, + stage_mid_channels=stage_mid_channels, + stage_out_channels=stage_out_channels, + stage_num_blocks=stage_num_blocks, + stage_numb_of_layers=stage_numb_of_layers, + stage_downsample=stage_downsample, + stage_light_block=stage_light_block, + stage_kernel_size=stage_kernel_size, + use_learnable_affine_block=use_learnable_affine_block, + data_format=data_format, + channel_axis=channel_axis, + name=f"{name}_encoder" if name else "encoder", + dtype=dtype, + ) + self.stage_names = [f"stage{i}" for i in range(len(stage_in_channels))] + self.out_features = self.stage_names + + # === Functional Model === + pixel_values = keras.layers.Input( + shape=image_shape, name="pixel_values_input" + ) + embedding_output = self.embedder_layer(pixel_values) + all_encoder_hidden_states_tuple = self.encoder_layer(embedding_output) + feature_maps_output = { + stage_name: all_encoder_hidden_states_tuple[idx + 1] + for idx, stage_name in enumerate(self.stage_names) + if stage_name in self.out_features + } + super().__init__( + inputs=pixel_values, outputs=feature_maps_output, **kwargs + ) + + # === Config === + self.initializer_range = initializer_range + self.depths = depths + self.embedding_size = embedding_size + self.hidden_sizes = hidden_sizes + self.stem_channels = stem_channels + self.hidden_act = hidden_act + self.use_learnable_affine_block = use_learnable_affine_block + self.num_channels = num_channels + self.stage_in_channels = stage_in_channels + self.stage_mid_channels = stage_mid_channels + self.stage_out_channels = stage_out_channels + self.stage_num_blocks = stage_num_blocks + self.stage_numb_of_layers = stage_numb_of_layers + self.stage_downsample = stage_downsample + self.stage_light_block = stage_light_block + self.stage_kernel_size = stage_kernel_size + self.data_format = data_format + + def get_config(self): + config = super().get_config() + config.update( + { + "initializer_range": self.initializer_range, + "depths": self.depths, + "embedding_size": self.embedding_size, + "hidden_sizes": self.hidden_sizes, + "stem_channels": self.stem_channels, + "hidden_act": self.hidden_act, + "use_learnable_affine_block": self.use_learnable_affine_block, + "num_channels": self.num_channels, + "stage_in_channels": self.stage_in_channels, + "stage_mid_channels": self.stage_mid_channels, + "stage_out_channels": self.stage_out_channels, + "stage_num_blocks": self.stage_num_blocks, + "stage_numb_of_layers": self.stage_numb_of_layers, + "stage_downsample": self.stage_downsample, + "stage_light_block": self.stage_light_block, + "stage_kernel_size": self.stage_kernel_size, + "image_shape": self.image_shape, + "data_format": self.data_format, + } + ) + return config diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py b/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py new file mode 100644 index 0000000000..71889a57c5 --- /dev/null +++ b/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py @@ -0,0 +1,133 @@ +import keras +import numpy as np +import pytest +from absl.testing import parameterized + +from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone +from keras_hub.src.tests.test_case import TestCase + + +class HGNetV2BackboneTest(TestCase): + def setUp(self): + self.default_input_shape = (64, 64, 3) + self.num_channels = self.default_input_shape[-1] + self.stem_channels = [self.num_channels, 16, 32] + self.default_stage_in_channels = [self.stem_channels[-1], 64] + self.default_stage_mid_channels = [16, 32] + self.default_stage_out_channels = [64, 128] + self.default_num_stages = len(self.default_stage_in_channels) + + self.init_kwargs = { + "initializer_range": 0.02, + "depths": [1] * self.default_num_stages, + "embedding_size": self.stem_channels[-1], + "hidden_sizes": self.default_stage_out_channels, + "stem_channels": self.stem_channels, + "hidden_act": "relu", + "use_learnable_affine_block": False, + "num_channels": self.num_channels, + "stage_in_channels": self.default_stage_in_channels, + "stage_mid_channels": self.default_stage_mid_channels, + "stage_out_channels": self.default_stage_out_channels, + "stage_num_blocks": [1] * self.default_num_stages, + "stage_numb_of_layers": [1] * self.default_num_stages, + "stage_downsample": [False, True], + "stage_light_block": [False, False], + "stage_kernel_size": [3] * self.default_num_stages, + "image_shape": self.default_input_shape, + } + self.input_size = self.default_input_shape[:2] + self.batch_size = 2 + self.input_data = keras.ops.convert_to_tensor( + np.random.rand(self.batch_size, *self.default_input_shape).astype( + np.float32 + ) + ) + + @parameterized.named_parameters( + ( + "default_config", + [False, True], + [False, False], + 2, + {"stage0": (2, 16, 16, 64), "stage1": (2, 8, 8, 128)}, + ), + ( + "early_downsample_light_blocks", + [True, True], + [True, True], + 2, + {"stage0": (2, 8, 8, 64), "stage1": (2, 4, 4, 128)}, + ), + ( + "single_stage_no_downsample", + [False], + [False], + 1, + {"stage0": (2, 16, 16, 64)}, + ), + ( + "all_no_downsample", + [False, False], + [False, False], + 2, + {"stage0": (2, 16, 16, 64), "stage1": (2, 16, 16, 128)}, + ), + ) + def test_backbone_basics( + self, + stage_downsample_config, + stage_light_block_config, + num_stages, + expected_shapes, + ): + current_init_kwargs = self.init_kwargs.copy() + current_init_kwargs["depths"] = [1] * num_stages + current_init_kwargs["hidden_sizes"] = self.default_stage_out_channels[ + :num_stages + ] + current_init_kwargs["stage_in_channels"] = ( + self.default_stage_in_channels[:num_stages] + ) + current_init_kwargs["stage_mid_channels"] = ( + self.default_stage_mid_channels[:num_stages] + ) + current_init_kwargs["stage_out_channels"] = ( + self.default_stage_out_channels[:num_stages] + ) + current_init_kwargs["stage_num_blocks"] = [1] * num_stages + current_init_kwargs["stage_numb_of_layers"] = [1] * num_stages + current_init_kwargs["stage_kernel_size"] = [3] * num_stages + current_init_kwargs["stage_downsample"] = stage_downsample_config + current_init_kwargs["stage_light_block"] = stage_light_block_config + if num_stages > 0: + current_init_kwargs["stage_in_channels"][0] = self.stem_channels[-1] + for i in range(1, num_stages): + current_init_kwargs["stage_in_channels"][i] = ( + current_init_kwargs["stage_out_channels"][i - 1] + ) + self.run_vision_backbone_test( + cls=HGNetV2Backbone, + init_kwargs=current_init_kwargs, + input_data=self.input_data, + expected_output_shape=expected_shapes, + run_mixed_precision_check=False, + run_data_format_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=HGNetV2Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in HGNetV2Backbone.presets: + self.run_preset_test( + cls=HGNetV2Backbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py b/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py new file mode 100644 index 0000000000..b7108ed87d --- /dev/null +++ b/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py @@ -0,0 +1,147 @@ +import keras + +from keras_hub.src.models.hgnetv2.hgnetv2_layers import HGNetV2Stage + + +@keras.saving.register_keras_serializable(package="keras_hub") +class HGNetV2Encoder(keras.layers.Layer): + """This class represents the encoder of the HGNetV2 model. + + This class implements the encoder part of the HGNetV2 architecture, which + consists of multiple stages. Each stage is an instance of `HGNetV2Stage`, + and the encoder processes the input through these stages sequentially, + collecting the hidden states at each stage. + + Args: + stage_in_channels: A list of integers, specifying the input channels + for each stage. + stage_mid_channels: A list of integers, specifying the mid channels for + each stage. + stage_out_channels: A list of integers, specifying the output channels + for each stage. + stage_num_blocks: A list of integers, specifying the number of blocks + in each stage. + stage_numb_of_layers: A list of integers, specifying the number of + layers in each block of each stage. + stage_downsample: A list of booleans or integers, indicating whether to + downsample in each stage. + stage_light_block: A list of booleans, indicating whether to use light + blocks in each stage. + stage_kernel_size: A list of integers or tuples, specifying the kernel + size for each stage. + use_learnable_affine_block: A boolean, indicating whether to use + learnable affine transformations in the blocks. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the inputs. + `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` while `"channels_first"` + corresponds to inputs with shape `(batch_size, channels, height, + width)`. It defaults to the `image_data_format` value found in your + Keras config file at `~/.keras/keras.json`. If you never set it, + then it will be `"channels_last"`. + channel_axis: int, the axis that represents the channels. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + stage_in_channels, + stage_mid_channels, + stage_out_channels, + stage_num_blocks, + stage_numb_of_layers, + stage_downsample, + stage_light_block, + stage_kernel_size, + use_learnable_affine_block, + data_format=None, + channel_axis=None, + **kwargs, + ): + super().__init__(**kwargs) + self.stage_in_channels = stage_in_channels + self.stage_mid_channels = stage_mid_channels + self.stage_out_channels = stage_out_channels + self.stage_num_blocks = stage_num_blocks + self.stage_numb_of_layers = stage_numb_of_layers + self.stage_downsample = stage_downsample + self.stage_light_block = stage_light_block + self.stage_kernel_size = stage_kernel_size + self.use_learnable_affine_block = use_learnable_affine_block + self.data_format = data_format + self.channel_axis = channel_axis + + self.stages_list = [] + for stage_idx in range(len(self.stage_in_channels)): + stage_layer = HGNetV2Stage( + stage_in_channels=self.stage_in_channels, + stage_mid_channels=self.stage_mid_channels, + stage_out_channels=self.stage_out_channels, + stage_num_blocks=self.stage_num_blocks, + stage_numb_of_layers=self.stage_numb_of_layers, + stage_downsample=self.stage_downsample, + stage_light_block=self.stage_light_block, + stage_kernel_size=self.stage_kernel_size, + use_learnable_affine_block=self.use_learnable_affine_block, + stage_index=stage_idx, + data_format=self.data_format, + channel_axis=self.channel_axis, + drop_path=0.0, + name=f"{self.name}_stage_{stage_idx}" + if self.name + else f"stage_{stage_idx}", + dtype=self.dtype, + ) + self.stages_list.append(stage_layer) + + def build(self, input_shape): + super().build(input_shape) + current_input_shape = input_shape + for stage_keras_layer in self.stages_list: + stage_keras_layer.build(current_input_shape) + current_input_shape = stage_keras_layer.compute_output_shape( + current_input_shape + ) + + def call( + self, + hidden_state, + training=None, + ): + all_hidden_states_list = [] + current_hidden_state = hidden_state + for stage_keras_layer in self.stages_list: + all_hidden_states_list.append(current_hidden_state) + current_hidden_state = stage_keras_layer( + current_hidden_state, training=training + ) + all_hidden_states_list.append(current_hidden_state) + return tuple(all_hidden_states_list) + + def compute_output_shape(self, input_shape): + current_shape = input_shape + all_hidden_shapes = [input_shape] + for stage_keras_layer in self.stages_list: + current_shape = stage_keras_layer.compute_output_shape( + current_shape + ) + all_hidden_shapes.append(current_shape) + return tuple(all_hidden_shapes) + + def get_config(self): + config = super().get_config() + config.update( + { + "stage_in_channels": self.stage_in_channels, + "stage_mid_channels": self.stage_mid_channels, + "stage_out_channels": self.stage_out_channels, + "stage_num_blocks": self.stage_num_blocks, + "stage_numb_of_layers": self.stage_numb_of_layers, + "stage_downsample": self.stage_downsample, + "stage_light_block": self.stage_light_block, + "stage_kernel_size": self.stage_kernel_size, + "use_learnable_affine_block": self.use_learnable_affine_block, + "data_format": self.data_format, + } + ) + return config diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py new file mode 100644 index 0000000000..fb3a12c90e --- /dev/null +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py @@ -0,0 +1,127 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone +from keras_hub.src.models.hgnetv2.hgnetv2_image_classifier_preprocessor import ( + HGNetV2ImageClassifierPreprocessor, +) +from keras_hub.src.models.hgnetv2.hgnetv2_layers import HGNetV2ConvLayer +from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.models.task import Task + + +@keras_hub_export("keras_hub.models.HGNetV2ImageClassifier") +class HGNetV2ImageClassifier(ImageClassifier): + backbone_cls = HGNetV2Backbone + preprocessor_cls = HGNetV2ImageClassifierPreprocessor + + def __init__( + self, + backbone, + preprocessor, + num_classes, + head_filters, + pooling="avg", + activation=None, + dropout=0.0, + head_dtype=None, + use_learnable_affine_block_head=False, + **kwargs, + ): + name = kwargs.get("name", "hgnetv2_image_classifier") + head_dtype = head_dtype or backbone.dtype_policy + data_format = getattr(backbone, "data_format", "channels_last") + channel_axis = -1 if data_format == "channels_last" else 1 + + # NOTE: This isn't in the usual order because the config is needed + # before layer initialization and the functional model. + # === Config === + self.num_classes = num_classes + self.pooling = pooling + self.activation = activation + self.dropout = dropout + self.head_filters = head_filters + + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.last_conv = HGNetV2ConvLayer( + in_channels=backbone.hidden_sizes[-1], + out_channels=self.head_filters, + kernel_size=1, + stride=1, + groups=1, + activation="relu", + use_learnable_affine_block=use_learnable_affine_block_head, + data_format=data_format, + channel_axis=channel_axis, + name="head_last", + dtype=head_dtype, + ) + if self.pooling == "avg": + self.pooler = keras.layers.GlobalAveragePooling2D( + data_format=data_format, + dtype=head_dtype, + name=f"{name}_avg_pool" if name else "avg_pool", + ) + elif self.pooling == "max": + self.pooler = keras.layers.GlobalMaxPooling2D( + data_format=data_format, + dtype=head_dtype, + name=f"{name}_max_pool" if name else "max_pool", + ) + # Check valid pooling. + else: + raise ValueError( + "Unknown `pooling` type. Polling should be either `'avg'` or " + f"`'max'`. Received: pooling={pooling}." + ) + + self.flatten_layer = keras.layers.Flatten( + dtype=head_dtype, + name=f"{name}_flatten" if name else "flatten", + ) + self.output_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=head_dtype, + name=f"{name}_output_dropout" if name else "output_dropout", + ) + if self.num_classes > 0: + self.output_dense = keras.layers.Dense( + units=num_classes, + activation=activation, + dtype=head_dtype, + name="predictions", + ) + else: + self.output_dense = keras.layers.Identity(name="predictions") + + # === Functional Model === + inputs = backbone.input + feature_maps = backbone(inputs) + last_stage_name = backbone.stage_names[-1] + last_hidden_state_for_pooling = feature_maps[last_stage_name] + x = self.last_conv(last_hidden_state_for_pooling) + x = self.pooler(x) + x = self.flatten_layer(x) + x = self.output_dropout(x) + outputs = self.output_dense(x) + Task.__init__( + self, + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def get_config(self): + config = Task.get_config(self) + config.update( + { + "num_classes": self.num_classes, + "pooling": self.pooling, + "activation": self.activation, + "dropout": self.dropout, + "head_filters": self.head_filters, + } + ) + return config diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_preprocessor.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_preprocessor.py new file mode 100644 index 0000000000..75a47d5b7e --- /dev/null +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_preprocessor.py @@ -0,0 +1,14 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone +from keras_hub.src.models.hgnetv2.hgnetv2_image_converter import ( + HGNetV2ImageConverter, +) +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) + + +@keras_hub_export("keras_hub.models.HGNetV2ImageClassifierPreprocessor") +class HGNetV2ImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = HGNetV2Backbone + image_converter_cls = HGNetV2ImageConverter diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py new file mode 100644 index 0000000000..3e00378246 --- /dev/null +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py @@ -0,0 +1,115 @@ +import numpy as np +import pytest + +from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone +from keras_hub.src.models.hgnetv2.hgnetv2_image_classifier import ( + HGNetV2ImageClassifier, +) +from keras_hub.src.models.hgnetv2.hgnetv2_image_classifier_preprocessor import ( + HGNetV2ImageClassifierPreprocessor, +) +from keras_hub.src.models.hgnetv2.hgnetv2_image_converter import ( + HGNetV2ImageConverter, +) +from keras_hub.src.tests.test_case import TestCase + + +class HGNetV2ImageClassifierTest(TestCase): + def setUp(self): + self.batch_size = 2 + self.height = 64 + self.width = 64 + self.num_channels = 3 + self.image_input_shape = (self.height, self.width, self.num_channels) + self.num_classes = 3 + self.images = np.ones( + (self.batch_size, *self.image_input_shape), dtype="float32" + ) + self.labels = np.random.randint(0, self.num_classes, self.batch_size) + num_stages = 2 + # Setup model. + stem_channels = [self.num_channels, 16, 32] + stage_in_channels = [stem_channels[-1], 64][:num_stages] + stage_mid_channels = [16, 32][:num_stages] + stage_out_channels = [64, 128][:num_stages] + stage_num_blocks = [1] * num_stages + stage_numb_of_layers = [1] * num_stages + stage_downsample = [False, True][:num_stages] + stage_light_block = [False, False][:num_stages] + stage_kernel_size = [3] * num_stages + + self.backbone = HGNetV2Backbone( + initializer_range=0.02, + depths=stage_num_blocks, + embedding_size=stem_channels[-1], + hidden_sizes=stage_out_channels, + stem_channels=stem_channels, + hidden_act="relu", + use_learnable_affine_block=False, + num_channels=self.num_channels, + stage_in_channels=stage_in_channels, + stage_mid_channels=stage_mid_channels, + stage_out_channels=stage_out_channels, + stage_num_blocks=stage_num_blocks, + stage_numb_of_layers=stage_numb_of_layers, + stage_downsample=stage_downsample, + stage_light_block=stage_light_block, + stage_kernel_size=stage_kernel_size, + image_shape=self.image_input_shape, + ) + self.image_converter = HGNetV2ImageConverter( + height=self.height, width=self.width + ) + self.preprocessor = HGNetV2ImageClassifierPreprocessor( + image_converter=self.image_converter + ) + self.init_kwargs = { + "backbone": self.backbone, + "preprocessor": self.preprocessor, + "num_classes": self.num_classes, + "head_filters": stage_out_channels[-1], + } + self.train_data = ( + self.images, + self.labels, + ) + self.expected_backbone_output_shapes = { + "stage0": (self.batch_size, 16, 16, 64), + "stage1": (self.batch_size, 8, 8, 128), + } + self.preset_image_size = 224 + self.images_for_presets = np.ones( + ( + self.batch_size, + self.preset_image_size, + self.preset_image_size, + self.num_channels, + ), + dtype="float32", + ) + + def test_classifier_basics(self): + self.run_task_test( + cls=HGNetV2ImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(self.batch_size, self.num_classes), + ) + + @pytest.mark.large + def test_all_presets(self): + for preset in HGNetV2ImageClassifier.presets: + self.run_preset_test( + cls=HGNetV2ImageClassifier, + preset=preset, + input_data=self.images_for_presets, + expected_output_shape=(self.batch_size, 1000), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=HGNetV2ImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py new file mode 100644 index 0000000000..b815ee0bfd --- /dev/null +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone + + +@keras_hub_export("keras_hub.layers.HGNetV2ImageConverter") +class HGNetV2ImageConverter(ImageConverter): + backbone_cls = HGNetV2Backbone diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_layers.py b/keras_hub/src/models/hgnetv2/hgnetv2_layers.py new file mode 100644 index 0000000000..7e23cb2e1c --- /dev/null +++ b/keras_hub/src/models/hgnetv2/hgnetv2_layers.py @@ -0,0 +1,926 @@ +import keras + + +@keras.saving.register_keras_serializable(package="keras_hub") +class HGNetV2LearnableAffineBlock(keras.layers.Layer): + """ + HGNetV2 learnable affine block. + + Applies a learnable scale and bias to the input tensor, implementing a + simple affine transformation with trainable parameters. + + Args: + scale_value: float, optional. Initial value for the scale parameter. + Defaults to 1.0. + bias_value: float, optional. Initial value for the bias parameter. + Defaults to 0.0. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__(self, scale_value=1.0, bias_value=0.0, **kwargs): + super().__init__(**kwargs) + self.scale_value = scale_value + self.bias_value = bias_value + + def build(self, input_shape): + self.scale = self.add_weight( + name="scale", + shape=(), + initializer=keras.initializers.Constant(self.scale_value), + trainable=True, + dtype=self.dtype_policy.name + if isinstance(self.dtype_policy, keras.mixed_precision.DTypePolicy) + else self.dtype_policy, + ) + self.bias = self.add_weight( + name="bias", + shape=(), + initializer=keras.initializers.Constant(self.bias_value), + trainable=True, + dtype=self.dtype_policy.name + if isinstance(self.dtype_policy, keras.mixed_precision.DTypePolicy) + else self.dtype_policy, + ) + super().build(input_shape) + + def call(self, hidden_state): + return self.scale * hidden_state + self.bias + + def get_config(self): + config = super().get_config() + config.update( + {"scale_value": self.scale_value, "bias_value": self.bias_value} + ) + return config + + +@keras.saving.register_keras_serializable(package="keras_hub") +class HGNetV2ConvLayer(keras.layers.Layer): + """ + HGNetV2 convolutional layer. + + Performs a 2D convolution followed by batch normalization and an activation + function. Includes zero-padding to maintain spatial dimensions and + optionally applies a learnable affine block. + + Args: + in_channels: int. Number of input channels. + out_channels: int. Number of output channels. + kernel_size: int. Size of the convolutional kernel. + stride: int. Stride of the convolution. + groups: int. Number of groups for group convolution. + activation: string, optional. Activation function to use ('relu', + 'gelu', 'tanh', or None). Defaults to 'relu'. + use_learnable_affine_block: bool, optional. Whether to include a + learnable affine block after activation. Defaults to False. + data_format: string, optional. Data format of the input ('channels_last' + or 'channels_first'). Defaults to None. + channel_axis: int, optional. Axis of the channel dimension. Defaults to + None. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + groups, + activation="relu", + use_learnable_affine_block=False, + data_format=None, + channel_axis=None, + **kwargs, + ): + super().__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.groups = groups + self.activation_name = activation + self.use_learnable_affine_block = use_learnable_affine_block + self.data_format = data_format + self.channel_axis = channel_axis + pad = (self.kernel_size - 1) // 2 + self.padding = keras.layers.ZeroPadding2D( + padding=((pad, pad), (pad, pad)), + data_format=self.data_format, + name=f"{self.name}_pad" if self.name else None, + ) + self.convolution = keras.layers.Conv2D( + filters=self.out_channels, + kernel_size=self.kernel_size, + strides=self.stride, + groups=self.groups, + padding="valid", + use_bias=False, + data_format=self.data_format, + name=f"{self.name}_conv" if self.name else None, + dtype=self.dtype_policy, + ) + self.normalization = keras.layers.BatchNormalization( + axis=self.channel_axis, + epsilon=1e-5, + momentum=0.9, + name=f"{self.name}_bn" if self.name else None, + dtype=self.dtype_policy, + ) + + if self.activation_name == "relu": + self.activation_layer = keras.layers.ReLU( + name=f"{self.name}_relu" if self.name else None, + dtype=self.dtype_policy, + ) + elif self.activation_name == "gelu": + self.activation_layer = keras.layers.Activation( + "gelu", + name=f"{self.name}_gelu" if self.name else None, + dtype=self.dtype_policy, + ) + elif self.activation_name == "tanh": + self.activation_layer = keras.layers.Activation( + "tanh", + name=f"{self.name}_tanh" if self.name else None, + dtype=self.dtype_policy, + ) + elif self.activation_name is None: + self.activation_layer = keras.layers.Identity( + name=f"{self.name}_identity_activation" if self.name else None, + dtype=self.dtype_policy, + ) + else: + raise ValueError(f"Unsupported activation: {self.activation_name}") + + if self.use_learnable_affine_block: + self.lab = HGNetV2LearnableAffineBlock( + name=f"{self.name}_lab" if self.name else None, + dtype=self.dtype_policy, + ) + else: + self.lab = keras.layers.Identity( + name=f"{self.name}_identity_lab" if self.name else None + ) + + def build(self, input_shape): + super().build(input_shape) + self.padding.build(input_shape) + padded_shape = self.padding.compute_output_shape(input_shape) + self.convolution.build(padded_shape) + conv_output_shape = self.convolution.compute_output_shape(padded_shape) + self.normalization.build(conv_output_shape) + self.lab.build(conv_output_shape) + + def call(self, inputs, training=None): + hidden_state = self.padding(inputs) + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state, training=training) + hidden_state = self.activation_layer(hidden_state) + hidden_state = self.lab(hidden_state) + return hidden_state + + def compute_output_shape(self, input_shape): + padded_shape = self.padding.compute_output_shape(input_shape) + shape = self.convolution.compute_output_shape(padded_shape) + return shape + + def get_config(self): + config = super().get_config() + config.update( + { + "in_channels": self.in_channels, + "out_channels": self.out_channels, + "kernel_size": self.kernel_size, + "stride": self.stride, + "groups": self.groups, + "activation": self.activation_name, + "use_learnable_affine_block": self.use_learnable_affine_block, + "data_format": self.data_format, + } + ) + return config + + +@keras.saving.register_keras_serializable(package="keras_hub") +class HGNetV2ConvLayerLight(keras.layers.Layer): + """ + HGNetV2 lightweight convolutional layer. + + Composes two convolutional layers: a 1x1 convolution followed by a depthwise + convolution with the specified kernel size. Optionally includes a learnable + affine block in the second convolution. + + Args: + in_channels: int. Number of input channels. + out_channels: int. Number of output channels. + kernel_size: int. Size of the convolutional kernel for the depthwise + convolution. + use_learnable_affine_block: bool, optional. Whether to include a + learnable affine block in the second convolution. Defaults to False. + data_format: string, optional. Data format of the input. Defaults to + None. + channel_axis: int, optional. Axis of the channel dimension. Defaults to + None. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + use_learnable_affine_block=False, + data_format=None, + channel_axis=None, + **kwargs, + ): + super().__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.use_learnable_affine_block = use_learnable_affine_block + self.data_format = data_format + self.channel_axis = channel_axis + + self.conv1_layer = HGNetV2ConvLayer( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + stride=1, + groups=1, + activation=None, + use_learnable_affine_block=False, + data_format=self.data_format, + channel_axis=self.channel_axis, + name=f"{self.name}_conv1" if self.name else "conv1", + dtype=self.dtype_policy, + ) + self.conv2_layer = HGNetV2ConvLayer( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + stride=1, + groups=self.out_channels, + activation="relu", + use_learnable_affine_block=self.use_learnable_affine_block, + data_format=self.data_format, + channel_axis=self.channel_axis, + name=f"{self.name}_conv2" if self.name else "conv2", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + super().build(input_shape) + self.conv1_layer.build(input_shape) + conv1_output_shape = self.conv1_layer.compute_output_shape(input_shape) + self.conv2_layer.build(conv1_output_shape) + + def call(self, hidden_state, training=None): + hidden_state = self.conv1_layer(hidden_state, training=training) + hidden_state = self.conv2_layer(hidden_state, training=training) + return hidden_state + + def get_config(self): + config = super().get_config() + config.update( + { + "in_channels": self.in_channels, + "out_channels": self.out_channels, + "kernel_size": self.kernel_size, + "use_learnable_affine_block": self.use_learnable_affine_block, + "data_format": self.data_format, + } + ) + return config + + def compute_output_shape(self, input_shape): + shape = self.conv1_layer.compute_output_shape(input_shape) + shape = self.conv2_layer.compute_output_shape(shape) + return shape + + +@keras.saving.register_keras_serializable(package="keras_hub") +class HGNetV2Embeddings(keras.layers.Layer): + """ + HGNetV2 embedding layer. + + Processes input images through a series of convolutional and pooling + operations to produce feature maps. Includes multiple convolutional layers + with specific configurations, padding, and concatenation. + + Args: + stem_channels: list of int. Channels for the stem layers. + hidden_act: string. Activation function to use in the convolutional + layers. + use_learnable_affine_block: bool. Whether to use learnable affine blocks + in the convolutional layers. + num_channels: int. Number of input channels (e.g., 3 for RGB images). + data_format: string, optional. Data format of the input. Defaults to + None. + channel_axis: int, optional. Axis of the channel dimension. Defaults to + None. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + stem_channels, + hidden_act, + use_learnable_affine_block, + num_channels, + data_format=None, + channel_axis=None, + **kwargs, + ): + super().__init__(**kwargs) + self.stem_channels = stem_channels + self.hidden_act = hidden_act + self.use_learnable_affine_block = use_learnable_affine_block + self.num_channels = num_channels + self.data_format = data_format + self.channel_axis = channel_axis + self.stem1_layer = HGNetV2ConvLayer( + in_channels=self.stem_channels[0], + out_channels=self.stem_channels[1], + kernel_size=3, + stride=2, + groups=1, + activation=self.hidden_act, + use_learnable_affine_block=self.use_learnable_affine_block, + data_format=self.data_format, + channel_axis=self.channel_axis, + name=f"{self.name}_stem1" if self.name else "stem1", + dtype=self.dtype_policy, + ) + self.padding1 = keras.layers.ZeroPadding2D( + padding=((0, 1), (0, 1)), + data_format=self.data_format, + name=f"{self.name}_padding1" if self.name else "padding1", + ) + self.stem2a_layer = HGNetV2ConvLayer( + in_channels=self.stem_channels[1], + out_channels=self.stem_channels[1] // 2, + kernel_size=2, + stride=1, + groups=1, + activation=self.hidden_act, + use_learnable_affine_block=self.use_learnable_affine_block, + data_format=self.data_format, + channel_axis=self.channel_axis, + name=f"{self.name}_stem2a" if self.name else "stem2a", + dtype=self.dtype_policy, + ) + self.padding2 = keras.layers.ZeroPadding2D( + padding=((0, 1), (0, 1)), + data_format=self.data_format, + name=f"{self.name}_padding2" if self.name else "padding2", + ) + self.stem2b_layer = HGNetV2ConvLayer( + in_channels=self.stem_channels[1] // 2, + out_channels=self.stem_channels[1], + kernel_size=2, + stride=1, + groups=1, + activation=self.hidden_act, + use_learnable_affine_block=self.use_learnable_affine_block, + data_format=self.data_format, + channel_axis=self.channel_axis, + name=f"{self.name}_stem2b" if self.name else "stem2b", + dtype=self.dtype_policy, + ) + self.pool_layer = keras.layers.MaxPool2D( + pool_size=2, + strides=1, + padding="valid", + data_format=self.data_format, + name=f"{self.name}_pool" if self.name else "pool", + ) + self.concatenate_layer = keras.layers.Concatenate( + axis=self.channel_axis, + name=f"{self.name}_concat" if self.name else "concat", + ) + self.stem3_layer = HGNetV2ConvLayer( + in_channels=self.stem_channels[1] * 2, + out_channels=self.stem_channels[1], + kernel_size=3, + stride=2, + groups=1, + activation=self.hidden_act, + use_learnable_affine_block=self.use_learnable_affine_block, + data_format=self.data_format, + channel_axis=self.channel_axis, + name=f"{self.name}_stem3" if self.name else "stem3", + dtype=self.dtype_policy, + ) + self.stem4_layer = HGNetV2ConvLayer( + in_channels=self.stem_channels[1], + out_channels=self.stem_channels[2], + kernel_size=1, + stride=1, + groups=1, + activation=self.hidden_act, + use_learnable_affine_block=self.use_learnable_affine_block, + data_format=self.data_format, + channel_axis=self.channel_axis, + name=f"{self.name}_stem4" if self.name else "stem4", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + super().build(input_shape) + current_shape = input_shape + self.stem1_layer.build(current_shape) + current_shape = self.stem1_layer.compute_output_shape(current_shape) + padded_shape1 = self.padding1.compute_output_shape(current_shape) + self.stem2a_layer.build(padded_shape1) + shape_after_stem2a = self.stem2a_layer.compute_output_shape( + padded_shape1 + ) + padded_shape2 = self.padding2.compute_output_shape(shape_after_stem2a) + self.stem2b_layer.build(padded_shape2) + shape_after_stem2b = self.stem2b_layer.compute_output_shape( + padded_shape2 + ) + shape_after_pool = self.pool_layer.compute_output_shape(padded_shape1) + concat_input_shapes = [shape_after_pool, shape_after_stem2b] + shape_after_concat = self.concatenate_layer.compute_output_shape( + concat_input_shapes + ) + self.stem3_layer.build(shape_after_concat) + shape_after_stem3 = self.stem3_layer.compute_output_shape( + shape_after_concat + ) + self.stem4_layer.build(shape_after_stem3) + + def compute_output_shape(self, input_shape): + current_shape = self.stem1_layer.compute_output_shape(input_shape) + padded_shape1 = self.padding1.compute_output_shape(current_shape) + shape_after_stem2a = self.stem2a_layer.compute_output_shape( + padded_shape1 + ) + padded_shape2 = self.padding2.compute_output_shape(shape_after_stem2a) + shape_after_stem2b = self.stem2b_layer.compute_output_shape( + padded_shape2 + ) + shape_after_pool = self.pool_layer.compute_output_shape(padded_shape1) + concat_input_shapes = [shape_after_pool, shape_after_stem2b] + shape_after_concat = self.concatenate_layer.compute_output_shape( + concat_input_shapes + ) + shape_after_stem3 = self.stem3_layer.compute_output_shape( + shape_after_concat + ) + final_shape = self.stem4_layer.compute_output_shape(shape_after_stem3) + return final_shape + + def call(self, pixel_values, training=None): + num_channels_check = keras.ops.shape(pixel_values)[self.channel_axis] + if num_channels_check != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values " + "match with the one set in the configuration. Expected " + f"{self.num_channels} but got {num_channels_check}." + ) + embedding = self.stem1_layer(pixel_values, training=training) + embedding_padded_for_2a_and_pool = self.padding1(embedding) + emb_stem_2a = self.stem2a_layer( + embedding_padded_for_2a_and_pool, training=training + ) + emb_stem_2a_padded = self.padding2(emb_stem_2a) + emb_stem_2a_processed = self.stem2b_layer( + emb_stem_2a_padded, training=training + ) + pooled_emb = self.pool_layer(embedding_padded_for_2a_and_pool) + embedding_concatenated = self.concatenate_layer( + [pooled_emb, emb_stem_2a_processed] + ) + embedding_after_stem3 = self.stem3_layer( + embedding_concatenated, training=training + ) + final_embedding = self.stem4_layer( + embedding_after_stem3, training=training + ) + return final_embedding + + def get_config(self): + config = super().get_config() + config.update( + { + "stem_channels": self.stem_channels, + "hidden_act": self.hidden_act, + "use_learnable_affine_block": self.use_learnable_affine_block, + "num_channels": self.num_channels, + "data_format": self.data_format, + } + ) + return config + + +@keras.saving.register_keras_serializable(package="keras_hub") +class HGNetV2BasicLayer(keras.layers.Layer): + """ + HGNetV2 basic layer. + + Consists of multiple convolutional blocks followed by aggregation through + concatenation and convolutional layers. Supports residual connections and + drop path for regularization. + + Args: + in_channels: int. Number of input channels. + middle_channels: int. Number of channels in the intermediate + convolutional blocks. + out_channels: int. Number of output channels. + layer_num: int. Number of convolutional blocks in the layer. + kernel_size: int, optional. Kernel size for the convolutional blocks. + Defaults to 3. + residual: bool, optional. Whether to include a residual connection. + Defaults to False. + light_block: bool, optional. Whether to use lightweight convolutional + blocks. Defaults to False. + drop_path: float, optional. Drop path rate for regularization. Defaults + to 0.0. + use_learnable_affine_block: bool, optional. Whether to use learnable + affine blocks in the convolutional blocks. Defaults to False. + data_format: string, optional. Data format of the input. Defaults to + None. + channel_axis: int, optional. Axis of the channel dimension. Defaults to + None. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + in_channels, + middle_channels, + out_channels, + layer_num, + kernel_size=3, + residual=False, + light_block=False, + drop_path=0.0, + use_learnable_affine_block=False, + data_format=None, + channel_axis=None, + **kwargs, + ): + super().__init__(**kwargs) + self.in_channels_arg = in_channels + self.middle_channels = middle_channels + self.out_channels = out_channels + self.layer_num = layer_num + self.kernel_size = kernel_size + self.residual = residual + self.light_block = light_block + self.drop_path_rate = drop_path + self.use_learnable_affine_block = use_learnable_affine_block + self.data_format = data_format + self.channel_axis = channel_axis + + self.layer_list = [] + for i in range(self.layer_num): + block_input_channels = ( + self.in_channels_arg if i == 0 else self.middle_channels + ) + if self.light_block: + block = HGNetV2ConvLayerLight( + in_channels=block_input_channels, + out_channels=self.middle_channels, + kernel_size=self.kernel_size, + use_learnable_affine_block=self.use_learnable_affine_block, + data_format=self.data_format, + channel_axis=self.channel_axis, + name=f"{self.name}_light_block_{i}" + if self.name + else f"light_block_{i}", + dtype=self.dtype_policy, + ) + else: + block = HGNetV2ConvLayer( + in_channels=block_input_channels, + out_channels=self.middle_channels, + kernel_size=self.kernel_size, + stride=1, + groups=1, + activation="relu", + use_learnable_affine_block=self.use_learnable_affine_block, + data_format=self.data_format, + channel_axis=self.channel_axis, + name=f"{self.name}_conv_block_{i}" + if self.name + else f"conv_block_{i}", + dtype=self.dtype_policy, + ) + self.layer_list.append(block) + self.total_channels_for_aggregation = ( + self.in_channels_arg + self.layer_num * self.middle_channels + ) + self.aggregation_squeeze_conv = HGNetV2ConvLayer( + in_channels=self.total_channels_for_aggregation, + out_channels=self.out_channels // 2, + kernel_size=1, + stride=1, + groups=1, + activation="relu", + use_learnable_affine_block=self.use_learnable_affine_block, + data_format=self.data_format, + channel_axis=self.channel_axis, + name=f"{self.name}_agg_squeeze" if self.name else "agg_squeeze", + dtype=self.dtype_policy, + ) + self.aggregation_excitation_conv = HGNetV2ConvLayer( + in_channels=self.out_channels // 2, + out_channels=self.out_channels, + kernel_size=1, + stride=1, + groups=1, + activation="relu", + use_learnable_affine_block=self.use_learnable_affine_block, + data_format=self.data_format, + channel_axis=self.channel_axis, + name=f"{self.name}_agg_excite" if self.name else "agg_excite", + dtype=self.dtype_policy, + ) + + if self.drop_path_rate > 0.0: + self.drop_path_layer = keras.layers.Dropout( + self.drop_path_rate, + noise_shape=(None, 1, 1, 1), + name=f"{self.name}_drop_path" if self.name else "drop_path", + ) + else: + self.drop_path_layer = keras.layers.Identity( + name=f"{self.name}_identity_drop_path" + if self.name + else "identity_drop_path" + ) + + self.concatenate_layer = keras.layers.Concatenate( + axis=self.channel_axis, + name=f"{self.name}_concat" if self.name else "concat", + ) + if self.residual: + self.add_layer = keras.layers.Add( + name=f"{self.name}_add_residual" + if self.name + else "add_residual" + ) + + def build(self, input_shape): + super().build(input_shape) + current_block_input_shape = input_shape + output_shapes_for_concat = [input_shape] + for i, layer_block in enumerate(self.layer_list): + layer_block.build(current_block_input_shape) + current_block_output_shape = layer_block.compute_output_shape( + current_block_input_shape + ) + output_shapes_for_concat.append(current_block_output_shape) + current_block_input_shape = current_block_output_shape + concatenated_shape = self.concatenate_layer.compute_output_shape( + output_shapes_for_concat + ) + self.aggregation_squeeze_conv.build(concatenated_shape) + agg_squeeze_output_shape = ( + self.aggregation_squeeze_conv.compute_output_shape( + concatenated_shape + ) + ) + self.aggregation_excitation_conv.build(agg_squeeze_output_shape) + + def compute_output_shape(self, input_shape): + output_tensors_shapes = [input_shape] + current_block_input_shape = input_shape + for layer_block in self.layer_list: + current_block_output_shape = layer_block.compute_output_shape( + current_block_input_shape + ) + output_tensors_shapes.append(current_block_output_shape) + current_block_input_shape = current_block_output_shape + concatenated_features_shape = ( + self.concatenate_layer.compute_output_shape(output_tensors_shapes) + ) + aggregated_features_shape = ( + self.aggregation_squeeze_conv.compute_output_shape( + concatenated_features_shape + ) + ) + final_output_shape = ( + self.aggregation_excitation_conv.compute_output_shape( + aggregated_features_shape + ) + ) + + return final_output_shape + + def call(self, hidden_state, training=None): + identity = hidden_state + output_tensors = [hidden_state] + + current_feature_map = hidden_state + for layer_block in self.layer_list: + current_feature_map = layer_block( + current_feature_map, training=training + ) + output_tensors.append(current_feature_map) + concatenated_features = self.concatenate_layer(output_tensors) + aggregated_features = self.aggregation_squeeze_conv( + concatenated_features, training=training + ) + aggregated_features = self.aggregation_excitation_conv( + aggregated_features, training=training + ) + if self.residual: + dropped_features = self.drop_path_layer( + aggregated_features, training=training + ) + final_output = self.add_layer([dropped_features, identity]) + else: + final_output = aggregated_features + return final_output + + def get_config(self): + config = super().get_config() + config.update( + { + "in_channels": self.in_channels_arg, + "middle_channels": self.middle_channels, + "out_channels": self.out_channels, + "layer_num": self.layer_num, + "kernel_size": self.kernel_size, + "residual": self.residual, + "light_block": self.light_block, + "drop_path": self.drop_path_rate, + "use_learnable_affine_block": self.use_learnable_affine_block, + "data_format": self.data_format, + } + ) + return config + + +@keras.saving.register_keras_serializable(package="keras_hub") +class HGNetV2Stage(keras.layers.Layer): + """ + HGNetV2 stage layer. + + Represents a stage in the HGNetV2 model, which may include downsampling + followed by a series of basic layers. Each stage can have different + configurations for the number of blocks, channels, etc. + + Args: + stage_in_channels: list of int. Input channels for each stage. + stage_mid_channels: list of int. Middle channels for each stage. + stage_out_channels: list of int. Output channels for each stage. + stage_num_blocks: list of int. Number of basic layers in each stage. + stage_numb_of_layers: list of int. Number of convolutional blocks in + each basic layer. + stage_downsample: list of bool. Whether to downsample at the beginning + of each stage. + stage_light_block: list of bool. Whether to use lightweight blocks in + each stage. + stage_kernel_size: list of int. Kernel sizes for each stage. + use_learnable_affine_block: bool. Whether to use learnable affine + blocks. + stage_index: int. The index of the current stage. + drop_path: float, optional. Drop path rate. Defaults to 0.0. + data_format: string, optional. Data format of the input. Defaults to + None. + channel_axis: int, optional. Axis of the channel dimension. Defaults to + None. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + stage_in_channels, + stage_mid_channels, + stage_out_channels, + stage_num_blocks, + stage_numb_of_layers, + stage_downsample, + stage_light_block, + stage_kernel_size, + use_learnable_affine_block, + stage_index: int, + drop_path: float = 0.0, + data_format=None, + channel_axis=None, + **kwargs, + ): + super().__init__(**kwargs) + self.stage_in_channels = stage_in_channels + self.stage_mid_channels = stage_mid_channels + self.stage_out_channels = stage_out_channels + self.stage_num_blocks = stage_num_blocks + self.stage_numb_of_layers = stage_numb_of_layers + self.stage_downsample = stage_downsample + self.stage_light_block = stage_light_block + self.stage_kernel_size = stage_kernel_size + self.use_learnable_affine_block = use_learnable_affine_block + self.stage_index = stage_index + self.drop_path = drop_path + self.data_format = data_format + self.channel_axis = channel_axis + self.current_stage_in_channels = stage_in_channels[stage_index] + self.current_stage_mid_channels = stage_mid_channels[stage_index] + self.current_stage_out_channels = stage_out_channels[stage_index] + self.current_stage_num_blocks = stage_num_blocks[stage_index] + self.current_stage_num_layers_per_block = stage_numb_of_layers[ + stage_index + ] + self.current_stage_is_downsample_active = stage_downsample[stage_index] + self.current_stage_is_light_block = stage_light_block[stage_index] + self.current_stage_kernel_size = stage_kernel_size[stage_index] + self.current_stage_use_lab = use_learnable_affine_block + self.current_stage_drop_path = drop_path + if self.current_stage_is_downsample_active: + self.downsample_layer = HGNetV2ConvLayer( + in_channels=self.current_stage_in_channels, + out_channels=self.current_stage_in_channels, + kernel_size=3, + stride=2, + groups=self.current_stage_in_channels, + activation=None, + use_learnable_affine_block=False, + data_format=self.data_format, + channel_axis=self.channel_axis, + name=f"{self.name}_downsample" if self.name else "downsample", + dtype=self.dtype_policy, + ) + else: + self.downsample_layer = keras.layers.Identity( + name=f"{self.name}_identity_downsample" + if self.name + else "identity_downsample" + ) + + self.blocks_list = [] + for i in range(self.current_stage_num_blocks): + basic_layer_input_channels = ( + self.current_stage_in_channels + if i == 0 + else self.current_stage_out_channels + ) + + block = HGNetV2BasicLayer( + in_channels=basic_layer_input_channels, + middle_channels=self.current_stage_mid_channels, + out_channels=self.current_stage_out_channels, + layer_num=self.current_stage_num_layers_per_block, + residual=(False if i == 0 else True), + kernel_size=self.current_stage_kernel_size, + light_block=self.current_stage_is_light_block, + drop_path=self.current_stage_drop_path, + use_learnable_affine_block=self.current_stage_use_lab, + data_format=self.data_format, + channel_axis=self.channel_axis, + name=f"{self.name}_block_{i}" if self.name else f"block_{i}", + dtype=self.dtype_policy, + ) + self.blocks_list.append(block) + + def build(self, input_shape): + super().build(input_shape) + current_input_shape = input_shape + self.downsample_layer.build(current_input_shape) + current_input_shape = self.downsample_layer.compute_output_shape( + current_input_shape + ) + + for block_item in self.blocks_list: + block_item.build(current_input_shape) + current_input_shape = block_item.compute_output_shape( + current_input_shape + ) + + def compute_output_shape(self, input_shape): + current_shape = self.downsample_layer.compute_output_shape(input_shape) + for block_item in self.blocks_list: + current_shape = block_item.compute_output_shape(current_shape) + return current_shape + + def call(self, hidden_state, training=None): + hidden_state = self.downsample_layer(hidden_state, training=training) + for block_item in self.blocks_list: + hidden_state = block_item(hidden_state, training=training) + return hidden_state + + def get_config(self): + config = super().get_config() + config.update( + { + "stage_in_channels": self.stage_in_channels, + "stage_mid_channels": self.stage_mid_channels, + "stage_out_channels": self.stage_out_channels, + "stage_num_blocks": self.stage_num_blocks, + "stage_numb_of_layers": self.stage_numb_of_layers, + "stage_downsample": self.stage_downsample, + "stage_light_block": self.stage_light_block, + "stage_kernel_size": self.stage_kernel_size, + "use_learnable_affine_block": self.use_learnable_affine_block, + "stage_index": self.stage_index, + "drop_path": self.drop_path, + "data_format": self.data_format, + } + ) + return config diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_presets.py b/keras_hub/src/models/hgnetv2/hgnetv2_presets.py new file mode 100644 index 0000000000..a74de81c39 --- /dev/null +++ b/keras_hub/src/models/hgnetv2/hgnetv2_presets.py @@ -0,0 +1,133 @@ +# Metadata for loading pretrained model weights. +backbone_presets = { + "hgnetv2_b0.ssld_stage1_in22k_in1k": { + "metadata": { + "description": ( + "HGNetV2 B0 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K." # noqa: E501 + ), + "params": 5996550, + "path": "hgnetv2", + }, + "kaggle_handle": "", + }, + "hgnetv2_b0.ssld_stage2_ft_in1k": { + "metadata": { + "description": ( + "HGNetV2 B0 model with 2-stage SSLD training, fine-tuned on ImageNet-1K." # noqa: E501 + ), + "params": 5996550, + "path": "hgnetv2", + }, + "kaggle_handle": "", + }, + "hgnetv2_b1.ssld_stage1_in22k_in1k": { + "metadata": { + "description": ( + "HGNetV2 B1 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K." # noqa: E501 + ), + "params": 6343158, + "path": "hgnetv2", + }, + "kaggle_handle": "", + }, + "hgnetv2_b1.ssld_stage2_ft_in1k": { + "metadata": { + "description": ( + "HGNetV2 B1 model with 2-stage SSLD training, fine-tuned on ImageNet-1K." # noqa: E501 + ), + "params": 6343158, + "path": "hgnetv2", + }, + "kaggle_handle": "", + }, + "hgnetv2_b2.ssld_stage1_in22k_in1k": { + "metadata": { + "description": ( + "HGNetV2 B2 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K." # noqa: E501 + ), + "params": 11221356, + "path": "hgnetv2", + }, + "kaggle_handle": "", + }, + "hgnetv2_b2.ssld_stage2_ft_in1k": { + "metadata": { + "description": ( + "HGNetV2 B2 model with 2-stage SSLD training, fine-tuned on ImageNet-1K." # noqa: E501 + ), + "params": 11221356, + "path": "hgnetv2", + }, + "kaggle_handle": "", + }, + "hgnetv2_b3.ssld_stage1_in22k_in1k": { + "metadata": { + "description": ( + "HGNetV2 B3 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K." # noqa: E501 + ), + "params": 16292216, + "path": "hgnetv2", + }, + "kaggle_handle": "", + }, + "hgnetv2_b3.ssld_stage2_ft_in1k": { + "metadata": { + "description": ( + "HGNetV2 B3 model with 2-stage SSLD training, fine-tuned on ImageNet-1K." # noqa: E501 + ), + "params": 16292216, + "path": "hgnetv2", + }, + "kaggle_handle": "", + }, + "hgnetv2_b4.ssld_stage2_ft_in1k": { + "metadata": { + "description": ( + "HGNetV2 B4 model with 2-stage SSLD training, fine-tuned on ImageNet-1K." # noqa: E501 + ), + "params": 19796680, + "path": "hgnetv2", + }, + "kaggle_handle": "", + }, + "hgnetv2_b5.ssld_stage1_in22k_in1k": { + "metadata": { + "description": ( + "HGNetV2 B5 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K." # noqa: E501 + ), + "params": 39569064, + "path": "hgnetv2", + }, + "kaggle_handle": "", + }, + "hgnetv2_b5.ssld_stage2_ft_in1k": { + "metadata": { + "description": ( + "HGNetV2 B5 model with 2-stage SSLD training, fine-tuned on ImageNet-1K." # noqa: E501 + ), + "params": 39569064, + "path": "hgnetv2", + }, + "kaggle_handle": "", + }, + "hgnetv2_b6.ssld_stage1_in22k_in1k": { + "metadata": { + "description": ( + "HGNetV2 B6 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K." # noqa: E501 + ), + "params": 75256776, + "path": "hgnetv2", + }, + "kaggle_handle": "", + }, + "hgnetv2_b6.ssld_stage2_ft_in1k": { + "metadata": { + "description": ( + "HGNetV2 B6 model with 2-stage SSLD training, fine-tuned on ImageNet-1K." # noqa: E501 + ), + "params": 75256776, + "path": "hgnetv2", + }, + "kaggle_handle": "", + }, +} diff --git a/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py b/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py new file mode 100644 index 0000000000..929e82da41 --- /dev/null +++ b/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py @@ -0,0 +1,399 @@ +"""Convert HGNetV2 checkpoints from Hugging Face to Keras. + +Usage: + export KAGGLE_USERNAME=XXX + export KAGGLE_KEY=XXX + + python tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py +""" + +import json +import os +import shutil + +import keras +import numpy as np +import safetensors.torch +import torch +from absl import app +from absl import flags +from PIL import Image +from timm import create_model + +import keras_hub +from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone + +FLAGS = flags.FLAGS + +PRESET_MAP = { + "hgnetv2_b6.ssld_stage2_ft_in1k": "timm/hgnetv2_b6.ssld_stage2_ft_in1k", + "hgnetv2_b6.ssld_stage1_in22k_in1k": "timm/hgnetv2_b6.ssld_stage1_in22k_in1k", # noqa: E501 + "hgnetv2_b5.ssld_stage2_ft_in1k": "timm/hgnetv2_b5.ssld_stage2_ft_in1k", + "hgnetv2_b5.ssld_stage1_in22k_in1k": "timm/hgnetv2_b5.ssld_stage1_in22k_in1k", # noqa: E501 + "hgnetv2_b4.ssld_stage2_ft_in1k": "timm/hgnetv2_b4.ssld_stage2_ft_in1k", + "hgnetv2_b3.ssld_stage2_ft_in1k": "timm/hgnetv2_b3.ssld_stage2_ft_in1k", + "hgnetv2_b3.ssld_stage1_in22k_in1k": "timm/hgnetv2_b3.ssld_stage1_in22k_in1k", # noqa: E501 + "hgnetv2_b2.ssld_stage2_ft_in1k": "timm/hgnetv2_b2.ssld_stage2_ft_in1k", + "hgnetv2_b2.ssld_stage1_in22k_in1k": "timm/hgnetv2_b2.ssld_stage1_in22k_in1k", # noqa: E501 + "hgnetv2_b1.ssld_stage2_ft_in1k": "timm/hgnetv2_b1.ssld_stage2_ft_in1k", + "hgnetv2_b1.ssld_stage1_in22k_in1k": "timm/hgnetv2_b1.ssld_stage1_in22k_in1k", # noqa: E501 + "hgnetv2_b0.ssld_stage2_ft_in1k": "timm/hgnetv2_b0.ssld_stage2_ft_in1k", + "hgnetv2_b0.ssld_stage1_in22k_in1k": "timm/hgnetv2_b0.ssld_stage1_in22k_in1k", # noqa: E501 +} +LAB_FALSE_PRESETS = [ + "hgnetv2_b6.ssld_stage2_ft_in1k", + "hgnetv2_b6.ssld_stage1_in22k_in1k", + "hgnetv2_b5.ssld_stage2_ft_in1k", + "hgnetv2_b5.ssld_stage1_in22k_in1k", + "hgnetv2_b4.ssld_stage2_ft_in1k", +] + +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/hgnetv2/keras/{preset}"', + required=False, +) +HGNETV2_CONFIGS = { + "hgnetv2_b0": { + "stem_channels": [3, 16, 16], + "stage_in_channels": [16, 64, 256, 512], + "stage_mid_channels": [16, 32, 64, 128], + "stage_out_channels": [64, 256, 512, 1024], + "stage_num_blocks": [1, 1, 2, 1], + "stage_numb_of_layers": [3, 3, 3, 3], + "stage_downsample": [False, True, True, True], + "stage_light_block": [False, False, True, True], + "stage_kernel_size": [3, 3, 5, 5], + "embedding_size": 16, + "hidden_sizes": [64, 256, 512, 1024], + "depths": [1, 1, 2, 1], + }, + "hgnetv2_b1": { + "stem_channels": [3, 24, 32], + "stage_in_channels": [32, 64, 256, 512], + "stage_mid_channels": [32, 48, 96, 192], + "stage_out_channels": [64, 256, 512, 1024], + "stage_num_blocks": [1, 1, 2, 1], + "stage_numb_of_layers": [3, 3, 3, 3], + "stage_downsample": [False, True, True, True], + "stage_light_block": [False, False, True, True], + "stage_kernel_size": [3, 3, 5, 5], + "embedding_size": 32, + "hidden_sizes": [64, 256, 512, 1024], + "depths": [1, 1, 2, 1], + }, + "hgnetv2_b2": { + "stem_channels": [3, 24, 32], + "stage_in_channels": [32, 96, 384, 768], + "stage_mid_channels": [32, 64, 128, 256], + "stage_out_channels": [96, 384, 768, 1536], + "stage_num_blocks": [1, 1, 3, 1], + "stage_numb_of_layers": [4, 4, 4, 4], + "stage_downsample": [False, True, True, True], + "stage_light_block": [False, False, True, True], + "stage_kernel_size": [3, 3, 5, 5], + "embedding_size": 32, + "hidden_sizes": [96, 384, 768, 1536], + "depths": [1, 1, 3, 1], + }, + "hgnetv2_b3": { + "stem_channels": [3, 24, 32], + "stage_in_channels": [32, 128, 512, 1024], + "stage_mid_channels": [32, 64, 128, 256], + "stage_out_channels": [128, 512, 1024, 2048], + "stage_num_blocks": [1, 1, 3, 1], + "stage_numb_of_layers": [5, 5, 5, 5], + "stage_downsample": [False, True, True, True], + "stage_light_block": [False, False, True, True], + "stage_kernel_size": [3, 3, 5, 5], + "embedding_size": 32, + "hidden_sizes": [128, 512, 1024, 2048], + "depths": [1, 1, 3, 1], + }, + "hgnetv2_b4": { + "stem_channels": [3, 32, 48], + "stage_in_channels": [48, 128, 512, 1024], + "stage_mid_channels": [48, 96, 192, 384], + "stage_out_channels": [128, 512, 1024, 2048], + "stage_num_blocks": [1, 1, 3, 1], + "stage_numb_of_layers": [6, 6, 6, 6], + "stage_downsample": [False, True, True, True], + "stage_light_block": [False, False, True, True], + "stage_kernel_size": [3, 3, 5, 5], + "embedding_size": 48, + "hidden_sizes": [128, 512, 1024, 2048], + "depths": [1, 1, 3, 1], + }, + "hgnetv2_b5": { + "stem_channels": [3, 32, 64], + "stage_in_channels": [64, 128, 512, 1024], + "stage_mid_channels": [64, 128, 256, 512], + "stage_out_channels": [128, 512, 1024, 2048], + "stage_num_blocks": [1, 2, 5, 2], + "stage_numb_of_layers": [6, 6, 6, 6], + "stage_downsample": [False, True, True, True], + "stage_light_block": [False, False, True, True], + "stage_kernel_size": [3, 3, 5, 5], + "embedding_size": 64, + "hidden_sizes": [128, 512, 1024, 2048], + "depths": [1, 2, 5, 2], + }, + "hgnetv2_b6": { + "stem_channels": [3, 48, 96], + "stage_in_channels": [96, 192, 512, 1024], + "stage_mid_channels": [96, 192, 384, 768], + "stage_out_channels": [192, 512, 1024, 2048], + "stage_num_blocks": [2, 3, 6, 3], + "stage_numb_of_layers": [6, 6, 6, 6], + "stage_downsample": [False, True, True, True], + "stage_light_block": [False, False, True, True], + "stage_kernel_size": [3, 3, 5, 5], + "embedding_size": 96, + "hidden_sizes": [192, 512, 1024, 2048], + "depths": [2, 3, 6, 3], + }, +} + + +def load_hf_config(hf_preset): + config_path = keras.utils.get_file( + origin=f"https://huggingface.co/{hf_preset}/raw/main/config.json", + cache_subdir=f"hf_models/{hf_preset}", + ) + with open(config_path, "r") as f: + config = json.load(f) + config["pretrained_cfg"] = hf_model.pretrained_cfg + return config + + +def convert_model(hf_config, architecture, preset_name): + config = HGNETV2_CONFIGS[architecture] + image_size = hf_config["pretrained_cfg"]["input_size"][1] + use_lab = preset_name not in LAB_FALSE_PRESETS + + backbone = HGNetV2Backbone( + image_shape=(image_size, image_size, 3), + initializer_range=0.02, + depths=config["depths"], + embedding_size=config["embedding_size"], + hidden_sizes=config["hidden_sizes"], + stem_channels=config["stem_channels"], + hidden_act="relu", + use_learnable_affine_block=use_lab, + num_channels=3, + stage_in_channels=config["stage_in_channels"], + stage_mid_channels=config["stage_mid_channels"], + stage_out_channels=config["stage_out_channels"], + stage_num_blocks=config["stage_num_blocks"], + stage_numb_of_layers=config["stage_numb_of_layers"], + stage_downsample=config["stage_downsample"], + stage_light_block=config["stage_light_block"], + stage_kernel_size=config["stage_kernel_size"], + ) + return backbone, config, image_size + + +def convert_weights(keras_model, hf_model): + state_dict = hf_model.state_dict() + classifier_keys = [ + key for key in state_dict.keys() if key.startswith("head") + ] + for key in classifier_keys: + state_dict.pop(key) + + def port_weights(keras_variable, weight_key, hook_fn=None): + if weight_key not in state_dict: + raise KeyError(f"Weight key '{weight_key}' not found in state_dict") + torch_tensor = state_dict[weight_key].cpu().numpy() + if hook_fn: + torch_tensor = hook_fn(torch_tensor, list(keras_variable.shape)) + if ( + keras_variable.shape == () + and isinstance(torch_tensor, np.ndarray) + and torch_tensor.shape == (1,) + ): + torch_tensor = torch_tensor[0] + keras_variable.assign(torch_tensor) + + def port_conv(keras_conv_layer, weight_key_prefix): + port_weights( + keras_conv_layer.convolution.kernel, + f"{weight_key_prefix}.conv.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + port_weights( + keras_conv_layer.normalization.gamma, + f"{weight_key_prefix}.bn.weight", + ) + port_weights( + keras_conv_layer.normalization.beta, + f"{weight_key_prefix}.bn.bias", + ) + port_weights( + keras_conv_layer.normalization.moving_mean, + f"{weight_key_prefix}.bn.running_mean", + ) + port_weights( + keras_conv_layer.normalization.moving_variance, + f"{weight_key_prefix}.bn.running_var", + ) + if not isinstance(keras_conv_layer.lab, keras.layers.Identity): + port_weights( + keras_conv_layer.lab.scale, f"{weight_key_prefix}.lab.scale" + ) + port_weights( + keras_conv_layer.lab.bias, f"{weight_key_prefix}.lab.bias" + ) + + def port_conv_light(keras_conv_light_layer, weight_key_prefix): + port_conv( + keras_conv_light_layer.conv1_layer, f"{weight_key_prefix}.conv1" + ) + port_conv( + keras_conv_light_layer.conv2_layer, f"{weight_key_prefix}.conv2" + ) + + def port_embeddings(keras_embeddings, weight_key_prefix): + port_conv(keras_embeddings.stem1_layer, f"{weight_key_prefix}.stem1") + port_conv(keras_embeddings.stem2a_layer, f"{weight_key_prefix}.stem2a") + port_conv(keras_embeddings.stem2b_layer, f"{weight_key_prefix}.stem2b") + port_conv(keras_embeddings.stem3_layer, f"{weight_key_prefix}.stem3") + port_conv(keras_embeddings.stem4_layer, f"{weight_key_prefix}.stem4") + + def port_basic_layer(keras_basic_layer, weight_key_prefix): + from keras_hub.src.models.hgnetv2.hgnetv2_layers import ( + HGNetV2ConvLayerLight, + ) + + for i, layer in enumerate(keras_basic_layer.layer_list): + layer_prefix = f"{weight_key_prefix}.layers.{i}" + if isinstance(layer, HGNetV2ConvLayerLight): + port_conv_light(layer, layer_prefix) + else: + port_conv(layer, layer_prefix) + port_conv( + keras_basic_layer.aggregation_squeeze_conv, + f"{weight_key_prefix}.aggregation.0", + ) + port_conv( + keras_basic_layer.aggregation_excitation_conv, + f"{weight_key_prefix}.aggregation.1", + ) + + def port_stage(keras_stage, weight_key_prefix): + if not isinstance(keras_stage.downsample_layer, keras.layers.Identity): + port_conv( + keras_stage.downsample_layer, f"{weight_key_prefix}.downsample" + ) + for block_idx, block in enumerate(keras_stage.blocks_list): + port_basic_layer(block, f"{weight_key_prefix}.blocks.{block_idx}") + + def port_encoder(keras_encoder, weight_key_prefix): + for i, stage in enumerate(keras_encoder.stages_list): + port_stage(stage, f"{weight_key_prefix}.{i}") + + port_embeddings(keras_model.embedder_layer, "stem") + port_encoder(keras_model.encoder_layer, "stages") + + +def convert_image_converter(hf_config): + pretrained_cfg = hf_config["pretrained_cfg"] + image_size = ( + pretrained_cfg["input_size"][1], + pretrained_cfg["input_size"][2], + ) + mean = pretrained_cfg["mean"] + std = pretrained_cfg["std"] + interpolation = pretrained_cfg["interpolation"] + return ( + keras.layers.Lambda( + lambda x: keras.preprocessing.image.smart_resize( + x, image_size, interpolation=interpolation + ) + ), + mean, + std, + ) + + +def validate_output(keras_model, keras_image_converter, hf_model, mean, std): + file = keras.utils.get_file( + origin="http://images.cocodataset.org/val2017/000000039769.jpg" + ) + image = Image.open(file) + images = np.expand_dims(np.array(image).astype("float32"), axis=0) + images = np.concatenate([images, images], axis=0) + images = keras_image_converter(images) + images = keras.ops.convert_to_tensor(images, dtype="float32") + mean_tensor = keras.ops.convert_to_tensor(mean, dtype="float32") + std_tensor = keras.ops.convert_to_tensor(std, dtype="float32") + images = (images - mean_tensor) / std_tensor + keras_preprocessed = images + hf_inputs = torch.from_numpy( + keras.ops.convert_to_numpy( + keras.ops.transpose(keras_preprocessed, (0, 3, 1, 2)) + ) + ) + keras_backbone_output_dict = keras_model(keras_preprocessed, training=False) + last_stage_name = keras_model.stage_names[-1] + keras_last_stage_tensor = keras_backbone_output_dict[last_stage_name] + hf_backbone = torch.nn.Sequential(*list(hf_model.children())[:-1]) + hf_backbone_output = hf_backbone(hf_inputs) + keras_output_np = keras.ops.convert_to_numpy(keras_last_stage_tensor) + hf_output_np = hf_backbone_output.detach().cpu().numpy() + hf_output_np = np.transpose(hf_output_np, (0, 2, 3, 1)) + modeling_diff = np.mean(np.abs(keras_output_np - hf_output_np)) + print("šŸ”¶ Modeling difference:", modeling_diff) + + +def main(_): + for preset in PRESET_MAP.keys(): + hf_preset = PRESET_MAP[preset] + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + print(f"\nšŸƒ Converting {preset}") + global hf_model + hf_model = create_model(hf_preset, pretrained=False) + safetensors_file = keras.utils.get_file( + origin=f"https://huggingface.co/{hf_preset}/resolve/main/model.safetensors", # noqa: E501 + cache_subdir=f"hf_models/{hf_preset}", + ) + try: + state_dict = safetensors.torch.load_file(safetensors_file) + hf_model.load_state_dict(state_dict) + except Exception as e: + print(f"Error loading Safetensors file for {preset}: {e}") + print("Clearing cache and retrying download...") + cache_dir = os.path.dirname(safetensors_file) + if os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + safetensors_file = keras.utils.get_file( + origin=f"https://huggingface.co/{hf_preset}/resolve/main/model.safetensors", # noqa: E501 + cache_subdir=f"hf_models/{hf_preset}", + ) + state_dict = safetensors.torch.load_file(safetensors_file) + hf_model.eval() + hf_config = load_hf_config(hf_preset) + architecture = hf_config["architecture"] + keras_model, _, _ = convert_model(hf_config, architecture, preset) + print("āœ… KerasHub model loaded.") + convert_weights(keras_model, hf_model) + print("āœ… Weights converted.") + keras_image_converter, mean, std = convert_image_converter(hf_config) + validate_output(keras_model, keras_image_converter, hf_model, mean, std) + print("āœ… Output validated.") + keras_model.save_to_preset(f"./{preset}") + print(f"šŸ Preset saved to ./{preset}.") + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"šŸ Preset {preset} uploaded to {upload_uri}") + + print("\nšŸšŸ All presets validated!") + + +if __name__ == "__main__": + app.run(main) From d4c78c1aa55c4dcffa29384c3eb10980f63034bf Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Mon, 9 Jun 2025 18:10:38 +0400 Subject: [PATCH 2/8] bug: Small bug related to weight loading in the conversion script --- .../convert_hgnetv2_checkpoints.py | 117 ++++++++++++++++-- 1 file changed, 105 insertions(+), 12 deletions(-) diff --git a/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py b/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py index 929e82da41..f58f66f457 100644 --- a/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py +++ b/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py @@ -22,6 +22,18 @@ import keras_hub from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone +from keras_hub.src.models.hgnetv2.hgnetv2_image_classifier import ( + HGNetV2ImageClassifier, +) +from keras_hub.src.models.hgnetv2.hgnetv2_image_classifier_preprocessor import ( + HGNetV2ImageClassifierPreprocessor, +) +from keras_hub.src.models.hgnetv2.hgnetv2_image_converter import ( + HGNetV2ImageConverter, +) +from keras_hub.src.models.hgnetv2.hgnetv2_layers import ( + HGNetV2LearnableAffineBlock, +) FLAGS = flags.FLAGS @@ -191,16 +203,23 @@ def convert_model(hf_config, architecture, preset_name): stage_light_block=config["stage_light_block"], stage_kernel_size=config["stage_kernel_size"], ) - return backbone, config, image_size + image_converter = HGNetV2ImageConverter() + preprocessor = HGNetV2ImageClassifierPreprocessor( + image_converter=image_converter + ) + keras_model = HGNetV2ImageClassifier( + backbone=backbone, + preprocessor=preprocessor, + num_classes=hf_config["num_classes"], + initializer_range=0.02, + head_filters=hf_model.head_hidden_size, + use_learnable_affine_block_head=use_lab, + ) + return keras_model, config, image_size def convert_weights(keras_model, hf_model): state_dict = hf_model.state_dict() - classifier_keys = [ - key for key in state_dict.keys() if key.startswith("head") - ] - for key in classifier_keys: - state_dict.pop(key) def port_weights(keras_variable, weight_key, hook_fn=None): if weight_key not in state_dict: @@ -216,6 +235,71 @@ def port_weights(keras_variable, weight_key, hook_fn=None): torch_tensor = torch_tensor[0] keras_variable.assign(torch_tensor) + def port_last_conv(keras_conv_layer, state_dict, prefix="head.last_conv"): + port_weights( + keras_conv_layer.convolution.kernel, + f"{prefix}.0.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + if f"{prefix}.1.weight" in state_dict: + port_weights( + keras_conv_layer.normalization.gamma, f"{prefix}.1.weight" + ) + port_weights( + keras_conv_layer.normalization.beta, f"{prefix}.1.bias" + ) + port_weights( + keras_conv_layer.normalization.moving_mean, + f"{prefix}.1.running_mean", + ) + port_weights( + keras_conv_layer.normalization.moving_variance, + f"{prefix}.1.running_var", + ) + if isinstance(keras_conv_layer.lab, HGNetV2LearnableAffineBlock): + lab_scale_key = f"{prefix}.2.scale" + lab_bias_key = f"{prefix}.2.bias" + if lab_scale_key in state_dict: + port_weights(keras_conv_layer.lab.scale, lab_scale_key) + port_weights(keras_conv_layer.lab.bias, lab_bias_key) + else: + gamma_dtype = keras_conv_layer.normalization.gamma.dtype + if isinstance(gamma_dtype, keras.DTypePolicy): + gamma_dtype = gamma_dtype.name + gamma_identity_value = np.sqrt( + 1.0 + keras_conv_layer.normalization.epsilon + ).astype(gamma_dtype) + keras_conv_layer.normalization.gamma.assign( + np.full_like( + keras_conv_layer.normalization.gamma.numpy(), + gamma_identity_value, + ) + ) + keras_conv_layer.normalization.beta.assign( + np.zeros_like(keras_conv_layer.normalization.beta.numpy()) + ) + keras_conv_layer.normalization.moving_mean.assign( + np.zeros_like( + keras_conv_layer.normalization.moving_mean.numpy() + ) + ) + keras_conv_layer.normalization.moving_variance.assign( + np.ones_like( + keras_conv_layer.normalization.moving_variance.numpy() + ) + ) + if isinstance(keras_conv_layer.lab, HGNetV2LearnableAffineBlock): + lab_scale_key_idx1 = f"{prefix}.1.scale" + lab_bias_key_idx1 = f"{prefix}.1.bias" + lab_scale_key_idx2 = f"{prefix}.2.scale" + lab_bias_key_idx2 = f"{prefix}.2.bias" + if lab_scale_key_idx1 in state_dict: + port_weights(keras_conv_layer.lab.scale, lab_scale_key_idx1) + port_weights(keras_conv_layer.lab.bias, lab_bias_key_idx1) + elif lab_scale_key_idx2 in state_dict: + port_weights(keras_conv_layer.lab.scale, lab_scale_key_idx2) + port_weights(keras_conv_layer.lab.bias, lab_bias_key_idx2) + def port_conv(keras_conv_layer, weight_key_prefix): port_weights( keras_conv_layer.convolution.kernel, @@ -293,8 +377,15 @@ def port_encoder(keras_encoder, weight_key_prefix): for i, stage in enumerate(keras_encoder.stages_list): port_stage(stage, f"{weight_key_prefix}.{i}") - port_embeddings(keras_model.embedder_layer, "stem") - port_encoder(keras_model.encoder_layer, "stages") + port_embeddings(keras_model.backbone.embedder_layer, "stem") + port_encoder(keras_model.backbone.encoder_layer, "stages") + port_last_conv(keras_model.last_conv, state_dict) + port_weights( + keras_model.output_dense.kernel, + "head.fc.weight", + hook_fn=lambda x, _: np.transpose(x, (1, 0)), + ) + port_weights(keras_model.output_dense.bias, "head.fc.bias") def convert_image_converter(hf_config): @@ -335,8 +426,10 @@ def validate_output(keras_model, keras_image_converter, hf_model, mean, std): keras.ops.transpose(keras_preprocessed, (0, 3, 1, 2)) ) ) - keras_backbone_output_dict = keras_model(keras_preprocessed, training=False) - last_stage_name = keras_model.stage_names[-1] + keras_backbone_output_dict = keras_model.backbone( + keras_preprocessed, training=False + ) + last_stage_name = keras_model.backbone.stage_names[-1] keras_last_stage_tensor = keras_backbone_output_dict[last_stage_name] hf_backbone = torch.nn.Sequential(*list(hf_model.children())[:-1]) hf_backbone_output = hf_backbone(hf_inputs) @@ -358,7 +451,7 @@ def main(_): global hf_model hf_model = create_model(hf_preset, pretrained=False) safetensors_file = keras.utils.get_file( - origin=f"https://huggingface.co/{hf_preset}/resolve/main/model.safetensors", # noqa: E501 + origin=f"https://huggingface.co/{hf_preset}/resolve/main/model.safetensors", cache_subdir=f"hf_models/{hf_preset}", ) try: @@ -371,7 +464,7 @@ def main(_): if os.path.exists(cache_dir): shutil.rmtree(cache_dir) safetensors_file = keras.utils.get_file( - origin=f"https://huggingface.co/{hf_preset}/resolve/main/model.safetensors", # noqa: E501 + origin=f"https://huggingface.co/{hf_preset}/resolve/main/model.safetensors", cache_subdir=f"hf_models/{hf_preset}", ) state_dict = safetensors.torch.load_file(safetensors_file) From 5b2039497b79ae615a918b1a06a30616263d510a Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Mon, 23 Jun 2025 19:48:07 +0400 Subject: [PATCH 3/8] finalizing: Add TIMM preprocessing layer --- .../hgnetv2/hgnetv2_image_classifier_test.py | 6 +- .../models/hgnetv2/hgnetv2_image_converter.py | 395 +++++++++++++++++- .../src/models/hgnetv2/hgnetv2_presets.py | 105 +---- .../convert_hgnetv2_checkpoints.py | 77 ++-- 4 files changed, 459 insertions(+), 124 deletions(-) diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py index 3e00378246..037ebdee51 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py @@ -58,7 +58,11 @@ def setUp(self): image_shape=self.image_input_shape, ) self.image_converter = HGNetV2ImageConverter( - height=self.height, width=self.width + image_size=(self.height, self.width), + crop_pct=0.875, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + interpolation="bilinear", ) self.preprocessor = HGNetV2ImageClassifierPreprocessor( image_converter=self.image_converter diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py index b815ee0bfd..0257724020 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py @@ -1,8 +1,399 @@ +import math + +import keras +import numpy as np +import tensorflow as tf + from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.layers.preprocessing.preprocessing_layer import ( + PreprocessingLayer, +) from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone +from keras_hub.src.utils.keras_utils import standardize_data_format +from keras_hub.src.utils.preset_utils import builtin_presets +from keras_hub.src.utils.preset_utils import find_subclass +from keras_hub.src.utils.preset_utils import get_preset_loader +from keras_hub.src.utils.preset_utils import get_preset_saver +from keras_hub.src.utils.python_utils import classproperty +from keras_hub.src.utils.tensor_utils import in_tf_function +from keras_hub.src.utils.tensor_utils import preprocessing_function + + +@keras.saving.register_keras_serializable(package="keras_hub") +class ResizeThenCrop(keras.layers.Layer): + """Resize and crop images to a target size while preserving aspect ratio. + + This layer resizes an input image to an intermediate size based on the + `crop_pct` parameter, preserving the aspect ratio, and then performs a + center crop to achieve the specified `target_size`. This preprocessing step + is commonly used in image classification models to prepare inputs for + neural networks. + + The preprocessing follows these steps: + 1. Compute an intermediate size by dividing the target height by `crop_pct`. + 2. Resize the image to the intermediate size, maintaining the aspect ratio. + 3. Crop the resized image from the center to the specified `target_size`. + + The layer accepts batched or unbatched image tensors with shape + `(..., height, width, channels)` (`channels_last` format). + + Args: + target_size: `(int, int)` tuple. The desired output size (height, width) + of the image, excluding the channels dimension. + crop_pct: float. The cropping percentage, typically between 0.0 and 1.0, + used to compute the intermediate resize size. For example, a + `crop_pct` of 0.875 means the intermediate height is + `target_height / 0.875`. + interpolation: String, the interpolation method for resizing. + Supports `"bilinear"`, `"nearest"`, `"bicubic"`, `"lanczos3"`, + `"lanczos5"`. Defaults to `"bilinear"`. + **kwargs: Additional keyword arguments passed to the parent + `keras.layers.Layer` class, such as `name` or `dtype`. + """ + + def __init__( + self, + target_size, + crop_pct, + interpolation="bilinear", + antialias=False, + **kwargs, + ): + super().__init__(**kwargs) + self.target_size = target_size + self.crop_pct = crop_pct + self.interpolation = interpolation + self.antialias = antialias + + def call(self, inputs): + target_height, target_width = self.target_size + intermediate_size = int(math.floor(target_height / self.crop_pct)) + input_shape = keras.ops.shape(inputs) + height = input_shape[-3] + width = input_shape[-2] + aspect_ratio = keras.ops.cast(width, "float32") / keras.ops.cast( + height, "float32" + ) + if keras.ops.is_tensor(aspect_ratio): + resize_height = keras.ops.cond( + aspect_ratio > 1, + lambda: intermediate_size, + lambda: keras.ops.cast( + intermediate_size / aspect_ratio, "int32" + ), + ) + resize_width = keras.ops.cond( + aspect_ratio > 1, + lambda: keras.ops.cast( + intermediate_size * aspect_ratio, "int32" + ), + lambda: intermediate_size, + ) + else: + if aspect_ratio > 1: + resize_height = intermediate_size + resize_width = int(intermediate_size * aspect_ratio) + else: + resize_width = intermediate_size + resize_height = int(intermediate_size / aspect_ratio) + resized = keras.ops.image.resize( + inputs, + (resize_height, resize_width), + interpolation=self.interpolation, + antialias=self.antialias, + ) + top = (resize_height - target_height) // 2 + left = (resize_width - target_width) // 2 + cropped = resized[ + :, top : top + target_height, left : left + target_width, : + ] + return cropped + + def get_config(self): + config = super().get_config() + config.update( + { + "target_size": self.target_size, + "crop_pct": self.crop_pct, + "antialias": self.antialias, + "interpolation": self.interpolation, + } + ) + return config @keras_hub_export("keras_hub.layers.HGNetV2ImageConverter") -class HGNetV2ImageConverter(ImageConverter): +class HGNetV2ImageConverter(PreprocessingLayer): + """Preprocess raw images into model-ready inputs for HGNetV2 models. + + This layer converts raw images into inputs suitable for HGNetV2 models. + The preprocessing includes resizing, cropping, scaling, and normalization + steps tailored for HGNetV2 architectures. The conversion proceeds in the + following steps: + + 1. Resize and crop the image to `image_size` using a `ResizeThenCrop` layer + with the specified `crop_pct` if `image_size` is provided. If + `image_size` is `None`, this step is skipped. + 2. Scale the image by dividing pixel values by 255.0 to normalize to [0, 1]. + 3. Normalize the image by subtracting the `mean` and dividing by the `std` + (per channel) if both are provided. If `mean` or `std` is `None`, this + step is skipped. + + The layer accepts batched or unbatched image tensors in channels_last or + channels_first format, with shape `(..., height, width, channels)` or + `(..., channels, height, width)`, respectively. It can also handle + dictionary inputs with an `"images"` key for compatibility with bounding box + preprocessing. + + This layer can be instantiated using the `from_preset()` constructor to load + preprocessing configurations for specific HGNetV2 presets, ensuring + compatibility with pretrained models. + + Args: + image_size: `(int, int)` tuple or `None`. The output size of the image + (height, width), excluding the channels axis. If `None`, resizing + and cropping are skipped. + crop_pct: float. The cropping percentage used in the `ResizeThenCrop` + layer to compute the intermediate resize size. Defaults to 0.875. + mean: list or tuple of floats, or `None`. Per-channel mean values for + normalization. If provided, these are subtracted from the image + after scaling. If `None`, this step is skipped. + std: list or tuple of floats, or `None`. Per-channel standard deviation + values for normalization. If provided, the image is divided by these + after mean subtraction. If `None`, this step is skipped. + interpolation: String, the interpolation method for resizing. + Supports `"bilinear"`, `"nearest"`, `"bicubic"`, `"lanczos3"`, + `"lanczos5"`. Defaults to `"bilinear"`. + antialias: Whether to use an antialiasing filter when downsampling an + image. Defaults to `False`. + bounding_box_format: A string specifying the format of the bounding + boxes, one of `"xyxy"`, `"rel_xyxy"`, `"xywh"`, `"center_xywh"`, + `"yxyx"`, `"rel_yxyx"`. Specifies the format of the bounding boxes + which will be resized to `image_size` along with the image. To pass + bounding boxes to this layer, pass a dict with keys `"images"` and + `"bounding_boxes"` when calling the layer. + data_format: String, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + + Examples: + ```python + import keras + import numpy as np + + # Create an HGNetV2ImageConverter for a specific image size. + converter = keras_hub.layers.HGNetV2ImageConverter( + image_size=(224, 224), + crop_pct=0.965, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + interpolation="bicubic", + ) + images = np.random.randint(0, 256, size=(2, 512, 512, 3)).astype("float32") + processed_images = converter(images) + + # Load an HGNetV2ImageConverter from a preset. + converter = keras_hub.layers.HGNetV2ImageConverter.from_preset( + "hgnetv2_b5.ssld_stage1_in22k_in1k" + ) + processed_images = converter(images) + """ + backbone_cls = HGNetV2Backbone + + def __init__( + self, + image_size=None, + crop_pct=0.875, + mean=None, + std=None, + scale=None, + offset=None, + crop_to_aspect_ratio=True, + pad_to_aspect_ratio=False, + interpolation="bilinear", + antialias=False, + bounding_box_format="yxyx", + data_format=None, + **kwargs, + ): + super().__init__(**kwargs) + if crop_to_aspect_ratio and pad_to_aspect_ratio: + raise ValueError( + "Only one of 'crop_to_aspect_ratio' or 'pad_to_aspect_ratio' " + "can be True." + ) + + self.image_size_val = image_size + self.crop_pct = crop_pct + self.mean = mean + self.std = std + self.scale = scale + self.offset = offset + self.crop_to_aspect_ratio = crop_to_aspect_ratio + self.pad_to_aspect_ratio = pad_to_aspect_ratio + self.interpolation = interpolation + self.antialias = antialias + self.bounding_box_format = bounding_box_format + self.data_format = standardize_data_format(data_format) + + self.custom_resizing = None + if image_size is not None: + self.custom_resizing = ResizeThenCrop( + target_size=image_size, + crop_pct=crop_pct, + interpolation=interpolation, + antialias=antialias, + dtype=self.dtype_policy, + name="custom_resizing", + ) + self.built = True + + @property + def image_size(self): + return self.image_size_val + + @image_size.setter + def image_size(self, value): + self.image_size_val = value + if value is not None: + self.custom_resizing = ResizeThenCrop( + target_size=value, + crop_pct=self.crop_pct, + interpolation=self.interpolation, + antialias=self.antialias, + dtype=self.dtype_policy, + name="custom_resizing", + ) + else: + self.custom_resizing = None + + @preprocessing_function + def call(self, inputs): + if self.image_size is not None and self.custom_resizing is not None: + if in_tf_function(): + target_height, target_width = self.image_size + intermediate_size = tf.cast( + tf.math.floor(target_height / self.crop_pct), tf.int32 + ) + input_shape = tf.shape(inputs) + height = input_shape[-3] + width = input_shape[-2] + aspect_ratio = tf.cast(width, tf.float32) / tf.cast( + height, tf.float32 + ) + resize_height = tf.cond( + aspect_ratio > 1, + lambda: intermediate_size, + lambda: tf.cast( + tf.cast(intermediate_size, tf.float32) / aspect_ratio, + tf.int32, + ), + ) + resize_width = tf.cond( + aspect_ratio > 1, + lambda: tf.cast( + tf.cast(intermediate_size, tf.float32) * aspect_ratio, + tf.int32, + ), + lambda: intermediate_size, + ) + resized = tf.image.resize( + inputs, + [resize_height, resize_width], + method=self.interpolation, + antialias=self.antialias, + ) + top = (resize_height - target_height) // 2 + left = (resize_width - target_width) // 2 + cropped = resized[ + :, top : top + target_height, left : left + target_width, : + ] + inputs = cropped + else: + inputs = self.custom_resizing(inputs) + if isinstance(inputs, dict): + x = inputs["images"] + else: + x = inputs + if in_tf_function(): + x = tf.cast(x, self.compute_dtype) / 255.0 + else: + x = keras.ops.cast(x, self.compute_dtype) / 255.0 + if self.mean is not None and self.std is not None: + mean = self._expand_non_channel_dims(self.mean, x) + std = self._expand_non_channel_dims(self.std, x) + x, mean = self._convert_types(x, mean, self.compute_dtype) + x, std = self._convert_types(x, std, self.compute_dtype) + x = (x - mean) / std + if self.scale is not None: + scale = self._expand_non_channel_dims(self.scale, x) + x, scale = self._convert_types(x, scale, self.compute_dtype) + x = x * scale + if self.offset is not None: + offset = self._expand_non_channel_dims(self.offset, x) + x, offset = self._convert_types(x, offset, x.dtype) + x = x + offset + if isinstance(inputs, dict): + inputs["images"] = x + else: + inputs = x + return inputs + + def _expand_non_channel_dims(self, value, inputs): + unbatched = len(keras.ops.shape(inputs)) == 3 + channels_first = self.data_format == "channels_first" + if unbatched: + broadcast_dims = (1, 2) if channels_first else (0, 1) + else: + broadcast_dims = (0, 2, 3) if channels_first else (0, 1, 2) + return np.expand_dims(value, broadcast_dims) + + def _convert_types(self, x, y, dtype): + if in_tf_function(): + return tf.cast(x, dtype), tf.cast(y, dtype) + x = keras.ops.cast(x, dtype) + y = keras.ops.cast(y, dtype) + if keras.backend.backend() == "torch": + y = y.to(x.device) + return x, y + + def get_config(self): + config = super().get_config() + config.update( + { + "image_size": self.image_size, + "crop_pct": self.crop_pct, + "mean": self.mean, + "std": self.std, + "scale": self.scale, + "offset": self.offset, + "crop_to_aspect_ratio": self.crop_to_aspect_ratio, + "pad_to_aspect_ratio": self.pad_to_aspect_ratio, + "interpolation": self.interpolation, + "antialias": self.antialias, + "bounding_box_format": self.bounding_box_format, + } + ) + return config + + @classproperty + def presets(cls): + return builtin_presets(cls) + + @classmethod + def from_preset(cls, preset, **kwargs): + loader = get_preset_loader(preset) + backbone_cls = loader.check_backbone_class() + if cls.backbone_cls != backbone_cls: + cls = find_subclass(preset, cls, backbone_cls) + return loader.load_image_converter(cls, **kwargs) + + def save_to_preset(self, preset_dir): + saver = get_preset_saver(preset_dir) + saver.save_image_converter(self) diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_presets.py b/keras_hub/src/models/hgnetv2/hgnetv2_presets.py index a74de81c39..e869a31cbb 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_presets.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_presets.py @@ -1,91 +1,12 @@ # Metadata for loading pretrained model weights. backbone_presets = { - "hgnetv2_b0.ssld_stage1_in22k_in1k": { - "metadata": { - "description": ( - "HGNetV2 B0 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K." # noqa: E501 - ), - "params": 5996550, - "path": "hgnetv2", - }, - "kaggle_handle": "", - }, - "hgnetv2_b0.ssld_stage2_ft_in1k": { - "metadata": { - "description": ( - "HGNetV2 B0 model with 2-stage SSLD training, fine-tuned on ImageNet-1K." # noqa: E501 - ), - "params": 5996550, - "path": "hgnetv2", - }, - "kaggle_handle": "", - }, - "hgnetv2_b1.ssld_stage1_in22k_in1k": { - "metadata": { - "description": ( - "HGNetV2 B1 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K." # noqa: E501 - ), - "params": 6343158, - "path": "hgnetv2", - }, - "kaggle_handle": "", - }, - "hgnetv2_b1.ssld_stage2_ft_in1k": { - "metadata": { - "description": ( - "HGNetV2 B1 model with 2-stage SSLD training, fine-tuned on ImageNet-1K." # noqa: E501 - ), - "params": 6343158, - "path": "hgnetv2", - }, - "kaggle_handle": "", - }, - "hgnetv2_b2.ssld_stage1_in22k_in1k": { - "metadata": { - "description": ( - "HGNetV2 B2 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K." # noqa: E501 - ), - "params": 11221356, - "path": "hgnetv2", - }, - "kaggle_handle": "", - }, - "hgnetv2_b2.ssld_stage2_ft_in1k": { - "metadata": { - "description": ( - "HGNetV2 B2 model with 2-stage SSLD training, fine-tuned on ImageNet-1K." # noqa: E501 - ), - "params": 11221356, - "path": "hgnetv2", - }, - "kaggle_handle": "", - }, - "hgnetv2_b3.ssld_stage1_in22k_in1k": { - "metadata": { - "description": ( - "HGNetV2 B3 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K." # noqa: E501 - ), - "params": 16292216, - "path": "hgnetv2", - }, - "kaggle_handle": "", - }, - "hgnetv2_b3.ssld_stage2_ft_in1k": { - "metadata": { - "description": ( - "HGNetV2 B3 model with 2-stage SSLD training, fine-tuned on ImageNet-1K." # noqa: E501 - ), - "params": 16292216, - "path": "hgnetv2", - }, - "kaggle_handle": "", - }, "hgnetv2_b4.ssld_stage2_ft_in1k": { "metadata": { "description": ( - "HGNetV2 B4 model with 2-stage SSLD training, fine-tuned on ImageNet-1K." # noqa: E501 + "HGNetV2 B4 model with 2-stage SSLD training, fine-tuned on " + "ImageNet-1K." ), - "params": 19796680, + "params": 13599072, "path": "hgnetv2", }, "kaggle_handle": "", @@ -93,9 +14,10 @@ "hgnetv2_b5.ssld_stage1_in22k_in1k": { "metadata": { "description": ( - "HGNetV2 B5 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K." # noqa: E501 + "HGNetV2 B5 model with 1-stage SSLD training, pre-trained on " + "ImageNet-22K and fine-tuned on ImageNet-1K." ), - "params": 39569064, + "params": 33419680, "path": "hgnetv2", }, "kaggle_handle": "", @@ -103,9 +25,10 @@ "hgnetv2_b5.ssld_stage2_ft_in1k": { "metadata": { "description": ( - "HGNetV2 B5 model with 2-stage SSLD training, fine-tuned on ImageNet-1K." # noqa: E501 + "HGNetV2 B5 model with 2-stage SSLD training, fine-tuned on " + "ImageNet-1K." ), - "params": 39569064, + "params": 33419680, "path": "hgnetv2", }, "kaggle_handle": "", @@ -113,9 +36,10 @@ "hgnetv2_b6.ssld_stage1_in22k_in1k": { "metadata": { "description": ( - "HGNetV2 B6 model with 1-stage SSLD training, pre-trained on ImageNet-22K and fine-tuned on ImageNet-1K." # noqa: E501 + "HGNetV2 B6 model with 1-stage SSLD training, pre-trained on " + "ImageNet-22K and fine-tuned on ImageNet-1K." ), - "params": 75256776, + "params": 69179888, "path": "hgnetv2", }, "kaggle_handle": "", @@ -123,9 +47,10 @@ "hgnetv2_b6.ssld_stage2_ft_in1k": { "metadata": { "description": ( - "HGNetV2 B6 model with 2-stage SSLD training, fine-tuned on ImageNet-1K." # noqa: E501 + "HGNetV2 B6 model with 2-stage SSLD training, fine-tuned on " + "ImageNet-1K." ), - "params": 75256776, + "params": 69179888, "path": "hgnetv2", }, "kaggle_handle": "", diff --git a/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py b/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py index f58f66f457..0c23694b8f 100644 --- a/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py +++ b/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py @@ -34,6 +34,9 @@ from keras_hub.src.models.hgnetv2.hgnetv2_layers import ( HGNetV2LearnableAffineBlock, ) +from keras_hub.src.utils.imagenet.imagenet_utils import ( + decode_imagenet_predictions, +) FLAGS = flags.FLAGS @@ -43,14 +46,6 @@ "hgnetv2_b5.ssld_stage2_ft_in1k": "timm/hgnetv2_b5.ssld_stage2_ft_in1k", "hgnetv2_b5.ssld_stage1_in22k_in1k": "timm/hgnetv2_b5.ssld_stage1_in22k_in1k", # noqa: E501 "hgnetv2_b4.ssld_stage2_ft_in1k": "timm/hgnetv2_b4.ssld_stage2_ft_in1k", - "hgnetv2_b3.ssld_stage2_ft_in1k": "timm/hgnetv2_b3.ssld_stage2_ft_in1k", - "hgnetv2_b3.ssld_stage1_in22k_in1k": "timm/hgnetv2_b3.ssld_stage1_in22k_in1k", # noqa: E501 - "hgnetv2_b2.ssld_stage2_ft_in1k": "timm/hgnetv2_b2.ssld_stage2_ft_in1k", - "hgnetv2_b2.ssld_stage1_in22k_in1k": "timm/hgnetv2_b2.ssld_stage1_in22k_in1k", # noqa: E501 - "hgnetv2_b1.ssld_stage2_ft_in1k": "timm/hgnetv2_b1.ssld_stage2_ft_in1k", - "hgnetv2_b1.ssld_stage1_in22k_in1k": "timm/hgnetv2_b1.ssld_stage1_in22k_in1k", # noqa: E501 - "hgnetv2_b0.ssld_stage2_ft_in1k": "timm/hgnetv2_b0.ssld_stage2_ft_in1k", - "hgnetv2_b0.ssld_stage1_in22k_in1k": "timm/hgnetv2_b0.ssld_stage1_in22k_in1k", # noqa: E501 } LAB_FALSE_PRESETS = [ "hgnetv2_b6.ssld_stage2_ft_in1k", @@ -203,7 +198,23 @@ def convert_model(hf_config, architecture, preset_name): stage_light_block=config["stage_light_block"], stage_kernel_size=config["stage_kernel_size"], ) - image_converter = HGNetV2ImageConverter() + pretrained_cfg = hf_config["pretrained_cfg"] + image_size = ( + pretrained_cfg["input_size"][1], + pretrained_cfg["input_size"][2], + ) + mean = pretrained_cfg["mean"] + std = pretrained_cfg["std"] + crop_pct = pretrained_cfg.get("crop_pct", 0.875) + interpolation = pretrained_cfg["interpolation"] + image_converter = HGNetV2ImageConverter( + image_size=image_size, + crop_pct=crop_pct, + mean=mean, + std=std, + interpolation=interpolation, + antialias=True if interpolation == "bicubic" else False, + ) preprocessor = HGNetV2ImageClassifierPreprocessor( image_converter=image_converter ) @@ -396,31 +407,30 @@ def convert_image_converter(hf_config): ) mean = pretrained_cfg["mean"] std = pretrained_cfg["std"] + crop_pct = pretrained_cfg.get("crop_pct", 0.875) interpolation = pretrained_cfg["interpolation"] - return ( - keras.layers.Lambda( - lambda x: keras.preprocessing.image.smart_resize( - x, image_size, interpolation=interpolation - ) - ), - mean, - std, + image_converter = HGNetV2ImageConverter( + image_size=image_size, + crop_pct=crop_pct, + mean=mean, + std=std, + interpolation=interpolation, + antialias=True if interpolation == "bicubic" else False, ) + return image_converter, mean, std def validate_output(keras_model, keras_image_converter, hf_model, mean, std): file = keras.utils.get_file( - origin="http://images.cocodataset.org/val2017/000000039769.jpg" + origin="https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg" ) image = Image.open(file) images = np.expand_dims(np.array(image).astype("float32"), axis=0) - images = np.concatenate([images, images], axis=0) - images = keras_image_converter(images) - images = keras.ops.convert_to_tensor(images, dtype="float32") - mean_tensor = keras.ops.convert_to_tensor(mean, dtype="float32") - std_tensor = keras.ops.convert_to_tensor(std, dtype="float32") - images = (images - mean_tensor) / std_tensor - keras_preprocessed = images + keras_preprocessed = keras_model.preprocessor(images) + keras_preprocessed = keras.ops.convert_to_tensor( + keras_preprocessed, dtype="float32" + ) + keras_logits = keras_model.predict(keras_preprocessed) hf_inputs = torch.from_numpy( keras.ops.convert_to_numpy( keras.ops.transpose(keras_preprocessed, (0, 3, 1, 2)) @@ -436,8 +446,14 @@ def validate_output(keras_model, keras_image_converter, hf_model, mean, std): keras_output_np = keras.ops.convert_to_numpy(keras_last_stage_tensor) hf_output_np = hf_backbone_output.detach().cpu().numpy() hf_output_np = np.transpose(hf_output_np, (0, 2, 3, 1)) - modeling_diff = np.mean(np.abs(keras_output_np - hf_output_np)) - print("šŸ”¶ Modeling difference:", modeling_diff) + print( + "šŸ”¶ Modeling Difference (Mean Absolute):", + np.mean(np.abs(keras_output_np - hf_output_np)), + ) + print( + "šŸ”¬ Keras top 5 ImageNet predictions:", + decode_imagenet_predictions(keras_logits, top=5), + ) def main(_): @@ -449,7 +465,7 @@ def main(_): print(f"\nšŸƒ Converting {preset}") global hf_model - hf_model = create_model(hf_preset, pretrained=False) + hf_model = create_model(hf_preset, pretrained=True) safetensors_file = keras.utils.get_file( origin=f"https://huggingface.co/{hf_preset}/resolve/main/model.safetensors", cache_subdir=f"hf_models/{hf_preset}", @@ -483,9 +499,8 @@ def main(_): upload_uri = FLAGS.upload_uri if upload_uri: keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") - print(f"šŸ Preset {preset} uploaded to {upload_uri}") - - print("\nšŸšŸ All presets validated!") + print(f"šŸ Successfully uploaded {preset} to {upload_uri}") + print("\nšŸ All presets validated and saved successfully!") if __name__ == "__main__": From df903dd3f17d48081d0eada155224aa0cc5d3c1e Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Sat, 28 Jun 2025 17:30:28 +0400 Subject: [PATCH 4/8] incorporate reviews: Consolidate stage configurations and improve API consistency --- .../src/models/hgnetv2/hgnetv2_backbone.py | 102 +++-- .../models/hgnetv2/hgnetv2_backbone_test.py | 75 ++-- .../src/models/hgnetv2/hgnetv2_encoder.py | 32 +- .../hgnetv2/hgnetv2_image_classifier.py | 99 ++++- .../hgnetv2/hgnetv2_image_classifier_test.py | 72 ++-- .../models/hgnetv2/hgnetv2_image_converter.py | 395 +----------------- .../src/models/hgnetv2/hgnetv2_layers.py | 34 +- .../src/models/hgnetv2/hgnetv2_presets.py | 10 +- .../convert_hgnetv2_checkpoints.py | 163 ++++---- 9 files changed, 315 insertions(+), 667 deletions(-) diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py b/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py index 4bf91cda9a..7127cc2d6d 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py @@ -11,10 +11,14 @@ class HGNetV2Backbone(Backbone): """This class represents a Keras Backbone of the HGNetV2 model. - This class implements an HGNetV2 backbone architecture. + This class implements an HGNetV2 backbone architecture, a convolutional + neural network (CNN) optimized for GPU efficiency. HGNetV2 is frequently + used as a lightweight CNN backbone in object detection pipelines like + RT-DETR and YOLO variants, delivering strong performance on classification + and detection tasks, with speed-ups and accuracy gains compared to larger + CNN backbones. Args: - initializer_range: float, the range for initializing weights. depths: list of ints, the number of blocks in each stage. embedding_size: int, the size of the embedding layer. hidden_sizes: list of ints, the sizes of the hidden layers. @@ -23,15 +27,19 @@ class HGNetV2Backbone(Backbone): use_learnable_affine_block: bool, whether to use learnable affine transformations. num_channels: int, the number of channels in the input image. - stage_in_channels: list of ints, the input channels for each stage. - stage_mid_channels: list of ints, the middle channels for each stage. - stage_out_channels: list of ints, the output channels for each stage. - stage_num_blocks: list of ints, the number of blocks in each stage. - stage_numb_of_layers: list of ints, the number of layers in each block. - stage_downsample: list of bools, whether to downsample in each stage. - stage_light_block: list of bools, whether to use light blocks in each - stage. - stage_kernel_size: list of ints, the kernel sizes for each stage. + stackwise_stage_filters: list of tuples, where each tuple contains + configuration for a stage: (stage_in_channels, stage_mid_channels, + stage_out_channels, stage_num_blocks, stage_num_of_layers, + stage_kernel_size). + - stage_in_channels: int, input channels for the stage + - stage_mid_channels: int middle channels for the stage + - stage_out_channels: int, output channels for the stage + - stage_num_blocks: int, number of blocks in the stage + - stage_num_of_layers: int, number of layers in each block + - stage_kernel_size: int, kernel size for the stage + apply_downsample: list of bools, whether to downsample in each stage. + use_lightweight_conv_block: list of bools, whether to use HGNetV2 + lightweight convolutional blocks in each stage. image_shape: tuple, the shape of the input image without the batch size. Defaults to `(None, None, 3)`. data_format: `None` or str, the data format ('channels_last' or @@ -48,13 +56,12 @@ class HGNetV2Backbone(Backbone): # Pretrained backbone. model = keras_hub.models.HGNetV2Backbone.from_preset( - "hgnetv2_b5.ssld_stage2_ft_in1k" + "hgnetv2_b5_ssld_stage2_ft_in1k" ) model(input_data) # Randomly initialized backbone with a custom config. model = HGNetV2Backbone( - initializer_range=0.02, depths=[1, 2, 4], embedding_size=32, hidden_sizes=[64, 128, 256], @@ -62,14 +69,13 @@ class HGNetV2Backbone(Backbone): hidden_act="relu", use_learnable_affine_block=False, num_channels=3, - stage_in_channels=[32, 64, 128], - stage_mid_channels=[16, 32, 64], - stage_out_channels=[64, 128, 256], - stage_num_blocks=[1, 2, 4], - stage_numb_of_layers=[1, 1, 1], - stage_downsample=[False, True, True], - stage_light_block=[False, False, False], - stage_kernel_size=[3, 3, 3], + stackwise_stage_filters=[ + (32, 16, 64, 1, 1, 3), # Stage 0 + (64, 32, 128, 2, 1, 3), # Stage 1 + (128, 64, 256, 4, 1, 3), # Stage 2 + ], + apply_downsample=[False, True, True], + use_lightweight_conv_block=[False, False, False], image_shape=(224, 224, 3), ) model(input_data) @@ -78,7 +84,6 @@ class HGNetV2Backbone(Backbone): def __init__( self, - initializer_range, depths, embedding_size, hidden_sizes, @@ -86,14 +91,9 @@ def __init__( hidden_act, use_learnable_affine_block, num_channels, - stage_in_channels, - stage_mid_channels, - stage_out_channels, - stage_num_blocks, - stage_numb_of_layers, - stage_downsample, - stage_light_block, - stage_kernel_size, + stackwise_stage_filters, + apply_downsample, + use_lightweight_conv_block, image_shape=(None, None, 3), data_format=None, dtype=None, @@ -103,6 +103,12 @@ def __init__( data_format = standardize_data_format(data_format) channel_axis = -1 if data_format == "channels_last" else 1 self.image_shape = image_shape + stage_in_channels = [stage[0] for stage in stackwise_stage_filters] + stage_mid_channels = [stage[1] for stage in stackwise_stage_filters] + stage_out_filters = [stage[2] for stage in stackwise_stage_filters] + stage_num_blocks = [stage[3] for stage in stackwise_stage_filters] + stage_num_of_layers = [stage[4] for stage in stackwise_stage_filters] + stage_kernel_size = [stage[5] for stage in stackwise_stage_filters] # === Layers === self.embedder_layer = HGNetV2Embeddings( @@ -118,11 +124,11 @@ def __init__( self.encoder_layer = HGNetV2Encoder( stage_in_channels=stage_in_channels, stage_mid_channels=stage_mid_channels, - stage_out_channels=stage_out_channels, + stage_out_channels=stage_out_filters, stage_num_blocks=stage_num_blocks, - stage_numb_of_layers=stage_numb_of_layers, - stage_downsample=stage_downsample, - stage_light_block=stage_light_block, + stage_num_of_layers=stage_num_of_layers, + apply_downsample=apply_downsample, + use_lightweight_conv_block=use_lightweight_conv_block, stage_kernel_size=stage_kernel_size, use_learnable_affine_block=use_learnable_affine_block, data_format=data_format, @@ -130,7 +136,9 @@ def __init__( name=f"{name}_encoder" if name else "encoder", dtype=dtype, ) - self.stage_names = [f"stage{i}" for i in range(len(stage_in_channels))] + self.stage_names = [ + f"stage{i}" for i in range(len(stackwise_stage_filters)) + ] self.out_features = self.stage_names # === Functional Model === @@ -149,7 +157,6 @@ def __init__( ) # === Config === - self.initializer_range = initializer_range self.depths = depths self.embedding_size = embedding_size self.hidden_sizes = hidden_sizes @@ -157,21 +164,15 @@ def __init__( self.hidden_act = hidden_act self.use_learnable_affine_block = use_learnable_affine_block self.num_channels = num_channels - self.stage_in_channels = stage_in_channels - self.stage_mid_channels = stage_mid_channels - self.stage_out_channels = stage_out_channels - self.stage_num_blocks = stage_num_blocks - self.stage_numb_of_layers = stage_numb_of_layers - self.stage_downsample = stage_downsample - self.stage_light_block = stage_light_block - self.stage_kernel_size = stage_kernel_size + self.stackwise_stage_filters = stackwise_stage_filters + self.apply_downsample = apply_downsample + self.use_lightweight_conv_block = use_lightweight_conv_block self.data_format = data_format def get_config(self): config = super().get_config() config.update( { - "initializer_range": self.initializer_range, "depths": self.depths, "embedding_size": self.embedding_size, "hidden_sizes": self.hidden_sizes, @@ -179,14 +180,9 @@ def get_config(self): "hidden_act": self.hidden_act, "use_learnable_affine_block": self.use_learnable_affine_block, "num_channels": self.num_channels, - "stage_in_channels": self.stage_in_channels, - "stage_mid_channels": self.stage_mid_channels, - "stage_out_channels": self.stage_out_channels, - "stage_num_blocks": self.stage_num_blocks, - "stage_numb_of_layers": self.stage_numb_of_layers, - "stage_downsample": self.stage_downsample, - "stage_light_block": self.stage_light_block, - "stage_kernel_size": self.stage_kernel_size, + "stackwise_stage_filters": self.stackwise_stage_filters, + "apply_downsample": self.apply_downsample, + "use_lightweight_conv_block": self.use_lightweight_conv_block, "image_shape": self.image_shape, "data_format": self.data_format, } diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py b/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py index 71889a57c5..abeb6bc6fe 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py @@ -11,33 +11,29 @@ class HGNetV2BackboneTest(TestCase): def setUp(self): self.default_input_shape = (64, 64, 3) self.num_channels = self.default_input_shape[-1] + self.batch_size = 2 self.stem_channels = [self.num_channels, 16, 32] - self.default_stage_in_channels = [self.stem_channels[-1], 64] - self.default_stage_mid_channels = [16, 32] - self.default_stage_out_channels = [64, 128] - self.default_num_stages = len(self.default_stage_in_channels) - + self.default_stage_out_filters = [64, 128] + self.default_num_stages = 2 + self.stackwise_stage_filters = [ + [32, 16, 64, 1, 1, 3], + [64, 32, 128, 1, 1, 3], + ] self.init_kwargs = { - "initializer_range": 0.02, - "depths": [1] * self.default_num_stages, "embedding_size": self.stem_channels[-1], - "hidden_sizes": self.default_stage_out_channels, "stem_channels": self.stem_channels, "hidden_act": "relu", "use_learnable_affine_block": False, "num_channels": self.num_channels, - "stage_in_channels": self.default_stage_in_channels, - "stage_mid_channels": self.default_stage_mid_channels, - "stage_out_channels": self.default_stage_out_channels, - "stage_num_blocks": [1] * self.default_num_stages, - "stage_numb_of_layers": [1] * self.default_num_stages, - "stage_downsample": [False, True], - "stage_light_block": [False, False], - "stage_kernel_size": [3] * self.default_num_stages, "image_shape": self.default_input_shape, + "depths": [1] * self.default_num_stages, + "hidden_sizes": [ + stage[2] for stage in self.stackwise_stage_filters + ], + "stackwise_stage_filters": self.stackwise_stage_filters, + "apply_downsample": [False, True], + "use_lightweight_conv_block": [False, False], } - self.input_size = self.default_input_shape[:2] - self.batch_size = 2 self.input_data = keras.ops.convert_to_tensor( np.random.rand(self.batch_size, *self.default_input_shape).astype( np.float32 @@ -46,7 +42,7 @@ def setUp(self): @parameterized.named_parameters( ( - "default_config", + "default", [False, True], [False, False], 2, @@ -76,39 +72,24 @@ def setUp(self): ) def test_backbone_basics( self, - stage_downsample_config, - stage_light_block_config, + apply_downsample, + use_lightweight_conv_block, num_stages, expected_shapes, ): - current_init_kwargs = self.init_kwargs.copy() - current_init_kwargs["depths"] = [1] * num_stages - current_init_kwargs["hidden_sizes"] = self.default_stage_out_channels[ - :num_stages - ] - current_init_kwargs["stage_in_channels"] = ( - self.default_stage_in_channels[:num_stages] - ) - current_init_kwargs["stage_mid_channels"] = ( - self.default_stage_mid_channels[:num_stages] - ) - current_init_kwargs["stage_out_channels"] = ( - self.default_stage_out_channels[:num_stages] - ) - current_init_kwargs["stage_num_blocks"] = [1] * num_stages - current_init_kwargs["stage_numb_of_layers"] = [1] * num_stages - current_init_kwargs["stage_kernel_size"] = [3] * num_stages - current_init_kwargs["stage_downsample"] = stage_downsample_config - current_init_kwargs["stage_light_block"] = stage_light_block_config - if num_stages > 0: - current_init_kwargs["stage_in_channels"][0] = self.stem_channels[-1] - for i in range(1, num_stages): - current_init_kwargs["stage_in_channels"][i] = ( - current_init_kwargs["stage_out_channels"][i - 1] - ) + test_filters = self.stackwise_stage_filters[:num_stages] + hidden_sizes = [stage[2] for stage in test_filters] + test_kwargs = { + **self.init_kwargs, + "depths": [1] * num_stages, + "hidden_sizes": hidden_sizes, + "stackwise_stage_filters": test_filters, + "apply_downsample": apply_downsample, + "use_lightweight_conv_block": use_lightweight_conv_block, + } self.run_vision_backbone_test( cls=HGNetV2Backbone, - init_kwargs=current_init_kwargs, + init_kwargs=test_kwargs, input_data=self.input_data, expected_output_shape=expected_shapes, run_mixed_precision_check=False, diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py b/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py index b7108ed87d..201cc72b7e 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py @@ -21,12 +21,12 @@ class HGNetV2Encoder(keras.layers.Layer): for each stage. stage_num_blocks: A list of integers, specifying the number of blocks in each stage. - stage_numb_of_layers: A list of integers, specifying the number of + stage_num_of_layers: A list of integers, specifying the number of layers in each block of each stage. - stage_downsample: A list of booleans or integers, indicating whether to + apply_downsample: A list of booleans or integers, indicating whether to downsample in each stage. - stage_light_block: A list of booleans, indicating whether to use light - blocks in each stage. + use_lightweight_conv_block: A list of booleans, indicating whether to + use HGNetV2 lightweight convolutional blocks in each stage. stage_kernel_size: A list of integers or tuples, specifying the kernel size for each stage. use_learnable_affine_block: A boolean, indicating whether to use @@ -49,9 +49,9 @@ def __init__( stage_mid_channels, stage_out_channels, stage_num_blocks, - stage_numb_of_layers, - stage_downsample, - stage_light_block, + stage_num_of_layers, + apply_downsample, + use_lightweight_conv_block, stage_kernel_size, use_learnable_affine_block, data_format=None, @@ -63,9 +63,9 @@ def __init__( self.stage_mid_channels = stage_mid_channels self.stage_out_channels = stage_out_channels self.stage_num_blocks = stage_num_blocks - self.stage_numb_of_layers = stage_numb_of_layers - self.stage_downsample = stage_downsample - self.stage_light_block = stage_light_block + self.stage_num_of_layers = stage_num_of_layers + self.apply_downsample = apply_downsample + self.use_lightweight_conv_block = use_lightweight_conv_block self.stage_kernel_size = stage_kernel_size self.use_learnable_affine_block = use_learnable_affine_block self.data_format = data_format @@ -78,9 +78,9 @@ def __init__( stage_mid_channels=self.stage_mid_channels, stage_out_channels=self.stage_out_channels, stage_num_blocks=self.stage_num_blocks, - stage_numb_of_layers=self.stage_numb_of_layers, - stage_downsample=self.stage_downsample, - stage_light_block=self.stage_light_block, + stage_num_of_layers=self.stage_num_of_layers, + apply_downsample=self.apply_downsample, + use_lightweight_conv_block=self.use_lightweight_conv_block, stage_kernel_size=self.stage_kernel_size, use_learnable_affine_block=self.use_learnable_affine_block, stage_index=stage_idx, @@ -136,9 +136,9 @@ def get_config(self): "stage_mid_channels": self.stage_mid_channels, "stage_out_channels": self.stage_out_channels, "stage_num_blocks": self.stage_num_blocks, - "stage_numb_of_layers": self.stage_numb_of_layers, - "stage_downsample": self.stage_downsample, - "stage_light_block": self.stage_light_block, + "stage_num_of_layers": self.stage_num_of_layers, + "apply_downsample": self.apply_downsample, + "use_lightweight_conv_block": self.use_lightweight_conv_block, "stage_kernel_size": self.stage_kernel_size, "use_learnable_affine_block": self.use_learnable_affine_block, "data_format": self.data_format, diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py index fb3a12c90e..df60e03ef7 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py @@ -12,6 +12,94 @@ @keras_hub_export("keras_hub.models.HGNetV2ImageClassifier") class HGNetV2ImageClassifier(ImageClassifier): + """HGNetV2 image classification model. + + `HGNetV2ImageClassifier` wraps a `HGNetV2Backbone` and + a `HGNetV2ImageClassifierPreprocessor` to create a model that can be used + for image classification tasks. This model implements the HGNetV2 + architecture with an additional classification head including a 1x1 + convolution layer, global pooling, and a dense output layer. + + The model takes an additional `num_classes` argument, controlling the number + of predicted output classes, and optionally, a `head_filters` argument to + specify the number of filters in the classification head's convolution + layer. To fine-tune with `fit()`, pass a dataset containing tuples of + `(x, y)` labels where `x` is an image tensor and `y` is an integer from + `[0, num_classes)`. + + Args: + backbone: A `HGNetV2Backbone` instance. + preprocessor: A `HGNetV2ImageClassifierPreprocessor` instance, + a `keras.Layer` instance, or a callable. If `None` no preprocessing + will be applied to the inputs. + num_classes: int. The number of classes to predict. + head_filters: int, optional. The number of filters in the + classification head's 1x1 convolution layer. If `None`, it defaults + to the last value of `hidden_sizes` from the backbone. + pooling: `"avg"` or `"max"`. The type of global pooling to apply after + the head convolution. Defaults to `"avg"`. + activation: `None`, str, or callable. The activation function to use on + the final `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `None`. + dropout: float. Dropout rate applied before the final dense layer. + Defaults to 0.0. + head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The + dtype to use for the classification head's computations and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and predict. + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + classifier = keras_hub.models.HGNetV2ImageClassifier.from_preset( + "hgnetv2_b5_ssld_stage2_ft_in1k" + ) + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train. + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + labels = [0, 3] + classifier = keras_hub.models.HGNetV2ImageClassifier.from_preset( + "hgnetv2_b5_ssld_stage2_ft_in1k" + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and frozen backbone. + ```python + classifier = keras_hub.models.HGNetV2ImageClassifier.from_preset( + "hgnetv2_b5_ssld_stage2_ft_in1k" + ) + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Create a custom HGNetV2 classifier with specific head configuration. + ```python + backbone = keras_hub.models.HGNetV2Backbone.from_preset( + "hgnetv2_b5_ssld_stage2_ft_in1k" + ) + preproc = keras_hub.models.HGNetV2ImageClassifierPreprocessor.from_preset( + "hgnetv2_b5_ssld_stage2_ft_in1k" + ) + classifier = keras_hub.models.HGNetV2ImageClassifier( + backbone=backbone, + preprocessor=preproc, + num_classes=10, + pooling="avg", + dropout=0.2, + ) + ``` + """ + backbone_cls = HGNetV2Backbone preprocessor_cls = HGNetV2ImageClassifierPreprocessor @@ -20,12 +108,11 @@ def __init__( backbone, preprocessor, num_classes, - head_filters, + head_filters=None, pooling="avg", activation=None, dropout=0.0, head_dtype=None, - use_learnable_affine_block_head=False, **kwargs, ): name = kwargs.get("name", "hgnetv2_image_classifier") @@ -40,7 +127,11 @@ def __init__( self.pooling = pooling self.activation = activation self.dropout = dropout - self.head_filters = head_filters + self.head_filters = ( + head_filters + if head_filters is not None + else backbone.hidden_sizes[-1] + ) # === Layers === self.backbone = backbone @@ -52,7 +143,7 @@ def __init__( stride=1, groups=1, activation="relu", - use_learnable_affine_block=use_learnable_affine_block_head, + use_learnable_affine_block=self.backbone.use_learnable_affine_block, data_format=data_format, channel_axis=channel_axis, name="head_last", diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py index 037ebdee51..907902f73d 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py @@ -17,52 +17,41 @@ class HGNetV2ImageClassifierTest(TestCase): def setUp(self): self.batch_size = 2 - self.height = 64 - self.width = 64 - self.num_channels = 3 - self.image_input_shape = (self.height, self.width, self.num_channels) + self.image_shape = (64, 64, 3) self.num_classes = 3 self.images = np.ones( - (self.batch_size, *self.image_input_shape), dtype="float32" + (self.batch_size, *self.image_shape), dtype="float32" ) self.labels = np.random.randint(0, self.num_classes, self.batch_size) - num_stages = 2 + self.train_data = (self.images, self.labels) # Setup model. - stem_channels = [self.num_channels, 16, 32] - stage_in_channels = [stem_channels[-1], 64][:num_stages] - stage_mid_channels = [16, 32][:num_stages] - stage_out_channels = [64, 128][:num_stages] - stage_num_blocks = [1] * num_stages - stage_numb_of_layers = [1] * num_stages - stage_downsample = [False, True][:num_stages] - stage_light_block = [False, False][:num_stages] - stage_kernel_size = [3] * num_stages - self.backbone = HGNetV2Backbone( initializer_range=0.02, - depths=stage_num_blocks, - embedding_size=stem_channels[-1], - hidden_sizes=stage_out_channels, - stem_channels=stem_channels, + depths=[1, 1], + embedding_size=32, + hidden_sizes=[64, 128], + stem_channels=[self.image_shape[-1], 16, 32], hidden_act="relu", use_learnable_affine_block=False, - num_channels=self.num_channels, - stage_in_channels=stage_in_channels, - stage_mid_channels=stage_mid_channels, - stage_out_channels=stage_out_channels, - stage_num_blocks=stage_num_blocks, - stage_numb_of_layers=stage_numb_of_layers, - stage_downsample=stage_downsample, - stage_light_block=stage_light_block, - stage_kernel_size=stage_kernel_size, - image_shape=self.image_input_shape, + num_channels=self.image_shape[-1], + stackwise_stage_filters=[ + [32, 16, 64, 1, 1, 3], + [64, 32, 128, 1, 1, 3], + ], + apply_downsample=[False, True], + use_lightweight_conv_block=[False, False], + image_shape=self.image_shape, ) + mean = [0.5, 0.5, 0.5] + std = [0.5, 0.5, 0.5] + scale = [1 / (255.0 * s) for s in std] + offset = [-m / s for m, s in zip(mean, std)] self.image_converter = HGNetV2ImageConverter( - image_size=(self.height, self.width), - crop_pct=0.875, - mean=[0.5, 0.5, 0.5], - std=[0.5, 0.5, 0.5], + image_size=self.image_shape[:2], + scale=scale, + offset=offset, interpolation="bilinear", + antialias=False, ) self.preprocessor = HGNetV2ImageClassifierPreprocessor( image_converter=self.image_converter @@ -71,25 +60,14 @@ def setUp(self): "backbone": self.backbone, "preprocessor": self.preprocessor, "num_classes": self.num_classes, - "head_filters": stage_out_channels[-1], } - self.train_data = ( - self.images, - self.labels, - ) self.expected_backbone_output_shapes = { "stage0": (self.batch_size, 16, 16, 64), "stage1": (self.batch_size, 8, 8, 128), } - self.preset_image_size = 224 + self.preset_image_shape = (224, 224, 3) self.images_for_presets = np.ones( - ( - self.batch_size, - self.preset_image_size, - self.preset_image_size, - self.num_channels, - ), - dtype="float32", + (self.batch_size, *self.preset_image_shape), dtype="float32" ) def test_classifier_basics(self): diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py index 0257724020..b815ee0bfd 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py @@ -1,399 +1,8 @@ -import math - -import keras -import numpy as np -import tensorflow as tf - from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.layers.preprocessing.preprocessing_layer import ( - PreprocessingLayer, -) +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone -from keras_hub.src.utils.keras_utils import standardize_data_format -from keras_hub.src.utils.preset_utils import builtin_presets -from keras_hub.src.utils.preset_utils import find_subclass -from keras_hub.src.utils.preset_utils import get_preset_loader -from keras_hub.src.utils.preset_utils import get_preset_saver -from keras_hub.src.utils.python_utils import classproperty -from keras_hub.src.utils.tensor_utils import in_tf_function -from keras_hub.src.utils.tensor_utils import preprocessing_function - - -@keras.saving.register_keras_serializable(package="keras_hub") -class ResizeThenCrop(keras.layers.Layer): - """Resize and crop images to a target size while preserving aspect ratio. - - This layer resizes an input image to an intermediate size based on the - `crop_pct` parameter, preserving the aspect ratio, and then performs a - center crop to achieve the specified `target_size`. This preprocessing step - is commonly used in image classification models to prepare inputs for - neural networks. - - The preprocessing follows these steps: - 1. Compute an intermediate size by dividing the target height by `crop_pct`. - 2. Resize the image to the intermediate size, maintaining the aspect ratio. - 3. Crop the resized image from the center to the specified `target_size`. - - The layer accepts batched or unbatched image tensors with shape - `(..., height, width, channels)` (`channels_last` format). - - Args: - target_size: `(int, int)` tuple. The desired output size (height, width) - of the image, excluding the channels dimension. - crop_pct: float. The cropping percentage, typically between 0.0 and 1.0, - used to compute the intermediate resize size. For example, a - `crop_pct` of 0.875 means the intermediate height is - `target_height / 0.875`. - interpolation: String, the interpolation method for resizing. - Supports `"bilinear"`, `"nearest"`, `"bicubic"`, `"lanczos3"`, - `"lanczos5"`. Defaults to `"bilinear"`. - **kwargs: Additional keyword arguments passed to the parent - `keras.layers.Layer` class, such as `name` or `dtype`. - """ - - def __init__( - self, - target_size, - crop_pct, - interpolation="bilinear", - antialias=False, - **kwargs, - ): - super().__init__(**kwargs) - self.target_size = target_size - self.crop_pct = crop_pct - self.interpolation = interpolation - self.antialias = antialias - - def call(self, inputs): - target_height, target_width = self.target_size - intermediate_size = int(math.floor(target_height / self.crop_pct)) - input_shape = keras.ops.shape(inputs) - height = input_shape[-3] - width = input_shape[-2] - aspect_ratio = keras.ops.cast(width, "float32") / keras.ops.cast( - height, "float32" - ) - if keras.ops.is_tensor(aspect_ratio): - resize_height = keras.ops.cond( - aspect_ratio > 1, - lambda: intermediate_size, - lambda: keras.ops.cast( - intermediate_size / aspect_ratio, "int32" - ), - ) - resize_width = keras.ops.cond( - aspect_ratio > 1, - lambda: keras.ops.cast( - intermediate_size * aspect_ratio, "int32" - ), - lambda: intermediate_size, - ) - else: - if aspect_ratio > 1: - resize_height = intermediate_size - resize_width = int(intermediate_size * aspect_ratio) - else: - resize_width = intermediate_size - resize_height = int(intermediate_size / aspect_ratio) - resized = keras.ops.image.resize( - inputs, - (resize_height, resize_width), - interpolation=self.interpolation, - antialias=self.antialias, - ) - top = (resize_height - target_height) // 2 - left = (resize_width - target_width) // 2 - cropped = resized[ - :, top : top + target_height, left : left + target_width, : - ] - return cropped - - def get_config(self): - config = super().get_config() - config.update( - { - "target_size": self.target_size, - "crop_pct": self.crop_pct, - "antialias": self.antialias, - "interpolation": self.interpolation, - } - ) - return config @keras_hub_export("keras_hub.layers.HGNetV2ImageConverter") -class HGNetV2ImageConverter(PreprocessingLayer): - """Preprocess raw images into model-ready inputs for HGNetV2 models. - - This layer converts raw images into inputs suitable for HGNetV2 models. - The preprocessing includes resizing, cropping, scaling, and normalization - steps tailored for HGNetV2 architectures. The conversion proceeds in the - following steps: - - 1. Resize and crop the image to `image_size` using a `ResizeThenCrop` layer - with the specified `crop_pct` if `image_size` is provided. If - `image_size` is `None`, this step is skipped. - 2. Scale the image by dividing pixel values by 255.0 to normalize to [0, 1]. - 3. Normalize the image by subtracting the `mean` and dividing by the `std` - (per channel) if both are provided. If `mean` or `std` is `None`, this - step is skipped. - - The layer accepts batched or unbatched image tensors in channels_last or - channels_first format, with shape `(..., height, width, channels)` or - `(..., channels, height, width)`, respectively. It can also handle - dictionary inputs with an `"images"` key for compatibility with bounding box - preprocessing. - - This layer can be instantiated using the `from_preset()` constructor to load - preprocessing configurations for specific HGNetV2 presets, ensuring - compatibility with pretrained models. - - Args: - image_size: `(int, int)` tuple or `None`. The output size of the image - (height, width), excluding the channels axis. If `None`, resizing - and cropping are skipped. - crop_pct: float. The cropping percentage used in the `ResizeThenCrop` - layer to compute the intermediate resize size. Defaults to 0.875. - mean: list or tuple of floats, or `None`. Per-channel mean values for - normalization. If provided, these are subtracted from the image - after scaling. If `None`, this step is skipped. - std: list or tuple of floats, or `None`. Per-channel standard deviation - values for normalization. If provided, the image is divided by these - after mean subtraction. If `None`, this step is skipped. - interpolation: String, the interpolation method for resizing. - Supports `"bilinear"`, `"nearest"`, `"bicubic"`, `"lanczos3"`, - `"lanczos5"`. Defaults to `"bilinear"`. - antialias: Whether to use an antialiasing filter when downsampling an - image. Defaults to `False`. - bounding_box_format: A string specifying the format of the bounding - boxes, one of `"xyxy"`, `"rel_xyxy"`, `"xywh"`, `"center_xywh"`, - `"yxyx"`, `"rel_yxyx"`. Specifies the format of the bounding boxes - which will be resized to `image_size` along with the image. To pass - bounding boxes to this layer, pass a dict with keys `"images"` and - `"bounding_boxes"` when calling the layer. - data_format: String, either `"channels_last"` or `"channels_first"`. - The ordering of the dimensions in the inputs. `"channels_last"` - corresponds to inputs with shape `(batch, height, width, channels)` - while `"channels_first"` corresponds to inputs with shape - `(batch, channels, height, width)`. It defaults to the - `image_data_format` value found in your Keras config file at - `~/.keras/keras.json`. If you never set it, then it will be - `"channels_last"`. - - Examples: - ```python - import keras - import numpy as np - - # Create an HGNetV2ImageConverter for a specific image size. - converter = keras_hub.layers.HGNetV2ImageConverter( - image_size=(224, 224), - crop_pct=0.965, - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225], - interpolation="bicubic", - ) - images = np.random.randint(0, 256, size=(2, 512, 512, 3)).astype("float32") - processed_images = converter(images) - - # Load an HGNetV2ImageConverter from a preset. - converter = keras_hub.layers.HGNetV2ImageConverter.from_preset( - "hgnetv2_b5.ssld_stage1_in22k_in1k" - ) - processed_images = converter(images) - """ - +class HGNetV2ImageConverter(ImageConverter): backbone_cls = HGNetV2Backbone - - def __init__( - self, - image_size=None, - crop_pct=0.875, - mean=None, - std=None, - scale=None, - offset=None, - crop_to_aspect_ratio=True, - pad_to_aspect_ratio=False, - interpolation="bilinear", - antialias=False, - bounding_box_format="yxyx", - data_format=None, - **kwargs, - ): - super().__init__(**kwargs) - if crop_to_aspect_ratio and pad_to_aspect_ratio: - raise ValueError( - "Only one of 'crop_to_aspect_ratio' or 'pad_to_aspect_ratio' " - "can be True." - ) - - self.image_size_val = image_size - self.crop_pct = crop_pct - self.mean = mean - self.std = std - self.scale = scale - self.offset = offset - self.crop_to_aspect_ratio = crop_to_aspect_ratio - self.pad_to_aspect_ratio = pad_to_aspect_ratio - self.interpolation = interpolation - self.antialias = antialias - self.bounding_box_format = bounding_box_format - self.data_format = standardize_data_format(data_format) - - self.custom_resizing = None - if image_size is not None: - self.custom_resizing = ResizeThenCrop( - target_size=image_size, - crop_pct=crop_pct, - interpolation=interpolation, - antialias=antialias, - dtype=self.dtype_policy, - name="custom_resizing", - ) - self.built = True - - @property - def image_size(self): - return self.image_size_val - - @image_size.setter - def image_size(self, value): - self.image_size_val = value - if value is not None: - self.custom_resizing = ResizeThenCrop( - target_size=value, - crop_pct=self.crop_pct, - interpolation=self.interpolation, - antialias=self.antialias, - dtype=self.dtype_policy, - name="custom_resizing", - ) - else: - self.custom_resizing = None - - @preprocessing_function - def call(self, inputs): - if self.image_size is not None and self.custom_resizing is not None: - if in_tf_function(): - target_height, target_width = self.image_size - intermediate_size = tf.cast( - tf.math.floor(target_height / self.crop_pct), tf.int32 - ) - input_shape = tf.shape(inputs) - height = input_shape[-3] - width = input_shape[-2] - aspect_ratio = tf.cast(width, tf.float32) / tf.cast( - height, tf.float32 - ) - resize_height = tf.cond( - aspect_ratio > 1, - lambda: intermediate_size, - lambda: tf.cast( - tf.cast(intermediate_size, tf.float32) / aspect_ratio, - tf.int32, - ), - ) - resize_width = tf.cond( - aspect_ratio > 1, - lambda: tf.cast( - tf.cast(intermediate_size, tf.float32) * aspect_ratio, - tf.int32, - ), - lambda: intermediate_size, - ) - resized = tf.image.resize( - inputs, - [resize_height, resize_width], - method=self.interpolation, - antialias=self.antialias, - ) - top = (resize_height - target_height) // 2 - left = (resize_width - target_width) // 2 - cropped = resized[ - :, top : top + target_height, left : left + target_width, : - ] - inputs = cropped - else: - inputs = self.custom_resizing(inputs) - if isinstance(inputs, dict): - x = inputs["images"] - else: - x = inputs - if in_tf_function(): - x = tf.cast(x, self.compute_dtype) / 255.0 - else: - x = keras.ops.cast(x, self.compute_dtype) / 255.0 - if self.mean is not None and self.std is not None: - mean = self._expand_non_channel_dims(self.mean, x) - std = self._expand_non_channel_dims(self.std, x) - x, mean = self._convert_types(x, mean, self.compute_dtype) - x, std = self._convert_types(x, std, self.compute_dtype) - x = (x - mean) / std - if self.scale is not None: - scale = self._expand_non_channel_dims(self.scale, x) - x, scale = self._convert_types(x, scale, self.compute_dtype) - x = x * scale - if self.offset is not None: - offset = self._expand_non_channel_dims(self.offset, x) - x, offset = self._convert_types(x, offset, x.dtype) - x = x + offset - if isinstance(inputs, dict): - inputs["images"] = x - else: - inputs = x - return inputs - - def _expand_non_channel_dims(self, value, inputs): - unbatched = len(keras.ops.shape(inputs)) == 3 - channels_first = self.data_format == "channels_first" - if unbatched: - broadcast_dims = (1, 2) if channels_first else (0, 1) - else: - broadcast_dims = (0, 2, 3) if channels_first else (0, 1, 2) - return np.expand_dims(value, broadcast_dims) - - def _convert_types(self, x, y, dtype): - if in_tf_function(): - return tf.cast(x, dtype), tf.cast(y, dtype) - x = keras.ops.cast(x, dtype) - y = keras.ops.cast(y, dtype) - if keras.backend.backend() == "torch": - y = y.to(x.device) - return x, y - - def get_config(self): - config = super().get_config() - config.update( - { - "image_size": self.image_size, - "crop_pct": self.crop_pct, - "mean": self.mean, - "std": self.std, - "scale": self.scale, - "offset": self.offset, - "crop_to_aspect_ratio": self.crop_to_aspect_ratio, - "pad_to_aspect_ratio": self.pad_to_aspect_ratio, - "interpolation": self.interpolation, - "antialias": self.antialias, - "bounding_box_format": self.bounding_box_format, - } - ) - return config - - @classproperty - def presets(cls): - return builtin_presets(cls) - - @classmethod - def from_preset(cls, preset, **kwargs): - loader = get_preset_loader(preset) - backbone_cls = loader.check_backbone_class() - if cls.backbone_cls != backbone_cls: - cls = find_subclass(preset, cls, backbone_cls) - return loader.load_image_converter(cls, **kwargs) - - def save_to_preset(self, preset_dir): - saver = get_preset_saver(preset_dir) - saver.save_image_converter(self) diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_layers.py b/keras_hub/src/models/hgnetv2/hgnetv2_layers.py index 7e23cb2e1c..1a9eb1fd12 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_layers.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_layers.py @@ -772,12 +772,12 @@ class HGNetV2Stage(keras.layers.Layer): stage_mid_channels: list of int. Middle channels for each stage. stage_out_channels: list of int. Output channels for each stage. stage_num_blocks: list of int. Number of basic layers in each stage. - stage_numb_of_layers: list of int. Number of convolutional blocks in + stage_num_of_layers: list of int. Number of convolutional blocks in each basic layer. - stage_downsample: list of bool. Whether to downsample at the beginning + apply_downsample: list of bools. Whether to downsample at the beginning of each stage. - stage_light_block: list of bool. Whether to use lightweight blocks in - each stage. + use_lightweight_conv_block: list of bools. Whether to use HGNetV2 + lightweight convolutional block in the stage. stage_kernel_size: list of int. Kernel sizes for each stage. use_learnable_affine_block: bool. Whether to use learnable affine blocks. @@ -796,9 +796,9 @@ def __init__( stage_mid_channels, stage_out_channels, stage_num_blocks, - stage_numb_of_layers, - stage_downsample, - stage_light_block, + stage_num_of_layers, + apply_downsample, + use_lightweight_conv_block, stage_kernel_size, use_learnable_affine_block, stage_index: int, @@ -812,9 +812,9 @@ def __init__( self.stage_mid_channels = stage_mid_channels self.stage_out_channels = stage_out_channels self.stage_num_blocks = stage_num_blocks - self.stage_numb_of_layers = stage_numb_of_layers - self.stage_downsample = stage_downsample - self.stage_light_block = stage_light_block + self.stage_num_of_layers = stage_num_of_layers + self.apply_downsample = apply_downsample + self.use_lightweight_conv_block = use_lightweight_conv_block self.stage_kernel_size = stage_kernel_size self.use_learnable_affine_block = use_learnable_affine_block self.stage_index = stage_index @@ -825,11 +825,13 @@ def __init__( self.current_stage_mid_channels = stage_mid_channels[stage_index] self.current_stage_out_channels = stage_out_channels[stage_index] self.current_stage_num_blocks = stage_num_blocks[stage_index] - self.current_stage_num_layers_per_block = stage_numb_of_layers[ + self.current_stage_num_layers_per_block = stage_num_of_layers[ + stage_index + ] + self.current_stage_is_downsample_active = apply_downsample[stage_index] + self.current_stage_is_light_block = use_lightweight_conv_block[ stage_index ] - self.current_stage_is_downsample_active = stage_downsample[stage_index] - self.current_stage_is_light_block = stage_light_block[stage_index] self.current_stage_kernel_size = stage_kernel_size[stage_index] self.current_stage_use_lab = use_learnable_affine_block self.current_stage_drop_path = drop_path @@ -913,9 +915,9 @@ def get_config(self): "stage_mid_channels": self.stage_mid_channels, "stage_out_channels": self.stage_out_channels, "stage_num_blocks": self.stage_num_blocks, - "stage_numb_of_layers": self.stage_numb_of_layers, - "stage_downsample": self.stage_downsample, - "stage_light_block": self.stage_light_block, + "stage_num_of_layers": self.stage_num_of_layers, + "apply_downsample": self.apply_downsample, + "use_lightweight_conv_block": self.use_lightweight_conv_block, "stage_kernel_size": self.stage_kernel_size, "use_learnable_affine_block": self.use_learnable_affine_block, "stage_index": self.stage_index, diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_presets.py b/keras_hub/src/models/hgnetv2/hgnetv2_presets.py index e869a31cbb..80443d8cf5 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_presets.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_presets.py @@ -1,6 +1,6 @@ # Metadata for loading pretrained model weights. backbone_presets = { - "hgnetv2_b4.ssld_stage2_ft_in1k": { + "hgnetv2_b4_ssld_stage2_ft_in1k": { "metadata": { "description": ( "HGNetV2 B4 model with 2-stage SSLD training, fine-tuned on " @@ -11,7 +11,7 @@ }, "kaggle_handle": "", }, - "hgnetv2_b5.ssld_stage1_in22k_in1k": { + "hgnetv2_b5_ssld_stage1_in22k_in1k": { "metadata": { "description": ( "HGNetV2 B5 model with 1-stage SSLD training, pre-trained on " @@ -22,7 +22,7 @@ }, "kaggle_handle": "", }, - "hgnetv2_b5.ssld_stage2_ft_in1k": { + "hgnetv2_b5_ssld_stage2_ft_in1k": { "metadata": { "description": ( "HGNetV2 B5 model with 2-stage SSLD training, fine-tuned on " @@ -33,7 +33,7 @@ }, "kaggle_handle": "", }, - "hgnetv2_b6.ssld_stage1_in22k_in1k": { + "hgnetv2_b6_ssld_stage1_in22k_in1k": { "metadata": { "description": ( "HGNetV2 B6 model with 1-stage SSLD training, pre-trained on " @@ -44,7 +44,7 @@ }, "kaggle_handle": "", }, - "hgnetv2_b6.ssld_stage2_ft_in1k": { + "hgnetv2_b6_ssld_stage2_ft_in1k": { "metadata": { "description": ( "HGNetV2 B6 model with 2-stage SSLD training, fine-tuned on " diff --git a/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py b/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py index 0c23694b8f..3a2837fc51 100644 --- a/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py +++ b/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py @@ -41,18 +41,18 @@ FLAGS = flags.FLAGS PRESET_MAP = { - "hgnetv2_b6.ssld_stage2_ft_in1k": "timm/hgnetv2_b6.ssld_stage2_ft_in1k", - "hgnetv2_b6.ssld_stage1_in22k_in1k": "timm/hgnetv2_b6.ssld_stage1_in22k_in1k", # noqa: E501 - "hgnetv2_b5.ssld_stage2_ft_in1k": "timm/hgnetv2_b5.ssld_stage2_ft_in1k", - "hgnetv2_b5.ssld_stage1_in22k_in1k": "timm/hgnetv2_b5.ssld_stage1_in22k_in1k", # noqa: E501 - "hgnetv2_b4.ssld_stage2_ft_in1k": "timm/hgnetv2_b4.ssld_stage2_ft_in1k", + "hgnetv2_b6_ssld_stage2_ft_in1k": "timm/hgnetv2_b6.ssld_stage2_ft_in1k", + "hgnetv2_b6_ssld_stage1_in22k_in1k": "timm/hgnetv2_b6.ssld_stage1_in22k_in1k", # noqa: E501 + "hgnetv2_b5_ssld_stage2_ft_in1k": "timm/hgnetv2_b5.ssld_stage2_ft_in1k", + "hgnetv2_b5_ssld_stage1_in22k_in1k": "timm/hgnetv2_b5.ssld_stage1_in22k_in1k", # noqa: E501 + "hgnetv2_b4_ssld_stage2_ft_in1k": "timm/hgnetv2_b4.ssld_stage2_ft_in1k", } LAB_FALSE_PRESETS = [ - "hgnetv2_b6.ssld_stage2_ft_in1k", - "hgnetv2_b6.ssld_stage1_in22k_in1k", - "hgnetv2_b5.ssld_stage2_ft_in1k", - "hgnetv2_b5.ssld_stage1_in22k_in1k", - "hgnetv2_b4.ssld_stage2_ft_in1k", + "hgnetv2_b6_ssld_stage2_ft_in1k", + "hgnetv2_b6_ssld_stage1_in22k_in1k", + "hgnetv2_b5_ssld_stage2_ft_in1k", + "hgnetv2_b5_ssld_stage1_in22k_in1k", + "hgnetv2_b4_ssld_stage2_ft_in1k", ] flags.DEFINE_string( @@ -64,98 +64,98 @@ HGNETV2_CONFIGS = { "hgnetv2_b0": { "stem_channels": [3, 16, 16], - "stage_in_channels": [16, 64, 256, 512], - "stage_mid_channels": [16, 32, 64, 128], - "stage_out_channels": [64, 256, 512, 1024], - "stage_num_blocks": [1, 1, 2, 1], - "stage_numb_of_layers": [3, 3, 3, 3], - "stage_downsample": [False, True, True, True], - "stage_light_block": [False, False, True, True], - "stage_kernel_size": [3, 3, 5, 5], + "stackwise_stage_filters": [ + [16, 16, 64, 1, 3, 3], + [64, 32, 256, 1, 3, 3], + [256, 64, 512, 2, 3, 5], + [512, 128, 1024, 1, 3, 5], + ], + "apply_downsample": [False, True, True, True], + "use_lightweight_conv_block": [False, False, True, True], "embedding_size": 16, "hidden_sizes": [64, 256, 512, 1024], "depths": [1, 1, 2, 1], }, "hgnetv2_b1": { "stem_channels": [3, 24, 32], - "stage_in_channels": [32, 64, 256, 512], - "stage_mid_channels": [32, 48, 96, 192], - "stage_out_channels": [64, 256, 512, 1024], - "stage_num_blocks": [1, 1, 2, 1], - "stage_numb_of_layers": [3, 3, 3, 3], - "stage_downsample": [False, True, True, True], - "stage_light_block": [False, False, True, True], - "stage_kernel_size": [3, 3, 5, 5], + "stackwise_stage_filters": [ + [32, 32, 64, 1, 3, 3], + [64, 48, 256, 1, 3, 3], + [256, 96, 512, 2, 3, 5], + [512, 192, 1024, 1, 3, 5], + ], + "apply_downsample": [False, True, True, True], + "use_lightweight_conv_block": [False, False, True, True], "embedding_size": 32, "hidden_sizes": [64, 256, 512, 1024], "depths": [1, 1, 2, 1], }, "hgnetv2_b2": { "stem_channels": [3, 24, 32], - "stage_in_channels": [32, 96, 384, 768], - "stage_mid_channels": [32, 64, 128, 256], - "stage_out_channels": [96, 384, 768, 1536], - "stage_num_blocks": [1, 1, 3, 1], - "stage_numb_of_layers": [4, 4, 4, 4], - "stage_downsample": [False, True, True, True], - "stage_light_block": [False, False, True, True], - "stage_kernel_size": [3, 3, 5, 5], + "stackwise_stage_filters": [ + [32, 32, 96, 1, 4, 3], + [96, 64, 384, 1, 4, 3], + [384, 128, 768, 3, 4, 5], + [768, 256, 1536, 1, 4, 5], + ], + "apply_downsample": [False, True, True, True], + "use_lightweight_conv_block": [False, False, True, True], "embedding_size": 32, "hidden_sizes": [96, 384, 768, 1536], "depths": [1, 1, 3, 1], }, "hgnetv2_b3": { "stem_channels": [3, 24, 32], - "stage_in_channels": [32, 128, 512, 1024], - "stage_mid_channels": [32, 64, 128, 256], - "stage_out_channels": [128, 512, 1024, 2048], - "stage_num_blocks": [1, 1, 3, 1], - "stage_numb_of_layers": [5, 5, 5, 5], - "stage_downsample": [False, True, True, True], - "stage_light_block": [False, False, True, True], - "stage_kernel_size": [3, 3, 5, 5], + "stackwise_stage_filters": [ + [32, 32, 128, 1, 5, 3], + [128, 64, 512, 1, 5, 3], + [512, 128, 1024, 3, 5, 5], + [1024, 256, 2048, 1, 5, 5], + ], + "apply_downsample": [False, True, True, True], + "use_lightweight_conv_block": [False, False, True, True], "embedding_size": 32, "hidden_sizes": [128, 512, 1024, 2048], "depths": [1, 1, 3, 1], }, "hgnetv2_b4": { "stem_channels": [3, 32, 48], - "stage_in_channels": [48, 128, 512, 1024], - "stage_mid_channels": [48, 96, 192, 384], - "stage_out_channels": [128, 512, 1024, 2048], - "stage_num_blocks": [1, 1, 3, 1], - "stage_numb_of_layers": [6, 6, 6, 6], - "stage_downsample": [False, True, True, True], - "stage_light_block": [False, False, True, True], - "stage_kernel_size": [3, 3, 5, 5], + "stackwise_stage_filters": [ + [48, 48, 128, 1, 6, 3], + [128, 96, 512, 1, 6, 3], + [512, 192, 1024, 3, 6, 5], + [1024, 384, 2048, 1, 6, 5], + ], + "apply_downsample": [False, True, True, True], + "use_lightweight_conv_block": [False, False, True, True], "embedding_size": 48, "hidden_sizes": [128, 512, 1024, 2048], "depths": [1, 1, 3, 1], }, "hgnetv2_b5": { "stem_channels": [3, 32, 64], - "stage_in_channels": [64, 128, 512, 1024], - "stage_mid_channels": [64, 128, 256, 512], - "stage_out_channels": [128, 512, 1024, 2048], - "stage_num_blocks": [1, 2, 5, 2], - "stage_numb_of_layers": [6, 6, 6, 6], - "stage_downsample": [False, True, True, True], - "stage_light_block": [False, False, True, True], - "stage_kernel_size": [3, 3, 5, 5], + "stackwise_stage_filters": [ + [64, 64, 128, 1, 6, 3], + [128, 128, 512, 2, 6, 3], + [512, 256, 1024, 5, 6, 5], + [1024, 512, 2048, 2, 6, 5], + ], + "apply_downsample": [False, True, True, True], + "use_lightweight_conv_block": [False, False, True, True], "embedding_size": 64, "hidden_sizes": [128, 512, 1024, 2048], "depths": [1, 2, 5, 2], }, "hgnetv2_b6": { "stem_channels": [3, 48, 96], - "stage_in_channels": [96, 192, 512, 1024], - "stage_mid_channels": [96, 192, 384, 768], - "stage_out_channels": [192, 512, 1024, 2048], - "stage_num_blocks": [2, 3, 6, 3], - "stage_numb_of_layers": [6, 6, 6, 6], - "stage_downsample": [False, True, True, True], - "stage_light_block": [False, False, True, True], - "stage_kernel_size": [3, 3, 5, 5], + "stackwise_stage_filters": [ + [96, 96, 192, 2, 6, 3], + [192, 192, 512, 3, 6, 3], + [512, 384, 1024, 6, 6, 5], + [1024, 768, 2048, 3, 6, 5], + ], + "apply_downsample": [False, True, True, True], + "use_lightweight_conv_block": [False, False, True, True], "embedding_size": 96, "hidden_sizes": [192, 512, 1024, 2048], "depths": [2, 3, 6, 3], @@ -181,7 +181,6 @@ def convert_model(hf_config, architecture, preset_name): backbone = HGNetV2Backbone( image_shape=(image_size, image_size, 3), - initializer_range=0.02, depths=config["depths"], embedding_size=config["embedding_size"], hidden_sizes=config["hidden_sizes"], @@ -189,14 +188,9 @@ def convert_model(hf_config, architecture, preset_name): hidden_act="relu", use_learnable_affine_block=use_lab, num_channels=3, - stage_in_channels=config["stage_in_channels"], - stage_mid_channels=config["stage_mid_channels"], - stage_out_channels=config["stage_out_channels"], - stage_num_blocks=config["stage_num_blocks"], - stage_numb_of_layers=config["stage_numb_of_layers"], - stage_downsample=config["stage_downsample"], - stage_light_block=config["stage_light_block"], - stage_kernel_size=config["stage_kernel_size"], + stackwise_stage_filters=config["stackwise_stage_filters"], + apply_downsample=config["apply_downsample"], + use_lightweight_conv_block=config["use_lightweight_conv_block"], ) pretrained_cfg = hf_config["pretrained_cfg"] image_size = ( @@ -205,13 +199,13 @@ def convert_model(hf_config, architecture, preset_name): ) mean = pretrained_cfg["mean"] std = pretrained_cfg["std"] - crop_pct = pretrained_cfg.get("crop_pct", 0.875) interpolation = pretrained_cfg["interpolation"] + scale = [1 / (255.0 * s) for s in std] if std else 1 / 255.0 + offset = [-m / s for m, s in zip(mean, std)] if mean and std else 0 image_converter = HGNetV2ImageConverter( image_size=image_size, - crop_pct=crop_pct, - mean=mean, - std=std, + scale=scale, + offset=offset, interpolation=interpolation, antialias=True if interpolation == "bicubic" else False, ) @@ -222,9 +216,6 @@ def convert_model(hf_config, architecture, preset_name): backbone=backbone, preprocessor=preprocessor, num_classes=hf_config["num_classes"], - initializer_range=0.02, - head_filters=hf_model.head_hidden_size, - use_learnable_affine_block_head=use_lab, ) return keras_model, config, image_size @@ -407,13 +398,13 @@ def convert_image_converter(hf_config): ) mean = pretrained_cfg["mean"] std = pretrained_cfg["std"] - crop_pct = pretrained_cfg.get("crop_pct", 0.875) interpolation = pretrained_cfg["interpolation"] + scale = [1 / (255.0 * s) for s in std] if std else 1 / 255.0 + offset = [-m / s for m, s in zip(mean, std)] if mean and std else 0 image_converter = HGNetV2ImageConverter( image_size=image_size, - crop_pct=crop_pct, - mean=mean, - std=std, + scale=scale, + offset=offset, interpolation=interpolation, antialias=True if interpolation == "bicubic" else False, ) From e6aa9e4744adb27bf333131b5ea499ac4acabe63 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Sat, 28 Jun 2025 18:29:39 +0400 Subject: [PATCH 5/8] bug: Unexpected argument error in JAX with Keras 3.5 --- keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py index 907902f73d..343a4cb8b9 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py @@ -26,7 +26,6 @@ def setUp(self): self.train_data = (self.images, self.labels) # Setup model. self.backbone = HGNetV2Backbone( - initializer_range=0.02, depths=[1, 1], embedding_size=32, hidden_sizes=[64, 128], From 68452ab14fa57b5099650979ee128c38e7b2a4a5 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Fri, 4 Jul 2025 13:33:09 +0400 Subject: [PATCH 6/8] small addition for the D-FINE to come: No changes to the existing HGNetV2 --- .../src/models/hgnetv2/hgnetv2_backbone.py | 17 +++++++---- .../models/hgnetv2/hgnetv2_backbone_test.py | 28 ++++++++++++++++--- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py b/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py index 7127cc2d6d..49f79ff0d9 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py @@ -32,7 +32,7 @@ class HGNetV2Backbone(Backbone): stage_out_channels, stage_num_blocks, stage_num_of_layers, stage_kernel_size). - stage_in_channels: int, input channels for the stage - - stage_mid_channels: int middle channels for the stage + - stage_mid_channels: int, middle channels for the stage - stage_out_channels: int, output channels for the stage - stage_num_blocks: int, number of blocks in the stage - stage_num_of_layers: int, number of layers in each block @@ -45,6 +45,9 @@ class HGNetV2Backbone(Backbone): data_format: `None` or str, the data format ('channels_last' or 'channels_first'). If not specified, defaults to the `image_data_format` value in your Keras config. + out_features: list of str or `None`, the names of the output features to + return. If `None`, returns all available features from all stages. + Defaults to `None`. dtype: `None` or str or `keras.mixed_precision.DTypePolicy`, the data type for computations and weights. @@ -96,6 +99,7 @@ def __init__( use_lightweight_conv_block, image_shape=(None, None, 3), data_format=None, + out_features=None, dtype=None, **kwargs, ): @@ -136,10 +140,12 @@ def __init__( name=f"{name}_encoder" if name else "encoder", dtype=dtype, ) - self.stage_names = [ - f"stage{i}" for i in range(len(stackwise_stage_filters)) + self.stage_names = ["stem"] + [ + f"stage{i + 1}" for i in range(len(stackwise_stage_filters)) ] - self.out_features = self.stage_names + self.out_features = ( + out_features if out_features is not None else self.stage_names + ) # === Functional Model === pixel_values = keras.layers.Input( @@ -148,7 +154,7 @@ def __init__( embedding_output = self.embedder_layer(pixel_values) all_encoder_hidden_states_tuple = self.encoder_layer(embedding_output) feature_maps_output = { - stage_name: all_encoder_hidden_states_tuple[idx + 1] + stage_name: all_encoder_hidden_states_tuple[idx] for idx, stage_name in enumerate(self.stage_names) if stage_name in self.out_features } @@ -184,6 +190,7 @@ def get_config(self): "apply_downsample": self.apply_downsample, "use_lightweight_conv_block": self.use_lightweight_conv_block, "image_shape": self.image_shape, + "out_features": self.out_features, "data_format": self.data_format, } ) diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py b/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py index abeb6bc6fe..448e36f1c7 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py @@ -33,6 +33,9 @@ def setUp(self): "stackwise_stage_filters": self.stackwise_stage_filters, "apply_downsample": [False, True], "use_lightweight_conv_block": [False, False], + # Explicitly pass the out_features arg to ensure comprehensive + # test coverage for D-FINE. + "out_features": ["stem", "stage1", "stage2"], } self.input_data = keras.ops.convert_to_tensor( np.random.rand(self.batch_size, *self.default_input_shape).astype( @@ -46,28 +49,43 @@ def setUp(self): [False, True], [False, False], 2, - {"stage0": (2, 16, 16, 64), "stage1": (2, 8, 8, 128)}, + { + "stem": (2, 16, 16, 32), + "stage1": (2, 16, 16, 64), + "stage2": (2, 8, 8, 128), + }, ), ( "early_downsample_light_blocks", [True, True], [True, True], 2, - {"stage0": (2, 8, 8, 64), "stage1": (2, 4, 4, 128)}, + { + "stem": (2, 16, 16, 32), + "stage1": (2, 8, 8, 64), + "stage2": (2, 4, 4, 128), + }, ), ( "single_stage_no_downsample", [False], [False], 1, - {"stage0": (2, 16, 16, 64)}, + { + "stem": (2, 16, 16, 32), + "stage1": (2, 16, 16, 64), + }, ), ( "all_no_downsample", [False, False], [False, False], 2, - {"stage0": (2, 16, 16, 64), "stage1": (2, 16, 16, 128)}, + { + "stem": (2, 16, 16, 32), + "stage1": (2, 16, 16, 64), + "stage2": (2, 16, 16, 128), + }, ), ) def test_backbone_basics( @@ -86,6 +104,8 @@ def test_backbone_basics( "stackwise_stage_filters": test_filters, "apply_downsample": apply_downsample, "use_lightweight_conv_block": use_lightweight_conv_block, + "out_features": ["stem"] + + [f"stage{i + 1}" for i in range(num_stages)], } self.run_vision_backbone_test( cls=HGNetV2Backbone, From 1c4be5fe4ba9af928b217cb7f5ce4eb70e874ea0 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Tue, 8 Jul 2025 14:31:24 +0400 Subject: [PATCH 7/8] D-FINE JIT compile: Remove non-essential conditional statement --- keras_hub/src/models/hgnetv2/hgnetv2_layers.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_layers.py b/keras_hub/src/models/hgnetv2/hgnetv2_layers.py index 1a9eb1fd12..3c08a57829 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_layers.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_layers.py @@ -475,13 +475,6 @@ def compute_output_shape(self, input_shape): return final_shape def call(self, pixel_values, training=None): - num_channels_check = keras.ops.shape(pixel_values)[self.channel_axis] - if num_channels_check != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values " - "match with the one set in the configuration. Expected " - f"{self.num_channels} but got {num_channels_check}." - ) embedding = self.stem1_layer(pixel_values, training=training) embedding_padded_for_2a_and_pool = self.padding1(embedding) emb_stem_2a = self.stem2a_layer( From ce9b0787e662d33a944bae58ced8c33ba5961fcd Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Fri, 11 Jul 2025 12:46:00 +0400 Subject: [PATCH 8/8] refactor: Address reviews and fix some nits --- .../src/models/hgnetv2/hgnetv2_backbone.py | 20 +++++++---------- .../models/hgnetv2/hgnetv2_backbone_test.py | 1 - .../src/models/hgnetv2/hgnetv2_encoder.py | 1 + .../hgnetv2/hgnetv2_image_classifier.py | 22 +++++++++---------- .../hgnetv2/hgnetv2_image_classifier_test.py | 5 ----- .../src/models/hgnetv2/hgnetv2_layers.py | 17 ++++++-------- .../convert_hgnetv2_checkpoints.py | 11 +++------- 7 files changed, 29 insertions(+), 48 deletions(-) diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py b/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py index 49f79ff0d9..12407b0f75 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py @@ -26,7 +26,6 @@ class HGNetV2Backbone(Backbone): hidden_act: str, the activation function for hidden layers. use_learnable_affine_block: bool, whether to use learnable affine transformations. - num_channels: int, the number of channels in the input image. stackwise_stage_filters: list of tuples, where each tuple contains configuration for a stage: (stage_in_channels, stage_mid_channels, stage_out_channels, stage_num_blocks, stage_num_of_layers, @@ -71,7 +70,6 @@ class HGNetV2Backbone(Backbone): stem_channels=[3, 16, 32], hidden_act="relu", use_learnable_affine_block=False, - num_channels=3, stackwise_stage_filters=[ (32, 16, 64, 1, 1, 3), # Stage 0 (64, 32, 128, 2, 1, 3), # Stage 1 @@ -93,7 +91,6 @@ def __init__( stem_channels, hidden_act, use_learnable_affine_block, - num_channels, stackwise_stage_filters, apply_downsample, use_lightweight_conv_block, @@ -107,19 +104,20 @@ def __init__( data_format = standardize_data_format(data_format) channel_axis = -1 if data_format == "channels_last" else 1 self.image_shape = image_shape - stage_in_channels = [stage[0] for stage in stackwise_stage_filters] - stage_mid_channels = [stage[1] for stage in stackwise_stage_filters] - stage_out_filters = [stage[2] for stage in stackwise_stage_filters] - stage_num_blocks = [stage[3] for stage in stackwise_stage_filters] - stage_num_of_layers = [stage[4] for stage in stackwise_stage_filters] - stage_kernel_size = [stage[5] for stage in stackwise_stage_filters] + ( + stage_in_channels, + stage_mid_channels, + stage_out_filters, + stage_num_blocks, + stage_num_of_layers, + stage_kernel_size, + ) = zip(*stackwise_stage_filters) # === Layers === self.embedder_layer = HGNetV2Embeddings( stem_channels=stem_channels, hidden_act=hidden_act, use_learnable_affine_block=use_learnable_affine_block, - num_channels=num_channels, data_format=data_format, channel_axis=channel_axis, name=f"{name}_embedder" if name else "embedder", @@ -169,7 +167,6 @@ def __init__( self.stem_channels = stem_channels self.hidden_act = hidden_act self.use_learnable_affine_block = use_learnable_affine_block - self.num_channels = num_channels self.stackwise_stage_filters = stackwise_stage_filters self.apply_downsample = apply_downsample self.use_lightweight_conv_block = use_lightweight_conv_block @@ -185,7 +182,6 @@ def get_config(self): "stem_channels": self.stem_channels, "hidden_act": self.hidden_act, "use_learnable_affine_block": self.use_learnable_affine_block, - "num_channels": self.num_channels, "stackwise_stage_filters": self.stackwise_stage_filters, "apply_downsample": self.apply_downsample, "use_lightweight_conv_block": self.use_lightweight_conv_block, diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py b/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py index 448e36f1c7..31c48d2c29 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py @@ -24,7 +24,6 @@ def setUp(self): "stem_channels": self.stem_channels, "hidden_act": "relu", "use_learnable_affine_block": False, - "num_channels": self.num_channels, "image_shape": self.default_input_shape, "depths": [1] * self.default_num_stages, "hidden_sizes": [ diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py b/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py index 201cc72b7e..9b4f2c98bf 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py @@ -142,6 +142,7 @@ def get_config(self): "stage_kernel_size": self.stage_kernel_size, "use_learnable_affine_block": self.use_learnable_affine_block, "data_format": self.data_format, + "channel_axis": self.channel_axis, } ) return config diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py index df60e03ef7..bed9831563 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py @@ -119,19 +119,12 @@ def __init__( head_dtype = head_dtype or backbone.dtype_policy data_format = getattr(backbone, "data_format", "channels_last") channel_axis = -1 if data_format == "channels_last" else 1 - - # NOTE: This isn't in the usual order because the config is needed - # before layer initialization and the functional model. - # === Config === - self.num_classes = num_classes - self.pooling = pooling - self.activation = activation - self.dropout = dropout self.head_filters = ( head_filters if head_filters is not None else backbone.hidden_sizes[-1] ) + self.activation = activation # === Layers === self.backbone = backbone @@ -149,13 +142,13 @@ def __init__( name="head_last", dtype=head_dtype, ) - if self.pooling == "avg": + if pooling == "avg": self.pooler = keras.layers.GlobalAveragePooling2D( data_format=data_format, dtype=head_dtype, name=f"{name}_avg_pool" if name else "avg_pool", ) - elif self.pooling == "max": + elif pooling == "max": self.pooler = keras.layers.GlobalMaxPooling2D( data_format=data_format, dtype=head_dtype, @@ -173,11 +166,11 @@ def __init__( name=f"{name}_flatten" if name else "flatten", ) self.output_dropout = keras.layers.Dropout( - rate=self.dropout, + rate=dropout, dtype=head_dtype, name=f"{name}_output_dropout" if name else "output_dropout", ) - if self.num_classes > 0: + if num_classes > 0: self.output_dense = keras.layers.Dense( units=num_classes, activation=activation, @@ -204,6 +197,11 @@ def __init__( **kwargs, ) + # === Config === + self.pooling = pooling + self.dropout = dropout + self.num_classes = num_classes + def get_config(self): config = Task.get_config(self) config.update( diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py index 343a4cb8b9..f294a23b72 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_test.py @@ -32,7 +32,6 @@ def setUp(self): stem_channels=[self.image_shape[-1], 16, 32], hidden_act="relu", use_learnable_affine_block=False, - num_channels=self.image_shape[-1], stackwise_stage_filters=[ [32, 16, 64, 1, 1, 3], [64, 32, 128, 1, 1, 3], @@ -60,10 +59,6 @@ def setUp(self): "preprocessor": self.preprocessor, "num_classes": self.num_classes, } - self.expected_backbone_output_shapes = { - "stage0": (self.batch_size, 16, 16, 64), - "stage1": (self.batch_size, 8, 8, 128), - } self.preset_image_shape = (224, 224, 3) self.images_for_presets = np.ones( (self.batch_size, *self.preset_image_shape), dtype="float32" diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_layers.py b/keras_hub/src/models/hgnetv2/hgnetv2_layers.py index 3c08a57829..4424e45283 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_layers.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_layers.py @@ -28,18 +28,14 @@ def build(self, input_shape): shape=(), initializer=keras.initializers.Constant(self.scale_value), trainable=True, - dtype=self.dtype_policy.name - if isinstance(self.dtype_policy, keras.mixed_precision.DTypePolicy) - else self.dtype_policy, + dtype=self.dtype, ) self.bias = self.add_weight( name="bias", shape=(), initializer=keras.initializers.Constant(self.bias_value), trainable=True, - dtype=self.dtype_policy.name - if isinstance(self.dtype_policy, keras.mixed_precision.DTypePolicy) - else self.dtype_policy, + dtype=self.dtype, ) super().build(input_shape) @@ -197,6 +193,7 @@ def get_config(self): "activation": self.activation_name, "use_learnable_affine_block": self.use_learnable_affine_block, "data_format": self.data_format, + "channel_axis": self.channel_axis, } ) return config @@ -290,6 +287,7 @@ def get_config(self): "kernel_size": self.kernel_size, "use_learnable_affine_block": self.use_learnable_affine_block, "data_format": self.data_format, + "channel_axis": self.channel_axis, } ) return config @@ -315,7 +313,6 @@ class HGNetV2Embeddings(keras.layers.Layer): layers. use_learnable_affine_block: bool. Whether to use learnable affine blocks in the convolutional layers. - num_channels: int. Number of input channels (e.g., 3 for RGB images). data_format: string, optional. Data format of the input. Defaults to None. channel_axis: int, optional. Axis of the channel dimension. Defaults to @@ -328,7 +325,6 @@ def __init__( stem_channels, hidden_act, use_learnable_affine_block, - num_channels, data_format=None, channel_axis=None, **kwargs, @@ -337,7 +333,6 @@ def __init__( self.stem_channels = stem_channels self.hidden_act = hidden_act self.use_learnable_affine_block = use_learnable_affine_block - self.num_channels = num_channels self.data_format = data_format self.channel_axis = channel_axis self.stem1_layer = HGNetV2ConvLayer( @@ -503,8 +498,8 @@ def get_config(self): "stem_channels": self.stem_channels, "hidden_act": self.hidden_act, "use_learnable_affine_block": self.use_learnable_affine_block, - "num_channels": self.num_channels, "data_format": self.data_format, + "channel_axis": self.channel_axis, } ) return config @@ -746,6 +741,7 @@ def get_config(self): "drop_path": self.drop_path_rate, "use_learnable_affine_block": self.use_learnable_affine_block, "data_format": self.data_format, + "channel_axis": self.channel_axis, } ) return config @@ -916,6 +912,7 @@ def get_config(self): "stage_index": self.stage_index, "drop_path": self.drop_path, "data_format": self.data_format, + "channel_axis": self.channel_axis, } ) return config diff --git a/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py b/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py index 3a2837fc51..cca794e5ee 100644 --- a/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py +++ b/tools/checkpoint_conversion/convert_hgnetv2_checkpoints.py @@ -31,6 +31,7 @@ from keras_hub.src.models.hgnetv2.hgnetv2_image_converter import ( HGNetV2ImageConverter, ) +from keras_hub.src.models.hgnetv2.hgnetv2_layers import HGNetV2ConvLayerLight from keras_hub.src.models.hgnetv2.hgnetv2_layers import ( HGNetV2LearnableAffineBlock, ) @@ -163,7 +164,7 @@ } -def load_hf_config(hf_preset): +def load_hf_config(hf_preset, hf_model): config_path = keras.utils.get_file( origin=f"https://huggingface.co/{hf_preset}/raw/main/config.json", cache_subdir=f"hf_models/{hf_preset}", @@ -187,7 +188,6 @@ def convert_model(hf_config, architecture, preset_name): stem_channels=config["stem_channels"], hidden_act="relu", use_learnable_affine_block=use_lab, - num_channels=3, stackwise_stage_filters=config["stackwise_stage_filters"], apply_downsample=config["apply_downsample"], use_lightweight_conv_block=config["use_lightweight_conv_block"], @@ -348,10 +348,6 @@ def port_embeddings(keras_embeddings, weight_key_prefix): port_conv(keras_embeddings.stem4_layer, f"{weight_key_prefix}.stem4") def port_basic_layer(keras_basic_layer, weight_key_prefix): - from keras_hub.src.models.hgnetv2.hgnetv2_layers import ( - HGNetV2ConvLayerLight, - ) - for i, layer in enumerate(keras_basic_layer.layer_list): layer_prefix = f"{weight_key_prefix}.layers.{i}" if isinstance(layer, HGNetV2ConvLayerLight): @@ -455,7 +451,6 @@ def main(_): os.makedirs(preset) print(f"\nšŸƒ Converting {preset}") - global hf_model hf_model = create_model(hf_preset, pretrained=True) safetensors_file = keras.utils.get_file( origin=f"https://huggingface.co/{hf_preset}/resolve/main/model.safetensors", @@ -476,7 +471,7 @@ def main(_): ) state_dict = safetensors.torch.load_file(safetensors_file) hf_model.eval() - hf_config = load_hf_config(hf_preset) + hf_config = load_hf_config(hf_preset, hf_model) architecture = hf_config["architecture"] keras_model, _, _ = convert_model(hf_config, architecture, preset) print("āœ… KerasHub model loaded.")