From fafcbb260f891b9a7961515fb002c71978eb7120 Mon Sep 17 00:00:00 2001 From: ado Date: Mon, 2 Jun 2025 09:17:21 -0400 Subject: [PATCH 1/7] Add backbone, layer and test of swin transform - image encoder --- .../swin_transformers_backbone.py | 143 ++++ .../swin_transformers_backbone_test.py | 51 ++ .../swin_transformers_layers.py | 704 ++++++++++++++++++ 3 files changed, 898 insertions(+) create mode 100644 keras_hub/src/models/swin_transformers/swin_transformers_backbone.py create mode 100644 keras_hub/src/models/swin_transformers/swin_transformers_backbone_test.py create mode 100644 keras_hub/src/models/swin_transformers/swin_transformers_layers.py diff --git a/keras_hub/src/models/swin_transformers/swin_transformers_backbone.py b/keras_hub/src/models/swin_transformers/swin_transformers_backbone.py new file mode 100644 index 0000000000..6bc9b80b97 --- /dev/null +++ b/keras_hub/src/models/swin_transformers/swin_transformers_backbone.py @@ -0,0 +1,143 @@ +import keras +from keras import layers +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.swin_transformers.swin_transformers_layers import ( + PatchEmbedding, + SwinTransformerStage, + PatchMerging +) + +@keras_hub_export("keras_hub.models.SwinTransformersBackbone") +class SwinTransformersBackbone(Backbone): + """Swin Transformer backbone. + + This backbone implements the Swin Transformer architecture as described in + [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030). + + The Swin Transformer is a hierarchical vision transformer that uses shifted + windows for self-attention computation. It has several advantages: + + 1. Hierarchical feature maps with downsampling like CNNs + 2. Linear computational complexity with respect to image size + 3. Support for various vision tasks, including image classification, + object detection, and semantic segmentation + + Args: + image_shape: A tuple or list of 3 integers representing the shape of the + input image `(height, width, channels)`. + patch_size: int. The size of each patch (both height and width). + embed_dim: int. The embedding dimension for the first stage. + depths: list of ints. Number of transformer blocks in each stage. + num_heads: list of ints. Number of attention heads in each stage. + window_size: int. Size of attention window (both height and width). + mlp_ratio: float. Ratio of MLP hidden dimension to embedding dimension. + qkv_bias: bool. If True, add a learnable bias to query, key, value. + dropout_rate: float. Dropout rate for embedding and transformer layers. + attention_dropout: float. Dropout rate for attention projections. + path_dropout: float. Stochastic depth rate for transformer blocks. + patch_norm: bool. If True, add normalization after patch embedding. + data_format: str. One of `"channels_last"` or `"channels_first"`. + dtype: The dtype of the layer weights. Defaults to None. + **kwargs: Additional keyword arguments to be passed to the parent + `Backbone` class. + """ + + def __init__( + self, + image_shape=(224, 224, 3), + patch_size=4, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + dropout_rate=0.0, + attention_dropout=0.0, + path_dropout=0.2, + patch_norm=True, + data_format="channels_last", + dtype=None, + **kwargs, + ): + if len(depths) != len(num_heads): + raise ValueError( + f"Length of depths ({len(depths)}) must match " + f"length of num_heads ({len(num_heads)})" + ) + + self.patch_embedding = PatchEmbedding( + patch_size=patch_size, + embed_dim=embed_dim, + data_format=data_format, + patch_norm=patch_norm, + name="patch_embedding" + ) + + self.pos_dropout = layers.Dropout(dropout_rate, name="pos_dropout") if dropout_rate > 0.0 else None + + self.stages = [] + for i, (depth, num_head) in enumerate(zip(depths, num_heads)): + dim = embed_dim * (2 ** i) + downsample = PatchMerging(dim=dim // 2, name=f"downsample_{i-1}") if i > 0 else None + + stage = SwinTransformerStage( + dim=dim, + depth=depth, + num_heads=num_head, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + dropout_rate=dropout_rate, + attention_dropout=attention_dropout, + path_dropout=path_dropout, + downsample=downsample, + name=f"stage_{i}" + ) + self.stages.append(stage) + + self.norm = layers.LayerNormalization(epsilon=1e-5, name="norm") + + inputs = keras.layers.Input(shape=image_shape) + x = self.patch_embedding(inputs) + if self.pos_dropout is not None: + x = self.pos_dropout(x) + for stage in self.stages: + x = stage(x) + x = self.norm(x) + + super().__init__(inputs=inputs, outputs=x, dtype=dtype, **kwargs) + + self.data_format = data_format + self.image_shape = image_shape + self.patch_size = patch_size + self.embed_dim = embed_dim + self.depths = depths + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.path_dropout = path_dropout + self.patch_norm = patch_norm + + def get_config(self): + config = super().get_config() + config.update({ + "image_shape": self.image_shape, + "patch_size": self.patch_size, + "embed_dim": self.embed_dim, + "depths": self.depths, + "num_heads": self.num_heads, + "window_size": self.window_size, + "mlp_ratio": self.mlp_ratio, + "qkv_bias": self.qkv_bias, + "dropout_rate": self.dropout_rate, + "attention_dropout": self.attention_dropout, + "path_dropout": self.path_dropout, + "patch_norm": self.patch_norm, + "data_format": self.data_format, + }) + return config \ No newline at end of file diff --git a/keras_hub/src/models/swin_transformers/swin_transformers_backbone_test.py b/keras_hub/src/models/swin_transformers/swin_transformers_backbone_test.py new file mode 100644 index 0000000000..48e04e1ecf --- /dev/null +++ b/keras_hub/src/models/swin_transformers/swin_transformers_backbone_test.py @@ -0,0 +1,51 @@ +import pytest +from keras import ops + +from keras_hub.src.models.swin_transformers.swin_transformers_backbone import ( + SwinTransformersBackbone, +) +from keras_hub.src.tests.test_case import TestCase + +class SwinTransformersBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "image_shape": (64, 64, 3), + "patch_size": 2, + "embed_dim": 32, + "depths": [1, 1, 1, 1], + "num_heads": [1, 2, 4, 8], + "window_size": 4, + "mlp_ratio": 4.0, + "qkv_bias": True, + "dropout_rate": 0.0, + "attention_dropout": 0.0, + "path_dropout": 0.1, + "patch_norm": True, + "data_format": "channels_last", + "dtype": "float32", + } + self.input_data = ops.ones((2, 64, 64, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=SwinTransformersBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 2, 2, 256), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=SwinTransformersBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.large + def test_smallest_preset(self): + pass # Will be added in a future PR when presets are implemented + + @pytest.mark.extra_large + def test_all_presets(self): + pass # Will be added in a future PR when presets are implemented diff --git a/keras_hub/src/models/swin_transformers/swin_transformers_layers.py b/keras_hub/src/models/swin_transformers/swin_transformers_layers.py new file mode 100644 index 0000000000..0352eba520 --- /dev/null +++ b/keras_hub/src/models/swin_transformers/swin_transformers_layers.py @@ -0,0 +1,704 @@ +import keras +from keras import layers +from keras import ops +import collections.abc + +def window_partition(x, window_size): + """Partition the input tensor into non-overlapping windows.""" + batch_size, height, width, channels = ops.shape(x) + + x = ops.reshape( + x, + ( + batch_size, + height // window_size, + window_size, + width // window_size, + window_size, + channels + ) + ) + + x = ops.transpose(x, (0, 1, 3, 2, 4, 5)) + windows = ops.reshape( + x, (-1, window_size, window_size, channels) + ) + return windows + + +def window_reverse(windows, window_size, height, width, channels): + """Reverse window partitioning.""" + batch_size = ops.shape(windows)[0] // ((height // window_size) * (width // window_size)) + + x = ops.reshape( + windows, + ( + batch_size, + height // window_size, + width // window_size, + window_size, + window_size, + channels + ) + ) + + x = ops.transpose(x, (0, 1, 3, 2, 4, 5)) + x = ops.reshape(x, (batch_size, height, width, channels)) + return x + + +class DropPath(layers.Layer): + """Drop paths (Stochastic Depth) per sample. + + This is an implementation of the paper "Deep Networks with Stochastic Depth", + which randomly drops entire layers for regularization. + + Args: + drop_prob: float, probability of dropping path. + """ + + def __init__(self, drop_prob=0.0, **kwargs): + super().__init__(**kwargs) + self.drop_prob = drop_prob + + def call(self, x, training=None): + if self.drop_prob == 0.0 or not training: + return x + + # Keep probability + keep_prob = 1.0 - self.drop_prob + + # Create binary mask with shape [batch_size, 1, 1, 1] + batch_size = ops.shape(x)[0] + random_tensor = keep_prob + ops.random.uniform((batch_size, 1, 1, 1), dtype=x.dtype) + binary_mask = ops.floor(random_tensor) + + # Scale output to preserve expected value + output = x / keep_prob * binary_mask + return output + + def get_config(self): + config = super().get_config() + config.update({"drop_prob": self.drop_prob}) + return config + + +class Mlp(layers.Layer): + """MLP module for Transformer. + + Args: + in_features: Input dimension. + hidden_features: Hidden dimension. + out_features: Output dimension. + dropout_rate: Dropout rate. + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + dropout_rate=0.0, + **kwargs, + ): + super().__init__(**kwargs) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.in_features = in_features + self.hidden_features = hidden_features + self.out_features = out_features + self.dropout_rate = dropout_rate + + self.fc1 = layers.Dense(hidden_features, name="fc1") + self.act = keras.activations.gelu + self.fc2 = layers.Dense(out_features, name="fc2") + self.drop = layers.Dropout(dropout_rate) if dropout_rate > 0.0 else None + + def call(self, x): + x = self.fc1(x) + x = self.act(x) + if self.drop is not None: + x = self.drop(x) + x = self.fc2(x) + if self.drop is not None: + x = self.drop(x) + return x + + def get_config(self): + config = super().get_config() + config.update({ + "in_features": self.in_features, + "hidden_features": self.hidden_features, + "out_features": self.out_features, + "dropout_rate": self.dropout_rate, + }) + return config + + +class WindowAttention(layers.Layer): + """Window based multi-head self attention. + + Args: + dim: Number of input channels + window_size: Window size + num_heads: Number of attention heads + qkv_bias: Add bias to query, key, value projections + attention_dropout: Attention dropout rate + dropout: Dropout rate + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + attention_dropout=0., + dropout=0., + **kwargs + ): + super().__init__(**kwargs) + + self.dim = dim + self.window_size = ( + window_size + if isinstance(window_size, collections.abc.Iterable) + else (window_size, window_size) + ) + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + # Linear layers for Q, K, V + self.qkv = layers.Dense( + dim * 3, + use_bias=qkv_bias, + name="qkv" + ) + + # Relative position encoding + self.relative_position_bias_table = self.add_weight( + shape=((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads), + initializer="zeros", + trainable=True, + name="relative_position_bias_table" + ) + + # Get pair-wise relative position index + coords = ops.stack(ops.meshgrid( + ops.arange(self.window_size[0]), + ops.arange(self.window_size[1]) + )) + coords = ops.reshape(coords, [2, -1]) + relative_coords = coords[:, :, None] - coords[:, None, :] + relative_coords = ops.transpose(relative_coords, [1, 2, 0]) + + relative_coords = relative_coords + self.window_size[0] - 1 + relative_coords = relative_coords * (2 * self.window_size[0] - 1) + relative_position_index = ops.sum(relative_coords, -1) + + self.relative_position_index = relative_position_index + + self.attn_drop = layers.Dropout(attention_dropout) + self.proj = layers.Dense(dim) + self.proj_drop = layers.Dropout(dropout) + + def build(self, input_shape): + self.num_windows = input_shape[0] // ( + self.window_size[0] * self.window_size[1] + ) + super().build(input_shape) + + def call(self, x, mask=None): + """Forward pass. + + Args: + x: Input tensor with shape [batch*num_windows, window_size*window_size, dim]. + mask: Optional mask for shifted window attention. + + Returns: + Output tensor with shape [batch*num_windows, window_size*window_size, dim]. + """ + B_, N, C = ops.shape(x) + + # QKV projection + qkv = self.qkv(x) # [B_, N, 3*C] + + # Calculate exact dimensions + qkv_dim = ops.shape(qkv)[-1] + dim_per_head = C // self.num_heads + + # Split QKV + # This splits the last dimension into 3 equal parts + chunk_size = qkv_dim // 3 + q = qkv[:, :, :chunk_size] + k = qkv[:, :, chunk_size:2*chunk_size] + v = qkv[:, :, 2*chunk_size:] + + # Reshape to separate heads + q = ops.reshape(q, (B_, N, self.num_heads, dim_per_head)) + k = ops.reshape(k, (B_, N, self.num_heads, dim_per_head)) + v = ops.reshape(v, (B_, N, self.num_heads, dim_per_head)) + + # Transpose to [B_, num_heads, N, head_dim] + q = ops.transpose(q, (0, 2, 1, 3)) + k = ops.transpose(k, (0, 2, 1, 3)) + v = ops.transpose(v, (0, 2, 1, 3)) + + # Scale query + q = q * self.scale + + # Compute attention scores + attn = ops.matmul(q, ops.transpose(k, (0, 1, 3, 2))) + + # Add relative position bias + relative_position_bias = ops.take( + self.relative_position_bias_table, + self.relative_position_index, + ) + + relative_position_bias = ops.reshape( + relative_position_bias, + (self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + self.num_heads) + ) + + # Transpose to [num_heads, Wh*Ww, Wh*Ww] + relative_position_bias = ops.transpose(relative_position_bias, (2, 0, 1)) + + # Add to attention [B_, num_heads, N, N] + attn = attn + ops.expand_dims(relative_position_bias, axis=0) + + # Apply attention mask if provided + if mask is not None: + nW = mask.shape[0] # num_windows + # attn: [B_/nW, nW, num_heads, N, N] + # mask: [1, nW, 1, N, N] + attn = ops.reshape(attn, (-1, nW, self.num_heads, N, N)) + mask = ops.expand_dims(mask, axis=1) # [nW, 1, N, N] -> [1, nW, 1, N, N] + attn = attn + ops.cast(mask, attn.dtype) * -100.0 + attn = ops.reshape(attn, (-1, self.num_heads, N, N)) + + # Softmax normalization and dropout + attn = ops.softmax(attn, axis=-1) + if self.attn_drop is not None: + attn = self.attn_drop(attn) + + # Apply attention to values + x = ops.matmul(attn, v) # [B_, num_heads, N, head_dim] + + # Transpose back to [B_, N, C] + x = ops.transpose(x, (0, 2, 1, 3)) + x = ops.reshape(x, (B_, N, C)) + + # Output projection and dropout + x = self.proj(x) + if self.proj_drop is not None: + x = self.proj_drop(x) + + return x + + def get_config(self): + config = super().get_config() + config.update({ + "dim": self.dim, + "window_size": self.window_size, + "num_heads": self.num_heads, + "qkv_bias": self.qkv_bias, + "attention_dropout": self.attention_dropout, + "dropout": self.dropout, + }) + return config + + +class SwinTransformerBlock(layers.Layer): + """Swin Transformer Block. + + Args: + dim: Number of input channels. + input_resolution: Input resolution (height, width). + num_heads: Number of attention heads. + window_size: Window size for attention. + shift_size: Shift size for shifted window attention (0 or window_size//2). + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + dropout_rate: Dropout rate. + attention_dropout: Dropout rate for attention. + path_dropout: Stochastic depth rate. + norm_layer: Normalization layer class. + """ + + def __init__( + self, + dim, + input_resolution=None, + num_heads=1, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + dropout_rate=0.0, + attention_dropout=0.0, + path_dropout=0.0, + norm_layer=layers.LayerNormalization, + **kwargs, + ): + super().__init__(**kwargs) + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(epsilon=1e-5, name="norm1") + self.attn = WindowAttention( + dim=dim, + window_size=window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + attention_dropout=attention_dropout, + dropout=dropout_rate, + name="attn" + ) + self.drop_path = DropPath(path_dropout) if path_dropout > 0. else None + self.norm2 = norm_layer(epsilon=1e-5, name="norm2") + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + dropout_rate=dropout_rate, + name="mlp" + ) + + if self.shift_size > 0: + H, W = self.input_resolution + img_mask = ops.zeros((1, H, W, 1)) + + h_slices = [ + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None) + ] + w_slices = [ + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None) + ] + + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask_segment = ops.ones((1, H, W, 1)) + img_mask_segment = ops.index_update( + img_mask_segment, (..., h, w, ...), ops.ones((1, h.stop - h.start if h.stop else H - h.start, + w.stop - w.start if w.stop else W - w.start, 1)) * cnt + ) + img_mask = img_mask + img_mask_segment + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = ops.reshape(mask_windows, (-1, self.window_size * self.window_size)) + attn_mask = ops.expand_dims(mask_windows, axis=1) - ops.expand_dims(mask_windows, axis=2) + attn_mask = ops.where(attn_mask != 0, -100.0, 0.0) + self.attn_mask = attn_mask + else: + self.attn_mask = None + + def call(self, x): + B, L, C = ops.shape(x) + H, W = self.input_resolution + + window_size = self.window_size + shift_size = self.shift_size + + if min(H, W) <= window_size: + window_size = min(H, W) + shift_size = 0 + + x = ops.reshape(x, (B, H, W, C)) + + if self.shift_size > 0: + shifted_x = ops.roll(x, shift=(-self.shift_size, -self.shift_size), axis=(1, 2)) + else: + shifted_x = x + + x_windows = window_partition(shifted_x, self.window_size) # [B*num_windows, window_size, window_size, C] + x_windows = ops.reshape(x_windows, (-1, self.window_size * self.window_size, C)) # [B*num_windows, window_size*window_size, C] + + identity = x_windows + + x_windows = self.norm1(x_windows) + attn_windows = self.attn(x_windows, mask=self.attn_mask) # [B*num_windows, window_size*window_size, C] + + if self.drop_path is not None: + attn_windows = self.drop_path(attn_windows) + + x_windows = identity + attn_windows + + identity = x_windows + x_windows = self.norm2(x_windows) + x_windows = self.mlp(x_windows) + + if self.drop_path is not None: + x_windows = self.drop_path(x_windows) + + x_windows = identity + x_windows + + x_windows = ops.reshape(x_windows, (-1, self.window_size, self.window_size, C)) + + if self.shift_size > 0: + x = window_reverse(x_windows, self.window_size, H, W, C) + x = ops.roll(x, shift=(self.shift_size, self.shift_size), axis=(1, 2)) + else: + x = window_reverse(x_windows, self.window_size, H, W, C) + + x = ops.reshape(x, (B, H * W, C)) + + return x + + def get_config(self): + config = super().get_config() + config.update({ + "dim": self.dim, + "input_resolution": self.input_resolution, + "num_heads": self.num_heads, + "window_size": self.window_size, + "shift_size": self.shift_size, + "mlp_ratio": self.mlp_ratio, + }) + return config + + def compute_output_shape(self, input_shape): + return input_shape + + +class PatchMerging(layers.Layer): + """Patch Merging Layer. + + This layer performs downsampling by concatenating patches and using a linear layer. + + Args: + dim: Number of input channels. + """ + + def __init__(self, dim, **kwargs): + super().__init__(**kwargs) + self.dim = dim + self.reduction = layers.Dense(2 * dim, use_bias=False, name="reduction") + self.norm = layers.LayerNormalization(epsilon=1e-5, name="norm") + + def call(self, x, H, W): + """Forward pass. + + Args: + x: Input tensor with shape [B, H*W, C]. + H: Height of feature map. + W: Width of feature map. + + Returns: + Downsampled feature map with shape [B, H/2*W/2, 2*C]. + """ + B, L, C = ops.shape(x) + + x = ops.reshape(x, (B, H, W, C)) + pad_values = ((0, 0), (0, H % 2), (0, W % 2), (0, 0)) + x = ops.pad(x, pad_values) + + # Reshape to group patches + x0 = x[:, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, :] + x2 = x[:, 0::2, 1::2, :] + x3 = x[:, 1::2, 1::2, :] + + x = ops.concatenate([x0, x1, x2, x3], axis=-1) + x = self.norm(x) + x = self.reduction(x) + x = ops.reshape(x, (B, -1, 2 * C)) + + return x + + def get_config(self): + config = super().get_config() + config.update({"dim": self.dim}) + return config + + def compute_output_shape(self, input_shape): + batch_size, seq_len, channels = input_shape + return (batch_size, seq_len // 4, channels * 2) + + +class PatchEmbedding(layers.Layer): + """Image to Patch Embedding. + + Args: + patch_size: Size of each patch. + embed_dim: Embedding dimension. + norm_layer: Normalization layer. + data_format: Format of the input data, either "channels_last" or "channels_first". + patch_norm: If True, add normalization after patch embedding. + """ + + def __init__( + self, + patch_size=4, + embed_dim=96, + norm_layer=None, + data_format="channels_last", + patch_norm=False, + **kwargs, + ): + super().__init__(**kwargs) + self.patch_size = patch_size + self.embed_dim = embed_dim + self.data_format = data_format + + self.proj = layers.Conv2D( + embed_dim, + kernel_size=patch_size, + strides=patch_size, + padding="valid", + data_format=data_format, + name="proj", + ) + + self.norm = norm_layer(epsilon=1e-5, name="norm") if patch_norm and norm_layer else None + + def call(self, x): + """Forward pass. + + Args: + x: Input images with shape [B, H, W, C] in channels_last format + or [B, C, H, W] in channels_first format. + + Returns: + Patch embeddings with shape [B, H//patch_size * W//patch_size, embed_dim]. + """ + B = ops.shape(x)[0] + + x = self.proj(x) + + if self.data_format == "channels_last": + _, H, W, C = ops.shape(x) + x = ops.reshape(x, (B, H * W, C)) + else: + _, C, H, W = ops.shape(x) + x = ops.transpose(x, (0, 2, 3, 1)) # [B, H, W, C] + x = ops.reshape(x, (B, H * W, C)) + + if self.norm is not None: + x = self.norm(x) + + return x + + def get_config(self): + config = super().get_config() + config.update({ + "patch_size": self.patch_size, + "embed_dim": self.embed_dim, + "data_format": self.data_format, + }) + return config + + +class SwinTransformerStage(layers.Layer): + """Swin Transformer Stage. + + A stage consists of multiple Swin Transformer blocks with the same resolution, + and an optional patch merging layer at the beginning. + + Args: + dim: Number of input channels. + depth: Number of blocks in this stage. + num_heads: Number of attention heads. + window_size: Local window size. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + dropout_rate: Dropout rate. + attention_dropout: Dropout rate for attention. + path_dropout: Stochastic depth rate. + downsample: Downsample layer at the end of the layer. + """ + + def __init__( + self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + dropout_rate=0.0, + attention_dropout=0.0, + path_dropout=0.0, + downsample=None, + **kwargs, + ): + super().__init__(**kwargs) + self.dim = dim + self.depth = depth + self.window_size = window_size + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + + self.blocks = [] + for i in range(depth): + self.blocks.append( + SwinTransformerBlock( + dim=dim, + input_resolution=None, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + dropout_rate=dropout_rate, + attention_dropout=attention_dropout, + path_dropout=path_dropout[i] if isinstance(path_dropout, list) else path_dropout, + name=f"blocks_{i}", + ) + ) + + self.downsample = downsample + + def call(self, x): + """Forward pass. + + Args: + x: Input feature with shape [B, H*W, C]. + + Returns: + Output feature with shape [B, H/2*W/2, 2*C] if downsample is applied, + otherwise [B, H*W, C]. + """ + B, L, C = ops.shape(x) + + H_W = ops.cast(ops.sqrt(ops.cast(L, "float32")), "int32") + + for block in self.blocks: + block.input_resolution = (H_W, H_W) + + for block in self.blocks: + x = block(x) + + if self.downsample is not None: + x = self.downsample(x, H_W, H_W) + + return x + + def get_config(self): + config = super().get_config() + config.update({ + "dim": self.dim, + "depth": self.depth, + "window_size": self.window_size, + "mlp_ratio": self.mlp_ratio, + }) + return config + + def compute_output_shape(self, input_shape): + batch_size, seq_len, channels = input_shape + if self.downsample is not None: + return (batch_size, seq_len // 4, channels * 2) + return input_shape From 69019f13290aa67c10624a18846c7bf1f9dbc92a Mon Sep 17 00:00:00 2001 From: ado Date: Fri, 13 Jun 2025 09:50:08 -0400 Subject: [PATCH 2/7] First iteration of swin transformer layers in keras hub models --- .../swin_transformer_backbone.py | 191 +++++ .../swin_transformer_backbone_test.py | 45 ++ .../swin_transformer_layers.py | 719 ++++++++++++++++++ 3 files changed, 955 insertions(+) create mode 100644 keras_hub/src/models/swin_transformer/swin_transformer_backbone.py create mode 100644 keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py create mode 100644 keras_hub/src/models/swin_transformer/swin_transformer_layers.py diff --git a/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py b/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py new file mode 100644 index 0000000000..89d1fa6776 --- /dev/null +++ b/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py @@ -0,0 +1,191 @@ +import keras +from keras import layers +from keras import ops +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.swin_transformer.swin_transformer_layers import ( + PatchEmbedding, + SwinTransformerStage, + PatchMerging +) + +def swin_kernel_initializer(stddev=0.02): + return keras.initializers.TruncatedNormal(stddev=stddev) + +@keras_hub_export("keras_hub.models.SwinTransformerBackbone") +class SwinTransformerBackbone(Backbone): + """A Swin Transformer backbone network. + + This network implements a hierarchical vision transformer as described in + ["Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"](https://arxiv.org/abs/2103.14030). + It includes the patch embedding, transformer stages with shifted windows, + and final normalization, but not the classification head. + + The default constructor gives a fully customizable, randomly initialized + Swin Transformer with any number of layers, heads, and embedding dimensions. + To load preset architectures and weights, use the `from_preset()` constructor. + + Args: + image_shape: tuple of ints. The shape of the input images, excluding batch dimension. + patch_size: int. Size of the patches to be extracted from the input images. + embed_dim: int. Base dimension of the transformer. + depths: tuple of ints. Number of transformer blocks in each stage. + num_heads: tuple of ints. Number of attention heads in each stage. + window_size: int. Size of the attention window. + mlp_ratio: float. Ratio of mlp hidden dim to embedding dim. + qkv_bias: bool. If True, add a learnable bias to query, key, value. + drop: float. Dropout rate. + attn_drop: float. Dropout rate for attention. + drop_path: float. Stochastic depth rate. + patch_norm: bool. If True, add normalization after patch embedding. + data_format: str. Format of the input data, either "channels_last" or "channels_first". + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. + + Examples: + ```python + # Pretrained Swin Transformer backbone. + model = keras_hub.models.SwinTransformerBackbone.from_preset( + "swin_tiny_224" + ) + model(np.ones((1, 224, 224, 3))) + + # Randomly initialized Swin Transformer with custom config. + model = keras_hub.models.SwinTransformerBackbone( + image_shape=(224, 224, 3), + patch_size=4, + embed_dim=96, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4.0, + ) + model(np.ones((1, 224, 224, 3))) + ``` + """ + + def __init__( + self, + image_shape, + patch_size=4, + embed_dim=96, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + drop=0.0, + attn_drop=0.0, + drop_path=0.1, + patch_norm=True, + data_format="channels_last", + dtype=None, + **kwargs, + ): + dtype = dtype or keras.backend.floatx() + + # === Layers === + self.patch_embedding = PatchEmbedding( + patch_size=patch_size, + embed_dim=embed_dim, + norm_layer=layers.LayerNormalization if patch_norm else None, + data_format=data_format, + patch_norm=patch_norm, + name="patch_embedding", + ) + + # Stochastic depth decay rule + dpr = [float(x) for x in ops.linspace(0.0, drop_path, sum(depths))] + + # === Functional Model === + inputs = keras.Input(shape=image_shape) + x = self.patch_embedding(inputs) + h, w = image_shape[0] // patch_size, image_shape[1] // patch_size + + # Build stages + self.stages = [] + for i in range(len(depths)): + stage = SwinTransformerStage( + dim=int(embed_dim * 2 ** i), + depth=depths[i], + num_heads=num_heads[i], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, + attn_drop=attn_drop, + drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], + downsample=PatchMerging if (i < len(depths) - 1) else None, + input_resolution=(h, w), + name=f"stage_{i}", + ) + self.stages.append(stage) + h //= 2 + w //= 2 + + # Final norm + self.norm_layers = [ + layers.LayerNormalization(epsilon=1e-5, name=f"norm_{i}") for i in range(len(depths)) + ] + + # Forward pass + features = [] + + for i, stage in enumerate(self.stages): + x = stage(x) + + def reshape_and_norm(tensor, norm_layer=self.norm_layers[i]): + shape = ops.shape(tensor) + B = shape[0] + L = shape[1] + C = shape[2] + H_float = ops.sqrt(ops.cast(L, x.dtype)) + H = ops.cast(H_float, "int32") + W = H + tensor = ops.reshape(tensor, (B, H, W, C)) + return norm_layer(tensor) + + x_reshaped = keras.layers.Lambda(reshape_and_norm)(x) + features.append(x_reshaped) + + + super().__init__( + inputs=inputs, + outputs=features[-1], + dtype=None, + **kwargs + ) + + # === Config === + self.image_shape = image_shape + self.patch_size = patch_size + self.embed_dim = embed_dim + self.depths = depths + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.drop = drop + self.attn_drop = attn_drop + self.drop_path = drop_path + self.patch_norm = patch_norm + self.data_format = data_format + + def get_config(self): + config = super().get_config() + config.update({ + "image_shape": self.image_shape, + "patch_size": self.patch_size, + "embed_dim": self.embed_dim, + "depths": self.depths, + "num_heads": self.num_heads, + "window_size": self.window_size, + "mlp_ratio": self.mlp_ratio, + "qkv_bias": self.qkv_bias, + "drop": self.drop, + "attn_drop": self.attn_drop, + "drop_path": self.drop_path, + "patch_norm": self.patch_norm, + "data_format": self.data_format, + }) + return config \ No newline at end of file diff --git a/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py b/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py new file mode 100644 index 0000000000..569cc4777f --- /dev/null +++ b/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py @@ -0,0 +1,45 @@ +import pytest +from keras import ops, mixed_precision + +from keras_hub.src.models.swin_transformer.swin_transformer_backbone import ( + SwinTransformerBackbone, +) +from keras_hub.src.tests.test_case import TestCase + +class SwinTransformerBackboneTest(TestCase): + def setUp(self): + mixed_precision.set_global_policy("float32") + self.init_kwargs = { + "image_shape": (224, 224, 3), + "patch_size": 4, + "embed_dim": 96, + "depths": (2, 2, 6, 2), + "num_heads": (3, 6, 12, 24), + "window_size": 7, + "mlp_ratio": 4.0, + } + self.input_data = ops.ones((1, 224, 224, 3)) + + def test_backbone_basics(self): + self.run_backbone_test( + cls=SwinTransformerBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape = (1, 7, 7, 768), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=SwinTransformerBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.large + def test_smallest_preset(self): + pass # Will be added in a future PR when presets are implemented + + @pytest.mark.extra_large + def test_all_presets(self): + pass # Will be added in a future PR when presets are implemented diff --git a/keras_hub/src/models/swin_transformer/swin_transformer_layers.py b/keras_hub/src/models/swin_transformer/swin_transformer_layers.py new file mode 100644 index 0000000000..45b494c63c --- /dev/null +++ b/keras_hub/src/models/swin_transformer/swin_transformer_layers.py @@ -0,0 +1,719 @@ +import keras +from keras import layers +from keras import ops +import collections.abc +from typing import Union, Tuple, Any +import numpy as np + +def get_relative_position_index(win_h, win_w): + """Get pair-wise relative position index for each token inside the window. + + Args: + win_h: Height of the window. + win_w: Width of the window. + + Returns: + A tensor of shape (win_h*win_w, win_h*win_w) containing the relative + position indices for each pair of tokens in the window. + """ + xx, yy = ops.meshgrid(ops.arange(win_h), ops.arange(win_w), indexing="ij") + coords = ops.stack([yy, xx], axis=0) + coords_flatten = ops.reshape(coords, (2, -1)) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = ops.transpose(relative_coords, (1, 2, 0)) + xx = (relative_coords[:, :, 0] + win_h - 1) * (2 * win_w - 1) + yy = relative_coords[:, :, 1] + win_w - 1 + relative_coords = ops.stack([xx, yy], axis=-1) + relative_position_index = ops.sum(relative_coords, axis=-1) + return relative_position_index + +def window_partition(x, window_size): + """Partition the input tensor into non-overlapping windows. + + Args: + x: Input tensor with shape [B, H, W, C] + window_size: Size of the window + + Returns: + Windows with shape [B*num_windows, window_size, window_size, C] + """ + shape = ops.shape(x) + if len(shape) != 4: + raise ValueError(f"Expected input tensor to have 4 dimensions, got {len(shape)}") + + B = shape[0] + H = shape[1] + W = shape[2] + C = shape[3] + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = ops.pad(x, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]]) + H = H + pad_h + W = W + pad_w + + num_windows_h = H // window_size + num_windows_w = W // window_size + + # Reshape to windows + x = ops.reshape( + x, + ( + B, + num_windows_h, + window_size, + num_windows_w, + window_size, + C + ) + ) + x = ops.transpose(x, (0, 1, 3, 2, 4, 5)) + windows = ops.reshape( + x, + (-1, window_size, window_size, C) + ) + + return windows, (H, W) + + +def window_reverse(windows, window_size, height, width, channels): + """Reverse window partitioning. + + Args: + windows: Windows with shape [B*num_windows, window_size, window_size, C] + window_size: Size of the window + height: Height of the feature map + width: Width of the feature map + channels: Number of channels + + Returns: + Feature map with shape [B, H, W, C] + """ + # Calculate number of windows + num_windows_h = height // window_size + num_windows_w = width // window_size + batch_size = ops.shape(windows)[0] // (num_windows_h * num_windows_w) + + # Reshape windows to [B, num_windows_h, num_windows_w, window_size, window_size, C] + x = ops.reshape( + windows, + ( + batch_size, + num_windows_h, + num_windows_w, + window_size, + window_size, + channels + ) + ) + + # Permute dimensions to get [B, H, W, C] + x = ops.transpose(x, (0, 1, 3, 2, 4, 5)) + x = ops.reshape(x, (batch_size, height, width, channels)) + + return x + + +class DropPath(layers.Layer): + """Drop paths (Stochastic Depth) per sample. + + This is an implementation of the paper "Deep Networks with Stochastic Depth", + which randomly drops entire layers for regularization. + + Args: + drop_prob: float, probability of dropping path. + """ + + def __init__(self, drop_prob=0.0, **kwargs): + super().__init__(**kwargs) + self.drop_prob = drop_prob + + def call(self, x, training=None): + if self.drop_prob == 0.0 or not training: + return x + keep_prob = 1.0 - self.drop_prob + + batch_size = ops.shape(x)[0] + random_tensor = keep_prob + ops.random.uniform((batch_size, 1, 1, 1)) + binary_mask = ops.floor(random_tensor) + output = x / keep_prob * binary_mask + return output + + def get_config(self): + config = super().get_config() + config.update({"drop_prob": self.drop_prob}) + return config + + +class Mlp(layers.Layer): + """MLP module for Transformer. + + Args: + in_features: Input dimension. + hidden_features: Hidden dimension. + out_features: Output dimension. + act_layer: Activation function to use (e.g., keras.activations.gelu). + dropout_rate: Dropout rate. + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=keras.activations.relu, + dropout_rate=0.0, + **kwargs, + ): + super().__init__(**kwargs) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.in_features = in_features + self.hidden_features = hidden_features + self.out_features = out_features + self.act_layer = act_layer + self.dropout_rate = dropout_rate + + self.fc1 = layers.Dense(hidden_features, name="fc1") + self.fc2 = layers.Dense(out_features, name="fc2") + self.drop = layers.Dropout(dropout_rate) if dropout_rate > 0.0 else None + + def call(self, x): + x = self.fc1(x) + x = self.act_layer(x) + if self.drop is not None: + x = self.drop(x) + x = self.fc2(x) + if self.drop is not None: + x = self.drop(x) + return x + + def get_config(self): + config = super().get_config() + config.update({ + "in_features": self.in_features, + "hidden_features": self.hidden_features, + "out_features": self.out_features, + "act_layer": keras.activations.serialize(self.act_layer), + "dropout_rate": self.dropout_rate, + }) + return config + + @classmethod + def from_config(cls, config): + config["act_layer"] = keras.activations.deserialize(config["act_layer"]) + return cls(**config) + + +class WindowAttention(keras.layers.Layer): + """Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + head_dim (int): Number of channels per head (dim // num_heads if not set) + window_size (tuple[int]): The height and width of the window. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float, optional): Override default scaling factor for queries and keys (default: head_dim ** -0.5) + attn_drop (float, optional): Dropout ratio of attention weights. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + num_heads, + head_dim=None, + window_size=7, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.dim = dim + self.window_size = ( + window_size + if isinstance(window_size, collections.abc.Iterable) + else (window_size, window_size) + ) + self.win_h, self.win_w = self.window_size + self.window_area = self.win_h * self.win_w + self.num_heads = num_heads + self.head_dim = head_dim or (dim // num_heads) + self.scale = qk_scale if qk_scale is not None else self.head_dim ** -0.5 + self.attn_dim = self.head_dim * self.num_heads + self.qkv_bias = qkv_bias + self.attn_drop_rate = attn_drop + self.proj_drop_rate = proj_drop + + self.relative_position_index = get_relative_position_index( + win_h=self.win_h, + win_w=self.win_w + ) + + def build(self, input_shape): + self.qkv = keras.layers.Dense( + self.head_dim * self.num_heads * 3, use_bias=self.qkv_bias, name="attention_qkv" + ) + self.attn_drop = keras.layers.Dropout(self.attn_drop_rate) + self.proj = keras.layers.Dense(self.dim, name="attention_projection") + self.proj_drop = keras.layers.Dropout(self.proj_drop_rate) + + self.relative_position_bias_table = self.add_weight( + shape=((2 * self.win_h - 1) * (2 * self.win_w - 1), self.num_heads), + initializer=keras.initializers.TruncatedNormal(stddev=0.02), + trainable=True, + name="relative_position_bias_table", + ) + super().build(input_shape) + + def _get_rel_pos_bias(self) -> Any: + relative_position_bias = ops.take( + self.relative_position_bias_table, + self.relative_position_index, + axis=0, + ) + return ops.transpose(relative_position_bias, (2, 0, 1)) + + def call( + self, x, mask=None, return_attns=False + ) -> Union[Any, Tuple[Any, Any]]: + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = ops.shape(x)[0], ops.shape(x)[1], ops.shape(x)[2] + qkv = self.qkv(x) + qkv = ops.reshape(qkv, (B_, N, 3, self.num_heads, -1)) + qkv = ops.transpose(qkv, (2, 0, 3, 1, 4)) + + q, k, v = ops.unstack(qkv, 3) + + scale = ops.cast(self.scale, dtype=qkv.dtype) + q = q * scale + attn = ops.matmul(q, ops.transpose(k, axes=[0, 1, 3, 2])) + attn = attn + self._get_rel_pos_bias() + + if mask is not None: + num_win = ops.shape(mask)[0] + attn = ops.reshape( + attn, (B_ // num_win, num_win, self.num_heads, N, N) + ) + attn = attn + ops.expand_dims(mask, 1)[None, ...] + + attn = ops.reshape(attn, (-1, self.num_heads, N, N)) + attn = ops.nn.softmax(attn, -1) + else: + attn = ops.nn.softmax(attn, -1) + + attn = self.attn_drop(attn) + + x = ops.matmul(attn, v) + x = ops.transpose(x, axes=[0, 2, 1, 3]) + x = ops.reshape(x, (B_, N, C)) + + x = self.proj(x) + x = self.proj_drop(x) + + if return_attns: + return x, attn + else: + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "dim": self.dim, + "window_size": self.window_size, + "win_h": self.win_h, + "win_w": self.win_w, + "num_heads": self.num_heads, + "head_dim": self.head_dim, + "attn_dim": self.attn_dim, + "scale": self.scale, + "qkv_bias": self.qkv_bias, + "attn_drop": self.attn_drop, + "proj_drop": self.proj_drop, + } + ) + return config + + +class SwinTransformerBlock(keras.layers.Layer): + """Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (keras.layers.Layer, optional): Activation layer. Default: keras.layers.Activation("gelu") + norm_layer (keras.layers.Layer, optional): Normalization layer. Default: keras.layers.LayerNormalization(epsilon=1e-5) + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=keras.activations.gelu, + norm_layer=keras.layers.LayerNormalization, + **kwargs, + ): + super().__init__(**kwargs) + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.act_layer = act_layer + + if min(self.input_resolution) <= self.window_size: + self.shift_size = 0 + self.window_size = min(self.input_resolution) + + self.norm1 = norm_layer(epsilon=1e-5, name="norm1") + self.attn = WindowAttention( + dim=dim, + num_heads=num_heads, + window_size=window_size, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + name="attn", + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else keras.layers.Identity() + self.norm2 = norm_layer(epsilon=1e-5, name="norm2") + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=self.act_layer, + dropout_rate=drop, + name="mlp", + ) + + def call(self, x): + H, W = self.input_resolution + B, L, C = ops.shape(x) + + shortcut = x + x = self.norm1(x) + x = ops.reshape(x, (B, H, W, C)) + + attn_mask = None + if self.shift_size > 0: + shifted_x = ops.roll(x, shift=(-self.shift_size, -self.shift_size), axis=(1, 2)) + img_mask = np.zeros((1, H, W, 1), dtype=np.int32) + cnt = 0 + h_slices = [ + (0, H // 2), + (H // 2, H - self.shift_size), + (H - self.shift_size, H), + ] + w_slices = [ + (0, W // 2), + (W // 2, W - self.shift_size), + (W - self.shift_size, W), + ] + for h in h_slices: + for w in w_slices: + img_mask[:, h[0]:h[1], w[0]:w[1], :] = cnt + cnt += 1 + img_mask = ops.convert_to_tensor(img_mask) + + mask_windows = window_partition(img_mask, self.window_size)[0] + mask_windows = ops.reshape(mask_windows, (-1, self.window_size * self.window_size)) + attn_mask = ops.expand_dims(mask_windows, 1) - ops.expand_dims(mask_windows, 2) + attn_mask = ops.where(attn_mask != 0, -100.0, 0.0) + else: + shifted_x = x + + x_windows, (H_pad, W_pad) = window_partition(x=shifted_x, window_size=self.window_size) + x_windows = ops.reshape(x_windows, (-1, self.window_size * self.window_size, C)) + attn_windows = self.attn(x_windows, mask=attn_mask) + + attn_windows = ops.reshape(attn_windows, (-1, self.window_size, self.window_size, C)) + shifted_x = window_reverse(attn_windows, self.window_size, H_pad, W_pad, C) + + if self.shift_size > 0: + x = ops.roll(shifted_x, shift=(self.shift_size, self.shift_size), axis=(1, 2)) + else: + x = shifted_x + + if H_pad > H or W_pad > W: + x = x[:, :H, :W, :] + + x = ops.reshape(x, (B, H * W, C)) + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "dim": self.dim, + "input_resolution": self.input_resolution, + "num_heads": self.num_heads, + "window_size": self.window_size, + "shift_size": self.shift_size, + "mlp_ratio": self.mlp_ratio, + } + ) + return config + + +class PatchMerging(layers.Layer): + """Patch Merging Layer. + + This layer performs downsampling by concatenating patches and using a linear layer. + + Args: + dim: Number of input channels. + """ + + def __init__(self, dim, **kwargs): + super().__init__(**kwargs) + self.dim = dim + self.reduction = layers.Dense(2 * dim, use_bias=False, name="reduction") + self.norm = layers.LayerNormalization(epsilon=1e-5, name="norm") + + def call(self, x, H, W): + """Forward pass. + + Args: + x: Input tensor with shape [B, H*W, C]. + H: Height of feature map. + W: Width of feature map. + + Returns: + Downsampled feature map with shape [B, H/2*W/2, 2*C]. + """ + B, L, C = ops.shape(x) + + x = ops.reshape(x, (B, H, W, C)) + pad_values = ((0, 0), (0, H % 2), (0, W % 2), (0, 0)) + x = ops.pad(x, pad_values) + + # Reshape to group patches + x0 = x[:, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, :] + x2 = x[:, 0::2, 1::2, :] + x3 = x[:, 1::2, 1::2, :] + + x = ops.concatenate([x0, x1, x2, x3], axis=-1) + x = self.norm(x) + x = self.reduction(x) + x = ops.reshape(x, (B, -1, 2 * C)) + + return x + + def get_config(self): + config = super().get_config() + config.update({"dim": self.dim}) + return config + +class PatchEmbedding(layers.Layer): + """Image to Patch Embedding layer for Swin Transformer. + + Args: + patch_size: int. Patch size (usually 4). + embed_dim: int. Output embedding dimension. + norm_layer: Callable layer class for normalization (e.g., LayerNormalization). + data_format: str. Either "channels_last" or "channels_first". + patch_norm: bool. Whether to apply normalization. + """ + + def __init__( + self, + patch_size=4, + embed_dim=96, + norm_layer=None, + data_format="channels_last", + patch_norm=True, + + **kwargs, + ): + super().__init__(**kwargs) + self.patch_size = patch_size + self.embed_dim = embed_dim + self.data_format = data_format + self.patch_norm = patch_norm + self.norm_layer = norm_layer + + self.proj = layers.Conv2D( + filters=embed_dim, + kernel_size=patch_size, + strides=patch_size, + padding="valid", + data_format=data_format, + name="proj", + ) + + if self.patch_norm and self.norm_layer is not None: + self.norm = norm_layer(name="norm") + else: + self.norm = None + + def call(self, x): + x = self.proj(x) # shape: (B, H//P, W//P, C) + if self.data_format == "channels_first": + x = ops.transpose(x, [0, 2, 3, 1]) + x = ops.reshape(x, [ops.shape(x)[0], -1, self.embed_dim]) + if self.norm: + x = self.norm(x) + return x + + def get_config(self): + config = super().get_config() + config.update({ + "patch_size": self.patch_size, + "embed_dim": self.embed_dim, + "data_format": self.data_format, + "patch_norm": self.patch_norm, + "norm_layer": keras.saving.serialize_keras_object(self.norm_layer) + if self.norm_layer else None, + }) + return config + + @classmethod + def from_config(cls, config): + config["norm_layer"] = keras.saving.deserialize_keras_object(config["norm_layer"]) + return cls(**config) + + +class SwinTransformerStage(layers.Layer): + """Swin Transformer Stage. + + A stage consists of multiple Swin Transformer blocks with the same resolution, + and an optional patch merging layer at the beginning. + + Args: + dim: Number of input channels. + depth: Number of blocks in this stage. + num_heads: Number of attention heads. + window_size: Local window size. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + drop: Dropout rate. + attn_drop: Dropout rate for attention. + drop_path: Stochastic depth rate. + downsample: Downsample layer at the end of the layer. + """ + + def __init__( + self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + downsample=None, + input_resolution=None, + **kwargs, + ): + super().__init__(**kwargs) + self.dim = dim + self.depth = depth + self.window_size = window_size + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.blocks = [] + self.downsample = downsample + self._drop_path = drop_path + self._qkv_bias = qkv_bias + self._drop = drop + self._attn_drop = attn_drop + self.input_resolution = input_resolution + + def build(self, input_shape): + for i in range(self.depth): + self.blocks.append( + SwinTransformerBlock( + dim=self.dim, + input_resolution=self.input_resolution, + num_heads=self.num_heads, + window_size=self.window_size, + shift_size=0 if (i % 2 == 0) else self.window_size // 2, + mlp_ratio=self.mlp_ratio, + qkv_bias=self._qkv_bias, + drop=self._drop, + attn_drop=self._attn_drop, + drop_path=self._drop_path[i] if isinstance(self._drop_path, list) else self._drop_path, + name=f"blocks_{i}", + ) + ) + + if self.downsample is not None: + self.downsample = self.downsample( + dim=self.dim, + name="downsample", + ) + + super().build(input_shape) + + def call(self, x): + """Forward pass. + + Args: + x: Input feature with shape [B, H*W, C]. + + Returns: + Output feature with shape [B, H/2*W/2, 2*C] if downsample is applied, + otherwise [B, H*W, C]. + """ + for block in self.blocks: + x = block(x) + + if self.downsample is not None: + H, W = self.input_resolution + x = self.downsample(x, H=H, W=W) + + return x + + def get_config(self): + config = super().get_config() + config.update({ + "dim": self.dim, + "depth": self.depth, + "num_heads": self.num_heads, + "window_size": self.window_size, + "mlp_ratio": self.mlp_ratio, + "qkv_bias": self._qkv_bias, + "drop": self._drop, + "attn_drop": self._attn_drop, + "drop_path": self._drop_path, + "downsample": keras.utils.serialize_keras_object(self.downsample) if self.downsample else None, + }) + return config + + @classmethod + def from_config(cls, config): + config["downsample"] = keras.utils.deserialize_keras_object(config["downsample"]) if config["downsample"] else None + return cls(**config) + From 874c524fa735f04ecf5470a15dd31a688639bfd1 Mon Sep 17 00:00:00 2001 From: ado Date: Fri, 13 Jun 2025 09:53:30 -0400 Subject: [PATCH 3/7] Deleted Swin Transformer files, remove typo --- .../swin_transformers_backbone.py | 143 ---- .../swin_transformers_backbone_test.py | 51 -- .../swin_transformers_layers.py | 704 ------------------ 3 files changed, 898 deletions(-) delete mode 100644 keras_hub/src/models/swin_transformers/swin_transformers_backbone.py delete mode 100644 keras_hub/src/models/swin_transformers/swin_transformers_backbone_test.py delete mode 100644 keras_hub/src/models/swin_transformers/swin_transformers_layers.py diff --git a/keras_hub/src/models/swin_transformers/swin_transformers_backbone.py b/keras_hub/src/models/swin_transformers/swin_transformers_backbone.py deleted file mode 100644 index 6bc9b80b97..0000000000 --- a/keras_hub/src/models/swin_transformers/swin_transformers_backbone.py +++ /dev/null @@ -1,143 +0,0 @@ -import keras -from keras import layers -from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.backbone import Backbone -from keras_hub.src.models.swin_transformers.swin_transformers_layers import ( - PatchEmbedding, - SwinTransformerStage, - PatchMerging -) - -@keras_hub_export("keras_hub.models.SwinTransformersBackbone") -class SwinTransformersBackbone(Backbone): - """Swin Transformer backbone. - - This backbone implements the Swin Transformer architecture as described in - [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030). - - The Swin Transformer is a hierarchical vision transformer that uses shifted - windows for self-attention computation. It has several advantages: - - 1. Hierarchical feature maps with downsampling like CNNs - 2. Linear computational complexity with respect to image size - 3. Support for various vision tasks, including image classification, - object detection, and semantic segmentation - - Args: - image_shape: A tuple or list of 3 integers representing the shape of the - input image `(height, width, channels)`. - patch_size: int. The size of each patch (both height and width). - embed_dim: int. The embedding dimension for the first stage. - depths: list of ints. Number of transformer blocks in each stage. - num_heads: list of ints. Number of attention heads in each stage. - window_size: int. Size of attention window (both height and width). - mlp_ratio: float. Ratio of MLP hidden dimension to embedding dimension. - qkv_bias: bool. If True, add a learnable bias to query, key, value. - dropout_rate: float. Dropout rate for embedding and transformer layers. - attention_dropout: float. Dropout rate for attention projections. - path_dropout: float. Stochastic depth rate for transformer blocks. - patch_norm: bool. If True, add normalization after patch embedding. - data_format: str. One of `"channels_last"` or `"channels_first"`. - dtype: The dtype of the layer weights. Defaults to None. - **kwargs: Additional keyword arguments to be passed to the parent - `Backbone` class. - """ - - def __init__( - self, - image_shape=(224, 224, 3), - patch_size=4, - embed_dim=96, - depths=[2, 2, 6, 2], - num_heads=[3, 6, 12, 24], - window_size=7, - mlp_ratio=4.0, - qkv_bias=True, - dropout_rate=0.0, - attention_dropout=0.0, - path_dropout=0.2, - patch_norm=True, - data_format="channels_last", - dtype=None, - **kwargs, - ): - if len(depths) != len(num_heads): - raise ValueError( - f"Length of depths ({len(depths)}) must match " - f"length of num_heads ({len(num_heads)})" - ) - - self.patch_embedding = PatchEmbedding( - patch_size=patch_size, - embed_dim=embed_dim, - data_format=data_format, - patch_norm=patch_norm, - name="patch_embedding" - ) - - self.pos_dropout = layers.Dropout(dropout_rate, name="pos_dropout") if dropout_rate > 0.0 else None - - self.stages = [] - for i, (depth, num_head) in enumerate(zip(depths, num_heads)): - dim = embed_dim * (2 ** i) - downsample = PatchMerging(dim=dim // 2, name=f"downsample_{i-1}") if i > 0 else None - - stage = SwinTransformerStage( - dim=dim, - depth=depth, - num_heads=num_head, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - dropout_rate=dropout_rate, - attention_dropout=attention_dropout, - path_dropout=path_dropout, - downsample=downsample, - name=f"stage_{i}" - ) - self.stages.append(stage) - - self.norm = layers.LayerNormalization(epsilon=1e-5, name="norm") - - inputs = keras.layers.Input(shape=image_shape) - x = self.patch_embedding(inputs) - if self.pos_dropout is not None: - x = self.pos_dropout(x) - for stage in self.stages: - x = stage(x) - x = self.norm(x) - - super().__init__(inputs=inputs, outputs=x, dtype=dtype, **kwargs) - - self.data_format = data_format - self.image_shape = image_shape - self.patch_size = patch_size - self.embed_dim = embed_dim - self.depths = depths - self.num_heads = num_heads - self.window_size = window_size - self.mlp_ratio = mlp_ratio - self.qkv_bias = qkv_bias - self.dropout_rate = dropout_rate - self.attention_dropout = attention_dropout - self.path_dropout = path_dropout - self.patch_norm = patch_norm - - def get_config(self): - config = super().get_config() - config.update({ - "image_shape": self.image_shape, - "patch_size": self.patch_size, - "embed_dim": self.embed_dim, - "depths": self.depths, - "num_heads": self.num_heads, - "window_size": self.window_size, - "mlp_ratio": self.mlp_ratio, - "qkv_bias": self.qkv_bias, - "dropout_rate": self.dropout_rate, - "attention_dropout": self.attention_dropout, - "path_dropout": self.path_dropout, - "patch_norm": self.patch_norm, - "data_format": self.data_format, - }) - return config \ No newline at end of file diff --git a/keras_hub/src/models/swin_transformers/swin_transformers_backbone_test.py b/keras_hub/src/models/swin_transformers/swin_transformers_backbone_test.py deleted file mode 100644 index 48e04e1ecf..0000000000 --- a/keras_hub/src/models/swin_transformers/swin_transformers_backbone_test.py +++ /dev/null @@ -1,51 +0,0 @@ -import pytest -from keras import ops - -from keras_hub.src.models.swin_transformers.swin_transformers_backbone import ( - SwinTransformersBackbone, -) -from keras_hub.src.tests.test_case import TestCase - -class SwinTransformersBackboneTest(TestCase): - def setUp(self): - self.init_kwargs = { - "image_shape": (64, 64, 3), - "patch_size": 2, - "embed_dim": 32, - "depths": [1, 1, 1, 1], - "num_heads": [1, 2, 4, 8], - "window_size": 4, - "mlp_ratio": 4.0, - "qkv_bias": True, - "dropout_rate": 0.0, - "attention_dropout": 0.0, - "path_dropout": 0.1, - "patch_norm": True, - "data_format": "channels_last", - "dtype": "float32", - } - self.input_data = ops.ones((2, 64, 64, 3), dtype="float32") - - def test_backbone_basics(self): - self.run_backbone_test( - cls=SwinTransformersBackbone, - init_kwargs=self.init_kwargs, - input_data=self.input_data, - expected_output_shape=(2, 2, 2, 256), - ) - - @pytest.mark.large - def test_saved_model(self): - self.run_model_saving_test( - cls=SwinTransformersBackbone, - init_kwargs=self.init_kwargs, - input_data=self.input_data, - ) - - @pytest.mark.large - def test_smallest_preset(self): - pass # Will be added in a future PR when presets are implemented - - @pytest.mark.extra_large - def test_all_presets(self): - pass # Will be added in a future PR when presets are implemented diff --git a/keras_hub/src/models/swin_transformers/swin_transformers_layers.py b/keras_hub/src/models/swin_transformers/swin_transformers_layers.py deleted file mode 100644 index 0352eba520..0000000000 --- a/keras_hub/src/models/swin_transformers/swin_transformers_layers.py +++ /dev/null @@ -1,704 +0,0 @@ -import keras -from keras import layers -from keras import ops -import collections.abc - -def window_partition(x, window_size): - """Partition the input tensor into non-overlapping windows.""" - batch_size, height, width, channels = ops.shape(x) - - x = ops.reshape( - x, - ( - batch_size, - height // window_size, - window_size, - width // window_size, - window_size, - channels - ) - ) - - x = ops.transpose(x, (0, 1, 3, 2, 4, 5)) - windows = ops.reshape( - x, (-1, window_size, window_size, channels) - ) - return windows - - -def window_reverse(windows, window_size, height, width, channels): - """Reverse window partitioning.""" - batch_size = ops.shape(windows)[0] // ((height // window_size) * (width // window_size)) - - x = ops.reshape( - windows, - ( - batch_size, - height // window_size, - width // window_size, - window_size, - window_size, - channels - ) - ) - - x = ops.transpose(x, (0, 1, 3, 2, 4, 5)) - x = ops.reshape(x, (batch_size, height, width, channels)) - return x - - -class DropPath(layers.Layer): - """Drop paths (Stochastic Depth) per sample. - - This is an implementation of the paper "Deep Networks with Stochastic Depth", - which randomly drops entire layers for regularization. - - Args: - drop_prob: float, probability of dropping path. - """ - - def __init__(self, drop_prob=0.0, **kwargs): - super().__init__(**kwargs) - self.drop_prob = drop_prob - - def call(self, x, training=None): - if self.drop_prob == 0.0 or not training: - return x - - # Keep probability - keep_prob = 1.0 - self.drop_prob - - # Create binary mask with shape [batch_size, 1, 1, 1] - batch_size = ops.shape(x)[0] - random_tensor = keep_prob + ops.random.uniform((batch_size, 1, 1, 1), dtype=x.dtype) - binary_mask = ops.floor(random_tensor) - - # Scale output to preserve expected value - output = x / keep_prob * binary_mask - return output - - def get_config(self): - config = super().get_config() - config.update({"drop_prob": self.drop_prob}) - return config - - -class Mlp(layers.Layer): - """MLP module for Transformer. - - Args: - in_features: Input dimension. - hidden_features: Hidden dimension. - out_features: Output dimension. - dropout_rate: Dropout rate. - """ - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - dropout_rate=0.0, - **kwargs, - ): - super().__init__(**kwargs) - out_features = out_features or in_features - hidden_features = hidden_features or in_features - - self.in_features = in_features - self.hidden_features = hidden_features - self.out_features = out_features - self.dropout_rate = dropout_rate - - self.fc1 = layers.Dense(hidden_features, name="fc1") - self.act = keras.activations.gelu - self.fc2 = layers.Dense(out_features, name="fc2") - self.drop = layers.Dropout(dropout_rate) if dropout_rate > 0.0 else None - - def call(self, x): - x = self.fc1(x) - x = self.act(x) - if self.drop is not None: - x = self.drop(x) - x = self.fc2(x) - if self.drop is not None: - x = self.drop(x) - return x - - def get_config(self): - config = super().get_config() - config.update({ - "in_features": self.in_features, - "hidden_features": self.hidden_features, - "out_features": self.out_features, - "dropout_rate": self.dropout_rate, - }) - return config - - -class WindowAttention(layers.Layer): - """Window based multi-head self attention. - - Args: - dim: Number of input channels - window_size: Window size - num_heads: Number of attention heads - qkv_bias: Add bias to query, key, value projections - attention_dropout: Attention dropout rate - dropout: Dropout rate - """ - - def __init__( - self, - dim, - window_size, - num_heads, - qkv_bias=True, - attention_dropout=0., - dropout=0., - **kwargs - ): - super().__init__(**kwargs) - - self.dim = dim - self.window_size = ( - window_size - if isinstance(window_size, collections.abc.Iterable) - else (window_size, window_size) - ) - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - # Linear layers for Q, K, V - self.qkv = layers.Dense( - dim * 3, - use_bias=qkv_bias, - name="qkv" - ) - - # Relative position encoding - self.relative_position_bias_table = self.add_weight( - shape=((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads), - initializer="zeros", - trainable=True, - name="relative_position_bias_table" - ) - - # Get pair-wise relative position index - coords = ops.stack(ops.meshgrid( - ops.arange(self.window_size[0]), - ops.arange(self.window_size[1]) - )) - coords = ops.reshape(coords, [2, -1]) - relative_coords = coords[:, :, None] - coords[:, None, :] - relative_coords = ops.transpose(relative_coords, [1, 2, 0]) - - relative_coords = relative_coords + self.window_size[0] - 1 - relative_coords = relative_coords * (2 * self.window_size[0] - 1) - relative_position_index = ops.sum(relative_coords, -1) - - self.relative_position_index = relative_position_index - - self.attn_drop = layers.Dropout(attention_dropout) - self.proj = layers.Dense(dim) - self.proj_drop = layers.Dropout(dropout) - - def build(self, input_shape): - self.num_windows = input_shape[0] // ( - self.window_size[0] * self.window_size[1] - ) - super().build(input_shape) - - def call(self, x, mask=None): - """Forward pass. - - Args: - x: Input tensor with shape [batch*num_windows, window_size*window_size, dim]. - mask: Optional mask for shifted window attention. - - Returns: - Output tensor with shape [batch*num_windows, window_size*window_size, dim]. - """ - B_, N, C = ops.shape(x) - - # QKV projection - qkv = self.qkv(x) # [B_, N, 3*C] - - # Calculate exact dimensions - qkv_dim = ops.shape(qkv)[-1] - dim_per_head = C // self.num_heads - - # Split QKV - # This splits the last dimension into 3 equal parts - chunk_size = qkv_dim // 3 - q = qkv[:, :, :chunk_size] - k = qkv[:, :, chunk_size:2*chunk_size] - v = qkv[:, :, 2*chunk_size:] - - # Reshape to separate heads - q = ops.reshape(q, (B_, N, self.num_heads, dim_per_head)) - k = ops.reshape(k, (B_, N, self.num_heads, dim_per_head)) - v = ops.reshape(v, (B_, N, self.num_heads, dim_per_head)) - - # Transpose to [B_, num_heads, N, head_dim] - q = ops.transpose(q, (0, 2, 1, 3)) - k = ops.transpose(k, (0, 2, 1, 3)) - v = ops.transpose(v, (0, 2, 1, 3)) - - # Scale query - q = q * self.scale - - # Compute attention scores - attn = ops.matmul(q, ops.transpose(k, (0, 1, 3, 2))) - - # Add relative position bias - relative_position_bias = ops.take( - self.relative_position_bias_table, - self.relative_position_index, - ) - - relative_position_bias = ops.reshape( - relative_position_bias, - (self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], - self.num_heads) - ) - - # Transpose to [num_heads, Wh*Ww, Wh*Ww] - relative_position_bias = ops.transpose(relative_position_bias, (2, 0, 1)) - - # Add to attention [B_, num_heads, N, N] - attn = attn + ops.expand_dims(relative_position_bias, axis=0) - - # Apply attention mask if provided - if mask is not None: - nW = mask.shape[0] # num_windows - # attn: [B_/nW, nW, num_heads, N, N] - # mask: [1, nW, 1, N, N] - attn = ops.reshape(attn, (-1, nW, self.num_heads, N, N)) - mask = ops.expand_dims(mask, axis=1) # [nW, 1, N, N] -> [1, nW, 1, N, N] - attn = attn + ops.cast(mask, attn.dtype) * -100.0 - attn = ops.reshape(attn, (-1, self.num_heads, N, N)) - - # Softmax normalization and dropout - attn = ops.softmax(attn, axis=-1) - if self.attn_drop is not None: - attn = self.attn_drop(attn) - - # Apply attention to values - x = ops.matmul(attn, v) # [B_, num_heads, N, head_dim] - - # Transpose back to [B_, N, C] - x = ops.transpose(x, (0, 2, 1, 3)) - x = ops.reshape(x, (B_, N, C)) - - # Output projection and dropout - x = self.proj(x) - if self.proj_drop is not None: - x = self.proj_drop(x) - - return x - - def get_config(self): - config = super().get_config() - config.update({ - "dim": self.dim, - "window_size": self.window_size, - "num_heads": self.num_heads, - "qkv_bias": self.qkv_bias, - "attention_dropout": self.attention_dropout, - "dropout": self.dropout, - }) - return config - - -class SwinTransformerBlock(layers.Layer): - """Swin Transformer Block. - - Args: - dim: Number of input channels. - input_resolution: Input resolution (height, width). - num_heads: Number of attention heads. - window_size: Window size for attention. - shift_size: Shift size for shifted window attention (0 or window_size//2). - mlp_ratio: Ratio of mlp hidden dim to embedding dim. - qkv_bias: If True, add a learnable bias to query, key, value. - dropout_rate: Dropout rate. - attention_dropout: Dropout rate for attention. - path_dropout: Stochastic depth rate. - norm_layer: Normalization layer class. - """ - - def __init__( - self, - dim, - input_resolution=None, - num_heads=1, - window_size=7, - shift_size=0, - mlp_ratio=4.0, - qkv_bias=True, - dropout_rate=0.0, - attention_dropout=0.0, - path_dropout=0.0, - norm_layer=layers.LayerNormalization, - **kwargs, - ): - super().__init__(**kwargs) - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - - self.norm1 = norm_layer(epsilon=1e-5, name="norm1") - self.attn = WindowAttention( - dim=dim, - window_size=window_size, - num_heads=num_heads, - qkv_bias=qkv_bias, - attention_dropout=attention_dropout, - dropout=dropout_rate, - name="attn" - ) - self.drop_path = DropPath(path_dropout) if path_dropout > 0. else None - self.norm2 = norm_layer(epsilon=1e-5, name="norm2") - self.mlp = Mlp( - in_features=dim, - hidden_features=int(dim * mlp_ratio), - dropout_rate=dropout_rate, - name="mlp" - ) - - if self.shift_size > 0: - H, W = self.input_resolution - img_mask = ops.zeros((1, H, W, 1)) - - h_slices = [ - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None) - ] - w_slices = [ - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None) - ] - - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask_segment = ops.ones((1, H, W, 1)) - img_mask_segment = ops.index_update( - img_mask_segment, (..., h, w, ...), ops.ones((1, h.stop - h.start if h.stop else H - h.start, - w.stop - w.start if w.stop else W - w.start, 1)) * cnt - ) - img_mask = img_mask + img_mask_segment - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = ops.reshape(mask_windows, (-1, self.window_size * self.window_size)) - attn_mask = ops.expand_dims(mask_windows, axis=1) - ops.expand_dims(mask_windows, axis=2) - attn_mask = ops.where(attn_mask != 0, -100.0, 0.0) - self.attn_mask = attn_mask - else: - self.attn_mask = None - - def call(self, x): - B, L, C = ops.shape(x) - H, W = self.input_resolution - - window_size = self.window_size - shift_size = self.shift_size - - if min(H, W) <= window_size: - window_size = min(H, W) - shift_size = 0 - - x = ops.reshape(x, (B, H, W, C)) - - if self.shift_size > 0: - shifted_x = ops.roll(x, shift=(-self.shift_size, -self.shift_size), axis=(1, 2)) - else: - shifted_x = x - - x_windows = window_partition(shifted_x, self.window_size) # [B*num_windows, window_size, window_size, C] - x_windows = ops.reshape(x_windows, (-1, self.window_size * self.window_size, C)) # [B*num_windows, window_size*window_size, C] - - identity = x_windows - - x_windows = self.norm1(x_windows) - attn_windows = self.attn(x_windows, mask=self.attn_mask) # [B*num_windows, window_size*window_size, C] - - if self.drop_path is not None: - attn_windows = self.drop_path(attn_windows) - - x_windows = identity + attn_windows - - identity = x_windows - x_windows = self.norm2(x_windows) - x_windows = self.mlp(x_windows) - - if self.drop_path is not None: - x_windows = self.drop_path(x_windows) - - x_windows = identity + x_windows - - x_windows = ops.reshape(x_windows, (-1, self.window_size, self.window_size, C)) - - if self.shift_size > 0: - x = window_reverse(x_windows, self.window_size, H, W, C) - x = ops.roll(x, shift=(self.shift_size, self.shift_size), axis=(1, 2)) - else: - x = window_reverse(x_windows, self.window_size, H, W, C) - - x = ops.reshape(x, (B, H * W, C)) - - return x - - def get_config(self): - config = super().get_config() - config.update({ - "dim": self.dim, - "input_resolution": self.input_resolution, - "num_heads": self.num_heads, - "window_size": self.window_size, - "shift_size": self.shift_size, - "mlp_ratio": self.mlp_ratio, - }) - return config - - def compute_output_shape(self, input_shape): - return input_shape - - -class PatchMerging(layers.Layer): - """Patch Merging Layer. - - This layer performs downsampling by concatenating patches and using a linear layer. - - Args: - dim: Number of input channels. - """ - - def __init__(self, dim, **kwargs): - super().__init__(**kwargs) - self.dim = dim - self.reduction = layers.Dense(2 * dim, use_bias=False, name="reduction") - self.norm = layers.LayerNormalization(epsilon=1e-5, name="norm") - - def call(self, x, H, W): - """Forward pass. - - Args: - x: Input tensor with shape [B, H*W, C]. - H: Height of feature map. - W: Width of feature map. - - Returns: - Downsampled feature map with shape [B, H/2*W/2, 2*C]. - """ - B, L, C = ops.shape(x) - - x = ops.reshape(x, (B, H, W, C)) - pad_values = ((0, 0), (0, H % 2), (0, W % 2), (0, 0)) - x = ops.pad(x, pad_values) - - # Reshape to group patches - x0 = x[:, 0::2, 0::2, :] - x1 = x[:, 1::2, 0::2, :] - x2 = x[:, 0::2, 1::2, :] - x3 = x[:, 1::2, 1::2, :] - - x = ops.concatenate([x0, x1, x2, x3], axis=-1) - x = self.norm(x) - x = self.reduction(x) - x = ops.reshape(x, (B, -1, 2 * C)) - - return x - - def get_config(self): - config = super().get_config() - config.update({"dim": self.dim}) - return config - - def compute_output_shape(self, input_shape): - batch_size, seq_len, channels = input_shape - return (batch_size, seq_len // 4, channels * 2) - - -class PatchEmbedding(layers.Layer): - """Image to Patch Embedding. - - Args: - patch_size: Size of each patch. - embed_dim: Embedding dimension. - norm_layer: Normalization layer. - data_format: Format of the input data, either "channels_last" or "channels_first". - patch_norm: If True, add normalization after patch embedding. - """ - - def __init__( - self, - patch_size=4, - embed_dim=96, - norm_layer=None, - data_format="channels_last", - patch_norm=False, - **kwargs, - ): - super().__init__(**kwargs) - self.patch_size = patch_size - self.embed_dim = embed_dim - self.data_format = data_format - - self.proj = layers.Conv2D( - embed_dim, - kernel_size=patch_size, - strides=patch_size, - padding="valid", - data_format=data_format, - name="proj", - ) - - self.norm = norm_layer(epsilon=1e-5, name="norm") if patch_norm and norm_layer else None - - def call(self, x): - """Forward pass. - - Args: - x: Input images with shape [B, H, W, C] in channels_last format - or [B, C, H, W] in channels_first format. - - Returns: - Patch embeddings with shape [B, H//patch_size * W//patch_size, embed_dim]. - """ - B = ops.shape(x)[0] - - x = self.proj(x) - - if self.data_format == "channels_last": - _, H, W, C = ops.shape(x) - x = ops.reshape(x, (B, H * W, C)) - else: - _, C, H, W = ops.shape(x) - x = ops.transpose(x, (0, 2, 3, 1)) # [B, H, W, C] - x = ops.reshape(x, (B, H * W, C)) - - if self.norm is not None: - x = self.norm(x) - - return x - - def get_config(self): - config = super().get_config() - config.update({ - "patch_size": self.patch_size, - "embed_dim": self.embed_dim, - "data_format": self.data_format, - }) - return config - - -class SwinTransformerStage(layers.Layer): - """Swin Transformer Stage. - - A stage consists of multiple Swin Transformer blocks with the same resolution, - and an optional patch merging layer at the beginning. - - Args: - dim: Number of input channels. - depth: Number of blocks in this stage. - num_heads: Number of attention heads. - window_size: Local window size. - mlp_ratio: Ratio of mlp hidden dim to embedding dim. - qkv_bias: If True, add a learnable bias to query, key, value. - dropout_rate: Dropout rate. - attention_dropout: Dropout rate for attention. - path_dropout: Stochastic depth rate. - downsample: Downsample layer at the end of the layer. - """ - - def __init__( - self, - dim, - depth, - num_heads, - window_size=7, - mlp_ratio=4.0, - qkv_bias=True, - dropout_rate=0.0, - attention_dropout=0.0, - path_dropout=0.0, - downsample=None, - **kwargs, - ): - super().__init__(**kwargs) - self.dim = dim - self.depth = depth - self.window_size = window_size - self.num_heads = num_heads - self.mlp_ratio = mlp_ratio - - self.blocks = [] - for i in range(depth): - self.blocks.append( - SwinTransformerBlock( - dim=dim, - input_resolution=None, - num_heads=num_heads, - window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - dropout_rate=dropout_rate, - attention_dropout=attention_dropout, - path_dropout=path_dropout[i] if isinstance(path_dropout, list) else path_dropout, - name=f"blocks_{i}", - ) - ) - - self.downsample = downsample - - def call(self, x): - """Forward pass. - - Args: - x: Input feature with shape [B, H*W, C]. - - Returns: - Output feature with shape [B, H/2*W/2, 2*C] if downsample is applied, - otherwise [B, H*W, C]. - """ - B, L, C = ops.shape(x) - - H_W = ops.cast(ops.sqrt(ops.cast(L, "float32")), "int32") - - for block in self.blocks: - block.input_resolution = (H_W, H_W) - - for block in self.blocks: - x = block(x) - - if self.downsample is not None: - x = self.downsample(x, H_W, H_W) - - return x - - def get_config(self): - config = super().get_config() - config.update({ - "dim": self.dim, - "depth": self.depth, - "window_size": self.window_size, - "mlp_ratio": self.mlp_ratio, - }) - return config - - def compute_output_shape(self, input_shape): - batch_size, seq_len, channels = input_shape - if self.downsample is not None: - return (batch_size, seq_len // 4, channels * 2) - return input_shape From a4f65954ef1899f1050d71c837f70d23b118fcc7 Mon Sep 17 00:00:00 2001 From: ado Date: Tue, 17 Jun 2025 14:35:02 -0400 Subject: [PATCH 4/7] Fixed dtype issue --- .../models/swin_transformer/swin_transformer_backbone.py | 6 +++--- .../swin_transformer/swin_transformer_backbone_test.py | 4 +++- .../src/models/swin_transformer/swin_transformer_layers.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py b/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py index 89d1fa6776..cc9fc48f44 100644 --- a/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py +++ b/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py @@ -82,7 +82,8 @@ def __init__( dtype=None, **kwargs, ): - dtype = dtype or keras.backend.floatx() + if dtype is None: + dtype = keras.backend.floatx() # === Layers === self.patch_embedding = PatchEmbedding( @@ -148,11 +149,10 @@ def reshape_and_norm(tensor, norm_layer=self.norm_layers[i]): x_reshaped = keras.layers.Lambda(reshape_and_norm)(x) features.append(x_reshaped) - super().__init__( inputs=inputs, outputs=features[-1], - dtype=None, + dtype=dtype, **kwargs ) diff --git a/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py b/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py index 569cc4777f..2c9101497e 100644 --- a/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py +++ b/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py @@ -8,7 +8,7 @@ class SwinTransformerBackboneTest(TestCase): def setUp(self): - mixed_precision.set_global_policy("float32") + super().setUp() self.init_kwargs = { "image_shape": (224, 224, 3), "patch_size": 4, @@ -26,6 +26,8 @@ def test_backbone_basics(self): init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output_shape = (1, 7, 7, 768), + run_mixed_precision_check=False, + run_quantization_check=False, ) @pytest.mark.large diff --git a/keras_hub/src/models/swin_transformer/swin_transformer_layers.py b/keras_hub/src/models/swin_transformer/swin_transformer_layers.py index 45b494c63c..cf4b16227d 100644 --- a/keras_hub/src/models/swin_transformer/swin_transformer_layers.py +++ b/keras_hub/src/models/swin_transformer/swin_transformer_layers.py @@ -552,7 +552,6 @@ def __init__( norm_layer=None, data_format="channels_last", patch_norm=True, - **kwargs, ): super().__init__(**kwargs) @@ -620,6 +619,7 @@ class SwinTransformerStage(layers.Layer): attn_drop: Dropout rate for attention. drop_path: Stochastic depth rate. downsample: Downsample layer at the end of the layer. + input_resolution: Input resolution. """ def __init__( From 3f55d9ba5279c3a62346498fb10276459102d6e3 Mon Sep 17 00:00:00 2001 From: ado Date: Wed, 18 Jun 2025 09:50:39 -0400 Subject: [PATCH 5/7] Fix formatting and undefined variable errors --- .../swin_transformer_backbone.py | 65 ++-- .../swin_transformer_backbone_test.py | 7 +- .../swin_transformer_layers.py | 297 ++++++++++-------- 3 files changed, 214 insertions(+), 155 deletions(-) diff --git a/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py b/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py index cc9fc48f44..e3719b59c3 100644 --- a/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py +++ b/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py @@ -1,17 +1,24 @@ import keras from keras import layers from keras import ops + from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.swin_transformer.swin_transformer_layers import ( PatchEmbedding, +) +from keras_hub.src.models.swin_transformer.swin_transformer_layers import ( + PatchMerging, +) +from keras_hub.src.models.swin_transformer.swin_transformer_layers import ( SwinTransformerStage, - PatchMerging ) + def swin_kernel_initializer(stddev=0.02): return keras.initializers.TruncatedNormal(stddev=stddev) + @keras_hub_export("keras_hub.models.SwinTransformerBackbone") class SwinTransformerBackbone(Backbone): """A Swin Transformer backbone network. @@ -23,11 +30,14 @@ class SwinTransformerBackbone(Backbone): The default constructor gives a fully customizable, randomly initialized Swin Transformer with any number of layers, heads, and embedding dimensions. - To load preset architectures and weights, use the `from_preset()` constructor. + To load preset architectures and weights, use the `from_preset()` + constructor. Args: - image_shape: tuple of ints. The shape of the input images, excluding batch dimension. - patch_size: int. Size of the patches to be extracted from the input images. + image_shape: tuple of ints. The shape of the input images, excluding + batch dimension. + patch_size: int. Size of the patches to be extracted from the input + images. embed_dim: int. Base dimension of the transformer. depths: tuple of ints. Number of transformer blocks in each stage. num_heads: tuple of ints. Number of attention heads in each stage. @@ -38,7 +48,8 @@ class SwinTransformerBackbone(Backbone): attn_drop: float. Dropout rate for attention. drop_path: float. Stochastic depth rate. patch_norm: bool. If True, add normalization after patch embedding. - data_format: str. Format of the input data, either "channels_last" or "channels_first". + data_format: str. Format of the input data, either "channels_last" or + "channels_first". dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use for model computations and weights. @@ -107,7 +118,7 @@ def __init__( self.stages = [] for i in range(len(depths)): stage = SwinTransformerStage( - dim=int(embed_dim * 2 ** i), + dim=int(embed_dim * 2**i), depth=depths[i], num_heads=num_heads[i], window_size=window_size, @@ -126,7 +137,8 @@ def __init__( # Final norm self.norm_layers = [ - layers.LayerNormalization(epsilon=1e-5, name=f"norm_{i}") for i in range(len(depths)) + layers.LayerNormalization(epsilon=1e-5, name=f"norm_{i}") + for i in range(len(depths)) ] # Forward pass @@ -150,10 +162,7 @@ def reshape_and_norm(tensor, norm_layer=self.norm_layers[i]): features.append(x_reshaped) super().__init__( - inputs=inputs, - outputs=features[-1], - dtype=dtype, - **kwargs + inputs=inputs, outputs=features[-1], dtype=dtype, **kwargs ) # === Config === @@ -173,19 +182,21 @@ def reshape_and_norm(tensor, norm_layer=self.norm_layers[i]): def get_config(self): config = super().get_config() - config.update({ - "image_shape": self.image_shape, - "patch_size": self.patch_size, - "embed_dim": self.embed_dim, - "depths": self.depths, - "num_heads": self.num_heads, - "window_size": self.window_size, - "mlp_ratio": self.mlp_ratio, - "qkv_bias": self.qkv_bias, - "drop": self.drop, - "attn_drop": self.attn_drop, - "drop_path": self.drop_path, - "patch_norm": self.patch_norm, - "data_format": self.data_format, - }) - return config \ No newline at end of file + config.update( + { + "image_shape": self.image_shape, + "patch_size": self.patch_size, + "embed_dim": self.embed_dim, + "depths": self.depths, + "num_heads": self.num_heads, + "window_size": self.window_size, + "mlp_ratio": self.mlp_ratio, + "qkv_bias": self.qkv_bias, + "drop": self.drop, + "attn_drop": self.attn_drop, + "drop_path": self.drop_path, + "patch_norm": self.patch_norm, + "data_format": self.data_format, + } + ) + return config diff --git a/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py b/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py index 2c9101497e..0be4e8ff10 100644 --- a/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py +++ b/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py @@ -1,11 +1,12 @@ import pytest -from keras import ops, mixed_precision - +from keras import ops + from keras_hub.src.models.swin_transformer.swin_transformer_backbone import ( SwinTransformerBackbone, ) from keras_hub.src.tests.test_case import TestCase + class SwinTransformerBackboneTest(TestCase): def setUp(self): super().setUp() @@ -25,7 +26,7 @@ def test_backbone_basics(self): cls=SwinTransformerBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape = (1, 7, 7, 768), + expected_output_shape=(1, 7, 7, 768), run_mixed_precision_check=False, run_quantization_check=False, ) diff --git a/keras_hub/src/models/swin_transformer/swin_transformer_layers.py b/keras_hub/src/models/swin_transformer/swin_transformer_layers.py index cf4b16227d..706d3ab612 100644 --- a/keras_hub/src/models/swin_transformer/swin_transformer_layers.py +++ b/keras_hub/src/models/swin_transformer/swin_transformer_layers.py @@ -1,46 +1,53 @@ +import collections.abc +from typing import Any +from typing import Tuple +from typing import Union + import keras +import numpy as np from keras import layers from keras import ops -import collections.abc -from typing import Union, Tuple, Any -import numpy as np + def get_relative_position_index(win_h, win_w): """Get pair-wise relative position index for each token inside the window. - + Args: win_h: Height of the window. win_w: Width of the window. - + Returns: A tensor of shape (win_h*win_w, win_h*win_w) containing the relative position indices for each pair of tokens in the window. """ xx, yy = ops.meshgrid(ops.arange(win_h), ops.arange(win_w), indexing="ij") - coords = ops.stack([yy, xx], axis=0) - coords_flatten = ops.reshape(coords, (2, -1)) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = ops.transpose(relative_coords, (1, 2, 0)) + coords = ops.stack([yy, xx], axis=0) + coords_flatten = ops.reshape(coords, (2, -1)) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = ops.transpose(relative_coords, (1, 2, 0)) xx = (relative_coords[:, :, 0] + win_h - 1) * (2 * win_w - 1) yy = relative_coords[:, :, 1] + win_w - 1 relative_coords = ops.stack([xx, yy], axis=-1) - relative_position_index = ops.sum(relative_coords, axis=-1) + relative_position_index = ops.sum(relative_coords, axis=-1) return relative_position_index + def window_partition(x, window_size): """Partition the input tensor into non-overlapping windows. - + Args: x: Input tensor with shape [B, H, W, C] window_size: Size of the window - + Returns: Windows with shape [B*num_windows, window_size, window_size, C] """ shape = ops.shape(x) if len(shape) != 4: - raise ValueError(f"Expected input tensor to have 4 dimensions, got {len(shape)}") - + raise ValueError( + f"Expected input tensor to have 4 dimensions, got {len(shape)}" + ) + B = shape[0] H = shape[1] W = shape[2] @@ -51,41 +58,30 @@ def window_partition(x, window_size): x = ops.pad(x, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]]) H = H + pad_h W = W + pad_w - + num_windows_h = H // window_size num_windows_w = W // window_size - + # Reshape to windows x = ops.reshape( - x, - ( - B, - num_windows_h, - window_size, - num_windows_w, - window_size, - C - ) - ) - x = ops.transpose(x, (0, 1, 3, 2, 4, 5)) - windows = ops.reshape( - x, - (-1, window_size, window_size, C) + x, (B, num_windows_h, window_size, num_windows_w, window_size, C) ) - + x = ops.transpose(x, (0, 1, 3, 2, 4, 5)) + windows = ops.reshape(x, (-1, window_size, window_size, C)) + return windows, (H, W) def window_reverse(windows, window_size, height, width, channels): """Reverse window partitioning. - + Args: windows: Windows with shape [B*num_windows, window_size, window_size, C] window_size: Size of the window height: Height of the feature map width: Width of the feature map channels: Number of channels - + Returns: Feature map with shape [B, H, W, C] """ @@ -93,8 +89,9 @@ def window_reverse(windows, window_size, height, width, channels): num_windows_h = height // window_size num_windows_w = width // window_size batch_size = ops.shape(windows)[0] // (num_windows_h * num_windows_w) - - # Reshape windows to [B, num_windows_h, num_windows_w, window_size, window_size, C] + + # Reshape windows to [B, num_windows_h, num_windows_w, window_size, + # window_size, C] x = ops.reshape( windows, ( @@ -103,23 +100,23 @@ def window_reverse(windows, window_size, height, width, channels): num_windows_w, window_size, window_size, - channels - ) + channels, + ), ) - + # Permute dimensions to get [B, H, W, C] x = ops.transpose(x, (0, 1, 3, 2, 4, 5)) x = ops.reshape(x, (batch_size, height, width, channels)) - + return x class DropPath(layers.Layer): """Drop paths (Stochastic Depth) per sample. - - This is an implementation of the paper "Deep Networks with Stochastic Depth", - which randomly drops entire layers for regularization. - + + This is an implementation of the paper "Deep Networks with Stochastic + Depth", which randomly drops entire layers for regularization. + Args: drop_prob: float, probability of dropping path. """ @@ -132,7 +129,7 @@ def call(self, x, training=None): if self.drop_prob == 0.0 or not training: return x keep_prob = 1.0 - self.drop_prob - + batch_size = ops.shape(x)[0] random_tensor = keep_prob + ops.random.uniform((batch_size, 1, 1, 1)) binary_mask = ops.floor(random_tensor) @@ -191,13 +188,15 @@ def call(self, x): def get_config(self): config = super().get_config() - config.update({ - "in_features": self.in_features, - "hidden_features": self.hidden_features, - "out_features": self.out_features, - "act_layer": keras.activations.serialize(self.act_layer), - "dropout_rate": self.dropout_rate, - }) + config.update( + { + "in_features": self.in_features, + "hidden_features": self.hidden_features, + "out_features": self.out_features, + "act_layer": keras.activations.serialize(self.act_layer), + "dropout_rate": self.dropout_rate, + } + ) return config @classmethod @@ -207,17 +206,21 @@ def from_config(cls, config): class WindowAttention(keras.layers.Layer): - """Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. + """Window based multi-head self attention (W-MSA) module with relative + position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. - head_dim (int): Number of channels per head (dim // num_heads if not set) + head_dim (int): Number of channels per head (dim // num_heads if not + set) window_size (tuple[int]): The height and width of the window. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float, optional): Override default scaling factor for queries and keys (default: head_dim ** -0.5) - attn_drop (float, optional): Dropout ratio of attention weights. Default: 0.0 + qkv_bias (bool, optional): If True, add a learnable bias to query, + key, value. Default: True + qk_scale (float, optional): Override default scaling factor for + queries and keys (default: head_dim ** -0.5) + attn_drop (float, optional): Dropout ratio of attention weights. + Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ @@ -244,20 +247,21 @@ def __init__( self.window_area = self.win_h * self.win_w self.num_heads = num_heads self.head_dim = head_dim or (dim // num_heads) - self.scale = qk_scale if qk_scale is not None else self.head_dim ** -0.5 + self.scale = qk_scale if qk_scale is not None else self.head_dim**-0.5 self.attn_dim = self.head_dim * self.num_heads self.qkv_bias = qkv_bias self.attn_drop_rate = attn_drop self.proj_drop_rate = proj_drop self.relative_position_index = get_relative_position_index( - win_h=self.win_h, - win_w=self.win_w + win_h=self.win_h, win_w=self.win_w ) def build(self, input_shape): self.qkv = keras.layers.Dense( - self.head_dim * self.num_heads * 3, use_bias=self.qkv_bias, name="attention_qkv" + self.head_dim * self.num_heads * 3, + use_bias=self.qkv_bias, + name="attention_qkv", ) self.attn_drop = keras.layers.Dropout(self.attn_drop_rate) self.proj = keras.layers.Dense(self.dim, name="attention_projection") @@ -285,7 +289,8 @@ def call( """ Args: x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or + None """ B_, N, C = ops.shape(x)[0], ops.shape(x)[1], ops.shape(x)[2] qkv = self.qkv(x) @@ -355,12 +360,15 @@ class SwinTransformerBlock(keras.layers.Layer): window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qkv_bias (bool, optional): If True, add a learnable bias to query, + key, value. Default: True drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (keras.layers.Layer, optional): Activation layer. Default: keras.layers.Activation("gelu") - norm_layer (keras.layers.Layer, optional): Normalization layer. Default: keras.layers.LayerNormalization(epsilon=1e-5) + act_layer (keras.layers.Layer, optional): Activation layer. Default: + keras.layers.Activation("gelu") + norm_layer (keras.layers.Layer, optional): Normalization layer. + Default: keras.layers.LayerNormalization(epsilon=1e-5) """ def __init__( @@ -402,7 +410,9 @@ def __init__( proj_drop=drop, name="attn", ) - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else keras.layers.Identity() + self.drop_path = ( + DropPath(drop_path) if drop_path > 0.0 else keras.layers.Identity() + ) self.norm2 = norm_layer(epsilon=1e-5, name="norm2") mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( @@ -423,7 +433,9 @@ def call(self, x): attn_mask = None if self.shift_size > 0: - shifted_x = ops.roll(x, shift=(-self.shift_size, -self.shift_size), axis=(1, 2)) + shifted_x = ops.roll( + x, shift=(-self.shift_size, -self.shift_size), axis=(1, 2) + ) img_mask = np.zeros((1, H, W, 1), dtype=np.int32) cnt = 0 h_slices = [ @@ -438,26 +450,40 @@ def call(self, x): ] for h in h_slices: for w in w_slices: - img_mask[:, h[0]:h[1], w[0]:w[1], :] = cnt + img_mask[:, h[0] : h[1], w[0] : w[1], :] = cnt cnt += 1 img_mask = ops.convert_to_tensor(img_mask) mask_windows = window_partition(img_mask, self.window_size)[0] - mask_windows = ops.reshape(mask_windows, (-1, self.window_size * self.window_size)) - attn_mask = ops.expand_dims(mask_windows, 1) - ops.expand_dims(mask_windows, 2) + mask_windows = ops.reshape( + mask_windows, (-1, self.window_size * self.window_size) + ) + attn_mask = ops.expand_dims(mask_windows, 1) - ops.expand_dims( + mask_windows, 2 + ) attn_mask = ops.where(attn_mask != 0, -100.0, 0.0) else: shifted_x = x - x_windows, (H_pad, W_pad) = window_partition(x=shifted_x, window_size=self.window_size) - x_windows = ops.reshape(x_windows, (-1, self.window_size * self.window_size, C)) + x_windows, (H_pad, W_pad) = window_partition( + x=shifted_x, window_size=self.window_size + ) + x_windows = ops.reshape( + x_windows, (-1, self.window_size * self.window_size, C) + ) attn_windows = self.attn(x_windows, mask=attn_mask) - attn_windows = ops.reshape(attn_windows, (-1, self.window_size, self.window_size, C)) - shifted_x = window_reverse(attn_windows, self.window_size, H_pad, W_pad, C) + attn_windows = ops.reshape( + attn_windows, (-1, self.window_size, self.window_size, C) + ) + shifted_x = window_reverse( + attn_windows, self.window_size, H_pad, W_pad, C + ) if self.shift_size > 0: - x = ops.roll(shifted_x, shift=(self.shift_size, self.shift_size), axis=(1, 2)) + x = ops.roll( + shifted_x, shift=(self.shift_size, self.shift_size), axis=(1, 2) + ) else: x = shifted_x @@ -486,9 +512,10 @@ def get_config(self): class PatchMerging(layers.Layer): """Patch Merging Layer. - - This layer performs downsampling by concatenating patches and using a linear layer. - + + This layer performs downsampling by concatenating patches and using a + linear layer. + Args: dim: Number of input channels. """ @@ -501,32 +528,32 @@ def __init__(self, dim, **kwargs): def call(self, x, H, W): """Forward pass. - + Args: x: Input tensor with shape [B, H*W, C]. H: Height of feature map. W: Width of feature map. - + Returns: Downsampled feature map with shape [B, H/2*W/2, 2*C]. """ B, L, C = ops.shape(x) - + x = ops.reshape(x, (B, H, W, C)) pad_values = ((0, 0), (0, H % 2), (0, W % 2), (0, 0)) x = ops.pad(x, pad_values) - + # Reshape to group patches - x0 = x[:, 0::2, 0::2, :] - x1 = x[:, 1::2, 0::2, :] - x2 = x[:, 0::2, 1::2, :] - x3 = x[:, 1::2, 1::2, :] - + x0 = x[:, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, :] + x2 = x[:, 0::2, 1::2, :] + x3 = x[:, 1::2, 1::2, :] + x = ops.concatenate([x0, x1, x2, x3], axis=-1) x = self.norm(x) x = self.reduction(x) x = ops.reshape(x, (B, -1, 2 * C)) - + return x def get_config(self): @@ -534,13 +561,15 @@ def get_config(self): config.update({"dim": self.dim}) return config + class PatchEmbedding(layers.Layer): """Image to Patch Embedding layer for Swin Transformer. Args: patch_size: int. Patch size (usually 4). embed_dim: int. Output embedding dimension. - norm_layer: Callable layer class for normalization (e.g., LayerNormalization). + norm_layer: Callable layer class for normalization (e.g., + LayerNormalization). data_format: str. Either "channels_last" or "channels_first". patch_norm: bool. Whether to apply normalization. """ @@ -586,28 +615,35 @@ def call(self, x): def get_config(self): config = super().get_config() - config.update({ - "patch_size": self.patch_size, - "embed_dim": self.embed_dim, - "data_format": self.data_format, - "patch_norm": self.patch_norm, - "norm_layer": keras.saving.serialize_keras_object(self.norm_layer) - if self.norm_layer else None, - }) + config.update( + { + "patch_size": self.patch_size, + "embed_dim": self.embed_dim, + "data_format": self.data_format, + "patch_norm": self.patch_norm, + "norm_layer": keras.saving.serialize_keras_object( + self.norm_layer + ) + if self.norm_layer + else None, + } + ) return config @classmethod def from_config(cls, config): - config["norm_layer"] = keras.saving.deserialize_keras_object(config["norm_layer"]) + config["norm_layer"] = keras.saving.deserialize_keras_object( + config["norm_layer"] + ) return cls(**config) class SwinTransformerStage(layers.Layer): """Swin Transformer Stage. - - A stage consists of multiple Swin Transformer blocks with the same resolution, - and an optional patch merging layer at the beginning. - + + A stage consists of multiple Swin Transformer blocks with the same + resolution, and an optional patch merging layer at the beginning. + Args: dim: Number of input channels. depth: Number of blocks in this stage. @@ -664,11 +700,13 @@ def build(self, input_shape): qkv_bias=self._qkv_bias, drop=self._drop, attn_drop=self._attn_drop, - drop_path=self._drop_path[i] if isinstance(self._drop_path, list) else self._drop_path, + drop_path=self._drop_path[i] + if isinstance(self._drop_path, list) + else self._drop_path, name=f"blocks_{i}", ) ) - + if self.downsample is not None: self.downsample = self.downsample( dim=self.dim, @@ -679,41 +717,50 @@ def build(self, input_shape): def call(self, x): """Forward pass. - + Args: x: Input feature with shape [B, H*W, C]. - + Returns: - Output feature with shape [B, H/2*W/2, 2*C] if downsample is applied, - otherwise [B, H*W, C]. + Output feature with shape [B, H/2*W/2, 2*C] if downsample is + applied, otherwise [B, H*W, C]. """ for block in self.blocks: x = block(x) - + if self.downsample is not None: H, W = self.input_resolution x = self.downsample(x, H=H, W=W) - + return x - + def get_config(self): config = super().get_config() - config.update({ - "dim": self.dim, - "depth": self.depth, - "num_heads": self.num_heads, - "window_size": self.window_size, - "mlp_ratio": self.mlp_ratio, - "qkv_bias": self._qkv_bias, - "drop": self._drop, - "attn_drop": self._attn_drop, - "drop_path": self._drop_path, - "downsample": keras.utils.serialize_keras_object(self.downsample) if self.downsample else None, - }) + config.update( + { + "dim": self.dim, + "depth": self.depth, + "num_heads": self.num_heads, + "window_size": self.window_size, + "mlp_ratio": self.mlp_ratio, + "qkv_bias": self._qkv_bias, + "drop": self._drop, + "attn_drop": self._attn_drop, + "drop_path": self._drop_path, + "downsample": keras.utils.serialize_keras_object( + self.downsample + ) + if self.downsample + else None, + } + ) return config @classmethod def from_config(cls, config): - config["downsample"] = keras.utils.deserialize_keras_object(config["downsample"]) if config["downsample"] else None + config["downsample"] = ( + keras.utils.deserialize_keras_object(config["downsample"]) + if config["downsample"] + else None + ) return cls(**config) - From 76d339a5bf64e716d2ccaa4dfd5bb3204b79981a Mon Sep 17 00:00:00 2001 From: ado Date: Fri, 18 Jul 2025 11:40:57 -0400 Subject: [PATCH 6/7] Added missing dropout, and initialize it directly as a Keras tensor --- .../swin_transformer_backbone.py | 21 ++++++---------- .../swin_transformer_backbone_test.py | 2 +- .../swin_transformer_layers.py | 25 +++++++++---------- 3 files changed, 20 insertions(+), 28 deletions(-) diff --git a/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py b/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py index e3719b59c3..fca0056cc8 100644 --- a/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py +++ b/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py @@ -106,12 +106,16 @@ def __init__( name="patch_embedding", ) + self.pos_drop = layers.Dropout(drop) + # Stochastic depth decay rule dpr = [float(x) for x in ops.linspace(0.0, drop_path, sum(depths))] # === Functional Model === inputs = keras.Input(shape=image_shape) x = self.patch_embedding(inputs) + x = self.pos_drop(x) + h, w = image_shape[0] // patch_size, image_shape[1] // patch_size # Build stages @@ -146,23 +150,12 @@ def __init__( for i, stage in enumerate(self.stages): x = stage(x) + features.append(self.norm_layers[i](x)) - def reshape_and_norm(tensor, norm_layer=self.norm_layers[i]): - shape = ops.shape(tensor) - B = shape[0] - L = shape[1] - C = shape[2] - H_float = ops.sqrt(ops.cast(L, x.dtype)) - H = ops.cast(H_float, "int32") - W = H - tensor = ops.reshape(tensor, (B, H, W, C)) - return norm_layer(tensor) - - x_reshaped = keras.layers.Lambda(reshape_and_norm)(x) - features.append(x_reshaped) + x = layers.LayerNormalization(epsilon=1e-5, name="output_norm")(x) super().__init__( - inputs=inputs, outputs=features[-1], dtype=dtype, **kwargs + inputs=inputs, outputs=x, dtype=dtype, **kwargs ) # === Config === diff --git a/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py b/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py index 0be4e8ff10..d2325083e6 100644 --- a/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py +++ b/keras_hub/src/models/swin_transformer/swin_transformer_backbone_test.py @@ -26,7 +26,7 @@ def test_backbone_basics(self): cls=SwinTransformerBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(1, 7, 7, 768), + expected_output_shape=(1, 49, 768), run_mixed_precision_check=False, run_quantization_check=False, ) diff --git a/keras_hub/src/models/swin_transformer/swin_transformer_layers.py b/keras_hub/src/models/swin_transformer/swin_transformer_layers.py index 706d3ab612..75eb32cbd4 100644 --- a/keras_hub/src/models/swin_transformer/swin_transformer_layers.py +++ b/keras_hub/src/models/swin_transformer/swin_transformer_layers.py @@ -436,32 +436,28 @@ def call(self, x): shifted_x = ops.roll( x, shift=(-self.shift_size, -self.shift_size), axis=(1, 2) ) - img_mask = np.zeros((1, H, W, 1), dtype=np.int32) + img_mask = ops.zeros((1, H, W, 1), dtype="int32") cnt = 0 h_slices = [ - (0, H // 2), - (H // 2, H - self.shift_size), + (0, ops.cast(H / 2, 'int32')), + (ops.cast(H / 2, 'int32'), H - self.shift_size), (H - self.shift_size, H), ] w_slices = [ - (0, W // 2), - (W // 2, W - self.shift_size), + (0, ops.cast(W / 2, 'int32')), + (ops.cast(W / 2, 'int32'), W - self.shift_size), (W - self.shift_size, W), ] for h in h_slices: for w in w_slices: - img_mask[:, h[0] : h[1], w[0] : w[1], :] = cnt + img_mask = ops.slice_update(img_mask, [0, h[0], w[0], 0], ops.ones((1, h[1] - h[0], w[1] - w[0], 1), dtype='int32') * cnt) cnt += 1 - img_mask = ops.convert_to_tensor(img_mask) mask_windows = window_partition(img_mask, self.window_size)[0] mask_windows = ops.reshape( mask_windows, (-1, self.window_size * self.window_size) ) - attn_mask = ops.expand_dims(mask_windows, 1) - ops.expand_dims( - mask_windows, 2 - ) - attn_mask = ops.where(attn_mask != 0, -100.0, 0.0) + attn_mask = ops.cast(ops.expand_dims(mask_windows, 1) != ops.expand_dims(mask_windows, 2), dtype='float32') * -100.0 else: shifted_x = x @@ -594,9 +590,10 @@ def __init__( filters=embed_dim, kernel_size=patch_size, strides=patch_size, - padding="valid", + padding="VALID", data_format=data_format, - name="proj", + name="conv_projection", + kernel_initializer="lecun_normal", ) if self.patch_norm and self.norm_layer is not None: @@ -608,11 +605,13 @@ def call(self, x): x = self.proj(x) # shape: (B, H//P, W//P, C) if self.data_format == "channels_first": x = ops.transpose(x, [0, 2, 3, 1]) + h, w = ops.shape(x)[1], ops.shape(x)[2] x = ops.reshape(x, [ops.shape(x)[0], -1, self.embed_dim]) if self.norm: x = self.norm(x) return x + def get_config(self): config = super().get_config() config.update( From c441c0a8a429ce27f3177feb68c36cbccbad2a5d Mon Sep 17 00:00:00 2001 From: ado Date: Fri, 18 Jul 2025 12:03:38 -0400 Subject: [PATCH 7/7] Added normalization init and name --- .../src/models/swin_transformer/swin_transformer_backbone.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py b/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py index fca0056cc8..665a73f2bc 100644 --- a/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py +++ b/keras_hub/src/models/swin_transformer/swin_transformer_backbone.py @@ -144,6 +144,7 @@ def __init__( layers.LayerNormalization(epsilon=1e-5, name=f"norm_{i}") for i in range(len(depths)) ] + self.norm = layers.LayerNormalization(epsilon=1e-5, name="output_norm") # Forward pass features = [] @@ -152,7 +153,7 @@ def __init__( x = stage(x) features.append(self.norm_layers[i](x)) - x = layers.LayerNormalization(epsilon=1e-5, name="output_norm")(x) + x = self.norm(x) super().__init__( inputs=inputs, outputs=x, dtype=dtype, **kwargs