Skip to content

Stratified Group KFold implementation #18649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 73 commits into from
Mar 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
4df86c6
Initial implementation
hermidalc Oct 13, 2019
6be3594
Forgot to add to second __add__ list
hermidalc Oct 13, 2019
2f28673
Update split method parameter doc
hermidalc Oct 13, 2019
2365735
Added example; changed default test_size to 0.1; added to author list
hermidalc Oct 14, 2019
b3d2b5a
Merge branch 'master' of github.com:scikit-learn/scikit-learn into st…
hermidalc Mar 18, 2020
aa8f288
StratifiedGroupKFold impl and other improvements
hermidalc Mar 18, 2020
647a97e
Add class to __all__ spec
hermidalc Mar 18, 2020
36babe5
Remove random_state when no shuffle
hermidalc Mar 18, 2020
32e502a
Tighter formatting
hermidalc Mar 18, 2020
c7ad3f3
Merge branch 'master' into stratified-groupshufflesplit
Oct 7, 2020
4826d96
Update the implementation of StratifiedGroupKFold
Oct 9, 2020
13801a7
Add StratifiedGroupKFold to __init__
Oct 9, 2020
8367133
Add y checks to StartifiedGroupKFold
Oct 9, 2020
bca2dbc
Raise error if n_splits > max num samples in class
Oct 9, 2020
31fc183
Warn if n_splits > mn num samples in class
Oct 9, 2020
0d9a58f
Add SGKfold to general repr test
Oct 9, 2020
3c2c639
Add SGKFold to 2d_y test case
Oct 9, 2020
8648519
Add SGKfold to value erros test case
Oct 9, 2020
7005af2
Add SGKFold to StratifiedKFold test cases
Oct 9, 2020
6a52ae9
Add SGKFold to reproducibility test case
Oct 9, 2020
6f83a85
Add SGKFold to GroupKFold test case
Oct 9, 2020
7fbc736
Add SGKFold to nested cv test case
Oct 9, 2020
d4f99e6
Add SGKFold to random_state with shuffle=False test case
Oct 9, 2020
a38a872
Add SGKFold to constant splits test case
Oct 9, 2020
490b503
Fix repr test case
Oct 9, 2020
cc8da98
Fix formatting issues
Oct 9, 2020
6990a91
Add samples to a fold with least num samples
Oct 9, 2020
1f4da2b
Remove GroupShuffleSplit impl
Oct 10, 2020
9359cc7
Add notes to StratifiedGroupKFold
Oct 10, 2020
6386faa
Fix doctest
Oct 10, 2020
536c4c9
Added stratified group kfold tests
Oct 19, 2020
9681a61
Better variable naming
Oct 19, 2020
b7e4fc8
Add section to documentation
Oct 19, 2020
e3112b4
Merge branch 'master' into stratified-group-kfold
Oct 19, 2020
2580a81
Remove leftover StratifiedGroupShuffleSplit import
Oct 19, 2020
81be001
Merge remote-tracking branch 'upstream/master' into stratified-group-…
marrodion Oct 30, 2020
113c06a
Add changelist and reference to original kernel
marrodion Oct 30, 2020
72ebb9f
Better naming for least populated class check
marrodion Oct 30, 2020
7093e70
Better expression for number of labels
marrodion Oct 30, 2020
2b5e71c
Remove use of Counter
marrodion Oct 30, 2020
d36473e
Add tests for homogeneous groups
marrodion Oct 30, 2020
f65d873
Add StratifiedGroupKFold test against GroupKFold
marrodion Oct 30, 2020
627fc9f
Add changes to changelist in docstring
marrodion Oct 30, 2020
25fcb42
Add StratifiedGroupKFold to classes.rst
marrodion Oct 30, 2020
57c53a5
Fix description of StratifiedGroupKFold
marrodion Oct 30, 2020
484bf9d
Merge branch 'main' into stratified-group-kfold
marrodion Jan 24, 2021
234f290
Move license notice out of docstring
marrodion Jan 24, 2021
0eb6080
Disambiguate labels to classes in doc
marrodion Jan 24, 2021
2e0ee20
Merge remote-tracking branch 'upstream/main' into stratified-group-kfold
marrodion Feb 17, 2021
f20718e
Merge remote-tracking branch 'upstream/main' into stratified-group-kfold
marrodion Mar 5, 2021
cf912af
Add changelog entry
marrodion Mar 5, 2021
42e00ed
Fix changelog author entry
marrodion Mar 10, 2021
a1d0f9f
Fix StratifiedGroupKFold docstring
marrodion Mar 10, 2021
8e3c852
Better variable names
marrodion Mar 10, 2021
096b23b
Remove defaultdict in favor of numpy indexing
marrodion Mar 10, 2021
0464839
Extracted best_fold search into a separate method
marrodion Mar 11, 2021
c0f907a
Make use of numpy broadcasting instead of for loop
marrodion Mar 10, 2021
41036ad
Encode groups and use arrays instead of dicts
marrodion Mar 11, 2021
93fcdd4
Use numpy sort instead of python
marrodion Mar 11, 2021
ae67ad3
Clarify shuffling behavior of StratifiedGroupKF in docs
marrodion Mar 11, 2021
3940e23
Switch name from label_idx to class_idx
marrodion Mar 11, 2021
9cff771
Merge remote-tracking branch 'upstream/main' into stratified-group-kfold
marrodion Mar 11, 2021
cccbff7
Remove accidentally leftover comment
marrodion Mar 11, 2021
5cd8fdb
Fix np.sort keyword to support numpy < 1.15
marrodion Mar 11, 2021
4fbc7a8
Merge remote-tracking branch 'upstream/main' into stratified-group-kfold
marrodion Mar 17, 2021
024d3c6
Fix typo in docstring
marrodion Mar 17, 2021
09cfc37
Add StratifiedGroupKFold to visualization doc
marrodion Mar 17, 2021
6c8d5da
Merge remote-tracking branch 'upstream/main' into stratified-group-kfold
marrodion Mar 17, 2021
60cb778
Add visualization for uneven group as an example
marrodion Mar 17, 2021
29a0fe5
Fix image numbers to match updated example
marrodion Mar 18, 2021
859d28a
Add author
marrodion Mar 19, 2021
6fd5ec8
Add SGKF visualization to docs
marrodion Mar 19, 2021
42a6b80
Add comments for groups in stratified CV tests
marrodion Mar 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,7 @@ Splitter Classes
model_selection.ShuffleSplit
model_selection.StratifiedKFold
model_selection.StratifiedShuffleSplit
model_selection.StratifiedGroupKFold
model_selection.TimeSeriesSplit

Splitter Functions
Expand Down
64 changes: 58 additions & 6 deletions doc/modules/cross_validation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ Example of 2-fold cross-validation on a dataset with 4 samples::
Here is a visualization of the cross-validation behavior. Note that
:class:`KFold` is not affected by classes or groups.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_004.png
.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_006.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
:scale: 75%
Expand Down Expand Up @@ -509,7 +509,7 @@ Here is a usage example::
Here is a visualization of the cross-validation behavior. Note that
:class:`ShuffleSplit` is not affected by classes or groups.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_006.png
.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_008.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
:scale: 75%
Expand Down Expand Up @@ -566,7 +566,7 @@ We can see that :class:`StratifiedKFold` preserves the class ratios

Here is a visualization of the cross-validation behavior.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_007.png
.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_009.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
:scale: 75%
Expand All @@ -585,7 +585,7 @@ percentage for each target class as in the complete set.

Here is a visualization of the cross-validation behavior.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_009.png
.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_012.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
:scale: 75%
Expand Down Expand Up @@ -645,6 +645,58 @@ size due to the imbalance in the data.

Here is a visualization of the cross-validation behavior.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_007.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
:scale: 75%

.. _stratified_group_k_fold:

StratifiedGroupKFold
^^^^^^^^^^^^^^^^^^^^

:class:`StratifiedGroupKFold` is a cross-validation scheme that combines both
:class:`StratifiedKFold` and :class:`GroupKFold`. The idea is to try to
preserve the distribution of classes in each split while keeping each group
within a single split. That might be useful when you have an unbalanced
dataset so that using just :class:`GroupKFold` might produce skewed splits.

Example::

>>> from sklearn.model_selection import StratifiedGroupKFold
>>> X = list(range(18))
>>> y = [1] * 6 + [0] * 12
>>> groups = [1, 2, 3, 3, 4, 4, 1, 1, 2, 2, 3, 4, 5, 5, 5, 6, 6, 6]
>>> sgkf = StratifiedGroupKFold(n_splits=3)
>>> for train, test in sgkf.split(X, y, groups=groups):
... print("%s %s" % (train, test))
[ 0 2 3 4 5 6 7 10 11 15 16 17] [ 1 8 9 12 13 14]
[ 0 1 4 5 6 7 8 9 11 12 13 14] [ 2 3 10 15 16 17]
[ 1 2 3 8 9 10 12 13 14 15 16 17] [ 0 4 5 6 7 11]

Implementation notes:

- With the current implementation full shuffle is not possible in most
scenarios. When shuffle=True, the following happens:

1. All groups a shuffled.
2. Groups are sorted by standard deviation of classes using stable sort.
3. Sorted groups are iterated over and assigned to folds.

That means that only groups with the same standard deviation of class
distribution will be shuffled, which might be useful when each group has only
a single class.
- The algorithm greedily assigns each group to one of n_splits test sets,
choosing the test set that minimises the variance in class distribution
across test sets. Group assignment proceeds from groups with highest to
lowest variance in class frequency, i.e. large groups peaked on one or few
classes are assigned first.
- This split is suboptimal in a sense that it might produce imbalanced splits
even if perfect stratification is possible. If you have relatively close
distribution of classes in each group, using :class:`GroupKFold` is better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Insert the visualisation here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thank you.
Wasn't sure if needed, not every CV has a visualization in this documentation page.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps not needed, but helpful. Happy for you to make the docs more consistent in another pr! ;)

Here is a visualization of cross-validation behavior for uneven groups:

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_005.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
Expand Down Expand Up @@ -733,7 +785,7 @@ Here is a usage example::

Here is a visualization of the cross-validation behavior.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_008.png
.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_011.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
:scale: 75%
Expand Down Expand Up @@ -835,7 +887,7 @@ Example of 3-split time series cross-validation on a dataset with 6 samples::

Here is a visualization of the cross-validation behavior.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_010.png
.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_indices_013.png
:target: ../auto_examples/model_selection/plot_cv_indices.html
:align: center
:scale: 75%
Expand Down
10 changes: 10 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,16 @@ Changelog
are integral.
:pr:`9843` by :user:`Jon Crall <Erotemic>`.

:mod:`sklearn.model_selection`
..............................

- |Feature| added :class:`model_selection.StratifiedGroupKFold`, that combines
:class:`model_selection.StratifiedKFold` and `model_selection.GroupKFold`,
providing an ability to split data preserving the distribution of classes in
each split while keeping each group within a single split.
:pr:`18649` by `Leandro Hermida <hermidalc>` and
`Rodion Martynov <marrodion>`.

:mod:`sklearn.naive_bayes`
..........................

Expand Down
35 changes: 26 additions & 9 deletions examples/model_selection/plot_cv_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from sklearn.model_selection import (TimeSeriesSplit, KFold, ShuffleSplit,
StratifiedKFold, GroupShuffleSplit,
GroupKFold, StratifiedShuffleSplit)
GroupKFold, StratifiedShuffleSplit,
StratifiedGroupKFold)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
Expand Down Expand Up @@ -113,16 +114,32 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
# %%
# As you can see, by default the KFold cross-validation iterator does not
# take either datapoint class or group into consideration. We can change this
# by using the ``StratifiedKFold`` like so.
# by using either:
#
# - ``StratifiedKFold`` to preserve the percentage of samples for each class.
# - ``GroupKFold`` to ensure that the same group will not appear in two
# different folds.
# - ``StratifiedGroupKFold`` to keep the constraint of ``GroupKFold`` while
# attempting to return stratified folds.

fig, ax = plt.subplots()
cv = StratifiedKFold(n_splits)
plot_cv_indices(cv, X, y, groups, ax, n_splits)
# To better demonstrate the difference, we will assign samples to groups
# unevenly:

uneven_groups = np.sort(np.random.randint(0, 10, n_points))

cvs = [StratifiedKFold, GroupKFold, StratifiedGroupKFold]

for cv in cvs:
fig, ax = plt.subplots(figsize=(6, 3))
plot_cv_indices(cv(n_splits), X, y, uneven_groups, ax, n_splits)
ax.legend([Patch(color=cmap_cv(.8)), Patch(color=cmap_cv(.02))],
['Testing set', 'Training set'], loc=(1.02, .8))
# Make the legend fit
plt.tight_layout()
fig.subplots_adjust(right=.7)

# %%
# In this case, the cross-validation retained the same ratio of classes across
# each CV split. Next we'll visualize this behavior for a number of CV
# iterators.
# Next we'll visualize this behavior for a number of CV iterators.
#
# Visualize cross-validation indices for many CV objects
# ------------------------------------------------------
Expand All @@ -133,7 +150,7 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
#
# Note how some use the group/class information while others do not.

cvs = [KFold, GroupKFold, ShuffleSplit, StratifiedKFold,
cvs = [KFold, GroupKFold, ShuffleSplit, StratifiedKFold, StratifiedGroupKFold,
GroupShuffleSplit, StratifiedShuffleSplit, TimeSeriesSplit]


Expand Down
2 changes: 2 additions & 0 deletions sklearn/model_selection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ._split import ShuffleSplit
from ._split import GroupShuffleSplit
from ._split import StratifiedShuffleSplit
from ._split import StratifiedGroupKFold
from ._split import PredefinedSplit
from ._split import train_test_split
from ._split import check_cv
Expand Down Expand Up @@ -57,6 +58,7 @@
'RandomizedSearchCV',
'ShuffleSplit',
'StratifiedKFold',
'StratifiedGroupKFold',
'StratifiedShuffleSplit',
'check_cv',
'cross_val_predict',
Expand Down
Loading