From 5b4931b0f0db1254026df6b5cad6a45224b88775 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Wed, 26 Mar 2025 13:22:33 +0530 Subject: [PATCH 01/23] Initial Implementation of GLASS Model Signed-off-by: Devansh Agarwal --- src/anomalib/models/components/__init__.py | 2 +- .../components/feature_extractors/__init__.py | 2 +- .../network_feature_extractor.py | 91 +++++++ src/anomalib/models/image/glass/__init__.py | 1 + .../models/image/glass/torch_model.py | 237 ++++++++++++++++++ 5 files changed, 331 insertions(+), 2 deletions(-) create mode 100644 src/anomalib/models/components/feature_extractors/network_feature_extractor.py create mode 100644 src/anomalib/models/image/glass/__init__.py create mode 100644 src/anomalib/models/image/glass/torch_model.py diff --git a/src/anomalib/models/components/__init__.py b/src/anomalib/models/components/__init__.py index a9108c1c16..24057cca0d 100644 --- a/src/anomalib/models/components/__init__.py +++ b/src/anomalib/models/components/__init__.py @@ -38,7 +38,7 @@ from .base import AnomalibModule, BufferListMixin, DynamicBufferMixin, MemoryBankMixin from .dimensionality_reduction import PCA, SparseRandomProjection -from .feature_extractors import TimmFeatureExtractor +from .feature_extractors import TimmFeatureExtractor, NetworkFeatureAggregator from .filters import GaussianBlur2d from .sampling import KCenterGreedy from .stats import GaussianKDE, MultiVariateGaussian diff --git a/src/anomalib/models/components/feature_extractors/__init__.py b/src/anomalib/models/components/feature_extractors/__init__.py index 66a2f36c34..c0ed169f4a 100644 --- a/src/anomalib/models/components/feature_extractors/__init__.py +++ b/src/anomalib/models/components/feature_extractors/__init__.py @@ -28,7 +28,7 @@ from .timm import TimmFeatureExtractor from .utils import dryrun_find_featuremap_dims - +from .network_feature_extractor import NetworkFeatureAggregator __all__ = [ "dryrun_find_featuremap_dims", "TimmFeatureExtractor", diff --git a/src/anomalib/models/components/feature_extractors/network_feature_extractor.py b/src/anomalib/models/components/feature_extractors/network_feature_extractor.py new file mode 100644 index 0000000000..9967296203 --- /dev/null +++ b/src/anomalib/models/components/feature_extractors/network_feature_extractor.py @@ -0,0 +1,91 @@ +import torch +from torch import nn +import copy + + +class NetworkFeatureAggregator(torch.nn.Module): + """Efficient extraction of network features.""" + + def __init__(self, backbone, layers_to_extract_from, train_backbone=False): + super(NetworkFeatureAggregator, self).__init__() + """Extraction of network features. + + Runs a network only to the last layer of the list of layers where + network features should be extracted from. + + Args: + backbone: torchvision.model + layers_to_extract_from: [list of str] + """ + self.layers_to_extract_from = layers_to_extract_from + self.backbone = backbone + self.train_backbone = train_backbone + if not hasattr(backbone, "hook_handles"): + self.backbone.hook_handles = [] + for handle in self.backbone.hook_handles: + handle.remove() + self.outputs = {} + + for extract_layer in layers_to_extract_from: + self.register_hook(extract_layer) + + self.to(self.device) + + def forward(self, images, eval=True): + self.outputs.clear() + if self.train_backbone and not eval: + self.backbone(images) + else: + with torch.no_grad(): + try: + _ = self.backbone(images) + except LastLayerToExtractReachedException: + pass + return self.outputs + + def feature_dimensions(self, input_shape): + """Computes the feature dimensions for all layers given input_shape.""" + _input = torch.ones([1] + list(input_shape)).to(self.device) + _output = self(_input) + return [_output[layer].shape[1] for layer in self.layers_to_extract_from] + + def register_hook(self, layer_name): + module = self.find_module(self.backbone, layer_name) + if module is not None: + forward_hook = ForwardHook( + self.outputs, layer_name, self.layers_to_extract_from[-1] + ) + if isinstance(module, torch.nn.Sequential): + hook = module[-1].register_forward_hook(forward_hook) + else: + hook = module.register_forward_hook(forward_hook) + self.backbone.hook_handles.append(hook) + else: + raise ValueError(f"Module {layer_name} not found in the model") + + def find_module(self, model, module_name): + for name, module in model.named_modules(): + if name == module_name: + return module + elif "." in module_name: + father, child = module_name.split(".", 1) + if name == father: + return self.find_module(module, child) + return None + + +class ForwardHook: + def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str): + self.hook_dict = hook_dict + self.layer_name = layer_name + self.raise_exception_to_break = copy.deepcopy( + layer_name == last_layer_to_extract + ) + + def __call__(self, module, input, output): + self.hook_dict[self.layer_name] = output + return None + + +class LastLayerToExtractReachedException(Exception): + pass diff --git a/src/anomalib/models/image/glass/__init__.py b/src/anomalib/models/image/glass/__init__.py new file mode 100644 index 0000000000..3a6d9bf711 --- /dev/null +++ b/src/anomalib/models/image/glass/__init__.py @@ -0,0 +1 @@ +from .torch_model import GlassModel \ No newline at end of file diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py new file mode 100644 index 0000000000..60627a1f47 --- /dev/null +++ b/src/anomalib/models/image/glass/torch_model.py @@ -0,0 +1,237 @@ +import torch +from torch import nn +import torch.nn.functional as F +from anomalib.models.components import NetworkFeatureAggregator +import math + +def init_weight(m): + if isinstance(m, torch.nn.Linear): + torch.nn.init.xavier_normal_(m.weight) + if isinstance(m, torch.nn.BatchNorm2d): + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + elif isinstance(m, torch.nn.Conv2d): + m.weight.data.normal_(0.0, 0.02) + +class Preprocessing(torch.nn.Module): + def __init__(self, input_dims, output_dim): + super(Preprocessing, self).__init__() + self.input_dims = input_dims + self.output_dim = output_dim + + self.preprocessing_modules = torch.nn.ModuleList() + for _ in input_dims: + module = MeanMapper(output_dim) + self.preprocessing_modules.append(module) + + def forward(self, features): + _features = [] + for module, feature in zip(self.preprocessing_modules, features): + _features.append(module(feature)) + return torch.stack(_features, dim=1) + + +class MeanMapper(torch.nn.Module): + def __init__(self, preprocessing_dim): + super(MeanMapper, self).__init__() + self.preprocessing_dim = preprocessing_dim + + def forward(self, features): + features = features.reshape(len(features), 1, -1) + return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1) + + +class Aggregator(torch.nn.Module): + def __init__(self, target_dim): + super(Aggregator, self).__init__() + self.target_dim = target_dim + + def forward(self, features): + """Returns reshaped and average pooled features.""" + features = features.reshape(len(features), 1, -1) + features = F.adaptive_avg_pool1d(features, self.target_dim) + return features.reshape(len(features), -1) + +class Projection(torch.nn.Module): + def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0): + super(Projection, self).__init__() + + if out_planes is None: + out_planes = in_planes + self.layers = torch.nn.Sequential() + _in = None + _out = None + for i in range(n_layers): + _in = in_planes if i == 0 else _out + _out = out_planes + self.layers.add_module(f"{i}fc", torch.nn.Linear(_in, _out)) + if i < n_layers - 1: + if layer_type > 1: + self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(.2)) + self.apply(init_weight) + + def forward(self, x): + + x = self.layers(x) + return x + +class Discriminator(torch.nn.Module): + def __init__(self, in_planes, n_layers=2, hidden=None): + super(Discriminator, self).__init__() + + _hidden = in_planes if hidden is None else hidden + self.body = torch.nn.Sequential() + for i in range(n_layers - 1): + _in = in_planes if i == 0 else _hidden + _hidden = int(_hidden // 1.5) if hidden is None else hidden + self.body.add_module('block%d' % (i + 1), + torch.nn.Sequential( + torch.nn.Linear(_in, _hidden), + torch.nn.BatchNorm1d(_hidden), + torch.nn.LeakyReLU(0.2) + )) + self.tail = torch.nn.Sequential(torch.nn.Linear(_hidden, 1, bias=False), + torch.nn.Sigmoid()) + self.apply(init_weight) + + def forward(self, x): + x = self.body(x) + x = self.tail(x) + return x + +class PatchMaker: + def __init__(self, patchsize, top_k=0, stride=None): + self.patchsize = patchsize + self.stride = stride + self.top_k = top_k + + def patchify(self, features, return_spatial_info=False): + padding = int((self.patchsize - 1) / 2) + unfolder = torch.nn.Unfold(kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1) + unfolded_features = unfolder(features) + number_of_total_patches = [] + for s in features.shape[-2:]: + n_patches = (s + 2 * padding - 1 * (self.patchsize - 1) - 1) / self.stride + 1 + number_of_total_patches.append(int(n_patches)) + unfolded_features = unfolded_features.reshape( + *features.shape[:2], self.patchsize, self.patchsize, -1 + ) + unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3) + + if return_spatial_info: + return unfolded_features, number_of_total_patches + return unfolded_features + + def unpatch_scores(self, x, batchsize): + return x.reshape(batchsize, -1, *x.shape[1:]) + + def score(self, x): + x = x[:, :, 0] + x = torch.max(x, dim=1).values + return x + +class GlassModel(nn.Module): + def __init__( + self, + input_shape, + pretrain_embed_dim, + target_embed_dim, + backbone: nn.Module, + patchsize: int =3, + patchstride: int =1, + pre_trained: bool =True, + layers: list[str] = ["layer1", "layer2", "layer3"], + pre_proj: int = 1, + dsc_layers=2, + dsc_hidden=1024 + ) -> None: + super().__init__() + self.backbone = backbone + self.layers = layers + self.input_shape = input_shape + + self.forward_modules = torch.ModuleDict({}) + feature_aggregator = NetworkFeatureAggregator( + self.backbone, self.layers, pre_trained + ) + feature_dimensions = feature_aggregator.feature_dimensions(input_shape) + self.forward_modules["feature_aggregator"] = feature_aggregator + + preprocessing = Preprocessing(feature_dimensions, pretrain_embed_dim) + self.forward_modules["preprocessing"] = preprocessing + self.target_embed_dimension = target_embed_dim + preadapt_aggregator = Aggregator(target_dim=target_embed_dim) + self.forward_modules["preadapt_aggregator"] = preadapt_aggregator + + self.pre_trained = pre_trained + + self.pre_proj = pre_proj + if self.pre_proj > 0: + self.pre_projection = Projection(self.target_embed_dimension, self.target_embed_dimension, pre_proj) + + self.discriminator = Discriminator(self.target_embed_dimension, n_layers=dsc_layers, hidden=dsc_hidden) + + self.patch_maker = PatchMaker(patchsize, stride=patchstride) + + def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): + if not eval and not self.pre_trained: + self.forward_modules["feature_aggregator"].train() + features = self.forward_modules["feature_aggregator"](images, eval=eval) + else: + self.forward_modules["feature_aggregator"].eval() + with torch.no_grad(): + features = self.forward_modules["feature_aggregator"](images) + + features = [features[layer] for layer in self.layers_to_extract_from] + for i, feat in enumerate(features): + if len(feat.shape) == 3: + B, L, C = feat.shape + features[i] = feat.reshape(B, int(math.sqrt(L)), int(math.sqrt(L)), C).permute(0, 3, 1, 2) + + features = [self.patch_maker.patchify(x, return_spatial_info=True) for x in features] + patch_shapes = [x[1] for x in features] + patch_features = [x[0] for x in features] + ref_num_patches = patch_shapes[0] + + for i in range(1, len(patch_features)): + _features = patch_features[i] + patch_dims = patch_shapes[i] + + _features = _features.reshape( + _features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:] + ) + _features = _features.permute(0, 3, 4, 5, 1, 2) + perm_base_shape = _features.shape + _features = _features.reshape(-1, *_features.shape[-2:]) + _features = F.interpolate( + _features.unsqueeze(1), + size=(ref_num_patches[0], ref_num_patches[1]), + mode="bilinear", + align_corners=False, + ) + _features = _features.squeeze(1) + _features = _features.reshape( + *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1] + ) + _features = _features.permute(0, 4, 5, 1, 2, 3) + _features = _features.reshape(len(_features), -1, *_features.shape[-3:]) + patch_features[i] = _features + + patch_features = [x.reshape(-1, *x.shape[-3:]) for x in patch_features] + patch_features = self.forward_modules["preprocessing"](patch_features) + patch_features = self.forward_modules["preadapt_aggregator"](patch_features) + + return patch_features, patch_shapes + + def forward(self, images, eval=False): + self.forward_modules.eval() + with torch.no_grad(): + if self.pre_proj > 0: + outputs = self.pre_proj(self.generate_embeddings(images, eval)) + outputs = outputs[0] if len(outputs) == 2 else outputs + else: + outputs = self.generate_embeddings(images, eval)[0] + outputs = outputs[0] if len(outputs) == 2 else outputs + outputs = outputs.reshape(images.shape[0], -1, outputs.shape[-1]) + return outputs + From 4789f49216dc5e1b339b9247ade98bc9379b6584 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Mon, 14 Apr 2025 21:08:55 +0530 Subject: [PATCH 02/23] Created the trainer class for glass model Signed-off-by: Devansh Agarwal --- .../network_feature_extractor.py | 6 +- src/anomalib/models/image/glass/__init__.py | 2 +- src/anomalib/models/image/glass/backbones.py | 50 ++++++ .../models/image/glass/lightning_model.py | 161 ++++++++++++++++++ src/anomalib/models/image/glass/loss.py | 87 ++++++++++ src/anomalib/models/image/glass/perlin.py | 73 ++++++++ .../models/image/glass/torch_model.py | 64 +++++-- 7 files changed, 421 insertions(+), 22 deletions(-) create mode 100644 src/anomalib/models/image/glass/backbones.py create mode 100644 src/anomalib/models/image/glass/lightning_model.py create mode 100644 src/anomalib/models/image/glass/loss.py create mode 100644 src/anomalib/models/image/glass/perlin.py diff --git a/src/anomalib/models/components/feature_extractors/network_feature_extractor.py b/src/anomalib/models/components/feature_extractors/network_feature_extractor.py index 9967296203..ea4a96dea6 100644 --- a/src/anomalib/models/components/feature_extractors/network_feature_extractor.py +++ b/src/anomalib/models/components/feature_extractors/network_feature_extractor.py @@ -6,7 +6,7 @@ class NetworkFeatureAggregator(torch.nn.Module): """Efficient extraction of network features.""" - def __init__(self, backbone, layers_to_extract_from, train_backbone=False): + def __init__(self, backbone, layers_to_extract_from, pre_trained=False): super(NetworkFeatureAggregator, self).__init__() """Extraction of network features. @@ -19,7 +19,7 @@ def __init__(self, backbone, layers_to_extract_from, train_backbone=False): """ self.layers_to_extract_from = layers_to_extract_from self.backbone = backbone - self.train_backbone = train_backbone + self.pre_trained = pre_trained if not hasattr(backbone, "hook_handles"): self.backbone.hook_handles = [] for handle in self.backbone.hook_handles: @@ -33,7 +33,7 @@ def __init__(self, backbone, layers_to_extract_from, train_backbone=False): def forward(self, images, eval=True): self.outputs.clear() - if self.train_backbone and not eval: + if not self.pre_trained and not eval: self.backbone(images) else: with torch.no_grad(): diff --git a/src/anomalib/models/image/glass/__init__.py b/src/anomalib/models/image/glass/__init__.py index 3a6d9bf711..95b6a1c5b7 100644 --- a/src/anomalib/models/image/glass/__init__.py +++ b/src/anomalib/models/image/glass/__init__.py @@ -1 +1 @@ -from .torch_model import GlassModel \ No newline at end of file +from .lightning_model import Glass \ No newline at end of file diff --git a/src/anomalib/models/image/glass/backbones.py b/src/anomalib/models/image/glass/backbones.py new file mode 100644 index 0000000000..8f404a9653 --- /dev/null +++ b/src/anomalib/models/image/glass/backbones.py @@ -0,0 +1,50 @@ +import torchvision.models as models +import timm + +_BACKBONES = { + "alexnet": "models.alexnet(pretrained=True)", + "resnet18": "models.resnet18(pretrained=True)", + "resnet50": "models.resnet50(pretrained=True)", + "resnet101": "models.resnet101(pretrained=True)", + "resnext101": "models.resnext101_32x8d(pretrained=True)", + "resnet200": 'timm.create_model("resnet200", pretrained=True)', + "resnest50": 'timm.create_model("resnest50d_4s2x40d", pretrained=True)', + "resnetv2_50_bit": 'timm.create_model("resnetv2_50x3_bitm", pretrained=True)', + "resnetv2_50_21k": 'timm.create_model("resnetv2_50x3_bitm_in21k", pretrained=True)', + "resnetv2_101_bit": 'timm.create_model("resnetv2_101x3_bitm", pretrained=True)', + "resnetv2_101_21k": 'timm.create_model("resnetv2_101x3_bitm_in21k", pretrained=True)', + "resnetv2_152_bit": 'timm.create_model("resnetv2_152x4_bitm", pretrained=True)', + "resnetv2_152_21k": 'timm.create_model("resnetv2_152x4_bitm_in21k", pretrained=True)', + "resnetv2_152_384": 'timm.create_model("resnetv2_152x2_bit_teacher_384", pretrained=True)', + "resnetv2_101": 'timm.create_model("resnetv2_101", pretrained=True)', + "vgg11": "models.vgg11(pretrained=True)", + "vgg19": "models.vgg19(pretrained=True)", + "vgg19_bn": "models.vgg19_bn(pretrained=True)", + "wideresnet50": "models.wide_resnet50_2(pretrained=True)", + "wideresnet101": "models.wide_resnet101_2(pretrained=True)", + "mnasnet_100": 'timm.create_model("mnasnet_100", pretrained=True)', + "mnasnet_a1": 'timm.create_model("mnasnet_a1", pretrained=True)', + "mnasnet_b1": 'timm.create_model("mnasnet_b1", pretrained=True)', + "densenet121": 'timm.create_model("densenet121", pretrained=True)', + "densenet201": 'timm.create_model("densenet201", pretrained=True)', + "inception_v4": 'timm.create_model("inception_v4", pretrained=True)', + "vit_small": 'timm.create_model("vit_small_patch16_224", pretrained=True)', + "vit_base": 'timm.create_model("vit_base_patch16_224", pretrained=True)', + "vit_large": 'timm.create_model("vit_large_patch16_224", pretrained=True)', + "vit_r50": 'timm.create_model("vit_large_r50_s32_224", pretrained=True)', + "vit_deit_base": 'timm.create_model("deit_base_patch16_224", pretrained=True)', + "vit_deit_distilled": 'timm.create_model("deit_base_distilled_patch16_224", pretrained=True)', + "vit_swin_base": 'timm.create_model("swin_base_patch4_window7_224", pretrained=True)', + "vit_swin_large": 'timm.create_model("swin_large_patch4_window7_224", pretrained=True)', + "efficientnet_b7": 'timm.create_model("tf_efficientnet_b7", pretrained=True)', + "efficientnet_b5": 'timm.create_model("tf_efficientnet_b5", pretrained=True)', + "efficientnet_b3": 'timm.create_model("tf_efficientnet_b3", pretrained=True)', + "efficientnet_b1": 'timm.create_model("tf_efficientnet_b1", pretrained=True)', + "efficientnetv2_m": 'timm.create_model("tf_efficientnetv2_m", pretrained=True)', + "efficientnetv2_l": 'timm.create_model("tf_efficientnetv2_l", pretrained=True)', + "efficientnet_b3a": 'timm.create_model("efficientnet_b3a", pretrained=True)', +} + + +def load(name): + return eval(_BACKBONES[name]) diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py new file mode 100644 index 0000000000..be45b61ac5 --- /dev/null +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -0,0 +1,161 @@ +import torch +from torch import nn +from torch import optim + +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from anomalib.data import Batch +from anomalib.models.components import AnomalibModule +from anomalib.models.components import AnomalibModule +from anomalib.metrics import Evaluator +from anomalib.post_processing import PostProcessor +from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer + +from .loss import FocalLoss +from .torch_model import GlassModel + +class Glass(AnomalibModule): + def __init__( + self, + backbone, + input_shape, + pretrain_embed_dim, + target_embed_dim, + patchsize: int = 3, + patchstride: int = 1, + pre_trained: bool = True, + layers: list[str] = ["layer1", "layer2", "layer3"], + pre_proj: int = 1, + dsc_layers: int = 2, + dsc_hidden: int = 1024, + dsc_margin: int = 0.5, + pre_processor: PreProcessor | bool = True, + post_processor: PostProcessor | bool = True, + evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, + mining: int = 1, + noise: float = 0.015, + radius: float = 0.75, + p: float = 0.5, + lr: int = 0.0001, + step: int = 0 + ): + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) + + self.model = GlassModel( + input_shape=input_shape, + pretrain_embed_dim=pretrain_embed_dim, + target_embed_dim=target_embed_dim, + backbone=backbone, + pre_trained=pre_trained, + patchsize=patchsize, + patchstride=patchstride, + layers=layers, + pre_proj=pre_proj, + dsc_layers=dsc_layers, + dsc_hidden=dsc_hidden, + dsc_margin=dsc_margin + ) + + self.p = p + self.radius = radius + self.mining = mining + self.noise = noise + self.distribution = 0 + self.lr = lr + self.step = step + + self.focal_loss = FocalLoss() + + def configure_optimizers(self) -> list[optim.Optimizer]: + optimizers = [] + if not self.model.pre_trained: + backbone_opt = optim.AdamW(self.model.foward_modules["feature_aggregator"].backbone.parameters(), self.lr) + optimizers.append(backbone_opt) + else: + optimizers.append(None) + + if self.model.pre_proj > 0: + proj_opt = optim.AdamW(self.model.pre_projection.parameters(), self.lr, weight_decay=1e-5) + optimizers.append(proj_opt) + else: + optimizers.append(None) + + dsc_opt = optim.AdamW(self.model.discriminator.parameters(), lr=self.lr * 2) + optimizers.append(dsc_opt) + + return optimizers + + def training_step( + self, + batch: Batch, + batch_idx: int + ) -> STEP_OUTPUT: + backbone_opt, proj_opt, dsc_opt = self.optimizers() + + self.model.forward_modules.eval() + if self.model.pre_proj > 0: + self.pre_projection.train() + self.model.discriminator.train() + + dsc_opt.zero_grad() + if proj_opt is not None: + proj_opt.zero_grad() + if backbone_opt is not None: + backbone_opt.zero_grad() + + img = batch.image + aug = batch.aug + + true_feats, fake_feats = self.model(img, aug) + + mask_s_gt = batch.mask_s.reshape(-1, 1) + noise = torch.normal(0, self.noise, true_feats.shape) + gaus_feats = true_feats + noise + + for step in range(self.step + 1): + scores = self.model.discriminator(torch.cat([true_feats, gaus_feats])) + true_scores = scores[:len(true_feats)] + gaus_scores = scores[len(true_feats):] + true_loss = nn.BCELoss()(true_scores, torch.zeros_like(true_scores)) + gaus_loss = nn.BCELoss()(gaus_scores, torch.ones_like(gaus_scores)) + bce_loss = true_loss + gaus_loss + + if step == self.step: + break + + grad = torch.autograd.grad(gaus_loss, [gaus_feats])[0] + grad_norm = torch.norm(grad, dim=1) + grad_norm = grad_norm.view(-1, 1) + grad_normalized = grad / (grad_norm + 1e-10) + + with torch.no_grad(): + gaus_feats.add_(0.001 * grad_normalized) + + fake_scores = self.model.discriminator(fake_feats) + + if self.p > 0: + fake_dist = (fake_scores - mask_s_gt) ** 2 + d_hard = torch.quantile(fake_dist, q=self.p) + take_scores_ = fake_scores[fake_dist >= d_hard].unsqueeze(1) + mask_ = mask_s_gt[fake_dist >= d_hard].unsqueeze(1) + else: + fake_scores_ = fake_scores + mask_ = mask_s_gt + output = torch.cat([1 - fake_scores_, fake_scores_], dim=1) + focal_loss = self.focal_loss(output, mask_) + + loss = bce_loss + focal_loss + loss.backward() + + if proj_opt is not None: + proj_opt.step() + if backbone_opt is not None: + backbone_opt.step() + dsc_opt.step() \ No newline at end of file diff --git a/src/anomalib/models/image/glass/loss.py b/src/anomalib/models/image/glass/loss.py new file mode 100644 index 0000000000..4a9dd33368 --- /dev/null +++ b/src/anomalib/models/image/glass/loss.py @@ -0,0 +1,87 @@ +import torch +import numpy as np +import torch.nn as nn + + +class FocalLoss(nn.Module): + """ + copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py + This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in + 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' + Focal_Loss= -1*alpha*(1-pt)*log(pt) + :param num_class: + :param alpha: (tensor) 3D or 4D the scalar factor for this criterion + :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more + focus on hard misclassified example + :param smooth: (float,double) smooth value when cross entropy + :param balance_index: (int) balance class index, should be specific when alpha is float + :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. + """ + + def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): + super(FocalLoss, self).__init__() + self.apply_nonlin = apply_nonlin + self.alpha = alpha + self.gamma = gamma + self.balance_index = balance_index + self.smooth = smooth + self.size_average = size_average + + if self.smooth is not None: + if self.smooth < 0 or self.smooth > 1.0: + raise ValueError('smooth value should be in [0,1]') + + def forward(self, logit, target): + if self.apply_nonlin is not None: + logit = self.apply_nonlin(logit) + num_class = logit.shape[1] + + if logit.dim() > 2: + logit = logit.view(logit.size(0), logit.size(1), -1) + logit = logit.permute(0, 2, 1).contiguous() + logit = logit.view(-1, logit.size(-1)) + target = torch.squeeze(target, 1) + target = target.view(-1, 1) + + alpha = self.alpha + if alpha is None: + alpha = torch.ones(num_class, 1) + elif isinstance(alpha, (list, np.ndarray)): + assert len(alpha) == num_class + alpha = torch.FloatTensor(alpha).view(num_class, 1) + alpha = alpha / alpha.sum() + elif isinstance(alpha, float): + alpha = torch.ones(num_class, 1) + alpha = alpha * (1 - self.alpha) + alpha[self.balance_index] = self.alpha + + else: + raise TypeError('Not support alpha type') + + if alpha.device != logit.device: + alpha = alpha.to(logit.device) + + idx = target.cpu().long() + + one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() + one_hot_key = one_hot_key.scatter_(1, idx, 1) + if one_hot_key.device != logit.device: + one_hot_key = one_hot_key.to(logit.device) + + if self.smooth: + one_hot_key = torch.clamp( + one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) + pt = (one_hot_key * logit).sum(1) + self.smooth + logpt = pt.log() + + gamma = self.gamma + + alpha = alpha[idx] + alpha = torch.squeeze(alpha) + loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt + + if self.size_average: + loss = loss.mean() + else: + loss = loss.sum() + return loss diff --git a/src/anomalib/models/image/glass/perlin.py b/src/anomalib/models/image/glass/perlin.py new file mode 100644 index 0000000000..0c0f27a328 --- /dev/null +++ b/src/anomalib/models/image/glass/perlin.py @@ -0,0 +1,73 @@ +import imgaug.augmenters as iaa +import numpy as np +import torch +import math + + +def generate_thr(img_shape, min=0, max=4): + min_perlin_scale = min + max_perlin_scale = max + perlin_scalex = 2 ** np.random.randint(min_perlin_scale, max_perlin_scale) + perlin_scaley = 2 ** np.random.randint(min_perlin_scale, max_perlin_scale) + perlin_noise_np = rand_perlin_2d_np((img_shape[1], img_shape[2]), (perlin_scalex, perlin_scaley)) + threshold = 0.5 + perlin_noise_np = iaa.Sequential([iaa.Affine(rotate=(-90, 90))])(image=perlin_noise_np) + perlin_thr = np.where(perlin_noise_np > threshold, np.ones_like(perlin_noise_np), np.zeros_like(perlin_noise_np)) + return perlin_thr + + +def perlin_mask(img_shape, feat_size, min, max, mask_fg, flag=0): + mask = np.zeros((feat_size, feat_size)) + while np.max(mask) == 0: + perlin_thr_1 = generate_thr(img_shape, min, max) + perlin_thr_2 = generate_thr(img_shape, min, max) + temp = torch.rand(1).numpy()[0] + if temp > 2 / 3: + perlin_thr = perlin_thr_1 + perlin_thr_2 + perlin_thr = np.where(perlin_thr > 0, np.ones_like(perlin_thr), np.zeros_like(perlin_thr)) + elif temp > 1 / 3: + perlin_thr = perlin_thr_1 * perlin_thr_2 + else: + perlin_thr = perlin_thr_1 + perlin_thr = torch.from_numpy(perlin_thr) + perlin_thr_fg = perlin_thr * mask_fg + down_ratio_y = int(img_shape[1] / feat_size) + down_ratio_x = int(img_shape[2] / feat_size) + mask_ = perlin_thr_fg + mask = torch.nn.functional.max_pool2d(perlin_thr_fg.unsqueeze(0).unsqueeze(0), (down_ratio_y, down_ratio_x)).float() + mask = mask.numpy()[0, 0] + mask_s = mask + if flag != 0: + mask_l = mask_.numpy() + if flag == 0: + return mask_s + else: + return mask_s, mask_l + + +def lerp_np(x, y, w): + fin_out = (y - x) * w + x + return fin_out + + +def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 + + angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1) + gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) + tt = np.repeat(np.repeat(gradients, d[0], axis=0), d[1], axis=1) + + tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]], d[0], axis=0), d[1], + axis=1) + dot = lambda grad, shift: ( + np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), + axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1) + + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) + n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) + n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) + n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) + t = fade(grid[:shape[0], :shape[1]]) + return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1]) diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 60627a1f47..5b952f5e9c 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -3,6 +3,8 @@ import torch.nn.functional as F from anomalib.models.components import NetworkFeatureAggregator import math +import scipy.ndimage as ndimage +import numpy as np def init_weight(m): if isinstance(m, torch.nn.Linear): @@ -130,6 +132,25 @@ def score(self, x): x = torch.max(x, dim=1).values return x +class RescaleSegmentor: + def __init__(self, target_size=288): + self.target_size = target_size + self.smoothing = 4 + + def convert_to_segmentation(self, patch_scores): + with torch.no_grad(): + if isinstance(patch_scores, np.ndarray): + patch_scores = torch.from_numpy(patch_scores) + _scores = patch_scores + _scores = _scores.unsqueeze(1) + _scores = F.interpolate( + _scores, size=self.target_size, mode="bilinear", align_corners=False + ) + _scores = _scores.squeeze(1) + patch_scores = _scores.cpu().numpy() + return [ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores] + + class GlassModel(nn.Module): def __init__( self, @@ -137,13 +158,14 @@ def __init__( pretrain_embed_dim, target_embed_dim, backbone: nn.Module, - patchsize: int =3, - patchstride: int =1, - pre_trained: bool =True, + patchsize: int = 3, + patchstride: int = 1, + pre_trained: bool = True, layers: list[str] = ["layer1", "layer2", "layer3"], pre_proj: int = 1, - dsc_layers=2, - dsc_hidden=1024 + dsc_layers:int = 2, + dsc_hidden:int = 1024, + dsc_margin:int = 0.5 ) -> None: super().__init__() self.backbone = backbone @@ -169,10 +191,15 @@ def __init__( if self.pre_proj > 0: self.pre_projection = Projection(self.target_embed_dimension, self.target_embed_dimension, pre_proj) - self.discriminator = Discriminator(self.target_embed_dimension, n_layers=dsc_layers, hidden=dsc_hidden) + self.dsc_layers = dsc_layers + self.dsc_hidden = dsc_hidden + self.dsc_margin = dsc_margin + self.discriminator = Discriminator(self.target_embed_dimension, n_layers=self.dsc_layers, hidden=self.dsc_hidden) self.patch_maker = PatchMaker(patchsize, stride=patchstride) + self.anomaly_segmentor = RescaleSegmentor() + def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): if not eval and not self.pre_trained: self.forward_modules["feature_aggregator"].train() @@ -223,15 +250,16 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): return patch_features, patch_shapes - def forward(self, images, eval=False): - self.forward_modules.eval() - with torch.no_grad(): - if self.pre_proj > 0: - outputs = self.pre_proj(self.generate_embeddings(images, eval)) - outputs = outputs[0] if len(outputs) == 2 else outputs - else: - outputs = self.generate_embeddings(images, eval)[0] - outputs = outputs[0] if len(outputs) == 2 else outputs - outputs = outputs.reshape(images.shape[0], -1, outputs.shape[-1]) - return outputs - + def forward(self, img, aug, eval=False): + if self.pre_proj > 0: + fake_feats = self.pre_projection(self.generate_embeddings(aug, evaluation=eval)[0]) + fake_feats = fake_feats[0] if len(fake_feats) == 2 else fake_feats + true_feats = self.pre_projection(self.generate_embeddings(img, evaluation=eval)[0]) + true_feats = true_feats[0] if len(true_feats) == 2 else true_feats + else: + fake_feats = self.generate_embeddings(aug, evaluation=eval)[0] + fake_feats.requires_grad = True + true_feats = self.generate_embeddings(img, evaluation=eval)[0] + true_feats.requires_grad = True + + return true_feats, fake_feats From 050fd4c1d4316436b5ca576a533ec758407981ce Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Mon, 28 Apr 2025 01:18:31 +0530 Subject: [PATCH 03/23] Added suggested changes Signed-off-by: Devansh Agarwal --- src/anomalib/models/__init__.py | 2 + src/anomalib/models/image/__init__.py | 4 + src/anomalib/models/image/glass/__init__.py | 5 +- src/anomalib/models/image/glass/backbones.py | 6 + .../models/image/glass/lightning_model.py | 142 +++++++++++++----- src/anomalib/models/image/glass/loss.py | 17 ++- src/anomalib/models/image/glass/perlin.py | 88 +++++++++-- .../models/image/glass/torch_model.py | 112 +++++++++----- third-party-programs.txt | 4 + 9 files changed, 287 insertions(+), 93 deletions(-) diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index 78fc07245c..7b4d74f23d 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -66,6 +66,7 @@ Fastflow, Fre, Ganomaly, + Glass, Padim, Patchcore, ReverseDistillation, @@ -94,6 +95,7 @@ class UnknownModelError(ModuleNotFoundError): "Fastflow", "Fre", "Ganomaly", + "Glass", "Padim", "Patchcore", "ReverseDistillation", diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index 2717290f3a..d58f9805df 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -55,6 +55,7 @@ from .fastflow import Fastflow from .fre import Fre from .ganomaly import Ganomaly +from .glass import Glass from .padim import Padim from .patchcore import Patchcore from .reverse_distillation import ReverseDistillation @@ -76,6 +77,7 @@ "Fastflow", "Fre", "Ganomaly", + "Glass", "Padim", "Patchcore", "ReverseDistillation", @@ -84,4 +86,6 @@ "Uflow", "VlmAd", "WinClip", + "Padim", + "Glass", ] diff --git a/src/anomalib/models/image/glass/__init__.py b/src/anomalib/models/image/glass/__init__.py index 95b6a1c5b7..06a7d7a9c2 100644 --- a/src/anomalib/models/image/glass/__init__.py +++ b/src/anomalib/models/image/glass/__init__.py @@ -1 +1,4 @@ -from .lightning_model import Glass \ No newline at end of file +# Copyright (C) 2022-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Glass as Glass diff --git a/src/anomalib/models/image/glass/backbones.py b/src/anomalib/models/image/glass/backbones.py index 8f404a9653..3795819ded 100644 --- a/src/anomalib/models/image/glass/backbones.py +++ b/src/anomalib/models/image/glass/backbones.py @@ -1,6 +1,12 @@ +# Copyright (C) 2022-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import torchvision.models as models import timm +"""copy from: https://github.com/cqylunlun/GLASS/blob/main/backbones.py +This provides mechanism to import any of the given backbones using its name. +""" _BACKBONES = { "alexnet": "models.alexnet(pretrained=True)", "resnet18": "models.resnet18(pretrained=True)", diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index be45b61ac5..cfa8d7725f 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -1,45 +1,51 @@ -import torch -from torch import nn -from torch import optim +# Copyright (C) 2022-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any +import torch from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import nn, optim +from anomalib import LearningType from anomalib.data import Batch -from anomalib.models.components import AnomalibModule -from anomalib.models.components import AnomalibModule from anomalib.metrics import Evaluator +from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor from anomalib.visualization import Visualizer from .loss import FocalLoss +from .perlin import PerlinNoise from .torch_model import GlassModel + class Glass(AnomalibModule): def __init__( - self, - backbone, - input_shape, - pretrain_embed_dim, - target_embed_dim, - patchsize: int = 3, - patchstride: int = 1, - pre_trained: bool = True, - layers: list[str] = ["layer1", "layer2", "layer3"], - pre_proj: int = 1, - dsc_layers: int = 2, - dsc_hidden: int = 1024, - dsc_margin: int = 0.5, - pre_processor: PreProcessor | bool = True, - post_processor: PostProcessor | bool = True, - evaluator: Evaluator | bool = True, - visualizer: Visualizer | bool = True, - mining: int = 1, - noise: float = 0.015, - radius: float = 0.75, - p: float = 0.5, - lr: int = 0.0001, - step: int = 0 + self, + input_shape, + anomaly_source_path: str, + backbone: str | nn.Module = "resnet18", + pretrain_embed_dim: int = 1024, + target_embed_dim: int = 1024, + patchsize: int = 3, + patchstride: int = 1, + pre_trained: bool = True, + layers: list[str] = ["layer1", "layer2", "layer3"], + pre_proj: int = 1, + dsc_layers: int = 2, + dsc_hidden: int = 1024, + dsc_margin: float = 0.5, + pre_processor: PreProcessor | bool = True, + post_processor: PostProcessor | bool = True, + evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, + mining: int = 1, + noise: float = 0.015, + radius: float = 0.75, + p: float = 0.5, + lr: float = 0.0001, + step: int = 20, ): super().__init__( pre_processor=pre_processor, @@ -48,6 +54,8 @@ def __init__( visualizer=visualizer, ) + self.perlin = PerlinNoise(anomaly_source_path) + self.model = GlassModel( input_shape=input_shape, pretrain_embed_dim=pretrain_embed_dim, @@ -60,9 +68,10 @@ def __init__( pre_proj=pre_proj, dsc_layers=dsc_layers, dsc_hidden=dsc_hidden, - dsc_margin=dsc_margin + dsc_margin=dsc_margin, ) + self.c = torch.tensor([1]) self.p = p self.radius = radius self.mining = mining @@ -91,11 +100,11 @@ def configure_optimizers(self) -> list[optim.Optimizer]: optimizers.append(dsc_opt) return optimizers - + def training_step( - self, - batch: Batch, - batch_idx: int + self, + batch: Batch, + batch_idx: int, ) -> STEP_OUTPUT: backbone_opt, proj_opt, dsc_opt = self.optimizers() @@ -111,24 +120,35 @@ def training_step( backbone_opt.zero_grad() img = batch.image - aug = batch.aug + aug, mask_s = self.perlin(img) true_feats, fake_feats = self.model(img, aug) - mask_s_gt = batch.mask_s.reshape(-1, 1) + mask_s_gt = mask_s.reshape(-1, 1) noise = torch.normal(0, self.noise, true_feats.shape) gaus_feats = true_feats + noise + center = self.c.repeat(img.shape[0], 1, 1) + center = center.reshape(-1, center.shape[-1]) + true_points = torch.concat([fake_feats[mask_s_gt[:, 0] == 0], true_feats], dim=0) + c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0) + dist_t = torch.norm(true_points - c_t_points, dim=1) + r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(self.device) + for step in range(self.step + 1): scores = self.model.discriminator(torch.cat([true_feats, gaus_feats])) - true_scores = scores[:len(true_feats)] - gaus_scores = scores[len(true_feats):] + true_scores = scores[: len(true_feats)] + gaus_scores = scores[len(true_feats) :] true_loss = nn.BCELoss()(true_scores, torch.zeros_like(true_scores)) gaus_loss = nn.BCELoss()(gaus_scores, torch.ones_like(gaus_scores)) bce_loss = true_loss + gaus_loss if step == self.step: break + if self.mining == 0: + dist_g = torch.norm(gaus_feats - center, dim=1) + r_g = torch.tensor([torch.quantile(dist_g, q=self.radius)]) + break grad = torch.autograd.grad(gaus_loss, [gaus_feats])[0] grad_norm = torch.norm(grad, dim=1) @@ -138,12 +158,29 @@ def training_step( with torch.no_grad(): gaus_feats.add_(0.001 * grad_normalized) + fake_points = fake_feats[mask_s_gt[:, 0] == 1] + true_points = true_feats[mask_s_gt[:, 0] == 1] + c_f_points = center[mask_s_gt[:, 0] == 1] + dist_f = torch.norm(fake_points - c_f_points, dim=1) + r_f = torch.tensor([torch.quantile(dist_f, q=self.radius)]).to(self.device) + proj_feats = c_f_points if self.svd == 1 else true_points + r = r_t if self.svd == 1 else 1 + + if self.svd == 1: + h = fake_points - proj_feats + h_norm = dist_f if self.svd == 1 else torch.norm(h, dim=1) + alpha = torch.clamp(h_norm, 2 * r, 4 * r) + proj = (alpha / (h_norm + 1e-10)).view(-1, 1) + h = proj * h + fake_points = proj_feats + h + fake_feats[mask_s_gt[:, 0] == 1] = fake_points + fake_scores = self.model.discriminator(fake_feats) if self.p > 0: fake_dist = (fake_scores - mask_s_gt) ** 2 d_hard = torch.quantile(fake_dist, q=self.p) - take_scores_ = fake_scores[fake_dist >= d_hard].unsqueeze(1) + fake_scores_ = fake_scores[fake_dist >= d_hard].unsqueeze(1) mask_ = mask_s_gt[fake_dist >= d_hard].unsqueeze(1) else: fake_scores_ = fake_scores @@ -158,4 +195,31 @@ def training_step( proj_opt.step() if backbone_opt is not None: backbone_opt.step() - dsc_opt.step() \ No newline at end of file + dsc_opt.step() + + def on_train_start(self) -> None: + dataloader = self.trainer.train_dataloader + + with torch.no_grad(): + for i, batch in enumerate(dataloader): + if i == 0: + self.c = self.model.calculate_mean(batch.image) + else: + self.c += self.model.calculate_mean(batch.image) + + self.c /= len(dataloader) + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type (ONE_CLASS for GLASS) + """ + return LearningType.ONE_CLASS + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return GLASS trainer arguments.""" + return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} + # TODO diff --git a/src/anomalib/models/image/glass/loss.py b/src/anomalib/models/image/glass/loss.py index 4a9dd33368..f1f5341d47 100644 --- a/src/anomalib/models/image/glass/loss.py +++ b/src/anomalib/models/image/glass/loss.py @@ -1,11 +1,13 @@ -import torch +# Copyright (C) 2022-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import numpy as np -import torch.nn as nn +import torch +from torch import nn class FocalLoss(nn.Module): - """ - copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py + """copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' Focal_Loss= -1*alpha*(1-pt)*log(pt) @@ -29,7 +31,7 @@ def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smoo if self.smooth is not None: if self.smooth < 0 or self.smooth > 1.0: - raise ValueError('smooth value should be in [0,1]') + raise ValueError("smooth value should be in [0,1]") def forward(self, logit, target): if self.apply_nonlin is not None: @@ -56,7 +58,7 @@ def forward(self, logit, target): alpha[self.balance_index] = self.alpha else: - raise TypeError('Not support alpha type') + raise TypeError("Not support alpha type") if alpha.device != logit.device: alpha = alpha.to(logit.device) @@ -69,8 +71,7 @@ def forward(self, logit, target): one_hot_key = one_hot_key.to(logit.device) if self.smooth: - one_hot_key = torch.clamp( - one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) + one_hot_key = torch.clamp(one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) pt = (one_hot_key * logit).sum(1) + self.smooth logpt = pt.log() diff --git a/src/anomalib/models/image/glass/perlin.py b/src/anomalib/models/image/glass/perlin.py index 0c0f27a328..2373d6f05d 100644 --- a/src/anomalib/models/image/glass/perlin.py +++ b/src/anomalib/models/image/glass/perlin.py @@ -1,7 +1,19 @@ +# Copyright (C) 2022-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import glob +import math + import imgaug.augmenters as iaa import numpy as np +import PIL +import PIL.Image import torch -import math +from torch import nn +from torchvision import transforms + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] def generate_thr(img_shape, min=0, max=4): @@ -34,15 +46,17 @@ def perlin_mask(img_shape, feat_size, min, max, mask_fg, flag=0): down_ratio_y = int(img_shape[1] / feat_size) down_ratio_x = int(img_shape[2] / feat_size) mask_ = perlin_thr_fg - mask = torch.nn.functional.max_pool2d(perlin_thr_fg.unsqueeze(0).unsqueeze(0), (down_ratio_y, down_ratio_x)).float() + mask = torch.nn.functional.max_pool2d( + perlin_thr_fg.unsqueeze(0).unsqueeze(0), + (down_ratio_y, down_ratio_x), + ).float() mask = mask.numpy()[0, 0] mask_s = mask if flag != 0: mask_l = mask_.numpy() if flag == 0: return mask_s - else: - return mask_s, mask_l + return mask_s, mask_l def lerp_np(x, y, w): @@ -50,24 +64,76 @@ def lerp_np(x, y, w): return fin_out -def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): +def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): delta = (res[0] / shape[0], res[1] / shape[1]) d = (shape[0] // res[0], shape[1] // res[1]) - grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 + grid = np.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(1, 2, 0) % 1 angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1) gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) tt = np.repeat(np.repeat(gradients, d[0], axis=0), d[1], axis=1) - tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]], d[0], axis=0), d[1], - axis=1) + tile_grads = lambda slice1, slice2: np.repeat( + np.repeat(gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]], d[0], axis=0), + d[1], + axis=1, + ) dot = lambda grad, shift: ( - np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), - axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1) + np.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), axis=-1) + * grad[: shape[0], : shape[1]] + ).sum(axis=-1) n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) - t = fade(grid[:shape[0], :shape[1]]) + t = fade(grid[: shape[0], : shape[1]]) return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1]) + + +class PerlinNoise(nn.Module): + def __init__(self, anomaly_source_path): + super().__init__() + self.anomaly_source_paths = sorted(glob.glob(anomaly_source_path + "/*/*.jpg")) + + def forward(self, image): + aug = PIL.Image.open(np.random.choice(self.anomaly_source_paths)).convert("RGB") + transform_aug = self.rand_augmenter() + aug = transform_aug(aug) + + mask_all = perlin_mask(image.shape, self.imgsize // self.downsampling, 0, 6, mask_fg, 1) + mask_s = torch.from_numpy(mask_all[0]) + mask_l = torch.from_numpy(mask_all[1]) + mask_fg = torch.tensor([1]) + + beta = np.random.normal(loc=self.mean, scale=self.std) + beta = np.clip(beta, 0.2, 0.8) + aug_image = image * (1 - mask_l) + (1 - beta) * aug * mask_l + beta * image * mask_l + return aug_image, mask_s + + def rand_augmenter(self): + list_aug = [ + transforms.ColorJitter(contrast=(0.8, 1.2)), + transforms.ColorJitter(brightness=(0.8, 1.2)), + transforms.ColorJitter(saturation=(0.8, 1.2), hue=(-0.2, 0.2)), + transforms.RandomHorizontalFlip(p=1), + transforms.RandomVerticalFlip(p=1), + transforms.RandomGrayscale(p=1), + transforms.RandomAutocontrast(p=1), + transforms.RandomEqualize(p=1), + transforms.RandomAffine(degrees=(-45, 45)), + ] + aug_idx = np.random.choice(np.arange(len(list_aug)), 3, replace=False) + + transform_aug = [ + transforms.Resize(self.resize), + list_aug[aug_idx[0]], + list_aug[aug_idx[1]], + list_aug[aug_idx[2]], + transforms.CenterCrop(self.imgsize), + transforms.ToTensor(), + transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ] + + transform_aug = transforms.Compose(transform_aug) + return transform_aug diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 5b952f5e9c..d9c559be3a 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -1,10 +1,16 @@ +# Copyright (C) 2022-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import math + +import numpy as np import torch -from torch import nn import torch.nn.functional as F +from scipy import ndimage +from torch import nn + from anomalib.models.components import NetworkFeatureAggregator -import math -import scipy.ndimage as ndimage -import numpy as np + def init_weight(m): if isinstance(m, torch.nn.Linear): @@ -15,6 +21,7 @@ def init_weight(m): elif isinstance(m, torch.nn.Conv2d): m.weight.data.normal_(0.0, 0.02) + class Preprocessing(torch.nn.Module): def __init__(self, input_dims, output_dim): super(Preprocessing, self).__init__() @@ -28,7 +35,7 @@ def __init__(self, input_dims, output_dim): def forward(self, features): _features = [] - for module, feature in zip(self.preprocessing_modules, features): + for module, feature in zip(self.preprocessing_modules, features, strict=False): _features.append(module(feature)) return torch.stack(_features, dim=1) @@ -54,6 +61,7 @@ def forward(self, features): features = F.adaptive_avg_pool1d(features, self.target_dim) return features.reshape(len(features), -1) + class Projection(torch.nn.Module): def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0): super(Projection, self).__init__() @@ -69,14 +77,14 @@ def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0): self.layers.add_module(f"{i}fc", torch.nn.Linear(_in, _out)) if i < n_layers - 1: if layer_type > 1: - self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(.2)) + self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(0.2)) self.apply(init_weight) def forward(self, x): - x = self.layers(x) return x - + + class Discriminator(torch.nn.Module): def __init__(self, in_planes, n_layers=2, hidden=None): super(Discriminator, self).__init__() @@ -86,14 +94,15 @@ def __init__(self, in_planes, n_layers=2, hidden=None): for i in range(n_layers - 1): _in = in_planes if i == 0 else _hidden _hidden = int(_hidden // 1.5) if hidden is None else hidden - self.body.add_module('block%d' % (i + 1), - torch.nn.Sequential( - torch.nn.Linear(_in, _hidden), - torch.nn.BatchNorm1d(_hidden), - torch.nn.LeakyReLU(0.2) - )) - self.tail = torch.nn.Sequential(torch.nn.Linear(_hidden, 1, bias=False), - torch.nn.Sigmoid()) + self.body.add_module( + "block%d" % (i + 1), + torch.nn.Sequential( + torch.nn.Linear(_in, _hidden), + torch.nn.BatchNorm1d(_hidden), + torch.nn.LeakyReLU(0.2), + ), + ) + self.tail = torch.nn.Sequential(torch.nn.Linear(_hidden, 1, bias=False), torch.nn.Sigmoid()) self.apply(init_weight) def forward(self, x): @@ -101,6 +110,7 @@ def forward(self, x): x = self.tail(x) return x + class PatchMaker: def __init__(self, patchsize, top_k=0, stride=None): self.patchsize = patchsize @@ -116,7 +126,10 @@ def patchify(self, features, return_spatial_info=False): n_patches = (s + 2 * padding - 1 * (self.patchsize - 1) - 1) / self.stride + 1 number_of_total_patches.append(int(n_patches)) unfolded_features = unfolded_features.reshape( - *features.shape[:2], self.patchsize, self.patchsize, -1 + *features.shape[:2], + self.patchsize, + self.patchsize, + -1, ) unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3) @@ -132,6 +145,7 @@ def score(self, x): x = torch.max(x, dim=1).values return x + class RescaleSegmentor: def __init__(self, target_size=288): self.target_size = target_size @@ -144,7 +158,10 @@ def convert_to_segmentation(self, patch_scores): _scores = patch_scores _scores = _scores.unsqueeze(1) _scores = F.interpolate( - _scores, size=self.target_size, mode="bilinear", align_corners=False + _scores, + size=self.target_size, + mode="bilinear", + align_corners=False, ) _scores = _scores.squeeze(1) patch_scores = _scores.cpu().numpy() @@ -155,17 +172,17 @@ class GlassModel(nn.Module): def __init__( self, input_shape, - pretrain_embed_dim, - target_embed_dim, - backbone: nn.Module, + pretrain_embed_dim: int = 1024, + target_embed_dim: int = 1024, + backbone: str | nn.Module = "resnet18", patchsize: int = 3, patchstride: int = 1, pre_trained: bool = True, layers: list[str] = ["layer1", "layer2", "layer3"], pre_proj: int = 1, - dsc_layers:int = 2, - dsc_hidden:int = 1024, - dsc_margin:int = 0.5 + dsc_layers: int = 2, + dsc_hidden: int = 1024, + dsc_margin: float = 0.5, ) -> None: super().__init__() self.backbone = backbone @@ -174,7 +191,9 @@ def __init__( self.forward_modules = torch.ModuleDict({}) feature_aggregator = NetworkFeatureAggregator( - self.backbone, self.layers, pre_trained + self.backbone, + self.layers, + pre_trained, ) feature_dimensions = feature_aggregator.feature_dimensions(input_shape) self.forward_modules["feature_aggregator"] = feature_aggregator @@ -194,12 +213,32 @@ def __init__( self.dsc_layers = dsc_layers self.dsc_hidden = dsc_hidden self.dsc_margin = dsc_margin - self.discriminator = Discriminator(self.target_embed_dimension, n_layers=self.dsc_layers, hidden=self.dsc_hidden) + self.discriminator = Discriminator( + self.target_embed_dimension, + n_layers=self.dsc_layers, + hidden=self.dsc_hidden, + ) self.patch_maker = PatchMaker(patchsize, stride=patchstride) self.anomaly_segmentor = RescaleSegmentor() + def calculate_mean(self, images): + self.forward_modules.eval() + with torch.no_grad(): + if self.pre_proj > 0: + outputs = self.pre_projection(self.generate_embeddings(images)) + outputs = outputs[0] if len(outputs) == 2 else outputs + else: + outputs = self._embed(images, evaluation=False)[0] + + outputs = outputs[0] if len(outputs) == 2 else outputs + outputs = outputs.reshape(images.shape[0], -1, outputs.shape[-1]) + + batch_mean = torch.mean(outputs, dim=0) + + return batch_mean + def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): if not eval and not self.pre_trained: self.forward_modules["feature_aggregator"].train() @@ -208,7 +247,7 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): self.forward_modules["feature_aggregator"].eval() with torch.no_grad(): features = self.forward_modules["feature_aggregator"](images) - + features = [features[layer] for layer in self.layers_to_extract_from] for i, feat in enumerate(features): if len(feat.shape) == 3: @@ -225,7 +264,10 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): patch_dims = patch_shapes[i] _features = _features.reshape( - _features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:] + _features.shape[0], + patch_dims[0], + patch_dims[1], + *_features.shape[2:], ) _features = _features.permute(0, 3, 4, 5, 1, 2) perm_base_shape = _features.shape @@ -238,7 +280,9 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): ) _features = _features.squeeze(1) _features = _features.reshape( - *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1] + *perm_base_shape[:-2], + ref_num_patches[0], + ref_num_patches[1], ) _features = _features.permute(0, 4, 5, 1, 2, 3) _features = _features.reshape(len(_features), -1, *_features.shape[-3:]) @@ -249,13 +293,13 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): patch_features = self.forward_modules["preadapt_aggregator"](patch_features) return patch_features, patch_shapes - + def forward(self, img, aug, eval=False): if self.pre_proj > 0: - fake_feats = self.pre_projection(self.generate_embeddings(aug, evaluation=eval)[0]) - fake_feats = fake_feats[0] if len(fake_feats) == 2 else fake_feats - true_feats = self.pre_projection(self.generate_embeddings(img, evaluation=eval)[0]) - true_feats = true_feats[0] if len(true_feats) == 2 else true_feats + fake_feats = self.pre_projection(self.generate_embeddings(aug, evaluation=eval)[0]) + fake_feats = fake_feats[0] if len(fake_feats) == 2 else fake_feats + true_feats = self.pre_projection(self.generate_embeddings(img, evaluation=eval)[0]) + true_feats = true_feats[0] if len(true_feats) == 2 else true_feats else: fake_feats = self.generate_embeddings(aug, evaluation=eval)[0] fake_feats.requires_grad = True diff --git a/third-party-programs.txt b/third-party-programs.txt index 5eeaca8ea9..751477a0c9 100644 --- a/third-party-programs.txt +++ b/third-party-programs.txt @@ -46,3 +46,7 @@ terms are listed below. 8. AUPIMO metric implementation is based on the original code Copyright (c) 2023 @jpcbertoldo, https://github.com/jpcbertoldo/aupimo SPDX-License-Identifier: MIT + +9. GLASS Model implementation is based on the original code + Copyright (c) 2024 Qiyu Chen, https://github.com/cqylunlun/GLASS + SPDX-License-Identifier: MIT From cdd09841867a8bee1b40130109dedb067efe52e6 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Mon, 28 Apr 2025 01:24:54 +0530 Subject: [PATCH 04/23] Modified forward method for model Signed-off-by: Devansh Agarwal --- src/anomalib/models/image/glass/torch_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index d9c559be3a..7cf8814f73 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -294,16 +294,16 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): return patch_features, patch_shapes - def forward(self, img, aug, eval=False): + def forward(self, img, aug, evaluation=False): if self.pre_proj > 0: - fake_feats = self.pre_projection(self.generate_embeddings(aug, evaluation=eval)[0]) + fake_feats = self.pre_projection(self.generate_embeddings(aug, evaluation=evaluation)[0]) fake_feats = fake_feats[0] if len(fake_feats) == 2 else fake_feats - true_feats = self.pre_projection(self.generate_embeddings(img, evaluation=eval)[0]) + true_feats = self.pre_projection(self.generate_embeddings(img, evaluation=evaluation)[0]) true_feats = true_feats[0] if len(true_feats) == 2 else true_feats else: - fake_feats = self.generate_embeddings(aug, evaluation=eval)[0] + fake_feats = self.generate_embeddings(aug, evaluation=evaluation)[0] fake_feats.requires_grad = True - true_feats = self.generate_embeddings(img, evaluation=eval)[0] + true_feats = self.generate_embeddings(img, evaluation=evaluation)[0] true_feats.requires_grad = True return true_feats, fake_feats From 381eec63668e33639f51b35e747442993b9e88db Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Wed, 30 Apr 2025 21:15:22 +0530 Subject: [PATCH 05/23] Fixed backbone loading logic Signed-off-by: Devansh Agarwal --- src/anomalib/models/image/glass/torch_model.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 7cf8814f73..9ff4463040 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -11,6 +11,8 @@ from anomalib.models.components import NetworkFeatureAggregator +from .backbones import load + def init_weight(m): if isinstance(m, torch.nn.Linear): @@ -168,6 +170,12 @@ def convert_to_segmentation(self, patch_scores): return [ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores] +def process_backbone(backbone): + if isinstance(backbone, str): + return load(backbone) + return backbone + + class GlassModel(nn.Module): def __init__( self, @@ -185,7 +193,8 @@ def __init__( dsc_margin: float = 0.5, ) -> None: super().__init__() - self.backbone = backbone + + self.backbone = process_backbone(backbone) self.layers = layers self.input_shape = input_shape From 9b1c51a41c7e802fbec11d3b46e6c6db4ea49741 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Sun, 4 May 2025 22:43:30 +0530 Subject: [PATCH 06/23] Added type for input shape Signed-off-by: Devansh Agarwal --- src/anomalib/models/image/glass/lightning_model.py | 2 +- src/anomalib/models/image/glass/torch_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index cfa8d7725f..28c612a556 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -23,7 +23,7 @@ class Glass(AnomalibModule): def __init__( self, - input_shape, + input_shape: tuple[int, int, int], anomaly_source_path: str, backbone: str | nn.Module = "resnet18", pretrain_embed_dim: int = 1024, diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 9ff4463040..748fec8dcd 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -179,7 +179,7 @@ def process_backbone(backbone): class GlassModel(nn.Module): def __init__( self, - input_shape, + input_shape: tuple[int, int, int], pretrain_embed_dim: int = 1024, target_embed_dim: int = 1024, backbone: str | nn.Module = "resnet18", From 161005ccd8da7caeefa44b1b9177d46fb7b57188 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Sun, 4 May 2025 23:31:57 +0530 Subject: [PATCH 07/23] Fixed bugs Signed-off-by: Devansh Agarwal --- .../network_feature_extractor.py | 3 +- .../models/image/glass/lightning_model.py | 45 +++++++++---------- .../models/image/glass/torch_model.py | 4 +- 3 files changed, 24 insertions(+), 28 deletions(-) diff --git a/src/anomalib/models/components/feature_extractors/network_feature_extractor.py b/src/anomalib/models/components/feature_extractors/network_feature_extractor.py index ea4a96dea6..eec3ad5fc7 100644 --- a/src/anomalib/models/components/feature_extractors/network_feature_extractor.py +++ b/src/anomalib/models/components/feature_extractors/network_feature_extractor.py @@ -29,7 +29,6 @@ def __init__(self, backbone, layers_to_extract_from, pre_trained=False): for extract_layer in layers_to_extract_from: self.register_hook(extract_layer) - self.to(self.device) def forward(self, images, eval=True): self.outputs.clear() @@ -45,7 +44,7 @@ def forward(self, images, eval=True): def feature_dimensions(self, input_shape): """Computes the feature dimensions for all layers given input_shape.""" - _input = torch.ones([1] + list(input_shape)).to(self.device) + _input = torch.ones([1] + list(input_shape)) _output = self(_input) return [_output[layer].shape[1] for layer in self.layers_to_extract_from] diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index 28c612a556..ce14799fe9 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -16,9 +16,10 @@ from anomalib.visualization import Visualizer from .loss import FocalLoss -from .perlin import PerlinNoise from .torch_model import GlassModel +from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator + class Glass(AnomalibModule): def __init__( @@ -54,7 +55,7 @@ def __init__( visualizer=visualizer, ) - self.perlin = PerlinNoise(anomaly_source_path) + self.augmentor = PerlinAnomalyGenerator(anomaly_source_path) self.model = GlassModel( input_shape=input_shape, @@ -82,31 +83,27 @@ def __init__( self.focal_loss = FocalLoss() - def configure_optimizers(self) -> list[optim.Optimizer]: - optimizers = [] - if not self.model.pre_trained: - backbone_opt = optim.AdamW(self.model.foward_modules["feature_aggregator"].backbone.parameters(), self.lr) - optimizers.append(backbone_opt) + if pre_proj > 0: + self.proj_opt = optim.AdamW(self.model.pre_projection.parameters(), self.lr, weight_decay=1e-5) else: - optimizers.append(None) + self.proj_opt = None - if self.model.pre_proj > 0: - proj_opt = optim.AdamW(self.model.pre_projection.parameters(), self.lr, weight_decay=1e-5) - optimizers.append(proj_opt) + if not pre_trained: + self.backbone_opt = optim.AdamW(self.model.foward_modules["feature_aggregator"].backbone.parameters(), self.lr) else: - optimizers.append(None) + self.backbone_opt = None + def configure_optimizers(self) -> list[optim.Optimizer]: dsc_opt = optim.AdamW(self.model.discriminator.parameters(), lr=self.lr * 2) - optimizers.append(dsc_opt) - return optimizers + return dsc_opt def training_step( self, batch: Batch, batch_idx: int, ) -> STEP_OUTPUT: - backbone_opt, proj_opt, dsc_opt = self.optimizers() + dsc_opt = self.optimizers() self.model.forward_modules.eval() if self.model.pre_proj > 0: @@ -114,13 +111,13 @@ def training_step( self.model.discriminator.train() dsc_opt.zero_grad() - if proj_opt is not None: - proj_opt.zero_grad() - if backbone_opt is not None: - backbone_opt.zero_grad() + if self.proj_opt is not None: + self.proj_opt.zero_grad() + if self.backbone_opt is not None: + self.backbone_opt.zero_grad() img = batch.image - aug, mask_s = self.perlin(img) + aug, mask_s = self.augmentor(img) true_feats, fake_feats = self.model(img, aug) @@ -191,10 +188,10 @@ def training_step( loss = bce_loss + focal_loss loss.backward() - if proj_opt is not None: - proj_opt.step() - if backbone_opt is not None: - backbone_opt.step() + if self.proj_opt is not None: + self.proj_opt.step() + if self.backbone_opt is not None: + self.backbone_opt.step() dsc_opt.step() def on_train_start(self) -> None: diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 748fec8dcd..cdcf95f503 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -198,7 +198,7 @@ def __init__( self.layers = layers self.input_shape = input_shape - self.forward_modules = torch.ModuleDict({}) + self.forward_modules = torch.nn.ModuleDict({}) feature_aggregator = NetworkFeatureAggregator( self.backbone, self.layers, @@ -257,7 +257,7 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): with torch.no_grad(): features = self.forward_modules["feature_aggregator"](images) - features = [features[layer] for layer in self.layers_to_extract_from] + features = [features[layer] for layer in self.layers] for i, feat in enumerate(features): if len(feat.shape) == 3: B, L, C = feat.shape From 617cf498ffcdc6b2b293b10d2f9f016d037af923 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Tue, 13 May 2025 20:41:23 +0530 Subject: [PATCH 08/23] Changed files as needed Signed-off-by: Devansh Agarwal --- src/anomalib/models/components/__init__.py | 2 +- .../components/feature_extractors/__init__.py | 1 - .../network_feature_extractor.py | 90 ------------ src/anomalib/models/image/__init__.py | 4 +- src/anomalib/models/image/glass/__init__.py | 2 +- src/anomalib/models/image/glass/backbones.py | 56 ------- .../models/image/glass/lightning_model.py | 90 +++++++++++- src/anomalib/models/image/glass/loss.py | 50 +++++-- src/anomalib/models/image/glass/perlin.py | 139 ------------------ .../models/image/glass/torch_model.py | 91 ++++++++---- 10 files changed, 181 insertions(+), 344 deletions(-) delete mode 100644 src/anomalib/models/components/feature_extractors/network_feature_extractor.py delete mode 100644 src/anomalib/models/image/glass/backbones.py delete mode 100644 src/anomalib/models/image/glass/perlin.py diff --git a/src/anomalib/models/components/__init__.py b/src/anomalib/models/components/__init__.py index 24057cca0d..a9108c1c16 100644 --- a/src/anomalib/models/components/__init__.py +++ b/src/anomalib/models/components/__init__.py @@ -38,7 +38,7 @@ from .base import AnomalibModule, BufferListMixin, DynamicBufferMixin, MemoryBankMixin from .dimensionality_reduction import PCA, SparseRandomProjection -from .feature_extractors import TimmFeatureExtractor, NetworkFeatureAggregator +from .feature_extractors import TimmFeatureExtractor from .filters import GaussianBlur2d from .sampling import KCenterGreedy from .stats import GaussianKDE, MultiVariateGaussian diff --git a/src/anomalib/models/components/feature_extractors/__init__.py b/src/anomalib/models/components/feature_extractors/__init__.py index c0ed169f4a..b9936e793d 100644 --- a/src/anomalib/models/components/feature_extractors/__init__.py +++ b/src/anomalib/models/components/feature_extractors/__init__.py @@ -28,7 +28,6 @@ from .timm import TimmFeatureExtractor from .utils import dryrun_find_featuremap_dims -from .network_feature_extractor import NetworkFeatureAggregator __all__ = [ "dryrun_find_featuremap_dims", "TimmFeatureExtractor", diff --git a/src/anomalib/models/components/feature_extractors/network_feature_extractor.py b/src/anomalib/models/components/feature_extractors/network_feature_extractor.py deleted file mode 100644 index eec3ad5fc7..0000000000 --- a/src/anomalib/models/components/feature_extractors/network_feature_extractor.py +++ /dev/null @@ -1,90 +0,0 @@ -import torch -from torch import nn -import copy - - -class NetworkFeatureAggregator(torch.nn.Module): - """Efficient extraction of network features.""" - - def __init__(self, backbone, layers_to_extract_from, pre_trained=False): - super(NetworkFeatureAggregator, self).__init__() - """Extraction of network features. - - Runs a network only to the last layer of the list of layers where - network features should be extracted from. - - Args: - backbone: torchvision.model - layers_to_extract_from: [list of str] - """ - self.layers_to_extract_from = layers_to_extract_from - self.backbone = backbone - self.pre_trained = pre_trained - if not hasattr(backbone, "hook_handles"): - self.backbone.hook_handles = [] - for handle in self.backbone.hook_handles: - handle.remove() - self.outputs = {} - - for extract_layer in layers_to_extract_from: - self.register_hook(extract_layer) - - - def forward(self, images, eval=True): - self.outputs.clear() - if not self.pre_trained and not eval: - self.backbone(images) - else: - with torch.no_grad(): - try: - _ = self.backbone(images) - except LastLayerToExtractReachedException: - pass - return self.outputs - - def feature_dimensions(self, input_shape): - """Computes the feature dimensions for all layers given input_shape.""" - _input = torch.ones([1] + list(input_shape)) - _output = self(_input) - return [_output[layer].shape[1] for layer in self.layers_to_extract_from] - - def register_hook(self, layer_name): - module = self.find_module(self.backbone, layer_name) - if module is not None: - forward_hook = ForwardHook( - self.outputs, layer_name, self.layers_to_extract_from[-1] - ) - if isinstance(module, torch.nn.Sequential): - hook = module[-1].register_forward_hook(forward_hook) - else: - hook = module.register_forward_hook(forward_hook) - self.backbone.hook_handles.append(hook) - else: - raise ValueError(f"Module {layer_name} not found in the model") - - def find_module(self, model, module_name): - for name, module in model.named_modules(): - if name == module_name: - return module - elif "." in module_name: - father, child = module_name.split(".", 1) - if name == father: - return self.find_module(module, child) - return None - - -class ForwardHook: - def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str): - self.hook_dict = hook_dict - self.layer_name = layer_name - self.raise_exception_to_break = copy.deepcopy( - layer_name == last_layer_to_extract - ) - - def __call__(self, module, input, output): - self.hook_dict[self.layer_name] = output - return None - - -class LastLayerToExtractReachedException(Exception): - pass diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index d58f9805df..6da89cbb12 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -85,7 +85,5 @@ "Supersimplenet", "Uflow", "VlmAd", - "WinClip", - "Padim", - "Glass", + "WinClip" ] diff --git a/src/anomalib/models/image/glass/__init__.py b/src/anomalib/models/image/glass/__init__.py index 06a7d7a9c2..4153e59684 100644 --- a/src/anomalib/models/image/glass/__init__.py +++ b/src/anomalib/models/image/glass/__init__.py @@ -1,4 +1,4 @@ -# Copyright (C) 2022-2025 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .lightning_model import Glass as Glass diff --git a/src/anomalib/models/image/glass/backbones.py b/src/anomalib/models/image/glass/backbones.py deleted file mode 100644 index 3795819ded..0000000000 --- a/src/anomalib/models/image/glass/backbones.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (C) 2022-2025 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import torchvision.models as models -import timm - -"""copy from: https://github.com/cqylunlun/GLASS/blob/main/backbones.py -This provides mechanism to import any of the given backbones using its name. -""" -_BACKBONES = { - "alexnet": "models.alexnet(pretrained=True)", - "resnet18": "models.resnet18(pretrained=True)", - "resnet50": "models.resnet50(pretrained=True)", - "resnet101": "models.resnet101(pretrained=True)", - "resnext101": "models.resnext101_32x8d(pretrained=True)", - "resnet200": 'timm.create_model("resnet200", pretrained=True)', - "resnest50": 'timm.create_model("resnest50d_4s2x40d", pretrained=True)', - "resnetv2_50_bit": 'timm.create_model("resnetv2_50x3_bitm", pretrained=True)', - "resnetv2_50_21k": 'timm.create_model("resnetv2_50x3_bitm_in21k", pretrained=True)', - "resnetv2_101_bit": 'timm.create_model("resnetv2_101x3_bitm", pretrained=True)', - "resnetv2_101_21k": 'timm.create_model("resnetv2_101x3_bitm_in21k", pretrained=True)', - "resnetv2_152_bit": 'timm.create_model("resnetv2_152x4_bitm", pretrained=True)', - "resnetv2_152_21k": 'timm.create_model("resnetv2_152x4_bitm_in21k", pretrained=True)', - "resnetv2_152_384": 'timm.create_model("resnetv2_152x2_bit_teacher_384", pretrained=True)', - "resnetv2_101": 'timm.create_model("resnetv2_101", pretrained=True)', - "vgg11": "models.vgg11(pretrained=True)", - "vgg19": "models.vgg19(pretrained=True)", - "vgg19_bn": "models.vgg19_bn(pretrained=True)", - "wideresnet50": "models.wide_resnet50_2(pretrained=True)", - "wideresnet101": "models.wide_resnet101_2(pretrained=True)", - "mnasnet_100": 'timm.create_model("mnasnet_100", pretrained=True)', - "mnasnet_a1": 'timm.create_model("mnasnet_a1", pretrained=True)', - "mnasnet_b1": 'timm.create_model("mnasnet_b1", pretrained=True)', - "densenet121": 'timm.create_model("densenet121", pretrained=True)', - "densenet201": 'timm.create_model("densenet201", pretrained=True)', - "inception_v4": 'timm.create_model("inception_v4", pretrained=True)', - "vit_small": 'timm.create_model("vit_small_patch16_224", pretrained=True)', - "vit_base": 'timm.create_model("vit_base_patch16_224", pretrained=True)', - "vit_large": 'timm.create_model("vit_large_patch16_224", pretrained=True)', - "vit_r50": 'timm.create_model("vit_large_r50_s32_224", pretrained=True)', - "vit_deit_base": 'timm.create_model("deit_base_patch16_224", pretrained=True)', - "vit_deit_distilled": 'timm.create_model("deit_base_distilled_patch16_224", pretrained=True)', - "vit_swin_base": 'timm.create_model("swin_base_patch4_window7_224", pretrained=True)', - "vit_swin_large": 'timm.create_model("swin_large_patch4_window7_224", pretrained=True)', - "efficientnet_b7": 'timm.create_model("tf_efficientnet_b7", pretrained=True)', - "efficientnet_b5": 'timm.create_model("tf_efficientnet_b5", pretrained=True)', - "efficientnet_b3": 'timm.create_model("tf_efficientnet_b3", pretrained=True)', - "efficientnet_b1": 'timm.create_model("tf_efficientnet_b1", pretrained=True)', - "efficientnetv2_m": 'timm.create_model("tf_efficientnetv2_m", pretrained=True)', - "efficientnetv2_l": 'timm.create_model("tf_efficientnetv2_l", pretrained=True)', - "efficientnet_b3a": 'timm.create_model("efficientnet_b3a", pretrained=True)', -} - - -def load(name): - return eval(_BACKBONES[name]) diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index ce14799fe9..0b39d08587 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -1,4 +1,18 @@ -# Copyright (C) 2022-2025 Intel Corporation +"""GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. + +This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in industrial settings. + +The model consists of: + - A feature extractor and feature adaptor to obtain robust normal representations + - A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with truncated projection + - A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks + - A shared discriminator trained with features from normal, global, and local synthetic samples + +Paper: `A Unified Anomaly Synthesis Strategy with Gradient Ascent for Industrial Anomaly Detection and Localization +` +""" + +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from typing import Any @@ -22,11 +36,66 @@ class Glass(AnomalibModule): + """PyTorch Lightning Implementation of the GLASS Model + + The model uses a pre-trained feature extractor to extract features and a feature adaptor to mitigate latent domain bias. + Global anomaly features are synthesized from adapted normal features using gradient ascent. + Local anomaly images are synthesized using texture overlay datasets like dtd which are then processed by feature extractor and feature adaptor. + All three different features are passed to the discriminator trained using loss functions. + + Args: + input_shape (tuple[int, int]): Input image dimensions as a tuple of (height, width). Required for shaping the input pipeline. + anomaly_source_path (str): Path to the dataset or source directory containing normal images and anomaly textures. + backbone (str, optional): Name of the CNN backbone used for feature extraction. + Defaults to `"resnet18"`. + pretrain_embed_dim (int, optional): Dimensionality of features extracted by the pre-trained backbone before adaptation. + Defaults to `1024`. + target_embed_dim (int, optional): Dimensionality of the target adapted features after projection. + Defaults to `1024`. + patchsize (int, optional): Size of the local patch used in feature aggregation (e.g., for neighborhood pooling). + Defaults to `3`. + patchstride (int, optional): Stride used when extracting patches for local feature aggregation. + Defaults to `1`. + pre_trained (bool, optional): Whether to use ImageNet pre-trained weights for the backbone network. + Defaults to `True`. + layers (list[str], optional): List of backbone layers to extract features from. + Defaults to `["layer1", "layer2", "layer3"]`. + pre_proj (int, optional): Number of projection layers used in the feature adaptor (e.g., MLP before discriminator). + Defaults to `1`. + dsc_layers (int, optional): Number of layers in the discriminator network. + Defaults to `2`. + dsc_hidden (int, optional): Number of hidden units in each discriminator layer. + Defaults to `1024`. + dsc_margin (float, optional): Margin used for contrastive or binary classification loss in discriminator training. + Defaults to `0.5`. + pre_processor (PreProcessor | bool, optional): reprocessing module or flag to enable default preprocessing. + Set to `True` to apply default normalization and resizing. + Defaults to `True`. + post_processor (PostProcessor | bool, optional): Postprocessing module or flag to enable default output smoothing or thresholding. + Defaults to `True`. + evaluator (Evaluator | bool, optional): Evaluation module for calculating metrics such as AUROC and PRO. + Defaults to `True`. + visualizer (Visualizer | bool, optional): Visualization module to generate heatmaps, segmentation overlays, and anomaly scores. + Defaults to `True`. + mining (int, optional): Number of iterations or difficulty level for Online Hard Example Mining (OHEM) during training. + Defaults to `1`. + noise (float, optional): Standard deviation of Gaussian noise used in feature-level anomaly synthesis. + Defaults to `0.015`. + radius (float, optional): Radius parameter used for truncated projection in the anomaly synthesis strategy. + Determines the range for valid synthetic anomalies in the hypersphere or manifold. + Defaults to `0.75`. + p (float, optional): Probability used in random selection logic, such as anomaly mask generation or augmentation choice. + Defaults to `0.5`. + lr (float, optional): Learning rate for training the feature adaptor and discriminator networks. + Defaults to `0.0001`. + step (int, optional): Number of gradient ascent steps or + """ + def __init__( self, - input_shape: tuple[int, int, int], + input_shape: tuple[int, int], anomaly_source_path: str, - backbone: str | nn.Module = "resnet18", + backbone: str = "resnet18", pretrain_embed_dim: int = 1024, target_embed_dim: int = 1024, patchsize: int = 3, @@ -84,12 +153,17 @@ def __init__( self.focal_loss = FocalLoss() if pre_proj > 0: - self.proj_opt = optim.AdamW(self.model.pre_projection.parameters(), self.lr, weight_decay=1e-5) + self.proj_opt = optim.AdamW( + self.model.pre_projection.parameters(), self.lr, weight_decay=1e-5 + ) else: self.proj_opt = None if not pre_trained: - self.backbone_opt = optim.AdamW(self.model.foward_modules["feature_aggregator"].backbone.parameters(), self.lr) + self.backbone_opt = optim.AdamW( + self.model.foward_modules["feature_aggregator"].backbone.parameters(), + self.lr, + ) else: self.backbone_opt = None @@ -107,7 +181,7 @@ def training_step( self.model.forward_modules.eval() if self.model.pre_proj > 0: - self.pre_projection.train() + self.model.pre_projection.train() self.model.discriminator.train() dsc_opt.zero_grad() @@ -127,7 +201,9 @@ def training_step( center = self.c.repeat(img.shape[0], 1, 1) center = center.reshape(-1, center.shape[-1]) - true_points = torch.concat([fake_feats[mask_s_gt[:, 0] == 0], true_feats], dim=0) + true_points = torch.concat( + [fake_feats[mask_s_gt[:, 0] == 0], true_feats], dim=0 + ) c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0) dist_t = torch.norm(true_points - c_t_points, dim=1) r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(self.device) diff --git a/src/anomalib/models/image/glass/loss.py b/src/anomalib/models/image/glass/loss.py index f1f5341d47..39a9b52ab5 100644 --- a/src/anomalib/models/image/glass/loss.py +++ b/src/anomalib/models/image/glass/loss.py @@ -1,5 +1,12 @@ -# Copyright (C) 2022-2025 Intel Corporation +# Original Code +# Copyright (c) 2021 @Hsuxu +# https://github.com/Hsuxu/Loss_ToolBox-PyTorch. # SPDX-License-Identifier: Apache-2.0 +# +# Modified +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import numpy as np import torch @@ -7,20 +14,31 @@ class FocalLoss(nn.Module): - """copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py - This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in - 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' - Focal_Loss= -1*alpha*(1-pt)*log(pt) - :param num_class: - :param alpha: (tensor) 3D or 4D the scalar factor for this criterion - :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more - focus on hard misclassified example - :param smooth: (float,double) smooth value when cross entropy - :param balance_index: (int) balance class index, should be specific when alpha is float - :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. + """Implementation of Focal Loss with support for smoothed label cross-entropy, as proposed in + 'Focal Loss for Dense Object Detection' (https://arxiv.org/abs/1708.02002). + + The focal loss formula is: + Focal_Loss = -1 * alpha * (1 - pt) ** gamma * log(pt) + + Args: + num_class (int): Number of classes. + alpha (float or Tensor): Scalar or Tensor weight factor for class imbalance. If float, `balance_index` should be set. + gamma (float): Focusing parameter that reduces the relative loss for well-classified examples (gamma > 0). + smooth (float): Label smoothing factor for cross-entropy. + balance_index (int): Index of the class to balance when `alpha` is a float. + size_average (bool, optional): If True (default), the loss is averaged over the batch; otherwise, the loss is summed. + """ - def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): + def __init__( + self, + apply_nonlin: nn.Module | None = None, + alpha: float | torch.Tensor = None, + gamma: float = 2, + balance_index: int = 0, + smooth: float = 1e-5, + size_average: bool = True, + ): super(FocalLoss, self).__init__() self.apply_nonlin = apply_nonlin self.alpha = alpha @@ -71,7 +89,11 @@ def forward(self, logit, target): one_hot_key = one_hot_key.to(logit.device) if self.smooth: - one_hot_key = torch.clamp(one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) + one_hot_key = torch.clamp( + one_hot_key, + self.smooth / (num_class - 1), + 1.0 - self.smooth, + ) pt = (one_hot_key * logit).sum(1) + self.smooth logpt = pt.log() diff --git a/src/anomalib/models/image/glass/perlin.py b/src/anomalib/models/image/glass/perlin.py deleted file mode 100644 index 2373d6f05d..0000000000 --- a/src/anomalib/models/image/glass/perlin.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (C) 2022-2025 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import glob -import math - -import imgaug.augmenters as iaa -import numpy as np -import PIL -import PIL.Image -import torch -from torch import nn -from torchvision import transforms - -IMAGENET_MEAN = [0.485, 0.456, 0.406] -IMAGENET_STD = [0.229, 0.224, 0.225] - - -def generate_thr(img_shape, min=0, max=4): - min_perlin_scale = min - max_perlin_scale = max - perlin_scalex = 2 ** np.random.randint(min_perlin_scale, max_perlin_scale) - perlin_scaley = 2 ** np.random.randint(min_perlin_scale, max_perlin_scale) - perlin_noise_np = rand_perlin_2d_np((img_shape[1], img_shape[2]), (perlin_scalex, perlin_scaley)) - threshold = 0.5 - perlin_noise_np = iaa.Sequential([iaa.Affine(rotate=(-90, 90))])(image=perlin_noise_np) - perlin_thr = np.where(perlin_noise_np > threshold, np.ones_like(perlin_noise_np), np.zeros_like(perlin_noise_np)) - return perlin_thr - - -def perlin_mask(img_shape, feat_size, min, max, mask_fg, flag=0): - mask = np.zeros((feat_size, feat_size)) - while np.max(mask) == 0: - perlin_thr_1 = generate_thr(img_shape, min, max) - perlin_thr_2 = generate_thr(img_shape, min, max) - temp = torch.rand(1).numpy()[0] - if temp > 2 / 3: - perlin_thr = perlin_thr_1 + perlin_thr_2 - perlin_thr = np.where(perlin_thr > 0, np.ones_like(perlin_thr), np.zeros_like(perlin_thr)) - elif temp > 1 / 3: - perlin_thr = perlin_thr_1 * perlin_thr_2 - else: - perlin_thr = perlin_thr_1 - perlin_thr = torch.from_numpy(perlin_thr) - perlin_thr_fg = perlin_thr * mask_fg - down_ratio_y = int(img_shape[1] / feat_size) - down_ratio_x = int(img_shape[2] / feat_size) - mask_ = perlin_thr_fg - mask = torch.nn.functional.max_pool2d( - perlin_thr_fg.unsqueeze(0).unsqueeze(0), - (down_ratio_y, down_ratio_x), - ).float() - mask = mask.numpy()[0, 0] - mask_s = mask - if flag != 0: - mask_l = mask_.numpy() - if flag == 0: - return mask_s - return mask_s, mask_l - - -def lerp_np(x, y, w): - fin_out = (y - x) * w + x - return fin_out - - -def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): - delta = (res[0] / shape[0], res[1] / shape[1]) - d = (shape[0] // res[0], shape[1] // res[1]) - grid = np.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(1, 2, 0) % 1 - - angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1) - gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) - tt = np.repeat(np.repeat(gradients, d[0], axis=0), d[1], axis=1) - - tile_grads = lambda slice1, slice2: np.repeat( - np.repeat(gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]], d[0], axis=0), - d[1], - axis=1, - ) - dot = lambda grad, shift: ( - np.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), axis=-1) - * grad[: shape[0], : shape[1]] - ).sum(axis=-1) - - n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) - n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) - n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) - n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) - t = fade(grid[: shape[0], : shape[1]]) - return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1]) - - -class PerlinNoise(nn.Module): - def __init__(self, anomaly_source_path): - super().__init__() - self.anomaly_source_paths = sorted(glob.glob(anomaly_source_path + "/*/*.jpg")) - - def forward(self, image): - aug = PIL.Image.open(np.random.choice(self.anomaly_source_paths)).convert("RGB") - transform_aug = self.rand_augmenter() - aug = transform_aug(aug) - - mask_all = perlin_mask(image.shape, self.imgsize // self.downsampling, 0, 6, mask_fg, 1) - mask_s = torch.from_numpy(mask_all[0]) - mask_l = torch.from_numpy(mask_all[1]) - mask_fg = torch.tensor([1]) - - beta = np.random.normal(loc=self.mean, scale=self.std) - beta = np.clip(beta, 0.2, 0.8) - aug_image = image * (1 - mask_l) + (1 - beta) * aug * mask_l + beta * image * mask_l - return aug_image, mask_s - - def rand_augmenter(self): - list_aug = [ - transforms.ColorJitter(contrast=(0.8, 1.2)), - transforms.ColorJitter(brightness=(0.8, 1.2)), - transforms.ColorJitter(saturation=(0.8, 1.2), hue=(-0.2, 0.2)), - transforms.RandomHorizontalFlip(p=1), - transforms.RandomVerticalFlip(p=1), - transforms.RandomGrayscale(p=1), - transforms.RandomAutocontrast(p=1), - transforms.RandomEqualize(p=1), - transforms.RandomAffine(degrees=(-45, 45)), - ] - aug_idx = np.random.choice(np.arange(len(list_aug)), 3, replace=False) - - transform_aug = [ - transforms.Resize(self.resize), - list_aug[aug_idx[0]], - list_aug[aug_idx[1]], - list_aug[aug_idx[2]], - transforms.CenterCrop(self.imgsize), - transforms.ToTensor(), - transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), - ] - - transform_aug = transforms.Compose(transform_aug) - return transform_aug diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index cdcf95f503..1d9e5fe50b 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -1,4 +1,4 @@ -# Copyright (C) 2022-2025 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import math @@ -9,9 +9,8 @@ from scipy import ndimage from torch import nn -from anomalib.models.components import NetworkFeatureAggregator - -from .backbones import load +from anomalib.models.components import TimmFeatureExtractor +from anomalib.models.components.feature_extractors import dryrun_find_featuremap_dims def init_weight(m): @@ -24,6 +23,22 @@ def init_weight(m): m.weight.data.normal_(0.0, 0.02) +def _deduce_dims( + feature_extractor: TimmFeatureExtractor, + input_size: tuple[int, int], + layers: list[str], +) -> list[int | tuple[int, int]]: + dimensions_mapping = dryrun_find_featuremap_dims( + feature_extractor, input_size, layers + ) + + n_features_original = [ + dimensions_mapping[layer]["num_features"] for layer in layers + ] + + return n_features_original + + class Preprocessing(torch.nn.Module): def __init__(self, input_dims, output_dim): super(Preprocessing, self).__init__() @@ -104,7 +119,9 @@ def __init__(self, in_planes, n_layers=2, hidden=None): torch.nn.LeakyReLU(0.2), ), ) - self.tail = torch.nn.Sequential(torch.nn.Linear(_hidden, 1, bias=False), torch.nn.Sigmoid()) + self.tail = torch.nn.Sequential( + torch.nn.Linear(_hidden, 1, bias=False), torch.nn.Sigmoid() + ) self.apply(init_weight) def forward(self, x): @@ -121,11 +138,15 @@ def __init__(self, patchsize, top_k=0, stride=None): def patchify(self, features, return_spatial_info=False): padding = int((self.patchsize - 1) / 2) - unfolder = torch.nn.Unfold(kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1) + unfolder = torch.nn.Unfold( + kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1 + ) unfolded_features = unfolder(features) number_of_total_patches = [] for s in features.shape[-2:]: - n_patches = (s + 2 * padding - 1 * (self.patchsize - 1) - 1) / self.stride + 1 + n_patches = ( + s + 2 * padding - 1 * (self.patchsize - 1) - 1 + ) / self.stride + 1 number_of_total_patches.append(int(n_patches)) unfolded_features = unfolded_features.reshape( *features.shape[:2], @@ -167,22 +188,19 @@ def convert_to_segmentation(self, patch_scores): ) _scores = _scores.squeeze(1) patch_scores = _scores.cpu().numpy() - return [ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores] - - -def process_backbone(backbone): - if isinstance(backbone, str): - return load(backbone) - return backbone + return [ + ndimage.gaussian_filter(patch_score, sigma=self.smoothing) + for patch_score in patch_scores + ] class GlassModel(nn.Module): def __init__( self, - input_shape: tuple[int, int, int], + input_shape: tuple[int, int], # (H, W) pretrain_embed_dim: int = 1024, target_embed_dim: int = 1024, - backbone: str | nn.Module = "resnet18", + backbone: str = "resnet18", patchsize: int = 3, patchstride: int = 1, pre_trained: bool = True, @@ -194,17 +212,18 @@ def __init__( ) -> None: super().__init__() - self.backbone = process_backbone(backbone) + self.backbone = backbone self.layers = layers self.input_shape = input_shape + self.pre_trained = pre_trained self.forward_modules = torch.nn.ModuleDict({}) - feature_aggregator = NetworkFeatureAggregator( - self.backbone, - self.layers, - pre_trained, + feature_aggregator = TimmFeatureExtractor( + backbone=self.backbone, + layers=self.layers, + pre_trained=self.pre_trained, ) - feature_dimensions = feature_aggregator.feature_dimensions(input_shape) + feature_dimensions = _deduce_dims(feature_aggregator, self.input_shape, layers) self.forward_modules["feature_aggregator"] = feature_aggregator preprocessing = Preprocessing(feature_dimensions, pretrain_embed_dim) @@ -213,11 +232,11 @@ def __init__( preadapt_aggregator = Aggregator(target_dim=target_embed_dim) self.forward_modules["preadapt_aggregator"] = preadapt_aggregator - self.pre_trained = pre_trained - self.pre_proj = pre_proj if self.pre_proj > 0: - self.pre_projection = Projection(self.target_embed_dimension, self.target_embed_dimension, pre_proj) + self.pre_projection = Projection( + self.target_embed_dimension, self.target_embed_dimension, pre_proj + ) self.dsc_layers = dsc_layers self.dsc_hidden = dsc_hidden @@ -236,7 +255,7 @@ def calculate_mean(self, images): self.forward_modules.eval() with torch.no_grad(): if self.pre_proj > 0: - outputs = self.pre_projection(self.generate_embeddings(images)) + outputs = self.pre_projection(self.generate_embeddings(images)[0]) outputs = outputs[0] if len(outputs) == 2 else outputs else: outputs = self._embed(images, evaluation=False)[0] @@ -261,9 +280,13 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): for i, feat in enumerate(features): if len(feat.shape) == 3: B, L, C = feat.shape - features[i] = feat.reshape(B, int(math.sqrt(L)), int(math.sqrt(L)), C).permute(0, 3, 1, 2) + features[i] = feat.reshape( + B, int(math.sqrt(L)), int(math.sqrt(L)), C + ).permute(0, 3, 1, 2) - features = [self.patch_maker.patchify(x, return_spatial_info=True) for x in features] + features = [ + self.patch_maker.patchify(x, return_spatial_info=True) for x in features + ] patch_shapes = [x[1] for x in features] patch_features = [x[0] for x in features] ref_num_patches = patch_shapes[0] @@ -305,14 +328,18 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): def forward(self, img, aug, evaluation=False): if self.pre_proj > 0: - fake_feats = self.pre_projection(self.generate_embeddings(aug, evaluation=evaluation)[0]) + fake_feats = self.pre_projection( + self.generate_embeddings(aug, eval=evaluation)[0] + ) fake_feats = fake_feats[0] if len(fake_feats) == 2 else fake_feats - true_feats = self.pre_projection(self.generate_embeddings(img, evaluation=evaluation)[0]) + true_feats = self.pre_projection( + self.generate_embeddings(img, eval=evaluation)[0] + ) true_feats = true_feats[0] if len(true_feats) == 2 else true_feats else: - fake_feats = self.generate_embeddings(aug, evaluation=evaluation)[0] + fake_feats = self.generate_embeddings(aug, eval=evaluation)[0] fake_feats.requires_grad = True - true_feats = self.generate_embeddings(img, evaluation=evaluation)[0] + true_feats = self.generate_embeddings(img, eval=evaluation)[0] true_feats.requires_grad = True return true_feats, fake_feats From 7fea20f2f2a1f22ce1d1bc6f0bdeebb42c8586c1 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Thu, 19 Jun 2025 17:49:16 +0530 Subject: [PATCH 09/23] Matched code to the original implementation Signed-off-by: Devansh Agarwal --- .../models/image/glass/lightning_model.py | 87 ++++++++++++++-- .../models/image/glass/torch_model.py | 98 +++++++++++++++---- 2 files changed, 155 insertions(+), 30 deletions(-) diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index 0b39d08587..4661aa323c 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -15,14 +15,18 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import math from typing import Any import torch from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import nn, optim +from torch.nn import functional as F +from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, Resize from anomalib import LearningType from anomalib.data import Batch +from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator from anomalib.metrics import Evaluator from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor @@ -32,11 +36,9 @@ from .loss import FocalLoss from .torch_model import GlassModel -from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator - class Glass(AnomalibModule): - """PyTorch Lightning Implementation of the GLASS Model + """PyTorch Lightning Implementation of the GLASS Model. The model uses a pre-trained feature extractor to extract features and a feature adaptor to mitigate latent domain bias. Global anomaly features are synthesized from adapted normal features using gradient ascent. @@ -88,7 +90,10 @@ class Glass(AnomalibModule): Defaults to `0.5`. lr (float, optional): Learning rate for training the feature adaptor and discriminator networks. Defaults to `0.0001`. - step (int, optional): Number of gradient ascent steps or + step (int, optional): Number of gradient ascent steps for anomaly synthesis. + Defaults to `20`. + svd (int, optional): Flag to enable SVD-based feature projection. + Defaults to `0`. """ def __init__( @@ -116,6 +121,7 @@ def __init__( p: float = 0.5, lr: float = 0.0001, step: int = 20, + svd: int = 0, ): super().__init__( pre_processor=pre_processor, @@ -149,12 +155,15 @@ def __init__( self.distribution = 0 self.lr = lr self.step = step + self.svd = svd self.focal_loss = FocalLoss() if pre_proj > 0: self.proj_opt = optim.AdamW( - self.model.pre_projection.parameters(), self.lr, weight_decay=1e-5 + self.model.pre_projection.parameters(), + self.lr, + weight_decay=1e-5, ) else: self.proj_opt = None @@ -167,6 +176,31 @@ def __init__( else: self.backbone_opt = None + @classmethod + def configure_pre_processor( + cls, + image_size: tuple[int, int] | None = None, + center_crop_size: tuple[int, int] | None = None, + ) -> PreProcessor: + image_size = image_size or (256, 256) + + if center_crop_size is not None: + if center_crop_size[0] > image_size[0] or center_crop_size[1] > image_size[1]: + msg = f"Center crop size {center_crop_size} cannot be larger than image size {image_size}." + raise ValueError(msg) + transform = Compose([ + Resize(image_size, antialias=True), + CenterCrop(center_crop_size), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + else: + transform = Compose([ + Resize(image_size, antialias=True), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + return PreProcessor(transform=transform) + def configure_optimizers(self) -> list[optim.Optimizer]: dsc_opt = optim.AdamW(self.model.discriminator.parameters(), lr=self.lr * 2) @@ -177,6 +211,15 @@ def training_step( batch: Batch, batch_idx: int, ) -> STEP_OUTPUT: + """Training step for GLASS model. + + Args: + batch (Batch): Input batch containing images and metadata + batch_idx (int): Index of the current batch + + Returns: + STEP_OUTPUT: Dictionary containing loss values and metrics + """ dsc_opt = self.optimizers() self.model.forward_modules.eval() @@ -192,17 +235,28 @@ def training_step( img = batch.image aug, mask_s = self.augmentor(img) + batch_size = img.shape[0] true_feats, fake_feats = self.model(img, aug) - mask_s_gt = mask_s.reshape(-1, 1) + h_ratio = mask_s.shape[2] // int(math.sqrt(fake_feats.shape[0] // batch_size)) + w_ratio = mask_s.shape[3] // int(math.sqrt(fake_feats.shape[0] // batch_size)) + + mask_s_resized = F.interpolate( + mask_s.float(), + size=(mask_s.shape[2] // h_ratio, mask_s.shape[3] // w_ratio), + mode="nearest", + ) + mask_s_gt = mask_s_resized.reshape(-1, 1) + noise = torch.normal(0, self.noise, true_feats.shape) gaus_feats = true_feats + noise center = self.c.repeat(img.shape[0], 1, 1) center = center.reshape(-1, center.shape[-1]) true_points = torch.concat( - [fake_feats[mask_s_gt[:, 0] == 0], true_feats], dim=0 + [fake_feats[mask_s_gt[:, 0] == 0], true_feats], + dim=0, ) c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0) dist_t = torch.norm(true_points - c_t_points, dim=1) @@ -235,7 +289,6 @@ def training_step( true_points = true_feats[mask_s_gt[:, 0] == 1] c_f_points = center[mask_s_gt[:, 0] == 1] dist_f = torch.norm(fake_points - c_f_points, dim=1) - r_f = torch.tensor([torch.quantile(dist_f, q=self.radius)]).to(self.device) proj_feats = c_f_points if self.svd == 1 else true_points r = r_t if self.svd == 1 else 1 @@ -270,7 +323,18 @@ def training_step( self.backbone_opt.step() dsc_opt.step() + self.log("true_loss", true_loss, prog_bar=True) + self.log("gaus_loss", gaus_loss, prog_bar=True) + self.log("bce_loss", bce_loss, prog_bar=True) + self.log("focal_losss", focal_loss, prog_bar=True) + self.log("loss", loss, prog_bar=True) + def on_train_start(self) -> None: + """Initialize model by computing mean feature representation across training dataset. + + This method is called at the start of training and computes a mean feature vector + that serves as a reference point for the normal class distribution. + """ dataloader = self.trainer.train_dataloader with torch.no_grad(): @@ -293,6 +357,9 @@ def learning_type(self) -> LearningType: @property def trainer_arguments(self) -> dict[str, Any]: - """Return GLASS trainer arguments.""" + """Return GLASS trainer arguments. + + Returns: + dict[str, Any]: Dictionary containing trainer configuration + """ return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} - # TODO diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 1d9e5fe50b..b311b47134 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -14,6 +14,10 @@ def init_weight(m): + """Initializes network weights using Xavier normal initialization for + linear layers and normal initialization for convolutional and batch + normalization layers. + """ if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_normal_(m.weight) if isinstance(m, torch.nn.BatchNorm2d): @@ -28,18 +32,29 @@ def _deduce_dims( input_size: tuple[int, int], layers: list[str], ) -> list[int | tuple[int, int]]: + """Determines feature dimensions for each layer in the feature extractor. + Parameters: + feature_extractor: The backbone feature extractor + input_size: Input image dimensions + layers: List of layer names to extract features from + """ dimensions_mapping = dryrun_find_featuremap_dims( - feature_extractor, input_size, layers + feature_extractor, + input_size, + layers, ) - n_features_original = [ - dimensions_mapping[layer]["num_features"] for layer in layers - ] + n_features_original = [dimensions_mapping[layer]["num_features"] for layer in layers] return n_features_original class Preprocessing(torch.nn.Module): + """Handles initial feature preprocessing across multiple input dimensions. + Input: List of features from different backbone layers + Output: Processed features with consistent dimensionality + """ + def __init__(self, input_dims, output_dim): super(Preprocessing, self).__init__() self.input_dims = input_dims @@ -58,6 +73,11 @@ def forward(self, features): class MeanMapper(torch.nn.Module): + """Maps input features to a fixed dimension using adaptive average pooling. + Input: Variable-sized feature tensors + Output: Fixed-size feature representations + """ + def __init__(self, preprocessing_dim): super(MeanMapper, self).__init__() self.preprocessing_dim = preprocessing_dim @@ -68,6 +88,11 @@ def forward(self, features): class Aggregator(torch.nn.Module): + """Aggregates and reshapes features to a target dimension. + Input: Multi-dimensional feature tensors + Output: Reshaped and pooled features of specified target dimension + """ + def __init__(self, target_dim): super(Aggregator, self).__init__() self.target_dim = target_dim @@ -80,6 +105,14 @@ def forward(self, features): class Projection(torch.nn.Module): + """Multi-layer projection network for feature adaptation. + Parameters: + in_planes: Input feature dimension + out_planes: Output feature dimension + n_layers: Number of projection layers + layer_type: Type of intermediate layers + """ + def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0): super(Projection, self).__init__() @@ -103,6 +136,13 @@ def forward(self, x): class Discriminator(torch.nn.Module): + """Discriminator network for anomaly detection. + Parameters: + in_planes: Input feature dimension + n_layers: Number of layers + hidden: Hidden layer dimensions + """ + def __init__(self, in_planes, n_layers=2, hidden=None): super(Discriminator, self).__init__() @@ -120,7 +160,8 @@ def __init__(self, in_planes, n_layers=2, hidden=None): ), ) self.tail = torch.nn.Sequential( - torch.nn.Linear(_hidden, 1, bias=False), torch.nn.Sigmoid() + torch.nn.Linear(_hidden, 1, bias=False), + torch.nn.Sigmoid(), ) self.apply(init_weight) @@ -131,6 +172,14 @@ def forward(self, x): class PatchMaker: + """Handles patch-based processing of feature maps. + + Methods: + patchify: Converts features into patches + unpatch_scores: Reshapes patch scores back to original dimensions + score: Computes final scores from patch-wise predictions + """ + def __init__(self, patchsize, top_k=0, stride=None): self.patchsize = patchsize self.stride = stride @@ -139,14 +188,15 @@ def __init__(self, patchsize, top_k=0, stride=None): def patchify(self, features, return_spatial_info=False): padding = int((self.patchsize - 1) / 2) unfolder = torch.nn.Unfold( - kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1 + kernel_size=self.patchsize, + stride=self.stride, + padding=padding, + dilation=1, ) unfolded_features = unfolder(features) number_of_total_patches = [] for s in features.shape[-2:]: - n_patches = ( - s + 2 * padding - 1 * (self.patchsize - 1) - 1 - ) / self.stride + 1 + n_patches = (s + 2 * padding - 1 * (self.patchsize - 1) - 1) / self.stride + 1 number_of_total_patches.append(int(n_patches)) unfolded_features = unfolded_features.reshape( *features.shape[:2], @@ -170,6 +220,12 @@ def score(self, x): class RescaleSegmentor: + """Handles rescaling of patch-based anomaly scores to full-image dimensions. + Parameters: + target_size: Target image size for rescaling + smoothing: Gaussian smoothing parameter for score smoothing + """ + def __init__(self, target_size=288): self.target_size = target_size self.smoothing = 4 @@ -188,13 +244,12 @@ def convert_to_segmentation(self, patch_scores): ) _scores = _scores.squeeze(1) patch_scores = _scores.cpu().numpy() - return [ - ndimage.gaussian_filter(patch_score, sigma=self.smoothing) - for patch_score in patch_scores - ] + return [ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores] class GlassModel(nn.Module): + """PyTorch Implementation of the GLASS Model.""" + def __init__( self, input_shape: tuple[int, int], # (H, W) @@ -235,7 +290,9 @@ def __init__( self.pre_proj = pre_proj if self.pre_proj > 0: self.pre_projection = Projection( - self.target_embed_dimension, self.target_embed_dimension, pre_proj + self.target_embed_dimension, + self.target_embed_dimension, + pre_proj, ) self.dsc_layers = dsc_layers @@ -281,12 +338,13 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): if len(feat.shape) == 3: B, L, C = feat.shape features[i] = feat.reshape( - B, int(math.sqrt(L)), int(math.sqrt(L)), C + B, + int(math.sqrt(L)), + int(math.sqrt(L)), + C, ).permute(0, 3, 1, 2) - features = [ - self.patch_maker.patchify(x, return_spatial_info=True) for x in features - ] + features = [self.patch_maker.patchify(x, return_spatial_info=True) for x in features] patch_shapes = [x[1] for x in features] patch_features = [x[0] for x in features] ref_num_patches = patch_shapes[0] @@ -329,11 +387,11 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): def forward(self, img, aug, evaluation=False): if self.pre_proj > 0: fake_feats = self.pre_projection( - self.generate_embeddings(aug, eval=evaluation)[0] + self.generate_embeddings(aug, eval=evaluation)[0], ) fake_feats = fake_feats[0] if len(fake_feats) == 2 else fake_feats true_feats = self.pre_projection( - self.generate_embeddings(img, eval=evaluation)[0] + self.generate_embeddings(img, eval=evaluation)[0], ) true_feats = true_feats[0] if len(true_feats) == 2 else true_feats else: From 1beedf5af03096515f8db9224cc94cd1a54a6172 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Mon, 23 Jun 2025 14:00:43 +0530 Subject: [PATCH 10/23] Added support for gpu Signed-off-by: Devansh Agarwal --- src/anomalib/models/image/glass/__init__.py | 21 +- .../models/image/glass/lightning_model.py | 105 +++-- src/anomalib/models/image/glass/loss.py | 103 +++-- .../models/image/glass/torch_model.py | 368 ++++++++++++------ 4 files changed, 415 insertions(+), 182 deletions(-) diff --git a/src/anomalib/models/image/glass/__init__.py b/src/anomalib/models/image/glass/__init__.py index 4153e59684..a3070bacf4 100644 --- a/src/anomalib/models/image/glass/__init__.py +++ b/src/anomalib/models/image/glass/__init__.py @@ -1,4 +1,23 @@ +"""GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. + +This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both +global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in +industrial settings. + +The model consists of: + - A feature extractor and feature adaptor to obtain robust normal representations + - A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with + truncated projection + - A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks + - A shared discriminator trained with features from normal, global, and local synthetic samples + +Paper: `A Unified Anomaly Synthesis Strategy with Gradient Ascent for Industrial Anomaly Detection and Localization +` +""" + # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from .lightning_model import Glass as Glass +from .lightning_model import Glass + +__all__ = ["Glass"] diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index 4661aa323c..56aed0f706 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -1,10 +1,13 @@ """GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. -This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in industrial settings. +This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both +global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in +industrial settings. The model consists of: - A feature extractor and feature adaptor to obtain robust normal representations - - A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with truncated projection + - A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with + truncated projection - A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks - A shared discriminator trained with features from normal, global, and local synthetic samples @@ -21,7 +24,7 @@ import torch from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import nn, optim -from torch.nn import functional as F +from torch.nn import functional as f from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, Resize from anomalib import LearningType @@ -40,17 +43,21 @@ class Glass(AnomalibModule): """PyTorch Lightning Implementation of the GLASS Model. - The model uses a pre-trained feature extractor to extract features and a feature adaptor to mitigate latent domain bias. + The model uses a pre-trained feature extractor to extract features and a feature adaptor to mitigate latent domain + bias. Global anomaly features are synthesized from adapted normal features using gradient ascent. - Local anomaly images are synthesized using texture overlay datasets like dtd which are then processed by feature extractor and feature adaptor. + Local anomaly images are synthesized using texture overlay datasets like dtd which are then processed by feature + extractor and feature adaptor. All three different features are passed to the discriminator trained using loss functions. Args: - input_shape (tuple[int, int]): Input image dimensions as a tuple of (height, width). Required for shaping the input pipeline. - anomaly_source_path (str): Path to the dataset or source directory containing normal images and anomaly textures. + input_shape (tuple[int, int]): Input image dimensions as a tuple of (height, width). Required for shaping the + input pipeline. + anomaly_source_path (str): Path to the dataset or source directory containing normal images and anomaly textures backbone (str, optional): Name of the CNN backbone used for feature extraction. Defaults to `"resnet18"`. - pretrain_embed_dim (int, optional): Dimensionality of features extracted by the pre-trained backbone before adaptation. + pretrain_embed_dim (int, optional): Dimensionality of features extracted by the pre-trained backbone before + adaptation. Defaults to `1024`. target_embed_dim (int, optional): Dimensionality of the target adapted features after projection. Defaults to `1024`. @@ -62,31 +69,37 @@ class Glass(AnomalibModule): Defaults to `True`. layers (list[str], optional): List of backbone layers to extract features from. Defaults to `["layer1", "layer2", "layer3"]`. - pre_proj (int, optional): Number of projection layers used in the feature adaptor (e.g., MLP before discriminator). + pre_proj (int, optional): Number of projection layers used in the feature adaptor (e.g., MLP before + discriminator). Defaults to `1`. dsc_layers (int, optional): Number of layers in the discriminator network. Defaults to `2`. dsc_hidden (int, optional): Number of hidden units in each discriminator layer. Defaults to `1024`. - dsc_margin (float, optional): Margin used for contrastive or binary classification loss in discriminator training. + dsc_margin (float, optional): Margin used for contrastive or binary classification loss in discriminator + training. Defaults to `0.5`. pre_processor (PreProcessor | bool, optional): reprocessing module or flag to enable default preprocessing. Set to `True` to apply default normalization and resizing. Defaults to `True`. - post_processor (PostProcessor | bool, optional): Postprocessing module or flag to enable default output smoothing or thresholding. + post_processor (PostProcessor | bool, optional): Postprocessing module or flag to enable default output + smoothing or thresholding. Defaults to `True`. evaluator (Evaluator | bool, optional): Evaluation module for calculating metrics such as AUROC and PRO. Defaults to `True`. - visualizer (Visualizer | bool, optional): Visualization module to generate heatmaps, segmentation overlays, and anomaly scores. + visualizer (Visualizer | bool, optional): Visualization module to generate heatmaps, segmentation overlays, and + anomaly scores. Defaults to `True`. - mining (int, optional): Number of iterations or difficulty level for Online Hard Example Mining (OHEM) during training. + mining (int, optional): Number of iterations or difficulty level for Online Hard Example Mining (OHEM) during + training. Defaults to `1`. noise (float, optional): Standard deviation of Gaussian noise used in feature-level anomaly synthesis. Defaults to `0.015`. radius (float, optional): Radius parameter used for truncated projection in the anomaly synthesis strategy. Determines the range for valid synthetic anomalies in the hypersphere or manifold. Defaults to `0.75`. - p (float, optional): Probability used in random selection logic, such as anomaly mask generation or augmentation choice. + p (float, optional): Probability used in random selection logic, such as anomaly mask generation or augmentation + choice. Defaults to `0.5`. lr (float, optional): Learning rate for training the feature adaptor and discriminator networks. Defaults to `0.0001`. @@ -106,7 +119,7 @@ def __init__( patchsize: int = 3, patchstride: int = 1, pre_trained: bool = True, - layers: list[str] = ["layer1", "layer2", "layer3"], + layers: list[str] | None = None, pre_proj: int = 1, dsc_layers: int = 2, dsc_hidden: int = 1024, @@ -122,7 +135,7 @@ def __init__( lr: float = 0.0001, step: int = 20, svd: int = 0, - ): + ) -> None: super().__init__( pre_processor=pre_processor, post_processor=post_processor, @@ -130,6 +143,9 @@ def __init__( visualizer=visualizer, ) + if layers is None: + layers = ["layer1", "layer2", "layer3"] + self.augmentor = PerlinAnomalyGenerator(anomaly_source_path) self.model = GlassModel( @@ -157,6 +173,8 @@ def __init__( self.step = step self.svd = svd + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.focal_loss = FocalLoss() if pre_proj > 0: @@ -170,7 +188,7 @@ def __init__( if not pre_trained: self.backbone_opt = optim.AdamW( - self.model.foward_modules["feature_aggregator"].backbone.parameters(), + self.mosdel.forward_modules["feature_aggregator"].backbone.parameters(), self.lr, ) else: @@ -182,6 +200,30 @@ def configure_pre_processor( image_size: tuple[int, int] | None = None, center_crop_size: tuple[int, int] | None = None, ) -> PreProcessor: + """Configure the default pre-processor for GLASS. + + If valid center_crop_size is provided, the pre-processor will + also perform center cropping, according to the paper. + + Args: + image_size (tuple[int, int] | None, optional): Target size for + resizing. Defaults to ``(256, 256)``. + center_crop_size (tuple[int, int] | None, optional): Size for center + cropping. Defaults to ``None``. + + Returns: + PreProcessor: Configured pre-processor instance. + + Raises: + ValueError: If at least one dimension of ``center_crop_size`` is larger + than correspondent ``image_size`` dimension. + + Example: + >>> pre_processor = Glass.configure_pre_processor( + ... image_size=(256, 256) + ... ) + >>> transformed_image = pre_processor(image) + """ image_size = image_size or (256, 256) if center_crop_size is not None: @@ -201,10 +243,13 @@ def configure_pre_processor( return PreProcessor(transform=transform) - def configure_optimizers(self) -> list[optim.Optimizer]: - dsc_opt = optim.AdamW(self.model.discriminator.parameters(), lr=self.lr * 2) + def configure_optimizers(self) -> optim.Optimizer: + """Configure optimizer for the discriminator. - return dsc_opt + Returns: + Optimizer: AdamW Optimizer for the discriminator. + """ + return optim.AdamW(self.model.discriminator.parameters(), lr=self.lr * 2) def training_step( self, @@ -220,6 +265,7 @@ def training_step( Returns: STEP_OUTPUT: Dictionary containing loss values and metrics """ + del batch_idx dsc_opt = self.optimizers() self.model.forward_modules.eval() @@ -235,21 +281,22 @@ def training_step( img = batch.image aug, mask_s = self.augmentor(img) - batch_size = img.shape[0] + if img is not None: + batch_size = img.shape[0] true_feats, fake_feats = self.model(img, aug) h_ratio = mask_s.shape[2] // int(math.sqrt(fake_feats.shape[0] // batch_size)) w_ratio = mask_s.shape[3] // int(math.sqrt(fake_feats.shape[0] // batch_size)) - mask_s_resized = F.interpolate( + mask_s_resized = f.interpolate( mask_s.float(), size=(mask_s.shape[2] // h_ratio, mask_s.shape[3] // w_ratio), mode="nearest", ) mask_s_gt = mask_s_resized.reshape(-1, 1) - noise = torch.normal(0, self.noise, true_feats.shape) + noise = torch.normal(0, self.noise, true_feats.shape).to(self.dev) gaus_feats = true_feats + noise center = self.c.repeat(img.shape[0], 1, 1) @@ -260,7 +307,7 @@ def training_step( ) c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0) dist_t = torch.norm(true_points - c_t_points, dim=1) - r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(self.device) + r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(self.dev) for step in range(self.step + 1): scores = self.model.discriminator(torch.cat([true_feats, gaus_feats])) @@ -272,10 +319,6 @@ def training_step( if step == self.step: break - if self.mining == 0: - dist_g = torch.norm(gaus_feats - center, dim=1) - r_g = torch.tensor([torch.quantile(dist_g, q=self.radius)]) - break grad = torch.autograd.grad(gaus_loss, [gaus_feats])[0] grad_norm = torch.norm(grad, dim=1) @@ -326,7 +369,7 @@ def training_step( self.log("true_loss", true_loss, prog_bar=True) self.log("gaus_loss", gaus_loss, prog_bar=True) self.log("bce_loss", bce_loss, prog_bar=True) - self.log("focal_losss", focal_loss, prog_bar=True) + self.log("focal_loss", focal_loss, prog_bar=True) self.log("loss", loss, prog_bar=True) def on_train_start(self) -> None: @@ -340,9 +383,9 @@ def on_train_start(self) -> None: with torch.no_grad(): for i, batch in enumerate(dataloader): if i == 0: - self.c = self.model.calculate_mean(batch.image) + self.c = self.model.calculate_mean(batch.image.to(self.dev)) else: - self.c += self.model.calculate_mean(batch.image) + self.c += self.model.calculate_mean(batch.image.to(self.dev)) self.c /= len(dataloader) diff --git a/src/anomalib/models/image/glass/loss.py b/src/anomalib/models/image/glass/loss.py index 39a9b52ab5..db3c9da5cb 100644 --- a/src/anomalib/models/image/glass/loss.py +++ b/src/anomalib/models/image/glass/loss.py @@ -1,3 +1,45 @@ +"""Focal Loss for multi-class classification with optional label smoothing and class weighting. + +This loss function is designed to address class imbalance by down-weighting easy examples and focusing training +on hard, misclassified examples. It is based on the paper: +"Focal Loss for Dense Object Detection" (https://arxiv.org/abs/1708.02002). + +The focal loss formula is: + FL(pt) = -alpha * (1 - pt) ** gamma * log(pt) + +where: + - pt is the predicted probability of the correct class + - alpha is a class balancing factor + - gamma is a focusing parameter + +Supports optional label smoothing and flexible alpha input (scalar or per-class tensor). Can be used with raw logits, +applying a specified non-linearity (e.g., softmax or sigmoid). + +Args: + apply_nonlin (nn.Module or None): Optional non-linearity to apply to the logits before loss computation. + For example, use `nn.Softmax(dim=1)` or `nn.Sigmoid()` if logits are not normalized. + alpha (float or torch.Tensor, optional): Class balancing factor. Can be: + - None: Equal weighting for all classes. + - float: Scalar for binary class weighting; applied to `balance_index`. + - Tensor: Per-class weights of shape (num_classes,). + gamma (float): Focusing parameter (> 0) to reduce the loss contribution from easy examples. Default is 2. + balance_index (int): Index of the class to apply `alpha` to when `alpha` is a float. + smooth (float): Label smoothing factor. A small value (e.g., 1e-5) helps prevent overconfidence. + size_average (bool): If True, average the loss over the batch; if False, sum the loss. + +Raises: + ValueError: If `smooth` is outside the range [0, 1]. + TypeError: If `alpha` is not a supported type. + +Inputs: + logit (torch.Tensor): Raw model outputs (logits) of shape (B, C, ...) where B is batch size and C is number of + classes. + target (torch.Tensor): Ground-truth class indices of shape (B, 1, ...) or broadcastable to match logit. + +Returns: + torch.Tensor: Scalar loss value (averaged or summed based on `size_average`). +""" + # Original Code # Copyright (c) 2021 @Hsuxu # https://github.com/Hsuxu/Loss_ToolBox-PyTorch. @@ -7,27 +49,27 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - import numpy as np import torch from torch import nn class FocalLoss(nn.Module): - """Implementation of Focal Loss with support for smoothed label cross-entropy, as proposed in - 'Focal Loss for Dense Object Detection' (https://arxiv.org/abs/1708.02002). + """Implementation of Focal Loss with support for smoothed label cross-entropy. + As proposed in 'Focal Loss for Dense Object Detection' (https://arxiv.org/abs/1708.02002). The focal loss formula is: Focal_Loss = -1 * alpha * (1 - pt) ** gamma * log(pt) Args: num_class (int): Number of classes. - alpha (float or Tensor): Scalar or Tensor weight factor for class imbalance. If float, `balance_index` should be set. + alpha (float or Tensor): Scalar or Tensor weight factor for class imbalance. If float, `balance_index` should be + set. gamma (float): Focusing parameter that reduces the relative loss for well-classified examples (gamma > 0). smooth (float): Label smoothing factor for cross-entropy. balance_index (int): Index of the class to balance when `alpha` is a float. - size_average (bool, optional): If True (default), the loss is averaged over the batch; otherwise, the loss is summed. - + size_average (bool, optional): If True (default), the loss is averaged over the batch; otherwise, the loss is + summed. """ def __init__( @@ -38,8 +80,21 @@ def __init__( balance_index: int = 0, smooth: float = 1e-5, size_average: bool = True, - ): - super(FocalLoss, self).__init__() + ) -> None: + """Initializes the FocalLoss instance. + + Args: + apply_nonlin (nn.Module or None): Optional non-linearity to apply to logits (e.g., softmax or sigmoid). + alpha (float or torch.Tensor, optional): Weighting factor for class imbalance. Can be: + - None: Equal weighting. + - float: Class at `balance_index` is weighted by `alpha`, others by 1 - `alpha`. + - Tensor: Direct per-class weights. + gamma (float): Focusing parameter for down-weighting easy examples (y > 0). + balance_index (int): Index of the class to apply `alpha` to when `alpha` is a float. + smooth (float): Label smoothing factor (0 to 1). + size_average (bool): If True, average the loss over the batch. If False, sum the loss. + """ + super().__init__() self.apply_nonlin = apply_nonlin self.alpha = alpha self.gamma = gamma @@ -47,11 +102,21 @@ def __init__( self.smooth = smooth self.size_average = size_average - if self.smooth is not None: - if self.smooth < 0 or self.smooth > 1.0: - raise ValueError("smooth value should be in [0,1]") + if self.smooth is not None and (self.smooth < 0 or self.smooth > 1.0): + msg = "smooth value should be in [0,1]" + raise ValueError(msg) + + def forward(self, logit: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Computes the focal loss between `logit` predictions and ground-truth `target`. + + Args: + logit (torch.Tensor): The predicted logits of shape (B, C, ...) where B is batch size and C is the number of + classes. + target (torch.Tensor): The ground-truth class indices of shape (B, 1, ...) or broadcastable to logit. - def forward(self, logit, target): + Returns: + torch.Tensor: Computed focal loss value (averaged or summed depending on `size_average`). + """ if self.apply_nonlin is not None: logit = self.apply_nonlin(logit) num_class = logit.shape[1] @@ -66,7 +131,7 @@ def forward(self, logit, target): alpha = self.alpha if alpha is None: alpha = torch.ones(num_class, 1) - elif isinstance(alpha, (list, np.ndarray)): + elif isinstance(alpha, (list | np.ndarray)): assert len(alpha) == num_class alpha = torch.FloatTensor(alpha).view(num_class, 1) alpha = alpha / alpha.sum() @@ -74,15 +139,14 @@ def forward(self, logit, target): alpha = torch.ones(num_class, 1) alpha = alpha * (1 - self.alpha) alpha[self.balance_index] = self.alpha - else: - raise TypeError("Not support alpha type") + msg = "Not support alpha type" + raise TypeError(msg) if alpha.device != logit.device: alpha = alpha.to(logit.device) idx = target.cpu().long() - one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() one_hot_key = one_hot_key.scatter_(1, idx, 1) if one_hot_key.device != logit.device: @@ -98,13 +162,8 @@ def forward(self, logit, target): logpt = pt.log() gamma = self.gamma - alpha = alpha[idx] alpha = torch.squeeze(alpha) loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt - if self.size_average: - loss = loss.mean() - else: - loss = loss.sum() - return loss + return loss.mean() if self.size_average else loss.sum() diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index b311b47134..18f178fe76 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -1,22 +1,38 @@ +"""GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. + +This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both +global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in +industrial settings. + +The model consists of: + - A feature extractor and feature adaptor to obtain robust normal representations + - A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with + truncated projection + - A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks + - A shared discriminator trained with features from normal, global, and local synthetic samples + +Paper: `A Unified Anomaly Synthesis Strategy with Gradient Ascent for Industrial Anomaly Detection and Localization +` +""" + # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import math -import numpy as np import torch -import torch.nn.functional as F -from scipy import ndimage +import torch.nn.functional as f from torch import nn from anomalib.models.components import TimmFeatureExtractor from anomalib.models.components.feature_extractors import dryrun_find_featuremap_dims -def init_weight(m): - """Initializes network weights using Xavier normal initialization for - linear layers and normal initialization for convolutional and batch - normalization layers. +def init_weight(m: nn.Module) -> None: + """Initializes network weights using Xavier normal initialization. + + Applies Xavier initialization for linear layers and normal initialization + for convolutional and batch normalization layers. """ if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_normal_(m.weight) @@ -33,10 +49,11 @@ def _deduce_dims( layers: list[str], ) -> list[int | tuple[int, int]]: """Determines feature dimensions for each layer in the feature extractor. - Parameters: - feature_extractor: The backbone feature extractor - input_size: Input image dimensions - layers: List of layer names to extract features from + + Args: + feature_extractor: The backbone feature extractor + input_size: Input image dimensions + layers: List of layer names to extract features from """ dimensions_mapping = dryrun_find_featuremap_dims( feature_extractor, @@ -44,19 +61,18 @@ def _deduce_dims( layers, ) - n_features_original = [dimensions_mapping[layer]["num_features"] for layer in layers] - - return n_features_original + return [dimensions_mapping[layer]["num_features"] for layer in layers] class Preprocessing(torch.nn.Module): """Handles initial feature preprocessing across multiple input dimensions. + Input: List of features from different backbone layers Output: Processed features with consistent dimensionality """ - def __init__(self, input_dims, output_dim): - super(Preprocessing, self).__init__() + def __init__(self, input_dims: list[int | tuple[int, int]], output_dim: int) -> None: + super().__init__() self.input_dims = input_dims self.output_dim = output_dim @@ -65,127 +81,181 @@ def __init__(self, input_dims, output_dim): module = MeanMapper(output_dim) self.preprocessing_modules.append(module) - def forward(self, features): - _features = [] + def forward(self, features: list[torch.Tensor]) -> torch.Tensor: + """Applies preprocessing modules to a list of input feature tensors. + + Args: + features (list of torch.Tensor): List of feature maps from different + layers of the backbone network. Each tensor can have a different shape. + + Returns: + torch.Tensor: A single tensor with shape (B, N, D), where B is the batch size, + N is the number of feature maps, and D is the output dimension (`output_dim`). + """ + features_ = [] for module, feature in zip(self.preprocessing_modules, features, strict=False): - _features.append(module(feature)) - return torch.stack(_features, dim=1) + features_.append(module(feature)) + return torch.stack(features_, dim=1) class MeanMapper(torch.nn.Module): """Maps input features to a fixed dimension using adaptive average pooling. + Input: Variable-sized feature tensors Output: Fixed-size feature representations """ - def __init__(self, preprocessing_dim): - super(MeanMapper, self).__init__() + def __init__(self, preprocessing_dim: int) -> None: + super().__init__() self.preprocessing_dim = preprocessing_dim - def forward(self, features): + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Applies adaptive average pooling to reshape features to a fixed size. + + Args: + features (torch.Tensor): Input tensor of shape (B, *) where * denotes + any number of remaining dimensions. It is flattened before pooling. + + Returns: + torch.Tensor: Output tensor of shape (B, D), where D is `preprocessing_dim`. + """ features = features.reshape(len(features), 1, -1) - return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1) + return f.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1) class Aggregator(torch.nn.Module): """Aggregates and reshapes features to a target dimension. + Input: Multi-dimensional feature tensors Output: Reshaped and pooled features of specified target dimension """ - def __init__(self, target_dim): - super(Aggregator, self).__init__() + def __init__(self, target_dim: int) -> None: + super().__init__() self.target_dim = target_dim - def forward(self, features): + def forward(self, features: torch.Tensor) -> torch.Tensor: """Returns reshaped and average pooled features.""" features = features.reshape(len(features), 1, -1) - features = F.adaptive_avg_pool1d(features, self.target_dim) + features = f.adaptive_avg_pool1d(features, self.target_dim) return features.reshape(len(features), -1) class Projection(torch.nn.Module): """Multi-layer projection network for feature adaptation. - Parameters: - in_planes: Input feature dimension - out_planes: Output feature dimension - n_layers: Number of projection layers - layer_type: Type of intermediate layers + + Args: + in_planes: Input feature dimension + out_planes: Output feature dimension + n_layers: Number of projection layers + layer_type: Type of intermediate layers """ - def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0): - super(Projection, self).__init__() + def __init__(self, in_planes: int, out_planes: int | None = None, n_layers: int = 1, layer_type: int = 0) -> None: + super().__init__() if out_planes is None: out_planes = in_planes self.layers = torch.nn.Sequential() - _in = None - _out = None + in_ = None + out = None for i in range(n_layers): - _in = in_planes if i == 0 else _out - _out = out_planes - self.layers.add_module(f"{i}fc", torch.nn.Linear(_in, _out)) - if i < n_layers - 1: - if layer_type > 1: - self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(0.2)) + in_ = in_planes if i == 0 else out + out = out_planes + self.layers.add_module(f"{i}fc", torch.nn.Linear(in_, out)) + if i < n_layers - 1 and layer_type > 1: + self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(0.2)) self.apply(init_weight) - def forward(self, x): - x = self.layers(x) - return x + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies the projection network to the input features. + + Args: + x (torch.Tensor): Input tensor of shape (B, in_planes), where B is the batch size. + + Returns: + torch.Tensor: Transformed tensor of shape (B, out_planes). + """ + return self.layers(x) class Discriminator(torch.nn.Module): """Discriminator network for anomaly detection. - Parameters: - in_planes: Input feature dimension - n_layers: Number of layers - hidden: Hidden layer dimensions + + Args: + in_planes: Input feature dimension + n_layers: Number of layers + hidden: Hidden layer dimensions """ - def __init__(self, in_planes, n_layers=2, hidden=None): - super(Discriminator, self).__init__() + def __init__(self, in_planes: int, n_layers: int = 2, hidden: int | None = None) -> None: + super().__init__() - _hidden = in_planes if hidden is None else hidden + hidden_ = in_planes if hidden is None else hidden self.body = torch.nn.Sequential() for i in range(n_layers - 1): - _in = in_planes if i == 0 else _hidden - _hidden = int(_hidden // 1.5) if hidden is None else hidden + in_ = in_planes if i == 0 else hidden_ + hidden_ = int(hidden_ // 1.5) if hidden is None else hidden self.body.add_module( - "block%d" % (i + 1), + f"block{i + 1}", torch.nn.Sequential( - torch.nn.Linear(_in, _hidden), - torch.nn.BatchNorm1d(_hidden), + torch.nn.Linear(in_, hidden_), + torch.nn.BatchNorm1d(hidden_), torch.nn.LeakyReLU(0.2), ), ) self.tail = torch.nn.Sequential( - torch.nn.Linear(_hidden, 1, bias=False), + torch.nn.Linear(hidden_, 1, bias=False), torch.nn.Sigmoid(), ) self.apply(init_weight) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs a forward pass through the discriminator network. + + Args: + x (torch.Tensor): Input tensor of shape (B, in_planes), where B is the batch size. + + Returns: + torch.Tensor: Output tensor of shape (B, 1) containing probability scores. + """ x = self.body(x) - x = self.tail(x) - return x + return self.tail(x) class PatchMaker: """Handles patch-based processing of feature maps. - Methods: - patchify: Converts features into patches - unpatch_scores: Reshapes patch scores back to original dimensions - score: Computes final scores from patch-wise predictions + This class provides utilities for converting feature maps into patches, + reshaping patch scores back to original dimensions, and computing global + anomaly scores from patch-wise predictions. + + Attributes: + patchsize (int): Size of each patch (patchsize x patchsize). + stride (int or None): Stride used for patch extraction. Defaults to patchsize if None. + top_k (int): Number of top patch scores to consider. Used for score reduction. """ - def __init__(self, patchsize, top_k=0, stride=None): + def __init__(self, patchsize: int, top_k: int = 0, stride: int | None = None) -> None: self.patchsize = patchsize - self.stride = stride + self.stride = stride if stride is not None else patchsize self.top_k = top_k - def patchify(self, features, return_spatial_info=False): + def patchify( + self, + features: torch.Tensor, + return_spatial_info: bool = False, + ) -> tuple[torch.Tensor, list[int]] | torch.Tensor: + """Converts a batch of feature maps into patches. + + Args: + features (torch.Tensor): Input feature maps of shape (B, C, H, W). + return_spatial_info (bool): If True, also returns spatial patch count. Default is False. + + Returns: + torch.Tensor: Output tensor of shape (B, N, C, patchsize, patchsize), where N is number of patches. + list[int], optional: Number of patches in (height, width) dimensions, only if return_spatial_info is True. + """ padding = int((self.patchsize - 1) / 2) unfolder = torch.nn.Unfold( kernel_size=self.patchsize, @@ -210,41 +280,31 @@ def patchify(self, features, return_spatial_info=False): return unfolded_features, number_of_total_patches return unfolded_features - def unpatch_scores(self, x, batchsize): - return x.reshape(batchsize, -1, *x.shape[1:]) + @staticmethod + def unpatch_scores(x: torch.Tensor, batchsize: int) -> torch.Tensor: + """Reshapes patch scores back into per-batch format. - def score(self, x): - x = x[:, :, 0] - x = torch.max(x, dim=1).values - return x + Args: + x (torch.Tensor): Input tensor of shape (B * N, ...). + batchsize (int): Original batch size. + Returns: + torch.Tensor: Reshaped tensor of shape (B, N, ...). + """ + return x.reshape(batchsize, -1, *x.shape[1:]) -class RescaleSegmentor: - """Handles rescaling of patch-based anomaly scores to full-image dimensions. - Parameters: - target_size: Target image size for rescaling - smoothing: Gaussian smoothing parameter for score smoothing - """ + @staticmethod + def score(x: torch.Tensor) -> torch.Tensor: + """Computes final anomaly scores from patch-wise predictions. - def __init__(self, target_size=288): - self.target_size = target_size - self.smoothing = 4 + Args: + x (torch.Tensor): Patch scores of shape (B, N, 1). - def convert_to_segmentation(self, patch_scores): - with torch.no_grad(): - if isinstance(patch_scores, np.ndarray): - patch_scores = torch.from_numpy(patch_scores) - _scores = patch_scores - _scores = _scores.unsqueeze(1) - _scores = F.interpolate( - _scores, - size=self.target_size, - mode="bilinear", - align_corners=False, - ) - _scores = _scores.squeeze(1) - patch_scores = _scores.cpu().numpy() - return [ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores] + Returns: + torch.Tensor: Final anomaly score per image, shape (B,). + """ + x = x[:, :, 0] # remove last dimension if singleton + return torch.max(x, dim=1).to_numpy() class GlassModel(nn.Module): @@ -259,7 +319,7 @@ def __init__( patchsize: int = 3, patchstride: int = 1, pre_trained: bool = True, - layers: list[str] = ["layer1", "layer2", "layer3"], + layers: list[str] | None = None, pre_proj: int = 1, dsc_layers: int = 2, dsc_hidden: int = 1024, @@ -267,6 +327,9 @@ def __init__( ) -> None: super().__init__() + if layers is None: + layers = ["layer1", "layer2", "layer3"] + self.backbone = backbone self.layers = layers self.input_shape = input_shape @@ -306,9 +369,24 @@ def __init__( self.patch_maker = PatchMaker(patchsize, stride=patchstride) - self.anomaly_segmentor = RescaleSegmentor() + def calculate_mean(self, images: torch.Tensor) -> torch.Tensor: + """Computes the mean feature embedding across a batch of images. - def calculate_mean(self, images): + This method performs a forward pass through the model to extract feature embeddings + for a batch of input images, optionally passing them through a pre-projection module. + It then reshapes the output and calculates the mean across the batch dimension. + + Args: + images (torch.Tensor): Input image tensor of shape (B, C, H, W), where: + - B is the batch size, + - C is the number of channels, + - H and W are height and width. + + Returns: + torch.Tensor: Mean embedding tensor of shape (N, D), where: + - N is the number of patches or tokens per image, + - D is the feature dimension. + """ self.forward_modules.eval() with torch.no_grad(): if self.pre_proj > 0: @@ -320,14 +398,37 @@ def calculate_mean(self, images): outputs = outputs[0] if len(outputs) == 2 else outputs outputs = outputs.reshape(images.shape[0], -1, outputs.shape[-1]) - batch_mean = torch.mean(outputs, dim=0) - - return batch_mean + return torch.mean(outputs, dim=0) - def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): - if not eval and not self.pre_trained: + def generate_embeddings( + self, + images: torch.Tensor, + evaluation: bool = False, + ) -> tuple[list[torch.Tensor], list[tuple[int, int]]]: + """Generates patch-wise feature embeddings for a batch of input images. + + This method performs a forward pass through the model's feature extraction pipeline, + processes selected intermediate layers, reshapes them into patches, aligns their spatial sizes, + and passes them through preprocessing and aggregation modules. + + Args: + images (torch.Tensor): Input images of shape (B, C, H, W), where: + - B is the batch size, + - C is the number of channels, + - H and W are the image height and width. + evaluation (bool, optional): Whether to run in evaluation mode (disabling gradients). + Default is False. + + Returns: + tuple[list[torch.Tensor], list[tuple[int, int]]]: + - A list of patch-level feature tensors, each of shape (N, D, P, P), + where N is the number of patches, D is the channel dimension, and P is patch size. + - A list of (height, width) tuples indicating the number of patches in each spatial dimension + for each corresponding feature level. + """ + if not evaluation and not self.pre_trained: self.forward_modules["feature_aggregator"].train() - features = self.forward_modules["feature_aggregator"](images, eval=eval) + features = self.forward_modules["feature_aggregator"](images) else: self.forward_modules["feature_aggregator"].eval() with torch.no_grad(): @@ -336,7 +437,7 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): features = [features[layer] for layer in self.layers] for i, feat in enumerate(features): if len(feat.shape) == 3: - B, L, C = feat.shape + B, L, C = feat.shape # noqa: N806 features[i] = feat.reshape( B, int(math.sqrt(L)), @@ -350,33 +451,33 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): ref_num_patches = patch_shapes[0] for i in range(1, len(patch_features)): - _features = patch_features[i] + features_ = patch_features[i] patch_dims = patch_shapes[i] - _features = _features.reshape( - _features.shape[0], + features_ = features_.reshape( + features_.shape[0], patch_dims[0], patch_dims[1], - *_features.shape[2:], + *features_.shape[2:], ) - _features = _features.permute(0, 3, 4, 5, 1, 2) - perm_base_shape = _features.shape - _features = _features.reshape(-1, *_features.shape[-2:]) - _features = F.interpolate( - _features.unsqueeze(1), + features_ = features_.permute(0, 3, 4, 5, 1, 2) + perm_base_shape = features_.shape + features_ = features_.reshape(-1, *features_.shape[-2:]) + features_ = f.interpolate( + features_.unsqueeze(1), size=(ref_num_patches[0], ref_num_patches[1]), mode="bilinear", align_corners=False, ) - _features = _features.squeeze(1) - _features = _features.reshape( + features_ = features_.squeeze(1) + features_ = features_.reshape( *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1], ) - _features = _features.permute(0, 4, 5, 1, 2, 3) - _features = _features.reshape(len(_features), -1, *_features.shape[-3:]) - patch_features[i] = _features + features_ = features_.permute(0, 4, 5, 1, 2, 3) + features_ = features_.reshape(len(features_), -1, *features_.shape[-3:]) + patch_features[i] = features_ patch_features = [x.reshape(-1, *x.shape[-3:]) for x in patch_features] patch_features = self.forward_modules["preprocessing"](patch_features) @@ -384,20 +485,31 @@ def generate_embeddings(self, images, provide_patch_shapes=False, eval=False): return patch_features, patch_shapes - def forward(self, img, aug, evaluation=False): + def forward( + self, + img: torch.Tensor, + aug: torch.Tensor, + evaluation: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass to compute patch-wise feature embeddings for original and augmented images. + + Depending on whether a pre-projection module is used, this method optionally applies it to the + embeddings generated for both `img` and `aug`. If not, the embeddings are directly obtained and + `requires_grad` is enabled for them, likely for gradient-based optimization or anomaly generation. + """ if self.pre_proj > 0: fake_feats = self.pre_projection( - self.generate_embeddings(aug, eval=evaluation)[0], + self.generate_embeddings(aug, evaluation=evaluation)[0], ) fake_feats = fake_feats[0] if len(fake_feats) == 2 else fake_feats true_feats = self.pre_projection( - self.generate_embeddings(img, eval=evaluation)[0], + self.generate_embeddings(img, evaluation=evaluation)[0], ) true_feats = true_feats[0] if len(true_feats) == 2 else true_feats else: - fake_feats = self.generate_embeddings(aug, eval=evaluation)[0] + fake_feats = self.generate_embeddings(aug, evaluation=evaluation)[0] fake_feats.requires_grad = True - true_feats = self.generate_embeddings(img, eval=evaluation)[0] + true_feats = self.generate_embeddings(img, evaluation=evaluation)[0] true_feats.requires_grad = True return true_feats, fake_feats From 838bc50ed1d486a3e8b9538c93b5e335f1812122 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Tue, 1 Jul 2025 17:34:04 +0530 Subject: [PATCH 11/23] Refactored code from lightning model to torch model Signed-off-by: Devansh Agarwal --- .../models/image/glass/lightning_model.py | 100 +---------- .../models/image/glass/torch_model.py | 155 ++++++++++++++++-- 2 files changed, 148 insertions(+), 107 deletions(-) diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index 56aed0f706..ad2744280e 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -18,13 +18,11 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import math from typing import Any import torch from lightning.pytorch.utilities.types import STEP_OUTPUT -from torch import nn, optim -from torch.nn import functional as f +from torch import optim from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, Resize from anomalib import LearningType @@ -36,7 +34,6 @@ from anomalib.pre_processing import PreProcessor from anomalib.visualization import Visualizer -from .loss import FocalLoss from .torch_model import GlassModel @@ -150,6 +147,7 @@ def __init__( self.model = GlassModel( input_shape=input_shape, + anomaly_source_path=anomaly_source_path, pretrain_embed_dim=pretrain_embed_dim, target_embed_dim=target_embed_dim, backbone=backbone, @@ -161,22 +159,19 @@ def __init__( dsc_layers=dsc_layers, dsc_hidden=dsc_hidden, dsc_margin=dsc_margin, + step=step, + svd=svd, + mining=mining, + noise=noise, + radius=radius, + p=p, ) self.c = torch.tensor([1]) - self.p = p - self.radius = radius - self.mining = mining - self.noise = noise - self.distribution = 0 self.lr = lr - self.step = step - self.svd = svd self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.focal_loss = FocalLoss() - if pre_proj > 0: self.proj_opt = optim.AdamW( self.model.pre_projection.parameters(), @@ -280,84 +275,7 @@ def training_step( self.backbone_opt.zero_grad() img = batch.image - aug, mask_s = self.augmentor(img) - if img is not None: - batch_size = img.shape[0] - - true_feats, fake_feats = self.model(img, aug) - - h_ratio = mask_s.shape[2] // int(math.sqrt(fake_feats.shape[0] // batch_size)) - w_ratio = mask_s.shape[3] // int(math.sqrt(fake_feats.shape[0] // batch_size)) - - mask_s_resized = f.interpolate( - mask_s.float(), - size=(mask_s.shape[2] // h_ratio, mask_s.shape[3] // w_ratio), - mode="nearest", - ) - mask_s_gt = mask_s_resized.reshape(-1, 1) - - noise = torch.normal(0, self.noise, true_feats.shape).to(self.dev) - gaus_feats = true_feats + noise - - center = self.c.repeat(img.shape[0], 1, 1) - center = center.reshape(-1, center.shape[-1]) - true_points = torch.concat( - [fake_feats[mask_s_gt[:, 0] == 0], true_feats], - dim=0, - ) - c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0) - dist_t = torch.norm(true_points - c_t_points, dim=1) - r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(self.dev) - - for step in range(self.step + 1): - scores = self.model.discriminator(torch.cat([true_feats, gaus_feats])) - true_scores = scores[: len(true_feats)] - gaus_scores = scores[len(true_feats) :] - true_loss = nn.BCELoss()(true_scores, torch.zeros_like(true_scores)) - gaus_loss = nn.BCELoss()(gaus_scores, torch.ones_like(gaus_scores)) - bce_loss = true_loss + gaus_loss - - if step == self.step: - break - - grad = torch.autograd.grad(gaus_loss, [gaus_feats])[0] - grad_norm = torch.norm(grad, dim=1) - grad_norm = grad_norm.view(-1, 1) - grad_normalized = grad / (grad_norm + 1e-10) - - with torch.no_grad(): - gaus_feats.add_(0.001 * grad_normalized) - - fake_points = fake_feats[mask_s_gt[:, 0] == 1] - true_points = true_feats[mask_s_gt[:, 0] == 1] - c_f_points = center[mask_s_gt[:, 0] == 1] - dist_f = torch.norm(fake_points - c_f_points, dim=1) - proj_feats = c_f_points if self.svd == 1 else true_points - r = r_t if self.svd == 1 else 1 - - if self.svd == 1: - h = fake_points - proj_feats - h_norm = dist_f if self.svd == 1 else torch.norm(h, dim=1) - alpha = torch.clamp(h_norm, 2 * r, 4 * r) - proj = (alpha / (h_norm + 1e-10)).view(-1, 1) - h = proj * h - fake_points = proj_feats + h - fake_feats[mask_s_gt[:, 0] == 1] = fake_points - - fake_scores = self.model.discriminator(fake_feats) - - if self.p > 0: - fake_dist = (fake_scores - mask_s_gt) ** 2 - d_hard = torch.quantile(fake_dist, q=self.p) - fake_scores_ = fake_scores[fake_dist >= d_hard].unsqueeze(1) - mask_ = mask_s_gt[fake_dist >= d_hard].unsqueeze(1) - else: - fake_scores_ = fake_scores - mask_ = mask_s_gt - output = torch.cat([1 - fake_scores_, fake_scores_], dim=1) - focal_loss = self.focal_loss(output, mask_) - - loss = bce_loss + focal_loss + true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(img, self.c) loss.backward() if self.proj_opt is not None: diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 18f178fe76..926de0d9b6 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -24,9 +24,12 @@ import torch.nn.functional as f from torch import nn +from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator from anomalib.models.components import TimmFeatureExtractor from anomalib.models.components.feature_extractors import dryrun_find_featuremap_dims +from .loss import FocalLoss + def init_weight(m: nn.Module) -> None: """Initializes network weights using Xavier normal initialization. @@ -313,6 +316,7 @@ class GlassModel(nn.Module): def __init__( self, input_shape: tuple[int, int], # (H, W) + anomaly_source_path: str, pretrain_embed_dim: int = 1024, target_embed_dim: int = 1024, backbone: str = "resnet18", @@ -324,6 +328,13 @@ def __init__( dsc_layers: int = 2, dsc_hidden: int = 1024, dsc_margin: float = 0.5, + mining: int = 1, + noise: float = 0.015, + radius: float = 0.75, + p: float = 0.5, + lr: float = 0.0001, + step: int = 20, + svd: int = 0, ) -> None: super().__init__() @@ -335,6 +346,12 @@ def __init__( self.input_shape = input_shape self.pre_trained = pre_trained + self.augmentor = PerlinAnomalyGenerator(anomaly_source_path) + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self.focal_loss = FocalLoss() + self.forward_modules = torch.nn.ModuleDict({}) feature_aggregator = TimmFeatureExtractor( backbone=self.backbone, @@ -367,6 +384,15 @@ def __init__( hidden=self.dsc_hidden, ) + self.p = p + self.radius = radius + self.mining = mining + self.noise = noise + self.distribution = 0 + self.lr = lr + self.step = step + self.svd = svd + self.patch_maker = PatchMaker(patchsize, stride=patchstride) def calculate_mean(self, images: torch.Tensor) -> torch.Tensor: @@ -400,6 +426,41 @@ def calculate_mean(self, images: torch.Tensor) -> torch.Tensor: return torch.mean(outputs, dim=0) + def calculate_features(self, + img: torch.Tensor, + aug: torch.Tensor, + evaluation: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Calculate and return feature embeddings for the input and augmented images. + + Depending on whether a pre-projection module is used, this method optionally applies it to the + + Args: + img (torch.Tensor): The original input image tensor. + aug (torch.Tensor): The augmented image tensor. + evaluation (bool, optional): Whether the model is in evaluation mode. Defaults to False. + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing the feature embeddings for the original + image (`true_feats`) and the augmented image (`fake_feats`). + """ + if self.pre_proj > 0: + fake_feats = self.pre_projection( + self.generate_embeddings(aug, evaluation=evaluation)[0], + ) + fake_feats = fake_feats[0] if len(fake_feats) == 2 else fake_feats + true_feats = self.pre_projection( + self.generate_embeddings(img, evaluation=evaluation)[0], + ) + true_feats = true_feats[0] if len(true_feats) == 2 else true_feats + else: + fake_feats = self.generate_embeddings(aug, evaluation=evaluation)[0] + fake_feats.requires_grad = True + true_feats = self.generate_embeddings(img, evaluation=evaluation)[0] + true_feats.requires_grad = True + + return true_feats, fake_feats + def generate_embeddings( self, images: torch.Tensor, @@ -488,8 +549,7 @@ def generate_embeddings( def forward( self, img: torch.Tensor, - aug: torch.Tensor, - evaluation: bool = False, + c: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass to compute patch-wise feature embeddings for original and augmented images. @@ -497,19 +557,82 @@ def forward( embeddings generated for both `img` and `aug`. If not, the embeddings are directly obtained and `requires_grad` is enabled for them, likely for gradient-based optimization or anomaly generation. """ - if self.pre_proj > 0: - fake_feats = self.pre_projection( - self.generate_embeddings(aug, evaluation=evaluation)[0], - ) - fake_feats = fake_feats[0] if len(fake_feats) == 2 else fake_feats - true_feats = self.pre_projection( - self.generate_embeddings(img, evaluation=evaluation)[0], - ) - true_feats = true_feats[0] if len(true_feats) == 2 else true_feats + aug, mask_s = self.augmentor(img) + if img is not None: + batch_size = img.shape[0] + + true_feats, fake_feats = self.calculate_features(img, aug) + + h_ratio = mask_s.shape[2] // int(math.sqrt(fake_feats.shape[0] // batch_size)) + w_ratio = mask_s.shape[3] // int(math.sqrt(fake_feats.shape[0] // batch_size)) + + mask_s_resized = f.interpolate( + mask_s.float(), + size=(mask_s.shape[2] // h_ratio, mask_s.shape[3] // w_ratio), + mode="nearest", + ) + mask_s_gt = mask_s_resized.reshape(-1, 1) + + noise = torch.normal(0, self.noise, true_feats.shape).to(self.device) + gaus_feats = true_feats + noise + + center = c.repeat(img.shape[0], 1, 1) + center = center.reshape(-1, center.shape[-1]) + true_points = torch.concat( + [fake_feats[mask_s_gt[:, 0] == 0], true_feats], + dim=0, + ) + c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0) + dist_t = torch.norm(true_points - c_t_points, dim=1) + r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(self.device) + + for step in range(self.step + 1): + scores = self.discriminator(torch.cat([true_feats, gaus_feats])) + true_scores = scores[: len(true_feats)] + gaus_scores = scores[len(true_feats) :] + true_loss = nn.BCELoss()(true_scores, torch.zeros_like(true_scores)) + gaus_loss = nn.BCELoss()(gaus_scores, torch.ones_like(gaus_scores)) + bce_loss = true_loss + gaus_loss + + if step == self.step: + break + + grad = torch.autograd.grad(gaus_loss, [gaus_feats])[0] + grad_norm = torch.norm(grad, dim=1) + grad_norm = grad_norm.view(-1, 1) + grad_normalized = grad / (grad_norm + 1e-10) + + with torch.no_grad(): + gaus_feats.add_(0.001 * grad_normalized) + + fake_points = fake_feats[mask_s_gt[:, 0] == 1] + true_points = true_feats[mask_s_gt[:, 0] == 1] + c_f_points = center[mask_s_gt[:, 0] == 1] + dist_f = torch.norm(fake_points - c_f_points, dim=1) + proj_feats = c_f_points if self.svd == 1 else true_points + r = r_t if self.svd == 1 else 1 + + if self.svd == 1: + h = fake_points - proj_feats + h_norm = dist_f if self.svd == 1 else torch.norm(h, dim=1) + alpha = torch.clamp(h_norm, 2 * r, 4 * r) + proj = (alpha / (h_norm + 1e-10)).view(-1, 1) + h = proj * h + fake_points = proj_feats + h + fake_feats[mask_s_gt[:, 0] == 1] = fake_points + + fake_scores = self.discriminator(fake_feats) + + if self.p > 0: + fake_dist = (fake_scores - mask_s_gt) ** 2 + d_hard = torch.quantile(fake_dist, q=self.p) + fake_scores_ = fake_scores[fake_dist >= d_hard].unsqueeze(1) + mask_ = mask_s_gt[fake_dist >= d_hard].unsqueeze(1) else: - fake_feats = self.generate_embeddings(aug, evaluation=evaluation)[0] - fake_feats.requires_grad = True - true_feats = self.generate_embeddings(img, evaluation=evaluation)[0] - true_feats.requires_grad = True + fake_scores_ = fake_scores + mask_ = mask_s_gt + output = torch.cat([1 - fake_scores_, fake_scores_], dim=1) + focal_loss = self.focal_loss(output, mask_) - return true_feats, fake_feats + loss = bce_loss + focal_loss + return true_loss, gaus_loss, bce_loss, focal_loss, loss From 1baa0b7274e82eec4a6cdf09c279ce6520078058 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Wed, 2 Jul 2025 12:32:33 +0530 Subject: [PATCH 12/23] GPU bug fixed Signed-off-by: Devansh Agarwal --- src/anomalib/models/image/glass/lightning_model.py | 8 +++----- src/anomalib/models/image/glass/torch_model.py | 7 +++---- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index ad2744280e..d2be528896 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -170,8 +170,6 @@ def __init__( self.c = torch.tensor([1]) self.lr = lr - self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if pre_proj > 0: self.proj_opt = optim.AdamW( self.model.pre_projection.parameters(), @@ -275,7 +273,7 @@ def training_step( self.backbone_opt.zero_grad() img = batch.image - true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(img, self.c) + true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(img, self.c, self.device) loss.backward() if self.proj_opt is not None: @@ -301,9 +299,9 @@ def on_train_start(self) -> None: with torch.no_grad(): for i, batch in enumerate(dataloader): if i == 0: - self.c = self.model.calculate_mean(batch.image.to(self.dev)) + self.c = self.model.calculate_mean(batch.image.to(self.device)) else: - self.c += self.model.calculate_mean(batch.image.to(self.dev)) + self.c += self.model.calculate_mean(batch.image.to(self.device)) self.c /= len(dataloader) diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 926de0d9b6..80a79ed5a5 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -348,8 +348,6 @@ def __init__( self.augmentor = PerlinAnomalyGenerator(anomaly_source_path) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.focal_loss = FocalLoss() self.forward_modules = torch.nn.ModuleDict({}) @@ -550,6 +548,7 @@ def forward( self, img: torch.Tensor, c: torch.Tensor | None = None, + device: torch.device | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass to compute patch-wise feature embeddings for original and augmented images. @@ -573,7 +572,7 @@ def forward( ) mask_s_gt = mask_s_resized.reshape(-1, 1) - noise = torch.normal(0, self.noise, true_feats.shape).to(self.device) + noise = torch.normal(0, self.noise, true_feats.shape).to(device) gaus_feats = true_feats + noise center = c.repeat(img.shape[0], 1, 1) @@ -584,7 +583,7 @@ def forward( ) c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0) dist_t = torch.norm(true_points - c_t_points, dim=1) - r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(self.device) + r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(device) for step in range(self.step + 1): scores = self.discriminator(torch.cat([true_feats, gaus_feats])) From f066b3c7d0a5467cbc4386d3c495cac9dd4a0588 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Wed, 2 Jul 2025 15:36:23 +0530 Subject: [PATCH 13/23] used image device in torch model Signed-off-by: Devansh Agarwal --- .../models/image/glass/lightning_model.py | 30 ++++++++++++++++++- .../models/image/glass/torch_model.py | 2 +- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index d2be528896..eb2dfa6903 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -273,7 +273,7 @@ def training_step( self.backbone_opt.zero_grad() img = batch.image - true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(img, self.c, self.device) + true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(img, self.c) loss.backward() if self.proj_opt is not None: @@ -288,6 +288,34 @@ def training_step( self.log("focal_loss", focal_loss, prog_bar=True) self.log("loss", loss, prog_bar=True) + def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: + """Validation step for GLASS model. + + This method is called during validation to compute the loss and metrics. + + Args: + batch (Batch): Input batch containing images and metadata + *args: Additional positional arguments + **kwargs: Additional keyword arguments + + Returns: + STEP_OUTPUT: Dictionary containing loss values and metrics + """ + del args, kwargs + self.model.forward_modules.eval() + if self.model.pre_proj > 0: + self.model.pre_projection.eval() + self.model.discriminator.eval() + + img = batch.image + true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(img, self.c, self.device) + + self.log("val/true_loss", true_loss, prog_bar=True) + self.log("val/gaus_loss", gaus_loss, prog_bar=True) + self.log("val/bce_loss", bce_loss, prog_bar=True) + self.log("val/focal_loss", focal_loss, prog_bar=True) + self.log("val/loss", loss, prog_bar=True) + def on_train_start(self) -> None: """Initialize model by computing mean feature representation across training dataset. diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 80a79ed5a5..2f479587c7 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -548,7 +548,6 @@ def forward( self, img: torch.Tensor, c: torch.Tensor | None = None, - device: torch.device | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass to compute patch-wise feature embeddings for original and augmented images. @@ -556,6 +555,7 @@ def forward( embeddings generated for both `img` and `aug`. If not, the embeddings are directly obtained and `requires_grad` is enabled for them, likely for gradient-based optimization or anomaly generation. """ + device = img.device aug, mask_s = self.augmentor(img) if img is not None: batch_size = img.shape[0] From 6e780b05f44a080190b99b71c6d3f8b7930178e9 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Wed, 2 Jul 2025 15:38:58 +0530 Subject: [PATCH 14/23] fixed bug Signed-off-by: Devansh Agarwal --- .../models/image/glass/lightning_model.py | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index eb2dfa6903..99e82e489b 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -288,34 +288,6 @@ def training_step( self.log("focal_loss", focal_loss, prog_bar=True) self.log("loss", loss, prog_bar=True) - def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: - """Validation step for GLASS model. - - This method is called during validation to compute the loss and metrics. - - Args: - batch (Batch): Input batch containing images and metadata - *args: Additional positional arguments - **kwargs: Additional keyword arguments - - Returns: - STEP_OUTPUT: Dictionary containing loss values and metrics - """ - del args, kwargs - self.model.forward_modules.eval() - if self.model.pre_proj > 0: - self.model.pre_projection.eval() - self.model.discriminator.eval() - - img = batch.image - true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(img, self.c, self.device) - - self.log("val/true_loss", true_loss, prog_bar=True) - self.log("val/gaus_loss", gaus_loss, prog_bar=True) - self.log("val/bce_loss", bce_loss, prog_bar=True) - self.log("val/focal_loss", focal_loss, prog_bar=True) - self.log("val/loss", loss, prog_bar=True) - def on_train_start(self) -> None: """Initialize model by computing mean feature representation across training dataset. From b1be6f5da85f4a0ae380c26af14b56ba5dc67f20 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Fri, 11 Jul 2025 22:44:05 +0530 Subject: [PATCH 15/23] Added validation step Signed-off-by: Devansh Agarwal --- .../models/image/glass/lightning_model.py | 24 ++++ .../models/image/glass/torch_model.py | 122 ++++++++++++++++-- 2 files changed, 137 insertions(+), 9 deletions(-) diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index 99e82e489b..47c47807d8 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -288,6 +288,30 @@ def training_step( self.log("focal_loss", focal_loss, prog_bar=True) self.log("loss", loss, prog_bar=True) + def validation_step( + self, + batch: Batch, + batch_idx: int, + ) -> STEP_OUTPUT: + """Performs a single validation step during model evaluation. + + Args: + batch (Batch): A batch of input data, typically containing images and ground truth labels. + batch_idx (int): Index of the batch (unused in this function). + + Returns: + STEP_OUTPUT: Output of the validation step, usually containing predictions and any associated metrics. + """ + del batch_idx + self.model.forward_modules.eval() + + if self.model.pre_proj > 0: + self.model.pre_projection.eval() + self.model.discriminator.eval() + + predictions = self.model(batch.image, self.c) + return batch.update(**predictions._asdict()) + def on_train_start(self) -> None: """Initialize model by computing mean feature representation across training dataset. diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 2f479587c7..3910b7ebf6 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -20,10 +20,13 @@ import math +import kornia.filters as kf +import numpy as np import torch import torch.nn.functional as f from torch import nn +from anomalib.data import InferenceBatch from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator from anomalib.models.components import TimmFeatureExtractor from anomalib.models.components.feature_extractors import dryrun_find_featuremap_dims @@ -307,7 +310,60 @@ def score(x: torch.Tensor) -> torch.Tensor: torch.Tensor: Final anomaly score per image, shape (B,). """ x = x[:, :, 0] # remove last dimension if singleton - return torch.max(x, dim=1).to_numpy() + return torch.max(x, dim=1).values + + +class RescaleSegmentor: + """A utility class for rescaling and smoothing patch-level anomaly scores to generate segmentation masks. + + Attributes: + target_size (int): The spatial size (height and width) to which patch scores will be rescaled. + smoothing (int): The standard deviation used for Gaussian smoothing. + """ + + def __init__(self, target_size: int = 288) -> None: + """Initializes the RescaleSegmentor. + + Args: + target_size (int, optional): The desired output size (height/width) of segmentation maps. Defaults to 288. + """ + self.target_size = target_size + self.smoothing = 4 + + def convert_to_segmentation( + self, + patch_scores: np.ndarray | torch.Tensor, + device: torch.device, + ) -> list[torch.Tensor]: + """Converts patch-level scores to smoothed segmentation masks. + + Args: + patch_scores (np.ndarray | torch.Tensor): Patch-wise scores of shape [N, H, W]. + device (torch.device): Device on which to perform computation. + + Returns: + List[torch.Tensor]: A list of segmentation masks, each of shape [H, W], + rescaled to `target_size` and smoothed. + """ + with torch.no_grad(): + if isinstance(patch_scores, np.ndarray): + patch_scores = torch.from_numpy(patch_scores) + + scores = patch_scores.to(device) + scores = scores.unsqueeze(1) # [N, 1, H, W] + scores = f.interpolate( + scores, size=self.target_size, mode="bilinear", align_corners=False, + ) + patch_scores = scores.squeeze(1) # [N, H, W] + + patch_stack = patch_scores.unsqueeze(1) # [N, 1, H, W] + smoothed_stack = kf.gaussian_blur2d( + patch_stack, + kernel_size=(5, 5), + sigma=(self.smoothing, self.smoothing), + ) + + return [s.squeeze(0) for s in smoothed_stack] # List of [H, W] tensors class GlassModel(nn.Module): @@ -393,6 +449,8 @@ def __init__( self.patch_maker = PatchMaker(patchsize, stride=patchstride) + self.anomaly_segmentor = RescaleSegmentor(target_size=input_shape[:]) + def calculate_mean(self, images: torch.Tensor) -> torch.Tensor: """Computes the mean feature embedding across a batch of images. @@ -544,6 +602,34 @@ def generate_embeddings( return patch_features, patch_shapes + def calculate_anomaly_scores(self, images: torch.Tensor) -> torch.Tensor: + """Calculates anomaly scores and segmentation masks for a batch of input images. + + Args: + images (torch.Tensor): Batch of input images of shape [B, C, H, W]. + + Returns: + tuple[torch.Tensor, list[torch.Tensor]]: + - image_scores: Tensor of anomaly scores per image, shape [B]. + - masks: List of segmentation masks for each image, each of shape [H, W]. + """ + with torch.no_grad(): + patch_features, patch_shapes = self.generate_embeddings(images, evaluation=True) + if self.pre_proj > 0: + patch_features = self.pre_projection(patch_features) + patch_features = patch_features[0] if len(patch_features) == 2 else patch_features + + patch_scores = image_scores = self.discriminator(patch_features) + patch_scores = self.patch_maker.unpatch_scores(patch_scores, images.shape[0]) + scales = patch_shapes[0] + patch_scores = patch_scores.reshape(images.shape[0], scales[0], scales[1]) + masks = self.anomaly_segmentor.convert_to_segmentation(patch_scores, device=images.device) + + image_scores = self.patch_maker.unpatch_scores(image_scores, batchsize=images.shape[0]) + image_scores = self.patch_maker.score(image_scores) + + return image_scores, masks + def forward( self, img: torch.Tensor, @@ -596,13 +682,26 @@ def forward( if step == self.step: break - grad = torch.autograd.grad(gaus_loss, [gaus_feats])[0] - grad_norm = torch.norm(grad, dim=1) - grad_norm = grad_norm.view(-1, 1) - grad_normalized = grad / (grad_norm + 1e-10) + if self.training: + grad = torch.autograd.grad(gaus_loss, [gaus_feats])[0] + grad_norm = torch.norm(grad, dim=1) + grad_norm = grad_norm.view(-1, 1) + grad_normalized = grad / (grad_norm + 1e-10) - with torch.no_grad(): - gaus_feats.add_(0.001 * grad_normalized) + with torch.no_grad(): + gaus_feats.add_(0.001 * grad_normalized) + + if (step + 1) % 5 == 0: + dist_g = torch.norm(gaus_feats - center, dim=1) + proj_feats = center if self.svd == 1 else true_feats + r = r_t if self.svd == 1 else 0.5 + + h = gaus_feats - proj_feats + h_norm = dist_g if self.svd == 1 else torch.norm(h, dim=1) + alpha = torch.clamp(h_norm, r, 2 * r) + proj = (alpha / (h_norm + 1e-10)).view(-1, 1) + h = proj * h + gaus_feats = proj_feats + h fake_points = fake_feats[mask_s_gt[:, 0] == 1] true_points = true_feats[mask_s_gt[:, 0] == 1] @@ -633,5 +732,10 @@ def forward( output = torch.cat([1 - fake_scores_, fake_scores_], dim=1) focal_loss = self.focal_loss(output, mask_) - loss = bce_loss + focal_loss - return true_loss, gaus_loss, bce_loss, focal_loss, loss + if self.training: + loss = bce_loss + focal_loss + return true_loss, gaus_loss, bce_loss, focal_loss, loss + + anomaly_scores, masks = self.calculate_anomaly_scores(img) + masks = torch.stack(masks) + return InferenceBatch(pred_score=anomaly_scores, anomaly_map=masks) From d5affe49e7709614990acd3b54f63621d180ba45 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Mon, 28 Jul 2025 23:19:51 +0530 Subject: [PATCH 16/23] Refactored code for better readability Signed-off-by: Devansh Agarwal --- .../components/feature_extractors/__init__.py | 1 + src/anomalib/models/image/__init__.py | 2 +- src/anomalib/models/image/glass/__init__.py | 6 +- .../models/image/glass/lightning_model.py | 120 +++++++----------- src/anomalib/models/image/glass/loss.py | 70 +++++----- .../models/image/glass/torch_model.py | 107 ++++++++-------- 6 files changed, 143 insertions(+), 163 deletions(-) diff --git a/src/anomalib/models/components/feature_extractors/__init__.py b/src/anomalib/models/components/feature_extractors/__init__.py index b9936e793d..66a2f36c34 100644 --- a/src/anomalib/models/components/feature_extractors/__init__.py +++ b/src/anomalib/models/components/feature_extractors/__init__.py @@ -28,6 +28,7 @@ from .timm import TimmFeatureExtractor from .utils import dryrun_find_featuremap_dims + __all__ = [ "dryrun_find_featuremap_dims", "TimmFeatureExtractor", diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index 6da89cbb12..4f2c36dae6 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -85,5 +85,5 @@ "Supersimplenet", "Uflow", "VlmAd", - "WinClip" + "WinClip", ] diff --git a/src/anomalib/models/image/glass/__init__.py b/src/anomalib/models/image/glass/__init__.py index a3070bacf4..cac6c015e8 100644 --- a/src/anomalib/models/image/glass/__init__.py +++ b/src/anomalib/models/image/glass/__init__.py @@ -1,3 +1,6 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + """GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both @@ -15,9 +18,6 @@ ` """ -# Copyright (C) 2025 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - from .lightning_model import Glass __all__ = ["Glass"] diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index 47c47807d8..3f28f04b2c 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -1,3 +1,6 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + """GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both @@ -15,19 +18,14 @@ ` """ -# Copyright (C) 2025 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - from typing import Any -import torch from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import optim from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, Resize from anomalib import LearningType from anomalib.data import Batch -from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator from anomalib.metrics import Evaluator from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor @@ -66,15 +64,15 @@ class Glass(AnomalibModule): Defaults to `True`. layers (list[str], optional): List of backbone layers to extract features from. Defaults to `["layer1", "layer2", "layer3"]`. - pre_proj (int, optional): Number of projection layers used in the feature adaptor (e.g., MLP before + pre_projection (int, optional): Number of projection layers used in the feature adaptor (e.g., MLP before discriminator). Defaults to `1`. - dsc_layers (int, optional): Number of layers in the discriminator network. + discriminator_layers (int, optional): Number of layers in the discriminator network. Defaults to `2`. - dsc_hidden (int, optional): Number of hidden units in each discriminator layer. + discriminator_hidden (int, optional): Number of hidden units in each discriminator layer. Defaults to `1024`. - dsc_margin (float, optional): Margin used for contrastive or binary classification loss in discriminator - training. + discriminator_margin (float, optional): Margin used for contrastive or binary classification loss in + discriminator training. Defaults to `0.5`. pre_processor (PreProcessor | bool, optional): reprocessing module or flag to enable default preprocessing. Set to `True` to apply default normalization and resizing. @@ -95,10 +93,10 @@ class Glass(AnomalibModule): radius (float, optional): Radius parameter used for truncated projection in the anomaly synthesis strategy. Determines the range for valid synthetic anomalies in the hypersphere or manifold. Defaults to `0.75`. - p (float, optional): Probability used in random selection logic, such as anomaly mask generation or augmentation - choice. + random_selection_prob (float, optional): Probability used in random selection logic, such as anomaly mask + generation or augmentation choice. Defaults to `0.5`. - lr (float, optional): Learning rate for training the feature adaptor and discriminator networks. + learning_rate (float, optional): Learning rate for training the feature adaptor and discriminator networks. Defaults to `0.0001`. step (int, optional): Number of gradient ascent steps for anomaly synthesis. Defaults to `20`. @@ -108,8 +106,8 @@ class Glass(AnomalibModule): def __init__( self, - input_shape: tuple[int, int], - anomaly_source_path: str, + input_shape: tuple[int, int] = (256, 256), + anomaly_source_path: str | None = None, backbone: str = "resnet18", pretrain_embed_dim: int = 1024, target_embed_dim: int = 1024, @@ -117,10 +115,10 @@ def __init__( patchstride: int = 1, pre_trained: bool = True, layers: list[str] | None = None, - pre_proj: int = 1, - dsc_layers: int = 2, - dsc_hidden: int = 1024, - dsc_margin: float = 0.5, + pre_projection: int = 1, + discriminator_layers: int = 2, + discriminator_hidden: int = 1024, + discriminator_margin: float = 0.5, pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, @@ -128,8 +126,8 @@ def __init__( mining: int = 1, noise: float = 0.015, radius: float = 0.75, - p: float = 0.5, - lr: float = 0.0001, + random_selection_prob: float = 0.5, + learning_rate: float = 0.0001, step: int = 20, svd: int = 0, ) -> None: @@ -143,8 +141,6 @@ def __init__( if layers is None: layers = ["layer1", "layer2", "layer3"] - self.augmentor = PerlinAnomalyGenerator(anomaly_source_path) - self.model = GlassModel( input_shape=input_shape, anomaly_source_path=anomaly_source_path, @@ -155,34 +151,33 @@ def __init__( patchsize=patchsize, patchstride=patchstride, layers=layers, - pre_proj=pre_proj, - dsc_layers=dsc_layers, - dsc_hidden=dsc_hidden, - dsc_margin=dsc_margin, + pre_projection=pre_projection, + discriminator_layers=discriminator_layers, + discriminator_hidden=discriminator_hidden, + discriminator_margin=discriminator_margin, step=step, svd=svd, mining=mining, noise=noise, radius=radius, - p=p, + random_selection_prob=random_selection_prob, ) - self.c = torch.tensor([1]) - self.lr = lr + self.learning_rate = learning_rate - if pre_proj > 0: - self.proj_opt = optim.AdamW( - self.model.pre_projection.parameters(), - self.lr, + if pre_projection > 0: + self.projection_opt = optim.AdamW( + self.model.projection.parameters(), + self.learning_rate, weight_decay=1e-5, ) else: - self.proj_opt = None + self.projection_opt = None if not pre_trained: self.backbone_opt = optim.AdamW( - self.mosdel.forward_modules["feature_aggregator"].backbone.parameters(), - self.lr, + self.model.forward_modules["feature_aggregator"].backbone.parameters(), + self.learning_rate, ) else: self.backbone_opt = None @@ -242,13 +237,9 @@ def configure_optimizers(self) -> optim.Optimizer: Returns: Optimizer: AdamW Optimizer for the discriminator. """ - return optim.AdamW(self.model.discriminator.parameters(), lr=self.lr * 2) + return optim.AdamW(self.model.discriminator.parameters(), lr=self.learning_rate * 2) - def training_step( - self, - batch: Batch, - batch_idx: int, - ) -> STEP_OUTPUT: + def training_step(self, batch: Batch, batch_idx: int) -> STEP_OUTPUT: """Training step for GLASS model. Args: @@ -259,28 +250,27 @@ def training_step( STEP_OUTPUT: Dictionary containing loss values and metrics """ del batch_idx - dsc_opt = self.optimizers() + discriminator_opt = self.optimizers() self.model.forward_modules.eval() - if self.model.pre_proj > 0: - self.model.pre_projection.train() + if self.model.pre_projection > 0: + self.model.projection.train() self.model.discriminator.train() - dsc_opt.zero_grad() - if self.proj_opt is not None: - self.proj_opt.zero_grad() + discriminator_opt.zero_grad() + if self.projection_opt is not None: + self.projection_opt.zero_grad() if self.backbone_opt is not None: self.backbone_opt.zero_grad() - img = batch.image - true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(img, self.c) + true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(batch.image) loss.backward() - if self.proj_opt is not None: - self.proj_opt.step() + if self.projection_opt is not None: + self.projection_opt.step() if self.backbone_opt is not None: self.backbone_opt.step() - dsc_opt.step() + discriminator_opt.step() self.log("true_loss", true_loss, prog_bar=True) self.log("gaus_loss", gaus_loss, prog_bar=True) @@ -288,11 +278,7 @@ def training_step( self.log("focal_loss", focal_loss, prog_bar=True) self.log("loss", loss, prog_bar=True) - def validation_step( - self, - batch: Batch, - batch_idx: int, - ) -> STEP_OUTPUT: + def validation_step(self, batch: Batch, batch_idx: int) -> STEP_OUTPUT: """Performs a single validation step during model evaluation. Args: @@ -305,11 +291,11 @@ def validation_step( del batch_idx self.model.forward_modules.eval() - if self.model.pre_proj > 0: - self.model.pre_projection.eval() + if self.model.pre_projection > 0: + self.model.projection.eval() self.model.discriminator.eval() - predictions = self.model(batch.image, self.c) + predictions = self.model(batch.image) return batch.update(**predictions._asdict()) def on_train_start(self) -> None: @@ -319,15 +305,7 @@ def on_train_start(self) -> None: that serves as a reference point for the normal class distribution. """ dataloader = self.trainer.train_dataloader - - with torch.no_grad(): - for i, batch in enumerate(dataloader): - if i == 0: - self.c = self.model.calculate_mean(batch.image.to(self.device)) - else: - self.c += self.model.calculate_mean(batch.image.to(self.device)) - - self.c /= len(dataloader) + self.model.calculate_center(dataloader, self.device) @property def learning_type(self) -> LearningType: diff --git a/src/anomalib/models/image/glass/loss.py b/src/anomalib/models/image/glass/loss.py index db3c9da5cb..73365f6752 100644 --- a/src/anomalib/models/image/glass/loss.py +++ b/src/anomalib/models/image/glass/loss.py @@ -1,3 +1,12 @@ +# Original Code +# Copyright (c) 2021 @Hsuxu +# https://github.com/Hsuxu/Loss_ToolBox-PyTorch. +# SPDX-License-Identifier: Apache-2.0 +# +# Modified +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + """Focal Loss for multi-class classification with optional label smoothing and class weighting. This loss function is designed to address class imbalance by down-weighting easy examples and focusing training @@ -16,7 +25,7 @@ applying a specified non-linearity (e.g., softmax or sigmoid). Args: - apply_nonlin (nn.Module or None): Optional non-linearity to apply to the logits before loss computation. + apply_nonlinearity (nn.Module or None): Optional non-linearity to apply to the logits before loss computation. For example, use `nn.Softmax(dim=1)` or `nn.Sigmoid()` if logits are not normalized. alpha (float or torch.Tensor, optional): Class balancing factor. Can be: - None: Equal weighting for all classes. @@ -40,15 +49,6 @@ torch.Tensor: Scalar loss value (averaged or summed based on `size_average`). """ -# Original Code -# Copyright (c) 2021 @Hsuxu -# https://github.com/Hsuxu/Loss_ToolBox-PyTorch. -# SPDX-License-Identifier: Apache-2.0 -# -# Modified -# Copyright (C) 2025 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - import numpy as np import torch from torch import nn @@ -74,7 +74,7 @@ class FocalLoss(nn.Module): def __init__( self, - apply_nonlin: nn.Module | None = None, + apply_nonlinearity: nn.Module | None = None, alpha: float | torch.Tensor = None, gamma: float = 2, balance_index: int = 0, @@ -84,7 +84,7 @@ def __init__( """Initializes the FocalLoss instance. Args: - apply_nonlin (nn.Module or None): Optional non-linearity to apply to logits (e.g., softmax or sigmoid). + apply_nonlinearity (nn.Module or None): Optional non-linearity to apply to logits (e.g., softmax or sigmoid) alpha (float or torch.Tensor, optional): Weighting factor for class imbalance. Can be: - None: Equal weighting. - float: Class at `balance_index` is weighted by `alpha`, others by 1 - `alpha`. @@ -95,7 +95,7 @@ def __init__( size_average (bool): If True, average the loss over the batch. If False, sum the loss. """ super().__init__() - self.apply_nonlin = apply_nonlin + self.apply_nonlinearity = apply_nonlinearity self.alpha = alpha self.gamma = gamma self.balance_index = balance_index @@ -106,59 +106,59 @@ def __init__( msg = "smooth value should be in [0,1]" raise ValueError(msg) - def forward(self, logit: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Computes the focal loss between `logit` predictions and ground-truth `target`. Args: - logit (torch.Tensor): The predicted logits of shape (B, C, ...) where B is batch size and C is the number of - classes. + logits (torch.Tensor): The predicted logits of shape (B, C, ...) where B is batch size and C is the + number of classes. target (torch.Tensor): The ground-truth class indices of shape (B, 1, ...) or broadcastable to logit. Returns: torch.Tensor: Computed focal loss value (averaged or summed depending on `size_average`). """ - if self.apply_nonlin is not None: - logit = self.apply_nonlin(logit) - num_class = logit.shape[1] - - if logit.dim() > 2: - logit = logit.view(logit.size(0), logit.size(1), -1) - logit = logit.permute(0, 2, 1).contiguous() - logit = logit.view(-1, logit.size(-1)) + if self.apply_nonlinearity is not None: + logits = self.apply_nonlinearity(logits) + num_classes = logits.shape[1] + + if logits.dim() > 2: + logits = logits.view(logits.size(0), logits.size(1), -1) + logits = logits.permute(0, 2, 1).contiguous() + logits = logits.view(-1, logits.size(-1)) target = torch.squeeze(target, 1) target = target.view(-1, 1) alpha = self.alpha if alpha is None: - alpha = torch.ones(num_class, 1) + alpha = torch.ones(num_classes, 1) elif isinstance(alpha, (list | np.ndarray)): - assert len(alpha) == num_class - alpha = torch.FloatTensor(alpha).view(num_class, 1) + assert len(alpha) == num_classes + alpha = torch.FloatTensor(alpha).view(num_classes, 1) alpha = alpha / alpha.sum() elif isinstance(alpha, float): - alpha = torch.ones(num_class, 1) + alpha = torch.ones(num_classes, 1) alpha = alpha * (1 - self.alpha) alpha[self.balance_index] = self.alpha else: msg = "Not support alpha type" raise TypeError(msg) - if alpha.device != logit.device: - alpha = alpha.to(logit.device) + if alpha.device != logits.device: + alpha = alpha.to(logits.device) idx = target.cpu().long() - one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() + one_hot_key = torch.FloatTensor(target.size(0), num_classes).zero_() one_hot_key = one_hot_key.scatter_(1, idx, 1) - if one_hot_key.device != logit.device: - one_hot_key = one_hot_key.to(logit.device) + if one_hot_key.device != logits.device: + one_hot_key = one_hot_key.to(logits.device) if self.smooth: one_hot_key = torch.clamp( one_hot_key, - self.smooth / (num_class - 1), + self.smooth / (num_classes - 1), 1.0 - self.smooth, ) - pt = (one_hot_key * logit).sum(1) + self.smooth + pt = (one_hot_key * logits).sum(1) + self.smooth logpt = pt.log() gamma = self.gamma diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 3910b7ebf6..faae16f06b 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -1,3 +1,6 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + """GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both @@ -15,9 +18,6 @@ ` """ -# Copyright (C) 2025 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - import math import kornia.filters as kf @@ -25,6 +25,7 @@ import torch import torch.nn.functional as f from torch import nn +from torch.utils.data import dataloader from anomalib.data import InferenceBatch from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator @@ -300,7 +301,7 @@ def unpatch_scores(x: torch.Tensor, batchsize: int) -> torch.Tensor: return x.reshape(batchsize, -1, *x.shape[1:]) @staticmethod - def score(x: torch.Tensor) -> torch.Tensor: + def compute_score(x: torch.Tensor) -> torch.Tensor: """Computes final anomaly scores from patch-wise predictions. Args: @@ -371,8 +372,8 @@ class GlassModel(nn.Module): def __init__( self, - input_shape: tuple[int, int], # (H, W) - anomaly_source_path: str, + input_shape: tuple[int, int] = (256, 256), # (H, W) + anomaly_source_path: str | None = None, pretrain_embed_dim: int = 1024, target_embed_dim: int = 1024, backbone: str = "resnet18", @@ -380,15 +381,14 @@ def __init__( patchstride: int = 1, pre_trained: bool = True, layers: list[str] | None = None, - pre_proj: int = 1, - dsc_layers: int = 2, - dsc_hidden: int = 1024, - dsc_margin: float = 0.5, + pre_projection: int = 1, + discriminator_layers: int = 2, + discriminator_hidden: int = 1024, + discriminator_margin: float = 0.5, mining: int = 1, noise: float = 0.015, radius: float = 0.75, - p: float = 0.5, - lr: float = 0.0001, + random_selection_prob: float = 0.5, step: int = 20, svd: int = 0, ) -> None: @@ -421,29 +421,28 @@ def __init__( preadapt_aggregator = Aggregator(target_dim=target_embed_dim) self.forward_modules["preadapt_aggregator"] = preadapt_aggregator - self.pre_proj = pre_proj - if self.pre_proj > 0: - self.pre_projection = Projection( + self.pre_projection = pre_projection + if self.pre_projection > 0: + self.projection = Projection( self.target_embed_dimension, self.target_embed_dimension, - pre_proj, + pre_projection, ) - self.dsc_layers = dsc_layers - self.dsc_hidden = dsc_hidden - self.dsc_margin = dsc_margin + self.discriminator_layers = discriminator_layers + self.discriminator_hidden = discriminator_hidden + self.discriminator_margin = discriminator_margin self.discriminator = Discriminator( self.target_embed_dimension, - n_layers=self.dsc_layers, - hidden=self.dsc_hidden, + n_layers=self.discriminator_layers, + hidden=self.discriminator_hidden, ) - self.p = p + self.random_selection_prob = random_selection_prob self.radius = radius self.mining = mining self.noise = noise self.distribution = 0 - self.lr = lr self.step = step self.svd = svd @@ -451,36 +450,39 @@ def __init__( self.anomaly_segmentor = RescaleSegmentor(target_size=input_shape[:]) - def calculate_mean(self, images: torch.Tensor) -> torch.Tensor: - """Computes the mean feature embedding across a batch of images. + def calculate_center(self, dataloader: dataloader, device: torch.device) -> None: + """Calculates and updates the center embedding from a dataset. - This method performs a forward pass through the model to extract feature embeddings - for a batch of input images, optionally passing them through a pre-projection module. - It then reshapes the output and calculates the mean across the batch dimension. + This method runs the model in evaluation mode and computes the mean feature + representation (center) across the entire dataset. The center is used for + further downstream tasks such as anomaly detection or feature normalization. Args: - images (torch.Tensor): Input image tensor of shape (B, C, H, W), where: - - B is the batch size, - - C is the number of channels, - - H and W are height and width. + dataloader (DataLoader): A PyTorch DataLoader providing batches of data, + where each batch contains an 'image' attribute. + device (torch.device): The device on which tensors should be processed + (e.g., torch.device("cuda") or torch.device("cpu")). Returns: - torch.Tensor: Mean embedding tensor of shape (N, D), where: - - N is the number of patches or tokens per image, - - D is the feature dimension. + None: The method updates `self.center` in-place with the computed center tensor. """ self.forward_modules.eval() + self.center = torch.tensor([1]) with torch.no_grad(): - if self.pre_proj > 0: - outputs = self.pre_projection(self.generate_embeddings(images)[0]) - outputs = outputs[0] if len(outputs) == 2 else outputs - else: - outputs = self._embed(images, evaluation=False)[0] + for i, batch in enumerate(dataloader): + if self.pre_projection > 0: + outputs = self.projection(self.generate_embeddings(batch.image.to(device))[0]) + outputs = outputs[0] if len(outputs) == 2 else outputs + else: + outputs = self._embed(batch.image.to(device), evaluation=False)[0] - outputs = outputs[0] if len(outputs) == 2 else outputs - outputs = outputs.reshape(images.shape[0], -1, outputs.shape[-1]) + outputs = outputs[0] if len(outputs) == 2 else outputs + outputs = outputs.reshape(batch.image.to(device).shape[0], -1, outputs.shape[-1]) - return torch.mean(outputs, dim=0) + if i == 0: + self.center = torch.mean(outputs, dim=0) + else: + self.center += torch.mean(outputs, dim=0) def calculate_features(self, img: torch.Tensor, @@ -500,12 +502,12 @@ def calculate_features(self, tuple[torch.Tensor, torch.Tensor]: A tuple containing the feature embeddings for the original image (`true_feats`) and the augmented image (`fake_feats`). """ - if self.pre_proj > 0: - fake_feats = self.pre_projection( + if self.pre_projection > 0: + fake_feats = self.projection( self.generate_embeddings(aug, evaluation=evaluation)[0], ) fake_feats = fake_feats[0] if len(fake_feats) == 2 else fake_feats - true_feats = self.pre_projection( + true_feats = self.projection( self.generate_embeddings(img, evaluation=evaluation)[0], ) true_feats = true_feats[0] if len(true_feats) == 2 else true_feats @@ -615,8 +617,8 @@ def calculate_anomaly_scores(self, images: torch.Tensor) -> torch.Tensor: """ with torch.no_grad(): patch_features, patch_shapes = self.generate_embeddings(images, evaluation=True) - if self.pre_proj > 0: - patch_features = self.pre_projection(patch_features) + if self.pre_projection > 0: + patch_features = self.projection(patch_features) patch_features = patch_features[0] if len(patch_features) == 2 else patch_features patch_scores = image_scores = self.discriminator(patch_features) @@ -626,14 +628,13 @@ def calculate_anomaly_scores(self, images: torch.Tensor) -> torch.Tensor: masks = self.anomaly_segmentor.convert_to_segmentation(patch_scores, device=images.device) image_scores = self.patch_maker.unpatch_scores(image_scores, batchsize=images.shape[0]) - image_scores = self.patch_maker.score(image_scores) + image_scores = self.patch_maker.compute_score(image_scores) return image_scores, masks def forward( self, img: torch.Tensor, - c: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass to compute patch-wise feature embeddings for original and augmented images. @@ -661,7 +662,7 @@ def forward( noise = torch.normal(0, self.noise, true_feats.shape).to(device) gaus_feats = true_feats + noise - center = c.repeat(img.shape[0], 1, 1) + center = self.center.repeat(img.shape[0], 1, 1) center = center.reshape(-1, center.shape[-1]) true_points = torch.concat( [fake_feats[mask_s_gt[:, 0] == 0], true_feats], @@ -721,9 +722,9 @@ def forward( fake_scores = self.discriminator(fake_feats) - if self.p > 0: + if self.random_selection_prob > 0: fake_dist = (fake_scores - mask_s_gt) ** 2 - d_hard = torch.quantile(fake_dist, q=self.p) + d_hard = torch.quantile(fake_dist, q=self.random_selection_prob) fake_scores_ = fake_scores[fake_dist >= d_hard].unsqueeze(1) mask_ = mask_s_gt[fake_dist >= d_hard].unsqueeze(1) else: From a1097e5b2cfeff14e4805b09a9e02d1562d7ed8e Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Thu, 31 Jul 2025 20:34:12 +0530 Subject: [PATCH 17/23] Set automatic optimization to False and made component functions Signed-off-by: Devansh Agarwal --- .../models/image/glass/components/__init__.py | 19 ++ .../image/glass/components/aggregator.py | 25 ++ .../image/glass/components/discriminator.py | 52 +++ .../image/glass/components/init_weight.py | 22 ++ .../image/glass/components/patch_maker.py | 90 +++++ .../image/glass/components/preprocessing.py | 66 ++++ .../image/glass/components/projection.py | 46 +++ .../glass/components/rescale_segmentor.py | 62 ++++ .../models/image/glass/lightning_model.py | 4 +- src/anomalib/models/image/glass/loss.py | 2 +- .../models/image/glass/torch_model.py | 314 +----------------- 11 files changed, 387 insertions(+), 315 deletions(-) create mode 100644 src/anomalib/models/image/glass/components/__init__.py create mode 100644 src/anomalib/models/image/glass/components/aggregator.py create mode 100644 src/anomalib/models/image/glass/components/discriminator.py create mode 100644 src/anomalib/models/image/glass/components/init_weight.py create mode 100644 src/anomalib/models/image/glass/components/patch_maker.py create mode 100644 src/anomalib/models/image/glass/components/preprocessing.py create mode 100644 src/anomalib/models/image/glass/components/projection.py create mode 100644 src/anomalib/models/image/glass/components/rescale_segmentor.py diff --git a/src/anomalib/models/image/glass/components/__init__.py b/src/anomalib/models/image/glass/components/__init__.py new file mode 100644 index 0000000000..8695d68fe3 --- /dev/null +++ b/src/anomalib/models/image/glass/components/__init__.py @@ -0,0 +1,19 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Utility functions for GLASS Model.""" + +from .aggregator import Aggregator +from .discriminator import Discriminator +from .patch_maker import PatchMaker +from .preprocessing import Preprocessing +from .projection import Projection +from .rescale_segmentor import RescaleSegmentor + +__all__ = ["Aggregator", + "Discriminator", + "PatchMaker", + "Preprocessing", + "Projection", + "RescaleSegmentor", +] diff --git a/src/anomalib/models/image/glass/components/aggregator.py b/src/anomalib/models/image/glass/components/aggregator.py new file mode 100644 index 0000000000..86623876e2 --- /dev/null +++ b/src/anomalib/models/image/glass/components/aggregator.py @@ -0,0 +1,25 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Aggregates and reshapes features to a target dimension.""" + +import torch +import torch.nn.functional as f + + +class Aggregator(torch.nn.Module): + """Aggregates and reshapes features to a target dimension. + + Input: Multi-dimensional feature tensors + Output: Reshaped and pooled features of specified target dimension + """ + + def __init__(self, target_dim: int) -> None: + super().__init__() + self.target_dim = target_dim + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Returns reshaped and average pooled features.""" + features = features.reshape(len(features), 1, -1) + features = f.adaptive_avg_pool1d(features, self.target_dim) + return features.reshape(len(features), -1) diff --git a/src/anomalib/models/image/glass/components/discriminator.py b/src/anomalib/models/image/glass/components/discriminator.py new file mode 100644 index 0000000000..9c7d984fbb --- /dev/null +++ b/src/anomalib/models/image/glass/components/discriminator.py @@ -0,0 +1,52 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Discriminator network for anomaly detection.""" + +import torch + +from .init_weight import init_weight + + +class Discriminator(torch.nn.Module): + """Discriminator network for anomaly detection. + + Args: + in_planes: Input feature dimension + n_layers: Number of layers + hidden: Hidden layer dimensions + """ + + def __init__(self, in_planes: int, n_layers: int = 2, hidden: int | None = None) -> None: + super().__init__() + + hidden_ = in_planes if hidden is None else hidden + self.body = torch.nn.Sequential() + for i in range(n_layers - 1): + in_ = in_planes if i == 0 else hidden_ + hidden_ = int(hidden_ // 1.5) if hidden is None else hidden + self.body.add_module( + f"block{i + 1}", + torch.nn.Sequential( + torch.nn.Linear(in_, hidden_), + torch.nn.BatchNorm1d(hidden_), + torch.nn.LeakyReLU(0.2), + ), + ) + self.tail = torch.nn.Sequential( + torch.nn.Linear(hidden_, 1, bias=False), + torch.nn.Sigmoid(), + ) + self.apply(init_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs a forward pass through the discriminator network. + + Args: + x (torch.Tensor): Input tensor of shape (B, in_planes), where B is the batch size. + + Returns: + torch.Tensor: Output tensor of shape (B, 1) containing probability scores. + """ + x = self.body(x) + return self.tail(x) diff --git a/src/anomalib/models/image/glass/components/init_weight.py b/src/anomalib/models/image/glass/components/init_weight.py new file mode 100644 index 0000000000..1e84d12fc8 --- /dev/null +++ b/src/anomalib/models/image/glass/components/init_weight.py @@ -0,0 +1,22 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Initializes network weights using Xavier normal initialization.""" + +import torch +from torch import nn + + +def init_weight(m: nn.Module) -> None: + """Initializes network weights using Xavier normal initialization. + + Applies Xavier initialization for linear layers and normal initialization + for convolutional and batch normalization layers. + """ + if isinstance(m, torch.nn.Linear): + torch.nn.init.xavier_normal_(m.weight) + if isinstance(m, torch.nn.BatchNorm2d): + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + elif isinstance(m, torch.nn.Conv2d): + m.weight.data.normal_(0.0, 0.02) diff --git a/src/anomalib/models/image/glass/components/patch_maker.py b/src/anomalib/models/image/glass/components/patch_maker.py new file mode 100644 index 0000000000..a8ff059d02 --- /dev/null +++ b/src/anomalib/models/image/glass/components/patch_maker.py @@ -0,0 +1,90 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Handles patch-based processing of feature maps.""" + +import torch + + +class PatchMaker: + """Handles patch-based processing of feature maps. + + This class provides utilities for converting feature maps into patches, + reshaping patch scores back to original dimensions, and computing global + anomaly scores from patch-wise predictions. + + Attributes: + patchsize (int): Size of each patch (patchsize x patchsize). + stride (int or None): Stride used for patch extraction. Defaults to patchsize if None. + top_k (int): Number of top patch scores to consider. Used for score reduction. + """ + + def __init__(self, patchsize: int, top_k: int = 0, stride: int | None = None) -> None: + self.patchsize = patchsize + self.stride = stride if stride is not None else patchsize + self.top_k = top_k + + def patchify( + self, + features: torch.Tensor, + return_spatial_info: bool = False, + ) -> tuple[torch.Tensor, list[int]] | torch.Tensor: + """Converts a batch of feature maps into patches. + + Args: + features (torch.Tensor): Input feature maps of shape (B, C, H, W). + return_spatial_info (bool): If True, also returns spatial patch count. Default is False. + + Returns: + torch.Tensor: Output tensor of shape (B, N, C, patchsize, patchsize), where N is number of patches. + list[int], optional: Number of patches in (height, width) dimensions, only if return_spatial_info is True. + """ + padding = int((self.patchsize - 1) / 2) + unfolder = torch.nn.Unfold( + kernel_size=self.patchsize, + stride=self.stride, + padding=padding, + dilation=1, + ) + unfolded_features = unfolder(features) + number_of_total_patches = [] + for s in features.shape[-2:]: + n_patches = (s + 2 * padding - 1 * (self.patchsize - 1) - 1) / self.stride + 1 + number_of_total_patches.append(int(n_patches)) + unfolded_features = unfolded_features.reshape( + *features.shape[:2], + self.patchsize, + self.patchsize, + -1, + ) + unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3) + + if return_spatial_info: + return unfolded_features, number_of_total_patches + return unfolded_features + + @staticmethod + def unpatch_scores(x: torch.Tensor, batchsize: int) -> torch.Tensor: + """Reshapes patch scores back into per-batch format. + + Args: + x (torch.Tensor): Input tensor of shape (B * N, ...). + batchsize (int): Original batch size. + + Returns: + torch.Tensor: Reshaped tensor of shape (B, N, ...). + """ + return x.reshape(batchsize, -1, *x.shape[1:]) + + @staticmethod + def compute_score(x: torch.Tensor) -> torch.Tensor: + """Computes final anomaly scores from patch-wise predictions. + + Args: + x (torch.Tensor): Patch scores of shape (B, N, 1). + + Returns: + torch.Tensor: Final anomaly score per image, shape (B,). + """ + x = x[:, :, 0] # remove last dimension if singleton + return torch.max(x, dim=1).values diff --git a/src/anomalib/models/image/glass/components/preprocessing.py b/src/anomalib/models/image/glass/components/preprocessing.py new file mode 100644 index 0000000000..8a1150d616 --- /dev/null +++ b/src/anomalib/models/image/glass/components/preprocessing.py @@ -0,0 +1,66 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Maps input features to a fixed dimension using adaptive average pooling.""" + +import torch +import torch.nn.functional as f + + +class MeanMapper(torch.nn.Module): + """Maps input features to a fixed dimension using adaptive average pooling. + + Input: Variable-sized feature tensors + Output: Fixed-size feature representations + """ + + def __init__(self, preprocessing_dim: int) -> None: + super().__init__() + self.preprocessing_dim = preprocessing_dim + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Applies adaptive average pooling to reshape features to a fixed size. + + Args: + features (torch.Tensor): Input tensor of shape (B, *) where * denotes + any number of remaining dimensions. It is flattened before pooling. + + Returns: + torch.Tensor: Output tensor of shape (B, D), where D is `preprocessing_dim`. + """ + features = features.reshape(len(features), 1, -1) + return f.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1) + + +class Preprocessing(torch.nn.Module): + """Handles initial feature preprocessing across multiple input dimensions. + + Input: List of features from different backbone layers + Output: Processed features with consistent dimensionality + """ + + def __init__(self, input_dims: list[int | tuple[int, int]], output_dim: int) -> None: + super().__init__() + self.input_dims = input_dims + self.output_dim = output_dim + + self.preprocessing_modules = torch.nn.ModuleList() + for _ in input_dims: + module = MeanMapper(output_dim) + self.preprocessing_modules.append(module) + + def forward(self, features: list[torch.Tensor]) -> torch.Tensor: + """Applies preprocessing modules to a list of input feature tensors. + + Args: + features (list of torch.Tensor): List of feature maps from different + layers of the backbone network. Each tensor can have a different shape. + + Returns: + torch.Tensor: A single tensor with shape (B, N, D), where B is the batch size, + N is the number of feature maps, and D is the output dimension (`output_dim`). + """ + features_ = [] + for module, feature in zip(self.preprocessing_modules, features, strict=False): + features_.append(module(feature)) + return torch.stack(features_, dim=1) diff --git a/src/anomalib/models/image/glass/components/projection.py b/src/anomalib/models/image/glass/components/projection.py new file mode 100644 index 0000000000..bf829ff493 --- /dev/null +++ b/src/anomalib/models/image/glass/components/projection.py @@ -0,0 +1,46 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Multi-layer projection network for feature adaptation.""" + +import torch + +from .init_weight import init_weight + + +class Projection(torch.nn.Module): + """Multi-layer projection network for feature adaptation. + + Args: + in_planes: Input feature dimension + out_planes: Output feature dimension + n_layers: Number of projection layers + layer_type: Type of intermediate layers + """ + + def __init__(self, in_planes: int, out_planes: int | None = None, n_layers: int = 1, layer_type: int = 0) -> None: + super().__init__() + + if out_planes is None: + out_planes = in_planes + self.layers = torch.nn.Sequential() + in_ = None + out = None + for i in range(n_layers): + in_ = in_planes if i == 0 else out + out = out_planes + self.layers.add_module(f"{i}fc", torch.nn.Linear(in_, out)) + if i < n_layers - 1 and layer_type > 1: + self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(0.2)) + self.apply(init_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies the projection network to the input features. + + Args: + x (torch.Tensor): Input tensor of shape (B, in_planes), where B is the batch size. + + Returns: + torch.Tensor: Transformed tensor of shape (B, out_planes). + """ + return self.layers(x) diff --git a/src/anomalib/models/image/glass/components/rescale_segmentor.py b/src/anomalib/models/image/glass/components/rescale_segmentor.py new file mode 100644 index 0000000000..bb83e9326e --- /dev/null +++ b/src/anomalib/models/image/glass/components/rescale_segmentor.py @@ -0,0 +1,62 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""A utility class for rescaling and smoothing patch-level anomaly scores to generate segmentation masks.""" + +import kornia.filters as kf +import numpy as np +import torch +import torch.nn.functional as f + + +class RescaleSegmentor: + """A utility class for rescaling and smoothing patch-level anomaly scores to generate segmentation masks. + + Attributes: + target_size (int): The spatial size (height and width) to which patch scores will be rescaled. + smoothing (int): The standard deviation used for Gaussian smoothing. + """ + + def __init__(self, target_size: int = 288) -> None: + """Initializes the RescaleSegmentor. + + Args: + target_size (int, optional): The desired output size (height/width) of segmentation maps. Defaults to 288. + """ + self.target_size = target_size + self.smoothing = 4 + + def convert_to_segmentation( + self, + patch_scores: np.ndarray | torch.Tensor, + device: torch.device, + ) -> list[torch.Tensor]: + """Converts patch-level scores to smoothed segmentation masks. + + Args: + patch_scores (np.ndarray | torch.Tensor): Patch-wise scores of shape [N, H, W]. + device (torch.device): Device on which to perform computation. + + Returns: + List[torch.Tensor]: A list of segmentation masks, each of shape [H, W], + rescaled to `target_size` and smoothed. + """ + with torch.no_grad(): + if isinstance(patch_scores, np.ndarray): + patch_scores = torch.from_numpy(patch_scores) + + scores = patch_scores.to(device) + scores = scores.unsqueeze(1) # [N, 1, H, W] + scores = f.interpolate( + scores, size=self.target_size, mode="bilinear", align_corners=False, + ) + patch_scores = scores.squeeze(1) # [N, H, W] + + patch_stack = patch_scores.unsqueeze(1) # [N, 1, H, W] + smoothed_stack = kf.gaussian_blur2d( + patch_stack, + kernel_size=(5, 5), + sigma=(self.smoothing, self.smoothing), + ) + + return [s.squeeze(0) for s in smoothed_stack] # List of [H, W] tensors diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index 3f28f04b2c..8a233ec50f 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -182,6 +182,8 @@ def __init__( else: self.backbone_opt = None + self.automatic_optimization = False + @classmethod def configure_pre_processor( cls, @@ -264,7 +266,7 @@ def training_step(self, batch: Batch, batch_idx: int) -> STEP_OUTPUT: self.backbone_opt.zero_grad() true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(batch.image) - loss.backward() + self.manual_backward(loss) if self.projection_opt is not None: self.projection_opt.step() diff --git a/src/anomalib/models/image/glass/loss.py b/src/anomalib/models/image/glass/loss.py index 73365f6752..89f2957446 100644 --- a/src/anomalib/models/image/glass/loss.py +++ b/src/anomalib/models/image/glass/loss.py @@ -75,7 +75,7 @@ class FocalLoss(nn.Module): def __init__( self, apply_nonlinearity: nn.Module | None = None, - alpha: float | torch.Tensor = None, + alpha: float | list | np.ndarray | None = None, gamma: float = 2, balance_index: int = 0, smooth: float = 1e-5, diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index faae16f06b..eb4f3ae32d 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -20,8 +20,6 @@ import math -import kornia.filters as kf -import numpy as np import torch import torch.nn.functional as f from torch import nn @@ -32,24 +30,10 @@ from anomalib.models.components import TimmFeatureExtractor from anomalib.models.components.feature_extractors import dryrun_find_featuremap_dims +from .components import Aggregator, Discriminator, PatchMaker, Preprocessing, Projection, RescaleSegmentor from .loss import FocalLoss -def init_weight(m: nn.Module) -> None: - """Initializes network weights using Xavier normal initialization. - - Applies Xavier initialization for linear layers and normal initialization - for convolutional and batch normalization layers. - """ - if isinstance(m, torch.nn.Linear): - torch.nn.init.xavier_normal_(m.weight) - if isinstance(m, torch.nn.BatchNorm2d): - m.weight.data.normal_(1.0, 0.02) - m.bias.data.fill_(0) - elif isinstance(m, torch.nn.Conv2d): - m.weight.data.normal_(0.0, 0.02) - - def _deduce_dims( feature_extractor: TimmFeatureExtractor, input_size: tuple[int, int], @@ -71,302 +55,6 @@ def _deduce_dims( return [dimensions_mapping[layer]["num_features"] for layer in layers] -class Preprocessing(torch.nn.Module): - """Handles initial feature preprocessing across multiple input dimensions. - - Input: List of features from different backbone layers - Output: Processed features with consistent dimensionality - """ - - def __init__(self, input_dims: list[int | tuple[int, int]], output_dim: int) -> None: - super().__init__() - self.input_dims = input_dims - self.output_dim = output_dim - - self.preprocessing_modules = torch.nn.ModuleList() - for _ in input_dims: - module = MeanMapper(output_dim) - self.preprocessing_modules.append(module) - - def forward(self, features: list[torch.Tensor]) -> torch.Tensor: - """Applies preprocessing modules to a list of input feature tensors. - - Args: - features (list of torch.Tensor): List of feature maps from different - layers of the backbone network. Each tensor can have a different shape. - - Returns: - torch.Tensor: A single tensor with shape (B, N, D), where B is the batch size, - N is the number of feature maps, and D is the output dimension (`output_dim`). - """ - features_ = [] - for module, feature in zip(self.preprocessing_modules, features, strict=False): - features_.append(module(feature)) - return torch.stack(features_, dim=1) - - -class MeanMapper(torch.nn.Module): - """Maps input features to a fixed dimension using adaptive average pooling. - - Input: Variable-sized feature tensors - Output: Fixed-size feature representations - """ - - def __init__(self, preprocessing_dim: int) -> None: - super().__init__() - self.preprocessing_dim = preprocessing_dim - - def forward(self, features: torch.Tensor) -> torch.Tensor: - """Applies adaptive average pooling to reshape features to a fixed size. - - Args: - features (torch.Tensor): Input tensor of shape (B, *) where * denotes - any number of remaining dimensions. It is flattened before pooling. - - Returns: - torch.Tensor: Output tensor of shape (B, D), where D is `preprocessing_dim`. - """ - features = features.reshape(len(features), 1, -1) - return f.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1) - - -class Aggregator(torch.nn.Module): - """Aggregates and reshapes features to a target dimension. - - Input: Multi-dimensional feature tensors - Output: Reshaped and pooled features of specified target dimension - """ - - def __init__(self, target_dim: int) -> None: - super().__init__() - self.target_dim = target_dim - - def forward(self, features: torch.Tensor) -> torch.Tensor: - """Returns reshaped and average pooled features.""" - features = features.reshape(len(features), 1, -1) - features = f.adaptive_avg_pool1d(features, self.target_dim) - return features.reshape(len(features), -1) - - -class Projection(torch.nn.Module): - """Multi-layer projection network for feature adaptation. - - Args: - in_planes: Input feature dimension - out_planes: Output feature dimension - n_layers: Number of projection layers - layer_type: Type of intermediate layers - """ - - def __init__(self, in_planes: int, out_planes: int | None = None, n_layers: int = 1, layer_type: int = 0) -> None: - super().__init__() - - if out_planes is None: - out_planes = in_planes - self.layers = torch.nn.Sequential() - in_ = None - out = None - for i in range(n_layers): - in_ = in_planes if i == 0 else out - out = out_planes - self.layers.add_module(f"{i}fc", torch.nn.Linear(in_, out)) - if i < n_layers - 1 and layer_type > 1: - self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(0.2)) - self.apply(init_weight) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Applies the projection network to the input features. - - Args: - x (torch.Tensor): Input tensor of shape (B, in_planes), where B is the batch size. - - Returns: - torch.Tensor: Transformed tensor of shape (B, out_planes). - """ - return self.layers(x) - - -class Discriminator(torch.nn.Module): - """Discriminator network for anomaly detection. - - Args: - in_planes: Input feature dimension - n_layers: Number of layers - hidden: Hidden layer dimensions - """ - - def __init__(self, in_planes: int, n_layers: int = 2, hidden: int | None = None) -> None: - super().__init__() - - hidden_ = in_planes if hidden is None else hidden - self.body = torch.nn.Sequential() - for i in range(n_layers - 1): - in_ = in_planes if i == 0 else hidden_ - hidden_ = int(hidden_ // 1.5) if hidden is None else hidden - self.body.add_module( - f"block{i + 1}", - torch.nn.Sequential( - torch.nn.Linear(in_, hidden_), - torch.nn.BatchNorm1d(hidden_), - torch.nn.LeakyReLU(0.2), - ), - ) - self.tail = torch.nn.Sequential( - torch.nn.Linear(hidden_, 1, bias=False), - torch.nn.Sigmoid(), - ) - self.apply(init_weight) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Performs a forward pass through the discriminator network. - - Args: - x (torch.Tensor): Input tensor of shape (B, in_planes), where B is the batch size. - - Returns: - torch.Tensor: Output tensor of shape (B, 1) containing probability scores. - """ - x = self.body(x) - return self.tail(x) - - -class PatchMaker: - """Handles patch-based processing of feature maps. - - This class provides utilities for converting feature maps into patches, - reshaping patch scores back to original dimensions, and computing global - anomaly scores from patch-wise predictions. - - Attributes: - patchsize (int): Size of each patch (patchsize x patchsize). - stride (int or None): Stride used for patch extraction. Defaults to patchsize if None. - top_k (int): Number of top patch scores to consider. Used for score reduction. - """ - - def __init__(self, patchsize: int, top_k: int = 0, stride: int | None = None) -> None: - self.patchsize = patchsize - self.stride = stride if stride is not None else patchsize - self.top_k = top_k - - def patchify( - self, - features: torch.Tensor, - return_spatial_info: bool = False, - ) -> tuple[torch.Tensor, list[int]] | torch.Tensor: - """Converts a batch of feature maps into patches. - - Args: - features (torch.Tensor): Input feature maps of shape (B, C, H, W). - return_spatial_info (bool): If True, also returns spatial patch count. Default is False. - - Returns: - torch.Tensor: Output tensor of shape (B, N, C, patchsize, patchsize), where N is number of patches. - list[int], optional: Number of patches in (height, width) dimensions, only if return_spatial_info is True. - """ - padding = int((self.patchsize - 1) / 2) - unfolder = torch.nn.Unfold( - kernel_size=self.patchsize, - stride=self.stride, - padding=padding, - dilation=1, - ) - unfolded_features = unfolder(features) - number_of_total_patches = [] - for s in features.shape[-2:]: - n_patches = (s + 2 * padding - 1 * (self.patchsize - 1) - 1) / self.stride + 1 - number_of_total_patches.append(int(n_patches)) - unfolded_features = unfolded_features.reshape( - *features.shape[:2], - self.patchsize, - self.patchsize, - -1, - ) - unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3) - - if return_spatial_info: - return unfolded_features, number_of_total_patches - return unfolded_features - - @staticmethod - def unpatch_scores(x: torch.Tensor, batchsize: int) -> torch.Tensor: - """Reshapes patch scores back into per-batch format. - - Args: - x (torch.Tensor): Input tensor of shape (B * N, ...). - batchsize (int): Original batch size. - - Returns: - torch.Tensor: Reshaped tensor of shape (B, N, ...). - """ - return x.reshape(batchsize, -1, *x.shape[1:]) - - @staticmethod - def compute_score(x: torch.Tensor) -> torch.Tensor: - """Computes final anomaly scores from patch-wise predictions. - - Args: - x (torch.Tensor): Patch scores of shape (B, N, 1). - - Returns: - torch.Tensor: Final anomaly score per image, shape (B,). - """ - x = x[:, :, 0] # remove last dimension if singleton - return torch.max(x, dim=1).values - - -class RescaleSegmentor: - """A utility class for rescaling and smoothing patch-level anomaly scores to generate segmentation masks. - - Attributes: - target_size (int): The spatial size (height and width) to which patch scores will be rescaled. - smoothing (int): The standard deviation used for Gaussian smoothing. - """ - - def __init__(self, target_size: int = 288) -> None: - """Initializes the RescaleSegmentor. - - Args: - target_size (int, optional): The desired output size (height/width) of segmentation maps. Defaults to 288. - """ - self.target_size = target_size - self.smoothing = 4 - - def convert_to_segmentation( - self, - patch_scores: np.ndarray | torch.Tensor, - device: torch.device, - ) -> list[torch.Tensor]: - """Converts patch-level scores to smoothed segmentation masks. - - Args: - patch_scores (np.ndarray | torch.Tensor): Patch-wise scores of shape [N, H, W]. - device (torch.device): Device on which to perform computation. - - Returns: - List[torch.Tensor]: A list of segmentation masks, each of shape [H, W], - rescaled to `target_size` and smoothed. - """ - with torch.no_grad(): - if isinstance(patch_scores, np.ndarray): - patch_scores = torch.from_numpy(patch_scores) - - scores = patch_scores.to(device) - scores = scores.unsqueeze(1) # [N, 1, H, W] - scores = f.interpolate( - scores, size=self.target_size, mode="bilinear", align_corners=False, - ) - patch_scores = scores.squeeze(1) # [N, H, W] - - patch_stack = patch_scores.unsqueeze(1) # [N, 1, H, W] - smoothed_stack = kf.gaussian_blur2d( - patch_stack, - kernel_size=(5, 5), - sigma=(self.smoothing, self.smoothing), - ) - - return [s.squeeze(0) for s in smoothed_stack] # List of [H, W] tensors - - class GlassModel(nn.Module): """PyTorch Implementation of the GLASS Model.""" From 44dcd6024016f77f3b55753643e0063a97c59347 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Tue, 12 Aug 2025 18:08:34 +0530 Subject: [PATCH 18/23] Added automated download for dtd dataset in Glass Model Signed-off-by: Devansh Agarwal --- .../image/glass/components/patch_maker.py | 4 +--- .../glass/components/rescale_segmentor.py | 2 +- .../models/image/glass/lightning_model.py | 21 +++++++++++++++---- .../models/image/glass/torch_model.py | 2 +- 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/anomalib/models/image/glass/components/patch_maker.py b/src/anomalib/models/image/glass/components/patch_maker.py index a8ff059d02..f953d324e8 100644 --- a/src/anomalib/models/image/glass/components/patch_maker.py +++ b/src/anomalib/models/image/glass/components/patch_maker.py @@ -16,13 +16,11 @@ class PatchMaker: Attributes: patchsize (int): Size of each patch (patchsize x patchsize). stride (int or None): Stride used for patch extraction. Defaults to patchsize if None. - top_k (int): Number of top patch scores to consider. Used for score reduction. """ - def __init__(self, patchsize: int, top_k: int = 0, stride: int | None = None) -> None: + def __init__(self, patchsize: int, stride: int | None = None) -> None: self.patchsize = patchsize self.stride = stride if stride is not None else patchsize - self.top_k = top_k def patchify( self, diff --git a/src/anomalib/models/image/glass/components/rescale_segmentor.py b/src/anomalib/models/image/glass/components/rescale_segmentor.py index bb83e9326e..42dcd7ba86 100644 --- a/src/anomalib/models/image/glass/components/rescale_segmentor.py +++ b/src/anomalib/models/image/glass/components/rescale_segmentor.py @@ -55,7 +55,7 @@ def convert_to_segmentation( patch_stack = patch_scores.unsqueeze(1) # [N, 1, H, W] smoothed_stack = kf.gaussian_blur2d( patch_stack, - kernel_size=(5, 5), + kernel_size=(33, 33), sigma=(self.smoothing, self.smoothing), ) diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index 8a233ec50f..e6ef407b6a 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -18,6 +18,7 @@ ` """ +from pathlib import Path from typing import Any from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -26,6 +27,7 @@ from anomalib import LearningType from anomalib.data import Batch +from anomalib.data.utils import DownloadInfo, download_and_extract from anomalib.metrics import Evaluator from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor @@ -34,6 +36,12 @@ from .torch_model import GlassModel +DTD_DOWNLOAD_INFO = DownloadInfo( + name="dtd-r1.0.1.tar.gz", + url="https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz", + hashsum="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205", +) + class Glass(AnomalibModule): """PyTorch Lightning Implementation of the GLASS Model. @@ -119,10 +127,6 @@ def __init__( discriminator_layers: int = 2, discriminator_hidden: int = 1024, discriminator_margin: float = 0.5, - pre_processor: PreProcessor | bool = True, - post_processor: PostProcessor | bool = True, - evaluator: Evaluator | bool = True, - visualizer: Visualizer | bool = True, mining: int = 1, noise: float = 0.015, radius: float = 0.75, @@ -130,6 +134,10 @@ def __init__( learning_rate: float = 0.0001, step: int = 20, svd: int = 0, + pre_processor: PreProcessor | bool = True, + post_processor: PostProcessor | bool = True, + evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: super().__init__( pre_processor=pre_processor, @@ -184,6 +192,11 @@ def __init__( self.automatic_optimization = False + if anomaly_source_path is not None: + dtd_dir = Path(anomaly_source_path) + if not dtd_dir.is_dir(): + download_and_extract(dtd_dir, DTD_DOWNLOAD_INFO) + @classmethod def configure_pre_processor( cls, diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index eb4f3ae32d..19c3bf1b75 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -60,7 +60,7 @@ class GlassModel(nn.Module): def __init__( self, - input_shape: tuple[int, int] = (256, 256), # (H, W) + input_shape: tuple[int, int] = (288, 288), # (H, W) anomaly_source_path: str | None = None, pretrain_embed_dim: int = 1024, target_embed_dim: int = 1024, From da57095fb3bfd0a7293cff7b65ba0395ef7cf206 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Thu, 14 Aug 2025 15:31:15 +0530 Subject: [PATCH 19/23] Removed some input args Signed-off-by: Devansh Agarwal --- examples/configs/model/glass.yaml | 26 +++++++++++++++++++ .../models/image/glass/lightning_model.py | 19 -------------- .../models/image/glass/torch_model.py | 25 +++++------------- 3 files changed, 33 insertions(+), 37 deletions(-) create mode 100644 examples/configs/model/glass.yaml diff --git a/examples/configs/model/glass.yaml b/examples/configs/model/glass.yaml new file mode 100644 index 0000000000..b6df2f283b --- /dev/null +++ b/examples/configs/model/glass.yaml @@ -0,0 +1,26 @@ +model: + class_path: anomalib.models.Glass + init_args: + input_shape: [288, 288] + backbone: resnet18 + pretrain_embed_dim: 1024 + target_embed_dim: 1024 + patchsize: 3 + patchstride: 1 + pre_trained: true + pre_projection: 1 + discriminator_layers: 2 + discriminator_hidden: 1024 + discriminator_margin: 0.5 + learning_rate: 0.0001 + step: 20 + svd: 0 + +trainer: + max_epochs: 640 + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + patience: 5 + monitor: pixel_AUROC + mode: max \ No newline at end of file diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index e6ef407b6a..c553c518ef 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -93,17 +93,6 @@ class Glass(AnomalibModule): visualizer (Visualizer | bool, optional): Visualization module to generate heatmaps, segmentation overlays, and anomaly scores. Defaults to `True`. - mining (int, optional): Number of iterations or difficulty level for Online Hard Example Mining (OHEM) during - training. - Defaults to `1`. - noise (float, optional): Standard deviation of Gaussian noise used in feature-level anomaly synthesis. - Defaults to `0.015`. - radius (float, optional): Radius parameter used for truncated projection in the anomaly synthesis strategy. - Determines the range for valid synthetic anomalies in the hypersphere or manifold. - Defaults to `0.75`. - random_selection_prob (float, optional): Probability used in random selection logic, such as anomaly mask - generation or augmentation choice. - Defaults to `0.5`. learning_rate (float, optional): Learning rate for training the feature adaptor and discriminator networks. Defaults to `0.0001`. step (int, optional): Number of gradient ascent steps for anomaly synthesis. @@ -127,10 +116,6 @@ def __init__( discriminator_layers: int = 2, discriminator_hidden: int = 1024, discriminator_margin: float = 0.5, - mining: int = 1, - noise: float = 0.015, - radius: float = 0.75, - random_selection_prob: float = 0.5, learning_rate: float = 0.0001, step: int = 20, svd: int = 0, @@ -165,10 +150,6 @@ def __init__( discriminator_margin=discriminator_margin, step=step, svd=svd, - mining=mining, - noise=noise, - radius=radius, - random_selection_prob=random_selection_prob, ) self.learning_rate = learning_rate diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 19c3bf1b75..3929e1fd74 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -73,10 +73,6 @@ def __init__( discriminator_layers: int = 2, discriminator_hidden: int = 1024, discriminator_margin: float = 0.5, - mining: int = 1, - noise: float = 0.015, - radius: float = 0.75, - random_selection_prob: float = 0.5, step: int = 20, svd: int = 0, ) -> None: @@ -126,10 +122,6 @@ def __init__( hidden=self.discriminator_hidden, ) - self.random_selection_prob = random_selection_prob - self.radius = radius - self.mining = mining - self.noise = noise self.distribution = 0 self.step = step self.svd = svd @@ -347,7 +339,7 @@ def forward( ) mask_s_gt = mask_s_resized.reshape(-1, 1) - noise = torch.normal(0, self.noise, true_feats.shape).to(device) + noise = torch.normal(0, 0.015, true_feats.shape).to(device) gaus_feats = true_feats + noise center = self.center.repeat(img.shape[0], 1, 1) @@ -358,7 +350,7 @@ def forward( ) c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0) dist_t = torch.norm(true_points - c_t_points, dim=1) - r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(device) + r_t = torch.tensor([torch.quantile(dist_t, q=0.75)]).to(device) for step in range(self.step + 1): scores = self.discriminator(torch.cat([true_feats, gaus_feats])) @@ -410,14 +402,11 @@ def forward( fake_scores = self.discriminator(fake_feats) - if self.random_selection_prob > 0: - fake_dist = (fake_scores - mask_s_gt) ** 2 - d_hard = torch.quantile(fake_dist, q=self.random_selection_prob) - fake_scores_ = fake_scores[fake_dist >= d_hard].unsqueeze(1) - mask_ = mask_s_gt[fake_dist >= d_hard].unsqueeze(1) - else: - fake_scores_ = fake_scores - mask_ = mask_s_gt + fake_dist = (fake_scores - mask_s_gt) ** 2 + d_hard = torch.quantile(fake_dist, q=0.5) + fake_scores_ = fake_scores[fake_dist >= d_hard].unsqueeze(1) + mask_ = mask_s_gt[fake_dist >= d_hard].unsqueeze(1) + output = torch.cat([1 - fake_scores_, fake_scores_], dim=1) focal_loss = self.focal_loss(output, mask_) From ba5a6dd86a90e465f88a4559e5ca288934e7eb40 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Thu, 14 Aug 2025 20:54:40 +0530 Subject: [PATCH 20/23] Change in default parameters Signed-off-by: Devansh Agarwal --- .../models/image/glass/lightning_model.py | 31 ++++++++++--------- .../models/image/glass/torch_model.py | 8 ++--- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index c553c518ef..5f291a0359 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -56,14 +56,15 @@ class Glass(AnomalibModule): Args: input_shape (tuple[int, int]): Input image dimensions as a tuple of (height, width). Required for shaping the input pipeline. + Defaults to `(288, 288)`. anomaly_source_path (str): Path to the dataset or source directory containing normal images and anomaly textures backbone (str, optional): Name of the CNN backbone used for feature extraction. - Defaults to `"resnet18"`. + Defaults to `"wideresnet50"`. pretrain_embed_dim (int, optional): Dimensionality of features extracted by the pre-trained backbone before adaptation. - Defaults to `1024`. + Defaults to `1536`. target_embed_dim (int, optional): Dimensionality of the target adapted features after projection. - Defaults to `1024`. + Defaults to `1536`. patchsize (int, optional): Size of the local patch used in feature aggregation (e.g., for neighborhood pooling). Defaults to `3`. patchstride (int, optional): Stride used when extracting patches for local feature aggregation. @@ -82,6 +83,12 @@ class Glass(AnomalibModule): discriminator_margin (float, optional): Margin used for contrastive or binary classification loss in discriminator training. Defaults to `0.5`. + learning_rate (float, optional): Learning rate for training the feature adaptor and discriminator networks. + Defaults to `0.0001`. + step (int, optional): Number of gradient ascent steps for anomaly synthesis. + Defaults to `20`. + svd (int, optional): Flag to enable SVD-based feature projection. + Defaults to `0`. pre_processor (PreProcessor | bool, optional): reprocessing module or flag to enable default preprocessing. Set to `True` to apply default normalization and resizing. Defaults to `True`. @@ -93,21 +100,15 @@ class Glass(AnomalibModule): visualizer (Visualizer | bool, optional): Visualization module to generate heatmaps, segmentation overlays, and anomaly scores. Defaults to `True`. - learning_rate (float, optional): Learning rate for training the feature adaptor and discriminator networks. - Defaults to `0.0001`. - step (int, optional): Number of gradient ascent steps for anomaly synthesis. - Defaults to `20`. - svd (int, optional): Flag to enable SVD-based feature projection. - Defaults to `0`. """ def __init__( self, - input_shape: tuple[int, int] = (256, 256), + input_shape: tuple[int, int] = (288, 288), anomaly_source_path: str | None = None, - backbone: str = "resnet18", - pretrain_embed_dim: int = 1024, - target_embed_dim: int = 1024, + backbone: str = "wideresnet50", + pretrain_embed_dim: int = 1536, + target_embed_dim: int = 1536, patchsize: int = 3, patchstride: int = 1, pre_trained: bool = True, @@ -132,7 +133,7 @@ def __init__( ) if layers is None: - layers = ["layer1", "layer2", "layer3"] + layers = ["layer2", "layer3"] self.model = GlassModel( input_shape=input_shape, @@ -155,7 +156,7 @@ def __init__( self.learning_rate = learning_rate if pre_projection > 0: - self.projection_opt = optim.AdamW( + self.projection_opt = optim.Adam( self.model.projection.parameters(), self.learning_rate, weight_decay=1e-5, diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 3929e1fd74..79cc410186 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -62,9 +62,9 @@ def __init__( self, input_shape: tuple[int, int] = (288, 288), # (H, W) anomaly_source_path: str | None = None, - pretrain_embed_dim: int = 1024, - target_embed_dim: int = 1024, - backbone: str = "resnet18", + pretrain_embed_dim: int = 1536, + target_embed_dim: int = 1536, + backbone: str = "wideresnet50", patchsize: int = 3, patchstride: int = 1, pre_trained: bool = True, @@ -79,7 +79,7 @@ def __init__( super().__init__() if layers is None: - layers = ["layer1", "layer2", "layer3"] + layers = ["layer2", "layer3"] self.backbone = backbone self.layers = layers From 714a3c34c4d228f71edd7686495fbea3ad433af9 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Thu, 14 Aug 2025 21:00:00 +0530 Subject: [PATCH 21/23] Fixed default backbone name Signed-off-by: Devansh Agarwal --- src/anomalib/models/image/glass/lightning_model.py | 4 ++-- src/anomalib/models/image/glass/torch_model.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index 5f291a0359..c19b6e25b2 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -59,7 +59,7 @@ class Glass(AnomalibModule): Defaults to `(288, 288)`. anomaly_source_path (str): Path to the dataset or source directory containing normal images and anomaly textures backbone (str, optional): Name of the CNN backbone used for feature extraction. - Defaults to `"wideresnet50"`. + Defaults to `"wide_resnet50_2"`. pretrain_embed_dim (int, optional): Dimensionality of features extracted by the pre-trained backbone before adaptation. Defaults to `1536`. @@ -106,7 +106,7 @@ def __init__( self, input_shape: tuple[int, int] = (288, 288), anomaly_source_path: str | None = None, - backbone: str = "wideresnet50", + backbone: str = "wide_resnet50_2", pretrain_embed_dim: int = 1536, target_embed_dim: int = 1536, patchsize: int = 3, diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 79cc410186..937bfbfe1e 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -64,7 +64,7 @@ def __init__( anomaly_source_path: str | None = None, pretrain_embed_dim: int = 1536, target_embed_dim: int = 1536, - backbone: str = "wideresnet50", + backbone: str = "wide_resnet50_2", patchsize: int = 3, patchstride: int = 1, pre_trained: bool = True, From 1a3519c9360171005196115d1b98020abae9ea61 Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Thu, 14 Aug 2025 22:02:03 +0530 Subject: [PATCH 22/23] Changed configure pre_processor method Signed-off-by: Devansh Agarwal --- src/anomalib/models/image/glass/lightning_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index c19b6e25b2..f18edca97c 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -209,7 +209,7 @@ def configure_pre_processor( ... ) >>> transformed_image = pre_processor(image) """ - image_size = image_size or (256, 256) + image_size = image_size or (288, 288) if center_crop_size is not None: if center_crop_size[0] > image_size[0] or center_crop_size[1] > image_size[1]: From 5466d46bf0b0ea1adec415ebceb43db664044f1e Mon Sep 17 00:00:00 2001 From: Devansh Agarwal Date: Sat, 13 Sep 2025 17:21:35 +0530 Subject: [PATCH 23/23] Made some changes to the workflow of GLASS Model Signed-off-by: Devansh Agarwal --- .../glass/components/rescale_segmentor.py | 7 +++++-- .../models/image/glass/lightning_model.py | 19 +++++++++++-------- src/anomalib/models/image/glass/loss.py | 9 ++++----- .../models/image/glass/torch_model.py | 11 +++++++---- 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/src/anomalib/models/image/glass/components/rescale_segmentor.py b/src/anomalib/models/image/glass/components/rescale_segmentor.py index 42dcd7ba86..ce1c11cc08 100644 --- a/src/anomalib/models/image/glass/components/rescale_segmentor.py +++ b/src/anomalib/models/image/glass/components/rescale_segmentor.py @@ -17,7 +17,7 @@ class RescaleSegmentor: smoothing (int): The standard deviation used for Gaussian smoothing. """ - def __init__(self, target_size: int = 288) -> None: + def __init__(self, target_size: tuple[int, int] = (288, 288)) -> None: """Initializes the RescaleSegmentor. Args: @@ -48,7 +48,10 @@ def convert_to_segmentation( scores = patch_scores.to(device) scores = scores.unsqueeze(1) # [N, 1, H, W] scores = f.interpolate( - scores, size=self.target_size, mode="bilinear", align_corners=False, + scores, + size=self.target_size, + mode="bilinear", + align_corners=False, ) patch_scores = scores.squeeze(1) # [N, H, W] diff --git a/src/anomalib/models/image/glass/lightning_model.py b/src/anomalib/models/image/glass/lightning_model.py index f18edca97c..c85233223a 100644 --- a/src/anomalib/models/image/glass/lightning_model.py +++ b/src/anomalib/models/image/glass/lightning_model.py @@ -135,6 +135,11 @@ def __init__( if layers is None: layers = ["layer2", "layer3"] + if anomaly_source_path is not None: + dtd_dir = Path(anomaly_source_path) + if not dtd_dir.is_dir(): + download_and_extract(dtd_dir, DTD_DOWNLOAD_INFO) + self.model = GlassModel( input_shape=input_shape, anomaly_source_path=anomaly_source_path, @@ -154,6 +159,7 @@ def __init__( ) self.learning_rate = learning_rate + self.pre_trained = pre_trained if pre_projection > 0: self.projection_opt = optim.Adam( @@ -164,7 +170,7 @@ def __init__( else: self.projection_opt = None - if not pre_trained: + if not self.pre_trained: self.backbone_opt = optim.AdamW( self.model.forward_modules["feature_aggregator"].backbone.parameters(), self.learning_rate, @@ -174,11 +180,6 @@ def __init__( self.automatic_optimization = False - if anomaly_source_path is not None: - dtd_dir = Path(anomaly_source_path) - if not dtd_dir.is_dir(): - download_and_extract(dtd_dir, DTD_DOWNLOAD_INFO) - @classmethod def configure_pre_processor( cls, @@ -247,9 +248,11 @@ def training_step(self, batch: Batch, batch_idx: int) -> STEP_OUTPUT: STEP_OUTPUT: Dictionary containing loss values and metrics """ del batch_idx + discriminator_opt = self.optimizers() - self.model.forward_modules.eval() + if not self.pre_trained: + self.model.forward_modules["feature_aggregator"].train() if self.model.pre_projection > 0: self.model.projection.train() self.model.discriminator.train() @@ -295,7 +298,7 @@ def validation_step(self, batch: Batch, batch_idx: int) -> STEP_OUTPUT: predictions = self.model(batch.image) return batch.update(**predictions._asdict()) - def on_train_start(self) -> None: + def on_train_epoch_start(self) -> None: """Initialize model by computing mean feature representation across training dataset. This method is called at the start of training and computes a mean feature vector diff --git a/src/anomalib/models/image/glass/loss.py b/src/anomalib/models/image/glass/loss.py index 89f2957446..4fb8e53c91 100644 --- a/src/anomalib/models/image/glass/loss.py +++ b/src/anomalib/models/image/glass/loss.py @@ -75,7 +75,7 @@ class FocalLoss(nn.Module): def __init__( self, apply_nonlinearity: nn.Module | None = None, - alpha: float | list | np.ndarray | None = None, + alpha: float | torch.Tensor | np.ndarray | None = None, gamma: float = 2, balance_index: int = 0, smooth: float = 1e-5, @@ -129,13 +129,12 @@ def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target = target.view(-1, 1) alpha = self.alpha - if alpha is None: + if self.alpha is None: alpha = torch.ones(num_classes, 1) - elif isinstance(alpha, (list | np.ndarray)): - assert len(alpha) == num_classes + elif isinstance(self.alpha, (list | np.ndarray)): alpha = torch.FloatTensor(alpha).view(num_classes, 1) alpha = alpha / alpha.sum() - elif isinstance(alpha, float): + elif isinstance(self.alpha, float): alpha = torch.ones(num_classes, 1) alpha = alpha * (1 - self.alpha) alpha[self.balance_index] = self.alpha diff --git a/src/anomalib/models/image/glass/torch_model.py b/src/anomalib/models/image/glass/torch_model.py index 937bfbfe1e..32485ad2ac 100644 --- a/src/anomalib/models/image/glass/torch_model.py +++ b/src/anomalib/models/image/glass/torch_model.py @@ -128,7 +128,7 @@ def __init__( self.patch_maker = PatchMaker(patchsize, stride=patchstride) - self.anomaly_segmentor = RescaleSegmentor(target_size=input_shape[:]) + self.anomaly_segmentor = RescaleSegmentor(target_size=input_shape) def calculate_center(self, dataloader: dataloader, device: torch.device) -> None: """Calculates and updates the center embedding from a dataset. @@ -164,7 +164,8 @@ def calculate_center(self, dataloader: dataloader, device: torch.device) -> None else: self.center += torch.mean(outputs, dim=0) - def calculate_features(self, + def calculate_features( + self, img: torch.Tensor, aug: torch.Tensor, evaluation: bool = False, @@ -193,8 +194,10 @@ def calculate_features(self, true_feats = true_feats[0] if len(true_feats) == 2 else true_feats else: fake_feats = self.generate_embeddings(aug, evaluation=evaluation)[0] + assert isinstance(fake_feats, torch.Tensor) fake_feats.requires_grad = True true_feats = self.generate_embeddings(img, evaluation=evaluation)[0] + assert isinstance(true_feats, torch.Tensor) true_feats.requires_grad = True return true_feats, fake_feats @@ -227,7 +230,7 @@ def generate_embeddings( """ if not evaluation and not self.pre_trained: self.forward_modules["feature_aggregator"].train() - features = self.forward_modules["feature_aggregator"](images) + features = self.forward_modules["feature_aggregator"](images, eval=evaluation) else: self.forward_modules["feature_aggregator"].eval() with torch.no_grad(): @@ -315,7 +318,7 @@ def calculate_anomaly_scores(self, images: torch.Tensor) -> torch.Tensor: def forward( self, img: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | InferenceBatch: """Forward pass to compute patch-wise feature embeddings for original and augmented images. Depending on whether a pre-projection module is used, this method optionally applies it to the