Skip to content

Commit 24325da

Browse files
committed
Use pytest.mark.parametrize
1 parent 79c1b63 commit 24325da

File tree

1 file changed

+23
-28
lines changed

1 file changed

+23
-28
lines changed

lib/matplotlib/tests/test_legend.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -673,54 +673,49 @@ def test_no_warn_big_data_when_loc_specified():
673673
fig.draw_artist(legend) # Check that no warning is emitted.
674674

675675

676-
def test_plot_multiple_input_multiple_label():
676+
@pytest.mark.parametrize('label_array', [['low', 'high'], ('low', 'high'), np.array(['low', 'high'])])
677+
def test_plot_multiple_input_multiple_label(label_array):
677678
# test ax.plot() with multidimensional input
678679
# and multiple labels
679680
x = [1, 2, 3]
680681
y = [[1, 2],
681682
[2, 5],
682683
[4, 9]]
683-
label_arrays = [['low', 'high'],
684-
('low', 'high'),
685-
np.array(['low', 'high'])]
686-
for label in label_arrays:
687-
fig, ax = plt.subplots()
688-
ax.plot(x, y, label=label)
689-
leg = ax.legend()
690-
legend_texts = [entry.get_text() for entry in leg.get_texts()]
691-
assert legend_texts == ['low', 'high']
692684

685+
fig, ax = plt.subplots()
686+
ax.plot(x, y, label=label_array)
687+
leg = ax.legend()
688+
legend_texts = [entry.get_text() for entry in leg.get_texts()]
689+
assert legend_texts == ['low', 'high']
693690

694-
def test_plot_multiple_input_single_label():
691+
692+
@pytest.mark.parametrize('label', ['one', 1, int])
693+
def test_plot_multiple_input_single_label(label):
695694
# test ax.plot() with multidimensional input
696695
# and single label
697696
x = [1, 2, 3]
698697
y = [[1, 2],
699698
[2, 5],
700699
[4, 9]]
701-
labels = ['one', 1, int]
702-
for label in labels:
703-
fig, ax = plt.subplots()
704-
ax.plot(x, y, label=label)
705-
leg = ax.legend()
706-
legend_texts = [entry.get_text() for entry in leg.get_texts()]
707-
assert legend_texts == [str(label)] * 2
708700

701+
fig, ax = plt.subplots()
702+
ax.plot(x, y, label=label)
703+
leg = ax.legend()
704+
legend_texts = [entry.get_text() for entry in leg.get_texts()]
705+
assert legend_texts == [str(label)] * 2
709706

710-
def test_plot_single_input_multiple_label():
707+
708+
@pytest.mark.parametrize('label_array', [['low', 'high'], ('low', 'high'), np.array(['low', 'high'])])
709+
def test_plot_single_input_multiple_label(label_array):
711710
# test ax.plot() with 1D array like input
712711
# and iterable label
713712
x = [1, 2, 3]
714713
y = [2, 5, 6]
715-
label_arrays = [['low', 'high'],
716-
('low', 'high'),
717-
np.array(['low', 'high'])]
718-
for label in label_arrays:
719-
fig, ax = plt.subplots()
720-
ax.plot(x, y, label=label)
721-
leg = ax.legend()
722-
assert len(leg.get_texts()) == 1
723-
assert leg.get_texts()[0].get_text() == str(label)
714+
fig, ax = plt.subplots()
715+
ax.plot(x, y, label=label_array)
716+
leg = ax.legend()
717+
assert len(leg.get_texts()) == 1
718+
assert leg.get_texts()[0].get_text() == str(label_array)
724719

725720

726721
def test_plot_multiple_label_incorrect_length_exception():

0 commit comments

Comments
 (0)