Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion alibi_detect/cd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .fet import FETDrift
from .fet_online import FETDriftOnline
from .context_aware import ContextMMDDrift
from .spectral import SpectralDrift

__all__ = [
"ChiSquareDrift",
Expand All @@ -32,5 +33,6 @@
"CVMDriftOnline",
"FETDrift",
"FETDriftOnline",
"ContextMMDDrift"
"ContextMMDDrift",
"SpectralDrift"
]
166 changes: 166 additions & 0 deletions alibi_detect/cd/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from scipy.stats import binomtest, ks_2samp
from sklearn.model_selection import StratifiedKFold


logger = logging.getLogger(__name__)


if has_pytorch:
import torch

Expand Down Expand Up @@ -661,6 +665,168 @@ def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True, return_
return cd


class BaseSpectralDrift(BaseDetector):
def __init__(
self,
x_ref: Union[np.ndarray, list],
p_val: float = .05,
x_ref_preprocessed: bool = False,
preprocess_at_init: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
threshold: Optional[float] = None,
n_bootstraps: int = 1000,
input_shape: Optional[tuple] = None,
data_type: Optional[str] = None
) -> None:
"""
Spectral eigenvalue-based data drift detector base class.

Parameters
----------
x_ref
Data used as reference distribution.
p_val
p-value used for the significance of the test.
x_ref_preprocessed
Whether the given reference data `x_ref` has been preprocessed yet. If `x_ref_preprocessed=True`, only
the test data `x` will be preprocessed at prediction time. If `x_ref_preprocessed=False`, the reference
data will also be preprocessed.
preprocess_at_init
Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference
data will be preprocessed at prediction time. Only applies if `x_ref_preprocessed=False`.
update_x_ref
Reference data can optionally be updated to the last n instances seen by the detector
or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while
for reservoir sampling {'reservoir_sampling': n} is passed.
preprocess_fn
Function to preprocess the data before computing the data drift metrics.
threshold
Spectral ratio threshold for drift detection. If None, computed from p_val using bootstrap.
n_bootstraps
Number of bootstrap samples for threshold computation.
input_shape
Shape of input data.
data_type
Optionally specify the data type (tabular, image or time-series). Added to metadata.
"""
super().__init__()

if p_val is None:
logger.warning('No p-value set for the drift threshold. Need to set it to detect data drift.')

# Validate input dimensions for spectral analysis
if hasattr(x_ref, 'shape') and x_ref.shape[1] < 2:
raise ValueError(f"Spectral analysis requires at least 2 features, got {x_ref.shape[1]}")

# x_ref preprocessing
self.preprocess_at_init = preprocess_at_init
self.x_ref_preprocessed = x_ref_preprocessed
if preprocess_fn is not None and not isinstance(preprocess_fn, Callable): # type: ignore[arg-type]
raise ValueError("`preprocess_fn` is not a valid Callable.")
if self.preprocess_at_init and not self.x_ref_preprocessed and preprocess_fn is not None:
self.x_ref = preprocess_fn(x_ref)
else:
self.x_ref = x_ref

# Other attributes
self.p_val = p_val
self.update_x_ref = update_x_ref
self.preprocess_fn = preprocess_fn
self.threshold = threshold
self.n_bootstraps = n_bootstraps
self.n = len(x_ref)

# store input shape for save and load functionality
self.input_shape = get_input_shape(input_shape, x_ref)

# set metadata
self.meta.update({'detector_type': 'drift', 'online': False, 'data_type': data_type})

def preprocess(self, x: Union[np.ndarray, list]) -> Tuple[np.ndarray, np.ndarray]:
"""
Data preprocessing before computing the drift scores.

Parameters
----------
x
Batch of instances.

Returns
-------
Preprocessed reference data and new instances.
"""
if self.preprocess_fn is not None:
x = self.preprocess_fn(x)
if not self.preprocess_at_init and not self.x_ref_preprocessed:
x_ref = self.preprocess_fn(self.x_ref)
else:
x_ref = self.x_ref
return x_ref, x # type: ignore[return-value]
else:
return self.x_ref, x # type: ignore[return-value]

@abstractmethod
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]:
"""
Compute spectral drift score.

Parameters
----------
x
Batch of instances.

Returns
-------
Tuple containing p-value, spectral ratio, and threshold.
"""
pass

def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True, return_distance: bool = True) \
-> Dict[str, Any]:
"""
Predict whether a batch of data has drifted from the reference data.

Parameters
----------
x
Batch of instances.
return_p_val
Whether to return the p-value of the test.
return_distance
Whether to return the spectral ratio between the new batch and reference data.

Returns
-------
Dictionary containing ``'meta'`` and ``'data'`` dictionaries.
- ``'meta'`` has the model's metadata.
- ``'data'`` contains the drift prediction and optionally the p-value, threshold and spectral ratio.
"""
# compute drift scores
p_val, spectral_ratio, distance_threshold = self.score(x)
drift_pred = int(p_val < self.p_val)

# update reference dataset
if isinstance(self.update_x_ref, dict) and self.preprocess_fn is not None and self.preprocess_at_init:
x = self.preprocess_fn(x)
self.x_ref = update_reference(self.x_ref, x, self.n, self.update_x_ref) # type: ignore[arg-type]
# used for reservoir sampling
self.n += len(x)

# populate drift dict
cd = concept_drift_dict()
cd['meta'] = self.meta
cd['data']['is_drift'] = drift_pred
if return_p_val:
cd['data']['p_val'] = p_val
cd['data']['threshold'] = self.p_val
if return_distance:
cd['data']['distance'] = spectral_ratio
cd['data']['distance_threshold'] = distance_threshold
cd['data']['spectral_ratio'] = spectral_ratio
return cd


class BaseLSDDDrift(BaseDetector):
# TODO: TBD: this is only created when _configure_normalization is called from backend-specific classes,
# is declaring it here the right thing to do?
Expand Down
Loading