Skip to content

Commit 367fb90

Browse files
FEA implement InstanceHardnessCV cross-validation splitter (#1125)
Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent 709ca4e commit 367fb90

File tree

11 files changed

+488
-1
lines changed

11 files changed

+488
-1
lines changed

doc/model_selection.rst

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
.. _cross_validation:
2+
3+
================
4+
Cross validation
5+
================
6+
7+
.. currentmodule:: imblearn.model_selection
8+
9+
10+
.. _instance_hardness_threshold_cv:
11+
12+
The term instance hardness is used in literature to express the difficulty to correctly
13+
classify an instance. An instance for which the predicted probability of the true class
14+
is low, has large instance hardness. The way these hard-to-classify instances are
15+
distributed over train and test sets in cross validation, has significant effect on the
16+
test set performance metrics. The :class:`~imblearn.model_selection.InstanceHardnessCV`
17+
splitter distributes samples with large instance hardness equally over the folds,
18+
resulting in more robust cross validation.
19+
20+
We will discuss instance hardness in this document and explain how to use the
21+
:class:`~imblearn.model_selection.InstanceHardnessCV` splitter.
22+
23+
Instance hardness and average precision
24+
=======================================
25+
26+
Instance hardness is defined as 1 minus the probability of the most probable class:
27+
28+
.. math::
29+
30+
H(x) = 1 - P(\hat{y}|x)
31+
32+
In this equation :math:`H(x)` is the instance hardness for a sample with features
33+
:math:`x` and :math:`P(\hat{y}|x)` the probability of predicted label :math:`\hat{y}`
34+
given the features. If the model predicts label 0 and gives a `predict_proba` output
35+
of [0.9, 0.1], the probability of the most probable class (0) is 0.9 and the
36+
instance hardness is `1-0.9=0.1`.
37+
38+
Samples with large instance hardness have significant effect on the area under
39+
precision-recall curve, or average precision. Especially samples with label 0
40+
with large instance hardness (so the model predicts label 1) reduce the average
41+
precision a lot as these points affect the precision-recall curve in the left
42+
where the area is largest; the precision is lowered in the range of low recall
43+
and high thresholds. When doing cross validation, e.g. in case of hyperparameter
44+
tuning or recursive feature elimination, random gathering of these points in
45+
some folds introduce variance in CV results that deteriorates robustness of the
46+
cross validation task. The :class:`~imblearn.model_selection.InstanceHardnessCV`
47+
splitter aims to distribute the samples with large instance hardness over the
48+
folds in order to reduce undesired variance. Note that one should use this
49+
splitter to make model *selection* tasks robust like hyperparameter tuning and
50+
feature selection but not for model *performance estimation* for which you also
51+
want to know the variance of performance to be expected in production.
52+
53+
54+
Create imbalanced dataset with samples with large instance hardness
55+
===================================================================
56+
57+
Let's start by creating a dataset to work with. We create a dataset with 5% class
58+
imbalance using scikit-learn's :func:`~sklearn.datasets.make_blobs` function.
59+
60+
>>> import numpy as np
61+
>>> from matplotlib import pyplot as plt
62+
>>> from sklearn.datasets import make_blobs
63+
>>> from imblearn.datasets import make_imbalance
64+
>>> random_state = 10
65+
>>> X, y = make_blobs(n_samples=[950, 50], centers=((-3, 0), (3, 0)),
66+
... random_state=random_state)
67+
>>> plt.scatter(X[:, 0], X[:, 1], c=y)
68+
>>> plt.show()
69+
70+
.. image:: ./auto_examples/model_selection/images/sphx_glr_plot_instance_hardness_cv_001.png
71+
:target: ./auto_examples/model_selection/plot_instance_hardness_cv.html
72+
:align: center
73+
74+
Now we add some samples with large instance hardness
75+
76+
>>> X_hard, y_hard = make_blobs(n_samples=10, centers=((3, 0), (-3, 0)),
77+
... cluster_std=1,
78+
... random_state=random_state)
79+
>>> X = np.vstack((X, X_hard))
80+
>>> y = np.hstack((y, y_hard))
81+
>>> plt.scatter(X[:, 0], X[:, 1], c=y)
82+
>>> plt.show()
83+
84+
.. image:: ./auto_examples/model_selection/images/sphx_glr_plot_instance_hardness_cv_002.png
85+
:target: ./auto_examples/model_selection/plot_instance_hardness_cv.html
86+
:align: center
87+
88+
Assess cross validation performance variance using `InstanceHardnessCV` splitter
89+
================================================================================
90+
91+
Then we take a :class:`~sklearn.linear_model.LogisticRegression` and assess the
92+
cross validation performance using a :class:`~sklearn.model_selection.StratifiedKFold`
93+
cv splitter and the :func:`~sklearn.model_selection.cross_validate` function.
94+
95+
>>> from sklearn.ensemble import LogisticRegressionClassifier
96+
>>> clf = LogisticRegressionClassifier(random_state=random_state)
97+
>>> skf_cv = StratifiedKFold(n_splits=5, shuffle=True,
98+
... random_state=random_state)
99+
>>> skf_result = cross_validate(clf, X, y, cv=skf_cv, scoring="average_precision")
100+
101+
Now, we do the same using an :class:`~imblearn.model_selection.InstanceHardnessCV`
102+
splitter. We use provide our classifier to the splitter to calculate instance hardness
103+
and distribute samples with large instance hardness equally over the folds.
104+
105+
>>> ih_cv = InstanceHardnessCV(estimator=clf, n_splits=5,
106+
... random_state=random_state)
107+
>>> ih_result = cross_validate(clf, X, y, cv=ih_cv, scoring="average_precision")
108+
109+
When we plot the test scores for both cv splitters, we see that the variance using the
110+
:class:`~imblearn.model_selection.InstanceHardnessCV` splitter is lower than for the
111+
:class:`~sklearn.model_selection.StratifiedKFold` splitter.
112+
113+
>>> plt.boxplot([skf_result['test_score'], ih_result['test_score']],
114+
... tick_labels=["StratifiedKFold", "InstanceHardnessCV"],
115+
... vert=False)
116+
>>> plt.xlabel('Average precision')
117+
>>> plt.tight_layout()
118+
119+
.. image:: ./auto_examples/model_selection/images/sphx_glr_plot_instance_hardness_cv_003.png
120+
:target: ./auto_examples/model_selection/plot_instance_hardness_cv.html
121+
:align: center
122+
123+
Be aware that the most important part of cross-validation splitters is to simulate the
124+
conditions that one will encounter in production. Therefore, if it is likely to get
125+
difficult samples in production, one should use a cross-validation splitter that
126+
emulates this situation. In our case, the
127+
:class:`~sklearn.model_selection.StratifiedKFold` splitter did not allow to distribute
128+
the difficult samples over the folds and thus it was likely a problem for our use case.

doc/references/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ This is the full API documentation of the `imbalanced-learn` toolbox.
1818
miscellaneous
1919
pipeline
2020
metrics
21+
model_selection
2122
datasets
2223
utils

doc/references/model_selection.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
.. _model_selection_ref:
2+
3+
Model selection methods
4+
=======================
5+
6+
.. automodule:: imblearn.model_selection
7+
:no-members:
8+
:no-inherited-members:
9+
10+
Cross-validation splitters
11+
--------------------------
12+
13+
.. automodule:: imblearn.model_selection._split
14+
:no-members:
15+
:no-inherited-members:
16+
17+
.. currentmodule:: imblearn.model_selection
18+
19+
.. autosummary::
20+
:toctree: generated/
21+
:template: class.rst
22+
23+
InstanceHardnessCV

doc/user_guide.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ User Guide
1919
ensemble.rst
2020
miscellaneous.rst
2121
metrics.rst
22+
model_selection.rst
2223
common_pitfalls.rst
2324
Dataset loading utilities <datasets/index.rst>
2425
developers_utils.rst

doc/whats_new/0.14.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ Bug fixes
1414
Enhancements
1515
............
1616

17+
- Add :class:`~imblearn.model_selection.InstanceHardnessCV` to split data and ensure
18+
that samples are distributed in folds based on their instance hardness.
19+
:pr:`1125` by :user:`Frits Hermans <fritshermans>`.
20+
1721
Compatibility
1822
.............
1923

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
====================================================
3+
Distribute hard-to-classify datapoints over CV folds
4+
====================================================
5+
6+
'Instance hardness' refers to the difficulty to classify an instance. The way
7+
hard-to-classify instances are distributed over train and test sets has
8+
significant effect on the test set performance metrics. In this example we
9+
show how to deal with this problem. We are making the comparison with normal
10+
:class:`~sklearn.model_selection.StratifiedKFold` cross-validation splitter.
11+
"""
12+
13+
# Authors: Frits Hermans, https://fritshermans.github.io
14+
# License: MIT
15+
16+
# %%
17+
print(__doc__)
18+
19+
# %%
20+
# Create an imbalanced dataset with instance hardness
21+
# ---------------------------------------------------
22+
#
23+
# We create an imbalanced dataset with using scikit-learn's
24+
# :func:`~sklearn.datasets.make_blobs` function and set the class imbalance ratio to
25+
# 5%.
26+
import numpy as np
27+
from matplotlib import pyplot as plt
28+
from sklearn.datasets import make_blobs
29+
30+
X, y = make_blobs(n_samples=[950, 50], centers=((-3, 0), (3, 0)), random_state=10)
31+
plt.scatter(X[:, 0], X[:, 1], c=y)
32+
33+
# %%
34+
# To introduce instance hardness in our dataset, we add some hard to classify samples:
35+
X_hard, y_hard = make_blobs(
36+
n_samples=10, centers=((3, 0), (-3, 0)), cluster_std=1, random_state=10
37+
)
38+
X, y = np.vstack((X, X_hard)), np.hstack((y, y_hard))
39+
plt.scatter(X[:, 0], X[:, 1], c=y)
40+
41+
# %%
42+
# Compare cross validation scores using `StratifiedKFold` and `InstanceHardnessCV`
43+
# --------------------------------------------------------------------------------
44+
#
45+
# Now, we want to assess a linear predictive model. Therefore, we should use
46+
# cross-validation. The most important concept with cross-validation is to create
47+
# training and test splits that are representative of the the data in production to have
48+
# statistical results that one can expect in production.
49+
#
50+
# By applying a standard :class:`~sklearn.model_selection.StratifiedKFold`
51+
# cross-validation splitter, we do not control in which fold the hard-to-classify
52+
# samples will be.
53+
#
54+
# The :class:`~imblearn.model_selection.InstanceHardnessCV` splitter allows to
55+
# control the distribution of the hard-to-classify samples over the folds.
56+
#
57+
# Let's make an experiment to compare the results that we get with both splitters.
58+
# We use a :class:`~sklearn.linear_model.LogisticRegression` classifier and
59+
# :func:`~sklearn.model_selection.cross_validate` to calculate the cross validation
60+
# scores. We use average precision for scoring.
61+
import pandas as pd
62+
from sklearn.linear_model import LogisticRegression
63+
from sklearn.model_selection import StratifiedKFold, cross_validate
64+
65+
from imblearn.model_selection import InstanceHardnessCV
66+
67+
logistic_regression = LogisticRegression()
68+
69+
results = {}
70+
for cv in (
71+
StratifiedKFold(n_splits=5, shuffle=True, random_state=10),
72+
InstanceHardnessCV(estimator=LogisticRegression(), n_splits=5, random_state=10),
73+
):
74+
result = cross_validate(
75+
logistic_regression,
76+
X,
77+
y,
78+
cv=cv,
79+
scoring="average_precision",
80+
)
81+
results[cv.__class__.__name__] = result["test_score"]
82+
results = pd.DataFrame(results)
83+
84+
# %%
85+
ax = results.plot.box(vert=False, whis=[0, 100])
86+
ax.set(
87+
xlabel="Average precision",
88+
title="Cross validation scores with different splitters",
89+
xlim=(0, 1),
90+
)
91+
92+
# %%
93+
# The boxplot shows that the :class:`~imblearn.model_selection.InstanceHardnessCV`
94+
# splitter results in less variation of average precision than
95+
# :class:`~sklearn.model_selection.StratifiedKFold` splitter. When doing
96+
# hyperparameter tuning or feature selection using a wrapper method (like
97+
# :class:`~sklearn.feature_selection.RFECV`) this will give more stable results.

imblearn/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
Module which provides methods generating an ensemble of
1212
under-sampled subsets.
1313
exceptions
14-
Module including custom warnings and error clases used across
14+
Module including custom warnings and error classes used across
1515
imbalanced-learn.
1616
keras
1717
Module which provides custom generator, layers for deep learning using
1818
keras.
1919
metrics
2020
Module which provides metrics to quantified the classification performance
2121
with imbalanced dataset.
22+
model_selection
23+
Module which provides methods to split the dataset into training and test sets.
2224
over_sampling
2325
Module which provides methods to over-sample a dataset.
2426
tensorflow
@@ -54,6 +56,7 @@
5456
ensemble,
5557
exceptions,
5658
metrics,
59+
model_selection,
5760
over_sampling,
5861
pipeline,
5962
tensorflow,
@@ -113,6 +116,7 @@ def __dir__(self):
113116
"exceptions",
114117
"keras",
115118
"metrics",
119+
"model_selection",
116120
"over_sampling",
117121
"tensorflow",
118122
"under_sampling",

imblearn/model_selection/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""
2+
The :mod:`imblearn.model_selection` provides methods to split the dataset into
3+
training and test sets.
4+
"""
5+
6+
from ._split import InstanceHardnessCV
7+
8+
__all__ = ["InstanceHardnessCV"]

0 commit comments

Comments
 (0)