Skip to content

Commit 6305e8d

Browse files
authored
Merge pull request #16178 from yozhikoff/add-multiple-label-support
ENH: Add multiple label support for Axes.plot()
2 parents 99e6240 + a161ae3 commit 6305e8d

File tree

4 files changed

+105
-2
lines changed

4 files changed

+105
-2
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
An iterable object with labels can be passed to `.Axes.plot`
2+
------------------------------------------------------------
3+
4+
When plotting multiple datasets by passing 2D data as *y* value to
5+
`~.Axes.plot`, labels for the datasets can be passed as a list, the
6+
length matching the number of columns in *y*.
7+
8+
.. plot::
9+
10+
import matplotlib.pyplot as plt
11+
12+
x = [1, 2, 3]
13+
14+
y = [[1, 2],
15+
[2, 5],
16+
[4, 9]]
17+
18+
plt.plot(x, y, label=['low', 'high'])
19+
plt.legend()

lib/matplotlib/axes/_axes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,6 +1499,8 @@ def plot(self, *args, scalex=True, scaley=True, data=None, **kwargs):
14991499
15001500
If you make multiple lines with one plot call, the kwargs
15011501
apply to all those lines.
1502+
In case if label object is iterable, each its element is
1503+
used as label for a separate line.
15021504
15031505
Here is a list of available `.Line2D` properties:
15041506

lib/matplotlib/axes/_base.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,12 @@ def __call__(self, *args, data=None, **kwargs):
294294
replaced[label_namer_idx], args[label_namer_idx])
295295
args = replaced
296296

297+
if len(args) >= 4 and not cbook.is_scalar_or_string(
298+
kwargs.get("label")):
299+
raise ValueError("plot() with multiple groups of data (i.e., "
300+
"pairs of x and y) does not support multiple "
301+
"labels")
302+
297303
# Repeatedly grab (x, y) or (x, y, format) from the front of args and
298304
# massage them into arguments to plot() or fill().
299305

@@ -447,8 +453,22 @@ def _plot_args(self, tup, kwargs, return_kwargs=False):
447453
ncx, ncy = x.shape[1], y.shape[1]
448454
if ncx > 1 and ncy > 1 and ncx != ncy:
449455
raise ValueError(f"x has {ncx} columns but y has {ncy} columns")
450-
result = (func(x[:, j % ncx], y[:, j % ncy], kw, kwargs)
451-
for j in range(max(ncx, ncy)))
456+
457+
label = kwargs.get('label')
458+
n_datasets = max(ncx, ncy)
459+
if n_datasets > 1 and not cbook.is_scalar_or_string(label):
460+
if len(label) != n_datasets:
461+
raise ValueError(f"label must be scalar or have the same "
462+
f"length as the input data, but found "
463+
f"{len(label)} for {n_datasets} datasets.")
464+
labels = label
465+
else:
466+
labels = [label] * n_datasets
467+
468+
result = (func(x[:, j % ncx], y[:, j % ncy], kw,
469+
{**kwargs, 'label': label})
470+
for j, label in enumerate(labels))
471+
452472
if return_kwargs:
453473
return list(result)
454474
else:

lib/matplotlib/tests/test_legend.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,3 +671,65 @@ def test_no_warn_big_data_when_loc_specified():
671671
ax.plot(np.arange(5000), label=idx)
672672
legend = ax.legend('best')
673673
fig.draw_artist(legend) # Check that no warning is emitted.
674+
675+
676+
@pytest.mark.parametrize('label_array', [['low', 'high'],
677+
('low', 'high'),
678+
np.array(['low', 'high'])])
679+
def test_plot_multiple_input_multiple_label(label_array):
680+
# test ax.plot() with multidimensional input
681+
# and multiple labels
682+
x = [1, 2, 3]
683+
y = [[1, 2],
684+
[2, 5],
685+
[4, 9]]
686+
687+
fig, ax = plt.subplots()
688+
ax.plot(x, y, label=label_array)
689+
leg = ax.legend()
690+
legend_texts = [entry.get_text() for entry in leg.get_texts()]
691+
assert legend_texts == ['low', 'high']
692+
693+
694+
@pytest.mark.parametrize('label', ['one', 1, int])
695+
def test_plot_multiple_input_single_label(label):
696+
# test ax.plot() with multidimensional input
697+
# and single label
698+
x = [1, 2, 3]
699+
y = [[1, 2],
700+
[2, 5],
701+
[4, 9]]
702+
703+
fig, ax = plt.subplots()
704+
ax.plot(x, y, label=label)
705+
leg = ax.legend()
706+
legend_texts = [entry.get_text() for entry in leg.get_texts()]
707+
assert legend_texts == [str(label)] * 2
708+
709+
710+
@pytest.mark.parametrize('label_array', [['low', 'high'],
711+
('low', 'high'),
712+
np.array(['low', 'high'])])
713+
def test_plot_single_input_multiple_label(label_array):
714+
# test ax.plot() with 1D array like input
715+
# and iterable label
716+
x = [1, 2, 3]
717+
y = [2, 5, 6]
718+
fig, ax = plt.subplots()
719+
ax.plot(x, y, label=label_array)
720+
leg = ax.legend()
721+
assert len(leg.get_texts()) == 1
722+
assert leg.get_texts()[0].get_text() == str(label_array)
723+
724+
725+
def test_plot_multiple_label_incorrect_length_exception():
726+
# check that excepton is raised if multiple labels
727+
# are given, but number of on labels != number of lines
728+
with pytest.raises(ValueError):
729+
x = [1, 2, 3]
730+
y = [[1, 2],
731+
[2, 5],
732+
[4, 9]]
733+
label = ['high', 'low', 'medium']
734+
fig, ax = plt.subplots()
735+
ax.plot(x, y, label=label)

0 commit comments

Comments
 (0)