Skip to content

Commit ab21254

Browse files
joclementglemaitre
authored andcommitted
FIX mislabelling multiclass target when labels is provided in top_k_accuracy_score (#19721)
1 parent 0143fe4 commit ab21254

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

doc/whats_new/v0.24.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ Changelog
5959
- |Fix|: Fixed a bug in :class:`linear_model.LogisticRegression`: the
6060
sample_weight object is not modified anymore. :pr:`19182` by
6161
:user:`Yosuke KOBAYASHI <m7142yosuke>`.
62+
63+
:mod:`sklearn.metrics`
64+
......................
65+
66+
- |Fix| :func:`metrics.top_k_accuracy_score` now supports multiclass
67+
problems where only two classes appear in `y_true` and all the classes
68+
are specified in `labels`.
69+
:pr:`19721` by :user:`Joris Clement <flyingdutchman23>`.
6270

6371
:mod:`sklearn.model_selection`
6472
..............................

sklearn/metrics/_ranking.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ def top_k_accuracy_score(y_true, y_score, *, k=2, normalize=True,
15891589
non-thresholded decision values (as returned by
15901590
:term:`decision_function` on some classifiers). The binary case expects
15911591
scores with shape (n_samples,) while the multiclass case expects scores
1592-
with shape (n_samples, n_classes). In the nulticlass case, the order of
1592+
with shape (n_samples, n_classes). In the multiclass case, the order of
15931593
the class scores must correspond to the order of ``labels``, if
15941594
provided, or else to the numerical or lexicographical order of the
15951595
labels in ``y_true``.
@@ -1646,6 +1646,8 @@ def top_k_accuracy_score(y_true, y_score, *, k=2, normalize=True,
16461646
y_true = check_array(y_true, ensure_2d=False, dtype=None)
16471647
y_true = column_or_1d(y_true)
16481648
y_type = type_of_target(y_true)
1649+
if y_type == "binary" and labels is not None and len(labels) > 2:
1650+
y_type = "multiclass"
16491651
y_score = check_array(y_score, ensure_2d=False)
16501652
y_score = column_or_1d(y_score) if y_type == 'binary' else y_score
16511653
check_consistent_length(y_true, y_score, sample_weight)

sklearn/metrics/tests/test_ranking.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,30 @@ def test_top_k_accuracy_score_binary(y_score, k, true_score):
16501650
assert score == score_acc == pytest.approx(true_score)
16511651

16521652

1653+
@pytest.mark.parametrize('y_true, true_score, labels', [
1654+
(np.array([0, 1, 1, 2]), 0.75, [0, 1, 2, 3]),
1655+
(np.array([0, 1, 1, 1]), 0.5, [0, 1, 2, 3]),
1656+
(np.array([1, 1, 1, 1]), 0.5, [0, 1, 2, 3]),
1657+
(np.array(['a', 'e', 'e', 'a']), 0.75, ['a', 'b', 'd', 'e']),
1658+
])
1659+
@pytest.mark.parametrize("labels_as_ndarray", [True, False])
1660+
def test_top_k_accuracy_score_multiclass_with_labels(
1661+
y_true, true_score, labels, labels_as_ndarray
1662+
):
1663+
"""Test when labels and y_score are multiclass."""
1664+
if labels_as_ndarray:
1665+
labels = np.asarray(labels)
1666+
y_score = np.array([
1667+
[0.4, 0.3, 0.2, 0.1],
1668+
[0.1, 0.3, 0.4, 0.2],
1669+
[0.4, 0.1, 0.2, 0.3],
1670+
[0.3, 0.2, 0.4, 0.1],
1671+
])
1672+
1673+
score = top_k_accuracy_score(y_true, y_score, k=2, labels=labels)
1674+
assert score == pytest.approx(true_score)
1675+
1676+
16531677
def test_top_k_accuracy_score_increasing():
16541678
# Make sure increasing k leads to a higher score
16551679
X, y = datasets.make_classification(n_classes=10, n_samples=1000,

0 commit comments

Comments
 (0)