Skip to content

Commit cc2517c

Browse files
committed
Apply unit conversion early in errorbar().
This allow using normal numpy constructs rather than manually looping and broadcasting. _process_unit_info was already special-handling `data is None` in a few places; the change here only handle the (theoretical) extra case where a custom unit converter would fail to properly pass None through.
1 parent d235b02 commit cc2517c

File tree

2 files changed

+18
-46
lines changed

2 files changed

+18
-46
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3281,27 +3281,12 @@ def errorbar(self, x, y, yerr=None, xerr=None,
32813281
kwargs = {k: v for k, v in kwargs.items() if v is not None}
32823282
kwargs.setdefault('zorder', 2)
32833283

3284-
self._process_unit_info([("x", x), ("y", y)], kwargs, convert=False)
3285-
3286-
# Make sure all the args are iterable; use lists not arrays to preserve
3287-
# units.
3288-
if not np.iterable(x):
3289-
x = [x]
3290-
3291-
if not np.iterable(y):
3292-
y = [y]
3293-
3284+
x, y, xerr, yerr = self._process_unit_info(
3285+
[("x", x), ("y", y), ("x", xerr), ("y", yerr)], kwargs)
3286+
x, y = np.atleast_1d(x, y) # Make sure all the args are iterable.
32943287
if len(x) != len(y):
32953288
raise ValueError("'x' and 'y' must have the same size")
32963289

3297-
if xerr is not None:
3298-
if not np.iterable(xerr):
3299-
xerr = [xerr] * len(x)
3300-
3301-
if yerr is not None:
3302-
if not np.iterable(yerr):
3303-
yerr = [yerr] * len(y)
3304-
33053290
if isinstance(errorevery, Integral):
33063291
errorevery = (0, errorevery)
33073292
if isinstance(errorevery, tuple):
@@ -3313,10 +3298,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
33133298
raise ValueError(
33143299
f'errorevery={errorevery!r} is a not a tuple of two '
33153300
f'integers')
3316-
33173301
elif isinstance(errorevery, slice):
33183302
pass
3319-
33203303
elif not isinstance(errorevery, str) and np.iterable(errorevery):
33213304
# fancy indexing
33223305
try:
@@ -3328,6 +3311,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
33283311
else:
33293312
raise ValueError(
33303313
f"errorevery={errorevery!r} is not a recognized value")
3314+
everymask = np.zeros(len(x), bool)
3315+
everymask[errorevery] = True
33313316

33323317
label = kwargs.pop("label", None)
33333318
kwargs['label'] = '_nolegend_'
@@ -3410,13 +3395,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
34103395
xlolims = np.broadcast_to(xlolims, len(x)).astype(bool)
34113396
xuplims = np.broadcast_to(xuplims, len(x)).astype(bool)
34123397

3413-
everymask = np.zeros(len(x), bool)
3414-
everymask[errorevery] = True
3415-
3416-
def apply_mask(arrays, mask):
3417-
# Return, for each array in *arrays*, the elements for which *mask*
3418-
# is True, without using fancy indexing.
3419-
return [[*itertools.compress(array, mask)] for array in arrays]
3398+
# Vectorized fancy-indexer.
3399+
def apply_mask(arrays, mask): return [array[mask] for array in arrays]
34203400

34213401
def extract_err(name, err, data, lolims, uplims):
34223402
"""
@@ -3437,24 +3417,14 @@ def extract_err(name, err, data, lolims, uplims):
34373417
Error is only applied on **lower** side when this is True. See
34383418
the note in the main docstring about this parameter's name.
34393419
"""
3440-
try: # Asymmetric error: pair of 1D iterables.
3441-
a, b = err
3442-
iter(a)
3443-
iter(b)
3444-
except (TypeError, ValueError):
3445-
a = b = err # Symmetric error: 1D iterable.
3446-
if np.ndim(a) > 1 or np.ndim(b) > 1:
3420+
try:
3421+
low, high = np.broadcast_to(err, (2, len(data)))
3422+
except ValueError:
34473423
raise ValueError(
3448-
f"{name}err must be a scalar or a 1D or (2, n) array-like")
3449-
# Using list comprehensions rather than arrays to preserve units.
3450-
for e in [a, b]:
3451-
if len(data) != len(e):
3452-
raise ValueError(
3453-
f"The lengths of the data ({len(data)}) and the "
3454-
f"error {len(e)} do not match")
3455-
low = [v if lo else v - e for v, e, lo in zip(data, a, lolims)]
3456-
high = [v if up else v + e for v, e, up in zip(data, b, uplims)]
3457-
return low, high
3424+
f"'{name}err' (shape: {np.shape(err)}) must be a scalar "
3425+
f"or a 1D or (2, n) array-like whose shape matches "
3426+
f"'{name}' (shape: {np.shape(data)})") from None
3427+
return data - low * ~lolims, data + high * ~uplims # low, high
34583428

34593429
if xerr is not None:
34603430
left, right = extract_err('x', xerr, x, xlolims, xuplims)

lib/matplotlib/axes/_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,7 +2312,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
23122312
----------
23132313
datasets : list
23142314
List of (axis_name, dataset) pairs (where the axis name is defined
2315-
as in `._get_axis_map`.
2315+
as in `._get_axis_map`). Individual datasets can also be None
2316+
(which gets passed through).
23162317
kwargs : dict
23172318
Other parameters from which unit info (i.e., the *xunits*,
23182319
*yunits*, *zunits* (for 3D axes), *runits* and *thetaunits* (for
@@ -2359,7 +2360,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
23592360
for dataset_axis_name, data in datasets:
23602361
if dataset_axis_name == axis_name and data is not None:
23612362
axis.update_units(data)
2362-
return [axis_map[axis_name].convert_units(data) if convert else data
2363+
return [axis_map[axis_name].convert_units(data)
2364+
if convert and data is not None else data
23632365
for axis_name, data in datasets]
23642366

23652367
def in_axes(self, mouseevent):

0 commit comments

Comments
 (0)