Skip to content

ENH/DEP add class method and deprecate plot function for confusion matrix #18543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Jan 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
bff9c61
ENH/DEP add class method to display confusion matrix and deprecate pl…
glemaitre Oct 6, 2020
131d175
MNT add a tool to inject docstring
glemaitre Oct 6, 2020
7a6f326
add docstring in private methods
glemaitre Oct 6, 2020
0d27d42
add deprecation filter in test
glemaitre Oct 6, 2020
57eafe9
remove warning documentation
glemaitre Oct 6, 2020
aa1d7c3
DOC add an entry in whats new
glemaitre Oct 6, 2020
ecb11f0
update examples
glemaitre Oct 6, 2020
99edffa
TST add new files
glemaitre Oct 6, 2020
46270b2
TST add test display custom labels
glemaitre Oct 6, 2020
987fd8e
TST add tests for plotting ConfusionMatrixDisplay
glemaitre Oct 6, 2020
59ecd00
TST adding back right error message for plot_confusion_matrix
glemaitre Oct 6, 2020
9a1641d
TST add test with different complex pipeline
glemaitre Oct 6, 2020
e2b4431
DOC fix docstring
glemaitre Oct 6, 2020
1b5ce74
TST add test for docstring substitution
glemaitre Oct 6, 2020
c1f0790
TST add really the file
glemaitre Oct 6, 2020
0633035
first round addressing review
glemaitre Oct 7, 2020
9fa42d3
improve injecter
glemaitre Oct 7, 2020
0ee9b99
Fix example docstring
glemaitre Oct 7, 2020
ee1ba71
remove fixture for more explicit tests
glemaitre Oct 7, 2020
1753197
Add see also
glemaitre Oct 7, 2020
399a04c
FIX avoid double validation
glemaitre Oct 7, 2020
5935559
Merge branch 'master' into is/15880_confusion_matrix
glemaitre Oct 12, 2020
1bf9241
iter
glemaitre Oct 12, 2020
b64f2d1
remove the need of double braces
glemaitre Oct 12, 2020
648e510
Revert "remove the need of double braces"
glemaitre Oct 17, 2020
d9d6274
remove string injection
glemaitre Oct 17, 2020
ca85e7c
Merge remote-tracking branch 'origin/master' into is/15880_confusion_…
glemaitre Oct 21, 2020
a5622fc
fix
glemaitre Oct 25, 2020
78fc223
Apply suggestions from code review
glemaitre Nov 3, 2020
3075bbb
MNT use explicit constructor
glemaitre Nov 3, 2020
9e08343
Merge remote-tracking branch 'origin/master' into is/15880_confusion_…
glemaitre Jan 5, 2021
9ccd1cc
MNT update deprecation version
glemaitre Jan 5, 2021
9870fc5
Merge remote-tracking branch 'origin/master' into is/15880_confusion_…
glemaitre Jan 5, 2021
c7e52d6
Apply suggestions from code review
glemaitre Jan 20, 2021
f1027bc
address adrin comments
glemaitre Jan 21, 2021
2c06de3
Merge remote-tracking branch 'origin/main' into is/15880_confusion_ma…
glemaitre Jan 22, 2021
29e8f1d
reorder whats new
glemaitre Jan 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ predicted to be in group :math:`j`. Here is an example::
[0, 0, 1],
[1, 0, 2]])

:func:`plot_confusion_matrix` can be used to visually represent a confusion
:class:`ConfusionMatrixDisplay` can be used to visually represent a confusion
matrix as shown in the
:ref:`sphx_glr_auto_examples_model_selection_plot_confusion_matrix.py`
example, which creates the following figure:
Expand Down
11 changes: 11 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ Changelog
:pr:`17743` by :user:`Maria Telenczuk <maikia>` and
:user:`Alexandre Gramfort <agramfort>`.

:mod:`sklearn.metrics`
......................

- |API| :class:`metrics.ConfusionMatrixDisplay` exposes two class methods
:func:`~metrics.ConfusionMatrixDisplay.from_estimator` and
:func:`~metrics.ConfusionMatrixDisplay.from_predictions` allowing to create
a confusion matrix plot using an estimator or the predictions.
:func:`metrics.plot_confusion_matrix` is deprecated in favor of these two
class methods and will be removed in 1.2.
:pr:`18543` by `Guillaume Lemaitre`_.

:mod:`sklearn.naive_bayes`
..........................

Expand Down
2 changes: 1 addition & 1 deletion examples/classification/plot_digits_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
# We can also plot a :ref:`confusion matrix <confusion_matrix>` of the
# true digit values and the predicted digit values.

disp = metrics.plot_confusion_matrix(clf, X_test, y_test)
disp = metrics.ConfusionMatrixDisplay.from_predictions(y_test, predicted)
disp.figure_.suptitle("Confusion Matrix")
print(f"Confusion matrix:\n{disp.confusion_matrix}")

Expand Down
10 changes: 5 additions & 5 deletions examples/model_selection/plot_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import plot_confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay

# import some data to play with
iris = datasets.load_iris()
Expand All @@ -52,10 +52,10 @@
titles_options = [("Confusion matrix, without normalization", None),
("Normalized confusion matrix", 'true')]
for title, normalize in titles_options:
disp = plot_confusion_matrix(classifier, X_test, y_test,
display_labels=class_names,
cmap=plt.cm.Blues,
normalize=normalize)
disp = ConfusionMatrixDisplay.from_estimator(
classifier, X_test, y_test, display_labels=class_names,
cmap=plt.cm.Blues, normalize=normalize
)
disp.ax_.set_title(title)

print(title)
Expand Down
290 changes: 286 additions & 4 deletions sklearn/metrics/_plot/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .. import confusion_matrix
from ...utils import check_matplotlib_support
from ...utils import deprecated
from ...utils.multiclass import unique_labels
from ...utils.validation import _deprecate_positional_args
from ...base import is_classifier
Expand All @@ -12,7 +13,9 @@
class ConfusionMatrixDisplay:
"""Confusion Matrix visualization.

It is recommend to use :func:`~sklearn.metrics.plot_confusion_matrix` to
It is recommend to use
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_estimator` or
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_predictions` to
create a :class:`ConfusionMatrixDisplay`. All parameters are stored as
attributes.

Expand Down Expand Up @@ -161,7 +164,274 @@ def plot(self, *, include_values=True, cmap='viridis',
self.ax_ = ax
return self

@classmethod
def from_estimator(
cls,
estimator,
X,
y,
*,
labels=None,
sample_weight=None,
normalize=None,
display_labels=None,
include_values=True,
xticks_rotation="horizontal",
values_format=None,
cmap="viridis",
ax=None,
colorbar=True,
):
"""Plot Confusion Matrix given an estimator and some data.

Read more in the :ref:`User Guide <confusion_matrix>`.

.. versionadded:: 1.0

Parameters
----------
estimator : estimator instance
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
in which the last estimator is a classifier.

X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.

y : array-like of shape (n_samples,)
Target values.

labels : array-like of shape (n_classes,), default=None
List of labels to index the confusion matrix. This may be used to
reorder or select a subset of labels. If `None` is given, those
that appear at least once in `y_true` or `y_pred` are used in
sorted order.

sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

normalize : {'true', 'pred', 'all'}, default=None
Either to normalize the counts display in the matrix:

- if `'true'`, the confusion matrix is normalized over the true
conditions (e.g. rows);
- if `'pred'`, the confusion matrix is normalized over the
predicted conditions (e.g. columns);
- if `'all'`, the confusion matrix is normalized by the total
number of samples;
- if `None` (default), the confusion matrix will not be normalized.

display_labels : array-like of shape (n_classes,), default=None
Target names used for plotting. By default, `labels` will be used
if it is defined, otherwise the unique labels of `y_true` and
`y_pred` will be used.

include_values : bool, default=True
Includes values in confusion matrix.

xticks_rotation : {'vertical', 'horizontal'} or float, \
default='horizontal'
Rotation of xtick labels.

values_format : str, default=None
Format specification for values in confusion matrix. If `None`, the
format specification is 'd' or '.2g' whichever is shorter.

cmap : str or matplotlib Colormap, default='viridis'
Colormap recognized by matplotlib.

ax : matplotlib Axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.

colorbar : bool, default=True
Whether or not to add a colorbar to the plot.

Returns
-------
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`

See Also
--------
ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix
given the true and predicted labels.

Examples
--------
>>> import matplotlib.pyplot as plt # doctest: +SKIP
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import ConfusionMatrixDisplay
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.svm import SVC
>>> X, y = make_classification(random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(
... X, y, random_state=0)
>>> clf = SVC(random_state=0)
>>> clf.fit(X_train, y_train)
SVC(random_state=0)
>>> ConfusionMatrixDisplay.from_estimator(
... clf, X_test, y_test) # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
"""
method_name = f"{cls.__name__}.from_estimator"
check_matplotlib_support(method_name)
if not is_classifier(estimator):
raise ValueError(f"{method_name} only supports classifiers")
y_pred = estimator.predict(X)

return cls.from_predictions(
y,
y_pred,
sample_weight=sample_weight,
labels=labels,
normalize=normalize,
display_labels=display_labels,
include_values=include_values,
cmap=cmap,
ax=ax,
xticks_rotation=xticks_rotation,
values_format=values_format,
colorbar=colorbar,
)

@classmethod
def from_predictions(
cls,
y_true,
y_pred,
*,
labels=None,
sample_weight=None,
normalize=None,
display_labels=None,
include_values=True,
xticks_rotation="horizontal",
values_format=None,
cmap="viridis",
ax=None,
colorbar=True,
):
"""Plot Confusion Matrix given true and predicted labels.

Read more in the :ref:`User Guide <confusion_matrix>`.

.. versionadded:: 0.24

Parameters
----------
y_true : array-like of shape (n_samples,)
True labels.

y_pred : array-like of shape (n_samples,)
The predicted labels given by the method `predict` of an
classifier.

labels : array-like of shape (n_classes,), default=None
List of labels to index the confusion matrix. This may be used to
reorder or select a subset of labels. If `None` is given, those
that appear at least once in `y_true` or `y_pred` are used in
sorted order.

sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

normalize : {'true', 'pred', 'all'}, default=None
Either to normalize the counts display in the matrix:

- if `'true'`, the confusion matrix is normalized over the true
conditions (e.g. rows);
- if `'pred'`, the confusion matrix is normalized over the
predicted conditions (e.g. columns);
- if `'all'`, the confusion matrix is normalized by the total
number of samples;
- if `None` (default), the confusion matrix will not be normalized.

display_labels : array-like of shape (n_classes,), default=None
Target names used for plotting. By default, `labels` will be used
if it is defined, otherwise the unique labels of `y_true` and
`y_pred` will be used.

include_values : bool, default=True
Includes values in confusion matrix.

xticks_rotation : {'vertical', 'horizontal'} or float, \
default='horizontal'
Rotation of xtick labels.

values_format : str, default=None
Format specification for values in confusion matrix. If `None`, the
format specification is 'd' or '.2g' whichever is shorter.

cmap : str or matplotlib Colormap, default='viridis'
Colormap recognized by matplotlib.

ax : matplotlib Axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.

colorbar : bool, default=True
Whether or not to add a colorbar to the plot.

Returns
-------
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`

See Also
--------
ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix
given an estimator, the data, and the label.

Examples
--------
>>> import matplotlib.pyplot as plt # doctest: +SKIP
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import ConfusionMatrixDisplay
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.svm import SVC
>>> X, y = make_classification(random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(
... X, y, random_state=0)
>>> clf = SVC(random_state=0)
>>> clf.fit(X_train, y_train)
SVC(random_state=0)
>>> y_pred = clf.predict(X_test)
>>> ConfusionMatrixDisplay.from_predictions(
... y_test, y_pred) # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
"""
check_matplotlib_support(f"{cls.__name__}.from_predictions")

if display_labels is None:
if labels is None:
display_labels = unique_labels(y_true, y_pred)
else:
display_labels = labels

cm = confusion_matrix(
y_true,
y_pred,
sample_weight=sample_weight,
labels=labels,
normalize=normalize,
)

disp = cls(confusion_matrix=cm, display_labels=display_labels)

return disp.plot(
include_values=include_values,
cmap=cmap,
ax=ax,
xticks_rotation=xticks_rotation,
values_format=values_format,
colorbar=colorbar,
)


@deprecated(
"Function plot_confusion_matrix is deprecated in 1.0 and will be "
"removed in 1.2. Use one of the class methods: "
"ConfusionMatrixDisplay.from_predictions or "
"ConfusionMatrixDisplay.from_estimator."
)
@_deprecate_positional_args
def plot_confusion_matrix(estimator, X, y_true, *, labels=None,
sample_weight=None, normalize=None,
Expand All @@ -173,6 +443,12 @@ def plot_confusion_matrix(estimator, X, y_true, *, labels=None,

Read more in the :ref:`User Guide <confusion_matrix>`.

.. deprecated:: 1.0
`plot_confusion_matrix` is deprecated in 1.0 and will be removed in
1.2. Use one of the following class methods:
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_predictions` or
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_estimator`.

Parameters
----------
estimator : estimator instance
Expand All @@ -194,9 +470,15 @@ def plot_confusion_matrix(estimator, X, y_true, *, labels=None,
Sample weights.

normalize : {'true', 'pred', 'all'}, default=None
Normalizes confusion matrix over the true (rows), predicted (columns)
conditions or all the population. If None, confusion matrix will not be
normalized.
Either to normalize the counts display in the matrix:

- if `'true'`, the confusion matrix is normalized over the true
conditions (e.g. rows);
- if `'pred'`, the confusion matrix is normalized over the
predicted conditions (e.g. columns);
- if `'all'`, the confusion matrix is normalized by the total
number of samples;
- if `None` (default), the confusion matrix will not be normalized.

display_labels : array-like of shape (n_classes,), default=None
Target names used for plotting. By default, `labels` will be used if
Expand Down
Loading