Skip to content

Commit babbee0

Browse files
committed
Updated input types for MultiNorm.__call__()
1 parent c6cf321 commit babbee0

File tree

2 files changed

+182
-12
lines changed

2 files changed

+182
-12
lines changed

lib/matplotlib/colors.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3413,10 +3413,9 @@ def __call__(self, values, clip=None, structured_output=None):
34133413
Parameters
34143414
----------
34153415
values : array-like
3416-
Data to normalize, as tuple, scalar array or structured array.
3416+
Data to normalize, as tuple or list or structured array.
34173417
3418-
- If tuple, must be of length `n_components`
3419-
- If scalar array, the first axis must be of length `n_components`
3418+
- If tuple or list, must be of length `n_components`
34203419
- If structured array, must have `n_components` fields.
34213420
34223421
clip : list of bools or bool or None, optional
@@ -3530,22 +3529,72 @@ def _iterable_components_in_data(data, n_components):
35303529
Parameters
35313530
----------
35323531
data : np.ndarray, tuple or list
3533-
The input array. It must either be an array with n_components fields or have
3534-
a length (n_components)
3532+
The input data, as a tuple or list or structured array.
3533+
3534+
- If tuple or list, must be of length `n_components`
3535+
- If structured array, must have `n_components` fields.
35353536
35363537
Returns
35373538
-------
35383539
tuple of np.ndarray
35393540
35403541
"""
3541-
if isinstance(data, np.ndarray) and data.dtype.fields is not None:
3542-
data = tuple(data[descriptor[0]] for descriptor in data.dtype.descr)
3543-
if len(data) != n_components:
3544-
raise ValueError("The input to this `MultiNorm` must be of shape "
3545-
f"({n_components}, ...), or be structured array or scalar "
3546-
f"with {n_components} fields.")
3542+
if isinstance(data, np.ndarray):
3543+
if data.dtype.fields is not None:
3544+
data = tuple(data[descriptor[0]] for descriptor in data.dtype.descr)
3545+
if len(data) != n_components:
3546+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3547+
f". A structured array with "
3548+
f"{len(data)} fields is not compatible")
3549+
else:
3550+
# Input is a scalar array, which we do not support.
3551+
# try to give a hint as to how the data can be converted to
3552+
# an accepted format
3553+
if ((len(data.shape) == 1 and
3554+
data.shape[0] == n_components) or
3555+
(len(data.shape) > 1 and
3556+
data.shape[0] == n_components and
3557+
data.shape[-1] != n_components)
3558+
):
3559+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3560+
". You can use `list(data)` to convert"
3561+
f" the input data of shape {data.shape} to"
3562+
" a compatible list")
3563+
3564+
elif (len(data.shape) > 1 and
3565+
data.shape[-1] == n_components and
3566+
data.shape[0] != n_components):
3567+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3568+
". You can use "
3569+
"`rfn.unstructured_to_structured(data)` available "
3570+
"with `from numpy.lib import recfunctions as rfn` "
3571+
"to convert the input array of shape "
3572+
f"{data.shape} to a structured array")
3573+
else:
3574+
# Cannot give shape hint
3575+
# Either neither first nor last axis matches, or both do.
3576+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3577+
f". An np.ndarray of shape {data.shape} is"
3578+
" not compatible")
3579+
elif isinstance(data, (tuple, list)):
3580+
if len(data) != n_components:
3581+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3582+
f". A {type(data)} of length {len(data)} is"
3583+
" not compatible")
3584+
else:
3585+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3586+
f". Input of type {type(data)} is not supported")
3587+
35473588
return data
35483589

3590+
@staticmethod
3591+
def _get_input_err(n_components):
3592+
# returns the start of the error message given when a
3593+
# MultiNorm receives incompatible input
3594+
return ("The input to this `MultiNorm` must be a list or tuple "
3595+
f"of length {n_components}, or be structured array "
3596+
f"with {n_components} fields")
3597+
35493598
@staticmethod
35503599
def _ensure_multicomponent_data(data, n_components):
35513600
"""

lib/matplotlib/tests/test_colors.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import base64
1010
import platform
1111

12+
from numpy.lib import recfunctions as rfn
1213
from numpy.testing import assert_array_equal, assert_array_almost_equal
1314

1415
from matplotlib import cbook, cm
@@ -1891,7 +1892,7 @@ def test_close_error_name():
18911892
matplotlib.colormaps["grays"]
18921893

18931894

1894-
def test_multi_norm():
1895+
def test_multi_norm_creation():
18951896
# tests for mcolors.MultiNorm
18961897

18971898
# test wrong input
@@ -1911,13 +1912,24 @@ def test_multi_norm():
19111912
match="Invalid norm str name"):
19121913
mcolors.MultiNorm(["None"])
19131914

1915+
norm = mpl.colors.MultiNorm(['linear', 'linear'])
1916+
1917+
1918+
def test_multi_norm_call_vmin_vmax():
19141919
# test get vmin, vmax
19151920
norm = mpl.colors.MultiNorm(['linear', 'log'])
19161921
norm.vmin = 1
19171922
norm.vmax = 2
19181923
assert norm.vmin == (1, 1)
19191924
assert norm.vmax == (2, 2)
19201925

1926+
1927+
def test_multi_norm_call_clip_inverse():
1928+
# test get vmin, vmax
1929+
norm = mpl.colors.MultiNorm(['linear', 'log'])
1930+
norm.vmin = 1
1931+
norm.vmax = 2
1932+
19211933
# test call with clip
19221934
assert_array_equal(norm([3, 3], clip=False), [2.0, 1.584962500721156])
19231935
assert_array_equal(norm([3, 3], clip=True), [1.0, 1.0])
@@ -1933,6 +1945,9 @@ def test_multi_norm():
19331945
# test inverse
19341946
assert_array_almost_equal(norm.inverse([0.5, 0.5849625007211562]), [1.5, 1.5])
19351947

1948+
1949+
def test_multi_norm_autoscale():
1950+
norm = mpl.colors.MultiNorm(['linear', 'log'])
19361951
# test autoscale
19371952
norm.autoscale([[0, 1, 2, 3], [0.1, 1, 2, 3]])
19381953
assert_array_equal(norm.vmin, [0, 0.1])
@@ -1945,3 +1960,109 @@ def test_multi_norm():
19451960
assert_array_equal(norm([5, 0]), [1, 0.5])
19461961
assert_array_equal(norm.vmin, (0, -50))
19471962
assert_array_equal(norm.vmax, (5, 50))
1963+
1964+
1965+
def test_mult_norm_call_types():
1966+
mn = mpl.colors.MultiNorm(['linear', 'linear'])
1967+
mn.vmin = -2
1968+
mn.vmax = 2
1969+
1970+
vals = np.arange(6).reshape((3,2))
1971+
target = np.ma.array([(0.5, 0.75),
1972+
(1., 1.25),
1973+
(1.5, 1.75)])
1974+
1975+
# test structured array as input
1976+
structured_target = rfn.unstructured_to_structured(target)
1977+
from_mn= mn(rfn.unstructured_to_structured(vals))
1978+
assert from_mn.dtype == structured_target.dtype
1979+
assert_array_almost_equal(rfn.structured_to_unstructured(from_mn),
1980+
rfn.structured_to_unstructured(structured_target))
1981+
1982+
# test list of arrays as input
1983+
assert_array_almost_equal(mn(list(vals.T)),
1984+
list(target.T))
1985+
# test list of floats as input
1986+
assert_array_almost_equal(mn(list(vals[0])),
1987+
list(target[0]))
1988+
# test tuple of arrays as input
1989+
assert_array_almost_equal(mn(tuple(vals.T)),
1990+
list(target.T))
1991+
1992+
1993+
# test setting structured_output true/false:
1994+
# structured input, structured output
1995+
from_mn = mn(rfn.unstructured_to_structured(vals), structured_output=True)
1996+
assert from_mn.dtype == structured_target.dtype
1997+
assert_array_almost_equal(rfn.structured_to_unstructured(from_mn),
1998+
rfn.structured_to_unstructured(structured_target))
1999+
# structured input, list as output
2000+
from_mn = mn(rfn.unstructured_to_structured(vals), structured_output=False)
2001+
assert_array_almost_equal(from_mn,
2002+
list(target.T))
2003+
# list as input, structured output
2004+
from_mn= mn(list(vals.T), structured_output=True)
2005+
assert from_mn.dtype == structured_target.dtype
2006+
assert_array_almost_equal(rfn.structured_to_unstructured(from_mn),
2007+
rfn.structured_to_unstructured(structured_target))
2008+
# list as input, list as output
2009+
from_mn = mn(list(vals.T), structured_output=False)
2010+
assert_array_almost_equal(from_mn,
2011+
list(target.T))
2012+
2013+
# test with NoNorm, list as input
2014+
mn_no_norm = mpl.colors.MultiNorm(['linear', mcolors.NoNorm()])
2015+
no_norm_out = mn_no_norm(list(vals.T))
2016+
assert_array_almost_equal(no_norm_out,
2017+
[[0., 0.5, 1.],
2018+
[1, 3, 5]])
2019+
assert no_norm_out[0].dtype == np.dtype('float64')
2020+
assert no_norm_out[1].dtype == np.dtype('int64')
2021+
2022+
# test with NoNorm, structured array as input
2023+
mn_no_norm = mpl.colors.MultiNorm(['linear', mcolors.NoNorm()])
2024+
no_norm_out = mn_no_norm(rfn.unstructured_to_structured(vals))
2025+
assert_array_almost_equal(rfn.structured_to_unstructured(no_norm_out),
2026+
np.array(\
2027+
[[0., 0.5, 1.],
2028+
[1, 3, 5]]).T)
2029+
assert no_norm_out.dtype['f0'] == np.dtype('float64')
2030+
assert no_norm_out.dtype['f1'] == np.dtype('int64')
2031+
2032+
# test single int as input
2033+
with pytest.raises(ValueError,
2034+
match="Input of type <class 'int'> is not supported"):
2035+
mn(1)
2036+
2037+
# test list of incompatible size
2038+
with pytest.raises(ValueError,
2039+
match="A <class 'list'> of length 3 is not compatible"):
2040+
mn([3, 2, 1])
2041+
2042+
# np.arrays of shapes that can be converted:
2043+
for data in [np.zeros(2), np.zeros((2,3)), np.zeros((2,3,3))]:
2044+
with pytest.raises(ValueError,
2045+
match=r"You can use `list\(data\)` to convert"):
2046+
mn(data)
2047+
2048+
for data in [np.zeros((3, 2)), np.zeros((3, 3, 2))]:
2049+
with pytest.raises(ValueError,
2050+
match=r"You can use `rfn.unstructured_to_structured"):
2051+
mn(data)
2052+
2053+
# np.ndarray that can be converted, but unclear if first or last axis
2054+
for data in [np.zeros((2, 2)), np.zeros((2, 3, 2))]:
2055+
with pytest.raises(ValueError,
2056+
match="An np.ndarray of shape"):
2057+
mn(data)
2058+
2059+
# incompatible arrays where no relevant axis matches
2060+
for data in [np.zeros(3), np.zeros((3, 2, 3))]:
2061+
with pytest.raises(ValueError,
2062+
match=r"An np.ndarray of shape"):
2063+
mn(data)
2064+
2065+
# test incompatible class
2066+
with pytest.raises(ValueError,
2067+
match="Input of type <class 'str'> is not supported"):
2068+
mn("An object of incompatible class")

0 commit comments

Comments
 (0)