-
Notifications
You must be signed in to change notification settings - Fork 818
π feat(model): add GLASS model into Anomalib #2629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: feature/model/glass
Are you sure you want to change the base?
Changes from 21 commits
5b4931b
4789f49
050fd4c
cdd0984
381eec6
9b1c51a
161005c
3d78beb
617cf49
f9d3207
7fea20f
1beedf5
838bc50
1baa0b7
f066b3c
6e780b0
b1be6f5
20d97dd
d5affe4
f008537
a1097e5
7e9d4d4
44dcd60
da57095
ba5a6dd
714a3c3
1a3519c
9e12285
5466d46
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright (C) 2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. | ||
|
||
This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both | ||
global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in | ||
industrial settings. | ||
|
||
The model consists of: | ||
- A feature extractor and feature adaptor to obtain robust normal representations | ||
- A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with | ||
truncated projection | ||
- A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks | ||
- A shared discriminator trained with features from normal, global, and local synthetic samples | ||
|
||
Paper: `A Unified Anomaly Synthesis Strategy with Gradient Ascent for Industrial Anomaly Detection and Localization | ||
<https://arxiv.org/pdf/2407.09359>` | ||
""" | ||
|
||
from .lightning_model import Glass | ||
|
||
__all__ = ["Glass"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (C) 2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Utility functions for GLASS Model.""" | ||
|
||
from .aggregator import Aggregator | ||
from .discriminator import Discriminator | ||
from .patch_maker import PatchMaker | ||
from .preprocessing import Preprocessing | ||
from .projection import Projection | ||
from .rescale_segmentor import RescaleSegmentor | ||
|
||
__all__ = ["Aggregator", | ||
"Discriminator", | ||
"PatchMaker", | ||
"Preprocessing", | ||
"Projection", | ||
"RescaleSegmentor", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Copyright (C) 2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Aggregates and reshapes features to a target dimension.""" | ||
|
||
import torch | ||
import torch.nn.functional as f | ||
|
||
|
||
class Aggregator(torch.nn.Module): | ||
"""Aggregates and reshapes features to a target dimension. | ||
|
||
Input: Multi-dimensional feature tensors | ||
Output: Reshaped and pooled features of specified target dimension | ||
""" | ||
|
||
def __init__(self, target_dim: int) -> None: | ||
super().__init__() | ||
self.target_dim = target_dim | ||
|
||
def forward(self, features: torch.Tensor) -> torch.Tensor: | ||
"""Returns reshaped and average pooled features.""" | ||
features = features.reshape(len(features), 1, -1) | ||
features = f.adaptive_avg_pool1d(features, self.target_dim) | ||
return features.reshape(len(features), -1) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (C) 2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Discriminator network for anomaly detection.""" | ||
|
||
import torch | ||
|
||
from .init_weight import init_weight | ||
|
||
|
||
class Discriminator(torch.nn.Module): | ||
"""Discriminator network for anomaly detection. | ||
|
||
Args: | ||
in_planes: Input feature dimension | ||
n_layers: Number of layers | ||
hidden: Hidden layer dimensions | ||
""" | ||
|
||
def __init__(self, in_planes: int, n_layers: int = 2, hidden: int | None = None) -> None: | ||
super().__init__() | ||
|
||
hidden_ = in_planes if hidden is None else hidden | ||
self.body = torch.nn.Sequential() | ||
for i in range(n_layers - 1): | ||
in_ = in_planes if i == 0 else hidden_ | ||
hidden_ = int(hidden_ // 1.5) if hidden is None else hidden | ||
self.body.add_module( | ||
f"block{i + 1}", | ||
torch.nn.Sequential( | ||
torch.nn.Linear(in_, hidden_), | ||
torch.nn.BatchNorm1d(hidden_), | ||
torch.nn.LeakyReLU(0.2), | ||
), | ||
) | ||
self.tail = torch.nn.Sequential( | ||
torch.nn.Linear(hidden_, 1, bias=False), | ||
torch.nn.Sigmoid(), | ||
) | ||
self.apply(init_weight) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
"""Performs a forward pass through the discriminator network. | ||
|
||
Args: | ||
x (torch.Tensor): Input tensor of shape (B, in_planes), where B is the batch size. | ||
|
||
Returns: | ||
torch.Tensor: Output tensor of shape (B, 1) containing probability scores. | ||
""" | ||
x = self.body(x) | ||
return self.tail(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright (C) 2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Initializes network weights using Xavier normal initialization.""" | ||
|
||
import torch | ||
from torch import nn | ||
|
||
|
||
def init_weight(m: nn.Module) -> None: | ||
"""Initializes network weights using Xavier normal initialization. | ||
|
||
Applies Xavier initialization for linear layers and normal initialization | ||
for convolutional and batch normalization layers. | ||
""" | ||
if isinstance(m, torch.nn.Linear): | ||
torch.nn.init.xavier_normal_(m.weight) | ||
if isinstance(m, torch.nn.BatchNorm2d): | ||
m.weight.data.normal_(1.0, 0.02) | ||
m.bias.data.fill_(0) | ||
elif isinstance(m, torch.nn.Conv2d): | ||
m.weight.data.normal_(0.0, 0.02) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (C) 2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Handles patch-based processing of feature maps.""" | ||
|
||
import torch | ||
|
||
|
||
class PatchMaker: | ||
"""Handles patch-based processing of feature maps. | ||
|
||
This class provides utilities for converting feature maps into patches, | ||
reshaping patch scores back to original dimensions, and computing global | ||
anomaly scores from patch-wise predictions. | ||
|
||
Attributes: | ||
patchsize (int): Size of each patch (patchsize x patchsize). | ||
stride (int or None): Stride used for patch extraction. Defaults to patchsize if None. | ||
top_k (int): Number of top patch scores to consider. Used for score reduction. | ||
|
||
""" | ||
|
||
def __init__(self, patchsize: int, top_k: int = 0, stride: int | None = None) -> None: | ||
self.patchsize = patchsize | ||
self.stride = stride if stride is not None else patchsize | ||
self.top_k = top_k | ||
|
||
|
||
def patchify( | ||
self, | ||
features: torch.Tensor, | ||
return_spatial_info: bool = False, | ||
) -> tuple[torch.Tensor, list[int]] | torch.Tensor: | ||
"""Converts a batch of feature maps into patches. | ||
|
||
Args: | ||
features (torch.Tensor): Input feature maps of shape (B, C, H, W). | ||
return_spatial_info (bool): If True, also returns spatial patch count. Default is False. | ||
|
||
Returns: | ||
torch.Tensor: Output tensor of shape (B, N, C, patchsize, patchsize), where N is number of patches. | ||
list[int], optional: Number of patches in (height, width) dimensions, only if return_spatial_info is True. | ||
""" | ||
padding = int((self.patchsize - 1) / 2) | ||
unfolder = torch.nn.Unfold( | ||
kernel_size=self.patchsize, | ||
stride=self.stride, | ||
padding=padding, | ||
dilation=1, | ||
) | ||
unfolded_features = unfolder(features) | ||
number_of_total_patches = [] | ||
for s in features.shape[-2:]: | ||
n_patches = (s + 2 * padding - 1 * (self.patchsize - 1) - 1) / self.stride + 1 | ||
number_of_total_patches.append(int(n_patches)) | ||
unfolded_features = unfolded_features.reshape( | ||
*features.shape[:2], | ||
self.patchsize, | ||
self.patchsize, | ||
-1, | ||
) | ||
unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3) | ||
|
||
if return_spatial_info: | ||
return unfolded_features, number_of_total_patches | ||
return unfolded_features | ||
|
||
@staticmethod | ||
def unpatch_scores(x: torch.Tensor, batchsize: int) -> torch.Tensor: | ||
"""Reshapes patch scores back into per-batch format. | ||
|
||
Args: | ||
x (torch.Tensor): Input tensor of shape (B * N, ...). | ||
batchsize (int): Original batch size. | ||
|
||
Returns: | ||
torch.Tensor: Reshaped tensor of shape (B, N, ...). | ||
""" | ||
return x.reshape(batchsize, -1, *x.shape[1:]) | ||
|
||
@staticmethod | ||
def compute_score(x: torch.Tensor) -> torch.Tensor: | ||
"""Computes final anomaly scores from patch-wise predictions. | ||
|
||
Args: | ||
x (torch.Tensor): Patch scores of shape (B, N, 1). | ||
|
||
Returns: | ||
torch.Tensor: Final anomaly score per image, shape (B,). | ||
""" | ||
x = x[:, :, 0] # remove last dimension if singleton | ||
return torch.max(x, dim=1).values |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright (C) 2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Maps input features to a fixed dimension using adaptive average pooling.""" | ||
|
||
import torch | ||
import torch.nn.functional as f | ||
|
||
|
||
class MeanMapper(torch.nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks very similar to the Aggregator class. |
||
"""Maps input features to a fixed dimension using adaptive average pooling. | ||
|
||
Input: Variable-sized feature tensors | ||
Output: Fixed-size feature representations | ||
""" | ||
|
||
def __init__(self, preprocessing_dim: int) -> None: | ||
super().__init__() | ||
self.preprocessing_dim = preprocessing_dim | ||
|
||
def forward(self, features: torch.Tensor) -> torch.Tensor: | ||
"""Applies adaptive average pooling to reshape features to a fixed size. | ||
|
||
Args: | ||
features (torch.Tensor): Input tensor of shape (B, *) where * denotes | ||
any number of remaining dimensions. It is flattened before pooling. | ||
|
||
Returns: | ||
torch.Tensor: Output tensor of shape (B, D), where D is `preprocessing_dim`. | ||
""" | ||
features = features.reshape(len(features), 1, -1) | ||
return f.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1) | ||
|
||
|
||
class Preprocessing(torch.nn.Module): | ||
"""Handles initial feature preprocessing across multiple input dimensions. | ||
|
||
Input: List of features from different backbone layers | ||
Output: Processed features with consistent dimensionality | ||
""" | ||
|
||
def __init__(self, input_dims: list[int | tuple[int, int]], output_dim: int) -> None: | ||
super().__init__() | ||
self.input_dims = input_dims | ||
self.output_dim = output_dim | ||
|
||
self.preprocessing_modules = torch.nn.ModuleList() | ||
for _ in input_dims: | ||
module = MeanMapper(output_dim) | ||
self.preprocessing_modules.append(module) | ||
|
||
def forward(self, features: list[torch.Tensor]) -> torch.Tensor: | ||
"""Applies preprocessing modules to a list of input feature tensors. | ||
|
||
Args: | ||
features (list of torch.Tensor): List of feature maps from different | ||
layers of the backbone network. Each tensor can have a different shape. | ||
|
||
Returns: | ||
torch.Tensor: A single tensor with shape (B, N, D), where B is the batch size, | ||
N is the number of feature maps, and D is the output dimension (`output_dim`). | ||
""" | ||
features_ = [] | ||
for module, feature in zip(self.preprocessing_modules, features, strict=False): | ||
features_.append(module(feature)) | ||
return torch.stack(features_, dim=1) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright (C) 2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Multi-layer projection network for feature adaptation.""" | ||
|
||
import torch | ||
|
||
from .init_weight import init_weight | ||
|
||
|
||
class Projection(torch.nn.Module): | ||
"""Multi-layer projection network for feature adaptation. | ||
|
||
Args: | ||
in_planes: Input feature dimension | ||
out_planes: Output feature dimension | ||
n_layers: Number of projection layers | ||
layer_type: Type of intermediate layers | ||
""" | ||
|
||
def __init__(self, in_planes: int, out_planes: int | None = None, n_layers: int = 1, layer_type: int = 0) -> None: | ||
super().__init__() | ||
|
||
if out_planes is None: | ||
out_planes = in_planes | ||
self.layers = torch.nn.Sequential() | ||
in_ = None | ||
out = None | ||
for i in range(n_layers): | ||
in_ = in_planes if i == 0 else out | ||
out = out_planes | ||
self.layers.add_module(f"{i}fc", torch.nn.Linear(in_, out)) | ||
if i < n_layers - 1 and layer_type > 1: | ||
self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(0.2)) | ||
self.apply(init_weight) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
"""Applies the projection network to the input features. | ||
|
||
Args: | ||
x (torch.Tensor): Input tensor of shape (B, in_planes), where B is the batch size. | ||
|
||
Returns: | ||
torch.Tensor: Transformed tensor of shape (B, out_planes). | ||
""" | ||
return self.layers(x) |
Uh oh!
There was an error while loading. Please reload this page.