diff --git a/examples/configs/model/glass.yaml b/examples/configs/model/glass.yaml new file mode 100644 index 0000000000..b6df2f283b --- /dev/null +++ b/examples/configs/model/glass.yaml @@ -0,0 +1,26 @@ +model: + class_path: anomalib.models.Glass + init_args: + input_shape: [288, 288] + backbone: resnet18 + pretrain_embed_dim: 1024 + target_embed_dim: 1024 + patchsize: 3 + patchstride: 1 + pre_trained: true + pre_projection: 1 + discriminator_layers: 2 + discriminator_hidden: 1024 + discriminator_margin: 0.5 + learning_rate: 0.0001 + step: 20 + svd: 0 + +trainer: + max_epochs: 640 + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + patience: 5 + monitor: pixel_AUROC + mode: max \ No newline at end of file diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index f7ed76d1f4..2a3afadb91 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -67,6 +67,7 @@ Fastflow, Fre, Ganomaly, + Glass, Padim, Patchcore, ReverseDistillation, @@ -105,6 +106,7 @@ class UnknownModelError(ModuleNotFoundError): "Fastflow", "Fre", "Ganomaly", + "Glass", "Padim", "Patchcore", "ReverseDistillation", diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index d5adc65ead..8bfc00236e 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -56,6 +56,7 @@ from .fastflow import Fastflow from .fre import Fre from .ganomaly import Ganomaly +from .glass import Glass from .padim import Padim from .patchcore import Patchcore from .reverse_distillation import ReverseDistillation @@ -78,6 +79,7 @@ "Fastflow", "Fre", "Ganomaly", + "Glass", "Padim", "Patchcore", "ReverseDistillation", diff --git a/src/anomalib/models/image/glass/__init__.py b/src/anomalib/models/image/glass/__init__.py new file mode 100644 index 0000000000..cac6c015e8 --- /dev/null +++ b/src/anomalib/models/image/glass/__init__.py @@ -0,0 +1,23 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. + +This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both +global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in +industrial settings. + +The model consists of: + - A feature extractor and feature adaptor to obtain robust normal representations + - A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with + truncated projection + - A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks + - A shared discriminator trained with features from normal, global, and local synthetic samples + +Paper: `A Unified Anomaly Synthesis Strategy with Gradient Ascent for Industrial Anomaly Detection and Localization +` +""" + +from .lightning_model import Glass + +__all__ = ["Glass"] diff --git a/src/anomalib/models/image/glass/components/__init__.py b/src/anomalib/models/image/glass/components/__init__.py new file mode 100644 index 0000000000..8695d68fe3 --- /dev/null +++ b/src/anomalib/models/image/glass/components/__init__.py @@ -0,0 +1,19 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Utility functions for GLASS Model.""" + +from .aggregator import Aggregator +from .discriminator import Discriminator +from .patch_maker import PatchMaker +from .preprocessing import Preprocessing +from .projection import Projection +from .rescale_segmentor import RescaleSegmentor + +__all__ = ["Aggregator", + "Discriminator", + "PatchMaker", + "Preprocessing", + "Projection", + "RescaleSegmentor", +] diff --git a/src/anomalib/models/image/glass/components/aggregator.py b/src/anomalib/models/image/glass/components/aggregator.py new file mode 100644 index 0000000000..86623876e2 --- /dev/null +++ b/src/anomalib/models/image/glass/components/aggregator.py @@ -0,0 +1,25 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Aggregates and reshapes features to a target dimension.""" + +import torch +import torch.nn.functional as f + + +class Aggregator(torch.nn.Module): + """Aggregates and reshapes features to a target dimension. + + Input: Multi-dimensional feature tensors + Output: Reshaped and pooled features of specified target dimension + """ + + def __init__(self, target_dim: int) -> None: + super().__init__() + self.target_dim = target_dim + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Returns reshaped and average pooled features.""" + features = features.reshape(len(features), 1, -1) + features = f.adaptive_avg_pool1d(features, self.target_dim) + return features.reshape(len(features), -1) diff --git a/src/anomalib/models/image/glass/components/discriminator.py b/src/anomalib/models/image/glass/components/discriminator.py new file mode 100644 index 0000000000..9c7d984fbb --- /dev/null +++ b/src/anomalib/models/image/glass/components/discriminator.py @@ -0,0 +1,52 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Discriminator network for anomaly detection.""" + +import torch + +from .init_weight import init_weight + + +class Discriminator(torch.nn.Module): + """Discriminator network for anomaly detection. + + Args: + in_planes: Input feature dimension + n_layers: Number of layers + hidden: Hidden layer dimensions + """ + + def __init__(self, in_planes: int, n_layers: int = 2, hidden: int | None = None) -> None: + super().__init__() + + hidden_ = in_planes if hidden is None else hidden + self.body = torch.nn.Sequential() + for i in range(n_layers - 1): + in_ = in_planes if i == 0 else hidden_ + hidden_ = int(hidden_ // 1.5) if hidden is None else hidden + self.body.add_module( + f"block{i + 1}", + torch.nn.Sequential( + torch.nn.Linear(in_, hidden_), + torch.nn.BatchNorm1d(hidden_), + torch.nn.LeakyReLU(0.2), + ), + ) + self.tail = torch.nn.Sequential( + torch.nn.Linear(hidden_, 1, bias=False), + torch.nn.Sigmoid(), + ) + self.apply(init_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs a forward pass through the discriminator network. + + Args: + x (torch.Tensor): Input tensor of shape (B, in_planes), where B is the batch size. + + Returns: + torch.Tensor: Output tensor of shape (B, 1) containing probability scores. + """ + x = self.body(x) + return self.tail(x) diff --git a/src/anomalib/models/image/glass/components/init_weight.py b/src/anomalib/models/image/glass/components/init_weight.py new file mode 100644 index 0000000000..1e84d12fc8 --- /dev/null +++ b/src/anomalib/models/image/glass/components/init_weight.py @@ -0,0 +1,22 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Initializes network weights using Xavier normal initialization.""" + +import torch +from torch import nn + + +def init_weight(m: nn.Module) -> None: + """Initializes network weights using Xavier normal initialization. + + Applies Xavier initialization for linear layers and normal initialization + for convolutional and batch normalization layers. + """ + if isinstance(m, torch.nn.Linear): + torch.nn.init.xavier_normal_(m.weight) + if isinstance(m, torch.nn.BatchNorm2d): + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + elif isinstance(m, torch.nn.Conv2d): + m.weight.data.normal_(0.0, 0.02) diff --git a/src/anomalib/models/image/glass/components/patch_maker.py b/src/anomalib/models/image/glass/components/patch_maker.py new file mode 100644 index 0000000000..f953d324e8 --- /dev/null +++ b/src/anomalib/models/image/glass/components/patch_maker.py @@ -0,0 +1,88 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Handles patch-based processing of feature maps.""" + +import torch + + +class PatchMaker: + """Handles patch-based processing of feature maps. + + This class provides utilities for converting feature maps into patches, + reshaping patch scores back to original dimensions, and computing global + anomaly scores from patch-wise predictions. + + Attributes: + patchsize (int): Size of each patch (patchsize x patchsize). + stride (int or None): Stride used for patch extraction. Defaults to patchsize if None. + """ + + def __init__(self, patchsize: int, stride: int | None = None) -> None: + self.patchsize = patchsize + self.stride = stride if stride is not None else patchsize + + def patchify( + self, + features: torch.Tensor, + return_spatial_info: bool = False, + ) -> tuple[torch.Tensor, list[int]] | torch.Tensor: + """Converts a batch of feature maps into patches. + + Args: + features (torch.Tensor): Input feature maps of shape (B, C, H, W). + return_spatial_info (bool): If True, also returns spatial patch count. Default is False. + + Returns: + torch.Tensor: Output tensor of shape (B, N, C, patchsize, patchsize), where N is number of patches. + list[int], optional: Number of patches in (height, width) dimensions, only if return_spatial_info is True. + """ + padding = int((self.patchsize - 1) / 2) + unfolder = torch.nn.Unfold( + kernel_size=self.patchsize, + stride=self.stride, + padding=padding, + dilation=1, + ) + unfolded_features = unfolder(features) + number_of_total_patches = [] + for s in features.shape[-2:]: + n_patches = (s + 2 * padding - 1 * (self.patchsize - 1) - 1) / self.stride + 1 + number_of_total_patches.append(int(n_patches)) + unfolded_features = unfolded_features.reshape( + *features.shape[:2], + self.patchsize, + self.patchsize, + -1, + ) + unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3) + + if return_spatial_info: + return unfolded_features, number_of_total_patches + return unfolded_features + + @staticmethod + def unpatch_scores(x: torch.Tensor, batchsize: int) -> torch.Tensor: + """Reshapes patch scores back into per-batch format. + + Args: + x (torch.Tensor): Input tensor of shape (B * N, ...). + batchsize (int): Original batch size. + + Returns: + torch.Tensor: Reshaped tensor of shape (B, N, ...). + """ + return x.reshape(batchsize, -1, *x.shape[1:]) + + @staticmethod + def compute_score(x: torch.Tensor) -> torch.Tensor: + """Computes final anomaly scores from patch-wise predictions. + + Args: + x (torch.Tensor): Patch scores of shape (B, N, 1). + + Returns: + torch.Tensor: Final anomaly score per image, shape (B,). + """ + x = x[:, :, 0] # remove last dimension if singleton + return torch.max(x, dim=1).values diff --git a/src/anomalib/models/image/glass/components/preprocessing.py b/src/anomalib/models/image/glass/components/preprocessing.py new file mode 100644 index 0000000000..8a1150d616 --- /dev/null +++ b/src/anomalib/models/image/glass/components/preprocessing.py @@ -0,0 +1,66 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Maps input features to a fixed dimension using adaptive average pooling.""" + +import torch +import torch.nn.functional as f + + +class MeanMapper(torch.nn.Module): + """Maps input features to a fixed dimension using adaptive average pooling. + + Input: Variable-sized feature tensors + Output: Fixed-size feature representations + """ + + def __init__(self, preprocessing_dim: int) -> None: + super().__init__() + self.preprocessing_dim = preprocessing_dim + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Applies adaptive average pooling to reshape features to a fixed size. + + Args: + features (torch.Tensor): Input tensor of shape (B, *) where * denotes + any number of remaining dimensions. It is flattened before pooling. + + Returns: + torch.Tensor: Output tensor of shape (B, D), where D is `preprocessing_dim`. + """ + features = features.reshape(len(features), 1, -1) + return f.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1) + + +class Preprocessing(torch.nn.Module): + """Handles initial feature preprocessing across multiple input dimensions. + + Input: List of features from different backbone layers + Output: Processed features with consistent dimensionality + """ + + def __init__(self, input_dims: list[int | tuple[int, int]], output_dim: int) -> None: + super().__init__() + self.input_dims = input_dims + self.output_dim = output_dim + + self.preprocessing_modules = torch.nn.ModuleList() + for _ in input_dims: + module = MeanMapper(output_dim) + self.preprocessing_modules.append(module) + + def forward(self, features: list[torch.Tensor]) -> torch.Tensor: + """Applies preprocessing modules to a list of input feature tensors. + + Args: + features (list of torch.Tensor): List of feature maps from different + layers of the backbone network. Each tensor can have a different shape. + + Returns: + torch.Tensor: A single tensor with shape (B, N, D), where B is the batch size, + N is the number of feature maps, and D is the output dimension (`output_dim`). + """ + features_ = [] + for module, feature in zip(self.preprocessing_modules, features, strict=False): + features_.append(module(feature)) + return torch.stack(features_, dim=1) diff --git a/src/anomalib/models/image/glass/components/projection.py b/src/anomalib/models/image/glass/components/projection.py new file mode 100644 index 0000000000..bf829ff493 --- /dev/null +++ b/src/anomalib/models/image/glass/components/projection.py @@ -0,0 +1,46 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Multi-layer projection network for feature adaptation.""" + +import torch + +from .init_weight import init_weight + + +class Projection(torch.nn.Module): + """Multi-layer projection network for feature adaptation. + + Args: + in_planes: Input feature dimension + out_planes: Output feature dimension + n_layers: Number of projection layers + layer_type: Type of intermediate layers + """ + + def __init__(self, in_planes: int, out_planes: int | None = None, n_layers: int = 1, layer_type: int = 0) -> None: + super().__init__() + + if out_planes is None: + out_planes = in_planes + self.layers = torch.nn.Sequential() + in_ = None + out = None + for i in range(n_layers): + in_ = in_planes if i == 0 else out + out = out_planes + self.layers.add_module(f"{i}fc", torch.nn.Linear(in_, out)) + if i < n_layers - 1 and layer_type > 1: + self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(0.2)) + self.apply(init_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies the projection network to the input features. + + Args: + x (torch.Tensor): Input tensor of shape (B, in_planes), where B is the batch size. + + Returns: + torch.Tensor: Transformed tensor of shape (B, out_planes). + """ + return self.layers(x) diff --git a/src/anomalib/models/image/glass/components/rescale_segmentor.py b/src/anomalib/models/image/glass/components/rescale_segmentor.py new file mode 100644 index 0000000000..ce1c11cc08 --- /dev/null +++ b/src/anomalib/models/image/glass/components/rescale_segmentor.py @@ -0,0 +1,65 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""A utility class for rescaling and smoothing patch-level anomaly scores to generate segmentation masks.""" + +import kornia.filters as kf +import numpy as np +import torch +import torch.nn.functional as f + + +class RescaleSegmentor: + """A utility class for rescaling and smoothing patch-level anomaly scores to generate segmentation masks. + + Attributes: + target_size (int): The spatial size (height and width) to which patch scores will be rescaled. + smoothing (int): The standard deviation used for Gaussian smoothing. + """ + + def __init__(self, target_size: tuple[int, int] = (288, 288)) -> None: + """Initializes the RescaleSegmentor. + + Args: + target_size (int, optional): The desired output size (height/width) of segmentation maps. Defaults to 288. + """ + self.target_size = target_size + self.smoothing = 4 + + def convert_to_segmentation( + self, + patch_scores: np.ndarray | torch.Tensor, + device: torch.device, + ) -> list[torch.Tensor]: + """Converts patch-level scores to smoothed segmentation masks. + + Args: + patch_scores (np.ndarray | torch.Tensor): Patch-wise scores of shape [N, H, W]. + device (torch.device): Device on which to perform computation. + + Returns: + List[torch.Tensor]: A list of segmentation masks, each of shape [H, W], + rescaled to `target_size` and smoothed. + """ + with torch.no_grad(): + if isinstance(patch_scores, np.ndarray): + patch_scores = torch.from_numpy(patch_scores) + + scores = patch_scores.to(device) + scores = scores.unsqueeze(1) # [N, 1, H, W] + scores = f.interpolate( + scores, + size=self.target_size, + mode="bilinear", + align_corners=False, + ) + patch_scores = scores.squeeze(1) # [N, H, W] + + patch_stack = patch_scores.unsqueeze(1) # [N, 1, H, W] + smoothed_stack = kf.gaussian_blur2d( + patch_stack, + kernel_size=(33, 33), + sigma=(self.smoothing, self.smoothing), + ) + + return [s.squeeze(0) for s in smoothed_stack] # List of [H, W] tensors diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py new file mode 100644 index 0000000000..c85233223a --- /dev/null +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -0,0 +1,326 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. + +This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both +global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in +industrial settings. + +The model consists of: + - A feature extractor and feature adaptor to obtain robust normal representations + - A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with + truncated projection + - A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks + - A shared discriminator trained with features from normal, global, and local synthetic samples + +Paper: `A Unified Anomaly Synthesis Strategy with Gradient Ascent for Industrial Anomaly Detection and Localization +` +""" + +from pathlib import Path +from typing import Any + +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import optim +from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, Resize + +from anomalib import LearningType +from anomalib.data import Batch +from anomalib.data.utils import DownloadInfo, download_and_extract +from anomalib.metrics import Evaluator +from anomalib.models.components import AnomalibModule +from anomalib.post_processing import PostProcessor +from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer + +from .torch_model import GlassModel + +DTD_DOWNLOAD_INFO = DownloadInfo( + name="dtd-r1.0.1.tar.gz", + url="https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz", + hashsum="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205", +) + + +class Glass(AnomalibModule): + """PyTorch Lightning Implementation of the GLASS Model. + + The model uses a pre-trained feature extractor to extract features and a feature adaptor to mitigate latent domain + bias. + Global anomaly features are synthesized from adapted normal features using gradient ascent. + Local anomaly images are synthesized using texture overlay datasets like dtd which are then processed by feature + extractor and feature adaptor. + All three different features are passed to the discriminator trained using loss functions. + + Args: + input_shape (tuple[int, int]): Input image dimensions as a tuple of (height, width). Required for shaping the + input pipeline. + Defaults to `(288, 288)`. + anomaly_source_path (str): Path to the dataset or source directory containing normal images and anomaly textures + backbone (str, optional): Name of the CNN backbone used for feature extraction. + Defaults to `"wide_resnet50_2"`. + pretrain_embed_dim (int, optional): Dimensionality of features extracted by the pre-trained backbone before + adaptation. + Defaults to `1536`. + target_embed_dim (int, optional): Dimensionality of the target adapted features after projection. + Defaults to `1536`. + patchsize (int, optional): Size of the local patch used in feature aggregation (e.g., for neighborhood pooling). + Defaults to `3`. + patchstride (int, optional): Stride used when extracting patches for local feature aggregation. + Defaults to `1`. + pre_trained (bool, optional): Whether to use ImageNet pre-trained weights for the backbone network. + Defaults to `True`. + layers (list[str], optional): List of backbone layers to extract features from. + Defaults to `["layer1", "layer2", "layer3"]`. + pre_projection (int, optional): Number of projection layers used in the feature adaptor (e.g., MLP before + discriminator). + Defaults to `1`. + discriminator_layers (int, optional): Number of layers in the discriminator network. + Defaults to `2`. + discriminator_hidden (int, optional): Number of hidden units in each discriminator layer. + Defaults to `1024`. + discriminator_margin (float, optional): Margin used for contrastive or binary classification loss in + discriminator training. + Defaults to `0.5`. + learning_rate (float, optional): Learning rate for training the feature adaptor and discriminator networks. + Defaults to `0.0001`. + step (int, optional): Number of gradient ascent steps for anomaly synthesis. + Defaults to `20`. + svd (int, optional): Flag to enable SVD-based feature projection. + Defaults to `0`. + pre_processor (PreProcessor | bool, optional): reprocessing module or flag to enable default preprocessing. + Set to `True` to apply default normalization and resizing. + Defaults to `True`. + post_processor (PostProcessor | bool, optional): Postprocessing module or flag to enable default output + smoothing or thresholding. + Defaults to `True`. + evaluator (Evaluator | bool, optional): Evaluation module for calculating metrics such as AUROC and PRO. + Defaults to `True`. + visualizer (Visualizer | bool, optional): Visualization module to generate heatmaps, segmentation overlays, and + anomaly scores. + Defaults to `True`. + """ + + def __init__( + self, + input_shape: tuple[int, int] = (288, 288), + anomaly_source_path: str | None = None, + backbone: str = "wide_resnet50_2", + pretrain_embed_dim: int = 1536, + target_embed_dim: int = 1536, + patchsize: int = 3, + patchstride: int = 1, + pre_trained: bool = True, + layers: list[str] | None = None, + pre_projection: int = 1, + discriminator_layers: int = 2, + discriminator_hidden: int = 1024, + discriminator_margin: float = 0.5, + learning_rate: float = 0.0001, + step: int = 20, + svd: int = 0, + pre_processor: PreProcessor | bool = True, + post_processor: PostProcessor | bool = True, + evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, + ) -> None: + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) + + if layers is None: + layers = ["layer2", "layer3"] + + if anomaly_source_path is not None: + dtd_dir = Path(anomaly_source_path) + if not dtd_dir.is_dir(): + download_and_extract(dtd_dir, DTD_DOWNLOAD_INFO) + + self.model = GlassModel( + input_shape=input_shape, + anomaly_source_path=anomaly_source_path, + pretrain_embed_dim=pretrain_embed_dim, + target_embed_dim=target_embed_dim, + backbone=backbone, + pre_trained=pre_trained, + patchsize=patchsize, + patchstride=patchstride, + layers=layers, + pre_projection=pre_projection, + discriminator_layers=discriminator_layers, + discriminator_hidden=discriminator_hidden, + discriminator_margin=discriminator_margin, + step=step, + svd=svd, + ) + + self.learning_rate = learning_rate + self.pre_trained = pre_trained + + if pre_projection > 0: + self.projection_opt = optim.Adam( + self.model.projection.parameters(), + self.learning_rate, + weight_decay=1e-5, + ) + else: + self.projection_opt = None + + if not self.pre_trained: + self.backbone_opt = optim.AdamW( + self.model.forward_modules["feature_aggregator"].backbone.parameters(), + self.learning_rate, + ) + else: + self.backbone_opt = None + + self.automatic_optimization = False + + @classmethod + def configure_pre_processor( + cls, + image_size: tuple[int, int] | None = None, + center_crop_size: tuple[int, int] | None = None, + ) -> PreProcessor: + """Configure the default pre-processor for GLASS. + + If valid center_crop_size is provided, the pre-processor will + also perform center cropping, according to the paper. + + Args: + image_size (tuple[int, int] | None, optional): Target size for + resizing. Defaults to ``(256, 256)``. + center_crop_size (tuple[int, int] | None, optional): Size for center + cropping. Defaults to ``None``. + + Returns: + PreProcessor: Configured pre-processor instance. + + Raises: + ValueError: If at least one dimension of ``center_crop_size`` is larger + than correspondent ``image_size`` dimension. + + Example: + >>> pre_processor = Glass.configure_pre_processor( + ... image_size=(256, 256) + ... ) + >>> transformed_image = pre_processor(image) + """ + image_size = image_size or (288, 288) + + if center_crop_size is not None: + if center_crop_size[0] > image_size[0] or center_crop_size[1] > image_size[1]: + msg = f"Center crop size {center_crop_size} cannot be larger than image size {image_size}." + raise ValueError(msg) + transform = Compose([ + Resize(image_size, antialias=True), + CenterCrop(center_crop_size), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + else: + transform = Compose([ + Resize(image_size, antialias=True), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + return PreProcessor(transform=transform) + + def configure_optimizers(self) -> optim.Optimizer: + """Configure optimizer for the discriminator. + + Returns: + Optimizer: AdamW Optimizer for the discriminator. + """ + return optim.AdamW(self.model.discriminator.parameters(), lr=self.learning_rate * 2) + + def training_step(self, batch: Batch, batch_idx: int) -> STEP_OUTPUT: + """Training step for GLASS model. + + Args: + batch (Batch): Input batch containing images and metadata + batch_idx (int): Index of the current batch + + Returns: + STEP_OUTPUT: Dictionary containing loss values and metrics + """ + del batch_idx + + discriminator_opt = self.optimizers() + + if not self.pre_trained: + self.model.forward_modules["feature_aggregator"].train() + if self.model.pre_projection > 0: + self.model.projection.train() + self.model.discriminator.train() + + discriminator_opt.zero_grad() + if self.projection_opt is not None: + self.projection_opt.zero_grad() + if self.backbone_opt is not None: + self.backbone_opt.zero_grad() + + true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(batch.image) + self.manual_backward(loss) + + if self.projection_opt is not None: + self.projection_opt.step() + if self.backbone_opt is not None: + self.backbone_opt.step() + discriminator_opt.step() + + self.log("true_loss", true_loss, prog_bar=True) + self.log("gaus_loss", gaus_loss, prog_bar=True) + self.log("bce_loss", bce_loss, prog_bar=True) + self.log("focal_loss", focal_loss, prog_bar=True) + self.log("loss", loss, prog_bar=True) + + def validation_step(self, batch: Batch, batch_idx: int) -> STEP_OUTPUT: + """Performs a single validation step during model evaluation. + + Args: + batch (Batch): A batch of input data, typically containing images and ground truth labels. + batch_idx (int): Index of the batch (unused in this function). + + Returns: + STEP_OUTPUT: Output of the validation step, usually containing predictions and any associated metrics. + """ + del batch_idx + self.model.forward_modules.eval() + + if self.model.pre_projection > 0: + self.model.projection.eval() + self.model.discriminator.eval() + + predictions = self.model(batch.image) + return batch.update(**predictions._asdict()) + + def on_train_epoch_start(self) -> None: + """Initialize model by computing mean feature representation across training dataset. + + This method is called at the start of training and computes a mean feature vector + that serves as a reference point for the normal class distribution. + """ + dataloader = self.trainer.train_dataloader + self.model.calculate_center(dataloader, self.device) + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type (ONE_CLASS for GLASS) + """ + return LearningType.ONE_CLASS + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return GLASS trainer arguments. + + Returns: + dict[str, Any]: Dictionary containing trainer configuration + """ + return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} diff --git a/src/anomalib/models/image/glass/loss.py b/src/anomalib/models/image/glass/loss.py new file mode 100644 index 0000000000..4fb8e53c91 --- /dev/null +++ b/src/anomalib/models/image/glass/loss.py @@ -0,0 +1,168 @@ +# Original Code +# Copyright (c) 2021 @Hsuxu +# https://github.com/Hsuxu/Loss_ToolBox-PyTorch. +# SPDX-License-Identifier: Apache-2.0 +# +# Modified +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Focal Loss for multi-class classification with optional label smoothing and class weighting. + +This loss function is designed to address class imbalance by down-weighting easy examples and focusing training +on hard, misclassified examples. It is based on the paper: +"Focal Loss for Dense Object Detection" (https://arxiv.org/abs/1708.02002). + +The focal loss formula is: + FL(pt) = -alpha * (1 - pt) ** gamma * log(pt) + +where: + - pt is the predicted probability of the correct class + - alpha is a class balancing factor + - gamma is a focusing parameter + +Supports optional label smoothing and flexible alpha input (scalar or per-class tensor). Can be used with raw logits, +applying a specified non-linearity (e.g., softmax or sigmoid). + +Args: + apply_nonlinearity (nn.Module or None): Optional non-linearity to apply to the logits before loss computation. + For example, use `nn.Softmax(dim=1)` or `nn.Sigmoid()` if logits are not normalized. + alpha (float or torch.Tensor, optional): Class balancing factor. Can be: + - None: Equal weighting for all classes. + - float: Scalar for binary class weighting; applied to `balance_index`. + - Tensor: Per-class weights of shape (num_classes,). + gamma (float): Focusing parameter (> 0) to reduce the loss contribution from easy examples. Default is 2. + balance_index (int): Index of the class to apply `alpha` to when `alpha` is a float. + smooth (float): Label smoothing factor. A small value (e.g., 1e-5) helps prevent overconfidence. + size_average (bool): If True, average the loss over the batch; if False, sum the loss. + +Raises: + ValueError: If `smooth` is outside the range [0, 1]. + TypeError: If `alpha` is not a supported type. + +Inputs: + logit (torch.Tensor): Raw model outputs (logits) of shape (B, C, ...) where B is batch size and C is number of + classes. + target (torch.Tensor): Ground-truth class indices of shape (B, 1, ...) or broadcastable to match logit. + +Returns: + torch.Tensor: Scalar loss value (averaged or summed based on `size_average`). +""" + +import numpy as np +import torch +from torch import nn + + +class FocalLoss(nn.Module): + """Implementation of Focal Loss with support for smoothed label cross-entropy. + + As proposed in 'Focal Loss for Dense Object Detection' (https://arxiv.org/abs/1708.02002). + The focal loss formula is: + Focal_Loss = -1 * alpha * (1 - pt) ** gamma * log(pt) + + Args: + num_class (int): Number of classes. + alpha (float or Tensor): Scalar or Tensor weight factor for class imbalance. If float, `balance_index` should be + set. + gamma (float): Focusing parameter that reduces the relative loss for well-classified examples (gamma > 0). + smooth (float): Label smoothing factor for cross-entropy. + balance_index (int): Index of the class to balance when `alpha` is a float. + size_average (bool, optional): If True (default), the loss is averaged over the batch; otherwise, the loss is + summed. + """ + + def __init__( + self, + apply_nonlinearity: nn.Module | None = None, + alpha: float | torch.Tensor | np.ndarray | None = None, + gamma: float = 2, + balance_index: int = 0, + smooth: float = 1e-5, + size_average: bool = True, + ) -> None: + """Initializes the FocalLoss instance. + + Args: + apply_nonlinearity (nn.Module or None): Optional non-linearity to apply to logits (e.g., softmax or sigmoid) + alpha (float or torch.Tensor, optional): Weighting factor for class imbalance. Can be: + - None: Equal weighting. + - float: Class at `balance_index` is weighted by `alpha`, others by 1 - `alpha`. + - Tensor: Direct per-class weights. + gamma (float): Focusing parameter for down-weighting easy examples (y > 0). + balance_index (int): Index of the class to apply `alpha` to when `alpha` is a float. + smooth (float): Label smoothing factor (0 to 1). + size_average (bool): If True, average the loss over the batch. If False, sum the loss. + """ + super().__init__() + self.apply_nonlinearity = apply_nonlinearity + self.alpha = alpha + self.gamma = gamma + self.balance_index = balance_index + self.smooth = smooth + self.size_average = size_average + + if self.smooth is not None and (self.smooth < 0 or self.smooth > 1.0): + msg = "smooth value should be in [0,1]" + raise ValueError(msg) + + def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Computes the focal loss between `logit` predictions and ground-truth `target`. + + Args: + logits (torch.Tensor): The predicted logits of shape (B, C, ...) where B is batch size and C is the + number of classes. + target (torch.Tensor): The ground-truth class indices of shape (B, 1, ...) or broadcastable to logit. + + Returns: + torch.Tensor: Computed focal loss value (averaged or summed depending on `size_average`). + """ + if self.apply_nonlinearity is not None: + logits = self.apply_nonlinearity(logits) + num_classes = logits.shape[1] + + if logits.dim() > 2: + logits = logits.view(logits.size(0), logits.size(1), -1) + logits = logits.permute(0, 2, 1).contiguous() + logits = logits.view(-1, logits.size(-1)) + target = torch.squeeze(target, 1) + target = target.view(-1, 1) + + alpha = self.alpha + if self.alpha is None: + alpha = torch.ones(num_classes, 1) + elif isinstance(self.alpha, (list | np.ndarray)): + alpha = torch.FloatTensor(alpha).view(num_classes, 1) + alpha = alpha / alpha.sum() + elif isinstance(self.alpha, float): + alpha = torch.ones(num_classes, 1) + alpha = alpha * (1 - self.alpha) + alpha[self.balance_index] = self.alpha + else: + msg = "Not support alpha type" + raise TypeError(msg) + + if alpha.device != logits.device: + alpha = alpha.to(logits.device) + + idx = target.cpu().long() + one_hot_key = torch.FloatTensor(target.size(0), num_classes).zero_() + one_hot_key = one_hot_key.scatter_(1, idx, 1) + if one_hot_key.device != logits.device: + one_hot_key = one_hot_key.to(logits.device) + + if self.smooth: + one_hot_key = torch.clamp( + one_hot_key, + self.smooth / (num_classes - 1), + 1.0 - self.smooth, + ) + pt = (one_hot_key * logits).sum(1) + self.smooth + logpt = pt.log() + + gamma = self.gamma + alpha = alpha[idx] + alpha = torch.squeeze(alpha) + loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt + + return loss.mean() if self.size_average else loss.sum() diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py new file mode 100644 index 0000000000..32485ad2ac --- /dev/null +++ b/src/anomalib/models/image/glass/torch_model.py @@ -0,0 +1,422 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. + +This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both +global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in +industrial settings. + +The model consists of: + - A feature extractor and feature adaptor to obtain robust normal representations + - A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with + truncated projection + - A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks + - A shared discriminator trained with features from normal, global, and local synthetic samples + +Paper: `A Unified Anomaly Synthesis Strategy with Gradient Ascent for Industrial Anomaly Detection and Localization +` +""" + +import math + +import torch +import torch.nn.functional as f +from torch import nn +from torch.utils.data import dataloader + +from anomalib.data import InferenceBatch +from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator +from anomalib.models.components import TimmFeatureExtractor +from anomalib.models.components.feature_extractors import dryrun_find_featuremap_dims + +from .components import Aggregator, Discriminator, PatchMaker, Preprocessing, Projection, RescaleSegmentor +from .loss import FocalLoss + + +def _deduce_dims( + feature_extractor: TimmFeatureExtractor, + input_size: tuple[int, int], + layers: list[str], +) -> list[int | tuple[int, int]]: + """Determines feature dimensions for each layer in the feature extractor. + + Args: + feature_extractor: The backbone feature extractor + input_size: Input image dimensions + layers: List of layer names to extract features from + """ + dimensions_mapping = dryrun_find_featuremap_dims( + feature_extractor, + input_size, + layers, + ) + + return [dimensions_mapping[layer]["num_features"] for layer in layers] + + +class GlassModel(nn.Module): + """PyTorch Implementation of the GLASS Model.""" + + def __init__( + self, + input_shape: tuple[int, int] = (288, 288), # (H, W) + anomaly_source_path: str | None = None, + pretrain_embed_dim: int = 1536, + target_embed_dim: int = 1536, + backbone: str = "wide_resnet50_2", + patchsize: int = 3, + patchstride: int = 1, + pre_trained: bool = True, + layers: list[str] | None = None, + pre_projection: int = 1, + discriminator_layers: int = 2, + discriminator_hidden: int = 1024, + discriminator_margin: float = 0.5, + step: int = 20, + svd: int = 0, + ) -> None: + super().__init__() + + if layers is None: + layers = ["layer2", "layer3"] + + self.backbone = backbone + self.layers = layers + self.input_shape = input_shape + self.pre_trained = pre_trained + + self.augmentor = PerlinAnomalyGenerator(anomaly_source_path) + + self.focal_loss = FocalLoss() + + self.forward_modules = torch.nn.ModuleDict({}) + feature_aggregator = TimmFeatureExtractor( + backbone=self.backbone, + layers=self.layers, + pre_trained=self.pre_trained, + ) + feature_dimensions = _deduce_dims(feature_aggregator, self.input_shape, layers) + self.forward_modules["feature_aggregator"] = feature_aggregator + + preprocessing = Preprocessing(feature_dimensions, pretrain_embed_dim) + self.forward_modules["preprocessing"] = preprocessing + self.target_embed_dimension = target_embed_dim + preadapt_aggregator = Aggregator(target_dim=target_embed_dim) + self.forward_modules["preadapt_aggregator"] = preadapt_aggregator + + self.pre_projection = pre_projection + if self.pre_projection > 0: + self.projection = Projection( + self.target_embed_dimension, + self.target_embed_dimension, + pre_projection, + ) + + self.discriminator_layers = discriminator_layers + self.discriminator_hidden = discriminator_hidden + self.discriminator_margin = discriminator_margin + self.discriminator = Discriminator( + self.target_embed_dimension, + n_layers=self.discriminator_layers, + hidden=self.discriminator_hidden, + ) + + self.distribution = 0 + self.step = step + self.svd = svd + + self.patch_maker = PatchMaker(patchsize, stride=patchstride) + + self.anomaly_segmentor = RescaleSegmentor(target_size=input_shape) + + def calculate_center(self, dataloader: dataloader, device: torch.device) -> None: + """Calculates and updates the center embedding from a dataset. + + This method runs the model in evaluation mode and computes the mean feature + representation (center) across the entire dataset. The center is used for + further downstream tasks such as anomaly detection or feature normalization. + + Args: + dataloader (DataLoader): A PyTorch DataLoader providing batches of data, + where each batch contains an 'image' attribute. + device (torch.device): The device on which tensors should be processed + (e.g., torch.device("cuda") or torch.device("cpu")). + + Returns: + None: The method updates `self.center` in-place with the computed center tensor. + """ + self.forward_modules.eval() + self.center = torch.tensor([1]) + with torch.no_grad(): + for i, batch in enumerate(dataloader): + if self.pre_projection > 0: + outputs = self.projection(self.generate_embeddings(batch.image.to(device))[0]) + outputs = outputs[0] if len(outputs) == 2 else outputs + else: + outputs = self._embed(batch.image.to(device), evaluation=False)[0] + + outputs = outputs[0] if len(outputs) == 2 else outputs + outputs = outputs.reshape(batch.image.to(device).shape[0], -1, outputs.shape[-1]) + + if i == 0: + self.center = torch.mean(outputs, dim=0) + else: + self.center += torch.mean(outputs, dim=0) + + def calculate_features( + self, + img: torch.Tensor, + aug: torch.Tensor, + evaluation: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Calculate and return feature embeddings for the input and augmented images. + + Depending on whether a pre-projection module is used, this method optionally applies it to the + + Args: + img (torch.Tensor): The original input image tensor. + aug (torch.Tensor): The augmented image tensor. + evaluation (bool, optional): Whether the model is in evaluation mode. Defaults to False. + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing the feature embeddings for the original + image (`true_feats`) and the augmented image (`fake_feats`). + """ + if self.pre_projection > 0: + fake_feats = self.projection( + self.generate_embeddings(aug, evaluation=evaluation)[0], + ) + fake_feats = fake_feats[0] if len(fake_feats) == 2 else fake_feats + true_feats = self.projection( + self.generate_embeddings(img, evaluation=evaluation)[0], + ) + true_feats = true_feats[0] if len(true_feats) == 2 else true_feats + else: + fake_feats = self.generate_embeddings(aug, evaluation=evaluation)[0] + assert isinstance(fake_feats, torch.Tensor) + fake_feats.requires_grad = True + true_feats = self.generate_embeddings(img, evaluation=evaluation)[0] + assert isinstance(true_feats, torch.Tensor) + true_feats.requires_grad = True + + return true_feats, fake_feats + + def generate_embeddings( + self, + images: torch.Tensor, + evaluation: bool = False, + ) -> tuple[list[torch.Tensor], list[tuple[int, int]]]: + """Generates patch-wise feature embeddings for a batch of input images. + + This method performs a forward pass through the model's feature extraction pipeline, + processes selected intermediate layers, reshapes them into patches, aligns their spatial sizes, + and passes them through preprocessing and aggregation modules. + + Args: + images (torch.Tensor): Input images of shape (B, C, H, W), where: + - B is the batch size, + - C is the number of channels, + - H and W are the image height and width. + evaluation (bool, optional): Whether to run in evaluation mode (disabling gradients). + Default is False. + + Returns: + tuple[list[torch.Tensor], list[tuple[int, int]]]: + - A list of patch-level feature tensors, each of shape (N, D, P, P), + where N is the number of patches, D is the channel dimension, and P is patch size. + - A list of (height, width) tuples indicating the number of patches in each spatial dimension + for each corresponding feature level. + """ + if not evaluation and not self.pre_trained: + self.forward_modules["feature_aggregator"].train() + features = self.forward_modules["feature_aggregator"](images, eval=evaluation) + else: + self.forward_modules["feature_aggregator"].eval() + with torch.no_grad(): + features = self.forward_modules["feature_aggregator"](images) + + features = [features[layer] for layer in self.layers] + for i, feat in enumerate(features): + if len(feat.shape) == 3: + B, L, C = feat.shape # noqa: N806 + features[i] = feat.reshape( + B, + int(math.sqrt(L)), + int(math.sqrt(L)), + C, + ).permute(0, 3, 1, 2) + + features = [self.patch_maker.patchify(x, return_spatial_info=True) for x in features] + patch_shapes = [x[1] for x in features] + patch_features = [x[0] for x in features] + ref_num_patches = patch_shapes[0] + + for i in range(1, len(patch_features)): + features_ = patch_features[i] + patch_dims = patch_shapes[i] + + features_ = features_.reshape( + features_.shape[0], + patch_dims[0], + patch_dims[1], + *features_.shape[2:], + ) + features_ = features_.permute(0, 3, 4, 5, 1, 2) + perm_base_shape = features_.shape + features_ = features_.reshape(-1, *features_.shape[-2:]) + features_ = f.interpolate( + features_.unsqueeze(1), + size=(ref_num_patches[0], ref_num_patches[1]), + mode="bilinear", + align_corners=False, + ) + features_ = features_.squeeze(1) + features_ = features_.reshape( + *perm_base_shape[:-2], + ref_num_patches[0], + ref_num_patches[1], + ) + features_ = features_.permute(0, 4, 5, 1, 2, 3) + features_ = features_.reshape(len(features_), -1, *features_.shape[-3:]) + patch_features[i] = features_ + + patch_features = [x.reshape(-1, *x.shape[-3:]) for x in patch_features] + patch_features = self.forward_modules["preprocessing"](patch_features) + patch_features = self.forward_modules["preadapt_aggregator"](patch_features) + + return patch_features, patch_shapes + + def calculate_anomaly_scores(self, images: torch.Tensor) -> torch.Tensor: + """Calculates anomaly scores and segmentation masks for a batch of input images. + + Args: + images (torch.Tensor): Batch of input images of shape [B, C, H, W]. + + Returns: + tuple[torch.Tensor, list[torch.Tensor]]: + - image_scores: Tensor of anomaly scores per image, shape [B]. + - masks: List of segmentation masks for each image, each of shape [H, W]. + """ + with torch.no_grad(): + patch_features, patch_shapes = self.generate_embeddings(images, evaluation=True) + if self.pre_projection > 0: + patch_features = self.projection(patch_features) + patch_features = patch_features[0] if len(patch_features) == 2 else patch_features + + patch_scores = image_scores = self.discriminator(patch_features) + patch_scores = self.patch_maker.unpatch_scores(patch_scores, images.shape[0]) + scales = patch_shapes[0] + patch_scores = patch_scores.reshape(images.shape[0], scales[0], scales[1]) + masks = self.anomaly_segmentor.convert_to_segmentation(patch_scores, device=images.device) + + image_scores = self.patch_maker.unpatch_scores(image_scores, batchsize=images.shape[0]) + image_scores = self.patch_maker.compute_score(image_scores) + + return image_scores, masks + + def forward( + self, + img: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | InferenceBatch: + """Forward pass to compute patch-wise feature embeddings for original and augmented images. + + Depending on whether a pre-projection module is used, this method optionally applies it to the + embeddings generated for both `img` and `aug`. If not, the embeddings are directly obtained and + `requires_grad` is enabled for them, likely for gradient-based optimization or anomaly generation. + """ + device = img.device + aug, mask_s = self.augmentor(img) + if img is not None: + batch_size = img.shape[0] + + true_feats, fake_feats = self.calculate_features(img, aug) + + h_ratio = mask_s.shape[2] // int(math.sqrt(fake_feats.shape[0] // batch_size)) + w_ratio = mask_s.shape[3] // int(math.sqrt(fake_feats.shape[0] // batch_size)) + + mask_s_resized = f.interpolate( + mask_s.float(), + size=(mask_s.shape[2] // h_ratio, mask_s.shape[3] // w_ratio), + mode="nearest", + ) + mask_s_gt = mask_s_resized.reshape(-1, 1) + + noise = torch.normal(0, 0.015, true_feats.shape).to(device) + gaus_feats = true_feats + noise + + center = self.center.repeat(img.shape[0], 1, 1) + center = center.reshape(-1, center.shape[-1]) + true_points = torch.concat( + [fake_feats[mask_s_gt[:, 0] == 0], true_feats], + dim=0, + ) + c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0) + dist_t = torch.norm(true_points - c_t_points, dim=1) + r_t = torch.tensor([torch.quantile(dist_t, q=0.75)]).to(device) + + for step in range(self.step + 1): + scores = self.discriminator(torch.cat([true_feats, gaus_feats])) + true_scores = scores[: len(true_feats)] + gaus_scores = scores[len(true_feats) :] + true_loss = nn.BCELoss()(true_scores, torch.zeros_like(true_scores)) + gaus_loss = nn.BCELoss()(gaus_scores, torch.ones_like(gaus_scores)) + bce_loss = true_loss + gaus_loss + + if step == self.step: + break + + if self.training: + grad = torch.autograd.grad(gaus_loss, [gaus_feats])[0] + grad_norm = torch.norm(grad, dim=1) + grad_norm = grad_norm.view(-1, 1) + grad_normalized = grad / (grad_norm + 1e-10) + + with torch.no_grad(): + gaus_feats.add_(0.001 * grad_normalized) + + if (step + 1) % 5 == 0: + dist_g = torch.norm(gaus_feats - center, dim=1) + proj_feats = center if self.svd == 1 else true_feats + r = r_t if self.svd == 1 else 0.5 + + h = gaus_feats - proj_feats + h_norm = dist_g if self.svd == 1 else torch.norm(h, dim=1) + alpha = torch.clamp(h_norm, r, 2 * r) + proj = (alpha / (h_norm + 1e-10)).view(-1, 1) + h = proj * h + gaus_feats = proj_feats + h + + fake_points = fake_feats[mask_s_gt[:, 0] == 1] + true_points = true_feats[mask_s_gt[:, 0] == 1] + c_f_points = center[mask_s_gt[:, 0] == 1] + dist_f = torch.norm(fake_points - c_f_points, dim=1) + proj_feats = c_f_points if self.svd == 1 else true_points + r = r_t if self.svd == 1 else 1 + + if self.svd == 1: + h = fake_points - proj_feats + h_norm = dist_f if self.svd == 1 else torch.norm(h, dim=1) + alpha = torch.clamp(h_norm, 2 * r, 4 * r) + proj = (alpha / (h_norm + 1e-10)).view(-1, 1) + h = proj * h + fake_points = proj_feats + h + fake_feats[mask_s_gt[:, 0] == 1] = fake_points + + fake_scores = self.discriminator(fake_feats) + + fake_dist = (fake_scores - mask_s_gt) ** 2 + d_hard = torch.quantile(fake_dist, q=0.5) + fake_scores_ = fake_scores[fake_dist >= d_hard].unsqueeze(1) + mask_ = mask_s_gt[fake_dist >= d_hard].unsqueeze(1) + + output = torch.cat([1 - fake_scores_, fake_scores_], dim=1) + focal_loss = self.focal_loss(output, mask_) + + if self.training: + loss = bce_loss + focal_loss + return true_loss, gaus_loss, bce_loss, focal_loss, loss + + anomaly_scores, masks = self.calculate_anomaly_scores(img) + masks = torch.stack(masks) + return InferenceBatch(pred_score=anomaly_scores, anomaly_map=masks) diff --git a/third-party-programs.txt b/third-party-programs.txt index 4f71dbc247..937224b5bd 100644 --- a/third-party-programs.txt +++ b/third-party-programs.txt @@ -49,4 +49,7 @@ terms are listed below. 9. UninetModel Copyright (c) 2025 @pangdatangtt, https://github.com/pangdatangtt/UniNet/ - SPDX-License-Identifier: MIT + +10. GLASS Model implementation is based on the original code + Copyright (c) 2024 Qiyu Chen, https://github.com/cqylunlun/GLASS + SPDX-License-Identifier: MIT