Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions imblearn/over_sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ._smote import KMeansSMOTE
from ._smote import SVMSMOTE
from ._smote import SMOTENC
from ._smote import SLSMOTE

__all__ = [
"ADASYN",
Expand All @@ -19,4 +20,5 @@
"BorderlineSMOTE",
"SVMSMOTE",
"SMOTENC",
"SLSMOTE",
]
328 changes: 324 additions & 4 deletions imblearn/over_sampling/_smote.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,12 +586,14 @@ def _fit_resample(self, X, y):
n_generated_samples = int(fractions * (n_samples + 1))
if np.count_nonzero(danger_bool) > 0:
nns = self.nn_k_.kneighbors(
_safe_indexing(support_vector, np.flatnonzero(danger_bool)),
_safe_indexing(
support_vector, np.flatnonzero(danger_bool)),
return_distance=False,
)[:, 1:]

X_new_1, y_new_1 = self._make_samples(
_safe_indexing(support_vector, np.flatnonzero(danger_bool)),
_safe_indexing(
support_vector, np.flatnonzero(danger_bool)),
y.dtype,
class_sample,
X_class,
Expand All @@ -602,12 +604,14 @@ def _fit_resample(self, X, y):

if np.count_nonzero(safety_bool) > 0:
nns = self.nn_k_.kneighbors(
_safe_indexing(support_vector, np.flatnonzero(safety_bool)),
_safe_indexing(
support_vector, np.flatnonzero(safety_bool)),
return_distance=False,
)[:, 1:]

X_new_2, y_new_2 = self._make_samples(
_safe_indexing(support_vector, np.flatnonzero(safety_bool)),
_safe_indexing(
support_vector, np.flatnonzero(safety_bool)),
y.dtype,
class_sample,
X_class,
Expand Down Expand Up @@ -1308,3 +1312,319 @@ def _fit_resample(self, X, y):
y_resampled = np.hstack((y_resampled, y_new))

return X_resampled, y_resampled


@Substitution(
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
random_state=_random_state_docstring,
)
class SLSMOTE(BaseSMOTE):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glemaitre SafeLevelSMOTE vs SLSMOTE

"""Class to perform over-sampling using safe-level SMOTE.
This is an implementation of the Safe-level-SMOTE described in [2]_.

Parameters
-----------
{sampling_strategy}

{random_state}

k_neighbors : int or object, optional (default=5)
If ``int``, number of nearest neighbours to used to construct synthetic
samples. If object, an estimator that inherits from
:class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to
find the k_neighbors.

m_neighbors : int or object, optional (default=10)
If ``int``, number of nearest neighbours used to determine the safe
level of an instance. If object, an estimator that inherits from
:class:`sklearn.neighbors.base.KNeighborsMixin` that will be used
to find the m_neighbors.

n_jobs : int or None, optional (default=None)
Number of CPU cores used during the cross-validation loop.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See
`Glossary <https://scikit-learn.org/stable/glossary.html#term-n-jobs>`_
for more details.


Notes
-----
See the original papers: [2]_ for more details.

Supports multi-class resampling. A one-vs.-rest scheme is used as
originally proposed in [1]_.

See also
--------
SMOTE : Over-sample using SMOTE.

SMOTENC : Over-sample using SMOTE for continuous and categorical features.

SVMSMOTE : Over-sample using SVM-SMOTE variant.

BorderlineSMOTE : Over-sample using Borderline-SMOTE.

ADASYN : Over-sample using ADASYN.

KMeansSMOTE: Over-sample using KMeans-SMOTE variant.

References
----------
.. [1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, "SMOTE:
synthetic minority over-sampling technique," Journal of artificial
intelligence research, 321-357, 2002.

.. [2] C. Bunkhumpornpat, K. Sinapiromsaran, C. Lursinsap, "Safe-level-
SMOTE: Safe-level-synthetic minority over-sampling technique for
handling the class imbalanced problem," In: Theeramunkong T.,
Kijsirikul B., Cercone N., Ho TB. (eds) Advances in Knowledge Discovery
and Data Mining. PAKDD 2009. Lecture Notes in Computer Science,
vol 5476. Springer, Berlin, Heidelberg, 475-482, 2009.


Examples
--------

>>> from collections import Counter
>>> from sklearn.datasets import make_classification
>>> from imblearn.over_sampling import \
SLSMOTE # doctest: +NORMALIZE_WHITESPACE
>>> X, y = make_classification(n_classes=2, class_sep=2,
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
>>> print('Original dataset shape %s' % Counter(y))
Original dataset shape Counter({{1: 900, 0: 100}})
>>> sm = SLSMOTE(random_state=42)
>>> X_res, y_res = sm.fit_resample(X, y)
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{0: 900, 1: 900}})

"""

def __init__(self,
sampling_strategy='auto',
random_state=None,
k_neighbors=5,
m_neighbors=10,
n_jobs=None):

super().__init__(sampling_strategy=sampling_strategy,
random_state=random_state, k_neighbors=k_neighbors,
n_jobs=n_jobs)

self.m_neighbors = m_neighbors

def _assign_sl(self, nn_estimator, samples, target_class, y):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use the name _assign_safe_levels unless it hurts the readability in the calling code.

'''
Assign the safe levels to the instances in the target class.

Parameters
----------
nn_estimator : estimator
An estimator that inherits from
:class:`sklearn.neighbors.base.KNeighborsMixin`. It gets the
nearest neighbors that are used to determine the safe levels.

samples : {array-like, sparse matrix}, shape (n_samples, n_features)
The samples to which the safe levels are assigned.

target_class : int or str
The target corresponding class being over-sampled.

y : array-like, shape (n_samples,)
The true label in order to calculate the safe levels.

Returns
-------
output : ndarray, shape (n_samples,)
A ndarray where the values refer to the safe level of the
instances in the target class.
'''

x = nn_estimator.kneighbors(samples, return_distance=False)[:, 1:]
nn_label = (y[x] == target_class).astype(int)
sl = np.sum(nn_label, axis=1)
return sl

def _validate_estimator(self):
super()._validate_estimator()
self.nn_m_ = check_neighbors_object('m_neighbors', self.m_neighbors,
additional_neighbor=1)
self.nn_m_.set_params(**{"n_jobs": self.n_jobs})

def _fit_resample(self, X, y):
self._validate_estimator()

X_resampled = X.copy()
y_resampled = y.copy()

for class_sample, n_samples in self.sampling_strategy_.items():
if n_samples == 0:
continue
target_class_indices = np.flatnonzero(y == class_sample)
X_class = _safe_indexing(X, target_class_indices)

self.nn_m_.fit(X)
sl = self._assign_sl(self.nn_m_, X_class, class_sample, y)

# filter the points in X_class that have safe level >0
# If safe level = 0, the point is not used to
# generate synthetic instances
X_safe_indices = np.flatnonzero(sl != 0)
X_safe_class = _safe_indexing(X_class, X_safe_indices)

self.nn_k_.fit(X_class)
nns = self.nn_k_.kneighbors(X_safe_class,
return_distance=False)[:, 1:]

sl_safe_class = sl[X_safe_indices]
sl_nns = sl[nns]
sl_safe_t = np.array([sl_safe_class]).transpose()
with np.errstate(divide='ignore'):
sl_ratio = np.divide(sl_safe_t, sl_nns)

X_new, y_new = self._make_samples_sl(X_safe_class, y.dtype,
class_sample, X_class,
nns, n_samples, sl_ratio,
1.0)

if sparse.issparse(X_new):
X_resampled = sparse.vstack([X_resampled, X_new])
else:
X_resampled = np.vstack((X_resampled, X_new))
y_resampled = np.hstack((y_resampled, y_new))

return X_resampled, y_resampled

def _make_samples_sl(self, X, y_dtype, y_type, nn_data, nn_num,
n_samples, sl_ratio, step_size=1.):
"""A support function that returns artificial samples using
safe-level SMOTE. It is similar to _make_samples method for SMOTE.

Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples_safe, n_features)
Points from which the points will be created.

y_dtype : dtype
The data type of the targets.

y_type : str or int
The minority target value, just so the function can return the
target values for the synthetic variables with correct length in
a clear format.

nn_data : ndarray, shape (n_samples_all, n_features)
Data set carrying all the neighbours to be used

nn_num : ndarray, shape (n_samples_safe, k_nearest_neighbours)
The nearest neighbours of each sample in `nn_data`.

n_samples : int
The number of samples to generate.

sl_ratio: ndarray, shape (n_samples_safe, k_nearest_neighbours)

step_size : float, optional (default=1.)
The step size to create samples.


Returns
-------
X_new : {ndarray, sparse matrix}, shape (n_samples_new, n_features)
Synthetically generated samples using the safe-level method.

y_new : ndarray, shape (n_samples_new,)
Target values for synthetic samples.

"""

random_state = check_random_state(self.random_state)
samples_indices = random_state.randint(low=0,
high=len(nn_num.flatten()),
size=n_samples)
rows = np.floor_divide(samples_indices, nn_num.shape[1])
cols = np.mod(samples_indices, nn_num.shape[1])
gap_arr = step_size * self._vgenerate_gap(sl_ratio)
gaps = gap_arr.flatten()[samples_indices]

y_new = np.array([y_type] * n_samples, dtype=y_dtype)

if sparse.issparse(X):
row_indices, col_indices, samples = [], [], []
for i, (row, col, gap) in enumerate(zip(rows, cols, gaps)):
if X[row].nnz:
sample = self._generate_sample(
X, nn_data, nn_num, row, col, gap)
row_indices += [i] * len(sample.indices)
col_indices += sample.indices.tolist()
samples += sample.data.tolist()
return (
sparse.csr_matrix(
(samples, (row_indices, col_indices)),
[len(samples_indices), X.shape[1]],
dtype=X.dtype,
),
y_new,
)

else:
X_new = np.zeros((n_samples, X.shape[1]), dtype=X.dtype)
for i, (row, col, gap) in enumerate(zip(rows, cols, gaps)):
X_new[i] = self._generate_sample(X, nn_data, nn_num,
row, col, gap)

return X_new, y_new

def _generate_gap(self, a_ratio, rand_state=None):
""" generate gap according to sl_ratio, non-vectorized version.

Parameters
----------
a_ratio: float
sl_ratio of a single data point

rand_state: random state object or int


Returns
------------
gap: float
a number between 0 and 1

"""

random_state = check_random_state(rand_state)
if np.isinf(a_ratio):
gap = 0
elif a_ratio >= 1:
gap = random_state.uniform(0, 1/a_ratio)
elif 0 < a_ratio < 1:
gap = random_state.uniform(1-a_ratio, 1)
else:
raise ValueError('sl_ratio should be nonegative')
return gap

def _vgenerate_gap(self, sl_ratio):
"""
generate gap according to sl_ratio, vectorized version of _generate_gap

Parameters
-----------
sl_ratio: ndarray shape (n_samples_safe, k_nearest_neighbours)
sl_ratio of all instances with safe_level>0 in the specified
class

Returns
------------
gap_arr: ndarray shape (n_samples_safe, k_nearest_neighbours)
the gap for all instances with safe_level>0 in the specified
class

"""
prng = check_random_state(self.random_state)
rand_state = prng.randint(sl_ratio.size+1, size=sl_ratio.shape)
vgap = np.vectorize(self._generate_gap)
gap_arr = vgap(sl_ratio, rand_state)
return gap_arr
Loading