diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index ecdebaa7f3..de998d758c 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -75,6 +75,9 @@ from keras_hub.src.models.cspnet.cspnet_image_converter import ( CSPNetImageConverter as CSPNetImageConverter, ) +from keras_hub.src.models.d_fine.d_fine_image_converter import ( + DFineImageConverter as DFineImageConverter, +) from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( DeepLabV3ImageConverter as DeepLabV3ImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 741c0febf0..19d4fce905 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -108,6 +108,15 @@ from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import ( CSPNetImageClassifierPreprocessor as CSPNetImageClassifierPreprocessor, ) +from keras_hub.src.models.d_fine.d_fine_backbone import ( + DFineBackbone as DFineBackbone, +) +from keras_hub.src.models.d_fine.d_fine_object_detector import ( + DFineObjectDetector as DFineObjectDetector, +) +from keras_hub.src.models.d_fine.d_fine_object_detector_preprocessor import ( + DFineObjectDetectorPreprocessor as DFineObjectDetectorPreprocessor, +) from keras_hub.src.models.deberta_v3.deberta_v3_backbone import ( DebertaV3Backbone as DebertaV3Backbone, ) diff --git a/keras_hub/src/models/d_fine/__init__.py b/keras_hub/src/models/d_fine/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/d_fine/d_fine_attention.py b/keras_hub/src/models/d_fine/d_fine_attention.py new file mode 100644 index 0000000000..9de071642c --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_attention.py @@ -0,0 +1,461 @@ +import math + +import keras + +from keras_hub.src.models.d_fine.d_fine_utils import ( + multi_scale_deformable_attention_v2, +) + + +class DFineMultiscaleDeformableAttention(keras.layers.Layer): + """Multi-scale deformable attention layer for D-FINE models. + + This layer implements the multi-scale deformable attention mechanism, which + is the core of the cross-attention in each `DFineDecoderLayer`. It allows + the model to attend to a small set of key sampling points around a reference + point across multiple feature levels from the encoder. + + The layer computes sampling locations and attention weights based on the + input queries, enabling the model to focus on relevant features across + multiple feature levels and spatial positions. + + Args: + hidden_dim: int, Hidden dimension size for the attention mechanism. + decoder_attention_heads: int, Number of attention heads. + num_feature_levels: int, Number of feature levels to attend to. + decoder_offset_scale: float, Scaling factor for sampling offsets. + decoder_method: str, Method used for deformable attention computation. + decoder_n_points: int or list, Number of sampling points per level. + If int, the same number of points is used for all levels. + If list, specifies points for each level individually. + num_queries: int, Number of queries in the attention mechanism. + spatial_shapes: list, List of spatial shapes for different + feature levels. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + hidden_dim, + decoder_attention_heads, + num_feature_levels, + decoder_offset_scale, + decoder_method, + decoder_n_points, + num_queries, + spatial_shapes, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_dim = hidden_dim + self.num_queries = num_queries + self.n_heads = decoder_attention_heads + self.n_levels = num_feature_levels + self.offset_scale = decoder_offset_scale + self.decoder_method = decoder_method + self.decoder_n_points = decoder_n_points + self.spatial_shapes = spatial_shapes + if isinstance(self.decoder_n_points, list): + self.num_points = self.decoder_n_points + else: + self.num_points = [ + self.decoder_n_points for _ in range(self.n_levels) + ] + self._num_points_scale = [ + 1.0 / n_points_at_level + for n_points_at_level in self.num_points + for _ in range(n_points_at_level) + ] + self.total_points = self.n_heads * sum(self.num_points) + self.ms_deformable_attn_core = multi_scale_deformable_attention_v2 + + def build(self, input_shape): + sampling_offsets_output_shape = ( + input_shape[1], + self.n_heads, + sum(self.num_points), + 2, + ) + self.sampling_offsets = keras.layers.EinsumDense( + "abc,cdef->abdef", + output_shape=sampling_offsets_output_shape, + bias_axes="def", + kernel_initializer="zeros", + bias_initializer="zeros", + name="sampling_offsets", + dtype=self.dtype_policy, + ) + self.sampling_offsets.build(input_shape) + attention_weights_output_shape = ( + input_shape[1], + self.n_heads, + sum(self.num_points), + ) + self.attention_weights = keras.layers.EinsumDense( + "abc,cde->abde", + output_shape=attention_weights_output_shape, + bias_axes="de", + kernel_initializer="zeros", + bias_initializer="zeros", + name="attention_weights", + dtype=self.dtype_policy, + ) + self.attention_weights.build(input_shape) + if self.sampling_offsets.bias is not None: + thetas = keras.ops.arange( + self.n_heads, dtype=self.variable_dtype + ) * (2.0 * math.pi / self.n_heads) + grid_init = keras.ops.stack( + [keras.ops.cos(thetas), keras.ops.sin(thetas)], axis=-1 + ) + grid_init = grid_init / keras.ops.max( + keras.ops.abs(grid_init), axis=-1, keepdims=True + ) + grid_init = keras.ops.reshape(grid_init, (self.n_heads, 1, 2)) + grid_init = keras.ops.tile(grid_init, [1, sum(self.num_points), 1]) + scaling = [] + for n in self.num_points: + scaling.append( + keras.ops.arange(1, n + 1, dtype=self.variable_dtype) + ) + scaling = keras.ops.concatenate(scaling, axis=0) + scaling = keras.ops.reshape(scaling, (1, -1, 1)) + grid_init *= scaling + self.sampling_offsets.bias.assign(grid_init) + self.num_points_scale = self.add_weight( + name="num_points_scale", + shape=(len(self._num_points_scale),), + initializer=keras.initializers.Constant(self._num_points_scale), + trainable=False, + ) + super().build(input_shape) + + def compute_attention( + self, hidden_states, reference_points, spatial_shapes + ): + batch_size = keras.ops.shape(hidden_states)[0] + num_queries = keras.ops.shape(hidden_states)[1] + sampling_offsets = self.sampling_offsets(hidden_states) + attention_weights = self.attention_weights(hidden_states) + attention_weights = keras.ops.softmax(attention_weights, axis=-1) + + if keras.ops.shape(reference_points)[-1] == 2: + offset_normalizer = keras.ops.cast( + spatial_shapes, dtype=hidden_states.dtype + ) + offset_normalizer = keras.ops.flip(offset_normalizer, axis=1) + offset_normalizer = keras.ops.reshape( + offset_normalizer, (1, 1, 1, self.n_levels, 1, 2) + ) + sampling_locations = ( + keras.ops.reshape( + reference_points, + (batch_size, num_queries, 1, self.n_levels, 1, 2), + ) + + sampling_offsets / offset_normalizer + ) + elif keras.ops.shape(reference_points)[-1] == 4: + num_points_scale_t = keras.ops.cast( + self.num_points_scale, dtype=hidden_states.dtype + ) + num_points_scale_t = keras.ops.expand_dims( + num_points_scale_t, axis=-1 + ) + offset = ( + sampling_offsets + * num_points_scale_t + * keras.ops.expand_dims(reference_points[..., 2:], axis=-2) + * self.offset_scale + ) + sampling_locations = ( + keras.ops.expand_dims(reference_points[..., :2], axis=-2) + + offset + ) + else: + raise ValueError( + f"Last dim of reference_points must be 2 or 4, but get " + f"{keras.ops.shape(reference_points)[-1]} instead." + ) + return sampling_locations, attention_weights + + def call( + self, + hidden_states, + encoder_hidden_states, + reference_points, + spatial_shapes, + ): + batch_size = keras.ops.shape(hidden_states)[0] + num_queries = keras.ops.shape(hidden_states)[1] + sequence_length = keras.ops.shape(encoder_hidden_states)[1] + value = keras.ops.reshape( + encoder_hidden_states, + ( + batch_size, + sequence_length, + self.n_heads, + self.hidden_dim // self.n_heads, + ), + ) + sampling_locations, attention_weights = self.compute_attention( + hidden_states, reference_points, spatial_shapes + ) + + # NOTE: slice_sizes_values passed down to ms_deformable_attn_core + # since JAX tracing doesn't support dynamic shapes. + slice_sizes = [h * w for h, w in self.spatial_shapes] + output = self.ms_deformable_attn_core( + value, + spatial_shapes, + sampling_locations, + attention_weights, + self.num_points, + slice_sizes, + self.spatial_shapes, + self.n_levels, + num_queries, + self.decoder_method, + ) + return output, attention_weights + + def compute_output_spec( + self, + hidden_states, + encoder_hidden_states, + reference_points, + spatial_shapes, + ): + input_shape = hidden_states.shape + batch_size = input_shape[0] if len(input_shape) > 0 else None + num_queries = input_shape[1] if len(input_shape) > 1 else None + output_shape = (batch_size, num_queries, self.hidden_dim) + output_spec = keras.KerasTensor(output_shape, dtype=self.compute_dtype) + attention_weights_shape = ( + batch_size, + num_queries, + self.n_heads, + sum(self.num_points), + ) + attention_weights_spec = keras.KerasTensor( + attention_weights_shape, dtype=self.compute_dtype + ) + return output_spec, attention_weights_spec + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "decoder_attention_heads": self.n_heads, + "num_feature_levels": self.n_levels, + "decoder_offset_scale": self.offset_scale, + "decoder_method": self.decoder_method, + "decoder_n_points": self.decoder_n_points, + "num_queries": self.num_queries, + "spatial_shapes": self.spatial_shapes, + } + ) + return config + + +class DFineMultiheadAttention(keras.layers.Layer): + """Multi-head attention layer for D-FINE models. + + This layer implements a standard multi-head attention mechanism. It is used + in two key places within the D-FINE architecture: + 1. In `DFineEncoderLayer` as the self-attention mechanism to process the + sequence of image features from the `HGNetV2Backbone` class. + 2. In `DFineDecoderLayer` as the self-attention mechanism to allow object + queries to interact with each other. + + It supports position embeddings to incorporate positional information and + attention masking to prevent attending to certain positions. + + Args: + embedding_dim: int, Embedding dimension size. + num_heads: int, Number of attention heads. + dropout: float, optional, Dropout probability for attention weights. + Defaults to `0.0`. + bias: bool, optional, Whether to include bias in projection layers. + Defaults to `True`. + kernel_initializer: str or initializer, optional, Initializer for + kernel weights. Defaults to `"glorot_uniform"`. + bias_initializer: str or initializer, optional, Initializer for + bias weights. Defaults to `"zeros"`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + embedding_dim, + num_heads, + dropout=0.0, + bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.dropout_rate = dropout + self.head_dim = embedding_dim // num_heads + if self.head_dim * self.num_heads != self.embedding_dim: + raise ValueError( + f"embedding_dim must be divisible by num_heads (got " + f"`embedding_dim`: {self.embedding_dim} and `num_heads`: " + f"{self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.bias = bias + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + self.dropout = keras.layers.Dropout( + self.dropout_rate, dtype=self.dtype_policy + ) + + def build(self, input_shape): + embedding_dim = self.embedding_dim + proj_equation = "abc,cde->abde" + proj_bias_axes = "de" + proj_output_shape = (None, self.num_heads, self.head_dim) + proj_input_shape = (None, None, embedding_dim) + self.q_proj = keras.layers.EinsumDense( + proj_equation, + output_shape=proj_output_shape, + bias_axes=proj_bias_axes if self.bias else None, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer if self.bias else None, + dtype=self.dtype_policy, + name="q_proj", + ) + self.q_proj.build(proj_input_shape) + self.k_proj = keras.layers.EinsumDense( + proj_equation, + output_shape=proj_output_shape, + bias_axes=proj_bias_axes if self.bias else None, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer if self.bias else None, + dtype=self.dtype_policy, + name="k_proj", + ) + self.k_proj.build(proj_input_shape) + self.v_proj = keras.layers.EinsumDense( + proj_equation, + output_shape=proj_output_shape, + bias_axes=proj_bias_axes if self.bias else None, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer if self.bias else None, + dtype=self.dtype_policy, + name="v_proj", + ) + self.v_proj.build(proj_input_shape) + out_proj_input_shape = (None, None, self.num_heads * self.head_dim) + out_proj_output_shape = (None, self.embedding_dim) + self.out_proj = keras.layers.EinsumDense( + "abc,cd->abd", + output_shape=out_proj_output_shape, + bias_axes="d" if self.bias else None, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer if self.bias else None, + dtype=self.dtype_policy, + name="out_proj", + ) + self.out_proj.build(out_proj_input_shape) + super().build(input_shape) + + def call( + self, + hidden_states, + position_embeddings=None, + attention_mask=None, + output_attentions=False, + training=None, + ): + batch_size = keras.ops.shape(hidden_states)[0] + target_len = keras.ops.shape(hidden_states)[1] + + def with_pos_embed(tensor, position_embeddings_k): + return ( + tensor + if position_embeddings_k is None + else tensor + position_embeddings_k + ) + + hidden_states_with_pos = with_pos_embed( + hidden_states, position_embeddings + ) + query_states = self.q_proj(hidden_states_with_pos) + key_states = self.k_proj(hidden_states_with_pos) + value_states = self.v_proj(hidden_states) + attn_weights = keras.ops.einsum( + "bthd,bshd->bhts", query_states * self.scaling, key_states + ) + if attention_mask is not None: + if keras.ops.ndim(attention_mask) == 2: + attention_mask = keras.ops.expand_dims(attention_mask, axis=0) + attention_mask = keras.ops.expand_dims(attention_mask, axis=1) + attn_weights = attn_weights + attention_mask + attn_weights = keras.ops.softmax(attn_weights, axis=-1) + attn_weights_for_output = attn_weights if output_attentions else None + attn_probs = self.dropout(attn_weights, training=training) + attn_output = keras.ops.einsum( + "bhts,bshd->bthd", attn_probs, value_states + ) + attn_output = keras.ops.reshape( + attn_output, (batch_size, target_len, self.embedding_dim) + ) + attn_output = self.out_proj(attn_output) + if output_attentions: + return attn_output, attn_weights_for_output + else: + return attn_output + + def compute_output_spec( + self, + hidden_states, + position_embeddings=None, + attention_mask=None, + output_attentions=False, + training=None, + ): + input_shape = hidden_states.shape + batch_size = input_shape[0] if len(input_shape) > 0 else None + target_len = input_shape[1] if len(input_shape) > 1 else None + source_len = target_len + attn_output_shape = (batch_size, target_len, self.embedding_dim) + attn_output_spec = keras.KerasTensor( + attn_output_shape, dtype=self.compute_dtype + ) + if output_attentions: + attn_weights_shape = ( + batch_size, + self.num_heads, + target_len, + source_len, + ) + attn_weights_spec = keras.KerasTensor( + attn_weights_shape, dtype=self.compute_dtype + ) + return attn_output_spec, attn_weights_spec + return attn_output_spec + + def get_config(self): + config = super().get_config() + config.update( + { + "embedding_dim": self.embedding_dim, + "num_heads": self.num_heads, + "dropout": self.dropout_rate, + "bias": self.bias, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + } + ) + return config diff --git a/keras_hub/src/models/d_fine/d_fine_backbone.py b/keras_hub/src/models/d_fine/d_fine_backbone.py new file mode 100644 index 0000000000..e74d72ae3c --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_backbone.py @@ -0,0 +1,891 @@ +import math + +import keras +import numpy as np + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.d_fine.d_fine_decoder import DFineDecoder +from keras_hub.src.models.d_fine.d_fine_hybrid_encoder import DFineHybridEncoder +from keras_hub.src.models.d_fine.d_fine_layers import DFineAnchorGenerator +from keras_hub.src.models.d_fine.d_fine_layers import ( + DFineContrastiveDenoisingGroupGenerator, +) +from keras_hub.src.models.d_fine.d_fine_layers import ( + DFineInitialQueryAndReferenceGenerator, +) +from keras_hub.src.models.d_fine.d_fine_layers import DFineMLPPredictionHead +from keras_hub.src.models.d_fine.d_fine_layers import DFineSourceFlattener +from keras_hub.src.models.d_fine.d_fine_layers import ( + DFineSpatialShapesExtractor, +) +from keras_hub.src.models.d_fine.d_fine_utils import d_fine_kernel_initializer +from keras_hub.src.utils.keras_utils import standardize_data_format + + +class DFineDenoisingPreprocessorLayer(keras.layers.Layer): + """Processes and prepares tensors for contrastive denoising. + + This layer is a helper used within the `DFineBackbone`'s functional model + definition. Its primary role is to take the outputs from the + `DFineContrastiveDenoisingGroupGenerator` and prepare them for the dynamic, + per-batch forward pass, mostly since this functionality cannot be integrated + directly into the `DFineBackbone` in the symbolic forward pass. + + The layer takes a tuple of `(pixel_values, input_query_class, + denoising_bbox_unact, attention_mask)` and an optional + `denoising_meta_values` dictionary as input to its `call` method. + """ + + def __init__(self, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + + def call(self, inputs, denoising_meta_values=None): + ( + pixel_values, + input_query_class, + denoising_bbox_unact, + attention_mask, + ) = inputs + input_query_class_tensor = keras.ops.convert_to_tensor( + input_query_class, dtype="int32" + ) + denoising_bbox_unact_tensor = keras.ops.convert_to_tensor( + denoising_bbox_unact, dtype=self.compute_dtype + ) + attention_mask_tensor = keras.ops.convert_to_tensor( + attention_mask, dtype=self.compute_dtype + ) + outputs = { + "input_query_class": input_query_class_tensor, + "denoising_bbox_unact": denoising_bbox_unact_tensor, + "attention_mask": attention_mask_tensor, + } + + if denoising_meta_values is not None: + batch_size = keras.ops.shape(pixel_values)[0] + dn_positive_idx = denoising_meta_values["dn_positive_idx"] + c_batch_size = keras.ops.shape(dn_positive_idx)[0] + if c_batch_size == 0: + outputs["dn_positive_idx"] = keras.ops.zeros( + (batch_size,) + keras.ops.shape(dn_positive_idx)[1:], + dtype=dn_positive_idx.dtype, + ) + else: + num_repeats = (batch_size + c_batch_size - 1) // c_batch_size + dn_positive_idx_tiled = keras.ops.tile( + dn_positive_idx, + (num_repeats,) + + (1,) * (keras.ops.ndim(dn_positive_idx) - 1), + ) + outputs["dn_positive_idx"] = dn_positive_idx_tiled[:batch_size] + dn_num_group = denoising_meta_values["dn_num_group"] + outputs["dn_num_group"] = keras.ops.tile( + keras.ops.expand_dims(dn_num_group, 0), (batch_size,) + ) + dn_num_split = denoising_meta_values["dn_num_split"] + outputs["dn_num_split"] = keras.ops.tile( + keras.ops.expand_dims(dn_num_split, 0), (batch_size, 1) + ) + + return outputs + + +@keras_hub_export("keras_hub.models.DFineBackbone") +class DFineBackbone(Backbone): + """D-FINE Backbone for Object Detection. + + This class implements the core D-FINE architecture, which serves as the + backbone for `DFineObjectDetector`. It integrates a `HGNetV2Backbone` for + initial feature extraction, a `DFineHybridEncoder` for multi-scale feature + fusion using FPN/PAN pathways, and a `DFineDecoder` for refining object + queries. + + The backbone orchestrates the entire forward pass, from processing raw + pixels to generating intermediate predictions. Key steps include: + 1. Extracting multi-scale feature maps using the HGNetV2 backbone. + 2. Fusing these features with the hybrid encoder. + 3. Generating anchor proposals and selecting the top-k to initialize + decoder queries and reference points. + 4. Generating noisy queries for contrastive denoising (if the `labels` + argument is provided). + 5. Passing the queries and fused features through the transformer decoder + to produce iterative predictions for bounding boxes and class logits. + + Args: + backbone: A `keras.Model` instance that serves as the feature extractor. + While any `keras.Model` can be used, we highly recommend using a + `keras_hub.models.HGNetV2Backbone` instance, as this architecture is + optimized for its outputs. If a custom backbone is provided, it + must have a `stage_names` attribute, or the `out_features` argument + for this model must be specified. This requirement helps prevent + hard-to-debug downstream dimensionality errors. + decoder_in_channels: list, Channel dimensions of the multi-scale + features from the hybrid encoder. This should typically be a list + of `encoder_hidden_dim` repeated for each feature level. + encoder_hidden_dim: int, Hidden dimension size for the encoder layers. + num_labels: int, Number of object classes for detection. + num_denoising: int, Number of denoising queries for contrastive + denoising training. Set to `0` to disable denoising. + learn_initial_query: bool, Whether to learn initial query embeddings. + num_queries: int, Number of object queries for detection. + anchor_image_size: tuple, Size of the anchor image as `(height, width)`. + feat_strides: list, List of feature stride values for different pyramid + levels. + num_feature_levels: int, Number of feature pyramid levels to use. + hidden_dim: int, Hidden dimension size for the model. + encoder_in_channels: list, Channel dimensions of the feature maps from + the backbone (`HGNetV2Backbone`) that are fed into the hybrid + encoder. + encode_proj_layers: list, List specifying projection layer + configurations. + num_attention_heads: int, Number of attention heads in encoder layers. + encoder_ffn_dim: int, Feed-forward network dimension in encoder. + num_encoder_layers: int, Number of encoder layers. + hidden_expansion: float, Hidden dimension expansion factor. + depth_multiplier: float, Depth multiplier for the backbone. + eval_idx: int, Index for evaluation. Defaults to `-1` for the last + layer. + num_decoder_layers: int, Number of decoder layers. + decoder_attention_heads: int, Number of attention heads in decoder + layers. + decoder_ffn_dim: int, Feed-forward network dimension in decoder. + decoder_method: str, Decoder method. Can be either `"default"` or + `"discrete"`. Defaults to `"default"`. + decoder_n_points: list, Number of sampling points for deformable + attention. + lqe_hidden_dim: int, Hidden dimension for learned query embedding. + num_lqe_layers: int, Number of layers in learned query embedding. + label_noise_ratio: float, Ratio of label noise for denoising + training. Defaults to `0.5`. + box_noise_scale: float, Scale factor for box noise in denoising + training. Defaults to `1.0`. + labels: list or None, Ground truth labels for denoising training. This + is passed during model initialization to construct the training + graph for contrastive denoising. Each element should be a + dictionary with `"boxes"` (numpy array of shape `[N, 4]` with + normalized coordinates) and `"labels"` (numpy array of shape `[N]` + with class indices). Required when `num_denoising > 0`. Defaults to + `None`. + seed: int or None, Random seed for reproducibility. Defaults to `None`. + image_shape: tuple, Shape of input images as `(height, width, + channels)`. Height and width can be `None` for variable input sizes. + Defaults to `(None, None, 3)`. + out_features: list or None, List of feature names to output from + backbone. If `None`, uses the last `len(decoder_in_channels)` + features from the backbone's `stage_names`. Defaults to `None`. + data_format: str, The data format of the image channels. Can be either + `"channels_first"` or `"channels_last"`. If `None` is specified, + it will use the `image_data_format` value found in your Keras + config file at `~/.keras/keras.json`. Defaults to `None`. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the base class. + + Example: + ```python + import keras + import numpy as np + from keras_hub.models import DFineBackbone + from keras_hub.models import HGNetV2Backbone + + # Example 1: Basic usage without denoising. + # First, build the `HGNetV2Backbone` instance. + hgnetv2 = HGNetV2Backbone( + stem_channels=[3, 16, 16], + stackwise_stage_filters=[ + [16, 16, 64, 1, 3, 3], + [64, 32, 256, 1, 3, 3], + [256, 64, 512, 2, 3, 5], + [512, 128, 1024, 1, 3, 5], + ], + apply_downsample=[False, True, True, True], + use_lightweight_conv_block=[False, False, True, True], + depths=[1, 1, 2, 1], + hidden_sizes=[64, 256, 512, 1024], + embedding_size=16, + use_learnable_affine_block=True, + hidden_act="relu", + image_shape=(None, None, 3), + out_features=["stage3", "stage4"], + data_format="channels_last", + ) + + # Then, pass the backbone instance to `DFineBackbone`. + backbone = DFineBackbone( + backbone=hgnetv2, + decoder_in_channels=[128, 128], + encoder_hidden_dim=128, + num_denoising=0, # Disable denoising + num_labels=80, + hidden_dim=128, + learn_initial_query=False, + num_queries=300, + anchor_image_size=(256, 256), + feat_strides=[16, 32], + num_feature_levels=2, + encoder_in_channels=[512, 1024], + encode_proj_layers=[1], + num_attention_heads=8, + encoder_ffn_dim=512, + num_encoder_layers=1, + hidden_expansion=0.34, + depth_multiplier=0.5, + eval_idx=-1, + num_decoder_layers=3, + decoder_attention_heads=8, + decoder_ffn_dim=512, + decoder_n_points=[6, 6], + lqe_hidden_dim=64, + num_lqe_layers=2, + out_features=["stage3", "stage4"], + image_shape=(None, None, 3), + data_format="channels_last", + seed=0, + ) + + # Prepare input data. + input_data = keras.random.uniform((2, 256, 256, 3)) + + # Forward pass. + outputs = backbone(input_data) + + # Example 2: With contrastive denoising training. + labels = [ + { + "boxes": np.array([[0.5, 0.5, 0.2, 0.2], [0.4, 0.4, 0.1, 0.1]]), + "labels": np.array([1, 10]), + }, + { + "boxes": np.array([[0.6, 0.6, 0.3, 0.3]]), + "labels": np.array([20]), + }, + ] + + # Pass the `HGNetV2Backbone` instance to `DFineBackbone`. + backbone_with_denoising = DFineBackbone( + backbone=hgnetv2, + decoder_in_channels=[128, 128], + encoder_hidden_dim=128, + num_denoising=100, # Enable denoising + num_labels=80, + hidden_dim=128, + learn_initial_query=False, + num_queries=300, + anchor_image_size=(256, 256), + feat_strides=[16, 32], + num_feature_levels=2, + encoder_in_channels=[512, 1024], + encode_proj_layers=[1], + num_attention_heads=8, + encoder_ffn_dim=512, + num_encoder_layers=1, + hidden_expansion=0.34, + depth_multiplier=0.5, + eval_idx=-1, + num_decoder_layers=3, + decoder_attention_heads=8, + decoder_ffn_dim=512, + decoder_n_points=[6, 6], + lqe_hidden_dim=64, + num_lqe_layers=2, + out_features=["stage3", "stage4"], + image_shape=(None, None, 3), + seed=0, + labels=labels, + ) + + # Forward pass with denoising. + outputs_with_denoising = backbone_with_denoising(input_data) + ``` + """ + + def __init__( + self, + backbone, + decoder_in_channels, + encoder_hidden_dim, + num_labels, + num_denoising, + learn_initial_query, + num_queries, + anchor_image_size, + feat_strides, + num_feature_levels, + hidden_dim, + encoder_in_channels, + encode_proj_layers, + num_attention_heads, + encoder_ffn_dim, + num_encoder_layers, + hidden_expansion, + depth_multiplier, + eval_idx, + num_decoder_layers, + decoder_attention_heads, + decoder_ffn_dim, + decoder_n_points, + lqe_hidden_dim, + num_lqe_layers, + decoder_method="default", + label_noise_ratio=0.5, + box_noise_scale=1.0, + labels=None, + seed=None, + image_shape=(None, None, 3), + out_features=None, + data_format=None, + dtype=None, + **kwargs, + ): + if decoder_method not in ["default", "discrete"]: + decoder_method = "default" + data_format = standardize_data_format(data_format) + channel_axis = -1 if data_format == "channels_last" else 1 + self.backbone = backbone + # Re-instantiate the backbone if its data_format mismatches the parents. + if ( + hasattr(self.backbone, "data_format") + and self.backbone.data_format != data_format + ): + backbone_config = self.backbone.get_config() + backbone_config["data_format"] = data_format + if ( + "image_shape" in backbone_config + and backbone_config["image_shape"] is not None + and len(backbone_config["image_shape"]) == 3 + ): + backbone_config["image_shape"] = tuple( + reversed(backbone_config["image_shape"]) + ) + self.backbone = self.backbone.__class__.from_config(backbone_config) + spatial_shapes = [] + for s in feat_strides: + h = anchor_image_size[0] // s + w = anchor_image_size[1] // s + spatial_shapes.append((h, w)) + # NOTE: While `HGNetV2Backbone` is handled automatically, `out_features` + # must be specified for custom backbones. This design choice prevents + # hard-to-debug dimension mismatches by placing the onus on the user for + # ensuring compatibility. + if not hasattr(self.backbone, "stage_names") and out_features is None: + raise ValueError( + "`out_features` must be specified when using a custom " + "backbone that does not have a `stage_names` attribute." + ) + stage_names = getattr(self.backbone, "stage_names", out_features) + out_features = ( + out_features + if out_features is not None + else stage_names[-len(decoder_in_channels) :] + ) + initializer = d_fine_kernel_initializer( + initializer_range=0.01, + ) + + # === Layers === + self.encoder = DFineHybridEncoder( + encoder_in_channels=encoder_in_channels, + feat_strides=feat_strides, + encoder_hidden_dim=encoder_hidden_dim, + encode_proj_layers=encode_proj_layers, + positional_encoding_temperature=10000, + eval_size=None, + normalize_before=False, + num_attention_heads=num_attention_heads, + dropout=0.0, + layer_norm_eps=1e-5, + encoder_activation_function="gelu", + activation_dropout=0.0, + encoder_ffn_dim=encoder_ffn_dim, + num_encoder_layers=num_encoder_layers, + batch_norm_eps=1e-5, + hidden_expansion=hidden_expansion, + depth_multiplier=depth_multiplier, + kernel_initializer=initializer, + bias_initializer="zeros", + channel_axis=channel_axis, + data_format=data_format, + dtype=dtype, + name="hybrid_encoder", + ) + self.decoder = DFineDecoder( + layer_scale=1.0, + eval_idx=eval_idx, + num_decoder_layers=num_decoder_layers, + dropout=0.0, + hidden_dim=hidden_dim, + reg_scale=4.0, + max_num_bins=32, + upsampling_factor=0.5, + decoder_attention_heads=decoder_attention_heads, + attention_dropout=0.0, + decoder_activation_function="relu", + activation_dropout=0.0, + layer_norm_eps=1e-5, + decoder_ffn_dim=decoder_ffn_dim, + num_feature_levels=num_feature_levels, + decoder_offset_scale=0.5, + decoder_method=decoder_method, + decoder_n_points=decoder_n_points, + top_prob_values=4, + lqe_hidden_dim=lqe_hidden_dim, + num_lqe_layers=num_lqe_layers, + num_labels=num_labels, + spatial_shapes=spatial_shapes, + dtype=dtype, + initializer_bias_prior_prob=None, + num_queries=num_queries, + name="decoder", + ) + self.anchor_generator = DFineAnchorGenerator( + anchor_image_size=anchor_image_size, + feat_strides=feat_strides, + data_format=data_format, + dtype=dtype, + name="anchor_generator", + ) + self.contrastive_denoising_group_generator = ( + DFineContrastiveDenoisingGroupGenerator( + num_labels=num_labels, + num_denoising=num_denoising, + label_noise_ratio=label_noise_ratio, + box_noise_scale=box_noise_scale, + seed=seed, + dtype=dtype, + name="contrastive_denoising_group_generator", + ) + ) + if num_denoising > 0: + self.denoising_class_embed = keras.layers.Embedding( + input_dim=num_labels + 1, + output_dim=hidden_dim, + embeddings_initializer="glorot_uniform", + name="denoising_class_embed", + dtype=dtype, + ) + self.denoising_class_embed.build(None) + else: + self.denoising_class_embed = None + + self.source_flattener = DFineSourceFlattener( + dtype=dtype, + name="source_flattener", + channel_axis=channel_axis, + data_format=data_format, + ) + self.initial_query_reference_generator = ( + DFineInitialQueryAndReferenceGenerator( + num_queries=num_queries, + learn_initial_query=learn_initial_query, + hidden_dim=hidden_dim, + dtype=dtype, + name="initial_query_reference_generator", + ) + ) + self.spatial_shapes_extractor = DFineSpatialShapesExtractor( + dtype=dtype, + data_format=data_format, + name="spatial_shapes_extractor", + ) + num_backbone_outs = len(decoder_in_channels) + self.encoder_input_proj_layers = [] + for i in range(num_backbone_outs): + self.encoder_input_proj_layers.append( + [ + keras.layers.Conv2D( + filters=encoder_hidden_dim, + kernel_size=1, + use_bias=False, + kernel_initializer=initializer, + bias_initializer="zeros", + data_format=data_format, + name=f"encoder_input_proj_conv_{i}", + dtype=dtype, + ), + keras.layers.BatchNormalization( + epsilon=1e-5, + axis=channel_axis, + name=f"encoder_input_proj_bn_{i}", + dtype=dtype, + ), + ] + ) + self.enc_output_layers = [ + keras.layers.Dense( + hidden_dim, + name="enc_output_dense", + dtype=dtype, + ), + keras.layers.LayerNormalization( + epsilon=1e-5, + name="enc_output_ln", + dtype=dtype, + ), + ] + prior_prob = 1 / (num_labels + 1) + enc_score_head_bias = float(-math.log((1 - prior_prob) / prior_prob)) + self.enc_score_head = keras.layers.Dense( + num_labels, + name="enc_score_head", + dtype=dtype, + kernel_initializer="glorot_uniform", + bias_initializer=keras.initializers.Constant(enc_score_head_bias), + ) + self.enc_bbox_head = DFineMLPPredictionHead( + input_dim=hidden_dim, + hidden_dim=hidden_dim, + output_dim=4, + num_layers=3, + name="enc_bbox_head", + dtype=dtype, + kernel_initializer=initializer, + last_layer_initializer="zeros", + ) + self.decoder_input_proj_layers = [] + for i in range(num_backbone_outs): + if hidden_dim == decoder_in_channels[-1]: + proj_layer = keras.layers.Identity( + name=f"decoder_input_proj_identity_{i}", + dtype=dtype, + ) + self.decoder_input_proj_layers.append(proj_layer) + else: + self.decoder_input_proj_layers.append( + [ + keras.layers.Conv2D( + filters=hidden_dim, + kernel_size=1, + use_bias=False, + kernel_initializer=initializer, + bias_initializer="zeros", + data_format=data_format, + name=f"decoder_input_proj_conv1_{i}", + dtype=dtype, + ), + keras.layers.BatchNormalization( + epsilon=1e-5, + axis=channel_axis, + name=f"decoder_input_proj_bn1_{i}", + dtype=dtype, + ), + ] + ) + for i in range(num_feature_levels - num_backbone_outs): + idx = num_backbone_outs + i + if hidden_dim == decoder_in_channels[-1]: + proj_layer = keras.layers.Identity( + name=f"decoder_input_proj_identity_{idx}", + dtype=dtype, + ) + self.decoder_input_proj_layers.append(proj_layer) + else: + self.decoder_input_proj_layers.append( + [ + keras.layers.Conv2D( + filters=hidden_dim, + kernel_size=3, + strides=2, + padding="same", + use_bias=False, + kernel_initializer=initializer, + bias_initializer="zeros", + data_format=data_format, + name=f"decoder_input_proj_conv3_{idx}", + dtype=dtype, + ), + keras.layers.BatchNormalization( + epsilon=1e-5, + axis=channel_axis, + name=f"decoder_input_proj_bn3_{idx}", + dtype=dtype, + ), + ] + ) + self.dn_split_point = None + + # === Functional Model === + pixel_values = keras.Input( + shape=image_shape, name="pixel_values", dtype="float32" + ) + feature_maps_output = self.backbone(pixel_values) + feature_maps = [feature_maps_output[stage] for stage in out_features] + feature_maps_output_tuple = tuple(feature_maps) + proj_feats = [] + for level, feature_map in enumerate(feature_maps_output_tuple): + x = self.encoder_input_proj_layers[level][0](feature_map) + x = self.encoder_input_proj_layers[level][1](x) + proj_feats.append(x) + encoder_outputs = self.encoder( + inputs_embeds=proj_feats, + output_hidden_states=True, + output_attentions=True, + ) + encoder_last_hidden_state = encoder_outputs[0] + encoder_hidden_states = ( + encoder_outputs[1] if len(encoder_outputs) > 1 else None + ) + encoder_attentions = ( + encoder_outputs[2] if len(encoder_outputs) > 2 else None + ) + last_hidden_state = encoder_outputs[0] + sources = [] + # NOTE: Handle both no-op (identity mapping) and an actual projection + # using Conv2D and BatchNorm with `isinstance(proj, list)`. + for level, source in enumerate(last_hidden_state): + proj = self.decoder_input_proj_layers[level] + if isinstance(proj, list): + x = proj[0](source) + x = proj[1](x) + sources.append(x) + else: + sources.append(proj(source)) + if num_feature_levels > len(sources): + len_sources = len(sources) + proj = self.decoder_input_proj_layers[len_sources] + if isinstance(proj, list): + x = proj[0](last_hidden_state[-1]) + x = proj[1](x) + sources.append(x) + else: + sources.append(proj(last_hidden_state[-1])) + for i in range(len_sources + 1, num_feature_levels): + proj = self.decoder_input_proj_layers[i] + if isinstance(proj, list): + x = proj[0](sources[-1]) + x = proj[1](x) + sources.append(x) + else: + sources.append(proj(sources[-1])) + spatial_shapes_tensor = self.spatial_shapes_extractor(sources) + source_flatten = self.source_flattener(sources) + if num_denoising > 0 and labels is not None: + ( + input_query_class, + denoising_bbox_unact, + attention_mask, + denoising_meta_values, + ) = self.contrastive_denoising_group_generator( + targets=labels, + num_queries=num_queries, + ) + self.dn_split_point = int(denoising_meta_values["dn_num_split"][0]) + else: + ( + denoising_class, + denoising_bbox_unact, + attention_mask, + denoising_meta_values, + ) = None, None, None, None + + if num_denoising > 0 and labels is not None: + denoising_processor = DFineDenoisingPreprocessorLayer( + name="denoising_processor", dtype=dtype + ) + denoising_tensors = denoising_processor( + [ + pixel_values, + input_query_class, + denoising_bbox_unact, + attention_mask, + ], + denoising_meta_values=denoising_meta_values, + ) + input_query_class_tensor = denoising_tensors["input_query_class"] + denoising_bbox_unact = denoising_tensors["denoising_bbox_unact"] + attention_mask = denoising_tensors["attention_mask"] + denoising_class = self.denoising_class_embed( + input_query_class_tensor + ) + + anchors, valid_mask = self.anchor_generator(sources) + memory = keras.ops.where(valid_mask, source_flatten, 0.0) + output_memory = self.enc_output_layers[0](memory) + output_memory = self.enc_output_layers[1](output_memory) + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + enc_outputs_coord_logits_plus_anchors = ( + enc_outputs_coord_logits + anchors + ) + init_reference_points, target, enc_topk_logits, enc_topk_bboxes = ( + self.initial_query_reference_generator( + ( + enc_outputs_class, + enc_outputs_coord_logits_plus_anchors, + output_memory, + sources[-1], + ), + denoising_bbox_unact=denoising_bbox_unact, + denoising_class=denoising_class, + ) + ) + decoder_outputs = self.decoder( + inputs_embeds=target, + encoder_hidden_states=source_flatten, + reference_points=init_reference_points, + spatial_shapes=spatial_shapes_tensor, + attention_mask=attention_mask, + output_hidden_states=True, + output_attentions=True, + ) + last_hidden_state = decoder_outputs[0] + intermediate_hidden_states = decoder_outputs[1] + intermediate_logits = decoder_outputs[2] + intermediate_reference_points = decoder_outputs[3] + intermediate_predicted_corners = decoder_outputs[4] + initial_reference_points = decoder_outputs[5] + decoder_hidden_states = ( + decoder_outputs[6] if len(decoder_outputs) > 6 else None + ) + decoder_attentions = ( + decoder_outputs[7] if len(decoder_outputs) > 7 else None + ) + cross_attentions = ( + decoder_outputs[8] if len(decoder_outputs) > 8 else None + ) + outputs = { + "last_hidden_state": last_hidden_state, + "intermediate_hidden_states": intermediate_hidden_states, + "intermediate_logits": intermediate_logits, + "intermediate_reference_points": intermediate_reference_points, + "intermediate_predicted_corners": intermediate_predicted_corners, + "initial_reference_points": initial_reference_points, + "decoder_hidden_states": decoder_hidden_states, + "decoder_attentions": decoder_attentions, + "cross_attentions": cross_attentions, + "encoder_last_hidden_state": encoder_last_hidden_state[0], + "encoder_hidden_states": encoder_hidden_states, + "encoder_attentions": encoder_attentions, + "init_reference_points": init_reference_points, + "enc_topk_logits": enc_topk_logits, + "enc_topk_bboxes": enc_topk_bboxes, + "enc_outputs_class": enc_outputs_class, + "enc_outputs_coord_logits": enc_outputs_coord_logits, + } + + if num_denoising > 0 and labels is not None: + outputs["dn_positive_idx"] = denoising_tensors["dn_positive_idx"] + outputs["dn_num_group"] = denoising_tensors["dn_num_group"] + outputs["dn_num_split"] = denoising_tensors["dn_num_split"] + + outputs = {k: v for k, v in outputs.items() if v is not None} + super().__init__( + inputs=pixel_values, + outputs=outputs, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.decoder_in_channels = decoder_in_channels + self.encoder_hidden_dim = encoder_hidden_dim + self.num_labels = num_labels + self.num_denoising = num_denoising + self.learn_initial_query = learn_initial_query + self.num_queries = num_queries + self.anchor_image_size = anchor_image_size + self.feat_strides = feat_strides + self.num_feature_levels = num_feature_levels + self.hidden_dim = hidden_dim + self.encoder_in_channels = encoder_in_channels + self.encode_proj_layers = encode_proj_layers + self.num_attention_heads = num_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.num_encoder_layers = num_encoder_layers + self.hidden_expansion = hidden_expansion + self.depth_multiplier = depth_multiplier + self.eval_idx = eval_idx + self.box_noise_scale = box_noise_scale + self.labels = labels + self.label_noise_ratio = label_noise_ratio + self.num_decoder_layers = num_decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_method = decoder_method + self.decoder_n_points = decoder_n_points + self.lqe_hidden_dim = lqe_hidden_dim + self.num_lqe_layers = num_lqe_layers + self.data_format = data_format + self.seed = seed + self.image_shape = image_shape + self.channel_axis = channel_axis + self.spatial_shapes = spatial_shapes + self.stage_names = stage_names + self.out_features = out_features + self.initializer = initializer + + def get_config(self): + config = super().get_config() + serializable_labels = None + if self.labels is not None: + serializable_labels = [] + for target in self.labels: + serializable_target = {} + for key, value in target.items(): + if hasattr(value, "tolist"): + serializable_target[key] = value.tolist() + else: + serializable_target[key] = value + serializable_labels.append(serializable_target) + config.update( + { + "backbone": keras.layers.serialize(self.backbone), + "decoder_in_channels": self.decoder_in_channels, + "encoder_hidden_dim": self.encoder_hidden_dim, + "num_labels": self.num_labels, + "num_denoising": self.num_denoising, + "learn_initial_query": self.learn_initial_query, + "num_queries": self.num_queries, + "anchor_image_size": self.anchor_image_size, + "feat_strides": self.feat_strides, + "num_feature_levels": self.num_feature_levels, + "hidden_dim": self.hidden_dim, + "encoder_in_channels": self.encoder_in_channels, + "encode_proj_layers": self.encode_proj_layers, + "num_attention_heads": self.num_attention_heads, + "encoder_ffn_dim": self.encoder_ffn_dim, + "num_encoder_layers": self.num_encoder_layers, + "hidden_expansion": self.hidden_expansion, + "depth_multiplier": self.depth_multiplier, + "eval_idx": self.eval_idx, + "box_noise_scale": self.box_noise_scale, + "label_noise_ratio": self.label_noise_ratio, + "labels": serializable_labels, + "num_decoder_layers": self.num_decoder_layers, + "decoder_attention_heads": self.decoder_attention_heads, + "decoder_ffn_dim": self.decoder_ffn_dim, + "decoder_method": self.decoder_method, + "decoder_n_points": self.decoder_n_points, + "lqe_hidden_dim": self.lqe_hidden_dim, + "num_lqe_layers": self.num_lqe_layers, + "seed": self.seed, + "image_shape": self.image_shape, + "data_format": self.data_format, + "out_features": self.out_features, + } + ) + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + config = config.copy() + if "labels" in config and config["labels"] is not None: + labels = config["labels"] + deserialized_labels = [] + for target in labels: + deserialized_target = {} + for key, value in target.items(): + if isinstance(value, list): + deserialized_target[key] = np.array(value) + else: + deserialized_target[key] = value + deserialized_labels.append(deserialized_target) + config["labels"] = deserialized_labels + if "dtype" in config and config["dtype"] is not None: + dtype_config = config["dtype"] + if "dtype" not in config["backbone"]["config"]: + config["backbone"]["config"]["dtype"] = dtype_config + config["backbone"] = keras.layers.deserialize( + config["backbone"], custom_objects=custom_objects + ) + return cls(**config) diff --git a/keras_hub/src/models/d_fine/d_fine_backbone_test.py b/keras_hub/src/models/d_fine/d_fine_backbone_test.py new file mode 100644 index 0000000000..822e2c09c3 --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_backbone_test.py @@ -0,0 +1,146 @@ +import keras +import numpy as np +import pytest +from absl.testing import parameterized + +from keras_hub.src.models.d_fine.d_fine_backbone import DFineBackbone +from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone +from keras_hub.src.tests.test_case import TestCase + + +class DFineBackboneTest(TestCase): + def setUp(self): + self.labels = [ + { + "boxes": np.array([[0.5, 0.5, 0.2, 0.2]]), + "labels": np.array([1]), + }, + { + "boxes": np.array([[0.6, 0.6, 0.3, 0.3]]), + "labels": np.array([2]), + }, + ] + hgnetv2_backbone = HGNetV2Backbone( + stem_channels=[3, 8, 8], + stackwise_stage_filters=[ + [8, 8, 16, 1, 1, 3], + [16, 8, 32, 1, 1, 3], + ], + apply_downsample=[False, True], + use_lightweight_conv_block=[False, False], + depths=[1, 1], + hidden_sizes=[16, 32], + embedding_size=8, + use_learnable_affine_block=True, + hidden_act="relu", + image_shape=(None, None, 3), + out_features=["stage1", "stage2"], + ) + self.base_init_kwargs = { + "backbone": hgnetv2_backbone, + "decoder_in_channels": [16, 16], + "encoder_hidden_dim": 16, + "num_denoising": 10, + "num_labels": 4, + "hidden_dim": 16, + "learn_initial_query": False, + "num_queries": 10, + "anchor_image_size": (32, 32), + "feat_strides": [4, 8], + "num_feature_levels": 2, + "encoder_in_channels": [16, 32], + "encode_proj_layers": [1], + "num_attention_heads": 2, + "encoder_ffn_dim": 32, + "num_encoder_layers": 1, + "hidden_expansion": 0.5, + "depth_multiplier": 0.5, + "eval_idx": -1, + "num_decoder_layers": 3, + "decoder_attention_heads": 2, + "decoder_ffn_dim": 32, + "decoder_n_points": [2, 2], + "lqe_hidden_dim": 16, + "num_lqe_layers": 2, + "out_features": ["stage1", "stage2"], + "image_shape": (None, None, 3), + "seed": 0, + } + self.input_data = keras.random.uniform((2, 32, 32, 3)) + + @parameterized.named_parameters( + ("default_eval_last", False, 10, -1, 4), + ("denoising_eval_last", True, 30, -1, 4), + ("default_eval_first", False, 10, 0, 4), + ("denoising_eval_first", True, 30, 0, 4), + ("default_eval_middle", False, 10, 1, 4), + ("denoising_eval_middle", True, 30, 1, 4), + ) + def test_backbone_basics( + self, use_noise_and_labels, total_queries, eval_idx, num_logit_layers + ): + init_kwargs = self.base_init_kwargs.copy() + init_kwargs["eval_idx"] = eval_idx + if use_noise_and_labels: + init_kwargs["box_noise_scale"] = 1.0 + init_kwargs["label_noise_ratio"] = 0.5 + init_kwargs["labels"] = self.labels + expected_output_shape = { + "last_hidden_state": (2, total_queries, 16), + "intermediate_hidden_states": (2, 3, total_queries, 16), + "intermediate_logits": (2, num_logit_layers, total_queries, 4), + "intermediate_reference_points": ( + 2, + num_logit_layers, + total_queries, + 4, + ), + "intermediate_predicted_corners": ( + 2, + num_logit_layers, + total_queries, + 132, + ), + "initial_reference_points": ( + 2, + num_logit_layers, + total_queries, + 4, + ), + "encoder_last_hidden_state": (2, 8, 8, 16), + "init_reference_points": (2, total_queries, 4), + "enc_topk_logits": (2, 10, 4), + "enc_topk_bboxes": (2, 10, 4), + "enc_outputs_class": (2, 80, 4), + "enc_outputs_coord_logits": (2, 80, 4), + } + # NOTE: The `run_vision_backbone_test` helper's `channels_first` + # check transposes all 3D / 4D outputs by default, which is incorrect + # for `DFineBackbone` non-spatial outputs like + # `intermediate_hidden_states` (shape: `(batch_size, num_decoder_layers, + # num_queries, hidden_dim)`). Use `spatial_output_keys` to specify + # spatial outputs (e.g., `encoder_last_hidden_state`) for transposition, + # ensuring congruence with reference outputs. + # https://github.com/huggingface/transformers/blob/d37f7517972f67e3f2194c000ed0f87f064e5099/src/transformers/models/d_fine/modeling_d_fine.py#L1595-L1614 + # NOTE: `last_hidden_state`, `intermediate_hidden_states`, and + # `decoder_hidden_state` are non-spatial object query embeddings, + # despite their names, and should not be transposed. Other outputs + # not listed are visibly non-spatial. + self.run_vision_backbone_test( + cls=DFineBackbone, + init_kwargs=init_kwargs, + input_data=self.input_data, + expected_output_shape=expected_output_shape, + spatial_output_keys=[ + "encoder_last_hidden_state", + "encoder_hidden_states", + ], + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DFineBackbone, + init_kwargs=self.base_init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/d_fine/d_fine_decoder.py b/keras_hub/src/models/d_fine/d_fine_decoder.py new file mode 100644 index 0000000000..20759d30ec --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_decoder.py @@ -0,0 +1,944 @@ +import math + +import keras +import numpy as np + +from keras_hub.src.models.d_fine.d_fine_attention import DFineMultiheadAttention +from keras_hub.src.models.d_fine.d_fine_attention import ( + DFineMultiscaleDeformableAttention, +) +from keras_hub.src.models.d_fine.d_fine_layers import DFineGate +from keras_hub.src.models.d_fine.d_fine_layers import DFineIntegral +from keras_hub.src.models.d_fine.d_fine_layers import DFineLQE +from keras_hub.src.models.d_fine.d_fine_layers import DFineMLP +from keras_hub.src.models.d_fine.d_fine_layers import DFineMLPPredictionHead +from keras_hub.src.models.d_fine.d_fine_utils import d_fine_kernel_initializer +from keras_hub.src.models.d_fine.d_fine_utils import distance2bbox +from keras_hub.src.models.d_fine.d_fine_utils import inverse_sigmoid +from keras_hub.src.models.d_fine.d_fine_utils import weighting_function +from keras_hub.src.utils.keras_utils import clone_initializer + + +class DFineDecoderLayer(keras.layers.Layer): + """Single decoder layer for D-FINE models. + + This layer is the fundamental building block of the `DFineDecoder`. It + refines a set of object queries by first allowing them to interact with + each other via self-attention (`DFineMultiheadAttention`), and then + attending to the image features from the encoder via cross-attention + (`DFineMultiscaleDeformableAttention`). A feed-forward network with a + gating mechanism (`DFineGate`) further processes the queries. + + Args: + hidden_dim: int, Hidden dimension size for all attention and + feed-forward layers. + decoder_attention_heads: int, Number of attention heads for both + self-attention and cross-attention mechanisms. + attention_dropout: float, Dropout probability for attention weights. + decoder_activation_function: str, Activation function name for the + feed-forward network (e.g., `"relu"`, `"gelu"`, etc). + dropout: float, General dropout probability applied to layer outputs. + activation_dropout: float, Dropout probability applied after activation + in the feed-forward network. + layer_norm_eps: float, Epsilon value for layer normalization to prevent + division by zero. + decoder_ffn_dim: int, Hidden dimension size for the feed-forward + network. + num_feature_levels: int, Number of feature pyramid levels to attend to. + decoder_offset_scale: float, Scaling factor for deformable attention + offsets. + decoder_method: str, Method used for deformable attention computation. + decoder_n_points: int or list, Number of sampling points per feature + level. + If int, same number for all levels. If list, specific count per + level. + spatial_shapes: list, List of spatial dimensions `(height, width)` + for each feature level. + num_queries: int, Number of object queries processed by the decoder. + kernel_initializer: str or Initializer, optional, Initializer for + the kernel weights. Defaults to `"glorot_uniform"`. + bias_initializer: str or Initializer, optional, Initializer for + the bias weights. Defaults to `"zeros"`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + hidden_dim, + decoder_attention_heads, + attention_dropout, + decoder_activation_function, + dropout, + activation_dropout, + layer_norm_eps, + decoder_ffn_dim, + num_feature_levels, + decoder_offset_scale, + decoder_method, + decoder_n_points, + spatial_shapes, + num_queries, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_dim = hidden_dim + self.num_queries = num_queries + self.decoder_attention_heads = decoder_attention_heads + self.attention_dropout_rate = attention_dropout + self.decoder_activation_function = decoder_activation_function + self.layer_norm_eps = layer_norm_eps + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = num_feature_levels + self.decoder_offset_scale = decoder_offset_scale + self.decoder_method = decoder_method + self.decoder_n_points = decoder_n_points + self.spatial_shapes = spatial_shapes + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + + self.self_attn = DFineMultiheadAttention( + embedding_dim=self.hidden_dim, + num_heads=self.decoder_attention_heads, + dropout=self.attention_dropout_rate, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="self_attn", + ) + self.dropout_layer = keras.layers.Dropout( + rate=dropout, name="dropout_layer", dtype=self.dtype_policy + ) + self.activation_dropout_layer = keras.layers.Dropout( + rate=activation_dropout, + name="activation_dropout_layer", + dtype=self.dtype_policy, + ) + self.activation_fn = keras.layers.Activation( + self.decoder_activation_function, + name="activation_fn", + dtype=self.dtype_policy, + ) + self.self_attn_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_eps, + name="self_attn_layer_norm", + dtype=self.dtype_policy, + ) + self.encoder_attn = DFineMultiscaleDeformableAttention( + hidden_dim=self.hidden_dim, + decoder_attention_heads=self.decoder_attention_heads, + num_feature_levels=self.num_feature_levels, + decoder_offset_scale=self.decoder_offset_scale, + dtype=self.dtype_policy, + decoder_method=self.decoder_method, + decoder_n_points=self.decoder_n_points, + spatial_shapes=self.spatial_shapes, + num_queries=self.num_queries, + name="encoder_attn", + ) + self.fc1 = keras.layers.Dense( + self.decoder_ffn_dim, + name="fc1", + dtype=self.dtype_policy, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + ) + self.fc2 = keras.layers.Dense( + self.hidden_dim, + name="fc2", + dtype=self.dtype_policy, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + ) + self.final_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_eps, + name="final_layer_norm", + dtype=self.dtype_policy, + ) + self.gateway = DFineGate( + self.hidden_dim, name="gateway", dtype=self.dtype_policy + ) + + def build(self, input_shape): + batch_size = input_shape[0] + num_queries = input_shape[1] + hidden_dim = self.hidden_dim + attention_input_shape = (batch_size, num_queries, hidden_dim) + self.self_attn.build(attention_input_shape) + self.encoder_attn.build(attention_input_shape) + self.fc1.build(attention_input_shape) + self.fc2.build((batch_size, num_queries, self.decoder_ffn_dim)) + self.gateway.build(attention_input_shape) + self.self_attn_layer_norm.build(attention_input_shape) + self.final_layer_norm.build(attention_input_shape) + super().build(input_shape) + + def call( + self, + hidden_states, + position_embeddings=None, + reference_points=None, + spatial_shapes=None, + encoder_hidden_states=None, + attention_mask=None, + output_attentions=False, + training=None, + ): + self_attn_output, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + training=training, + ) + hidden_states_2 = self_attn_output + hidden_states_2 = self.dropout_layer(hidden_states_2, training=training) + hidden_states = hidden_states + hidden_states_2 + hidden_states = self.self_attn_layer_norm( + hidden_states, training=training + ) + residual = hidden_states + query_for_cross_attn = residual + if position_embeddings is not None: + query_for_cross_attn = query_for_cross_attn + position_embeddings + encoder_attn_output_tensor, cross_attn_weights_tensor = ( + self.encoder_attn( + hidden_states=query_for_cross_attn, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + training=training, + ) + ) + hidden_states_2 = encoder_attn_output_tensor + current_cross_attn_weights = ( + cross_attn_weights_tensor if output_attentions else None + ) + hidden_states_2 = self.dropout_layer(hidden_states_2, training=training) + hidden_states = self.gateway( + residual, hidden_states_2, training=training + ) + hidden_states_ffn = self.fc1(hidden_states) + hidden_states_2 = self.activation_fn( + hidden_states_ffn, training=training + ) + hidden_states_2 = self.activation_dropout_layer( + hidden_states_2, training=training + ) + hidden_states_2 = self.fc2(hidden_states_2) + hidden_states_2 = self.dropout_layer(hidden_states_2, training=training) + hidden_states = hidden_states + hidden_states_2 + dtype_name = keras.backend.standardize_dtype(self.compute_dtype) + if dtype_name == "float16": + clamp_value = np.finfo(np.float16).max - 1000.0 + else: # float32, bfloat16 + clamp_value = np.finfo(np.float32).max - 1000.0 + hidden_states_clamped = keras.ops.clip( + hidden_states, x_min=-clamp_value, x_max=clamp_value + ) + hidden_states = self.final_layer_norm( + hidden_states_clamped, training=training + ) + return hidden_states, self_attn_weights, current_cross_attn_weights + + def compute_output_spec( + self, + hidden_states, + position_embeddings=None, + reference_points=None, + spatial_shapes=None, + encoder_hidden_states=None, + attention_mask=None, + output_attentions=False, + training=None, + ): + hidden_states_output_spec = keras.KerasTensor( + shape=hidden_states.shape, dtype=self.compute_dtype + ) + self_attn_output_spec = self.self_attn.compute_output_spec( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + output_attentions=True, + ) + _, self_attn_weights_spec = self_attn_output_spec + _, cross_attn_weights_spec = self.encoder_attn.compute_output_spec( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + ) + if not output_attentions: + self_attn_weights_spec = None + cross_attn_weights_spec = None + return ( + hidden_states_output_spec, + self_attn_weights_spec, + cross_attn_weights_spec, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "decoder_attention_heads": self.decoder_attention_heads, + "attention_dropout": self.attention_dropout_rate, + "decoder_activation_function": self.decoder_activation_function, + "dropout": self.dropout_layer.rate, + "activation_dropout": self.activation_dropout_layer.rate, + "layer_norm_eps": self.layer_norm_eps, + "decoder_ffn_dim": self.decoder_ffn_dim, + "num_feature_levels": self.num_feature_levels, + "decoder_offset_scale": self.decoder_offset_scale, + "decoder_method": self.decoder_method, + "decoder_n_points": self.decoder_n_points, + "spatial_shapes": self.spatial_shapes, + "num_queries": self.num_queries, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + } + ) + return config + + +class DFineDecoder(keras.layers.Layer): + """Complete decoder module for D-FINE object detection models. + + This class implements the full D-FINE decoder, which is responsible for + transforming a set of object queries into final bounding box and class + predictions. It consists of a stack of `DFineDecoderLayer` instances that + iteratively refine the queries. At each layer, prediction heads + (`class_embed`, `bbox_embed`) generate intermediate outputs, which are used + for auxiliary loss calculation during training. The final layer's output + represents the model's predictions. + + Args: + eval_idx: int, Index of decoder layer used for evaluation. Negative + values count from the end (e.g., -1 for last layer). + num_decoder_layers: int, Number of decoder layers in the stack. + dropout: float, General dropout probability applied throughout the + decoder. + hidden_dim: int, Hidden dimension size for all components. + reg_scale: float, Scaling factor for regression loss and coordinate + prediction. + max_num_bins: int, Maximum number of bins for integral-based coordinate + prediction. + upsampling_factor: float, Upsampling factor used in coordinate + prediction weighting. + decoder_attention_heads: int, Number of attention heads in each decoder + layer. + attention_dropout: float, Dropout probability for attention mechanisms. + decoder_activation_function: str, Activation function for feed-forward + networks. + activation_dropout: float, Dropout probability after activation + functions. + layer_norm_eps: float, Epsilon for layer normalization stability. + decoder_ffn_dim: int, Hidden dimension for feed-forward networks. + num_feature_levels: int, Number of feature pyramid levels. + decoder_offset_scale: float, Scaling factor for deformable attention + offsets. + decoder_method: str, Method for deformable attention computation, + either `"default"` or `"discrete"`. + decoder_n_points: int or list, Number of sampling points per feature + level. + top_prob_values: int, Number of top probability values used in LQE. + lqe_hidden_dim: int, Hidden dimension for LQE networks. + num_lqe_layers: int, Number of layers in LQE networks. + num_labels: int, Number of object classes for classification. + spatial_shapes: list, Spatial dimensions for each feature level. + layer_scale: float, Scaling factor for layer-wise feature dimensions. + num_queries: int, Number of object queries processed by the decoder. + initializer_bias_prior_prob: float, optional, Prior probability for + the bias of the classification head. Used to initialize the bias + of the `class_embed` layers. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + eval_idx, + num_decoder_layers, + dropout, + hidden_dim, + reg_scale, + max_num_bins, + upsampling_factor, + decoder_attention_heads, + attention_dropout, + decoder_activation_function, + activation_dropout, + layer_norm_eps, + decoder_ffn_dim, + num_feature_levels, + decoder_offset_scale, + decoder_method, + decoder_n_points, + top_prob_values, + lqe_hidden_dim, + num_lqe_layers, + num_labels, + spatial_shapes, + layer_scale, + num_queries, + initializer_bias_prior_prob=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.eval_idx = ( + eval_idx if eval_idx >= 0 else num_decoder_layers + eval_idx + ) + self.dropout_rate = dropout + self.num_queries = num_queries + self.hidden_dim = hidden_dim + self.num_decoder_layers = num_decoder_layers + self.reg_scale_val = reg_scale + self.max_num_bins = max_num_bins + self.upsampling_factor = upsampling_factor + self.decoder_attention_heads = decoder_attention_heads + self.attention_dropout_rate = attention_dropout + self.decoder_activation_function = decoder_activation_function + self.activation_dropout_rate = activation_dropout + self.layer_norm_eps = layer_norm_eps + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = num_feature_levels + self.decoder_offset_scale = decoder_offset_scale + self.decoder_method = decoder_method + self.decoder_n_points = decoder_n_points + self.top_prob_values = top_prob_values + self.lqe_hidden_dim = lqe_hidden_dim + self.num_lqe_layers = num_lqe_layers + self.num_labels = num_labels + self.spatial_shapes = spatial_shapes + self.layer_scale = layer_scale + self.initializer_bias_prior_prob = initializer_bias_prior_prob + self.initializer = d_fine_kernel_initializer() + self.decoder_layers = [] + for i in range(self.num_decoder_layers): + self.decoder_layers.append( + DFineDecoderLayer( + self.hidden_dim, + self.decoder_attention_heads, + self.attention_dropout_rate, + self.decoder_activation_function, + self.dropout_rate, + self.activation_dropout_rate, + self.layer_norm_eps, + self.decoder_ffn_dim, + self.num_feature_levels, + self.decoder_offset_scale, + self.decoder_method, + self.decoder_n_points, + self.spatial_shapes, + num_queries=self.num_queries, + kernel_initializer=clone_initializer(self.initializer), + bias_initializer="zeros", + dtype=self.dtype_policy, + name=f"decoder_layer_{i}", + ) + ) + + self.query_pos_head = DFineMLPPredictionHead( + input_dim=4, + hidden_dim=(2 * self.hidden_dim), + output_dim=self.hidden_dim, + num_layers=2, + dtype=self.dtype_policy, + kernel_initializer=clone_initializer(self.initializer), + bias_initializer="zeros", + name="query_pos_head", + ) + + num_pred = self.num_decoder_layers + scaled_dim = round(self.hidden_dim * self.layer_scale) + if initializer_bias_prior_prob is None: + prior_prob = 1 / (self.num_labels + 1) + else: + prior_prob = initializer_bias_prior_prob + class_embed_bias = float(-math.log((1 - prior_prob) / prior_prob)) + self.class_embed = [ + keras.layers.Dense( + self.num_labels, + name=f"class_embed_{i}", + dtype=self.dtype_policy, + kernel_initializer="glorot_uniform", + bias_initializer=keras.initializers.Constant(class_embed_bias), + ) + for i in range(num_pred) + ] + self.bbox_embed = [ + DFineMLPPredictionHead( + input_dim=self.hidden_dim, + hidden_dim=self.hidden_dim, + output_dim=4 * (self.max_num_bins + 1), + num_layers=3, + name=f"bbox_embed_{i}", + dtype=self.dtype_policy, + kernel_initializer=clone_initializer(self.initializer), + bias_initializer="zeros", + last_layer_initializer="zeros", + ) + for i in range(self.eval_idx + 1) + ] + [ + DFineMLPPredictionHead( + input_dim=scaled_dim, + hidden_dim=scaled_dim, + output_dim=4 * (self.max_num_bins + 1), + num_layers=3, + name=f"bbox_embed_{i + self.eval_idx + 1}", + dtype=self.dtype_policy, + kernel_initializer=clone_initializer(self.initializer), + bias_initializer="zeros", + last_layer_initializer="zeros", + ) + for i in range(self.num_decoder_layers - self.eval_idx - 1) + ] + self.pre_bbox_head = DFineMLP( + input_dim=self.hidden_dim, + hidden_dim=self.hidden_dim, + output_dim=4, + num_layers=3, + activation_function="relu", + dtype=self.dtype_policy, + kernel_initializer=clone_initializer(self.initializer), + bias_initializer="zeros", + name="pre_bbox_head", + ) + + self.integral = DFineIntegral( + max_num_bins=self.max_num_bins, + name="integral", + dtype=self.dtype_policy, + ) + + self.num_head = self.decoder_attention_heads + + self.lqe_layers = [] + for i in range(self.num_decoder_layers): + self.lqe_layers.append( + DFineLQE( + top_prob_values=self.top_prob_values, + max_num_bins=self.max_num_bins, + lqe_hidden_dim=self.lqe_hidden_dim, + num_lqe_layers=self.num_lqe_layers, + dtype=self.dtype_policy, + name=f"lqe_layer_{i}", + ) + ) + + def build(self, input_shape): + if isinstance(input_shape, dict): + if "inputs_embeds" not in input_shape: + raise ValueError( + "DFineDecoder.build() received a dict input_shape " + "missing 'inputs_embeds' key. Please ensure 'inputs_embeds'" + " is passed correctly." + ) + inputs_embeds_shape = input_shape["inputs_embeds"] + elif ( + isinstance(input_shape, (list, tuple)) + and len(input_shape) > 0 + and isinstance(input_shape[0], (list, tuple)) + ): + inputs_embeds_shape = input_shape[0] + else: + inputs_embeds_shape = input_shape + if not isinstance(inputs_embeds_shape, tuple): + raise TypeError( + f"Internal error: inputs_embeds_shape was expected to be a " + f"tuple, but got {type(inputs_embeds_shape)} with value " + f"{inputs_embeds_shape}. Original input_shape: {input_shape}" + ) + + batch_size_ph = ( + inputs_embeds_shape[0] + if inputs_embeds_shape + and len(inputs_embeds_shape) > 0 + and inputs_embeds_shape[0] is not None + else None + ) + num_queries_ph = ( + inputs_embeds_shape[1] + if inputs_embeds_shape + and len(inputs_embeds_shape) > 1 + and inputs_embeds_shape[1] is not None + else None + ) + current_decoder_layer_input_shape = inputs_embeds_shape + for decoder_layer_instance in self.decoder_layers: + decoder_layer_instance.build(current_decoder_layer_input_shape) + qph_input_shape = (batch_size_ph, num_queries_ph, 4) + self.query_pos_head.build(qph_input_shape) + pre_bbox_head_input_shape = ( + batch_size_ph, + num_queries_ph, + self.hidden_dim, + ) + self.pre_bbox_head.build(pre_bbox_head_input_shape) + lqe_scores_shape = (batch_size_ph, num_queries_ph, 1) + lqe_pred_corners_dim = 4 * (self.max_num_bins + 1) + lqe_pred_corners_shape = ( + batch_size_ph, + num_queries_ph, + lqe_pred_corners_dim, + ) + lqe_build_input_shape_tuple = (lqe_scores_shape, lqe_pred_corners_shape) + for lqe_layer in self.lqe_layers: + lqe_layer.build(lqe_build_input_shape_tuple) + self.reg_scale = self.add_weight( + name="reg_scale", + shape=(1,), + initializer=keras.initializers.Constant(self.reg_scale_val), + trainable=False, + ) + self.upsampling_factor = self.add_weight( + name="upsampling_factor", + shape=(1,), + initializer=keras.initializers.Constant(self.upsampling_factor), + trainable=False, + ) + input_shape_for_class_embed = ( + batch_size_ph, + num_queries_ph, + self.hidden_dim, + ) + for class_embed_layer in self.class_embed: + class_embed_layer.build(input_shape_for_class_embed) + input_shape_for_bbox_embed = ( + batch_size_ph, + num_queries_ph, + self.hidden_dim, + ) + for bbox_embed_layer in self.bbox_embed: + bbox_embed_layer.build(input_shape_for_bbox_embed) + super().build(input_shape) + + def compute_output_spec( + self, + inputs_embeds, + encoder_hidden_states, + reference_points, + spatial_shapes, + attention_mask=None, + output_hidden_states=None, + output_attentions=None, + training=None, + ): + output_attentions = ( + False if output_attentions is None else output_attentions + ) + output_hidden_states = ( + False if output_hidden_states is None else output_hidden_states + ) + batch_size = inputs_embeds.shape[0] + num_queries = inputs_embeds.shape[1] + hidden_dim = inputs_embeds.shape[2] + last_hidden_state_spec = keras.KerasTensor( + shape=(batch_size, num_queries, hidden_dim), + dtype=self.compute_dtype, + ) + intermediate_hidden_states_spec = None + if output_hidden_states: + intermediate_hidden_states_spec = keras.KerasTensor( + shape=( + batch_size, + self.num_decoder_layers, + num_queries, + hidden_dim, + ), + dtype=self.compute_dtype, + ) + num_layers_with_logits = self.num_decoder_layers + 1 + intermediate_logits_spec = keras.KerasTensor( + shape=( + batch_size, + num_layers_with_logits, + num_queries, + self.num_labels, + ), + dtype=self.compute_dtype, + ) + intermediate_reference_points_spec = keras.KerasTensor( + shape=(batch_size, num_layers_with_logits, num_queries, 4), + dtype=self.compute_dtype, + ) + intermediate_predicted_corners_spec = keras.KerasTensor( + shape=( + batch_size, + num_layers_with_logits, + num_queries, + 4 * (self.max_num_bins + 1), + ), + dtype=self.compute_dtype, + ) + initial_reference_points_spec = keras.KerasTensor( + shape=(batch_size, num_layers_with_logits, num_queries, 4), + dtype=self.compute_dtype, + ) + all_hidden_states_spec = None + all_self_attns_spec = None + all_cross_attentions_spec = None + if output_hidden_states: + all_hidden_states_spec = tuple( + [last_hidden_state_spec] * (self.num_decoder_layers + 1) + ) + if output_attentions: + ( + _, + self_attn_spec, + cross_attn_spec, + ) = self.decoder_layers[0].compute_output_spec( + hidden_states=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + output_attentions=True, + ) + all_self_attns_spec = tuple( + [self_attn_spec] * self.num_decoder_layers + ) + if encoder_hidden_states is not None: + all_cross_attentions_spec = tuple( + [cross_attn_spec] * self.num_decoder_layers + ) + outputs_tuple = [ + last_hidden_state_spec, + intermediate_hidden_states_spec, + intermediate_logits_spec, + intermediate_reference_points_spec, + intermediate_predicted_corners_spec, + initial_reference_points_spec, + all_hidden_states_spec, + all_self_attns_spec, + all_cross_attentions_spec, + ] + return tuple(v for v in outputs_tuple if v is not None) + + def call( + self, + inputs_embeds, + encoder_hidden_states, + reference_points, + spatial_shapes, + attention_mask=None, + output_hidden_states=None, + output_attentions=None, + training=None, + ): + output_attentions = ( + False if output_attentions is None else output_attentions + ) + output_hidden_states = ( + False if output_hidden_states is None else output_hidden_states + ) + + hidden_states = inputs_embeds + + all_hidden_states = [] if output_hidden_states else None + all_self_attns = [] if output_attentions else None + all_cross_attentions = ( + [] + if (output_attentions and encoder_hidden_states is not None) + else None + ) + + intermediate_hidden_states = [] + intermediate_reference_points = [] + intermediate_logits = [] + intermediate_predicted_corners = [] + initial_reference_points = [] + + output_detach = ( + keras.ops.zeros_like(hidden_states) + if hidden_states is not None + else 0 + ) + pred_corners_undetach = 0 + + project_flat = weighting_function( + self.max_num_bins, self.upsampling_factor, self.reg_scale + ) + project = keras.ops.expand_dims(project_flat, axis=0) + + ref_points_detach = keras.ops.sigmoid(reference_points) + + for i, decoder_layer_instance in enumerate(self.decoder_layers): + ref_points_input = keras.ops.expand_dims(ref_points_detach, axis=2) + query_pos_embed = self.query_pos_head( + ref_points_detach, training=training + ) + query_pos_embed = keras.ops.clip(query_pos_embed, -10.0, 10.0) + + if output_hidden_states: + all_hidden_states.append(hidden_states) + + output_tuple = decoder_layer_instance( + hidden_states=hidden_states, + position_embeddings=query_pos_embed, + reference_points=ref_points_input, + spatial_shapes=spatial_shapes, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + training=training, + ) + hidden_states = output_tuple[0] + self_attn_weights_from_layer = output_tuple[1] + cross_attn_weights_from_layer = output_tuple[2] + + if i == 0: + pre_bbox_head_output = self.pre_bbox_head( + hidden_states, training=training + ) + new_reference_points = keras.ops.sigmoid( + pre_bbox_head_output + inverse_sigmoid(ref_points_detach) + ) + ref_points_initial = keras.ops.stop_gradient( + new_reference_points + ) + + if self.bbox_embed is not None: + bbox_embed_input = hidden_states + output_detach + pred_corners = ( + self.bbox_embed[i](bbox_embed_input, training=training) + + pred_corners_undetach + ) + integral_output = self.integral( + pred_corners, project, training=training + ) + inter_ref_bbox = distance2bbox( + ref_points_initial, integral_output, self.reg_scale + ) + pred_corners_undetach = pred_corners + ref_points_detach = keras.ops.stop_gradient(inter_ref_bbox) + + output_detach = keras.ops.stop_gradient(hidden_states) + + intermediate_hidden_states.append(hidden_states) + + if self.class_embed is not None and self.bbox_embed is not None: + class_scores = self.class_embed[i](hidden_states) + refined_scores = self.lqe_layers[i]( + class_scores, pred_corners, training=training + ) + if i == 0: + # NOTE: For first layer, output both, pre-LQE and post-LQE + # predictions, to provide an initial estimate. In the orig. + # implementation, the `torch.stack()` op would've thrown + # an error due to mismatched lengths. + intermediate_logits.append(class_scores) + intermediate_reference_points.append(new_reference_points) + initial_reference_points.append(ref_points_initial) + intermediate_predicted_corners.append(pred_corners) + intermediate_logits.append(refined_scores) + intermediate_reference_points.append(inter_ref_bbox) + initial_reference_points.append(ref_points_initial) + intermediate_predicted_corners.append(pred_corners) + + if output_attentions: + if self_attn_weights_from_layer is not None: + all_self_attns.append(self_attn_weights_from_layer) + if ( + encoder_hidden_states is not None + and cross_attn_weights_from_layer is not None + ): + all_cross_attentions.append(cross_attn_weights_from_layer) + + intermediate_stacked = ( + keras.ops.stack(intermediate_hidden_states, axis=1) + if intermediate_hidden_states + else None + ) + + if self.class_embed is not None and self.bbox_embed is not None: + intermediate_logits_stacked = ( + keras.ops.stack(intermediate_logits, axis=1) + if intermediate_logits + else None + ) + intermediate_predicted_corners_stacked = ( + keras.ops.stack(intermediate_predicted_corners, axis=1) + if intermediate_predicted_corners + else None + ) + initial_reference_points_stacked = ( + keras.ops.stack(initial_reference_points, axis=1) + if initial_reference_points + else None + ) + intermediate_reference_points_stacked = ( + keras.ops.stack(intermediate_reference_points, axis=1) + if intermediate_reference_points + else None + ) + else: + intermediate_logits_stacked = None + intermediate_predicted_corners_stacked = None + initial_reference_points_stacked = None + intermediate_reference_points_stacked = None + + if output_hidden_states: + all_hidden_states.append(hidden_states) + + all_hidden_states_tuple = ( + tuple(all_hidden_states) if output_hidden_states else None + ) + all_self_attns_tuple = ( + tuple(all_self_attns) if output_attentions else None + ) + all_cross_attentions_tuple = ( + tuple(all_cross_attentions) + if (output_attentions and encoder_hidden_states is not None) + else None + ) + + outputs_tuple = [ + hidden_states, + intermediate_stacked, + intermediate_logits_stacked, + intermediate_reference_points_stacked, + intermediate_predicted_corners_stacked, + initial_reference_points_stacked, + all_hidden_states_tuple, + all_self_attns_tuple, + all_cross_attentions_tuple, + ] + return tuple(v for v in outputs_tuple if v is not None) + + def get_config(self): + config = super().get_config() + config.update( + { + "eval_idx": self.eval_idx, + "num_decoder_layers": self.num_decoder_layers, + "dropout": self.dropout_rate, + "hidden_dim": self.hidden_dim, + "reg_scale": self.reg_scale_val, + "max_num_bins": self.max_num_bins, + "upsampling_factor": self.upsampling_factor, + "decoder_attention_heads": self.decoder_attention_heads, + "attention_dropout": self.attention_dropout_rate, + "decoder_activation_function": self.decoder_activation_function, + "activation_dropout": self.activation_dropout_rate, + "layer_norm_eps": self.layer_norm_eps, + "decoder_ffn_dim": self.decoder_ffn_dim, + "num_feature_levels": self.num_feature_levels, + "decoder_offset_scale": self.decoder_offset_scale, + "decoder_method": self.decoder_method, + "decoder_n_points": self.decoder_n_points, + "top_prob_values": self.top_prob_values, + "lqe_hidden_dim": self.lqe_hidden_dim, + "num_lqe_layers": self.num_lqe_layers, + "num_labels": self.num_labels, + "spatial_shapes": self.spatial_shapes, + "layer_scale": self.layer_scale, + "num_queries": self.num_queries, + "initializer_bias_prior_prob": self.initializer_bias_prior_prob, + } + ) + return config diff --git a/keras_hub/src/models/d_fine/d_fine_encoder.py b/keras_hub/src/models/d_fine/d_fine_encoder.py new file mode 100644 index 0000000000..dc9ba40aa3 --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_encoder.py @@ -0,0 +1,365 @@ +import keras +import numpy as np + +from keras_hub.src.models.d_fine.d_fine_attention import DFineMultiheadAttention +from keras_hub.src.utils.keras_utils import clone_initializer + + +class DFineEncoderLayer(keras.layers.Layer): + """Single encoder layer for D-FINE models. + + This layer is the fundamental building block of the `DFineEncoder`. It + implements a standard transformer encoder layer with multi-head + self-attention (`DFineMultiheadAttention`) and a feed-forward network. It is + used to process and refine the feature sequences from the CNN backbone. + + Args: + normalize_before: bool, Whether to apply layer normalization before + the attention and feed-forward sub-layers (pre-norm) or after + (post-norm). + encoder_hidden_dim: int, Hidden dimension size of the encoder. + num_attention_heads: int, Number of attention heads in multi-head + attention. + dropout: float, Dropout probability applied to attention outputs and + feed-forward outputs. + layer_norm_eps: float, Small constant added to the denominator for + numerical stability in layer normalization. + encoder_activation_function: str, Activation function used in the + feed-forward network. + activation_dropout: float, Dropout probability applied after the + activation function in the feed-forward network. + encoder_ffn_dim: int, Hidden dimension size of the feed-forward network. + **kwargs: Additional keyword arguments passed to the parent class. + kernel_initializer: str or Initializer, optional, Initializer for + the kernel weights. Defaults to `"glorot_uniform"`. + bias_initializer: str or Initializer, optional, Initializer for + the bias weights. Defaults to `"zeros"`. + """ + + def __init__( + self, + normalize_before, + encoder_hidden_dim, + num_attention_heads, + dropout, + layer_norm_eps, + encoder_activation_function, + activation_dropout, + encoder_ffn_dim, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.normalize_before = normalize_before + self.encoder_hidden_dim = encoder_hidden_dim + self.num_attention_heads = num_attention_heads + self.dropout_rate = dropout + self.layer_norm_eps = layer_norm_eps + self.encoder_activation_function = encoder_activation_function + self.activation_dropout_rate = activation_dropout + self.encoder_ffn_dim = encoder_ffn_dim + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + self.self_attn = DFineMultiheadAttention( + embedding_dim=self.encoder_hidden_dim, + num_heads=self.num_attention_heads, + dropout=self.dropout_rate, + dtype=self.dtype_policy, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + name="self_attn", + ) + self.self_attn_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_eps, + name="self_attn_layer_norm", + dtype=self.dtype_policy, + ) + self.dropout_layer = keras.layers.Dropout( + rate=self.dropout_rate, + name="dropout_layer", + dtype=self.dtype_policy, + ) + self.activation_fn_layer = keras.layers.Activation( + self.encoder_activation_function, + name="activation_fn_layer", + dtype=self.dtype_policy, + ) + self.activation_dropout_layer = keras.layers.Dropout( + rate=self.activation_dropout_rate, + name="activation_dropout_layer", + dtype=self.dtype_policy, + ) + self.fc1 = keras.layers.Dense( + self.encoder_ffn_dim, + name="fc1", + dtype=self.dtype_policy, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + ) + self.fc2 = keras.layers.Dense( + self.encoder_hidden_dim, + name="fc2", + dtype=self.dtype_policy, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + ) + self.final_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_eps, + name="final_layer_norm", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + self.self_attn.build(input_shape) + self.self_attn_layer_norm.build(input_shape) + self.fc1.build(input_shape) + self.fc2.build((input_shape[0], input_shape[1], self.encoder_ffn_dim)) + self.final_layer_norm.build(input_shape) + super().build(input_shape) + + def call( + self, + hidden_states, + attention_mask=None, + position_embeddings=None, + output_attentions=False, + training=None, + ): + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm( + hidden_states, training=training + ) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + training=training, + ) + hidden_states = self.dropout_layer(hidden_states, training=training) + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm( + hidden_states, training=training + ) + if self.normalize_before: + hidden_states = self.final_layer_norm( + hidden_states, training=training + ) + residual_ffn = hidden_states + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn_layer(hidden_states) + hidden_states = self.activation_dropout_layer( + hidden_states, training=training + ) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, training=training) + hidden_states = residual_ffn + hidden_states + if not self.normalize_before: + hidden_states = self.final_layer_norm( + hidden_states, training=training + ) + if training: + dtype_name = keras.backend.standardize_dtype(self.compute_dtype) + if dtype_name == "float16": + clamp_value = np.finfo(np.float16).max - 1000.0 + else: # float32, bfloat16 + clamp_value = np.finfo(np.float32).max - 1000.0 + hidden_states = keras.ops.clip( + hidden_states, x_min=-clamp_value, x_max=clamp_value + ) + if output_attentions: + return hidden_states, attn_weights + return hidden_states + + def compute_output_spec( + self, + hidden_states, + attention_mask=None, + position_embeddings=None, + output_attentions=False, + training=None, + ): + attn_output_spec = self.self_attn.compute_output_spec( + hidden_states, + position_embeddings, + attention_mask, + output_attentions, + ) + if output_attentions: + hidden_states_output_spec, self_attn_weights_spec = attn_output_spec + return hidden_states_output_spec, self_attn_weights_spec + return attn_output_spec + + def get_config(self): + config = super().get_config() + config.update( + { + "normalize_before": self.normalize_before, + "encoder_hidden_dim": self.encoder_hidden_dim, + "num_attention_heads": self.num_attention_heads, + "dropout": self.dropout_rate, + "layer_norm_eps": self.layer_norm_eps, + "encoder_activation_function": self.encoder_activation_function, + "activation_dropout": self.activation_dropout_rate, + "encoder_ffn_dim": self.encoder_ffn_dim, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + } + ) + return config + + +class DFineEncoder(keras.layers.Layer): + """Multi-layer encoder for D-FINE models. + + This layer implements a stack of `DFineEncoderLayer` instances. It is used + within the `DFineHybridEncoder` to apply transformer-based processing to + the feature maps from the CNN backbone, creating rich contextual + representations before they are passed to the FPN/PAN pathways. + + Args: + normalize_before: bool, Whether to apply layer normalization before + the attention and feed-forward sub-layers (pre-norm) or after + (post-norm) in each encoder layer. + encoder_hidden_dim: int, Hidden dimension size of the encoder layers. + num_attention_heads: int, Number of attention heads in multi-head + attention for each layer. + dropout: float, Dropout probability applied to attention outputs and + feed-forward outputs in each layer. + layer_norm_eps: float, Small constant added to the denominator for + numerical stability in layer normalization. + encoder_activation_function: str, Activation function used in the + feed-forward networks of each layer. + activation_dropout: float, Dropout probability applied after the + activation function in the feed-forward networks. + encoder_ffn_dim: int, Hidden dimension size of the feed-forward + networks in each layer. + num_encoder_layers: int, Number of encoder layers in the stack. + kernel_initializer: str or Initializer, optional, Initializer for + the kernel weights of each layer. Defaults to + `"glorot_uniform"`. + bias_initializer: str or Initializer, optional, Initializer for + the bias weights of each layer. Defaults to + `"zeros"`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + normalize_before, + encoder_hidden_dim, + num_attention_heads, + dropout, + layer_norm_eps, + encoder_activation_function, + activation_dropout, + encoder_ffn_dim, + num_encoder_layers, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.normalize_before = normalize_before + self.encoder_hidden_dim = encoder_hidden_dim + self.num_attention_heads = num_attention_heads + self.dropout_rate = dropout + self.layer_norm_eps = layer_norm_eps + self.encoder_activation_function = encoder_activation_function + self.activation_dropout_rate = activation_dropout + self.encoder_ffn_dim = encoder_ffn_dim + self.num_encoder_layers = num_encoder_layers + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.encoder_layer = [] + for i in range(self.num_encoder_layers): + layer = DFineEncoderLayer( + normalize_before=self.normalize_before, + encoder_hidden_dim=self.encoder_hidden_dim, + num_attention_heads=self.num_attention_heads, + dropout=self.dropout_rate, + layer_norm_eps=self.layer_norm_eps, + encoder_activation_function=self.encoder_activation_function, + activation_dropout=self.activation_dropout_rate, + encoder_ffn_dim=self.encoder_ffn_dim, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + dtype=self.dtype_policy, + name=f"encoder_layer_{i}", + ) + self.encoder_layer.append(layer) + + def build(self, input_shape): + current_input_shape_for_layer = input_shape + for encoder_layer_instance in self.encoder_layer: + encoder_layer_instance.build(current_input_shape_for_layer) + super().build(input_shape) + + def compute_output_spec( + self, src, src_mask=None, pos_embed=None, output_attentions=False + ): + if not self.encoder_layer: + if output_attentions: + return src, None + return src + encoder_layer_output_spec = self.encoder_layer[0].compute_output_spec( + hidden_states=src, + attention_mask=src_mask, + position_embeddings=pos_embed, + output_attentions=output_attentions, + ) + if output_attentions: + return encoder_layer_output_spec + return encoder_layer_output_spec + + def call( + self, + src, + src_mask=None, + pos_embed=None, + output_attentions=False, + training=None, + ): + current_hidden_tensor = src + last_layer_attn_weights = None + + for encoder_layer_instance in self.encoder_layer: + current_hidden_tensor, layer_attn_weights = encoder_layer_instance( + hidden_states=current_hidden_tensor, + attention_mask=src_mask, + position_embeddings=pos_embed, + output_attentions=output_attentions, + training=training, + ) + if output_attentions: + last_layer_attn_weights = layer_attn_weights + + return current_hidden_tensor, last_layer_attn_weights + + def get_config(self): + config = super().get_config() + config.update( + { + "normalize_before": self.normalize_before, + "encoder_hidden_dim": self.encoder_hidden_dim, + "num_attention_heads": self.num_attention_heads, + "dropout": self.dropout_rate, + "layer_norm_eps": self.layer_norm_eps, + "encoder_activation_function": self.encoder_activation_function, + "activation_dropout": self.activation_dropout_rate, + "encoder_ffn_dim": self.encoder_ffn_dim, + "num_encoder_layers": self.num_encoder_layers, + "kernel_initializer": self.kernel_initializer, + "bias_initializer": self.bias_initializer, + } + ) + return config diff --git a/keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py b/keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py new file mode 100644 index 0000000000..2c5ddd9596 --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py @@ -0,0 +1,642 @@ +import keras + +from keras_hub.src.models.d_fine.d_fine_encoder import DFineEncoder +from keras_hub.src.models.d_fine.d_fine_layers import DFineConvNormLayer +from keras_hub.src.models.d_fine.d_fine_layers import ( + DFineFeatureAggregationBlock, +) +from keras_hub.src.models.d_fine.d_fine_layers import DFineSCDown + + +class DFineHybridEncoder(keras.layers.Layer): + """Hybrid encoder for the D-FINE model. + + This layer sits between the HGNetV2 backbone (`HGNetV2Backbone`) and the + main `DFineDecoder`. It takes multi-scale feature maps from the backbone, + optionally refines them with transformer-based `DFineEncoder` layers, and + then fuses them using a Feature Pyramid Network (FPN) top-down pathway and a + Path Aggregation Network (PAN) bottom-up pathway. The resulting enriched + feature maps serve as the key and value inputs for the decoder's + cross-attention mechanism. + + Args: + encoder_in_channels: list of int, Input channel dimensions for each + feature level from the backbone. + feat_strides: list of int, Stride values for each feature level, + indicating the downsampling factor relative to the input image. + encoder_hidden_dim: int, Hidden dimension size used throughout the + encoder for feature projection and attention computation. + encode_proj_layers: list of int, Indices of feature levels to apply + transformer encoding to. Not all levels need transformer + processing. + positional_encoding_temperature: float, Temperature parameter for + sinusoidal positional embeddings used in transformer attention. + eval_size: tuple or None, Fixed evaluation size `(height, width)` for + consistent positional embeddings during inference. If `None`, + dynamic sizing is used. + normalize_before: bool, Whether to apply layer normalization before + attention and feed-forward operations in transformer layers. + num_attention_heads: int, Number of attention heads in multi-head + attention mechanisms within transformer layers. + dropout: float, Dropout probability applied to attention weights and + feed-forward networks for regularization. + layer_norm_eps: float, Small epsilon value for numerical stability in + layer normalization operations. + encoder_activation_function: str, Activation function used in + transformer feed-forward networks (e.g., `"relu"`, `"gelu"`). + activation_dropout: float, Dropout probability specifically applied to + activation functions in feed-forward networks. + encoder_ffn_dim: int, Hidden dimension size for feed-forward networks + within transformer layers. + num_encoder_layers: int, Number of transformer encoder layers to apply + at each selected feature level. + batch_norm_eps: float, Small epsilon value for numerical stability in + batch normalization operations used in components. + hidden_expansion: float, Expansion factor for hidden dimensions in + `DFineFeatureAggregationBlock` blocks used in FPN and PAN pathways. + depth_multiplier: float, Depth multiplier for scaling the number of + blocks in `DFineFeatureAggregationBlock` modules. + kernel_initializer: str or Initializer, optional, Initializer for + the kernel weights of each layer. Defaults to + `"glorot_uniform"`. + bias_initializer: str or Initializer, optional, Initializer for + the bias weights of each layer. Defaults to + `"zeros"`. + channel_axis: int, optional, The channel axis. Defaults to `None`. + data_format: str, optional, The data format. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + encoder_in_channels, + feat_strides, + encoder_hidden_dim, + encode_proj_layers, + positional_encoding_temperature, + eval_size, + normalize_before, + num_attention_heads, + dropout, + layer_norm_eps, + encoder_activation_function, + activation_dropout, + encoder_ffn_dim, + num_encoder_layers, + batch_norm_eps, + hidden_expansion, + depth_multiplier, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + channel_axis=None, + data_format=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + + self.encoder_in_channels = encoder_in_channels + self.num_fpn_stages = len(self.encoder_in_channels) - 1 + self.feat_strides = feat_strides + self.encoder_hidden_dim = encoder_hidden_dim + self.encode_proj_layers = encode_proj_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.eval_size = eval_size + self.out_channels = [ + self.encoder_hidden_dim for _ in self.encoder_in_channels + ] + self.out_strides = self.feat_strides + self.depth_multiplier = depth_multiplier + self.num_encoder_layers = num_encoder_layers + self.normalize_before = normalize_before + self.num_attention_heads = num_attention_heads + self.dropout_rate = dropout + self.layer_norm_eps = layer_norm_eps + self.encoder_activation_function = encoder_activation_function + self.activation_dropout_rate = activation_dropout + self.encoder_ffn_dim = encoder_ffn_dim + self.batch_norm_eps = batch_norm_eps + self.hidden_expansion = hidden_expansion + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.channel_axis = channel_axis + self.data_format = data_format + + self.encoder = [ + DFineEncoder( + normalize_before=self.normalize_before, + encoder_hidden_dim=self.encoder_hidden_dim, + num_attention_heads=self.num_attention_heads, + dropout=self.dropout_rate, + layer_norm_eps=self.layer_norm_eps, + encoder_activation_function=self.encoder_activation_function, + activation_dropout=self.activation_dropout_rate, + encoder_ffn_dim=self.encoder_ffn_dim, + dtype=self.dtype_policy, + num_encoder_layers=self.num_encoder_layers, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + name=f"d_fine_encoder_{i}", + ) + for i in range(len(self.encode_proj_layers)) + ] + + self.lateral_convs = [] + self.fpn_blocks = [] + for i in range(len(self.encoder_in_channels) - 1, 0, -1): + lateral_layer = DFineConvNormLayer( + filters=self.encoder_hidden_dim, + kernel_size=1, + batch_norm_eps=self.batch_norm_eps, + stride=1, + groups=1, + padding=0, + activation_function=None, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name=f"lateral_conv_{i}", + ) + self.lateral_convs.append(lateral_layer) + num_blocks = round(3 * self.depth_multiplier) + fpn_layer = DFineFeatureAggregationBlock( + encoder_hidden_dim=self.encoder_hidden_dim, + hidden_expansion=self.hidden_expansion, + batch_norm_eps=self.batch_norm_eps, + activation_function="silu", + num_blocks=num_blocks, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name=f"fpn_block_{i}", + ) + self.fpn_blocks.append(fpn_layer) + + self.downsample_convs = [] + self.pan_blocks = [] + for i in range(len(self.encoder_in_channels) - 1): + num_blocks = round(3 * self.depth_multiplier) + self.downsample_convs.append( + DFineSCDown( + encoder_hidden_dim=self.encoder_hidden_dim, + batch_norm_eps=self.batch_norm_eps, + kernel_size=3, + stride=2, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name=f"downsample_conv_{i}", + ) + ) + self.pan_blocks.append( + DFineFeatureAggregationBlock( + encoder_hidden_dim=self.encoder_hidden_dim, + hidden_expansion=self.hidden_expansion, + batch_norm_eps=self.batch_norm_eps, + activation_function="silu", + num_blocks=num_blocks, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name=f"pan_block_{i}", + ) + ) + + self.upsample = keras.layers.UpSampling2D( + size=(2, 2), + interpolation="nearest", + dtype=self.dtype_policy, + data_format=self.data_format, + name="upsample", + ) + self.identity = keras.layers.Identity( + dtype=self.dtype_policy, name="identity" + ) + + def build(self, input_shape): + inputs_embeds_shapes = input_shape + # Encoder layers. + if self.num_encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + feature_map_shape = inputs_embeds_shapes[enc_ind] + if self.data_format == "channels_last": + batch_s, h_s, w_s, c_s = feature_map_shape + else: # channels_first + batch_s, c_s, h_s, w_s = feature_map_shape + if h_s is not None and w_s is not None: + seq_len_for_this_encoder = h_s * w_s + else: + seq_len_for_this_encoder = None + encoder_input_shape = (batch_s, seq_len_for_this_encoder, c_s) + self.encoder[i].build(encoder_input_shape) + # FPN and PAN pathways. + # FPN (Top-down pathway). + fpn_feature_maps_shapes = [inputs_embeds_shapes[-1]] + for idx, (lateral_conv, fpn_block) in enumerate( + zip(self.lateral_convs, self.fpn_blocks) + ): + lateral_conv.build(fpn_feature_maps_shapes[-1]) + shape_after_lateral_conv = lateral_conv.compute_output_shape( + fpn_feature_maps_shapes[-1] + ) + if self.data_format == "channels_last": + batch_s, orig_h, orig_w, c = shape_after_lateral_conv + target_h = orig_h * 2 if orig_h is not None else None + target_w = orig_w * 2 if orig_w is not None else None + shape_after_resize = (batch_s, target_h, target_w, c) + else: + batch_s, c, orig_h, orig_w = shape_after_lateral_conv + target_h = orig_h * 2 if orig_h is not None else None + target_w = orig_w * 2 if orig_w is not None else None + shape_after_resize = (batch_s, c, target_h, target_w) + backbone_feature_map_k_shape = inputs_embeds_shapes[ + self.num_fpn_stages - idx - 1 + ] + shape_after_concat_fpn = list(shape_after_resize) + shape_after_concat_fpn[self.channel_axis] += ( + backbone_feature_map_k_shape[self.channel_axis] + ) + shape_after_concat_fpn = tuple(shape_after_concat_fpn) + fpn_block.build(shape_after_concat_fpn) + fpn_feature_maps_shapes.append( + fpn_block.compute_output_shape(shape_after_concat_fpn) + ) + # PAN (Bottom-up pathway). + reversed_fpn_feature_maps_shapes = fpn_feature_maps_shapes[::-1] + pan_feature_maps_shapes = [reversed_fpn_feature_maps_shapes[0]] + for idx, (downsample_conv, pan_block) in enumerate( + zip(self.downsample_convs, self.pan_blocks) + ): + downsample_conv.build(pan_feature_maps_shapes[-1]) + shape_after_downsample = downsample_conv.compute_output_shape( + pan_feature_maps_shapes[-1] + ) + fpn_shape = reversed_fpn_feature_maps_shapes[idx + 1] + concat_shape = list(shape_after_downsample) + concat_shape[self.channel_axis] += fpn_shape[self.channel_axis] + pan_block.build(tuple(concat_shape)) + pan_feature_maps_shapes.append( + pan_block.compute_output_shape(tuple(concat_shape)) + ) + super().build(input_shape) + + def call( + self, + inputs_embeds, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + training=None, + ): + hidden_states = [keras.ops.convert_to_tensor(t) for t in inputs_embeds] + + output_attentions = ( + output_attentions if output_attentions is not None else False + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else False + ) + + encoder_states_tuple = () if output_hidden_states else None + all_attentions_tuple = () if output_attentions else None + + processed_maps = {} + if self.num_encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + current_feature_map = hidden_states[enc_ind] + if output_hidden_states: + encoder_states_tuple = encoder_states_tuple + ( + self.identity(current_feature_map), + ) + + batch_size = keras.ops.shape(current_feature_map)[0] + if self.data_format == "channels_last": + height = keras.ops.shape(current_feature_map)[1] + width = keras.ops.shape(current_feature_map)[2] + channels = keras.ops.shape(current_feature_map)[-1] + src_flatten = keras.ops.reshape( + current_feature_map, + (batch_size, height * width, channels), + ) + else: + channels = keras.ops.shape(current_feature_map)[1] + height = keras.ops.shape(current_feature_map)[2] + width = keras.ops.shape(current_feature_map)[3] + + transposed_map = keras.ops.transpose( + current_feature_map, (0, 2, 3, 1) + ) + src_flatten = keras.ops.reshape( + transposed_map, + (batch_size, height * width, channels), + ) + + pos_embed = None + if training or self.eval_size is None: + pos_embed = self.build_2d_sincos_position_embedding( + width, + height, + self.encoder_hidden_dim, + self.positional_encoding_temperature, + dtype=self.compute_dtype, + ) + encoder_output = self.encoder[i]( + src=src_flatten, + src_mask=attention_mask, + pos_embed=pos_embed, + output_attentions=output_attentions, + training=training, + ) + if output_attentions: + processed_feature_map, layer_attentions = encoder_output + else: + processed_feature_map, layer_attentions = ( + encoder_output, + None, + ) + + if self.data_format == "channels_last": + processed_maps[enc_ind] = keras.ops.reshape( + processed_feature_map, + (batch_size, height, width, self.encoder_hidden_dim), + ) + else: + reshaped_map = keras.ops.reshape( + processed_feature_map, + (batch_size, height, width, self.encoder_hidden_dim), + ) + processed_maps[enc_ind] = keras.ops.transpose( + reshaped_map, (0, 3, 1, 2) + ) + + if output_attentions and layer_attentions is not None: + all_attentions_tuple = all_attentions_tuple + ( + layer_attentions, + ) + + processed_hidden_states = [] + for i in range(len(hidden_states)): + if i in processed_maps: + processed_hidden_states.append(processed_maps[i]) + else: + processed_hidden_states.append(hidden_states[i]) + if self.num_encoder_layers > 0: + if output_hidden_states: + encoder_states_tuple = encoder_states_tuple + ( + self.identity( + processed_hidden_states[self.encode_proj_layers[-1]] + ), + ) + else: + processed_hidden_states = hidden_states + fpn_inter_outputs = [] + y = processed_hidden_states[-1] + for idx, (lateral_conv, fpn_block) in enumerate( + zip(self.lateral_convs, self.fpn_blocks) + ): + backbone_feature_map_k = processed_hidden_states[ + self.num_fpn_stages - idx - 1 + ] + y_lateral = lateral_conv(y, training=training) + fpn_inter_outputs.append(y_lateral) + y_upsampled = self.upsample(y_lateral, training=training) + fused_feature_map_k = keras.ops.concatenate( + [y_upsampled, backbone_feature_map_k], + axis=self.channel_axis, + ) + y = fpn_block(fused_feature_map_k, training=training) + fpn_feature_maps = fpn_inter_outputs + [y] + + fpn_feature_maps = fpn_feature_maps[::-1] + + pan_feature_maps = [fpn_feature_maps[0]] + for idx, (downsample_conv, pan_block) in enumerate( + zip(self.downsample_convs, self.pan_blocks) + ): + top_pan_feature_map_k = pan_feature_maps[-1] + fpn_feature_map_k = fpn_feature_maps[idx + 1] + + downsampled_feature_map_k = downsample_conv( + top_pan_feature_map_k, training=training + ) + fused_feature_map_k = keras.ops.concatenate( + [downsampled_feature_map_k, fpn_feature_map_k], + axis=self.channel_axis, + ) + new_pan_feature_map_k = pan_block( + fused_feature_map_k, training=training + ) + pan_feature_maps.append(new_pan_feature_map_k) + + return tuple( + v + for v in [ + pan_feature_maps, + encoder_states_tuple if output_hidden_states else None, + all_attentions_tuple if output_attentions else None, + ] + if v is not None + ) + + @staticmethod + def build_2d_sincos_position_embedding( + width, + height, + embedding_dim=256, + temperature=10000.0, + dtype="float32", + ): + grid_w = keras.ops.arange(width, dtype=dtype) + grid_h = keras.ops.arange(height, dtype=dtype) + grid_w, grid_h = keras.ops.meshgrid(grid_w, grid_h, indexing="ij") + if embedding_dim % 4 != 0: + raise ValueError( + "Embed dimension must be divisible by 4 for 2D sin-cos position" + " embedding" + ) + pos_dim = embedding_dim // 4 + omega = keras.ops.arange(pos_dim, dtype=dtype) / pos_dim + omega = 1.0 / (temperature**omega) + + out_w = keras.ops.matmul( + keras.ops.reshape(grid_w, (-1, 1)), + keras.ops.reshape(omega, (1, -1)), + ) + out_h = keras.ops.matmul( + keras.ops.reshape(grid_h, (-1, 1)), + keras.ops.reshape(omega, (1, -1)), + ) + + concatenated_embeds = keras.ops.concatenate( + [ + keras.ops.sin(out_w), + keras.ops.cos(out_w), + keras.ops.sin(out_h), + keras.ops.cos(out_h), + ], + axis=1, + ) + return keras.ops.expand_dims(concatenated_embeds, axis=0) + + def get_config(self): + config = super().get_config() + config.update( + { + "encoder_in_channels": self.encoder_in_channels, + "feat_strides": self.feat_strides, + "encoder_hidden_dim": self.encoder_hidden_dim, + "encode_proj_layers": self.encode_proj_layers, + "positional_encoding_temperature": self.positional_encoding_temperature, # noqa: E501 + "eval_size": self.eval_size, + "normalize_before": self.normalize_before, + "num_attention_heads": self.num_attention_heads, + "dropout": self.dropout_rate, + "layer_norm_eps": self.layer_norm_eps, + "encoder_activation_function": self.encoder_activation_function, + "activation_dropout": self.activation_dropout_rate, + "encoder_ffn_dim": self.encoder_ffn_dim, + "num_encoder_layers": self.num_encoder_layers, + "batch_norm_eps": self.batch_norm_eps, + "hidden_expansion": self.hidden_expansion, + "depth_multiplier": self.depth_multiplier, + "kernel_initializer": self.kernel_initializer, + "bias_initializer": self.bias_initializer, + "channel_axis": self.channel_axis, + "data_format": self.data_format, + } + ) + return config + + def compute_output_spec( + self, + inputs_embeds, + attention_mask_spec=None, + output_attentions=None, + output_hidden_states=None, + training=None, + ): + output_attentions = ( + output_attentions if output_attentions is not None else False + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else False + ) + hidden_states_specs = list(inputs_embeds) + encoder_states_tuple_specs = () if output_hidden_states else None + all_attentions_tuple_specs = () if output_attentions else None + processed_maps_specs = {} + if self.num_encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + current_feature_map_spec = hidden_states_specs[enc_ind] + if output_hidden_states: + encoder_states_tuple_specs += ( + self.identity(current_feature_map_spec), + ) + if self.data_format == "channels_last": + batch_size, h, w, c = current_feature_map_spec.shape + else: + batch_size, c, h, w = current_feature_map_spec.shape + seq_len = h * w if h is not None and w is not None else None + src_flatten_spec = keras.KerasTensor( + (batch_size, seq_len, c), dtype=self.compute_dtype + ) + pos_embed_spec = keras.KerasTensor( + (batch_size, seq_len, self.encoder_hidden_dim), + dtype=self.compute_dtype, + ) + encoder_output_spec = self.encoder[i].compute_output_spec( + src=src_flatten_spec, + src_mask=attention_mask_spec, + pos_embed=pos_embed_spec, + output_attentions=output_attentions, + ) + if output_attentions: + _, layer_attentions_spec = encoder_output_spec + all_attentions_tuple_specs += (layer_attentions_spec,) + if self.data_format == "channels_last": + processed_maps_specs[enc_ind] = keras.KerasTensor( + (batch_size, h, w, self.encoder_hidden_dim), + dtype=self.compute_dtype, + ) + else: + processed_maps_specs[enc_ind] = keras.KerasTensor( + (batch_size, self.encoder_hidden_dim, h, w), + dtype=self.compute_dtype, + ) + processed_hidden_states_specs = [] + for i in range(len(hidden_states_specs)): + if i in processed_maps_specs: + processed_hidden_states_specs.append(processed_maps_specs[i]) + else: + processed_hidden_states_specs.append(hidden_states_specs[i]) + if self.num_encoder_layers > 0: + if output_hidden_states: + encoder_states_tuple_specs += ( + self.identity( + processed_hidden_states_specs[ + self.encode_proj_layers[-1] + ] + ), + ) + else: + processed_hidden_states_specs = hidden_states_specs + fpn_inter_outputs_specs = [] + y_spec = processed_hidden_states_specs[-1] + for idx, (lateral_conv, fpn_block) in enumerate( + zip(self.lateral_convs, self.fpn_blocks) + ): + backbone_feature_map_k_spec = processed_hidden_states_specs[ + self.num_fpn_stages - idx - 1 + ] + y_lateral_spec = keras.KerasTensor( + lateral_conv.compute_output_shape(y_spec.shape), + dtype=self.compute_dtype, + ) + fpn_inter_outputs_specs.append(y_lateral_spec) + y_upsampled_spec = keras.KerasTensor( + self.upsample.compute_output_shape(y_lateral_spec.shape), + dtype=self.compute_dtype, + ) + concat_shape = list(y_upsampled_spec.shape) + concat_shape[self.channel_axis] += ( + backbone_feature_map_k_spec.shape[self.channel_axis] + ) + y_spec = keras.KerasTensor( + fpn_block.compute_output_shape(tuple(concat_shape)), + dtype=self.compute_dtype, + ) + fpn_feature_maps_specs = fpn_inter_outputs_specs + [y_spec] + fpn_feature_maps_specs = fpn_feature_maps_specs[::-1] + pan_feature_maps_specs = [fpn_feature_maps_specs[0]] + for idx, (downsample_conv, pan_block) in enumerate( + zip(self.downsample_convs, self.pan_blocks) + ): + top_pan_feature_map_k_spec = pan_feature_maps_specs[-1] + fpn_feature_map_k_spec = fpn_feature_maps_specs[idx + 1] + downsampled_feature_map_k_spec = keras.KerasTensor( + downsample_conv.compute_output_shape( + top_pan_feature_map_k_spec.shape + ), + dtype=self.compute_dtype, + ) + concat_shape = list(downsampled_feature_map_k_spec.shape) + concat_shape[self.channel_axis] += fpn_feature_map_k_spec.shape[ + self.channel_axis + ] + new_pan_feature_map_k_spec = keras.KerasTensor( + pan_block.compute_output_shape(tuple(concat_shape)), + dtype=self.compute_dtype, + ) + pan_feature_maps_specs.append(new_pan_feature_map_k_spec) + outputs = [ + tuple(pan_feature_maps_specs), + ] + if output_hidden_states: + outputs.append(encoder_states_tuple_specs) + if output_attentions: + outputs.append(all_attentions_tuple_specs) + return tuple(outputs) if len(outputs) > 1 else outputs[0] diff --git a/keras_hub/src/models/d_fine/d_fine_image_converter.py b/keras_hub/src/models/d_fine/d_fine_image_converter.py new file mode 100644 index 0000000000..7f17e3411c --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.d_fine.d_fine_backbone import DFineBackbone + + +@keras_hub_export("keras_hub.layers.DFineImageConverter") +class DFineImageConverter(ImageConverter): + backbone_cls = DFineBackbone diff --git a/keras_hub/src/models/d_fine/d_fine_layers.py b/keras_hub/src/models/d_fine/d_fine_layers.py new file mode 100644 index 0000000000..0a843b43ed --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_layers.py @@ -0,0 +1,1828 @@ +import keras +import numpy as np + +from keras_hub.src.models.d_fine.d_fine_utils import inverse_sigmoid + + +class DFineGate(keras.layers.Layer): + """Gating layer for combining two input tensors using learnable gates. + + This layer is used within the `DFineDecoderLayer` to merge the output of + the self-attention mechanism (residual) with the output of the + cross-attention mechanism (`hidden_states`). It computes a weighted sum of + the two inputs, where the weights are learned gates. The result is + normalized using layer normalization. + + Args: + hidden_dim: int, The hidden dimension size for the gate computation. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__(self, hidden_dim, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.hidden_dim = hidden_dim + self.norm = keras.layers.LayerNormalization( + epsilon=1e-5, name="norm", dtype=self.dtype_policy + ) + self.gate = keras.layers.Dense( + 2 * self.hidden_dim, + name="gate", + dtype=self.dtype_policy, + kernel_initializer="zeros", + bias_initializer="zeros", + ) + + def build(self, input_shape): + batch_dim, seq_len_dim = None, None + if input_shape and len(input_shape) == 3: + batch_dim = input_shape[0] + seq_len_dim = input_shape[1] + gate_build_shape = (batch_dim, seq_len_dim, 2 * self.hidden_dim) + self.gate.build(gate_build_shape) + norm_build_shape = (batch_dim, seq_len_dim, self.hidden_dim) + self.norm.build(norm_build_shape) + super().build(input_shape) + + def call(self, second_residual, hidden_states, training=None): + gate_input = keras.ops.concatenate( + [second_residual, hidden_states], axis=-1 + ) + gates_linear_output = self.gate(gate_input) + gates = keras.ops.sigmoid(gates_linear_output) + gate_chunks = keras.ops.split(gates, 2, axis=-1) + gate1 = gate_chunks[0] + gate2 = gate_chunks[1] + gated_sum = gate1 * second_residual + gate2 * hidden_states + hidden_states = self.norm(gated_sum, training=training) + return hidden_states + + def get_config(self): + config = super().get_config() + config.update({"hidden_dim": self.hidden_dim}) + return config + + +class DFineMLP(keras.layers.Layer): + """Multi-layer perceptron (MLP) layer. + + This layer implements a standard MLP. It is used in several places within + the D-FINE model, such as the `reg_conf` head inside `DFineLQE` for + predicting quality scores and the `pre_bbox_head` in `DFineDecoder` for + initial bounding box predictions. + + Args: + input_dim: int, The input dimension. + hidden_dim: int, The hidden dimension for intermediate layers. + output_dim: int, The output dimension. + num_layers: int, The number of layers in the MLP. + activation_function: str, The activation function to use between layers. + kernel_initializer: str or Initializer, optional, Initializer for + the kernel weights. Defaults to `"glorot_uniform"`. + bias_initializer: str or Initializer, optional, Initializer for + the bias weights. Defaults to `"zeros"`. + last_layer_initializer: str or Initializer, optional, Special + initializer for the final layer's weights and biases. If `None`, + uses `kernel_initializer` and `bias_initializer`. Defaults to + `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + input_dim, + hidden_dim, + output_dim, + num_layers, + activation_function="relu", + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + last_layer_initializer=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.num_layers = num_layers + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.activation_function = activation_function + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + # NOTE: In the original code, this is done by searching the modules for + # specific last layers, instead, we find the last layer in each of the + # specific modules with `num_layers - 1`. + self.last_layer_initializer = keras.initializers.get( + last_layer_initializer + ) + h = [hidden_dim] * (num_layers - 1) + input_dims = [input_dim] + h + output_dims = h + [output_dim] + self.dense_layers = [] + for i, (_, out_dim) in enumerate(zip(input_dims, output_dims)): + # NOTE: Req. for handling the case of initializing the final layers' + # weights and biases to zero when required (for ex: `bbox_embed` or + # `reg_conf`). + is_last_layer = i == num_layers - 1 + current_kernel_init = self.kernel_initializer + current_bias_init = self.bias_initializer + if is_last_layer and self.last_layer_initializer is not None: + current_kernel_init = self.last_layer_initializer + current_bias_init = self.last_layer_initializer + self.dense_layers.append( + keras.layers.Dense( + units=out_dim, + name=f"mlp_dense_layer_{i}", + dtype=self.dtype_policy, + kernel_initializer=current_kernel_init, + bias_initializer=current_bias_init, + ) + ) + self.activation_layer = keras.layers.Activation( + activation_function, + name="mlp_activation_layer", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + if self.dense_layers: + current_build_shape = input_shape + for i, dense_layer in enumerate(self.dense_layers): + dense_layer.build(current_build_shape) + current_build_shape = dense_layer.compute_output_shape( + current_build_shape + ) + super().build(input_shape) + + def call(self, stat_features, training=None): + x = stat_features + for i in range(self.num_layers): + dense_layer = self.dense_layers[i] + x = dense_layer(x) + if i < self.num_layers - 1: + x = self.activation_layer(x) + return x + + def compute_output_spec(self, stat_features_spec): + output_shape = list(stat_features_spec.shape) + output_shape[-1] = self.output_dim + return keras.KerasTensor( + shape=tuple(output_shape), dtype=self.compute_dtype + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + "num_layers": self.num_layers, + "activation_function": self.activation_function, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + "last_layer_initializer": keras.initializers.serialize( + self.last_layer_initializer + ), + } + ) + return config + + +class DFineSourceFlattener(keras.layers.Layer): + """Layer to flatten and concatenate a list of source tensors. + + This layer is used in `DFineBackbone` to process feature maps from the + `DFineHybridEncoder`. It takes a list of multi-scale feature maps, + flattens each along its spatial dimensions, and concatenates them + along the sequence dimension. + + Args: + channel_axis: int, optional, The channel axis. Defaults to `None`. + data_format: str, optional, The data format. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, channel_axis=None, data_format=None, dtype=None, **kwargs + ): + super().__init__(dtype=dtype, **kwargs) + self.channel_axis = channel_axis + self.data_format = data_format + + def call(self, sources, training=None): + source_flatten = [] + for i, source_item in enumerate(sources): + if self.data_format == "channels_first": + source_item = keras.ops.transpose(source_item, [0, 2, 3, 1]) + batch_size = keras.ops.shape(source_item)[0] + channels = keras.ops.shape(source_item)[-1] + source_reshaped = keras.ops.reshape( + source_item, (batch_size, -1, channels) + ) + source_flatten.append(source_reshaped) + source_flatten_concatenated = keras.ops.concatenate( + source_flatten, axis=1 + ) + return source_flatten_concatenated + + def compute_output_shape(self, sources_shape): + if not sources_shape or not isinstance(sources_shape, list): + return tuple() + if not all(isinstance(s, tuple) and len(s) == 4 for s in sources_shape): + return tuple() + batch_size = sources_shape[0][0] + if self.data_format == "channels_first": + channels = sources_shape[0][1] + else: + channels = sources_shape[0][-1] + calculated_spatial_elements = [] + for s_shape in sources_shape: + if self.data_format == "channels_first": + h, w = s_shape[2], s_shape[3] + else: + h, w = s_shape[1], s_shape[2] + if h is None or w is None: + calculated_spatial_elements.append(None) + else: + calculated_spatial_elements.append(h * w) + if any(elem is None for elem in calculated_spatial_elements): + total_spatial_elements = None + else: + total_spatial_elements = sum(calculated_spatial_elements) + return (batch_size, total_spatial_elements, channels) + + def get_config(self): + config = super().get_config() + config.update( + { + "channel_axis": self.channel_axis, + "data_format": self.data_format, + } + ) + return config + + +class DFineContrastiveDenoisingGroupGenerator(keras.layers.Layer): + """Layer to generate denoising groups for contrastive learning. + + This layer, used in `DFineBackbone`, implements the core logic for + contrastive denoising, a key training strategy in D-FINE. It takes ground + truth `targets`, adds controlled noise to labels and boxes, and generates + the necessary attention masks, queries, and reference points for the + decoder. Due to functional model constraints, noise is generated once at + model initialization. + + Args: + num_labels: int, The number of object classes. + num_denoising: int, The number of denoising queries. + label_noise_ratio: float, The ratio of label noise to apply. + box_noise_scale: float, The scale of box noise to apply. + seed: int, optional, The random seed for noise generation. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + num_labels, + num_denoising, + label_noise_ratio, + box_noise_scale, + seed=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.num_labels = num_labels + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.seed = seed + self.seed_generator = keras.random.SeedGenerator(seed) + + def build(self, input_shape): + super().build(input_shape) + + def call(self, targets, num_queries): + if self.num_denoising <= 0: + return None, None, None, None + num_ground_truths = [len(t["labels"]) for t in targets] + max_gt_num = 0 + if num_ground_truths: + max_gt_num = max(num_ground_truths) + if max_gt_num == 0: + return None, None, None, None + num_groups_denoising_queries = self.num_denoising // max_gt_num + num_groups_denoising_queries = ( + 1 + if num_groups_denoising_queries == 0 + else num_groups_denoising_queries + ) + batch_size = len(num_ground_truths) + input_query_class = [] + input_query_bbox = [] + pad_gt_mask = [] + for i in range(batch_size): + num_gt = num_ground_truths[i] + if num_gt > 0: + labels = targets[i]["labels"] + boxes = targets[i]["boxes"] + padded_class_labels = keras.ops.pad( + labels, + [[0, max_gt_num - num_gt]], + constant_values=self.num_labels, + ) + padded_boxes = keras.ops.pad( + keras.ops.cast(boxes, dtype=self.compute_dtype), + [[0, max_gt_num - num_gt], [0, 0]], + constant_values=0.0, + ) + mask = keras.ops.concatenate( + [ + keras.ops.ones([num_gt], dtype="bool"), + keras.ops.zeros([max_gt_num - num_gt], dtype="bool"), + ] + ) + else: + padded_class_labels = keras.ops.full( + [max_gt_num], self.num_labels, dtype="int32" + ) + padded_boxes = keras.ops.zeros( + [max_gt_num, 4], dtype=self.compute_dtype + ) + mask = keras.ops.zeros([max_gt_num], dtype="bool") + input_query_class.append(padded_class_labels) + input_query_bbox.append(padded_boxes) + pad_gt_mask.append(mask) + input_query_class = keras.ops.stack(input_query_class, axis=0) + input_query_bbox = keras.ops.stack(input_query_bbox, axis=0) + pad_gt_mask = keras.ops.stack(pad_gt_mask, axis=0) + input_query_class = keras.ops.tile( + input_query_class, [1, 2 * num_groups_denoising_queries] + ) + input_query_bbox = keras.ops.tile( + input_query_bbox, [1, 2 * num_groups_denoising_queries, 1] + ) + pad_gt_mask = keras.ops.tile( + pad_gt_mask, [1, 2 * num_groups_denoising_queries] + ) + negative_gt_mask = keras.ops.zeros( + [batch_size, max_gt_num * 2, 1], dtype=self.compute_dtype + ) + updates_neg = keras.ops.ones( + [batch_size, max_gt_num, 1], dtype=negative_gt_mask.dtype + ) + negative_gt_mask = keras.ops.slice_update( + negative_gt_mask, [0, max_gt_num, 0], updates_neg + ) + negative_gt_mask = keras.ops.tile( + negative_gt_mask, [1, num_groups_denoising_queries, 1] + ) + positive_gt_mask_float = 1.0 - negative_gt_mask + squeezed_positive_gt_mask = keras.ops.squeeze( + positive_gt_mask_float, axis=-1 + ) + positive_gt_mask = squeezed_positive_gt_mask * keras.ops.cast( + pad_gt_mask, dtype=squeezed_positive_gt_mask.dtype + ) + denoise_positive_idx = [] + for i in range(batch_size): + mask_i = positive_gt_mask[i] + idx = keras.ops.nonzero(mask_i)[0] + denoise_positive_idx.append(idx) + if self.label_noise_ratio > 0: + noise_mask = keras.random.uniform( + keras.ops.shape(input_query_class), + dtype=self.compute_dtype, + seed=self.seed_generator, + ) < (self.label_noise_ratio * 0.5) + max_len = 0 + for idx in denoise_positive_idx: + current_len = keras.ops.shape(idx)[0] + if current_len > max_len: + max_len = current_len + padded_indices = [] + for idx in denoise_positive_idx: + current_len = keras.ops.shape(idx)[0] + pad_len = max_len - current_len + padded = keras.ops.pad(idx, [[0, pad_len]], constant_values=-1) + padded_indices.append(padded) + dn_positive_idx = ( + keras.ops.stack(padded_indices, axis=0) if padded_indices else None + ) + if self.label_noise_ratio > 0: + noise_mask = keras.ops.cast(noise_mask, "bool") + new_label = keras.random.randint( + keras.ops.shape(input_query_class), + 0, + self.num_labels, + seed=self.seed_generator, + dtype="int32", + ) + input_query_class = keras.ops.where( + noise_mask & pad_gt_mask, + new_label, + input_query_class, + ) + if self.box_noise_scale > 0: + known_bbox = keras.utils.bounding_boxes.convert_format( + input_query_bbox, + source="center_xywh", + target="xyxy", + dtype=self.compute_dtype, + ) + width_height = input_query_bbox[..., 2:] + diff = ( + keras.ops.tile(width_height, [1, 1, 2]) + * 0.5 + * self.box_noise_scale + ) + rand_int_sign = keras.random.randint( + keras.ops.shape(input_query_bbox), + 0, + 2, + seed=self.seed_generator, + ) + rand_sign = ( + keras.ops.cast(rand_int_sign, dtype=diff.dtype) * 2.0 - 1.0 + ) + rand_part = keras.random.uniform( + keras.ops.shape(input_query_bbox), + seed=self.seed_generator, + dtype=self.compute_dtype, + ) + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * ( + 1 - negative_gt_mask + ) + rand_part = rand_part * rand_sign + known_bbox = known_bbox + rand_part * diff + known_bbox = keras.ops.clip(known_bbox, 0.0, 1.0) + input_query_bbox = keras.utils.bounding_boxes.convert_format( + known_bbox, + source="xyxy", + target="center_xywh", + dtype=self.compute_dtype, + ) + input_query_bbox = inverse_sigmoid(input_query_bbox) + num_denoising_total = max_gt_num * 2 * num_groups_denoising_queries + target_size = num_denoising_total + num_queries + attn_mask = keras.ops.zeros( + [target_size, target_size], dtype=self.compute_dtype + ) + updates_attn1 = keras.ops.ones( + [ + target_size - num_denoising_total, + num_denoising_total, + ], + dtype=attn_mask.dtype, + ) + attn_mask = keras.ops.slice_update( + attn_mask, [num_denoising_total, 0], updates_attn1 + ) + for i in range(num_groups_denoising_queries): + start = max_gt_num * 2 * i + end = max_gt_num * 2 * (i + 1) + updates_attn2 = keras.ops.ones( + [end - start, start], dtype=attn_mask.dtype + ) + attn_mask = keras.ops.slice_update( + attn_mask, [start, 0], updates_attn2 + ) + updates_attn3 = keras.ops.ones( + [end - start, num_denoising_total - end], + dtype=attn_mask.dtype, + ) + attn_mask = keras.ops.slice_update( + attn_mask, [start, end], updates_attn3 + ) + if dn_positive_idx is not None: + denoising_meta_values = { + "dn_positive_idx": dn_positive_idx, + "dn_num_group": keras.ops.convert_to_tensor( + num_groups_denoising_queries, dtype="int32" + ), + "dn_num_split": keras.ops.convert_to_tensor( + [num_denoising_total, num_queries], dtype="int32" + ), + } + return ( + input_query_class, + input_query_bbox, + attn_mask, + denoising_meta_values, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_labels": self.num_labels, + "num_denoising": self.num_denoising, + "label_noise_ratio": self.label_noise_ratio, + "box_noise_scale": self.box_noise_scale, + "seed": self.seed, + } + ) + return config + + +class DFineAnchorGenerator(keras.layers.Layer): + """Layer to generate anchor boxes for object detection. + + This layer is used in `DFineBackbone` to generate anchor proposals. These + anchors are combined with the output of the encoder's bounding box head + (`enc_bbox_head`) to create initial reference points for the decoder's + queries. + + Args: + anchor_image_size: tuple, The size of the input image `(height, width)`. + feat_strides: list, The strides of the feature maps. + data_format: str, The data format of the image channels. Can be either + `"channels_first"` or `"channels_last"`. If `None` is specified, + it will use the `image_data_format` value found in your Keras + config file at `~/.keras/keras.json`. Defaults to `None`. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + anchor_image_size, + feat_strides, + data_format=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.anchor_image_size = anchor_image_size + self.feat_strides = feat_strides + self.data_format = data_format + + def call(self, sources_for_shape_derivation=None, grid_size=0.05): + spatial_shapes = None + if sources_for_shape_derivation is not None: + if self.data_format == "channels_first": + spatial_shapes = [ + (keras.ops.shape(s)[2], keras.ops.shape(s)[3]) + for s in sources_for_shape_derivation + ] + else: + spatial_shapes = [ + (keras.ops.shape(s)[1], keras.ops.shape(s)[2]) + for s in sources_for_shape_derivation + ] + + if spatial_shapes is None: + spatial_shapes = [ + ( + keras.ops.cast(self.anchor_image_size[0] / s, "int32"), + keras.ops.cast(self.anchor_image_size[1] / s, "int32"), + ) + for s in self.feat_strides + ] + + anchors = [] + for level, (height, width) in enumerate(spatial_shapes): + grid_y, grid_x = keras.ops.meshgrid( + keras.ops.arange(height, dtype=self.compute_dtype), + keras.ops.arange(width, dtype=self.compute_dtype), + indexing="ij", + ) + grid_xy = keras.ops.stack([grid_x, grid_y], axis=-1) + grid_xy = keras.ops.expand_dims(grid_xy, axis=0) + 0.5 + grid_xy = grid_xy / keras.ops.array( + [width, height], dtype=self.compute_dtype + ) + wh = keras.ops.ones_like(grid_xy) * grid_size * (2.0**level) + level_anchors = keras.ops.concatenate([grid_xy, wh], axis=-1) + level_anchors = keras.ops.reshape( + level_anchors, (-1, height * width, 4) + ) + anchors.append(level_anchors) + + eps = 1e-2 + anchors = keras.ops.concatenate(anchors, axis=1) + valid_mask = keras.ops.all( + (anchors > eps) & (anchors < 1 - eps), axis=-1, keepdims=True + ) + anchors_transformed = keras.ops.log(anchors / (1 - anchors)) + dtype_name = keras.backend.standardize_dtype(self.compute_dtype) + if dtype_name == "float16": + finfo_dtype = np.float16 + else: + finfo_dtype = np.float32 + max_float = keras.ops.array( + np.finfo(finfo_dtype).max, dtype=self.compute_dtype + ) + anchors = keras.ops.where(valid_mask, anchors_transformed, max_float) + + return anchors, valid_mask + + def compute_output_shape( + self, sources_for_shape_derivation_shape=None, grid_size_shape=None + ): + num_total_anchors_dim = None + + if sources_for_shape_derivation_shape is None: + num_total_anchors_calc = 0 + for s_stride in self.feat_strides: + h = self.anchor_image_size[0] // s_stride + w = self.anchor_image_size[1] // s_stride + num_total_anchors_calc += h * w + num_total_anchors_dim = num_total_anchors_calc + else: + calculated_spatial_elements = [] + for s_shape in sources_for_shape_derivation_shape: + if self.data_format == "channels_first": + h, w = s_shape[2], s_shape[3] + else: + h, w = s_shape[1], s_shape[2] + if h is None or w is None: + calculated_spatial_elements.append(None) + else: + calculated_spatial_elements.append(h * w) + if any(elem is None for elem in calculated_spatial_elements): + num_total_anchors_dim = None + else: + num_total_anchors_dim = sum(calculated_spatial_elements) + + anchors_shape = (1, num_total_anchors_dim, 4) + valid_mask_shape = (1, num_total_anchors_dim, 1) + return anchors_shape, valid_mask_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "anchor_image_size": self.anchor_image_size, + "feat_strides": self.feat_strides, + "data_format": self.data_format, + } + ) + return config + + +class DFineSpatialShapesExtractor(keras.layers.Layer): + """Layer to extract spatial shapes from input tensors. + + This layer is used in `DFineBackbone` to extract the spatial dimensions + (height, width) from the multi-scale feature maps. The resulting shape + tensor is passed to the `DFineDecoder` for use in deformable attention. + + Args: + data_format: str, optional, The data format of the input tensors. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__(self, data_format=None, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.data_format = data_format + + def call(self, sources): + if self.data_format == "channels_first": + spatial_shapes = [ + (keras.ops.shape(s)[2], keras.ops.shape(s)[3]) for s in sources + ] + else: + spatial_shapes = [ + (keras.ops.shape(s)[1], keras.ops.shape(s)[2]) for s in sources + ] + spatial_shapes_tensor = keras.ops.array(spatial_shapes, dtype="int32") + return spatial_shapes_tensor + + def compute_output_shape(self, input_shape): + if not isinstance(input_shape, list): + raise ValueError("Expected a list of shape tuples") + num_sources = len(input_shape) + return (num_sources, 2) + + def get_config(self): + config = super().get_config() + config.update({"data_format": self.data_format}) + return config + + +class DFineInitialQueryAndReferenceGenerator(keras.layers.Layer): + """Layer to generate initial queries and reference points for the decoder. + + This layer is a crucial component in `DFineBackbone` that bridges the + encoder and decoder. It selects the top-k predictions from the encoder's + output heads and uses them to generate the initial `target` (queries) and + `reference_points` that are fed into the `DFineDecoder`. + + Args: + num_queries: int, The number of queries to generate. + hidden_dim: int, The hidden dimension of the model. + learn_initial_query: bool, Whether to learn the initial query + embeddings. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + num_queries, + hidden_dim, + learn_initial_query, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.num_queries = num_queries + self.hidden_dim = hidden_dim + self.learn_initial_query = learn_initial_query + if self.learn_initial_query: + self.query_indices_base = keras.ops.expand_dims( + keras.ops.arange(self.num_queries, dtype="int32"), axis=0 + ) + self.weight_embedding = keras.layers.Embedding( + input_dim=num_queries, + output_dim=hidden_dim, + name="weight_embedding", + dtype=self.dtype_policy, + embeddings_initializer="glorot_uniform", + ) + else: + self.weight_embedding = None + + def call( + self, + inputs, + denoising_bbox_unact=None, + denoising_class=None, + training=None, + ): + ( + enc_outputs_class, + enc_outputs_coord_logits_plus_anchors, + output_memory, + sources_last_element, + ) = inputs + enc_outputs_class_max = keras.ops.max(enc_outputs_class, axis=-1) + topk_ind = keras.ops.top_k( + enc_outputs_class_max, k=self.num_queries, sorted=True + )[1] + + def gather_batch(elems): + data, indices = elems + return keras.ops.take(data, indices, axis=0) + + reference_points_unact = keras.ops.map( + gather_batch, (enc_outputs_coord_logits_plus_anchors, topk_ind) + ) + enc_topk_logits = keras.ops.map( + gather_batch, (enc_outputs_class, topk_ind) + ) + enc_topk_bboxes = keras.ops.sigmoid(reference_points_unact) + + if denoising_bbox_unact is not None: + current_batch_size = keras.ops.shape(reference_points_unact)[0] + denoising_bbox_unact = denoising_bbox_unact[:current_batch_size] + if denoising_class is not None: + denoising_class = denoising_class[:current_batch_size] + reference_points_unact = keras.ops.concatenate( + [denoising_bbox_unact, reference_points_unact], axis=1 + ) + if self.learn_initial_query: + query_indices = self.query_indices_base + target_embedding_val = self.weight_embedding( + query_indices, training=training + ) + batch_size = keras.ops.shape(sources_last_element)[0] + target = keras.ops.tile(target_embedding_val, [batch_size, 1, 1]) + else: + target = keras.ops.map(gather_batch, (output_memory, topk_ind)) + target = keras.ops.stop_gradient(target) + + if denoising_class is not None: + target = keras.ops.concatenate([denoising_class, target], axis=1) + init_reference_points = keras.ops.stop_gradient(reference_points_unact) + return init_reference_points, target, enc_topk_logits, enc_topk_bboxes + + def get_config(self): + config = super().get_config() + config.update( + { + "num_queries": self.num_queries, + "hidden_dim": self.hidden_dim, + "learn_initial_query": self.learn_initial_query, + } + ) + return config + + def compute_output_spec( + self, + inputs, + denoising_bbox_unact=None, + denoising_class=None, + training=None, + ): + ( + enc_outputs_class_spec, + _, + output_memory_spec, + _, + ) = inputs + batch_size = enc_outputs_class_spec.shape[0] + d_model_dim = output_memory_spec.shape[-1] + num_labels_dim = enc_outputs_class_spec.shape[-1] + num_queries_for_ref_points = self.num_queries + if denoising_bbox_unact is not None: + if len(denoising_bbox_unact.shape) > 1: + if denoising_bbox_unact.shape[1] is not None: + num_queries_for_ref_points = ( + denoising_bbox_unact.shape[1] + self.num_queries + ) + else: + num_queries_for_ref_points = None + num_queries_for_target = self.num_queries + if denoising_class is not None: + if len(denoising_class.shape) > 1: + if denoising_class.shape[1] is not None: + num_queries_for_target = ( + denoising_class.shape[1] + self.num_queries + ) + else: + num_queries_for_target = None + init_reference_points_spec = keras.KerasTensor( + shape=(batch_size, num_queries_for_ref_points, 4), + dtype=self.compute_dtype, + ) + target_spec = keras.KerasTensor( + shape=(batch_size, num_queries_for_target, d_model_dim), + dtype=self.compute_dtype, + ) + enc_topk_logits_spec = keras.KerasTensor( + shape=(batch_size, self.num_queries, num_labels_dim), + dtype=self.compute_dtype, + ) + enc_topk_bboxes_spec = keras.KerasTensor( + shape=(batch_size, self.num_queries, 4), dtype=self.compute_dtype + ) + + return ( + init_reference_points_spec, + target_spec, + enc_topk_logits_spec, + enc_topk_bboxes_spec, + ) + + +class DFineIntegral(keras.layers.Layer): + """Layer to compute integrated values from predicted corner probabilities. + + This layer implements the integral regression technique for bounding box + prediction. It is used in `DFineDecoder` to transform the predicted + distribution over bins (from `bbox_embed`) into continuous distance values, + which are then used to calculate the final box coordinates. + + Args: + max_num_bins: int, The maximum number of bins for the predictions. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__(self, max_num_bins, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.max_num_bins = max_num_bins + + def build(self, input_shape): + super().build(input_shape) + + def call(self, pred_corners, project, training=None): + original_shape = keras.ops.shape(pred_corners) + batch_size = original_shape[0] + num_queries = original_shape[1] + reshaped_pred_corners = keras.ops.reshape( + pred_corners, (-1, self.max_num_bins + 1) + ) + softmax_output = keras.ops.softmax(reshaped_pred_corners, axis=1) + linear_output = keras.ops.matmul( + softmax_output, keras.ops.transpose(project) + ) + squeezed_output = keras.ops.squeeze(linear_output, axis=-1) + output_grouped_by_4 = keras.ops.reshape(squeezed_output, (-1, 4)) + final_output = keras.ops.reshape( + output_grouped_by_4, (batch_size, num_queries, -1) + ) + return final_output + + def get_config(self): + config = super().get_config() + config.update( + { + "max_num_bins": self.max_num_bins, + } + ) + return config + + +class DFineLQE(keras.layers.Layer): + """Layer to compute quality scores for predictions. + + This layer, used within `DFineDecoder`, implements the Localization Quality + Estimation (LQE) head. It computes a quality score from the distribution of + predicted bounding box corners and adds this score to the classification + logits, enhancing prediction confidence. + + Args: + top_prob_values: int, The number of top probabilities to consider. + max_num_bins: int, The maximum number of bins for the predictions. + lqe_hidden_dim: int, The hidden dimension for the MLP. + num_lqe_layers: int, The number of layers in the MLP. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + top_prob_values, + max_num_bins, + lqe_hidden_dim, + num_lqe_layers, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.top_prob_values = top_prob_values + self.max_num_bins = max_num_bins + self.reg_conf = DFineMLP( + input_dim=4 * (self.top_prob_values + 1), + hidden_dim=lqe_hidden_dim, + output_dim=1, + num_layers=num_lqe_layers, + dtype=self.dtype_policy, + last_layer_initializer="zeros", + name="reg_conf", + ) + + def build(self, input_shape): + reg_conf_input_shape = ( + input_shape[0][0], + input_shape[0][1], + 4 * (self.top_prob_values + 1), + ) + self.reg_conf.build(reg_conf_input_shape) + super().build(input_shape) + + def call(self, scores, pred_corners, training=None): + original_shape = keras.ops.shape(pred_corners) + batch_size = original_shape[0] + length = original_shape[1] + reshaped_pred_corners = keras.ops.reshape( + pred_corners, (batch_size, length, 4, self.max_num_bins + 1) + ) + prob = keras.ops.softmax(reshaped_pred_corners, axis=-1) + prob_topk, _ = keras.ops.top_k( + prob, k=self.top_prob_values, sorted=True + ) + stat = keras.ops.concatenate( + [prob_topk, keras.ops.mean(prob_topk, axis=-1, keepdims=True)], + axis=-1, + ) + reshaped_stat = keras.ops.reshape(stat, (batch_size, length, -1)) + quality_score = self.reg_conf(reshaped_stat, training=training) + return scores + quality_score + + def get_config(self): + config = super().get_config() + config.update( + { + "top_prob_values": self.top_prob_values, + "max_num_bins": self.max_num_bins, + "lqe_hidden_dim": self.reg_conf.hidden_dim, + "num_lqe_layers": self.reg_conf.num_layers, + } + ) + return config + + +class DFineConvNormLayer(keras.layers.Layer): + """Convolutional layer with normalization and optional activation. + + This is a fundamental building block used in the CNN parts of D-FINE. It + combines a `Conv2D` layer with `BatchNormalization` and an optional + activation. It is used extensively in layers like `DFineRepVggBlock`, + `DFineCSPRepLayer`, and within the `DFineHybridEncoder`. + + Args: + filters: int, The number of output channels. + kernel_size: int, The size of the convolutional kernel. + batch_norm_eps: float, The epsilon value for batch normalization. + stride: int, The stride of the convolution. + groups: int, The number of groups for grouped convolution. + padding: int or None, The padding to apply. + activation_function: str or None, The activation function to use. + kernel_initializer: str or Initializer, optional, Initializer for + the kernel weights. Defaults to `"glorot_uniform"`. + bias_initializer: str or Initializer, optional, Initializer for + the bias weights. Defaults to `"zeros"`. + channel_axis: int, optional, The channel axis. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + filters, + kernel_size, + batch_norm_eps, + stride, + groups, + padding, + activation_function, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + channel_axis=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.filters = filters + self.kernel_size = kernel_size + self.batch_norm_eps = batch_norm_eps + self.stride = stride + self.groups = groups + self.padding_arg = padding + self.activation_function = activation_function + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + self.channel_axis = channel_axis + if self.padding_arg is None: + keras_conv_padding_mode = "same" + self.explicit_padding_layer = None + else: + keras_conv_padding_mode = "valid" + self.explicit_padding_layer = keras.layers.ZeroPadding2D( + padding=self.padding_arg, + name=f"{self.name}_explicit_padding", + dtype=self.dtype_policy, + ) + + self.convolution = keras.layers.Conv2D( + filters=self.filters, + kernel_size=self.kernel_size, + strides=self.stride, + padding=keras_conv_padding_mode, + groups=self.groups, + use_bias=False, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + name=f"{self.name}_convolution", + ) + self.normalization = keras.layers.BatchNormalization( + epsilon=self.batch_norm_eps, + name=f"{self.name}_normalization", + axis=self.channel_axis, + dtype=self.dtype_policy, + ) + self.activation_layer = ( + keras.layers.Activation( + self.activation_function, + name=f"{self.name}_activation", + dtype=self.dtype_policy, + ) + if self.activation_function + else keras.layers.Identity( + name=f"{self.name}_identity_activation", dtype=self.dtype_policy + ) + ) + + def build(self, input_shape): + if self.explicit_padding_layer: + self.explicit_padding_layer.build(input_shape) + shape = self.explicit_padding_layer.compute_output_shape( + input_shape + ) + else: + shape = input_shape + self.convolution.build(shape) + conv_output_shape = self.convolution.compute_output_shape(shape) + self.normalization.build(conv_output_shape) + self.activation_layer.build(conv_output_shape) + super().build(input_shape) + + def call(self, hidden_state, training=None): + if self.explicit_padding_layer: + hidden_state = self.explicit_padding_layer(hidden_state) + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state, training=training) + hidden_state = self.activation_layer(hidden_state) + return hidden_state + + def compute_output_shape(self, input_shape): + shape = input_shape + if self.explicit_padding_layer: + shape = self.explicit_padding_layer.compute_output_shape(shape) + return self.convolution.compute_output_shape(shape) + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "kernel_size": self.kernel_size, + "batch_norm_eps": self.batch_norm_eps, + "stride": self.stride, + "groups": self.groups, + "padding": self.padding_arg, + "activation_function": self.activation_function, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + "channel_axis": self.channel_axis, + } + ) + return config + + +class DFineRepVggBlock(keras.layers.Layer): + """RepVGG-style block with two parallel convolutional paths. + + This layer implements a block inspired by the RepVGG architecture, featuring + two parallel convolutional paths (3x3 and 1x1) that are summed. It serves + as the core bottleneck block within the `DFineCSPRepLayer`. + + Args: + activation_function: str, The activation function to use. + filters: int, The number of output channels. + batch_norm_eps: float, The epsilon value for batch normalization. + kernel_initializer: str or Initializer, optional, Initializer for + the kernel weights. Defaults to `"glorot_uniform"`. + bias_initializer: str or Initializer, optional, Initializer for + the bias weights. Defaults to `"zeros"`. + channel_axis: int, optional, The channel axis. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + activation_function, + filters, + batch_norm_eps=1e-5, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + channel_axis=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.activation_function = activation_function + self.filters = filters + self.batch_norm_eps = batch_norm_eps + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + self.channel_axis = channel_axis + self.conv1_layer = DFineConvNormLayer( + filters=self.filters, + kernel_size=3, + batch_norm_eps=self.batch_norm_eps, + stride=1, + groups=1, + padding=1, + activation_function=None, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name="conv1", + ) + self.conv2_layer = DFineConvNormLayer( + filters=self.filters, + kernel_size=1, + batch_norm_eps=self.batch_norm_eps, + stride=1, + groups=1, + padding=0, + activation_function=None, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name="conv2", + ) + self.activation_layer = ( + keras.layers.Activation( + self.activation_function, + name="block_activation", + dtype=self.dtype_policy, + ) + if self.activation_function + else keras.layers.Identity( + name="identity_activation", dtype=self.dtype_policy + ) + ) + + def build(self, input_shape): + self.conv1_layer.build(input_shape) + self.conv2_layer.build(input_shape) + self.activation_layer.build(input_shape) + super().build(input_shape) + + def call(self, x, training=None): + y1 = self.conv1_layer(x, training=training) + y2 = self.conv2_layer(x, training=training) + y = y1 + y2 + return self.activation_layer(y) + + def compute_output_shape(self, input_shape): + return self.conv1_layer.compute_output_shape(input_shape) + + def get_config(self): + config = super().get_config() + config.update( + { + "activation_function": self.activation_function, + "filters": self.filters, + "batch_norm_eps": self.batch_norm_eps, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + "channel_axis": self.channel_axis, + } + ) + return config + + +class DFineCSPRepLayer(keras.layers.Layer): + """CSP (Cross Stage Partial) layer with repeated bottleneck blocks. + + This layer implements a Cross Stage Partial (CSP) block using + `DFineRepVggBlock` as its bottleneck. It is a key component of the + `DFineFeatureAggregationBlock` block, which forms the FPN/PAN structure in + the `DFineHybridEncoder`. + + Args: + activation_function: str, The activation function to use. + batch_norm_eps: float, The epsilon value for batch normalization. + filters: int, The number of output channels. + num_blocks: int, The number of bottleneck blocks. + expansion: float, The expansion factor for hidden channels. Defaults to + `1.0`. + kernel_initializer: str or Initializer, optional, Initializer for + the kernel weights. Defaults to `"glorot_uniform"`. + bias_initializer: str or Initializer, optional, Initializer for + the bias weights. Defaults to `"zeros"`. + channel_axis: int, optional, The channel axis. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + activation_function, + batch_norm_eps, + filters, + num_blocks, + expansion=1.0, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + channel_axis=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.activation_function = activation_function + self.batch_norm_eps = batch_norm_eps + self.filters = filters + self.num_blocks = num_blocks + self.expansion = expansion + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + self.channel_axis = channel_axis + hidden_channels = int(self.filters * self.expansion) + self.conv1 = DFineConvNormLayer( + filters=hidden_channels, + kernel_size=1, + batch_norm_eps=self.batch_norm_eps, + stride=1, + groups=1, + padding=0, + activation_function=self.activation_function, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name="conv1", + ) + self.conv2 = DFineConvNormLayer( + filters=hidden_channels, + kernel_size=1, + batch_norm_eps=self.batch_norm_eps, + stride=1, + groups=1, + padding=0, + activation_function=self.activation_function, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name="conv2", + ) + self.bottleneck_layers = [ + DFineRepVggBlock( + activation_function=self.activation_function, + filters=hidden_channels, + batch_norm_eps=self.batch_norm_eps, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name=f"bottleneck_{i}", + ) + for i in range(self.num_blocks) + ] + if hidden_channels != self.filters: + self.conv3 = DFineConvNormLayer( + filters=self.filters, + kernel_size=1, + batch_norm_eps=self.batch_norm_eps, + stride=1, + groups=1, + padding=0, + activation_function=self.activation_function, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name="conv3", + ) + else: + self.conv3 = keras.layers.Identity( + name="conv3_identity", dtype=self.dtype_policy + ) + + def build(self, input_shape): + self.conv1.build(input_shape) + self.conv2.build(input_shape) + bottleneck_input_shape = self.conv1.compute_output_shape(input_shape) + for bottleneck_layer in self.bottleneck_layers: + bottleneck_layer.build(bottleneck_input_shape) + self.conv3.build(bottleneck_input_shape) + super().build(input_shape) + + def call(self, hidden_state, training=None): + hidden_state_1 = self.conv1(hidden_state, training=training) + for bottleneck_layer in self.bottleneck_layers: + hidden_state_1 = bottleneck_layer(hidden_state_1, training=training) + hidden_state_2 = self.conv2(hidden_state, training=training) + summed_hidden_states = hidden_state_1 + hidden_state_2 + if isinstance(self.conv3, keras.layers.Identity): + hidden_state_3 = self.conv3(summed_hidden_states) + else: + hidden_state_3 = self.conv3(summed_hidden_states, training=training) + return hidden_state_3 + + def compute_output_shape(self, input_shape): + shape_after_conv1 = self.conv1.compute_output_shape(input_shape) + return self.conv3.compute_output_shape(shape_after_conv1) + + def get_config(self): + config = super().get_config() + config.update( + { + "activation_function": self.activation_function, + "batch_norm_eps": self.batch_norm_eps, + "filters": self.filters, + "num_blocks": self.num_blocks, + "expansion": self.expansion, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + "channel_axis": self.channel_axis, + } + ) + return config + + +class DFineFeatureAggregationBlock(keras.layers.Layer): + """Complex block combining convolutional and CSP layers. + + This layer implements a complex feature extraction block combining multiple + convolutional and `DFineCSPRepLayer` layers. It is the main building block + for the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN) + pathways within the `DFineHybridEncoder`. + + Args: + encoder_hidden_dim: int, The hidden dimension of the encoder. + hidden_expansion: float, The expansion factor for hidden channels. + batch_norm_eps: float, The epsilon value for batch normalization. + activation_function: str, The activation function to use. + num_blocks: int, The number of blocks in the CSP layers. + kernel_initializer: str or Initializer, optional, Initializer for + the kernel weights. Defaults to `"glorot_uniform"`. + bias_initializer: str or Initializer, optional, Initializer for + the bias weights. Defaults to `"zeros"`. + channel_axis: int, optional, The channel axis. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + encoder_hidden_dim, + hidden_expansion, + batch_norm_eps, + activation_function, + num_blocks, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + channel_axis=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.encoder_hidden_dim = encoder_hidden_dim + self.hidden_expansion = hidden_expansion + self.batch_norm_eps = batch_norm_eps + self.activation_function = activation_function + self.num_blocks = num_blocks + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + self.channel_axis = channel_axis + + conv3_dim = self.encoder_hidden_dim * 2 + self.conv4_dim = int( + self.hidden_expansion * self.encoder_hidden_dim / 2 + ) + self.conv_dim = conv3_dim // 2 + self.conv1 = DFineConvNormLayer( + filters=conv3_dim, + kernel_size=1, + batch_norm_eps=self.batch_norm_eps, + stride=1, + groups=1, + padding=0, + activation_function=self.activation_function, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name="conv1", + ) + self.csp_rep1 = DFineCSPRepLayer( + activation_function=self.activation_function, + batch_norm_eps=self.batch_norm_eps, + filters=self.conv4_dim, + num_blocks=self.num_blocks, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name="csp_rep1", + ) + self.conv2 = DFineConvNormLayer( + filters=self.conv4_dim, + kernel_size=3, + batch_norm_eps=self.batch_norm_eps, + stride=1, + groups=1, + padding=1, + activation_function=self.activation_function, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name="conv2", + ) + self.csp_rep2 = DFineCSPRepLayer( + activation_function=self.activation_function, + batch_norm_eps=self.batch_norm_eps, + filters=self.conv4_dim, + num_blocks=self.num_blocks, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name="csp_rep2", + ) + self.conv3 = DFineConvNormLayer( + filters=self.conv4_dim, + kernel_size=3, + batch_norm_eps=self.batch_norm_eps, + stride=1, + groups=1, + padding=1, + activation_function=self.activation_function, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name="conv3", + ) + self.conv4 = DFineConvNormLayer( + filters=self.encoder_hidden_dim, + kernel_size=1, + batch_norm_eps=self.batch_norm_eps, + stride=1, + groups=1, + padding=0, + activation_function=self.activation_function, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name="conv4", + ) + + def build(self, input_shape): + self.conv1.build(input_shape) + shape_after_conv1 = self.conv1.compute_output_shape(input_shape) + csp_rep_input_shape_list = list(shape_after_conv1) + csp_rep_input_shape_list[self.channel_axis] = self.conv_dim + csp_rep_input_shape = tuple(csp_rep_input_shape_list) + self.csp_rep1.build(csp_rep_input_shape) + shape_after_csp_rep1 = self.csp_rep1.compute_output_shape( + csp_rep_input_shape + ) + self.conv2.build(shape_after_csp_rep1) + shape_after_conv2 = self.conv2.compute_output_shape( + shape_after_csp_rep1 + ) + self.csp_rep2.build(shape_after_conv2) + shape_after_csp_rep2 = self.csp_rep2.compute_output_shape( + shape_after_conv2 + ) + self.conv3.build(shape_after_csp_rep2) + shape_for_concat_list = list(shape_after_conv1) + shape_for_concat_list[self.channel_axis] = ( + self.conv_dim * 2 + self.conv4_dim * 2 + ) + shape_for_concat = tuple(shape_for_concat_list) + self.conv4.build(shape_for_concat) + super().build(input_shape) + + def call(self, input_features, training=None): + conv1_out = self.conv1(input_features, training=training) + split_features_tensor = keras.ops.split( + conv1_out, [self.conv_dim, self.conv_dim], axis=self.channel_axis + ) + split_features = list(split_features_tensor) + branch1 = self.csp_rep1(split_features[-1], training=training) + branch1 = self.conv2(branch1, training=training) + branch2 = self.csp_rep2(branch1, training=training) + branch2 = self.conv3(branch2, training=training) + split_features.extend([branch1, branch2]) + merged_features = keras.ops.concatenate( + split_features, axis=self.channel_axis + ) + merged_features = self.conv4(merged_features, training=training) + return merged_features + + def compute_output_shape(self, input_shape): + shape_after_conv1 = self.conv1.compute_output_shape(input_shape) + shape_for_concat_list = list(shape_after_conv1) + shape_for_concat_list[self.channel_axis] = ( + self.conv_dim * 2 + self.conv4_dim * 2 + ) + shape_for_concat = tuple(shape_for_concat_list) + return self.conv4.compute_output_shape(shape_for_concat) + + def get_config(self): + config = super().get_config() + config.update( + { + "encoder_hidden_dim": self.encoder_hidden_dim, + "hidden_expansion": self.hidden_expansion, + "batch_norm_eps": self.batch_norm_eps, + "activation_function": self.activation_function, + "num_blocks": self.num_blocks, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + "channel_axis": self.channel_axis, + } + ) + return config + + +class DFineSCDown(keras.layers.Layer): + """Downsampling layer using convolutions. + + This layer is used in the `DFineHybridEncoder` to perform downsampling. + Specifically, it is part of the Path Aggregation Network (PAN) bottom-up + pathway, reducing the spatial resolution of feature maps. + + Args: + encoder_hidden_dim: int, The hidden dimension of the encoder. + batch_norm_eps: float, The epsilon value for batch normalization. + kernel_size: int, The kernel size for the second convolution. + stride: int, The stride for the second convolution. + kernel_initializer: str or Initializer, optional, Initializer for + the kernel weights. Defaults to `"glorot_uniform"`. + bias_initializer: str or Initializer, optional, Initializer for + the bias weights. Defaults to `"zeros"`. + channel_axis: int, optional, The channel axis. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + encoder_hidden_dim, + batch_norm_eps, + kernel_size, + stride, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + channel_axis=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.encoder_hidden_dim = encoder_hidden_dim + self.batch_norm_eps = batch_norm_eps + self.conv2_kernel_size = kernel_size + self.conv2_stride = stride + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + self.channel_axis = channel_axis + self.conv1 = DFineConvNormLayer( + filters=self.encoder_hidden_dim, + kernel_size=1, + batch_norm_eps=self.batch_norm_eps, + stride=1, + groups=1, + padding=0, + activation_function=None, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name="conv1", + ) + self.conv2 = DFineConvNormLayer( + filters=self.encoder_hidden_dim, + kernel_size=self.conv2_kernel_size, + batch_norm_eps=self.batch_norm_eps, + stride=self.conv2_stride, + groups=self.encoder_hidden_dim, + padding=(self.conv2_kernel_size - 1) // 2, + activation_function=None, + dtype=self.dtype_policy, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + channel_axis=self.channel_axis, + name="conv2", + ) + + def build(self, input_shape): + self.conv1.build(input_shape) + shape_after_conv1 = self.conv1.compute_output_shape(input_shape) + self.conv2.build(shape_after_conv1) + super().build(input_shape) + + def call(self, input_features, training=None): + x = self.conv1(input_features, training=training) + x = self.conv2(x, training=training) + return x + + def compute_output_shape(self, input_shape): + shape_after_conv1 = self.conv1.compute_output_shape(input_shape) + return self.conv2.compute_output_shape(shape_after_conv1) + + def get_config(self): + config = super().get_config() + config.update( + { + "encoder_hidden_dim": self.encoder_hidden_dim, + "batch_norm_eps": self.batch_norm_eps, + "kernel_size": self.conv2_kernel_size, + "stride": self.conv2_stride, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + "channel_axis": self.channel_axis, + } + ) + return config + + +class DFineMLPPredictionHead(keras.layers.Layer): + """MLP head for making predictions from feature vectors. + + This layer is a generic MLP used for various prediction tasks in D-FINE. + It is used for the encoder's bounding box head (`enc_bbox_head` in + `DFineBackbone`), the decoder's bounding box embedding (`bbox_embed` in + `DFineDecoder`), and the query position head (`query_pos_head` in + `DFineDecoder`). + + Args: + input_dim: int, The input dimension. + hidden_dim: int, The hidden dimension for intermediate layers. + output_dim: int, The output dimension. + num_layers: int, The number of layers in the MLP. + kernel_initializer: str or Initializer, optional, Initializer for + the kernel weights. Defaults to `"glorot_uniform"`. + bias_initializer: str or Initializer, optional, Initializer for + the bias weights. Defaults to `"zeros"`. + last_layer_initializer: str or Initializer, optional, Special + initializer for the final layer's weights and biases. If `None`, + uses `kernel_initializer` and `bias_initializer`. Defaults to + `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + input_dim, + hidden_dim, + output_dim, + num_layers, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + last_layer_initializer=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.num_layers = num_layers + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + self.last_layer_initializer = keras.initializers.get( + last_layer_initializer + ) + + h = [self.hidden_dim] * (self.num_layers - 1) + input_dims = [self.input_dim] + h + output_dims = h + [self.output_dim] + + self.dense_layers = [] + for i, (_, out_dim) in enumerate(zip(input_dims, output_dims)): + is_last_layer = i == self.num_layers - 1 + current_kernel_init = self.kernel_initializer + current_bias_init = self.bias_initializer + if is_last_layer and self.last_layer_initializer is not None: + current_kernel_init = self.last_layer_initializer + current_bias_init = self.last_layer_initializer + self.dense_layers.append( + keras.layers.Dense( + units=out_dim, + name=f"linear_{i}", + dtype=self.dtype_policy, + kernel_initializer=current_kernel_init, + bias_initializer=current_bias_init, + ) + ) + + def build(self, input_shape): + if self.dense_layers: + current_build_shape = input_shape + for i, dense_layer in enumerate(self.dense_layers): + dense_layer.build(current_build_shape) + current_build_shape = dense_layer.compute_output_shape( + current_build_shape + ) + super().build(input_shape) + + def call(self, x, training=None): + current_x = x + for i, layer in enumerate(self.dense_layers): + current_x = layer(current_x) + if i < self.num_layers - 1: + current_x = keras.ops.relu(current_x) + return current_x + + def compute_output_spec(self, x_spec): + output_shape = list(x_spec.shape) + output_shape[-1] = self.output_dim + return keras.KerasTensor( + shape=tuple(output_shape), dtype=self.compute_dtype + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + "num_layers": self.num_layers, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + "last_layer_initializer": keras.initializers.serialize( + self.last_layer_initializer + ), + } + ) + return config diff --git a/keras_hub/src/models/d_fine/d_fine_loss.py b/keras_hub/src/models/d_fine/d_fine_loss.py new file mode 100644 index 0000000000..d53e722a77 --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_loss.py @@ -0,0 +1,938 @@ +import keras + +from keras_hub.src.models.d_fine.d_fine_utils import hungarian_assignment +from keras_hub.src.models.d_fine.d_fine_utils import weighting_function + + +def gather_along_first_two_dims(tensor, batch_idx, src_idx): + batch_size, num_queries, *feature_dims = keras.ops.shape(tensor) + batch_size = keras.ops.cast(batch_size, dtype=batch_idx.dtype) + num_queries = keras.ops.cast(num_queries, dtype=batch_idx.dtype) + linear_idx = batch_idx * num_queries + src_idx + flat_tensor = keras.ops.reshape(tensor, (-1, *feature_dims)) + gathered = keras.ops.take(flat_tensor, linear_idx, axis=0) + return gathered + + +def hungarian_matcher( + outputs, + targets, + num_targets_per_image, + use_focal_loss, + matcher_alpha, + matcher_gamma, + matcher_bbox_cost, + matcher_class_cost, + matcher_ciou_cost, + backbone, +): + """Performs bipartite matching between predictions and ground truths. + + This method implements the Hungarian matching algorithm to find the + optimal one-to-one assignment between the model's predictions (queries) + and the ground truth objects. The cost matrix for the assignment is a + weighted sum of three components: + 1. **Class Cost:** The cost of classifying a query into the wrong + class. + 2. **Bounding Box Cost:** The L1 distance between the predicted and + ground truth bounding boxes. + 3. **CIoU Cost:** The Complete Intersection over Union (CIoU) loss. + + Args: + outputs: dict, A dictionary containing predicted `"logits"` and + `"pred_boxes"`. + targets: list of dict, A list of dictionaries, each containing + the ground truth `"labels"` and `"boxes"`. + num_targets_per_image: A tensor of shape `(batch_size,)` indicating + the number of ground truth objects in each image. + + Returns: + tuple: A tuple of three tensors `(row_indices, col_indices, + valid_masks)`. `row_indices` and `col_indices` contain the indices + of matched predictions and ground truths, while `valid_masks` + indicates which matches are valid. + """ + batch_size = keras.ops.shape(outputs["logits"])[0] + num_queries = keras.ops.shape(outputs["logits"])[1] + out_logits = outputs["logits"] + out_bbox = outputs["pred_boxes"] + target_ids_all = keras.ops.cast(targets[0]["labels"], dtype="int32") + target_bbox_all = targets[0]["boxes"] + target_offsets = keras.ops.concatenate( + [ + keras.ops.zeros((1,), dtype="int32"), + keras.ops.cumsum(num_targets_per_image), + ] + ) + max_matches = num_queries + row_indices_init = keras.ops.zeros((batch_size, max_matches), dtype="int32") + col_indices_init = keras.ops.zeros((batch_size, max_matches), dtype="int32") + valid_masks_init = keras.ops.zeros((batch_size, max_matches), dtype="bool") + + def loop_body(i, loop_vars): + row_indices, col_indices, valid_masks = loop_vars + out_logits_i = out_logits[i] + out_bbox_i = out_bbox[i] + start = target_offsets[i] + end = target_offsets[i + 1] + num_targets_i = end - start + k = keras.ops.arange(0, num_queries) + is_valid_target_mask = k < num_targets_i + target_indices = start + k + safe_target_indices = keras.ops.minimum( + target_indices, keras.ops.shape(target_ids_all)[0] - 1 + ) + target_ids_i = keras.ops.take( + target_ids_all, safe_target_indices, axis=0 + ) + target_bbox_i = keras.ops.take( + target_bbox_all, safe_target_indices, axis=0 + ) + + def compute_cost_matrix(): + if use_focal_loss: + out_prob_i = keras.ops.sigmoid(out_logits_i) + safe_ids_for_take = keras.ops.maximum(target_ids_i, 0) + prob_for_target_classes = keras.ops.take( + out_prob_i, safe_ids_for_take, axis=1 + ) + p = prob_for_target_classes + pos_cost = ( + matcher_alpha + * keras.ops.power(1 - p, matcher_gamma) + * (-keras.ops.log(p + 1e-8)) + ) + neg_cost = ( + (1 - matcher_alpha) + * keras.ops.power(p, matcher_gamma) + * (-keras.ops.log(1 - p + 1e-8)) + ) + class_cost_i = pos_cost - neg_cost + else: + out_prob_softmax_i = keras.ops.softmax(out_logits_i, axis=-1) + safe_ids_for_take = keras.ops.maximum(target_ids_i, 0) + prob_for_target_classes = keras.ops.take( + out_prob_softmax_i, safe_ids_for_take, axis=1 + ) + class_cost_i = -prob_for_target_classes + + bbox_cost_i = keras.ops.sum( + keras.ops.abs( + keras.ops.expand_dims(out_bbox_i, 1) + - keras.ops.expand_dims(target_bbox_i, 0) + ), + axis=2, + ) + out_bbox_corners_i = keras.utils.bounding_boxes.convert_format( + out_bbox_i, + source="center_xywh", + target="xyxy", + ) + target_bbox_corners_i = keras.utils.bounding_boxes.convert_format( + target_bbox_i, + source="center_xywh", + target="xyxy", + ) + ciou_cost_i = -keras.utils.bounding_boxes.compute_ciou( + keras.ops.expand_dims(out_bbox_corners_i, 1), + keras.ops.expand_dims(target_bbox_corners_i, 0), + bounding_box_format="xyxy", + ) + + cost_matrix_i = ( + matcher_bbox_cost * bbox_cost_i + + matcher_class_cost * class_cost_i + + matcher_ciou_cost * ciou_cost_i + ) + cost_matrix_i = keras.ops.where( + keras.ops.expand_dims(is_valid_target_mask, 0), + cost_matrix_i, + 1e9, + ) + return cost_matrix_i + + def perform_assignment(): + cost_matrix_i = compute_cost_matrix() + row_idx, col_idx, valid_mask = hungarian_assignment( + cost_matrix_i, backbone.num_queries + ) + valid_mask = keras.ops.logical_and( + valid_mask, col_idx < num_targets_i + ) + return row_idx, col_idx, valid_mask + + def skip_assignment(): + return ( + keras.ops.zeros((num_queries,), dtype="int32"), + keras.ops.zeros((num_queries,), dtype="int32"), + keras.ops.zeros((num_queries,), dtype="bool"), + ) + + row_idx, col_idx, valid_mask = keras.ops.cond( + keras.ops.greater(num_targets_i, 0), + perform_assignment, + skip_assignment, + ) + row_indices = keras.ops.scatter_update( + row_indices, [[i]], keras.ops.expand_dims(row_idx, axis=0) + ) + col_indices = keras.ops.scatter_update( + col_indices, [[i]], keras.ops.expand_dims(col_idx, axis=0) + ) + valid_masks = keras.ops.scatter_update( + valid_masks, [[i]], keras.ops.expand_dims(valid_mask, axis=0) + ) + return row_indices, col_indices, valid_masks + + row_indices, col_indices, valid_masks = keras.ops.fori_loop( + 0, + batch_size, + loop_body, + (row_indices_init, col_indices_init, valid_masks_init), + ) + return (row_indices, col_indices, valid_masks) + + +def compute_vfl_loss( + outputs, + targets, + indices, + num_boxes, + num_classes, + matcher_alpha, + matcher_gamma, +): + """Computes the Varifocal Loss (VFL) for classification. + + VFL is an asymmetric focal loss variant designed for dense object + detection. It treats the Intersection over Union (IoU) between a + predicted box and its matched ground truth box as the target score for + positive examples while down-weighting the loss for negative examples. + This helps the model focus on high-quality localizations. + + Args: + outputs: dict, A dictionary containing the model's predictions, + including `"logits"` and `"pred_boxes"`. + targets: list of dict, A list of dictionaries containing ground + truth `"labels"` and `"boxes"`. + indices: tuple, `(row_ind, col_ind, valid_mask)` from the + Hungarian matcher, indicating the assignments between + predictions and targets. + num_boxes: int, The total number of ground truth boxes in the batch, + used for normalization. + + Returns: + Dictionary: The computed VFL loss. + """ + _, col_indices, valid_masks = indices + batch_idx, src_idx = _get_source_permutation_idx(indices) + src_boxes = gather_along_first_two_dims( + outputs["pred_boxes"], batch_idx, src_idx + ) + flat_col_indices = keras.ops.reshape(col_indices, (-1,)) + flat_valid_masks = keras.ops.reshape(valid_masks, (-1,)) + src_logits = outputs["logits"] + target_classes_init = keras.ops.full( + shape=keras.ops.shape(src_logits)[:2], + fill_value=num_classes, + dtype="int32", + ) + target_score_original = keras.ops.zeros_like( + target_classes_init, dtype=src_logits.dtype + ) + update_indices = keras.ops.stack([batch_idx, src_idx], axis=-1) + + def process_targets(): + target_labels_tensor = keras.ops.stack( + [t["labels"] for t in targets], axis=0 + ) + target_boxes_tensor = keras.ops.stack( + [t["boxes"] for t in targets], axis=0 + ) + if keras.ops.ndim(target_labels_tensor) == 3: + target_labels_tensor = keras.ops.squeeze( + target_labels_tensor, axis=1 + ) + if keras.ops.ndim(target_boxes_tensor) == 4: + target_boxes_tensor = keras.ops.squeeze(target_boxes_tensor, axis=1) + flat_target_labels = keras.ops.reshape(target_labels_tensor, (-1,)) + flat_target_boxes = keras.ops.reshape(target_boxes_tensor, (-1, 4)) + num_targets = keras.ops.shape(flat_target_labels)[0] + num_targets = keras.ops.cast(num_targets, dtype=flat_col_indices.dtype) + safe_flat_col_indices = keras.ops.where( + (flat_col_indices >= 0) & (flat_col_indices < num_targets), + flat_col_indices, + 0, + ) + target_classes_flat = keras.ops.take( + flat_target_labels, safe_flat_col_indices, axis=0 + ) + target_boxes_flat = keras.ops.take( + flat_target_boxes, safe_flat_col_indices, axis=0 + ) + target_classes_flat = keras.ops.where( + flat_valid_masks, target_classes_flat, num_classes + ) + target_boxes_flat = keras.ops.where( + keras.ops.expand_dims(flat_valid_masks, axis=-1), + target_boxes_flat, + 0.0, + ) + src_boxes_corners = keras.utils.bounding_boxes.convert_format( + keras.ops.stop_gradient(src_boxes), + source="center_xywh", + target="xyxy", + ) + target_boxes_corners = keras.utils.bounding_boxes.convert_format( + target_boxes_flat, + source="center_xywh", + target="xyxy", + ) + ious_matrix = keras.utils.bounding_boxes.compute_iou( + src_boxes_corners, + target_boxes_corners, + bounding_box_format="xyxy", + ) + ious = keras.ops.diagonal(ious_matrix) + ious = ious * keras.ops.cast(flat_valid_masks, dtype=ious.dtype) + target_classes_flat = keras.ops.cast(target_classes_flat, dtype="int32") + ious = keras.ops.cast(ious, dtype=src_logits.dtype) + target_classes_updated = keras.ops.scatter_update( + target_classes_init, update_indices, target_classes_flat + ) + target_score_updated = keras.ops.scatter_update( + target_score_original, update_indices, ious + ) + return target_classes_updated, target_score_updated + + target_classes, target_score_original = process_targets() + target_one_hot = keras.ops.one_hot( + target_classes, num_classes=num_classes + 1 + )[..., :-1] + target_score = ( + keras.ops.expand_dims(target_score_original, axis=-1) * target_one_hot + ) + pred_score_sigmoid = keras.ops.sigmoid(keras.ops.stop_gradient(src_logits)) + weight = ( + matcher_alpha + * keras.ops.power(pred_score_sigmoid, matcher_gamma) + * (1 - target_one_hot) + + target_score + ) + loss_vfl = keras.ops.binary_crossentropy( + target_score, src_logits, from_logits=True + ) + loss_vfl = loss_vfl * weight + loss_vfl = ( + keras.ops.sum(keras.ops.mean(loss_vfl, axis=1)) + * keras.ops.cast(keras.ops.shape(src_logits)[1], dtype=loss_vfl.dtype) + / num_boxes + ) + return {"loss_vfl": loss_vfl} + + +def compute_box_losses(outputs, targets, indices, num_boxes): + """Computes the bounding box regression losses. + + This function calculates two losses for the bounding boxes that were + successfully matched to ground truth objects by the Hungarian matcher: + 1. **L1 Loss (`loss_bbox`):** A regression loss that measures the + absolute difference between the predicted and ground truth box + coordinates. + 2. **Complete IoU Loss (`loss_ciou`):** A scale-invariant loss that + accounts for the shape and orientation of the boxes, providing a + better gradient signal than the standard IoU, especially for + non-overlapping boxes. + + Args: + outputs: dict, A dictionary containing predicted `"pred_boxes"`. + targets: list of dict, A list of dictionaries containing ground + truth `"boxes"`. + indices: tuple, The assignments from the Hungarian matcher. + num_boxes: int, The total number of ground truth boxes for + normalization. + + Returns: + Dictionary: A dictionary containing the L1 and CIoU losses. + """ + _, col_indices, valid_masks = indices + batch_idx, src_idx = _get_source_permutation_idx(indices) + src_boxes = gather_along_first_two_dims( + outputs["pred_boxes"], batch_idx, src_idx + ) + target_boxes_all = targets[0]["boxes"] + if keras.ops.ndim(target_boxes_all) == 3: + target_boxes_all = keras.ops.squeeze(target_boxes_all, axis=0) + col_indices_flat = keras.ops.reshape(col_indices, [-1]) + valid_masks_flat = keras.ops.reshape(valid_masks, [-1]) + max_box_idx = keras.ops.maximum(keras.ops.shape(target_boxes_all)[0] - 1, 0) + max_box_idx = keras.ops.cast(max_box_idx, dtype=col_indices_flat.dtype) + safe_col_indices = keras.ops.clip(col_indices_flat, 0, max_box_idx) + target_boxes = keras.ops.take(target_boxes_all, safe_col_indices, axis=0) + valid_masks_expanded = keras.ops.expand_dims(valid_masks_flat, axis=-1) + valid_masks_expanded = keras.ops.cast( + valid_masks_expanded, target_boxes.dtype + ) + target_boxes = target_boxes * valid_masks_expanded + l1_loss = keras.ops.sum( + keras.ops.abs(src_boxes - target_boxes) + * keras.ops.cast(valid_masks_expanded, src_boxes.dtype) + ) + src_boxes_xyxy = keras.utils.bounding_boxes.convert_format( + src_boxes, + source="center_xywh", + target="xyxy", + ) + target_boxes_xyxy = keras.utils.bounding_boxes.convert_format( + target_boxes, + source="center_xywh", + target="xyxy", + ) + ciou = keras.utils.bounding_boxes.compute_ciou( + src_boxes_xyxy, + target_boxes_xyxy, + bounding_box_format="xyxy", + ) + ciou_loss = keras.ops.sum( + (1.0 - ciou) * keras.ops.cast(valid_masks_flat, src_boxes.dtype) + ) + return { + "loss_bbox": l1_loss / num_boxes, + "loss_ciou": ciou_loss / num_boxes, + } + + +def compute_local_losses( + outputs, + targets, + indices, + num_boxes, + backbone, + ddf_temperature, + compute_ddf=None, +): + """Computes local refinement losses (FGL and DDF). + + This function calculates two advanced losses for fine-grained box + and feature refinement: + 1. **Focal Grid Loss (`loss_fgl`):** This loss operates on the + integral-based representation of the bounding box corners. It is a + focal loss applied to the distribution over discrete bins, + encouraging the model to produce sharp, unimodal distributions + around the true corner locations. + 2. **Distribution-guided Denoising Focal Loss (`loss_ddf`):** This is + a knowledge distillation loss used for auxiliary decoder layers. It + minimizes the KL-divergence between the corner prediction + distribution of an intermediate layer (student) and that of the + final decoder layer (teacher). This guides the intermediate layers + to learn features that are consistent with the final, most refined + predictions. + + Args: + outputs: dict, A dictionary of model predictions, including + `"pred_corners"`, `"ref_points"`, and potentially teacher + predictions like `"teacher_corners"` and `"teacher_logits"`. + targets: list of dict, A list of dictionaries with ground truth + `"boxes"`. + indices: tuple of Tensors, The assignments from the Hungarian + matcher. + num_boxes: scalar Tensor, The total number of ground truth boxes for + normalization. + compute_ddf: bool, Indicates whether to compute the DDF loss. + + Returns: + Dictionary: A dictionary containing the computed FGL and DDF losses. + """ + losses = {} + if ( + "pred_corners" not in outputs + or outputs["pred_corners"] is None + or "ref_points" not in outputs + or outputs["ref_points"] is None + ): + losses["loss_fgl"] = keras.ops.convert_to_tensor( + 0.0, dtype=keras.backend.floatx() + ) + losses["loss_ddf"] = keras.ops.convert_to_tensor( + 0.0, dtype=keras.backend.floatx() + ) + return losses + + if compute_ddf is None: + compute_ddf = ( + "teacher_corners" in outputs + and outputs["teacher_corners"] is not None + and "teacher_logits" in outputs + ) + + _, col_indices, valid_masks = indices + batch_idx, src_idx = _get_source_permutation_idx(indices) + col_indices_flat = keras.ops.reshape(col_indices, [-1]) + valid_masks_flat = keras.ops.reshape(valid_masks, [-1]) + target_boxes_all = targets[0]["boxes"] + if keras.ops.ndim(target_boxes_all) == 3: + target_boxes_all = keras.ops.squeeze(target_boxes_all, axis=0) + max_box_idx = keras.ops.maximum(keras.ops.shape(target_boxes_all)[0] - 1, 0) + max_box_idx = keras.ops.cast(max_box_idx, dtype=col_indices_flat.dtype) + safe_col_indices = keras.ops.clip(col_indices_flat, 0, max_box_idx) + target_boxes_matched_center = keras.ops.take( + target_boxes_all, safe_col_indices, axis=0 + ) + valid_masks_expanded = keras.ops.expand_dims(valid_masks_flat, axis=-1) + valid_masks_expanded = keras.ops.cast( + valid_masks_expanded, target_boxes_matched_center.dtype + ) + target_boxes_matched_center = ( + target_boxes_matched_center * valid_masks_expanded + ) + + pred_corners_matched_flat = gather_along_first_two_dims( + outputs["pred_corners"], batch_idx, src_idx + ) + pred_corners_matched = keras.ops.reshape( + pred_corners_matched_flat, + (-1, backbone.decoder.max_num_bins + 1), + ) + ref_points_matched = gather_along_first_two_dims( + outputs["ref_points"], batch_idx, src_idx + ) + ref_points_matched = keras.ops.stop_gradient(ref_points_matched) + target_boxes_corners_matched = keras.utils.bounding_boxes.convert_format( + target_boxes_matched_center, + source="center_xywh", + target="xyxy", + ) + reg_scale_tensor = backbone.decoder.reg_scale + up_tensor = backbone.decoder.upsampling_factor + target_corners_dist, weight_right, weight_left = bbox2distance( + ref_points_matched, + target_boxes_corners_matched, + backbone.decoder.max_num_bins, + reg_scale_tensor, + up_tensor, + ) + pred_boxes_matched_center = gather_along_first_two_dims( + outputs["pred_boxes"], batch_idx, src_idx + ) + pred_boxes_corners_matched = keras.utils.bounding_boxes.convert_format( + pred_boxes_matched_center, + source="center_xywh", + target="xyxy", + ) + ious_pairwise = keras.utils.bounding_boxes.compute_iou( + pred_boxes_corners_matched, + target_boxes_corners_matched, + bounding_box_format="xyxy", + ) + ious = keras.ops.diagonal(ious_pairwise) + ious = ious * keras.ops.cast(valid_masks_flat, dtype=ious.dtype) + weight_targets_fgl = keras.ops.reshape( + keras.ops.tile(keras.ops.expand_dims(ious, 1), [1, 4]), + [-1], + ) + weight_targets_fgl = keras.ops.stop_gradient(weight_targets_fgl) + losses["loss_fgl"] = unimodal_distribution_focal_loss( + pred_corners_matched, + target_corners_dist, + weight_right, + weight_left, + weight=weight_targets_fgl, + avg_factor=num_boxes, + ) + + def ddf_true_fn(): + pred_corners_all = keras.ops.reshape( + outputs["pred_corners"], + (-1, backbone.decoder.max_num_bins + 1), + ) + target_corners_all = keras.ops.reshape( + keras.ops.stop_gradient(outputs["teacher_corners"]), + (-1, backbone.decoder.max_num_bins + 1), + ) + + def compute_ddf_loss_fn(): + weight_targets_local = keras.ops.max( + keras.ops.sigmoid(outputs["teacher_logits"]), axis=-1 + ) + num_queries = keras.ops.cast( + keras.ops.shape(weight_targets_local)[1], + dtype=batch_idx.dtype, + ) + flat_update_indices = batch_idx * num_queries + src_idx + flat_update_indices = keras.ops.expand_dims( + flat_update_indices, axis=-1 + ) + mask = keras.ops.zeros_like(weight_targets_local, dtype="bool") + mask_flat = keras.ops.scatter_update( + keras.ops.reshape(mask, (-1,)), + flat_update_indices, + keras.ops.ones_like(batch_idx, dtype="bool"), + ) + mask = keras.ops.reshape( + mask_flat, keras.ops.shape(weight_targets_local) + ) + weight_targets_local_flat = keras.ops.reshape( + weight_targets_local, (-1,) + ) + weight_targets_local_matched_flat = keras.ops.scatter_update( + weight_targets_local_flat, + flat_update_indices, + ious, + ) + weight_targets_local = keras.ops.reshape( + weight_targets_local_matched_flat, + keras.ops.shape(weight_targets_local), + ) + weight_targets_local_expanded = keras.ops.reshape( + keras.ops.tile( + keras.ops.expand_dims(weight_targets_local, axis=-1), + [1, 1, 4], + ), + [-1], + ) + weight_targets_local_expanded = keras.ops.stop_gradient( + weight_targets_local_expanded + ) + # NOTE: Original impl hardcodes `ddf_temperature` to 5.0 for + # DDFL. + # KerasHub lets users configure it if needed. + # Ref: https://github.com/huggingface/transformers/blob/b374c3d12e8a42014b7911d1bddf598aeada1154/src/transformers/loss/loss_d_fine.py#L238 + pred_softmax = keras.ops.softmax( + pred_corners_all / ddf_temperature, axis=-1 + ) + target_softmax = keras.ops.softmax( + target_corners_all / ddf_temperature, axis=-1 + ) + kl_div = keras.ops.sum( + target_softmax + * ( + keras.ops.log(target_softmax + 1e-8) + - keras.ops.log(pred_softmax + 1e-8) + ), + axis=-1, + ) + loss_match_local = ( + weight_targets_local_expanded * (ddf_temperature**2) * kl_div + ) + mask_expanded = keras.ops.expand_dims(mask, axis=-1) + mask_expanded = keras.ops.tile(mask_expanded, [1, 1, 4]) + mask_flat = keras.ops.reshape(mask_expanded, (-1,)) + loss_match_local1 = keras.ops.cond( + keras.ops.any(mask_flat), + lambda: keras.ops.sum( + loss_match_local + * keras.ops.cast(mask_flat, loss_match_local.dtype) + ) + / keras.ops.sum( + keras.ops.cast(mask_flat, loss_match_local.dtype) + ), + lambda: keras.ops.convert_to_tensor( + 0.0, dtype=loss_match_local.dtype + ), + ) + neg_mask_flat = keras.ops.logical_not(mask_flat) + loss_match_local2 = keras.ops.cond( + keras.ops.any(neg_mask_flat), + lambda: keras.ops.sum( + loss_match_local + * keras.ops.cast(neg_mask_flat, loss_match_local.dtype) + ) + / keras.ops.sum( + keras.ops.cast(neg_mask_flat, loss_match_local.dtype) + ), + lambda: keras.ops.convert_to_tensor( + 0.0, dtype=loss_match_local.dtype + ), + ) + batch_scale = 1.0 / keras.ops.cast( + keras.ops.shape(outputs["pred_boxes"])[0], + dtype="float32", + ) + num_pos = keras.ops.sqrt( + keras.ops.sum(keras.ops.cast(mask, dtype="float32")) + * batch_scale + ) + num_neg = keras.ops.sqrt( + keras.ops.sum(keras.ops.cast(~mask, dtype="float32")) + * batch_scale + ) + return ( + loss_match_local1 * num_pos + loss_match_local2 * num_neg + ) / (num_pos + num_neg + 1e-8) + + all_equal = keras.ops.all( + keras.ops.equal(pred_corners_all, target_corners_all) + ) + return keras.ops.cond( + all_equal, + lambda: keras.ops.sum(pred_corners_all) * 0.0, + compute_ddf_loss_fn, + ) + + def ddf_false_fn(): + return keras.ops.convert_to_tensor(0.0, dtype=keras.backend.floatx()) + + losses["loss_ddf"] = keras.ops.cond(compute_ddf, ddf_true_fn, ddf_false_fn) + return losses + + +def _translate_gt_valid_case( + gt_flat, valid_idx_mask, function_values, max_num_bins, mask +): + closest_left_indices = ( + keras.ops.sum(keras.ops.cast(mask, "int32"), axis=1) - 1 + ) + indices_float = keras.ops.cast(closest_left_indices, dtype=gt_flat.dtype) + weight_right = keras.ops.zeros_like(indices_float) + weight_left = keras.ops.zeros_like(indices_float) + valid_indices_int = keras.ops.arange(keras.ops.shape(valid_idx_mask)[0]) + valid_indices_int = keras.ops.where(valid_idx_mask, valid_indices_int, -1) + valid_indices_int = keras.ops.where( + valid_indices_int >= 0, valid_indices_int, 0 + ) + valid_indices_long = keras.ops.cast( + keras.ops.where( + valid_idx_mask, + keras.ops.take(indices_float, valid_indices_int, axis=0), + 0.0, + ), + "int32", + ) + gt_valid = keras.ops.where( + valid_idx_mask, + keras.ops.take(gt_flat, valid_indices_int, axis=0), + 0.0, + ) + left_values = keras.ops.take(function_values, valid_indices_long, axis=0) + right_values = keras.ops.take( + function_values, + keras.ops.clip( + valid_indices_long + 1, + 0, + keras.ops.shape(function_values)[0] - 1, + ), + axis=0, + ) + left_diffs = keras.ops.abs(gt_valid - left_values) + right_diffs = keras.ops.abs(right_values - gt_valid) + wr_valid = left_diffs / (left_diffs + right_diffs + 1e-8) + wl_valid = 1.0 - wr_valid + weight_right = keras.ops.where( + keras.ops.expand_dims(valid_idx_mask, axis=-1), + keras.ops.expand_dims(wr_valid, axis=-1), + keras.ops.expand_dims(weight_right, axis=-1), + ) + weight_right = keras.ops.squeeze(weight_right, axis=-1) + weight_left = keras.ops.where( + keras.ops.expand_dims(valid_idx_mask, axis=-1), + keras.ops.expand_dims(wl_valid, axis=-1), + keras.ops.expand_dims(weight_left, axis=-1), + ) + weight_left = keras.ops.squeeze(weight_left, axis=-1) + indices_float = keras.ops.where( + indices_float < 0, + keras.ops.zeros_like(indices_float), + indices_float, + ) + weight_right = keras.ops.where( + indices_float < 0, keras.ops.zeros_like(weight_right), weight_right + ) + weight_left = keras.ops.where( + indices_float < 0, keras.ops.ones_like(weight_left), weight_left + ) + indices_float = keras.ops.where( + indices_float >= max_num_bins, + keras.ops.cast(max_num_bins - 0.1, dtype=indices_float.dtype), + indices_float, + ) + weight_right = keras.ops.where( + indices_float >= max_num_bins, + keras.ops.ones_like(weight_right), + weight_right, + ) + weight_left = keras.ops.where( + indices_float >= max_num_bins, + keras.ops.zeros_like(weight_left), + weight_left, + ) + return indices_float, weight_right, weight_left + + +def translate_gt(gt, max_num_bins, reg_scale, up): + gt_flat = keras.ops.reshape(gt, [-1]) + function_values = weighting_function(max_num_bins, up, reg_scale) + diffs = keras.ops.expand_dims( + function_values, axis=0 + ) - keras.ops.expand_dims(gt_flat, axis=1) + mask = diffs <= 0 + closest_left_indices = ( + keras.ops.sum(keras.ops.cast(mask, "int32"), axis=1) - 1 + ) + indices_float = keras.ops.cast(closest_left_indices, dtype=gt_flat.dtype) + weight_right = keras.ops.zeros_like(indices_float) + weight_left = keras.ops.zeros_like(indices_float) + valid_idx_mask = (indices_float >= 0) & (indices_float < max_num_bins) + return keras.ops.cond( + keras.ops.any(valid_idx_mask), + lambda: _translate_gt_valid_case( + gt_flat, valid_idx_mask, function_values, max_num_bins, mask + ), + lambda: ( + keras.ops.zeros_like(indices_float), + keras.ops.zeros_like(weight_right), + keras.ops.ones_like(weight_left), + ), + ) + + +def _compute_bbox2distance(points, bbox, max_num_bins, reg_scale, up, eps=0.1): + reg_scale_abs = keras.ops.abs(reg_scale) + left = (points[..., 0] - bbox[..., 0]) / ( + points[..., 2] / reg_scale_abs + 1e-16 + ) - 0.5 * reg_scale_abs + top = (points[..., 1] - bbox[..., 1]) / ( + points[..., 3] / reg_scale_abs + 1e-16 + ) - 0.5 * reg_scale_abs + right = (bbox[..., 2] - points[..., 0]) / ( + points[..., 2] / reg_scale_abs + 1e-16 + ) - 0.5 * reg_scale_abs + bottom = (bbox[..., 3] - points[..., 1]) / ( + points[..., 3] / reg_scale_abs + 1e-16 + ) - 0.5 * reg_scale_abs + four_lens = keras.ops.stack([left, top, right, bottom], axis=-1) + up_tensor = ( + keras.ops.convert_to_tensor(up) + if not isinstance(up, (keras.KerasTensor)) + else up + ) + four_lens_translated, weight_right, weight_left = translate_gt( + four_lens, max_num_bins, reg_scale_abs, up_tensor + ) + four_lens_translated = keras.ops.clip( + four_lens_translated, 0, max_num_bins - eps + ) + return ( + keras.ops.stop_gradient(four_lens_translated), + keras.ops.stop_gradient(weight_right), + keras.ops.stop_gradient(weight_left), + ) + + +def bbox2distance(points, bbox, max_num_bins, reg_scale, up, eps=0.1): + expected_flat_size = keras.ops.shape(points)[0] * 4 + return keras.ops.cond( + keras.ops.equal(keras.ops.shape(points)[0], 0), + lambda: ( + keras.ops.zeros( + (expected_flat_size,), dtype=keras.backend.floatx() + ), + keras.ops.zeros( + (expected_flat_size,), dtype=keras.backend.floatx() + ), + keras.ops.zeros( + (expected_flat_size,), dtype=keras.backend.floatx() + ), + ), + lambda: _compute_bbox2distance( + points, bbox, max_num_bins, reg_scale, up, eps + ), + ) + + +def unimodal_distribution_focal_loss( + pred, + label, + weight_right, + weight_left, + weight=None, + reduction="sum", + avg_factor=None, +): + label_flat = keras.ops.reshape(label, [-1]) + weight_right_flat = keras.ops.reshape(weight_right, [-1]) + weight_left_flat = keras.ops.reshape(weight_left, [-1]) + dis_left = keras.ops.cast(label_flat, "int32") + dis_right = dis_left + 1 + loss_left = ( + keras.ops.sparse_categorical_crossentropy( + dis_left, pred, from_logits=True + ) + * weight_left_flat + ) + loss_right = ( + keras.ops.sparse_categorical_crossentropy( + dis_right, pred, from_logits=True + ) + * weight_right_flat + ) + loss = loss_left + loss_right + if weight is not None: + loss = loss * keras.ops.cast(weight, dtype=loss.dtype) + if avg_factor is not None: + loss = keras.ops.sum(loss) / avg_factor + elif reduction == "mean": + loss = keras.ops.mean(loss) + elif reduction == "sum": + loss = keras.ops.sum(loss) + return loss + + +def _get_source_permutation_idx(indices): + """Gathers the batch and source indices for matched predictions. + + This method is a JAX-compatible adaptation of the author's approach, + which creates dynamically sized tensors by concatenating indices from a + list, which is not traceable by a JIT compiler. + + To ensure JAX compatibility, this implementation uses a masking + strategy. It returns fixed-size tensors where invalid positions are + padded with `0`. The downstream loss functions then use the + `valid_masks` tensor to ignore these padded entries during loss + computation. + """ + row_indices, _, valid_masks = indices + batch_size = keras.ops.shape(row_indices)[0] + max_matches = keras.ops.shape(row_indices)[1] + batch_indices = keras.ops.arange(batch_size, dtype="int32") + batch_indices = keras.ops.expand_dims(batch_indices, axis=1) + batch_indices = keras.ops.tile(batch_indices, [1, max_matches]) + batch_indices_flat = keras.ops.reshape(batch_indices, (-1,)) + row_indices_flat = keras.ops.reshape(row_indices, (-1,)) + valid_masks_flat = keras.ops.reshape(valid_masks, (-1,)) + batch_idx = keras.ops.where( + valid_masks_flat, + keras.ops.cast(batch_indices_flat, "int64"), + 0, + ) + src_idx = keras.ops.where( + valid_masks_flat, + keras.ops.cast(row_indices_flat, dtype="int64"), + 0, + ) + return batch_idx, src_idx + + +def get_cdn_matched_indices(dn_meta): + """Generates matched indices for contrastive denoising (CDN) training. + + This method is a JAX-compatible adaptation of the author's approach, + which iterates through the batch to build a list of dynamically sized + index tensors, which is not traceable by a JIT compiler. + + To ensure JAX compatibility, this implementation operates on the entire + batch as a single tensor operation. It uses the pre-padded + `dn_positive_idx` tensor (where -1 indicates padding) to generate + fixed-size `row_indices`, `col_indices`, and a `valid_masks` tensor. + """ + dn_positive_idx = dn_meta["dn_positive_idx"] + batch_size = keras.ops.shape(dn_positive_idx)[0] + num_denoising_queries = keras.ops.shape(dn_positive_idx)[1] + row_indices = keras.ops.tile( + keras.ops.expand_dims( + keras.ops.arange(num_denoising_queries, dtype="int64"), 0 + ), + [batch_size, 1], + ) + col_indices = dn_positive_idx + valid_masks = keras.ops.not_equal(col_indices, -1) + return (row_indices, col_indices, valid_masks) diff --git a/keras_hub/src/models/d_fine/d_fine_object_detector.py b/keras_hub/src/models/d_fine/d_fine_object_detector.py new file mode 100644 index 0000000000..e062f4a2d0 --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_object_detector.py @@ -0,0 +1,875 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression +from keras_hub.src.models.d_fine.d_fine_backbone import DFineBackbone +from keras_hub.src.models.d_fine.d_fine_loss import compute_box_losses +from keras_hub.src.models.d_fine.d_fine_loss import compute_local_losses +from keras_hub.src.models.d_fine.d_fine_loss import compute_vfl_loss +from keras_hub.src.models.d_fine.d_fine_loss import get_cdn_matched_indices +from keras_hub.src.models.d_fine.d_fine_loss import hungarian_matcher +from keras_hub.src.models.d_fine.d_fine_object_detector_preprocessor import ( + DFineObjectDetectorPreprocessor, +) +from keras_hub.src.models.object_detector import ObjectDetector +from keras_hub.src.utils.tensor_utils import assert_bounding_box_support + + +@keras_hub_export("keras_hub.models.DFineObjectDetector") +class DFineObjectDetector(ObjectDetector): + """D-FINE Object Detector model. + + This class wraps the `DFineBackbone` and adds the final prediction and loss + computation logic for end-to-end object detection. It is responsible for: + 1. Defining the functional model that connects the `DFineBackbone` to the + input layers. + 2. Implementing the `compute_loss` method, which uses a Hungarian matcher + to assign predictions to ground truth targets and calculates a weighted + sum of multiple loss components (classification, bounding box, etc.). + 3. Post-processing the raw outputs from the backbone into final, decoded + predictions (boxes, labels, confidence scores) during inference. + + Args: + backbone: A `keras_hub.models.Backbone` instance, specifically a + `DFineBackbone`, serving as the feature extractor for the object + detector. + num_classes: An integer representing the number of object classes to + detect. + bounding_box_format: A string specifying the format of the bounding + boxes. Defaults to `"yxyx"`. Must be a supported format (e.g., + `"yxyx"`, `"xyxy"`). + preprocessor: Optional. An instance of `DFineObjectDetectorPreprocessor` + for input data preprocessing. + matcher_class_cost: A float representing the cost for class mismatch in + the Hungarian matcher. Defaults to `2.0`. + matcher_bbox_cost: A float representing the cost for bounding box + mismatch in the Hungarian matcher. Defaults to `5.0`. + matcher_ciou_cost: A float representing the cost for complete IoU + mismatch in the Hungarian matcher. Defaults to `2.0`. + use_focal_loss: A boolean indicating whether to use focal loss for + classification. Defaults to `True`. + matcher_alpha: A float parameter for the focal loss alpha. Defaults to + `0.25`. + matcher_gamma: A float parameter for the focal loss gamma. Defaults to + `2.0`. + weight_loss_vfl: Weight for the classification loss. Defaults to `1.0`. + weight_loss_bbox: Weight for the bounding box regression loss. Default + is `5.0`. + weight_loss_ciou: Weight for the complete IoU loss. Defaults to + `2.0`. + weight_loss_fgl: Weight for the focal grid loss. Defaults to `0.15`. + weight_loss_ddf: Weight for the DDF loss. Defaults to `1.5`. + ddf_temperature: A float temperature scaling factor for the DDF loss. + Defaults to `5.0`. + prediction_decoder: Optional. A `keras.layers.Layer` instance that + decodes raw predictions. If not provided, a `NonMaxSuppression` + layer is used. + activation: Optional. The activation function to apply to the logits + before decoding. Defaults to `None`. + + Examples: + + **Creating a DFineObjectDetector without labels:** + + ```python + import numpy as np + from keras_hub.src.models.d_fine.d_fine_object_detector import ( + DFineObjectDetector + ) + from keras_hub.src.models.d_fine.d_fine_backbone import DFineBackbone + from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone + + # Initialize the backbone without labels. + hgnetv2_backbone = HGNetV2Backbone( + stem_channels=[3, 16, 16], + stackwise_stage_filters=[ + [16, 16, 64, 1, 3, 3], + [64, 32, 256, 1, 3, 3], + [256, 64, 512, 2, 3, 5], + [512, 128, 1024, 1, 3, 5], + ], + apply_downsample=[False, True, True, True], + use_lightweight_conv_block=[False, False, True, True], + depths=[1, 1, 2, 1], + hidden_sizes=[64, 256, 512, 1024], + embedding_size=16, + use_learnable_affine_block=True, + hidden_act="relu", + image_shape=(256, 256, 3), + out_features=["stage3", "stage4"], + ) + + # Initialize the backbone without labels. + backbone = DFineBackbone( + backbone=hgnetv2_backbone, + decoder_in_channels=[128, 128], + encoder_hidden_dim=128, + num_denoising=100, + num_labels=80, + hidden_dim=128, + learn_initial_query=False, + num_queries=300, + anchor_image_size=(256, 256), + feat_strides=[16, 32], + num_feature_levels=2, + encoder_in_channels=[512, 1024], + encode_proj_layers=[1], + num_attention_heads=8, + encoder_ffn_dim=512, + num_encoder_layers=1, + hidden_expansion=0.34, + depth_multiplier=0.5, + eval_idx=-1, + num_decoder_layers=3, + decoder_attention_heads=8, + decoder_ffn_dim=512, + decoder_n_points=[6, 6], + lqe_hidden_dim=64, + num_lqe_layers=2, + out_features=["stage3", "stage4"], + image_shape=(256, 256, 3), + ) + + # Create the detector. + detector = DFineObjectDetector( + backbone=backbone, + num_classes=80, + bounding_box_format="yxyx", + ) + ``` + + **Creating a DFineObjectDetector with labels for the backbone:** + + ```python + import numpy as np + from keras_hub.src.models.d_fine.d_fine_object_detector import ( + DFineObjectDetector + ) + from keras_hub.src.models.d_fine.d_fine_backbone import DFineBackbone + from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone + + # Define labels for the backbone. + labels = [ + { + "boxes": np.array([[0.5, 0.5, 0.2, 0.2], [0.4, 0.4, 0.1, 0.1]]), + "labels": np.array([1, 10]) + }, + {"boxes": np.array([[0.6, 0.6, 0.3, 0.3]]), "labels": np.array([20])}, + ] + + hgnetv2_backbone = HGNetV2Backbone( + stem_channels=[3, 16, 16], + stackwise_stage_filters=[ + [16, 16, 64, 1, 3, 3], + [64, 32, 256, 1, 3, 3], + [256, 64, 512, 2, 3, 5], + [512, 128, 1024, 1, 3, 5], + ], + apply_downsample=[False, True, True, True], + use_lightweight_conv_block=[False, False, True, True], + depths=[1, 1, 2, 1], + hidden_sizes=[64, 256, 512, 1024], + embedding_size=16, + use_learnable_affine_block=True, + hidden_act="relu", + image_shape=(256, 256, 3), + out_features=["stage3", "stage4"], + ) + + # Backbone is initialized with labels. + backbone = DFineBackbone( + backbone=hgnetv2_backbone, + decoder_in_channels=[128, 128], + encoder_hidden_dim=128, + num_denoising=100, + num_labels=80, + hidden_dim=128, + learn_initial_query=False, + num_queries=300, + anchor_image_size=(256, 256), + feat_strides=[16, 32], + num_feature_levels=2, + encoder_in_channels=[512, 1024], + encode_proj_layers=[1], + num_attention_heads=8, + encoder_ffn_dim=512, + num_encoder_layers=1, + hidden_expansion=0.34, + depth_multiplier=0.5, + eval_idx=-1, + num_decoder_layers=3, + decoder_attention_heads=8, + decoder_ffn_dim=512, + decoder_n_points=[6, 6], + lqe_hidden_dim=64, + num_lqe_layers=2, + out_features=["stage3", "stage4"], + image_shape=(256, 256, 3), + labels=labels, + box_noise_scale=1.0, + label_noise_ratio=0.5, + ) + + # Create the detector. + detector = DFineObjectDetector( + backbone=backbone, + num_classes=80, + bounding_box_format="yxyx", + ) + ``` + + **Using the detector for training:** + + ```python + import numpy as np + from keras_hub.src.models.d_fine.d_fine_object_detector import ( + DFineObjectDetector + ) + from keras_hub.src.models.d_fine.d_fine_backbone import DFineBackbone + from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone + + # Initialize backbone and detector. + hgnetv2_backbone = HGNetV2Backbone( + stem_channels=[3, 16, 16], + stackwise_stage_filters=[ + [16, 16, 64, 1, 3, 3], + [64, 32, 256, 1, 3, 3], + [256, 64, 512, 2, 3, 5], + [512, 128, 1024, 1, 3, 5], + ], + apply_downsample=[False, True, True, True], + use_lightweight_conv_block=[False, False, True, True], + depths=[1, 1, 2, 1], + hidden_sizes=[64, 256, 512, 1024], + embedding_size=16, + use_learnable_affine_block=True, + hidden_act="relu", + image_shape=(256, 256, 3), + out_features=["stage3", "stage4"], + ) + backbone = DFineBackbone( + backbone=hgnetv2_backbone, + decoder_in_channels=[128, 128], + encoder_hidden_dim=128, + num_denoising=100, + num_labels=80, + hidden_dim=128, + learn_initial_query=False, + num_queries=300, + anchor_image_size=(256, 256), + feat_strides=[16, 32], + num_feature_levels=2, + encoder_in_channels=[512, 1024], + encode_proj_layers=[1], + num_attention_heads=8, + encoder_ffn_dim=512, + num_encoder_layers=1, + hidden_expansion=0.34, + depth_multiplier=0.5, + eval_idx=-1, + num_decoder_layers=3, + decoder_attention_heads=8, + decoder_ffn_dim=512, + decoder_n_points=[6, 6], + lqe_hidden_dim=64, + num_lqe_layers=2, + out_features=["stage3", "stage4"], + image_shape=(256, 256, 3), + ) + detector = DFineObjectDetector( + backbone=backbone, + num_classes=80, + bounding_box_format="yxyx", + ) + + # Sample training data. + images = np.random.uniform( + low=0, high=255, size=(2, 256, 256, 3) + ).astype("float32") + bounding_boxes = { + "boxes": [ + np.array([[10.0, 20.0, 20.0, 30.0], [20.0, 30.0, 30.0, 40.0]]), + np.array([[15.0, 25.0, 25.0, 35.0]]), + ], + "labels": [ + np.array([0, 2]), np.array([1]) + ], + } + + # Compile the model. + detector.compile( + optimizer="adam", + loss=detector.compute_loss, + ) + + # Train the model. + detector.fit(x=images, y=bounding_boxes, epochs=1, batch_size=1) + ``` + + **Making predictions:** + + ```python + import numpy as np + from keras_hub.src.models.d_fine.d_fine_object_detector import ( + DFineObjectDetector + ) + from keras_hub.src.models.d_fine.d_fine_backbone import DFineBackbone + from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone + + # Initialize backbone and detector. + hgnetv2_backbone = HGNetV2Backbone( + stem_channels=[3, 16, 16], + stackwise_stage_filters=[ + [16, 16, 64, 1, 3, 3], + [64, 32, 256, 1, 3, 3], + [256, 64, 512, 2, 3, 5], + [512, 128, 1024, 1, 3, 5], + ], + apply_downsample=[False, True, True, True], + use_lightweight_conv_block=[False, False, True, True], + depths=[1, 1, 2, 1], + hidden_sizes=[64, 256, 512, 1024], + embedding_size=16, + use_learnable_affine_block=True, + hidden_act="relu", + image_shape=(256, 256, 3), + out_features=["stage3", "stage4"], + ) + backbone = DFineBackbone( + backbone=hgnetv2_backbone, + decoder_in_channels=[128, 128], + encoder_hidden_dim=128, + num_denoising=100, + num_labels=80, + hidden_dim=128, + learn_initial_query=False, + num_queries=300, + anchor_image_size=(256, 256), + feat_strides=[16, 32], + num_feature_levels=2, + encoder_in_channels=[512, 1024], + encode_proj_layers=[1], + num_attention_heads=8, + encoder_ffn_dim=512, + num_encoder_layers=1, + hidden_expansion=0.34, + depth_multiplier=0.5, + eval_idx=-1, + num_decoder_layers=3, + decoder_attention_heads=8, + decoder_ffn_dim=512, + decoder_n_points=[6, 6], + lqe_hidden_dim=64, + num_lqe_layers=2, + out_features=["stage3", "stage4"], + image_shape=(256, 256, 3), + ) + detector = DFineObjectDetector( + backbone=backbone, + num_classes=80, + bounding_box_format="yxyx", + ) + + # Sample test image. + test_image = np.random.uniform( + low=0, high=255, size=(1, 256, 256, 3) + ).astype("float32") + + # Make predictions. + predictions = detector.predict(test_image) + + # Access predictions. + boxes = predictions["boxes"] # Shape: (1, 100, 4) + labels = predictions["labels"] # Shape: (1, 100) + confidence = predictions["confidence"] # Shape: (1, 100) + num_detections = predictions["num_detections"] # Shape: (1,) + ``` + """ + + backbone_cls = DFineBackbone + preprocessor_cls = DFineObjectDetectorPreprocessor + + def __init__( + self, + backbone, + num_classes, + bounding_box_format="yxyx", + preprocessor=None, + matcher_class_cost=2.0, + matcher_bbox_cost=5.0, + matcher_ciou_cost=2.0, + use_focal_loss=True, + matcher_alpha=0.25, + matcher_gamma=2.0, + weight_loss_vfl=1.0, + weight_loss_bbox=5.0, + weight_loss_ciou=2.0, + weight_loss_fgl=0.15, + weight_loss_ddf=1.5, + ddf_temperature=5.0, + prediction_decoder=None, + activation=None, + **kwargs, + ): + assert_bounding_box_support(self.__class__.__name__) + + # === Functional Model === + image_input = keras.layers.Input( + shape=backbone.image_shape, name="images" + ) + outputs = backbone(image_input) + intermediate_logits = outputs["intermediate_logits"] + intermediate_reference_points = outputs["intermediate_reference_points"] + intermediate_predicted_corners = outputs[ + "intermediate_predicted_corners" + ] + initial_reference_points = outputs["initial_reference_points"] + logits = intermediate_logits[:, -1, :, :] + pred_boxes = intermediate_reference_points[:, -1, :, :] + model_outputs = { + "logits": logits, + "pred_boxes": pred_boxes, + "intermediate_logits": intermediate_logits, + "intermediate_reference_points": intermediate_reference_points, + "intermediate_predicted_corners": intermediate_predicted_corners, + "initial_reference_points": initial_reference_points, + "enc_topk_logits": outputs["enc_topk_logits"], + "enc_topk_bboxes": outputs["enc_topk_bboxes"], + } + if "dn_num_group" in outputs: + model_outputs["dn_positive_idx"] = outputs["dn_positive_idx"] + model_outputs["dn_num_group"] = outputs["dn_num_group"] + model_outputs["dn_num_split"] = outputs["dn_num_split"] + super().__init__( + inputs=image_input, + outputs=model_outputs, + **kwargs, + ) + + # === Config === + self.backbone = backbone + self.num_classes = num_classes + self.bounding_box_format = bounding_box_format + self.preprocessor = preprocessor + self.matcher_class_cost = matcher_class_cost + self.matcher_bbox_cost = matcher_bbox_cost + self.matcher_ciou_cost = matcher_ciou_cost + self.use_focal_loss = use_focal_loss + self.matcher_alpha = matcher_alpha + self.matcher_gamma = matcher_gamma + self.weight_dict = { + "loss_vfl": weight_loss_vfl, + "loss_bbox": weight_loss_bbox, + "loss_ciou": weight_loss_ciou, + "loss_fgl": weight_loss_fgl, + "loss_ddf": weight_loss_ddf, + } + self.ddf_temperature = ddf_temperature + self.activation = activation + self._prediction_decoder = prediction_decoder or NonMaxSuppression( + from_logits=(self.activation != keras.activations.sigmoid), + bounding_box_format=self.bounding_box_format, + max_detections=backbone.num_queries, + ) + + def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): + gt_boxes = y["boxes"] + gt_labels = y["labels"] + batch_size = keras.ops.shape(gt_labels)[0] + num_objects = keras.ops.shape(gt_labels)[1] + num_targets_per_image = keras.ops.tile( + keras.ops.expand_dims(num_objects, 0), [batch_size] + ) + labels_for_item = keras.ops.reshape(gt_labels, [-1]) + boxes_for_item = keras.ops.reshape(gt_boxes, [-1, 4]) + targets = {"labels": labels_for_item, "boxes": boxes_for_item} + + intermediate_logits_all = y_pred["intermediate_logits"] + intermediate_ref_points_all = y_pred["intermediate_reference_points"] + predicted_corners_all = y_pred["intermediate_predicted_corners"] + initial_ref_points_all = y_pred["initial_reference_points"] + enc_topk_logits = y_pred["enc_topk_logits"] + enc_topk_bboxes = y_pred["enc_topk_bboxes"] + if "dn_num_group" in y_pred: + denoising_meta_values = { + "dn_positive_idx": y_pred["dn_positive_idx"], + "dn_num_group": y_pred["dn_num_group"], + "dn_num_split": y_pred["dn_num_split"], + } + dn_split_point = self.backbone.dn_split_point + ( + dn_intermediate_logits, + matching_intermediate_logits, + ) = keras.ops.split( + intermediate_logits_all, [dn_split_point], axis=2 + ) + ( + dn_intermediate_ref_points, + matching_intermediate_ref_points, + ) = keras.ops.split( + intermediate_ref_points_all, [dn_split_point], axis=2 + ) + ( + dn_predicted_corners, + matching_predicted_corners, + ) = keras.ops.split(predicted_corners_all, [dn_split_point], axis=2) + ( + dn_initial_ref_points, + matching_initial_ref_points, + ) = keras.ops.split( + initial_ref_points_all, [dn_split_point], axis=2 + ) + else: + denoising_meta_values = None + matching_intermediate_logits = intermediate_logits_all + matching_intermediate_ref_points = intermediate_ref_points_all + matching_predicted_corners = predicted_corners_all + matching_initial_ref_points = initial_ref_points_all + matching_logits = matching_intermediate_logits[:, -1, :, :] + matching_pred_boxes = matching_intermediate_ref_points[:, -1, :, :] + outputs_without_aux = { + "logits": matching_logits, + "pred_boxes": keras.ops.clip(matching_pred_boxes, 0, 1), + } + indices = hungarian_matcher( + outputs_without_aux, + [targets], + num_targets_per_image, + self.use_focal_loss, + self.matcher_alpha, + self.matcher_gamma, + self.matcher_bbox_cost, + self.matcher_class_cost, + self.matcher_ciou_cost, + self.backbone, + ) + num_boxes = keras.ops.shape(labels_for_item)[0] + num_boxes = keras.ops.convert_to_tensor(num_boxes, dtype="float32") + num_boxes = keras.ops.maximum(num_boxes, 1.0) + losses = {} + vfl_loss = compute_vfl_loss( + outputs_without_aux, + [targets], + indices, + num_boxes, + self.num_classes, + self.matcher_alpha, + self.matcher_gamma, + ) + losses.update( + { + k: v * self.weight_dict[k] + for k, v in vfl_loss.items() + if k in self.weight_dict + } + ) + box_losses = compute_box_losses( + outputs_without_aux, [targets], indices, num_boxes + ) + losses.update( + { + k: v * self.weight_dict[k] + for k, v in box_losses.items() + if k in self.weight_dict + } + ) + local_losses = compute_local_losses( + { + **outputs_without_aux, + "pred_corners": matching_predicted_corners[:, -1, :, :], + "ref_points": matching_initial_ref_points[:, -1, :, :], + "teacher_corners": keras.ops.zeros_like( + matching_predicted_corners[:, -1, :, :] + ), + "teacher_logits": keras.ops.zeros_like(matching_logits), + }, + [targets], + indices, + num_boxes, + self.backbone, + self.ddf_temperature, + compute_ddf=False, + ) + losses.update( + { + k: v * self.weight_dict[k] + for k, v in local_losses.items() + if k in self.weight_dict + } + ) + + num_aux_layers = self.backbone.num_decoder_layers + auxiliary_outputs_list = [ + { + "logits": matching_intermediate_logits[:, i, :, :], + "pred_boxes": keras.ops.clip( + matching_intermediate_ref_points[:, i, :, :], 0, 1 + ), + "pred_corners": matching_predicted_corners[:, i, :, :], + "ref_points": matching_initial_ref_points[:, i, :, :], + "teacher_corners": matching_predicted_corners[:, -1, :, :], + "teacher_logits": matching_intermediate_logits[:, -1, :, :], + } + for i in range(num_aux_layers) + ] + for i, aux_output in enumerate(auxiliary_outputs_list): + aux_indices = hungarian_matcher( + aux_output, + [targets], + num_targets_per_image, + self.use_focal_loss, + self.matcher_alpha, + self.matcher_gamma, + self.matcher_bbox_cost, + self.matcher_class_cost, + self.matcher_ciou_cost, + self.backbone, + ) + aux_vfl_loss = compute_vfl_loss( + aux_output, + [targets], + aux_indices, + num_boxes, + self.num_classes, + self.matcher_alpha, + self.matcher_gamma, + ) + aux_box_losses = compute_box_losses( + aux_output, [targets], aux_indices, num_boxes + ) + is_not_last_aux_layer = i < len(auxiliary_outputs_list) - 1 + aux_local_losses = compute_local_losses( + aux_output, + [targets], + aux_indices, + num_boxes, + self.backbone, + self.ddf_temperature, + compute_ddf=is_not_last_aux_layer, + ) + aux_losses = {**aux_vfl_loss, **aux_box_losses, **aux_local_losses} + weighted_aux_losses = { + k + f"_aux_{i}": v * self.weight_dict[k] + for k, v in aux_losses.items() + if k in self.weight_dict + } + losses.update(weighted_aux_losses) + # Add encoder loss. + enc_output = { + "logits": enc_topk_logits, + "pred_boxes": keras.ops.clip(enc_topk_bboxes, 0, 1), + } + enc_indices = hungarian_matcher( + enc_output, + [targets], + num_targets_per_image, + self.use_focal_loss, + self.matcher_alpha, + self.matcher_gamma, + self.matcher_bbox_cost, + self.matcher_class_cost, + self.matcher_ciou_cost, + self.backbone, + ) + enc_vfl_loss = compute_vfl_loss( + enc_output, + [targets], + enc_indices, + num_boxes, + self.num_classes, + self.matcher_alpha, + self.matcher_gamma, + ) + enc_box_losses = compute_box_losses( + enc_output, [targets], enc_indices, num_boxes + ) + enc_losses = {**enc_vfl_loss, **enc_box_losses} + weighted_enc_losses = { + k + "_enc": v * self.weight_dict[k] + for k, v in enc_losses.items() + if k in self.weight_dict + } + losses.update(weighted_enc_losses) + + if denoising_meta_values is not None: + dn_indices = get_cdn_matched_indices(denoising_meta_values) + dn_num_group = denoising_meta_values["dn_num_group"][0] + num_boxes_dn = num_boxes * keras.ops.cast(dn_num_group, "float32") + num_dn_layers = self.backbone.num_decoder_layers + 1 + for i in range(num_dn_layers): + is_not_last_layer = keras.ops.less(i, num_dn_layers - 1) + teacher_idx = num_dn_layers - 1 + dn_aux_output = { + "logits": dn_intermediate_logits[:, i, :, :], + "pred_boxes": keras.ops.clip( + dn_intermediate_ref_points[:, i, :, :], 0, 1 + ), + "pred_corners": dn_predicted_corners[:, i, :, :], + "ref_points": dn_initial_ref_points[:, i, :, :], + "teacher_corners": dn_predicted_corners[ + :, teacher_idx, :, : + ], + "teacher_logits": dn_intermediate_logits[ + :, teacher_idx, :, : + ], + } + vfl_loss = compute_vfl_loss( + dn_aux_output, + [targets], + dn_indices, + num_boxes_dn, + self.num_classes, + self.matcher_alpha, + self.matcher_gamma, + ) + box_losses = compute_box_losses( + dn_aux_output, [targets], dn_indices, num_boxes_dn + ) + local_losses = compute_local_losses( + dn_aux_output, + [targets], + dn_indices, + num_boxes_dn, + self.backbone, + self.ddf_temperature, + compute_ddf=is_not_last_layer, + ) + all_losses = {**vfl_loss, **box_losses, **local_losses} + weighted_losses = { + k + f"_dn_{i}": v * self.weight_dict[k] + for k, v in all_losses.items() + if k in self.weight_dict + } + losses.update(weighted_losses) + total_loss = keras.ops.sum([v for v in losses.values()]) + return total_loss + + @property + def prediction_decoder(self): + return self._prediction_decoder + + @prediction_decoder.setter + def prediction_decoder(self, prediction_decoder): + if prediction_decoder.bounding_box_format != self.bounding_box_format: + raise ValueError( + "Expected `prediction_decoder` and `DFineObjectDetector` to " + "use the same `bounding_box_format`, but got " + "`prediction_decoder.bounding_box_format=" + f"{prediction_decoder.bounding_box_format}`, and " + "`self.bounding_box_format=" + f"{self.bounding_box_format}`." + ) + self._prediction_decoder = prediction_decoder + self.make_predict_function(force=True) + self.make_train_function(force=True) + self.make_test_function(force=True) + + def decode_predictions(self, predictions, data): + """Decodes raw model predictions into final bounding boxes. + + This method takes the raw output from the model (logits and normalized + bounding boxes in center format) and converts them into the final + detection format. The process involves: + 1. Denormalizing the bounding box coordinates to the original image + dimensions. + 2. Converting boxes from center format `(cx, cy, w, h)` to corner + format `(ymin, xmin, ymax, xmax)`. + 3. Applying non-maximum suppression to filter out overlapping boxes + and keep only the most confident detections. + + Args: + predictions: dict, A dictionary of tensors from the model, + containing `"logits"` and `"pred_boxes"`. + data: tuple, The input data tuple, from which the original images + are extracted to obtain their dimensions for denormalization. + + Returns: + Dictionary: Final predictions, containing `"boxes"`, `"labels"`, + `"confidence"`, and `"num_detections"`. + """ + if isinstance(data, (list, tuple)): + images, _ = data + else: + images = data + logits = predictions["logits"] + pred_boxes = predictions["pred_boxes"] + height, width, _ = keras.ops.shape(images)[1:] + denormalized_boxes = keras.ops.stack( + [ + pred_boxes[..., 0] * width, # center_x + pred_boxes[..., 1] * height, # center_y + pred_boxes[..., 2] * width, # width + pred_boxes[..., 3] * height, # height + ], + axis=-1, + ) + pred_boxes_xyxy = keras.utils.bounding_boxes.convert_format( + denormalized_boxes, + source="center_xywh", + target="xyxy", + ) + pred_boxes_yxyx = keras.ops.stack( + [ + pred_boxes_xyxy[..., 1], # y_min + pred_boxes_xyxy[..., 0], # x_min + pred_boxes_xyxy[..., 3], # y_max + pred_boxes_xyxy[..., 2], # x_max + ], + axis=-1, + ) + y_pred = self.prediction_decoder(pred_boxes_yxyx, logits, images=images) + return y_pred + + def get_config(self): + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "bounding_box_format": self.bounding_box_format, + "matcher_class_cost": self.matcher_class_cost, + "matcher_bbox_cost": self.matcher_bbox_cost, + "matcher_ciou_cost": self.matcher_ciou_cost, + "use_focal_loss": self.use_focal_loss, + "matcher_alpha": self.matcher_alpha, + "matcher_gamma": self.matcher_gamma, + "weight_loss_vfl": self.weight_dict["loss_vfl"], + "weight_loss_bbox": self.weight_dict["loss_bbox"], + "weight_loss_ciou": self.weight_dict["loss_ciou"], + "weight_loss_fgl": self.weight_dict["loss_fgl"], + "weight_loss_ddf": self.weight_dict["loss_ddf"], + "ddf_temperature": self.ddf_temperature, + "prediction_decoder": keras.saving.serialize_keras_object( + self._prediction_decoder + ), + } + ) + return config + + def predict_step(self, *args): + outputs = super().predict_step(*args) + if isinstance(outputs, tuple): + return self.decode_predictions(outputs[0], args[-1]), outputs[1] + return self.decode_predictions(outputs, *args) + + @classmethod + def from_config(cls, config): + config = config.copy() + if "backbone" in config and isinstance(config["backbone"], dict): + config["backbone"] = keras.saving.deserialize_keras_object( + config["backbone"] + ) + if "preprocessor" in config and isinstance( + config["preprocessor"], dict + ): + config["preprocessor"] = keras.saving.deserialize_keras_object( + config["preprocessor"] + ) + if "prediction_decoder" in config and isinstance( + config["prediction_decoder"], dict + ): + config["prediction_decoder"] = ( + keras.saving.deserialize_keras_object( + config["prediction_decoder"] + ) + ) + return cls(**config) diff --git a/keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py b/keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py new file mode 100644 index 0000000000..2708229898 --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py @@ -0,0 +1,14 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.d_fine.d_fine_backbone import DFineBackbone +from keras_hub.src.models.d_fine.d_fine_image_converter import ( + DFineImageConverter, +) +from keras_hub.src.models.object_detector_preprocessor import ( + ObjectDetectorPreprocessor, +) + + +@keras_hub_export("keras_hub.models.DFineObjectDetectorPreprocessor") +class DFineObjectDetectorPreprocessor(ObjectDetectorPreprocessor): + backbone_cls = DFineBackbone + image_converter_cls = DFineImageConverter diff --git a/keras_hub/src/models/d_fine/d_fine_object_detector_test.py b/keras_hub/src/models/d_fine/d_fine_object_detector_test.py new file mode 100644 index 0000000000..3b3bfe14c0 --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_object_detector_test.py @@ -0,0 +1,154 @@ +import keras +import numpy as np +import pytest +from absl.testing import parameterized +from packaging import version + +from keras_hub.src.models.d_fine.d_fine_backbone import DFineBackbone +from keras_hub.src.models.d_fine.d_fine_image_converter import ( + DFineImageConverter, +) +from keras_hub.src.models.d_fine.d_fine_object_detector import ( + DFineObjectDetector, +) +from keras_hub.src.models.d_fine.d_fine_object_detector_preprocessor import ( + DFineObjectDetectorPreprocessor, +) +from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone +from keras_hub.src.tests.test_case import TestCase + + +@pytest.mark.skipif( + version.parse(keras.__version__) < version.parse("3.8.0"), + reason="Bbox utils are not supported before Keras < 3.8.0", +) +class DFineObjectDetectorTest(TestCase): + def setUp(self): + self.labels = [ + { + "boxes": np.array([[0.5, 0.5, 0.2, 0.2]]), + "labels": np.array([1]), + }, + { + "boxes": np.array([[0.6, 0.6, 0.3, 0.3]]), + "labels": np.array([2]), + }, + ] + self.stackwise_stage_filters = [ + [8, 8, 16, 1, 1, 3], + [16, 8, 32, 1, 1, 3], + ] + self.apply_downsample = [False, True] + self.use_lightweight_conv_block = [False, False] + self.input_size = 32 + self.bounding_box_format = "yxyx" + + image_converter = DFineImageConverter( + bounding_box_format=self.bounding_box_format, + image_size=(self.input_size, self.input_size), + ) + preprocessor = DFineObjectDetectorPreprocessor( + image_converter=image_converter, + ) + self.preprocessor = preprocessor + self.images = np.random.uniform( + low=0, high=255, size=(1, self.input_size, self.input_size, 3) + ).astype("float32") + self.bounding_boxes = { + "boxes": np.array([[[10.0, 10.0, 20.0, 20.0]]]), + "labels": np.array([[0]]), + } + self.train_data = ( + self.images, + self.bounding_boxes, + ) + hgnetv2_backbone = HGNetV2Backbone( + stem_channels=[3, 8, 8], + stackwise_stage_filters=self.stackwise_stage_filters, + apply_downsample=self.apply_downsample, + use_lightweight_conv_block=self.use_lightweight_conv_block, + depths=[1, 1], + hidden_sizes=[16, 32], + embedding_size=8, + use_learnable_affine_block=True, + hidden_act="relu", + image_shape=(None, None, 3), + out_features=["stage1", "stage2"], + data_format="channels_last", + ) + self.base_backbone_kwargs = { + "backbone": hgnetv2_backbone, + "decoder_in_channels": [16, 16], + "encoder_hidden_dim": 16, + "num_denoising": 10, + "num_labels": 4, + "hidden_dim": 16, + "learn_initial_query": False, + "num_queries": 10, + "anchor_image_size": (self.input_size, self.input_size), + "feat_strides": [4, 8], + "num_feature_levels": 2, + "encoder_in_channels": [16, 32], + "encode_proj_layers": [1], + "num_attention_heads": 2, + "encoder_ffn_dim": 32, + "num_encoder_layers": 1, + "hidden_expansion": 0.5, + "depth_multiplier": 0.5, + "eval_idx": -1, + "num_decoder_layers": 1, + "decoder_attention_heads": 2, + "decoder_ffn_dim": 32, + "decoder_method": "default", + "decoder_n_points": [2, 2], + "lqe_hidden_dim": 16, + "num_lqe_layers": 1, + "out_features": ["stage1", "stage2"], + "image_shape": (None, None, 3), + "data_format": "channels_last", + "seed": 0, + } + + @parameterized.named_parameters( + ("default", False), + ("denoising", True), + ) + def test_detection_basics(self, use_noise_and_labels): + backbone_kwargs = self.base_backbone_kwargs.copy() + if use_noise_and_labels: + backbone_kwargs["box_noise_scale"] = 1.0 + backbone_kwargs["label_noise_ratio"] = 0.5 + backbone_kwargs["labels"] = self.labels + backbone = DFineBackbone(**backbone_kwargs) + init_kwargs = { + "backbone": backbone, + "num_classes": 4, + "bounding_box_format": self.bounding_box_format, + "preprocessor": self.preprocessor, + } + self.run_task_test( + cls=DFineObjectDetector, + init_kwargs=init_kwargs, + train_data=self.train_data, + expected_output_shape={ + "boxes": (1, 10, 4), + "labels": (1, 10), + "confidence": (1, 10), + "num_detections": (1,), + }, + ) + + @pytest.mark.large + def test_saved_model(self): + backbone = DFineBackbone(**self.base_backbone_kwargs) + init_kwargs = { + "backbone": backbone, + "num_classes": 4, + "bounding_box_format": self.bounding_box_format, + "preprocessor": self.preprocessor, + } + self.run_model_saving_test( + cls=DFineObjectDetector, + init_kwargs=init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/d_fine/d_fine_presets.py b/keras_hub/src/models/d_fine/d_fine_presets.py new file mode 100644 index 0000000000..d976272974 --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_presets.py @@ -0,0 +1,2 @@ +# Metadata for loading pretrained model weights. +backbone_presets = {} diff --git a/keras_hub/src/models/d_fine/d_fine_utils.py b/keras_hub/src/models/d_fine/d_fine_utils.py new file mode 100644 index 0000000000..770d6fc7f8 --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_utils.py @@ -0,0 +1,827 @@ +import keras + + +def d_fine_kernel_initializer(initializer_range=0.01, name="random_normal"): + if name == "random_normal": + return keras.initializers.RandomNormal( + mean=0.0, stddev=initializer_range + ) + elif name == "glorot_uniform": + return keras.initializers.GlorotUniform() + elif name == "zeros": + return keras.initializers.Zeros() + + +def inverse_sigmoid(x, eps=1e-5): + """Computes the inverse sigmoid (logit) function. + + This function computes the inverse of the sigmoid function, also known as + the logit function. It is used in D-FINE to transform bounding box + coordinates from the `[0, 1]` range back to logits, for example in + `DFineContrastiveDenoisingGroupGenerator` and `DFineDecoder`. + + Args: + x: Tensor, Input tensor with values in `[0, 1]`. + eps: float, Small epsilon value to prevent numerical instability + at the boundaries. Default is `1e-5`. + + Returns: + Tensor: The inverse sigmoid of the input tensor. + """ + x = keras.ops.clip(x, 0, 1) + x1 = keras.ops.maximum(x, eps) + x2 = keras.ops.maximum(1 - x, eps) + return keras.ops.log(x1 / x2) + + +def grid_sample(data, grid, align_corners=False, height=None, width=None): + """Samples data at specified grid locations using bilinear interpolation. + + This function performs bilinear interpolation to sample data at arbitrary + grid locations. It is a core component of the deformable attention + mechanism, used within `multi_scale_deformable_attention_v2`. + This is a Keras-native implementation (polyfill) for + `torch.nn.functional.grid_sample`. + + Args: + data: Tensor, Input data tensor of shape `[batch, channels, height, + width]`. + grid: Tensor, Grid coordinates of shape `[batch, out_height, out_width, + 2]`. The last dimension contains `(x, y)` coordinates normalized to + `[-1, 1]`. + align_corners: bool, If `True`, align corners for coordinate mapping. + Default is `False`. + height: int, optional, Override height for coordinate normalization. + width: int, optional, Override width for coordinate normalization. + + Returns: + Tensor: Sampled data of shape `[batch, channels, out_height, + out_width]`. + """ + num_batch, _, data_height, data_width = keras.ops.shape(data) + _, out_height, out_width, _ = keras.ops.shape(grid) + dtype = data.dtype + grid_x_norm = grid[..., 0] + grid_y_norm = grid[..., 1] + h_in = height if height is not None else data_height + w_in = width if width is not None else data_width + height_f = keras.ops.cast(h_in, dtype=dtype) + width_f = keras.ops.cast(w_in, dtype=dtype) + if align_corners: + x_unnorm = (grid_x_norm + 1) / 2 * (width_f - 1) + y_unnorm = (grid_y_norm + 1) / 2 * (height_f - 1) + else: + x_unnorm = ((grid_x_norm + 1) / 2 * width_f) - 0.5 + y_unnorm = ((grid_y_norm + 1) / 2 * height_f) - 0.5 + x0 = keras.ops.floor(x_unnorm) + y0 = keras.ops.floor(y_unnorm) + x1 = x0 + 1 + y1 = y0 + 1 + w_y0_val = y1 - y_unnorm + w_y1_val = y_unnorm - y0 + w_x0_val = x1 - x_unnorm + w_x1_val = x_unnorm - x0 + data_permuted = keras.ops.transpose(data, (0, 2, 3, 1)) + + def gather_padded( + data_p, + y_coords, + x_coords, + actual_data_height, + actual_data_width, + override_height=None, + override_width=None, + ): + y_coords_int = keras.ops.cast(y_coords, "int32") + x_coords_int = keras.ops.cast(x_coords, "int32") + + y_oob = keras.ops.logical_or( + y_coords_int < 0, y_coords_int >= actual_data_height + ) + x_oob = keras.ops.logical_or( + x_coords_int < 0, x_coords_int >= actual_data_width + ) + oob_mask = keras.ops.logical_or(y_oob, x_oob) + + y_coords_clipped = keras.ops.clip( + y_coords_int, 0, actual_data_height - 1 + ) + x_coords_clipped = keras.ops.clip( + x_coords_int, 0, actual_data_width - 1 + ) + + width_for_indexing = ( + override_width if override_width is not None else actual_data_width + ) + + if override_height is not None and override_width is not None: + data_flat = keras.ops.reshape( + data_p, + ( + num_batch, + override_height * override_width, + keras.ops.shape(data_p)[-1], + ), + ) + else: + data_flat = keras.ops.reshape( + data_p, (num_batch, -1, keras.ops.shape(data_p)[-1]) + ) + y_coords_flat = keras.ops.reshape( + y_coords_clipped, (num_batch, out_height * out_width) + ) + x_coords_flat = keras.ops.reshape( + x_coords_clipped, (num_batch, out_height * out_width) + ) + indices = y_coords_flat * width_for_indexing + x_coords_flat + + num_elements_per_batch = keras.ops.shape(data_flat)[1] + batch_offsets = ( + keras.ops.arange(num_batch, dtype=indices.dtype) + * num_elements_per_batch + ) + batch_offsets = keras.ops.reshape(batch_offsets, (num_batch, 1)) + absolute_indices = indices + batch_offsets + data_reshaped_for_gather = keras.ops.reshape( + data_flat, (-1, keras.ops.shape(data_flat)[-1]) + ) + gathered = keras.ops.take( + data_reshaped_for_gather, absolute_indices, axis=0 + ) + gathered = keras.ops.reshape( + gathered, (num_batch, out_height, out_width, -1) + ) + oob_mask_expanded = keras.ops.expand_dims(oob_mask, axis=-1) + gathered_values = gathered * keras.ops.cast( + keras.ops.logical_not(oob_mask_expanded), dtype=gathered.dtype + ) + return gathered_values + + batch_indices = keras.ops.arange(0, num_batch, dtype="int32") + batch_indices = keras.ops.reshape(batch_indices, (num_batch, 1, 1)) + batch_indices = keras.ops.tile(batch_indices, (1, out_height, out_width)) + val_y0_x0 = gather_padded(data_permuted, y0, x0, h_in, w_in, height, width) + val_y0_x1 = gather_padded(data_permuted, y0, x1, h_in, w_in, height, width) + val_y1_x0 = gather_padded(data_permuted, y1, x0, h_in, w_in, height, width) + val_y1_x1 = gather_padded(data_permuted, y1, x1, h_in, w_in, height, width) + interp_val = ( + val_y0_x0 * keras.ops.expand_dims(w_y0_val * w_x0_val, -1) + + val_y0_x1 * keras.ops.expand_dims(w_y0_val * w_x1_val, -1) + + val_y1_x0 * keras.ops.expand_dims(w_y1_val * w_x0_val, -1) + + val_y1_x1 * keras.ops.expand_dims(w_y1_val * w_x1_val, -1) + ) + + return keras.ops.transpose(interp_val, (0, 3, 1, 2)) + + +def multi_scale_deformable_attention_v2( + value, + dynamic_spatial_shapes, + sampling_locations, + attention_weights, + num_points, + slice_sizes, + spatial_shapes, + num_levels, + num_queries, + method="default", +): + """Computes multi-scale deformable attention mechanism. + + This function implements the core of the multi-scale deformable attention + mechanism used in `DFineMultiScaleDeformableAttention`. It samples features + at multiple scales and locations based on learned attention weights and + sampling locations. + + Args: + value: Tensor, Feature values of shape `[batch, seq_len, num_heads, + hidden_dim]`. + dynamic_spatial_shapes: Tensor, Spatial shapes for each level. + sampling_locations: Tensor, Sampling locations of shape + `[batch, num_queries, num_heads, num_levels, num_points, 2]`. + attention_weights: Tensor, Attention weights of shape `[batch, + num_queries, num_heads, total_points]`. + num_points: list, Number of sampling points for each level. + slice_sizes: list, Sizes for slicing the value tensor. + spatial_shapes: list, Spatial shapes for each level. + num_levels: int, Number of feature levels. + num_queries: int, Number of queries. + method: str, Sampling method, either `"default"` or `"discrete"`. + Default is `"default"`. + + Returns: + Tensor: Output features of shape `[batch, num_queries, num_heads * + hidden_dim]`. + """ + value_shape = keras.ops.shape(value) + batch_size = value_shape[0] + num_heads = value_shape[2] + hidden_dim = value_shape[3] + sampling_shape = keras.ops.shape(sampling_locations) + num_levels_from_shape = sampling_shape[3] + num_points_from_shape = sampling_shape[4] + permuted_value = keras.ops.transpose(value, axes=(0, 2, 3, 1)) + seq_len = value_shape[1] + flattened_value = keras.ops.reshape( + permuted_value, (-1, hidden_dim, seq_len) + ) + value_chunk_sizes = keras.ops.array(slice_sizes, dtype="int32") + cum_sizes = keras.ops.concatenate( + [ + keras.ops.zeros((1,), dtype="int32"), + keras.ops.cumsum(value_chunk_sizes), + ] + ) + values = [] + for i in range(len(spatial_shapes)): + start = cum_sizes[i] + current_slice_size = slice_sizes[i] + dynamic_slice_start_indices = (0, 0, start) + dynamic_slice_shape = ( + keras.ops.shape(flattened_value)[0], + keras.ops.shape(flattened_value)[1], + current_slice_size, + ) + sliced_value = keras.ops.slice( + flattened_value, dynamic_slice_start_indices, dynamic_slice_shape + ) + values.append(sliced_value) + if method == "default": + sampling_grids = 2 * sampling_locations - 1 + elif method == "discrete": + sampling_grids = sampling_locations + else: + sampling_grids = 2 * sampling_locations - 1 + permuted_sampling_grids = keras.ops.transpose( + sampling_grids, axes=(0, 2, 1, 3, 4) + ) + flattened_sampling_grids = keras.ops.reshape( + permuted_sampling_grids, + ( + batch_size * num_heads, + num_queries, + num_levels_from_shape, + num_points_from_shape, + ), + ) + cum_points = keras.ops.concatenate( + [ + keras.ops.zeros((1,), dtype="int32"), + keras.ops.cumsum(keras.ops.array(num_points, dtype="int32")), + ] + ) + sampling_grids = [] + for i in range(num_levels): + start = cum_points[i] + current_level_num_points = num_points[i] + slice_start_indices = (0, 0, start, 0) + slice_shape = ( + keras.ops.shape(flattened_sampling_grids)[0], + keras.ops.shape(flattened_sampling_grids)[1], + current_level_num_points, + keras.ops.shape(flattened_sampling_grids)[3], + ) + sliced_grid = keras.ops.slice( + flattened_sampling_grids, slice_start_indices, slice_shape + ) + sampling_grids.append(sliced_grid) + sampling_values = [] + for level_id in range(num_levels): + if spatial_shapes is not None and len(spatial_shapes) == num_levels: + height, width = spatial_shapes[level_id] + else: + height = dynamic_spatial_shapes[level_id, 0] + width = dynamic_spatial_shapes[level_id, 1] + value_l_ = keras.ops.reshape( + values[level_id], + (batch_size * num_heads, hidden_dim, height, width), + ) + sampling_grid_l_ = sampling_grids[level_id] + if method == "default": + sampling_value_l_ = grid_sample( + data=value_l_, + grid=sampling_grid_l_, + align_corners=False, + height=height, + width=width, + ) + elif method == "discrete": + scale_factors = keras.ops.cast( + keras.ops.array([width, height]), + dtype=sampling_grid_l_.dtype, + ) + sampling_coord_float = sampling_grid_l_ * scale_factors + sampling_coord_x_int = keras.ops.cast( + keras.ops.floor(sampling_coord_float[..., 0] + 0.5), "int32" + ) + sampling_coord_y_int = keras.ops.cast( + keras.ops.floor(sampling_coord_float[..., 1] + 0.5), "int32" + ) + clamped_coord_x = keras.ops.clip(sampling_coord_x_int, 0, width - 1) + clamped_coord_y = keras.ops.clip( + sampling_coord_y_int, 0, height - 1 + ) + sampling_coord_stacked = keras.ops.stack( + [clamped_coord_x, clamped_coord_y], axis=-1 + ) + B_prime = batch_size * num_heads + Q_dim = num_queries + P_level = num_points[level_id] + sampling_coord = keras.ops.reshape( + sampling_coord_stacked, (B_prime, Q_dim * P_level, 2) + ) + value_l_permuted = keras.ops.transpose(value_l_, (0, 2, 3, 1)) + y_coords_for_gather = sampling_coord[ + ..., 1 + ] # (B_prime, Q_dim * P_level) + x_coords_for_gather = sampling_coord[ + ..., 0 + ] # (B_prime, Q_dim * P_level) + indices = y_coords_for_gather * width + x_coords_for_gather + indices = keras.ops.expand_dims(indices, axis=-1) + value_l_flat = keras.ops.reshape( + value_l_permuted, (B_prime, height * width, hidden_dim) + ) + gathered_values = keras.ops.take_along_axis( + value_l_flat, indices, axis=1 + ) + permuted_gathered_values = keras.ops.transpose( + gathered_values, axes=(0, 2, 1) + ) + sampling_value_l_ = keras.ops.reshape( + permuted_gathered_values, (B_prime, hidden_dim, Q_dim, P_level) + ) + else: + sampling_value_l_ = grid_sample( + data=value_l_, + grid=sampling_grid_l_, + align_corners=False, + height=height, + width=width, + ) + sampling_values.append(sampling_value_l_) + attention_weights = keras.ops.transpose( + attention_weights, axes=(0, 2, 1, 3) + ) + attention_weights = keras.ops.reshape( + attention_weights, + (batch_size * num_heads, 1, num_queries, sum(num_points)), + ) + concatenated_sampling_values = keras.ops.concatenate( + sampling_values, axis=-1 + ) + weighted_values = concatenated_sampling_values * attention_weights + summed_values = keras.ops.sum(weighted_values, axis=-1) + output = keras.ops.reshape( + summed_values, (batch_size, num_heads * hidden_dim, num_queries) + ) + return keras.ops.transpose(output, axes=(0, 2, 1)) + + +def weighting_function(max_num_bins, upsampling_factor, reg_scale): + """Generates weighting values for binning operations. + + This function creates a set of weighting values used for integral-based + bounding box regression. It is used in `DFineDecoder` to create a + projection matrix for converting corner predictions into distances. The + weights follow an exponential distribution around zero. + + Args: + max_num_bins: int, Maximum number of bins to generate. + upsampling_factor: Tensor, A scaling hyperparameter that controls the + range of the bins used for integral-based bounding box regression. + reg_scale: float, Regularization scale factor. + + Returns: + Tensor: Weighting values of shape `[max_num_bins]`. + """ + upper_bound1 = abs(upsampling_factor[0]) * abs(reg_scale) + upper_bound2 = abs(upsampling_factor[0]) * abs(reg_scale) * 2 + step = (upper_bound1 + 1) ** (2 / (max_num_bins - 2)) + left_values = [ + -((step) ** i) + 1 for i in range(max_num_bins // 2 - 1, 0, -1) + ] + right_values = [(step) ** i - 1 for i in range(1, max_num_bins // 2)] + values = ( + [-upper_bound2] + + left_values + + [ + keras.ops.zeros_like( + keras.ops.expand_dims(upsampling_factor[0], axis=0) + ) + ] + + right_values + + [upper_bound2] + ) + values = keras.ops.concatenate(values, 0) + return values + + +def distance2bbox(points, distance, reg_scale): + """Converts distance predictions to bounding boxes. + + This function converts distance predictions from anchor points to + bounding boxes. It is a key part of the regression head in `DFineDecoder`, + transforming the output of the integral-based prediction into final + bounding box coordinates. + + Args: + points: Tensor, Anchor points of shape `[..., 4]` where the last + dimension contains `[x, y, width, height]`. + distance: Tensor, Distance predictions of shape `[..., 4]` where + the last dimension contains `[left, top, right, bottom]` distances. + reg_scale: float, Regularization scale factor. + + Returns: + Tensor: Bounding boxes in center format of shape `[..., 4]` where + the last dimension contains `[center_x, center_y, width, height]`. + """ + reg_scale = abs(reg_scale) + top_left_x = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * ( + points[..., 2] / reg_scale + ) + top_left_y = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * ( + points[..., 3] / reg_scale + ) + bottom_right_x = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * ( + points[..., 2] / reg_scale + ) + bottom_right_y = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * ( + points[..., 3] / reg_scale + ) + bboxes = keras.ops.stack( + [top_left_x, top_left_y, bottom_right_x, bottom_right_y], axis=-1 + ) + return keras.utils.bounding_boxes.convert_format( + bboxes, + source="xyxy", + target="center_xywh", + dtype=points.dtype, + ) + + +def hungarian_assignment(cost_matrix, num_queries): + """Solves the linear assignment problem using the Hungarian algorithm. + + This function provides a JIT-compatible implementation of the Hungarian + (Munkres) algorithm using pure `keras.ops` operations. It is designed to + replace Scipy's `optimize.linear_sum_assignment` for backend-agnostic + end-to-end model compilation. The implementation uses a stateful loop + with `keras.ops.while_loop`, a state machine pattern with + `keras.ops.switch`, and tensor-only operations to ensure compatibility + with static graphs and standard accelerators. + + Args: + cost_matrix: Tensor, A 2D tensor of shape `(num_rows, num_cols)` + representing the cost of each potential assignment. `num_rows` + typically corresponds to the number of predictions (queries), + and `num_cols` corresponds to number of ground-truth targets. + num_queries: int, The fixed number of queries (predictions) from + the model, used to establish static shapes for JAX compatibility. + + Returns: + Tuple: A tuple `(row_ind, col_ind, valid_mask)` containing: + - row_ind: Tensor with integer indices for the rows (predictions). + - col_ind: Tensor with integer indices for the assigned columns + (targets). + - valid_mask: Boolean tensor where `True` indicates a valid + assignment that falls within the original (unpadded) cost + matrix dimensions. + """ + # Reference: https://github.com/bmc/munkres/blob/master/munkres.py + + original_num_rows, original_num_cols = keras.ops.shape(cost_matrix) + # Pad matrix to be square. + padded_cost_matrix = keras.ops.full( + (num_queries, num_queries), 1e9, dtype=cost_matrix.dtype + ) + padded_cost_matrix = keras.ops.slice_update( + padded_cost_matrix, + (0, 0), + cost_matrix, + ) + # Step 1: Subtract row minima. + cost = padded_cost_matrix - keras.ops.min( + padded_cost_matrix, axis=1, keepdims=True + ) + # Step 2: Subtract column minima. + cost = cost - keras.ops.min(cost, axis=0, keepdims=True) + + def body( + step, + cost, + starred_mask, + row_covered, + col_covered, + primed_mask, + path_start_row, + path_start_col, + ): + zero_mask = keras.ops.abs(cost) < 1e-6 + + def step_2(): + # Initial starring: Star zeros with no starred zero in their row or + # column. + s_mask = keras.ops.zeros_like(starred_mask, dtype="bool") + + def star_zeros(i, s_m): + def star_zeros_in_row(j, s_m_inner): + is_zero = zero_mask[i, j] + # Check if no starred zero in this row. + no_star_in_row = keras.ops.logical_not( + keras.ops.any(s_m_inner[i]) + ) + # Check if no starred zero in this column. + no_star_in_col = keras.ops.logical_not( + keras.ops.any(s_m_inner[:, j]) + ) + + def can_star(): + return keras.ops.scatter_update( + s_m_inner, + [[i, j]], + [True], + ) + + def cannot_star(): + return s_m_inner + + should_star = keras.ops.logical_and( + keras.ops.logical_and(is_zero, no_star_in_row), + no_star_in_col, + ) + return keras.ops.cond(should_star, can_star, cannot_star) + + return keras.ops.fori_loop( + 0, num_queries, star_zeros_in_row, s_m + ) + + s_mask = keras.ops.fori_loop(0, num_queries, star_zeros, s_mask) + return ( + 3, + cost, + s_mask, + keras.ops.zeros_like(row_covered), + keras.ops.zeros_like(col_covered), + keras.ops.zeros_like(primed_mask), + -1, + -1, + ) + + def step_3(): + # Step 3: Cover each column containing a starred zero. + new_col_covered = keras.ops.any(starred_mask, axis=0) + num_covered = keras.ops.sum( + keras.ops.cast(new_col_covered, "int32") + ) + return keras.ops.cond( + num_covered >= num_queries, + lambda: ( + 0, + cost, + starred_mask, + row_covered, + new_col_covered, + primed_mask, + -1, + -1, + ), # Done + lambda: ( + 4, + cost, + starred_mask, + row_covered, + new_col_covered, + primed_mask, + -1, + -1, + ), # Continue to step 4 + ) + + def step_4(): + # Step 4: Find a noncovered zero and prime it. + uncovered_zeros = keras.ops.logical_and( + keras.ops.logical_and( + zero_mask, + keras.ops.logical_not( + keras.ops.expand_dims(row_covered, 1) + ), + ), + keras.ops.logical_not(keras.ops.expand_dims(col_covered, 0)), + ) + + def has_uncovered_zero(): + uncovered_zeros_flat = keras.ops.reshape(uncovered_zeros, [-1]) + first_idx = keras.ops.argmax( + keras.ops.cast(uncovered_zeros_flat, "int32") + ) + r = first_idx // num_queries + c = first_idx % num_queries + p_mask = keras.ops.scatter_update(primed_mask, [[r, c]], [True]) + starred_in_row = starred_mask[r] + + def has_starred_in_row(): + star_col = keras.ops.argmax( + keras.ops.cast(starred_in_row, "int32") + ) + r_cov = keras.ops.scatter_update(row_covered, [[r]], [True]) + c_cov = keras.ops.scatter_update( + col_covered, [[star_col]], [False] + ) + return 4, cost, starred_mask, r_cov, c_cov, p_mask, -1, -1 + + def no_starred_in_row(): + return ( + 5, + cost, + starred_mask, + row_covered, + col_covered, + p_mask, + r, + c, + ) + + return keras.ops.cond( + keras.ops.any(starred_in_row), + has_starred_in_row, + no_starred_in_row, + ) + + def no_uncovered_zero(): + return ( + 6, + cost, + starred_mask, + row_covered, + col_covered, + primed_mask, + -1, + -1, + ) + + return keras.ops.cond( + keras.ops.any(uncovered_zeros), + has_uncovered_zero, + no_uncovered_zero, + ) + + def step_5(): + # Step 5: Construct a series of alternating starred and primed + # zeros. + path = keras.ops.full((num_queries * 2, 2), -1, dtype="int32") + path = keras.ops.scatter_update( + path, [[0]], [[path_start_row, path_start_col]] + ) + + def build_path(count, path_state): + def continue_building(cnt, p): + current_col = p[cnt - 1, 1] + starred_in_col = starred_mask[:, current_col] + + def found_star(): + star_row = keras.ops.argmax( + keras.ops.cast(starred_in_col, "int32") + ) + p1 = keras.ops.scatter_update( + p, [[cnt]], [[star_row, current_col]] + ) + primed_in_star_row = primed_mask[star_row] + prime_col = keras.ops.argmax( + keras.ops.cast(primed_in_star_row, "int32") + ) + p2 = keras.ops.scatter_update( + p1, [[cnt + 1]], [[star_row, prime_col]] + ) + return cnt + 2, p2 + + def no_star(): + # Path complete. + return cnt, p + + return keras.ops.cond( + keras.ops.any(starred_in_col), found_star, no_star + ) + + def should_continue(cnt, p): + return keras.ops.logical_and( + cnt < num_queries * 2, p[cnt - 1, 1] >= 0 + ) + + return keras.ops.while_loop( + should_continue, + continue_building, + (count, path_state), + maximum_iterations=num_queries, + ) + + path_count, final_path = build_path(1, path) + s_mask = starred_mask + + def update_star_mask(i, mask): + def apply_update(): + row_idx = final_path[i, 0] + col_idx = final_path[i, 1] + valid_row = keras.ops.logical_and( + row_idx >= 0, row_idx < num_queries + ) + valid_col = keras.ops.logical_and( + col_idx >= 0, col_idx < num_queries + ) + valid_indices = keras.ops.logical_and(valid_row, valid_col) + + def do_update(): + current_value = mask[row_idx, col_idx] + new_value = keras.ops.logical_not(current_value) + return keras.ops.scatter_update( + mask, [[row_idx, col_idx]], [new_value] + ) + + def skip_update(): + return mask + + return keras.ops.cond(valid_indices, do_update, skip_update) + + def skip_iteration(): + return mask + + should_process = i < path_count + return keras.ops.cond( + should_process, apply_update, skip_iteration + ) + + s_mask = keras.ops.fori_loop( + 0, num_queries * 2, update_star_mask, s_mask + ) + return ( + 3, + cost, + s_mask, + keras.ops.zeros_like(row_covered), + keras.ops.zeros_like(col_covered), + keras.ops.zeros_like(primed_mask), + -1, + -1, + ) + + def step_6(): + # Step 6: Add/subtract minimum uncovered value. + uncovered_mask = keras.ops.logical_and( + keras.ops.logical_not(keras.ops.expand_dims(row_covered, 1)), + keras.ops.logical_not(keras.ops.expand_dims(col_covered, 0)), + ) + min_val = keras.ops.min(keras.ops.where(uncovered_mask, cost, 1e9)) + # Add to covered rows. + row_adjustment = keras.ops.where( + keras.ops.expand_dims(row_covered, 1), min_val, 0.0 + ) + # Subtract from uncovered columns. + col_adjustment = keras.ops.where( + keras.ops.expand_dims(col_covered, 0), 0.0, -min_val + ) + new_cost = cost + row_adjustment + col_adjustment + return ( + 4, + new_cost, + starred_mask, + row_covered, + col_covered, + primed_mask, + -1, + -1, + ) + + return keras.ops.switch( + step - 2, [step_2, step_3, step_4, step_5, step_6] + ) + + # Main algorithm loop. + init_state = ( + 2, # Start at step 2 + cost, + keras.ops.zeros( + (num_queries, num_queries), dtype="bool" + ), # starred_mask + keras.ops.zeros((num_queries,), dtype="bool"), # row_covered + keras.ops.zeros((num_queries,), dtype="bool"), # col_covered + keras.ops.zeros( + (num_queries, num_queries), dtype="bool" + ), # primed_mask + -1, # path_start_row + -1, # path_start_col + ) + final_state = keras.ops.while_loop( + lambda step, *_: step > 0, + body, + init_state, + maximum_iterations=num_queries * num_queries, + ) + final_starred_mask = final_state[2] + row_ind = keras.ops.arange(num_queries, dtype="int32") + col_ind = keras.ops.argmax( + keras.ops.cast(final_starred_mask, "int32"), axis=1 + ) + valid_mask = keras.ops.logical_and( + row_ind < original_num_rows, col_ind < original_num_cols + ) + return row_ind, col_ind, valid_mask diff --git a/keras_hub/src/models/d_fine/d_fine_utils_test.py b/keras_hub/src/models/d_fine/d_fine_utils_test.py new file mode 100644 index 0000000000..e047dd350d --- /dev/null +++ b/keras_hub/src/models/d_fine/d_fine_utils_test.py @@ -0,0 +1,65 @@ +import keras +import numpy as np +from absl.testing import parameterized +from scipy.optimize import linear_sum_assignment + +from keras_hub.src.models.d_fine.d_fine_utils import hungarian_assignment +from keras_hub.src.tests.test_case import TestCase + + +class DFineUtilsTest(TestCase): + @parameterized.named_parameters( + ( + "square_matrix", + np.array([[4, 1, 3], [2, 0, 5], [3, 2, 2]], dtype="float32"), + ), + ( + "rectangular_more_rows", + np.array( + [[10, 20, 30], [40, 50, 5], [15, 25, 35], [5, 10, 15]], + dtype="float32", + ), + ), + ( + "rectangular_more_cols", + np.array( + [[10, 20, 30, 40], [50, 5, 15, 25], [35, 45, 55, 65]], + dtype="float32", + ), + ), + ( + "duplicate_min_costs", + np.array([[1, 1, 2], [2, 3, 1], [3, 1, 4]], dtype="float32"), + ), + ( + "larger_matrix", + np.array( + [ + [9, 2, 7, 8, 4], + [6, 4, 3, 7, 5], + [5, 8, 1, 8, 2], + [7, 6, 9, 4, 1], + [3, 5, 8, 5, 4], + ], + dtype="float32", + ), + ), + ) + def test_hungarian_assignment_equivalence(self, cost_matrix): + # Test if the Keras version is equivalent to SciPy's + # `optimize.linear_sum_assignment`. + num_queries = max(cost_matrix.shape) + keras_row_ind, keras_col_ind, keras_valid_mask = hungarian_assignment( + keras.ops.convert_to_tensor(cost_matrix), + num_queries, + ) + scipy_row_ind, scipy_col_ind = linear_sum_assignment(cost_matrix) + scipy_cost = cost_matrix[scipy_row_ind, scipy_col_ind].sum() + valid_row_ind = keras.ops.convert_to_numpy( + keras_row_ind[keras_valid_mask] + ) + valid_col_ind = keras.ops.convert_to_numpy( + keras_col_ind[keras_valid_mask] + ) + keras_cost = cost_matrix[valid_row_ind, valid_col_ind].sum() + self.assertAllClose(keras_cost, scipy_cost) diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py b/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py index 12407b0f75..1eb7ff657b 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_backbone.py @@ -157,7 +157,10 @@ def __init__( if stage_name in self.out_features } super().__init__( - inputs=pixel_values, outputs=feature_maps_output, **kwargs + inputs=pixel_values, + outputs=feature_maps_output, + dtype=dtype, + **kwargs, ) # === Config === diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py b/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py index 31c48d2c29..193709fb7e 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_backbone_test.py @@ -111,8 +111,6 @@ def test_backbone_basics( init_kwargs=test_kwargs, input_data=self.input_data, expected_output_shape=expected_shapes, - run_mixed_precision_check=False, - run_data_format_check=False, ) @pytest.mark.large diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py b/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py index 9b4f2c98bf..50c2e63551 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_encoder.py @@ -56,9 +56,10 @@ def __init__( use_learnable_affine_block, data_format=None, channel_axis=None, + dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.stage_in_channels = stage_in_channels self.stage_mid_channels = stage_mid_channels self.stage_out_channels = stage_out_channels @@ -90,7 +91,7 @@ def __init__( name=f"{self.name}_stage_{stage_idx}" if self.name else f"stage_{stage_idx}", - dtype=self.dtype, + dtype=dtype, ) self.stages_list.append(stage_layer) diff --git a/keras_hub/src/models/hgnetv2/hgnetv2_layers.py b/keras_hub/src/models/hgnetv2/hgnetv2_layers.py index 4424e45283..0d571ad634 100644 --- a/keras_hub/src/models/hgnetv2/hgnetv2_layers.py +++ b/keras_hub/src/models/hgnetv2/hgnetv2_layers.py @@ -17,8 +17,8 @@ class HGNetV2LearnableAffineBlock(keras.layers.Layer): **kwargs: Additional keyword arguments passed to the parent class. """ - def __init__(self, scale_value=1.0, bias_value=0.0, **kwargs): - super().__init__(**kwargs) + def __init__(self, scale_value=1.0, bias_value=0.0, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) self.scale_value = scale_value self.bias_value = bias_value @@ -87,9 +87,10 @@ def __init__( use_learnable_affine_block=False, data_format=None, channel_axis=None, + dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size @@ -104,6 +105,7 @@ def __init__( padding=((pad, pad), (pad, pad)), data_format=self.data_format, name=f"{self.name}_pad" if self.name else None, + dtype=self.dtype_policy, ) self.convolution = keras.layers.Conv2D( filters=self.out_channels, @@ -156,7 +158,8 @@ def __init__( ) else: self.lab = keras.layers.Identity( - name=f"{self.name}_identity_lab" if self.name else None + name=f"{self.name}_identity_lab" if self.name else None, + dtype=self.dtype_policy, ) def build(self, input_shape): @@ -230,9 +233,10 @@ def __init__( use_learnable_affine_block=False, data_format=None, channel_axis=None, + dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size @@ -327,9 +331,10 @@ def __init__( use_learnable_affine_block, data_format=None, channel_axis=None, + dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.stem_channels = stem_channels self.hidden_act = hidden_act self.use_learnable_affine_block = use_learnable_affine_block @@ -352,6 +357,7 @@ def __init__( padding=((0, 1), (0, 1)), data_format=self.data_format, name=f"{self.name}_padding1" if self.name else "padding1", + dtype=self.dtype_policy, ) self.stem2a_layer = HGNetV2ConvLayer( in_channels=self.stem_channels[1], @@ -370,6 +376,7 @@ def __init__( padding=((0, 1), (0, 1)), data_format=self.data_format, name=f"{self.name}_padding2" if self.name else "padding2", + dtype=self.dtype_policy, ) self.stem2b_layer = HGNetV2ConvLayer( in_channels=self.stem_channels[1] // 2, @@ -390,10 +397,12 @@ def __init__( padding="valid", data_format=self.data_format, name=f"{self.name}_pool" if self.name else "pool", + dtype=self.dtype_policy, ) self.concatenate_layer = keras.layers.Concatenate( axis=self.channel_axis, name=f"{self.name}_concat" if self.name else "concat", + dtype=self.dtype_policy, ) self.stem3_layer = HGNetV2ConvLayer( in_channels=self.stem_channels[1] * 2, @@ -550,9 +559,10 @@ def __init__( use_learnable_affine_block=False, data_format=None, channel_axis=None, + dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.in_channels_arg = in_channels self.middle_channels = middle_channels self.out_channels = out_channels @@ -635,23 +645,27 @@ def __init__( self.drop_path_rate, noise_shape=(None, 1, 1, 1), name=f"{self.name}_drop_path" if self.name else "drop_path", + dtype=self.dtype_policy, ) else: self.drop_path_layer = keras.layers.Identity( name=f"{self.name}_identity_drop_path" if self.name - else "identity_drop_path" + else "identity_drop_path", + dtype=self.dtype_policy, ) self.concatenate_layer = keras.layers.Concatenate( axis=self.channel_axis, name=f"{self.name}_concat" if self.name else "concat", + dtype=self.dtype_policy, ) if self.residual: self.add_layer = keras.layers.Add( name=f"{self.name}_add_residual" if self.name - else "add_residual" + else "add_residual", + dtype=self.dtype_policy, ) def build(self, input_shape): @@ -794,9 +808,10 @@ def __init__( drop_path: float = 0.0, data_format=None, channel_axis=None, + dtype=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.stage_in_channels = stage_in_channels self.stage_mid_channels = stage_mid_channels self.stage_out_channels = stage_out_channels @@ -842,7 +857,8 @@ def __init__( self.downsample_layer = keras.layers.Identity( name=f"{self.name}_identity_downsample" if self.name - else "identity_downsample" + else "identity_downsample", + dtype=self.dtype_policy, ) self.blocks_list = [] diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index 43eb8050c3..f70ab78840 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -499,6 +499,7 @@ def run_vision_backbone_test( init_kwargs, input_data, expected_output_shape, + spatial_output_keys=None, expected_pyramid_output_keys=None, expected_pyramid_image_sizes=None, variable_length_data=None, @@ -557,12 +558,47 @@ def run_vision_backbone_test( input_data = ops.transpose(input_data, axes=(2, 0, 1)) elif len(input_data_shape) == 4: input_data = ops.transpose(input_data, axes=(0, 3, 1, 2)) - if len(expected_output_shape) == 3: + if isinstance(expected_output_shape, dict): + # Handle dictionary of shapes. + transposed_shapes = {} + for key, shape in expected_output_shape.items(): + if spatial_output_keys and key not in spatial_output_keys: + transposed_shapes[key] = shape + continue + if len(shape) == 3: + transposed_shapes[key] = (shape[0], shape[2], shape[1]) + elif len(shape) == 4: + transposed_shapes[key] = ( + shape[0], + shape[3], + shape[1], + shape[2], + ) + else: + transposed_shapes[key] = shape + expected_output_shape = transposed_shapes + elif len(expected_output_shape) == 3: x = expected_output_shape expected_output_shape = (x[0], x[2], x[1]) elif len(expected_output_shape) == 4: x = expected_output_shape expected_output_shape = (x[0], x[3], x[1], x[2]) + original_init_kwargs = init_kwargs.copy() + init_kwargs = original_init_kwargs.copy() + # Handle nested `keras.Model` instances passed within `init_kwargs`. + for k, v in init_kwargs.items(): + if isinstance(v, keras.Model) and hasattr(v, "data_format"): + config = v.get_config() + config["data_format"] = "channels_first" + if ( + "image_shape" in config + and config["image_shape"] is not None + and len(config["image_shape"]) == 3 + ): + config["image_shape"] = tuple( + reversed(config["image_shape"]) + ) + init_kwargs[k] = v.__class__.from_config(config) if "image_shape" in init_kwargs: init_kwargs = init_kwargs.copy() init_kwargs["image_shape"] = tuple( diff --git a/keras_hub/src/utils/transformers/convert_t5gemma_test.py b/keras_hub/src/utils/transformers/convert_t5gemma_test.py index 939984eba5..36f40bf37d 100644 --- a/keras_hub/src/utils/transformers/convert_t5gemma_test.py +++ b/keras_hub/src/utils/transformers/convert_t5gemma_test.py @@ -1,3 +1,4 @@ +import keras import pytest from keras_hub.src.models.backbone import Backbone @@ -7,6 +8,13 @@ from keras_hub.src.tests.test_case import TestCase +# NOTE: This test is valid and should pass locally. It is skipped only on +# TensorFlow GPU CI because of ResourceExhaustedError (OOM). Revisit once +# TensorFlow GPU CI runs without hitting OOM. +@pytest.mark.skipif( + keras.backend.backend() == "tensorflow", + reason="TensorFlow GPU CI OOM (ResourceExhaustedError)", +) class TestTask(TestCase): @pytest.mark.large def test_convert_tiny_preset(self): diff --git a/tools/checkpoint_conversion/convert_d_fine_checkpoints.py b/tools/checkpoint_conversion/convert_d_fine_checkpoints.py new file mode 100644 index 0000000000..cb090df545 --- /dev/null +++ b/tools/checkpoint_conversion/convert_d_fine_checkpoints.py @@ -0,0 +1,728 @@ +import json +import os +import shutil + +import keras +import numpy as np +import torch +from absl import app +from absl import flags +from huggingface_hub import hf_hub_download +from PIL import Image +from safetensors.torch import load_file +from transformers import AutoImageProcessor +from transformers import DFineForObjectDetection + +import keras_hub +from keras_hub.src.models.d_fine.d_fine_backbone import DFineBackbone +from keras_hub.src.models.d_fine.d_fine_image_converter import ( + DFineImageConverter, +) +from keras_hub.src.models.d_fine.d_fine_layers import DFineConvNormLayer +from keras_hub.src.models.d_fine.d_fine_object_detector import ( + DFineObjectDetector, +) +from keras_hub.src.models.d_fine.d_fine_object_detector_preprocessor import ( + DFineObjectDetectorPreprocessor, +) +from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone +from keras_hub.src.models.hgnetv2.hgnetv2_layers import HGNetV2ConvLayer +from keras_hub.src.models.hgnetv2.hgnetv2_layers import ( + HGNetV2LearnableAffineBlock, +) + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "preset", + None, + "Must be one of 'dfine_large_coco', 'dfine_xlarge_coco', " + "'dfine_small_coco', 'dfine_nano_coco', 'dfine_medium_coco', " + "'dfine_small_obj365', 'dfine_medium_obj365', 'dfine_large_obj365', " + "'dfine_xlarge_obj365', 'dfine_small_obj2coco', 'dfine_medium_obj2coco', " + "'dfine_large_obj2coco_e25', or 'dfine_xlarge_obj2coco'", + required=True, +) +flags.DEFINE_string( + "upload_uri", + None, + 'Optional upload URI, e.g., "kaggle://keras/dfine/keras/dfine_xlarge_coco"', + required=False, +) + +PRESET_MAP = { + "dfine_large_coco": "ustc-community/dfine-large-coco", + "dfine_xlarge_coco": "ustc-community/dfine-xlarge-coco", + "dfine_small_coco": "ustc-community/dfine-small-coco", + "dfine_medium_coco": "ustc-community/dfine-medium-coco", + "dfine_nano_coco": "ustc-community/dfine-nano-coco", + "dfine_small_obj365": "ustc-community/dfine-small-obj365", + "dfine_medium_obj365": "ustc-community/dfine-medium-obj365", + "dfine_large_obj365": "ustc-community/dfine-large-obj365", + "dfine_xlarge_obj365": "ustc-community/dfine-xlarge-obj365", + "dfine_small_obj2coco": "ustc-community/dfine-small-obj2coco", + "dfine_medium_obj2coco": "ustc-community/dfine-medium-obj2coco", + "dfine_large_obj2coco_e25": "ustc-community/dfine-large-obj2coco-e25", + "dfine_xlarge_obj2coco": "ustc-community/dfine-xlarge-obj2coco", +} + + +def load_pytorch_model(hf_preset): + model_path = hf_hub_download( + repo_id=hf_preset, + filename="model.safetensors", + cache_dir="./hf_models", + ) + state_dict = load_file(model_path) + return state_dict + + +def get_keras_model(config, hf_preset): + backbone_config = config["backbone_config"] + stackwise_stage_filters = [ + [ + backbone_config["stage_in_channels"][i], + backbone_config["stage_mid_channels"][i], + backbone_config["stage_out_channels"][i], + backbone_config["stage_num_blocks"][i], + backbone_config["stage_numb_of_layers"][i], + backbone_config["stage_kernel_size"][i], + ] + for i in range(len(backbone_config["stage_in_channels"])) + ] + hgnetv2_params = { + "depths": backbone_config["depths"], + "embedding_size": backbone_config["embedding_size"], + "hidden_sizes": backbone_config["hidden_sizes"], + "stem_channels": backbone_config["stem_channels"], + "hidden_act": backbone_config["hidden_act"], + "use_learnable_affine_block": backbone_config[ + "use_learnable_affine_block" + ], + "stackwise_stage_filters": stackwise_stage_filters, + "apply_downsample": backbone_config["stage_downsample"], + "use_lightweight_conv_block": backbone_config["stage_light_block"], + "out_features": backbone_config["out_features"], + } + hgnetv2_backbone = HGNetV2Backbone(**hgnetv2_params) + dfine_params = { + "decoder_in_channels": config["decoder_in_channels"], + "encoder_hidden_dim": config["encoder_hidden_dim"], + "num_denoising": config["num_denoising"], + "num_labels": len(config["id2label"]), + "learn_initial_query": config["learn_initial_query"], + "num_queries": config["num_queries"], + "anchor_image_size": (640, 640), + "feat_strides": config["feat_strides"], + "num_feature_levels": config["num_feature_levels"], + "hidden_dim": config["d_model"], + "encoder_in_channels": config["encoder_in_channels"], + "encode_proj_layers": config["encode_proj_layers"], + "num_attention_heads": config["encoder_attention_heads"], + "encoder_ffn_dim": config["encoder_ffn_dim"], + "num_encoder_layers": config["encoder_layers"], + "hidden_expansion": config["hidden_expansion"], + "depth_multiplier": config["depth_mult"], + "eval_idx": config["eval_idx"], + "label_noise_ratio": config.get("label_noise_ratio", 0.5), + "box_noise_scale": config.get("box_noise_scale", 1.0), + "num_decoder_layers": config["decoder_layers"], + "decoder_attention_heads": config["decoder_attention_heads"], + "decoder_ffn_dim": config["decoder_ffn_dim"], + "decoder_n_points": config["decoder_n_points"], + "lqe_hidden_dim": config["lqe_hidden_dim"], + "num_lqe_layers": config["lqe_layers"], + "image_shape": (None, None, 3), + "out_features": backbone_config["out_features"], + "seed": 0, + } + backbone = DFineBackbone(backbone=hgnetv2_backbone, **dfine_params) + config_path = keras.utils.get_file( + origin=f"https://huggingface.co/{hf_preset}/raw/main/preprocessor_config.json", # noqa: E501 + cache_subdir=f"hf_models/{hf_preset}", + ) + with open(config_path, "r") as f: + preprocessor_config = json.load(f) + scale = None + offset = None + if preprocessor_config.get("do_rescale", False): + scale = preprocessor_config.get("rescale_factor") + if preprocessor_config.get("do_normalize", False): + mean = preprocessor_config["image_mean"] + std = preprocessor_config["image_std"] + if isinstance(scale, (float, int)): + scale = [scale / s for s in std] + else: + scale = [1.0 / s for s in std] + offset = [-m / s for m, s in zip(mean, std)] + image_converter = DFineImageConverter( + image_size=(640, 640), + scale=scale, + offset=offset, + crop_to_aspect_ratio=True, + ) + preprocessor = DFineObjectDetectorPreprocessor( + image_converter=image_converter, + ) + model = DFineObjectDetector( + backbone=backbone, + num_classes=len(config["id2label"]), + bounding_box_format="yxyx", + preprocessor=preprocessor, + matcher_class_cost=config["matcher_class_cost"], + matcher_bbox_cost=config["matcher_bbox_cost"], + matcher_giou_cost=config["matcher_giou_cost"], + use_focal_loss=config["use_focal_loss"], + matcher_alpha=config["matcher_alpha"], + matcher_gamma=config["matcher_gamma"], + weight_loss_vfl=config["weight_loss_vfl"], + weight_loss_bbox=config["weight_loss_bbox"], + weight_loss_ciou=config["weight_loss_giou"], + weight_loss_fgl=config.get("weight_loss_fgl", 0.15), + weight_loss_ddf=config.get("weight_loss_ddf", 1.5), + ) + return model + + +def set_conv_norm_weights(state_dict, prefix, k_conv): + if isinstance(k_conv, HGNetV2ConvLayer): + pt_conv_suffix = "convolution" + pt_norm_suffix = "normalization" + lab_suffix = "lab" + elif isinstance(k_conv, DFineConvNormLayer): + pt_conv_suffix = "conv" + pt_norm_suffix = "norm" + lab_suffix = None + else: + raise TypeError(f"Unsupported Keras ConvNormLayer type: {type(k_conv)}") + conv_weight_key = f"{prefix}.{pt_conv_suffix}.weight" + if conv_weight_key in state_dict: + k_conv.convolution.kernel.assign( + state_dict[conv_weight_key].permute(2, 3, 1, 0).numpy() + ) + norm_weight_key = f"{prefix}.{pt_norm_suffix}.weight" + norm_bias_key = f"{prefix}.{pt_norm_suffix}.bias" + norm_mean_key = f"{prefix}.{pt_norm_suffix}.running_mean" + norm_var_key = f"{prefix}.{pt_norm_suffix}.running_var" + if all( + key in state_dict + for key in [norm_weight_key, norm_bias_key, norm_mean_key, norm_var_key] + ): + k_conv.normalization.set_weights( + [ + state_dict[norm_weight_key].numpy(), + state_dict[norm_bias_key].numpy(), + state_dict[norm_mean_key].numpy(), + state_dict[norm_var_key].numpy(), + ] + ) + if isinstance(k_conv, HGNetV2ConvLayer) and isinstance( + k_conv.lab, HGNetV2LearnableAffineBlock + ): + lab_scale_key = f"{prefix}.{lab_suffix}.scale" + lab_bias_key = f"{prefix}.{lab_suffix}.bias" + if lab_scale_key in state_dict and lab_bias_key in state_dict: + k_conv.lab.scale.assign(state_dict[lab_scale_key].item()) + k_conv.lab.bias.assign(state_dict[lab_bias_key].item()) + + +def transfer_hgnet_backbone_weights(state_dict, k_backbone): + backbone_prefix = "model.backbone.model." + embedder_prefix = f"{backbone_prefix}embedder." + for stem in ["stem1", "stem2a", "stem2b", "stem3", "stem4"]: + k_conv = getattr(k_backbone.backbone.embedder_layer, f"{stem}_layer") + set_conv_norm_weights(state_dict, f"{embedder_prefix}{stem}", k_conv) + + stages_prefix = f"{backbone_prefix}encoder.stages." + for stage_idx, stage in enumerate( + k_backbone.backbone.encoder_layer.stages_list + ): + prefix = f"{stages_prefix}{stage_idx}." + if hasattr(stage, "downsample_layer") and not isinstance( + stage.downsample_layer, keras.layers.Identity + ): + set_conv_norm_weights( + state_dict, f"{prefix}downsample", stage.downsample_layer + ) + for block_idx, block in enumerate(stage.blocks_list): + block_prefix = f"{prefix}blocks.{block_idx}." + for layer_idx, layer in enumerate(block.layer_list): + if hasattr(layer, "conv1_layer"): + set_conv_norm_weights( + state_dict, + f"{block_prefix}layers.{layer_idx}.conv1", + layer.conv1_layer, + ) + set_conv_norm_weights( + state_dict, + f"{block_prefix}layers.{layer_idx}.conv2", + layer.conv2_layer, + ) + else: + set_conv_norm_weights( + state_dict, f"{block_prefix}layers.{layer_idx}", layer + ) + set_conv_norm_weights( + state_dict, + f"{block_prefix}aggregation.0", + block.aggregation_squeeze_conv, + ) + set_conv_norm_weights( + state_dict, + f"{block_prefix}aggregation.1", + block.aggregation_excitation_conv, + ) + + +def transfer_hybrid_encoder_weights(state_dict, k_encoder): + for i, lateral_conv in enumerate(k_encoder.lateral_convs): + set_conv_norm_weights( + state_dict, f"model.encoder.lateral_convs.{i}", lateral_conv + ) + + for i, fpn_block in enumerate(k_encoder.fpn_blocks): + prefix = f"model.encoder.fpn_blocks.{i}" + set_conv_norm_weights(state_dict, f"{prefix}.conv1", fpn_block.conv1) + set_conv_norm_weights(state_dict, f"{prefix}.conv2", fpn_block.conv2) + set_conv_norm_weights(state_dict, f"{prefix}.conv3", fpn_block.conv3) + set_conv_norm_weights(state_dict, f"{prefix}.conv4", fpn_block.conv4) + for j, bottleneck in enumerate(fpn_block.csp_rep1.bottleneck_layers): + set_conv_norm_weights( + state_dict, + f"{prefix}.csp_rep1.bottlenecks.{j}.conv1", + bottleneck.conv1_layer, + ) + set_conv_norm_weights( + state_dict, + f"{prefix}.csp_rep1.bottlenecks.{j}.conv2", + bottleneck.conv2_layer, + ) + set_conv_norm_weights( + state_dict, f"{prefix}.csp_rep1.conv1", fpn_block.csp_rep1.conv1 + ) + set_conv_norm_weights( + state_dict, f"{prefix}.csp_rep1.conv2", fpn_block.csp_rep1.conv2 + ) + for j, bottleneck in enumerate(fpn_block.csp_rep2.bottleneck_layers): + set_conv_norm_weights( + state_dict, + f"{prefix}.csp_rep2.bottlenecks.{j}.conv1", + bottleneck.conv1_layer, + ) + set_conv_norm_weights( + state_dict, + f"{prefix}.csp_rep2.bottlenecks.{j}.conv2", + bottleneck.conv2_layer, + ) + set_conv_norm_weights( + state_dict, f"{prefix}.csp_rep2.conv1", fpn_block.csp_rep2.conv1 + ) + set_conv_norm_weights( + state_dict, f"{prefix}.csp_rep2.conv2", fpn_block.csp_rep2.conv2 + ) + + for i, down_conv in enumerate(k_encoder.downsample_convs): + prefix = f"model.encoder.downsample_convs.{i}" + set_conv_norm_weights(state_dict, f"{prefix}.conv1", down_conv.conv1) + set_conv_norm_weights(state_dict, f"{prefix}.conv2", down_conv.conv2) + + for i, pan_block in enumerate(k_encoder.pan_blocks): + prefix = f"model.encoder.pan_blocks.{i}" + set_conv_norm_weights(state_dict, f"{prefix}.conv1", pan_block.conv1) + set_conv_norm_weights(state_dict, f"{prefix}.conv2", pan_block.conv2) + set_conv_norm_weights(state_dict, f"{prefix}.conv3", pan_block.conv3) + set_conv_norm_weights(state_dict, f"{prefix}.conv4", pan_block.conv4) + for j, bottleneck in enumerate(pan_block.csp_rep1.bottleneck_layers): + set_conv_norm_weights( + state_dict, + f"{prefix}.csp_rep1.bottlenecks.{j}.conv1", + bottleneck.conv1_layer, + ) + set_conv_norm_weights( + state_dict, + f"{prefix}.csp_rep1.bottlenecks.{j}.conv2", + bottleneck.conv2_layer, + ) + set_conv_norm_weights( + state_dict, f"{prefix}.csp_rep1.conv1", pan_block.csp_rep1.conv1 + ) + set_conv_norm_weights( + state_dict, f"{prefix}.csp_rep1.conv2", pan_block.csp_rep1.conv2 + ) + for j, bottleneck in enumerate(pan_block.csp_rep2.bottleneck_layers): + set_conv_norm_weights( + state_dict, + f"{prefix}.csp_rep2.bottlenecks.{j}.conv1", + bottleneck.conv1_layer, + ) + set_conv_norm_weights( + state_dict, + f"{prefix}.csp_rep2.bottlenecks.{j}.conv2", + bottleneck.conv2_layer, + ) + set_conv_norm_weights( + state_dict, f"{prefix}.csp_rep2.conv1", pan_block.csp_rep2.conv1 + ) + set_conv_norm_weights( + state_dict, f"{prefix}.csp_rep2.conv2", pan_block.csp_rep2.conv2 + ) + + +def transfer_transformer_encoder_weights(state_dict, k_encoder): + for i, layer in enumerate(k_encoder.encoder[0].encoder_layer): + prefix = f"model.encoder.encoder.0.layers.{i}" + for proj in ["q", "k", "v"]: + pt_weight = state_dict[ + f"{prefix}.self_attn.{proj}_proj.weight" + ].T.numpy() + head_dim = ( + k_encoder.encoder_hidden_dim // k_encoder.num_attention_heads + ) + k_weight = pt_weight.reshape( + k_encoder.encoder_hidden_dim, + k_encoder.num_attention_heads, + head_dim, + ) + k_proj = getattr(layer.self_attn, f"{proj}_proj") + k_proj.weights[0].assign(k_weight) + k_proj.weights[1].assign( + state_dict[f"{prefix}.self_attn.{proj}_proj.bias"] + .numpy() + .reshape(k_encoder.num_attention_heads, head_dim) + ) + layer.self_attn.out_proj.weights[0].assign( + state_dict[f"{prefix}.self_attn.out_proj.weight"].T.numpy() + ) + layer.self_attn.out_proj.weights[1].assign( + state_dict[f"{prefix}.self_attn.out_proj.bias"].numpy() + ) + layer.self_attn_layer_norm.set_weights( + [ + state_dict[f"{prefix}.self_attn_layer_norm.weight"].numpy(), + state_dict[f"{prefix}.self_attn_layer_norm.bias"].numpy(), + ] + ) + layer.fc1.weights[0].assign( + state_dict[f"{prefix}.fc1.weight"].T.numpy() + ) + layer.fc1.weights[1].assign(state_dict[f"{prefix}.fc1.bias"].numpy()) + layer.fc2.weights[0].assign( + state_dict[f"{prefix}.fc2.weight"].T.numpy() + ) + layer.fc2.weights[1].assign(state_dict[f"{prefix}.fc2.bias"].numpy()) + layer.final_layer_norm.set_weights( + [ + state_dict[f"{prefix}.final_layer_norm.weight"].numpy(), + state_dict[f"{prefix}.final_layer_norm.bias"].numpy(), + ] + ) + + +def transfer_decoder_weights(state_dict, k_decoder): + for i, layer in enumerate(k_decoder.decoder_layers): + prefix = f"model.decoder.layers.{i}" + for proj in ["q", "k", "v"]: + pt_weight = state_dict[ + f"{prefix}.self_attn.{proj}_proj.weight" + ].T.numpy() + head_dim = k_decoder.hidden_dim // k_decoder.decoder_attention_heads + k_weight = pt_weight.reshape( + k_decoder.hidden_dim, + k_decoder.decoder_attention_heads, + head_dim, + ) + k_proj = getattr(layer.self_attn, f"{proj}_proj") + k_proj.weights[0].assign(k_weight) + k_proj.weights[1].assign( + state_dict[f"{prefix}.self_attn.{proj}_proj.bias"] + .numpy() + .reshape(k_decoder.decoder_attention_heads, head_dim) + ) + layer.self_attn.out_proj.weights[0].assign( + state_dict[f"{prefix}.self_attn.out_proj.weight"].T.numpy() + ) + layer.self_attn.out_proj.weights[1].assign( + state_dict[f"{prefix}.self_attn.out_proj.bias"].numpy() + ) + layer.self_attn_layer_norm.set_weights( + [ + state_dict[f"{prefix}.self_attn_layer_norm.weight"].numpy(), + state_dict[f"{prefix}.self_attn_layer_norm.bias"].numpy(), + ] + ) + pytorch_offset_weight = state_dict[ + f"{prefix}.encoder_attn.sampling_offsets.weight" + ].T + keras_offset_kernel = layer.encoder_attn.sampling_offsets.kernel + keras_offset_kernel.assign( + pytorch_offset_weight.numpy().reshape(keras_offset_kernel.shape) + ) + pytorch_offset_bias = state_dict[ + f"{prefix}.encoder_attn.sampling_offsets.bias" + ] + keras_offset_bias = layer.encoder_attn.sampling_offsets.bias + keras_offset_bias.assign( + pytorch_offset_bias.numpy().reshape(keras_offset_bias.shape) + ) + pytorch_attn_weight = state_dict[ + f"{prefix}.encoder_attn.attention_weights.weight" + ].T + keras_attn_kernel = layer.encoder_attn.attention_weights.kernel + keras_attn_kernel.assign( + pytorch_attn_weight.numpy().reshape(keras_attn_kernel.shape) + ) + pytorch_attn_bias = state_dict[ + f"{prefix}.encoder_attn.attention_weights.bias" + ] + keras_attn_bias = layer.encoder_attn.attention_weights.bias + keras_attn_bias.assign( + pytorch_attn_bias.numpy().reshape(keras_attn_bias.shape) + ) + num_points_scale_key = f"{prefix}.encoder_attn.num_points_scale" + if num_points_scale_key in state_dict: + layer.encoder_attn.num_points_scale.assign( + state_dict[num_points_scale_key].numpy() + ) + layer.fc1.weights[0].assign( + state_dict[f"{prefix}.fc1.weight"].T.numpy() + ) + layer.fc1.weights[1].assign(state_dict[f"{prefix}.fc1.bias"].numpy()) + layer.fc2.weights[0].assign( + state_dict[f"{prefix}.fc2.weight"].T.numpy() + ) + layer.fc2.weights[1].assign(state_dict[f"{prefix}.fc2.bias"].numpy()) + layer.final_layer_norm.set_weights( + [ + state_dict[f"{prefix}.final_layer_norm.weight"].numpy(), + state_dict[f"{prefix}.final_layer_norm.bias"].numpy(), + ] + ) + layer.gateway.gate.weights[0].assign( + state_dict[f"{prefix}.gateway.gate.weight"].T.numpy() + ) + layer.gateway.gate.weights[1].assign( + state_dict[f"{prefix}.gateway.gate.bias"].numpy() + ) + layer.gateway.norm.set_weights( + [ + state_dict[f"{prefix}.gateway.norm.weight"].numpy(), + state_dict[f"{prefix}.gateway.norm.bias"].numpy(), + ] + ) + + for i, layer in enumerate(k_decoder.lqe_layers): + prefix = f"model.decoder.lqe_layers.{i}.reg_conf.layers" + for j, dense in enumerate(layer.reg_conf.dense_layers): + dense.weights[0].assign( + state_dict[f"{prefix}.{j}.weight"].T.numpy() + ) + dense.weights[1].assign(state_dict[f"{prefix}.{j}.bias"].numpy()) + + for i, dense in enumerate(k_decoder.pre_bbox_head.dense_layers): + prefix = f"model.decoder.pre_bbox_head.layers.{i}" + dense.weights[0].assign(state_dict[f"{prefix}.weight"].T.numpy()) + dense.weights[1].assign(state_dict[f"{prefix}.bias"].numpy()) + + for i, dense in enumerate(k_decoder.query_pos_head.dense_layers): + prefix = f"model.decoder.query_pos_head.layers.{i}" + dense.weights[0].assign(state_dict[f"{prefix}.weight"].T.numpy()) + dense.weights[1].assign(state_dict[f"{prefix}.bias"].numpy()) + + k_decoder.reg_scale.assign(state_dict["model.decoder.reg_scale"].numpy()) + k_decoder.upsampling_factor.assign(state_dict["model.decoder.up"].numpy()) + + +def transfer_prediction_heads(state_dict, k_decoder): + for i, class_embed in enumerate(k_decoder.class_embed): + prefix = f"model.decoder.class_embed.{i}" + class_embed.weights[0].assign(state_dict[f"{prefix}.weight"].T.numpy()) + class_embed.weights[1].assign(state_dict[f"{prefix}.bias"].numpy()) + for i, bbox_embed in enumerate(k_decoder.bbox_embed): + prefix = f"model.decoder.bbox_embed.{i}.layers" + for j, layer in enumerate(bbox_embed.dense_layers): + layer.weights[0].assign( + state_dict[f"{prefix}.{j}.weight"].T.numpy() + ) + layer.weights[1].assign(state_dict[f"{prefix}.{j}.bias"].numpy()) + + +def transfer_dfine_model_weights(state_dict, k_model): + backbone = k_model.backbone + transfer_hgnet_backbone_weights(state_dict, backbone) + + for i, proj_layers in enumerate(backbone.encoder_input_proj_layers): + prefix = f"model.encoder_input_proj.{i}" + conv_weight_key = f"{prefix}.0.weight" + if conv_weight_key in state_dict: + proj_layers[0].weights[0].assign( + state_dict[conv_weight_key].permute(2, 3, 1, 0).numpy() + ) + proj_layers[1].set_weights( + [ + state_dict[f"{prefix}.1.weight"].numpy(), + state_dict[f"{prefix}.1.bias"].numpy(), + state_dict[f"{prefix}.1.running_mean"].numpy(), + state_dict[f"{prefix}.1.running_var"].numpy(), + ] + ) + + transfer_hybrid_encoder_weights(state_dict, backbone.encoder) + transfer_transformer_encoder_weights(state_dict, backbone.encoder) + if backbone.denoising_class_embed is not None: + backbone.denoising_class_embed.weights[0].assign( + state_dict["model.denoising_class_embed.weight"].numpy() + ) + + backbone.enc_output_layers[0].weights[0].assign( + state_dict["model.enc_output.0.weight"].T.numpy() + ) + backbone.enc_output_layers[0].weights[1].assign( + state_dict["model.enc_output.0.bias"].numpy() + ) + backbone.enc_output_layers[1].set_weights( + [ + state_dict["model.enc_output.1.weight"].numpy(), + state_dict["model.enc_output.1.bias"].numpy(), + ] + ) + + backbone.enc_score_head.weights[0].assign( + state_dict["model.enc_score_head.weight"].T.numpy() + ) + backbone.enc_score_head.weights[1].assign( + state_dict["model.enc_score_head.bias"].numpy() + ) + + for i, dense in enumerate(backbone.enc_bbox_head.dense_layers): + prefix = f"model.enc_bbox_head.layers.{i}" + dense.weights[0].assign(state_dict[f"{prefix}.weight"].T.numpy()) + dense.weights[1].assign(state_dict[f"{prefix}.bias"].numpy()) + + for i, proj_layers in enumerate(backbone.decoder_input_proj_layers): + prefix = f"model.decoder_input_proj.{i}" + if isinstance(proj_layers, keras.layers.Identity): + continue + conv_weight_key = f"{prefix}.0.weight" + if conv_weight_key in state_dict: + proj_layers[0].weights[0].assign( + state_dict[conv_weight_key].permute(2, 3, 1, 0).numpy() + ) + proj_layers[1].set_weights( + [ + state_dict[f"{prefix}.1.weight"].numpy(), + state_dict[f"{prefix}.1.bias"].numpy(), + state_dict[f"{prefix}.1.running_mean"].numpy(), + state_dict[f"{prefix}.1.running_var"].numpy(), + ] + ) + + transfer_decoder_weights(state_dict, backbone.decoder) + transfer_prediction_heads(state_dict, backbone.decoder) + + +def validate_conversion(keras_model, hf_preset): + pt_model = DFineForObjectDetection.from_pretrained(hf_preset) + image_processor = AutoImageProcessor.from_pretrained(hf_preset) + pt_model.eval() + raw_image = np.random.uniform(0, 255, (640, 640, 3)).astype(np.uint8) + pil_image = Image.fromarray(raw_image) + inputs = image_processor(images=pil_image, return_tensors="pt") + with torch.no_grad(): + pt_outputs = pt_model(**inputs) + keras_input = np.expand_dims(raw_image, axis=0).astype(np.float32) + keras_preprocessed_input = keras_model.preprocessor(keras_input) + keras_outputs = keras_model(keras_preprocessed_input, training=False) + + def to_numpy(tensor): + if keras.backend.backend() == "torch": + return tensor.detach().cpu().numpy() + elif keras.backend.backend() == "jax": + return np.array(tensor) + elif keras.backend.backend() == "tensorflow": + return tensor.numpy() + else: + return np.array(tensor) + + pt_pred_boxes = pt_outputs["pred_boxes"].detach().cpu().numpy() + print("\n=== Output Comparison ===") + pt_logits = pt_outputs["logits"].detach().cpu().numpy() + k_logits = to_numpy(keras_outputs["logits"]) + k_pred_boxes = to_numpy(keras_outputs["pred_boxes"]) + boxes_diff = np.mean(np.abs(pt_pred_boxes - k_pred_boxes)) + if boxes_diff < 1e-5: + print(f"šŸ”¶ Predicted Bounding Boxes Difference: {boxes_diff:.6e}") + print("āœ… Validation successful") + print(f"PyTorch Logits Shape: {pt_logits.shape}, dtype: {pt_logits.dtype}") + print(f"Keras Logits Shape: {k_logits.shape}, dtype: {k_logits.dtype}") + print("\n=== Logits Statistics ===") + print( + f"PyTorch Logits Min: {np.min(pt_logits):.6f}, max: " + f"{np.max(pt_logits):.6f}, mean: {np.mean(pt_logits):.6f}, std: " + f"{np.std(pt_logits):.6f}" + ) + print( + f"Keras Logits Min: {np.min(k_logits):.6f}, max: {np.max(k_logits):.6f}" + f", mean: {np.mean(k_logits):.6f}, std: {np.std(k_logits):.6f}" + ) + print("\n=== Pred Boxes Statistics ===") + print( + f"PyTorch Pred Boxes Min: {np.min(pt_pred_boxes):.6f}, max: " + f"{np.max(pt_pred_boxes):.6f}, mean: {np.mean(pt_pred_boxes):.6f}, " + f"std: {np.std(pt_pred_boxes):.6f}" + ) + print( + f"Keras Pred Boxes Min: {np.min(k_pred_boxes):.6f}, max: " + f"{np.max(k_pred_boxes):.6f}, mean: {np.mean(k_pred_boxes):.6f}, std: " + f"{np.std(k_pred_boxes):.6f}" + ) + print(f"NaN in Keras Logits: {np.any(np.isnan(k_logits))}") + print(f"NaN in Keras Boxes: {np.any(np.isnan(k_pred_boxes))}") + + +def main(_): + keras.utils.set_random_seed(0) + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + if FLAGS.preset not in PRESET_MAP: + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one of " + f"{list(PRESET_MAP.keys())}" + ) + hf_preset = PRESET_MAP[FLAGS.preset] + output_dir = FLAGS.preset + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + os.makedirs(output_dir) + print(f"\nāœ… Converting {FLAGS.preset}") + + state_dict = load_pytorch_model(hf_preset) + print("āœ… PyTorch state dict loaded") + + config_path = hf_hub_download( + repo_id=hf_preset, + filename="config.json", + cache_dir="./hf_models", + ) + with open(config_path, "r") as f: + config = json.load(f) + + keras_model = get_keras_model(config, hf_preset) + dummy_input = np.zeros((1, 640, 640, 3), dtype="float32") + keras_model(dummy_input) + print("āœ… Keras model constructed") + + transfer_dfine_model_weights(state_dict, keras_model) + print("āœ… Weights transferred") + validate_conversion(keras_model, hf_preset) + print("āœ… Validation completed") + + keras_model.save_to_preset(output_dir) + print(f"šŸ Preset saved to {output_dir}") + + if FLAGS.upload_uri: + keras_hub.upload_preset(uri=FLAGS.upload_uri, preset=output_dir) + print(f"šŸ Preset uploaded to {FLAGS.upload_uri}") + + +if __name__ == "__main__": + app.run(main)