diff --git a/CHANGELOG.md b/CHANGELOG.md index 56738df2c3..545f6a2f1b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added +- 🚀 Add MEBin post-processing method + ### Removed ### Changed diff --git a/docs/source/markdown/guides/reference/post_processing/index.md b/docs/source/markdown/guides/reference/post_processing/index.md index 02bdb4d638..493a293413 100644 --- a/docs/source/markdown/guides/reference/post_processing/index.md +++ b/docs/source/markdown/guides/reference/post_processing/index.md @@ -23,6 +23,16 @@ Post-processor for one-class anomaly detection. +++ [Learn more »](one-class-post-processor) ::: + +:::{grid-item-card} {octicon}`gear` MEBin Post-processor +:link: mebin-post-processor +:link-type: ref + +MEBin post-processor from AnomalyNCD. + ++++ +[Learn more »](mebin-post-processor) +::: :::: (base-post-processor)= @@ -44,3 +54,13 @@ Post-processor for one-class anomaly detection. :members: :show-inheritance: ``` + +(mebin-post-processor)= + +## MEBin Post-processor + +```{eval-rst} +.. automodule:: anomalib.post_processing.mebin_post_processor + :members: + :show-inheritance: +``` diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index f38d2ffde1..96017e6eeb 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -62,7 +62,7 @@ from .pimo import AUPIMO, PIMO from .precision_recall_curve import BinaryPrecisionRecallCurve from .pro import PRO -from .threshold import F1AdaptiveThreshold, ManualThreshold +from .threshold import F1AdaptiveThreshold, ManualThreshold, MEBin __all__ = [ "AUROC", @@ -83,4 +83,5 @@ "PRO", "PIMO", "AUPIMO", + "MEBin", ] diff --git a/src/anomalib/metrics/threshold/__init__.py b/src/anomalib/metrics/threshold/__init__.py index 1bf10854ec..0cebd87b82 100644 --- a/src/anomalib/metrics/threshold/__init__.py +++ b/src/anomalib/metrics/threshold/__init__.py @@ -24,5 +24,6 @@ from .base import BaseThreshold, Threshold from .f1_adaptive_threshold import F1AdaptiveThreshold from .manual_threshold import ManualThreshold +from .mebin import MEBin -__all__ = ["BaseThreshold", "Threshold", "F1AdaptiveThreshold", "ManualThreshold"] +__all__ = ["BaseThreshold", "Threshold", "F1AdaptiveThreshold", "ManualThreshold", "MEBin"] diff --git a/src/anomalib/metrics/threshold/mebin.py b/src/anomalib/metrics/threshold/mebin.py new file mode 100644 index 0000000000..fa13f94709 --- /dev/null +++ b/src/anomalib/metrics/threshold/mebin.py @@ -0,0 +1,290 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""MEBin (Main Element Binarization) adaptive thresholding for anomaly detection. + +This module provides the ``MEBin`` class which implements the Main Element +Binarization algorithm designed to address the non-prominence of anomalies +in anomaly maps. MEBin obtains anomaly-centered images by analyzing the +stability of connected components across multiple threshold levels. + +The algorithm is particularly effective for: +- Industrial anomaly detection scenarios +- Multi-class anomaly classification tasks +- Cases where anomalies are non-prominent in anomaly maps +- Avoiding the impact of incorrect detections + +The threshold is computed by: +1. Adaptively determining threshold search range from anomaly map statistics +2. Sampling anomaly maps at configurable rates across threshold range +3. Counting connected components at each threshold level +4. Finding stable intervals where component count remains constant +5. Selecting threshold from the longest stable interval + +MEBin was introduced in "AnomalyNCD: Towards Novel Anomaly Class Discovery +in Industrial Scenarios" (https://arxiv.org/abs/2410.14379). + +Example: + >>> import numpy as np + >>> from anomalib.metrics.threshold import MEBin + >>> + >>> # Create sample anomaly maps with simulated anomalous regions + >>> anomaly_maps = [] + >>> for i in range(5): + ... amap = np.random.rand(128, 128) * 50 # Background noise + ... amap[40:80, 40:80] = np.random.rand(40, 40) * 200 + 55 # Anomalous region + ... anomaly_maps.append(amap) + >>> + >>> # Initialize MEBin with appropriate parameters + >>> mebin = MEBin(anomaly_maps, sample_rate=8, min_interval_len=3) + >>> + >>> # Compute binary masks and thresholds + >>> binarized_maps, thresholds = mebin.binarize_anomaly_maps() + >>> print(f"Processed {len(binarized_maps)} maps, thresholds: {thresholds}") + Processed 5 maps, thresholds: [...] + +Note: + MEBin is designed for industrial scenarios where anomalies may be + non-prominent. The min_interval_len parameter should be tuned based + on the expected stability of connected component counts. +""" + +from __future__ import annotations + +import cv2 +import numpy as np +from tqdm import tqdm + + +class MEBin: + """MEBin (Main Element Binarization) adaptive thresholding algorithm. + + This class implements the Main Element Binarization algorithm designed + to address non-prominent anomalies in industrial anomaly detection scenarios. + MEBin determines optimal thresholds by analyzing the stability of connected + component counts across different threshold levels to obtain anomaly-centered + binary representations. + + The algorithm works by: + - Adaptively determining threshold search ranges from anomaly statistics + - Sampling anomaly maps at configurable rates across threshold range + - Counting connected components at each threshold level + - Identifying stable intervals where component count remains constant + - Selecting the optimal threshold from the longest stable interval + - Optionally applying morphological erosion to reduce noise + + Args: + anomaly_map_list (list[np.ndarray]): List of anomaly map arrays as numpy arrays. + sample_rate (int, optional): Sampling rate for threshold search. Higher + values reduce processing time but may affect accuracy. + Defaults to 4. + min_interval_len (int, optional): Minimum length of stable intervals. + Should be tuned based on the expected stability of anomaly score + distributions. + Defaults to 4. + erode (bool, optional): Whether to apply morphological erosion to + binarized results to reduce noise. + Defaults to True. + + Example: + >>> import numpy as np + >>> from anomalib.metrics.threshold import MEBin + >>> + >>> # Create sample anomaly maps with realistic structure + >>> anomaly_maps = [] + >>> for i in range(3): + ... # Background with low anomaly scores + ... amap = np.random.rand(64, 64) * 30 + ... # Add anomalous regions with higher scores + ... amap[20:40, 20:40] = np.random.rand(20, 20) * 150 + 100 + ... anomaly_maps.append(amap) + >>> + >>> # Initialize MEBin with custom parameters + >>> mebin = MEBin(anomaly_maps, sample_rate=4, min_interval_len=3, erode=True) + >>> + >>> # Binarize anomaly maps + >>> binary_masks, thresholds = mebin.binarize_anomaly_maps() + >>> print(f"Generated {len(binary_masks)} binary masks") + Generated 3 binary masks + """ + + def __init__( + self, + anomaly_map_list: list[np.ndarray], + sample_rate: int = 4, + min_interval_len: int = 4, + erode: bool = True, + ) -> None: + self.anomaly_map_list = anomaly_map_list + + self.sample_rate = sample_rate + self.min_interval_len = min_interval_len + self.erode = erode + + # Adaptively determine the threshold search range + self.max_th, self.min_th = self.get_search_range() + + def get_search_range(self) -> tuple[float, float]: + """Determine the threshold search range adaptively. + + This method analyzes all anomaly maps to determine the minimum and maximum + threshold values for the binarization process. The search range is based + on the actual anomaly score distributions in the input maps. + + Returns: + tuple[float, float]: A tuple containing: + - max_th: Maximum threshold for binarization. + - min_th: Minimum threshold for binarization. + """ + # Get the anomaly scores of all anomaly maps + anomaly_score_list = [np.max(x) for x in self.anomaly_map_list] + + # Select the maximum and minimum anomaly scores from images + max_score, min_score = max(anomaly_score_list), min(anomaly_score_list) + max_th, min_th = max_score, min_score + + return max_th, min_th + + def get_threshold( + self, + anomaly_num_sequence: list[int], + min_interval_len: int, + ) -> tuple[int, int]: + """Find the 'stable interval' in the anomaly region number sequence. + + Stable Interval: A continuous threshold range in which the number of connected components remains constant, + and the length of the threshold range is greater than or equal to the given length threshold + (min_interval_len). + + Args: + anomaly_num_sequence (list): Sequence of connected component counts + at each threshold level, ordered from high to low threshold. + min_interval_len (int): Minimum length requirement for stable intervals. + Longer intervals indicate more robust threshold selection. + + Returns: + tuple[int, int]: A tuple containing (threshold, est_anomaly_num) where + threshold is the final threshold for binarization and est_anomaly_num + is the estimated number of anomalies. + """ + interval_result = {} + current_index = 0 + while current_index < len(anomaly_num_sequence): + start = current_index + value = anomaly_num_sequence[start] + end = start + # Move the 'end' pointer forward until a different connected component number is encountered. + while ( + end < len(anomaly_num_sequence) - 1 and anomaly_num_sequence[end] == anomaly_num_sequence[end + 1] + ): + end += 1 + # If the length of the current stable interval is greater than or equal to the given + # threshold (min_interval_len), and the value is not zero, record this interval. + if end - start + 1 >= min_interval_len and value != 0: + if value not in interval_result: + interval_result[value] = [(start, end)] + else: + interval_result[value].append((start, end)) + current_index = end + 1 + + # If a 'stable interval' exists, calculate the final threshold based on the longest stable interval. + # If no stable interval is found, it indicates that no anomaly regions exist, and 255 is returned. + + if interval_result: + # Iterate through the stable intervals, calculating their lengths and corresponding + # number of connected component. + count_result = {} + for anomaly_num in interval_result: + count_result[anomaly_num] = max(x[1] - x[0] for x in interval_result[anomaly_num]) + est_anomaly_num = max(count_result, key=lambda k: count_result[k]) + est_anomaly_num_interval_result = interval_result[est_anomaly_num] + + # Find the longest stable interval. + longest_interval = sorted(est_anomaly_num_interval_result, key=lambda x: x[1] - x[0])[-1] + + # Use the endpoint threshold of the longest stable interval as the final threshold. + index = longest_interval[1] + threshold = 255 - index * self.sample_rate + threshold = int(threshold * (self.max_th - self.min_th) / 255 + self.min_th) + return threshold, est_anomaly_num + return 255, 0 + + def bin_and_erode(self, anomaly_map: np.ndarray, threshold: int) -> np.ndarray: + """Binarize anomaly map and optionally apply erosion. + + This method converts a continuous anomaly map to a binary mask using + the specified threshold, and optionally applies morphological erosion + to reduce noise and smooth the boundaries of anomaly regions. + + The binarization process: + 1. Pixels above threshold become 255 (anomalous) + 2. Pixels below threshold become 0 (normal) + 3. Optional erosion with 6x6 kernel to reduce noise + + Args: + anomaly_map (numpy.ndarray): Input anomaly map with continuous + anomaly scores to be binarized. + threshold (int): Threshold value for binarization. Pixels with + values above this threshold are considered anomalous. + + Returns: + numpy.ndarray: Binary mask where 255 indicates anomalous regions + and 0 indicates normal regions. The result is of type uint8. + + Note: + Erosion is applied with a 6x6 kernel and 1 iteration to balance + noise reduction with preservation of anomaly boundaries. + """ + bin_result = np.where(anomaly_map > threshold, 255, 0).astype(np.uint8) + + # Apply erosion operation to the binarized result + if self.erode: + kernel_size = 6 + iter_num = 1 + kernel = np.ones((kernel_size, kernel_size), np.uint8) + bin_result = cv2.erode(bin_result, kernel, iterations=iter_num) + return bin_result + + def binarize_anomaly_maps(self) -> tuple[list[np.ndarray], list[int]]: + """Perform binarization within the given threshold search range. + + Count the number of connected components in the binarized results. + Adaptively determine the threshold according to the count, + and perform binarization on the anomaly maps. + + Returns: + tuple[list[np.ndarray], list[int]]: A tuple containing (binarized_maps, thresholds) + where binarized_maps is a list of binarized images and thresholds is a list + of thresholds for each image. + """ + self.binarized_maps = [] + self.thresholds = [] + + for anomaly_map in tqdm(self.anomaly_map_list): + # Normalize the anomaly map within the given threshold search range. + if self.max_th == self.min_th: + # Rare case where all anomaly maps have identical max values + anomaly_map_norm = np.where(anomaly_map < self.min_th, 0, 255) + else: + anomaly_map_norm = np.where( + anomaly_map < self.min_th, + 0, + ((anomaly_map - self.min_th) / (self.max_th - self.min_th)) * 255, + ) + anomaly_num_sequence = [] + + # Search for the threshold from high to low within the given range using the specified sampling rate. + for score in range(255, 0, -self.sample_rate): + bin_result = self.bin_and_erode(anomaly_map_norm, score) + num_labels, *_rest = cv2.connectedComponentsWithStats(bin_result, connectivity=8) + anomaly_num = num_labels - 1 + anomaly_num_sequence.append(anomaly_num) + + # Adaptively determine the threshold based on the anomaly connected component count sequence. + threshold, _est_anomaly_num = self.get_threshold(anomaly_num_sequence, self.min_interval_len) + + # Binarize the anomaly image based on the determined threshold. + bin_result = self.bin_and_erode(anomaly_map, threshold) + self.binarized_maps.append(bin_result) + self.thresholds.append(threshold) + + return self.binarized_maps, self.thresholds diff --git a/src/anomalib/post_processing/__init__.py b/src/anomalib/post_processing/__init__.py index 5ca7bf598b..bc7a639eea 100644 --- a/src/anomalib/post_processing/__init__.py +++ b/src/anomalib/post_processing/__init__.py @@ -19,6 +19,7 @@ >>> predictions = post_processor(anomaly_maps=anomaly_maps) """ +from .mebin_post_processor import MEBinPostProcessor from .post_processor import PostProcessor -__all__ = ["PostProcessor"] +__all__ = ["PostProcessor", "MEBinPostProcessor"] diff --git a/src/anomalib/post_processing/mebin_post_processor.py b/src/anomalib/post_processing/mebin_post_processor.py new file mode 100644 index 0000000000..ee6db455ff --- /dev/null +++ b/src/anomalib/post_processing/mebin_post_processor.py @@ -0,0 +1,234 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Post-processing module for MEBin-based anomaly detection results. + +This module provides post-processing functionality for anomaly detection +outputs through the :class:`MEBinPostProcessor` class. + +MEBin was introduced in AnomalyNCD : https://arxiv.org/pdf/2410.14379 + +The MEBin post-processor handles: + - Converting anomaly maps to binary masks using MEBin algorithm + - Sampling anomaly maps at configurable rates for efficient processing + - Applying morphological operations (erosion) to refine binary masks + - Maintaining minimum interval lengths for consistent mask generation + - Formatting results for downstream use + +Example: + Example: + >>> from anomalib.post_processing import MEBinPostProcessor + >>> from anomalib.data import InferenceBatch + >>> import torch + >>> # Create sample anomaly maps + >>> anomaly_maps = torch.rand(4, 1, 256, 256) + >>> predictions = InferenceBatch(anomaly_map=anomaly_maps) + >>> post_processor = MEBinPostProcessor(sample_rate=4, min_interval_len=4) + >>> results = post_processor(predictions) +""" + +import numpy as np +import torch +from lightning import LightningModule, Trainer + +from anomalib.data import Batch, InferenceBatch +from anomalib.metrics import MEBin + +from .post_processor import PostProcessor + + +class MEBinPostProcessor(PostProcessor): + """Post-processor for MEBin-based anomaly detection. + + This class handles post-processing of anomaly detection results by: + - Converting continuous anomaly maps to binary masks using MEBin algorithm + - Sampling anomaly maps at configurable rates for efficient processing + - Applying morphological operations (erosion) to refine binary masks + - Maintaining minimum interval lengths for consistent mask generation + - Formatting results for downstream use + + Args: + sample_rate (int): Threshold sampling step size. + Defaults to 4 + min_interval_len (int): Minimum length of the stable interval. Can be adjusted based on the interval + between normal and abnormal score distributions in the anomaly score maps. + Decrease if there are many false negatives, increase if there are many false positives. + Defaults to 4 + erode (bool): Whether to perform erosion after binarization to eliminate noise. + Defaults to True + **kwargs: Additional keyword arguments passed to parent class. + + Example: + >>> from anomalib.post_processing import MEBinPostProcessor + >>> from anomalib.data import InferenceBatch + >>> import torch + >>> # Create sample predictions + >>> anomaly_maps = torch.rand(4, 1, 256, 256) + >>> predictions = InferenceBatch(anomaly_map=anomaly_maps) + >>> post_processor = MEBinPostProcessor(sample_rate=4, min_interval_len=4) + >>> results = post_processor(predictions) + """ + + def __init__( + self, + sample_rate: int = 4, + min_interval_len: int = 4, + erode: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.sample_rate = sample_rate + self.min_interval_len = min_interval_len + self.erode = erode + + def forward(self, predictions: InferenceBatch) -> InferenceBatch: + """Post-process model predictions using MEBin algorithm. + + This method converts continuous anomaly maps to binary masks using the MEBin + algorithm, which provides efficient and accurate binarization of anomaly + detection results. + + Args: + predictions (InferenceBatch): Batch containing model predictions with + anomaly maps to be processed. + + Returns: + InferenceBatch: Post-processed batch with binary masks generated from + anomaly maps using MEBin algorithm. + + Note: + The method automatically handles tensor-to-numpy conversion and back, + ensuring compatibility with the original tensor device and dtype. + """ + if predictions.anomaly_map is None: + msg = "Anomaly map is required for MEBin post-processing" + raise ValueError(msg) + + # Store the original tensor for device and dtype info + original_anomaly_map = predictions.anomaly_map + anomaly_maps = original_anomaly_map.detach().cpu().numpy() + if anomaly_maps.ndim == 4: + anomaly_maps = anomaly_maps[:, 0, :, :] # Remove channel dimension + + # Convert to proper format for MEBin (don't normalize individually) + # MEBin will handle normalization after determining the global min/max range + anomaly_maps_list = [amap.astype(np.float32) for amap in anomaly_maps] + + mebin = MEBin( + anomaly_map_list=anomaly_maps_list, + sample_rate=self.sample_rate, + min_interval_len=self.min_interval_len, + erode=self.erode, + ) + binarized_maps, _ = mebin.binarize_anomaly_maps() + + # Convert back to torch.Tensor and normalize to 0/1 + pred_masks = torch.stack([torch.from_numpy(bm).to(original_anomaly_map.device) for bm in binarized_maps]) + pred_masks = (pred_masks > 0).to(original_anomaly_map.dtype) + + # Create result with MEBin pred_mask + result = InferenceBatch( + pred_label=predictions.pred_label, + pred_score=predictions.pred_score, + pred_mask=pred_masks, + anomaly_map=predictions.anomaly_map, + ) + + # Apply parent class post-processing for normalization and thresholding + # This will compute pred_label from pred_score if needed + return super().forward(result) + + def on_test_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: Batch, + *args, + **kwargs, + ) -> None: + """Apply MEBin post-processing to test batch predictions. + + Args: + trainer (Trainer): PyTorch Lightning trainer instance. + pl_module (LightningModule): PyTorch Lightning module instance. + outputs (Batch): Batch containing model predictions. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + del trainer, pl_module, args, kwargs # Unused arguments + self.post_process_batch(outputs) + + def on_predict_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: Batch, + *args, + **kwargs, + ) -> None: + """Apply MEBin post-processing to prediction batch. + + Args: + trainer (Trainer): PyTorch Lightning trainer instance. + pl_module (LightningModule): PyTorch Lightning module instance. + outputs (Batch): Batch containing model predictions. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + del trainer, pl_module, args, kwargs # Unused arguments + self.post_process_batch(outputs) + + def post_process_batch(self, batch: Batch) -> None: + """Post-process a batch of predictions using MEBin algorithm. + + This method applies MEBin binarization to anomaly maps in the batch and + updates the pred_mask field with the binarized results. + + Args: + batch (Batch): Batch containing model predictions to be processed. + """ + if batch.anomaly_map is None: + return + + # Store the original tensor for device and dtype info + original_anomaly_map = batch.anomaly_map + anomaly_maps = original_anomaly_map.detach().cpu().numpy() + + # Handle different tensor shapes + if anomaly_maps.ndim == 4: + anomaly_maps = anomaly_maps[:, 0, :, :] # Remove channel dimension if present + elif anomaly_maps.ndim == 3: + # Already in correct format (batch, height, width) + pass + else: + msg = f"Unsupported anomaly map shape: {anomaly_maps.shape}" + raise ValueError(msg) + + # Convert to proper format for MEBin (don't normalize individually) + # MEBin will handle normalization after determining the global min/max range + anomaly_maps_list = [amap.astype(np.float32) for amap in anomaly_maps] + + # Apply MEBin binarization + mebin = MEBin( + anomaly_map_list=anomaly_maps_list, + sample_rate=self.sample_rate, + min_interval_len=self.min_interval_len, + erode=self.erode, + ) + binarized_maps, _ = mebin.binarize_anomaly_maps() + + # Convert back to torch.Tensor and normalize to 0/1 + pred_masks = torch.stack([torch.from_numpy(bm).to(original_anomaly_map.device) for bm in binarized_maps]) + pred_masks = (pred_masks > 0).to(original_anomaly_map.dtype) + + # Add channel dimension if original had one + if original_anomaly_map.ndim == 4: + pred_masks = pred_masks.unsqueeze(1) + + # Update the batch with binarized masks + batch.pred_mask = pred_masks + + # Apply parent class post-processing for normalization and thresholding + # This will compute pred_label from pred_score if needed + super().post_process_batch(batch) diff --git a/tests/unit/post_processing/test_mebin_post_processor.py b/tests/unit/post_processing/test_mebin_post_processor.py new file mode 100644 index 0000000000..2d0813d320 --- /dev/null +++ b/tests/unit/post_processing/test_mebin_post_processor.py @@ -0,0 +1,232 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Test the MEBinPostProcessor class.""" + +from unittest.mock import MagicMock, Mock, patch + +import numpy as np +import pytest +import torch + +from anomalib.data import InferenceBatch +from anomalib.post_processing import MEBinPostProcessor + + +class TestMEBinPostProcessor: + """Test the MEBinPostProcessor class.""" + + @staticmethod + def test_initialization_default_params() -> None: + """Test MEBinPostProcessor initialization with default parameters.""" + processor = MEBinPostProcessor() + + assert processor.sample_rate == 4 + assert processor.min_interval_len == 4 + assert processor.erode is True + + @staticmethod + @pytest.mark.parametrize( + ("sample_rate", "min_interval_len", "erode"), + [ + (2, 3, True), + (8, 6, False), + (1, 1, True), + ], + ) + def test_initialization_custom_params( + sample_rate: int, + min_interval_len: int, + erode: bool, + ) -> None: + """Test MEBinPostProcessor initialization with custom parameters.""" + processor = MEBinPostProcessor( + sample_rate=sample_rate, + min_interval_len=min_interval_len, + erode=erode, + ) + + assert processor.sample_rate == sample_rate + assert processor.min_interval_len == min_interval_len + assert processor.erode == erode + + @staticmethod + @patch("anomalib.post_processing.mebin_post_processor.MEBin") + def test_forward_single_anomaly_map(mock_mebin: MagicMock) -> None: + """Test forward method with single anomaly map.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 0], [1, 1]], dtype=np.uint8)], + [0.5], + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_map = torch.rand(1, 1, 4, 4) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_map, + pred_mask=None, + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify results + assert isinstance(result, InferenceBatch) + assert result.pred_mask is not None + assert result.pred_mask.shape == (1, 2, 2) + assert result.pred_mask.dtype == anomaly_map.dtype + + @staticmethod + @patch("anomalib.post_processing.mebin_post_processor.MEBin") + def test_forward_batch_anomaly_maps(mock_mebin: MagicMock) -> None: + """Test forward method with batch of anomaly maps.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [ + np.array([[0, 0], [1, 1]], dtype=np.uint8), + np.array([[1, 0], [0, 1]], dtype=np.uint8), + ], + [0.5, 0.6], + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_maps = torch.rand(2, 1, 4, 4) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8, 0.9]), + pred_label=torch.tensor([1, 1]), + anomaly_map=anomaly_maps, + pred_mask=None, + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify results + assert isinstance(result, InferenceBatch) + assert result.pred_mask.shape == (2, 2, 2) + + @staticmethod + @patch("anomalib.post_processing.mebin_post_processor.MEBin") + def test_forward_normalization(mock_mebin: MagicMock) -> None: + """Test that anomaly maps are properly normalized to 0-255 range.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 0], [1, 1]], dtype=np.uint8)], + [0.5], + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data with specific range + anomaly_maps = torch.tensor([[[[0.0, 0.5], [1.0, 0.2]]]]) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_maps, + pred_mask=None, + ) + + # Test forward pass + processor = MEBinPostProcessor() + processor.forward(predictions) + + # Verify MEBin was called with normalized data + mock_mebin.assert_called_once() + call_args = mock_mebin.call_args + anomaly_map_list = call_args[1]["anomaly_map_list"] + + # Check that the data is normalized to 0-255 range + assert len(anomaly_map_list) == 1 + assert anomaly_map_list[0].dtype == np.uint8 + assert anomaly_map_list[0].min() >= 0 + assert anomaly_map_list[0].max() <= 255 + + @staticmethod + @patch("anomalib.post_processing.mebin_post_processor.MEBin") + def test_forward_mebin_parameters(mock_mebin: MagicMock) -> None: + """Test that MEBin is called with correct parameters.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 0], [1, 1]], dtype=np.uint8)], + [0.5], + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_maps = torch.rand(1, 1, 4, 4) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_maps, + pred_mask=None, + ) + + # Test with custom parameters + processor = MEBinPostProcessor( + sample_rate=8, + min_interval_len=6, + erode=False, + ) + _ = processor.forward(predictions) + + # Verify MEBin was called with correct parameters + mock_mebin.assert_called_once_with( + anomaly_map_list=mock_mebin.call_args[1]["anomaly_map_list"], + sample_rate=8, + min_interval_len=6, + erode=False, + ) + + @staticmethod + @patch("anomalib.post_processing.mebin_post_processor.MEBin") + def test_forward_binary_mask_conversion(mock_mebin: MagicMock) -> None: + """Test that binary masks are properly converted to 0/1 values.""" + # Setup mock to return masks with values > 0 + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 128], [255, 64]], dtype=np.uint8)], + [0.5], + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_maps = torch.rand(1, 1, 2, 2) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_maps, + pred_mask=None, + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify that all values are either 0 or 1 + unique_values = torch.unique(result.pred_mask) + assert torch.all((unique_values == 0) | (unique_values == 1)) + + @staticmethod + def test_forward_missing_anomaly_map() -> None: + """Test that ValueError is raised when anomaly_map is None.""" + # Create test data without anomaly_map + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=None, + pred_mask=None, + ) + + # Test forward pass should raise ValueError + processor = MEBinPostProcessor() + with pytest.raises(ValueError, match="Anomaly map is required for MEBin post-processing"): + processor.forward(predictions)