Skip to content

ENH: Streamplot control for integration max step and error #29333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions doc/users/next_whats_new/streamplot_integration_control.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Streamplot integration control
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Two new options have been added to the `~.axes.Axes.streamplot` function that
give the user better control of the streamline integration. The first is called
``integration_max_step_scale`` and multiplies the default max step computed by the
integrator. The second is called ``integration_max_error_scale`` and multiplies the
default max error set by the integrator. Values for these parameters between
zero and one reduce (tighten) the max step or error to improve streamline
accuracy by performing more computation. Values greater than one increase
(loosen) the max step or error to reduce computation time at the cost of lower
streamline accuracy.

The integrator defaults are both hand-tuned values and may not be applicable to
all cases, so this allows customizing the behavior to specific use cases.
Modifying only ``integration_max_step_scale`` has proved effective, but it may be useful
to control the error as well.
87 changes: 87 additions & 0 deletions galleries/examples/images_contours_and_fields/plot_streamplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* Unbroken streamlines even when exceeding the limit of lines within a single
grid cell.
"""
import time

import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -74,6 +76,91 @@
axs[7].streamplot(X, Y, U, V, broken_streamlines=False)
axs[7].set_title('Streamplot with unbroken streamlines')

plt.tight_layout()
# plt.show()

# %%
# Streamline computation
# ----------------------
#
# The streamlines are computed by integrating along the provided vector field
# from the seed points, which are either automatically generated or manually
# specified. The accuracy and smoothness of the streamlines can be adjusted using
# the ``integration_max_step_scale`` and ``integration_max_error_scale`` optional
# parameters. See the `~.axes.Axes.streamplot` function documentation for more
# details.
#
# This example shows how adjusting the maximum allowed step size and error for
# the integrator changes the appearance of the streamline. The differences can
# be subtle, but can be observed particularly where the streamlines have
# high curvature (as shown in the zoomed in region).

# Linear potential flow over a lifting cylinder
n = 50
x, y = np.meshgrid(np.linspace(-2, 2, n), np.linspace(-3, 3, n))
th = np.arctan2(y, x)
r = np.sqrt(x**2 + y**2)
vr = -np.cos(th) / r**2
vt = -np.sin(th) / r**2 - 1 / r
vx = vr * np.cos(th) - vt * np.sin(th) + 1.0
vy = vr * np.sin(th) + vt * np.cos(th)

# Seed points
n_seed = 50
seed_pts = np.column_stack((np.full(n_seed, -1.75), np.linspace(-2, 2, n_seed)))

_, axs = plt.subplots(3, 1, figsize=(6, 14))
th_circ = np.linspace(0, 2 * np.pi, 100)
for ax, max_val in zip(axs, [0.05, 1, 5]):
ax_ins = ax.inset_axes([0.0, 0.7, 0.3, 0.35])
for ax_curr, is_inset in zip([ax, ax_ins], [False, True]):
t_start = time.time()
ax_curr.streamplot(
x,
y,
vx,
vy,
start_points=seed_pts,
broken_streamlines=False,
arrowsize=1e-10,
linewidth=2 if is_inset else 0.6,
color="k",
integration_max_step_scale=max_val,
integration_max_error_scale=max_val,
)
if is_inset:
t_total = time.time() - t_start

# Draw the cylinder
ax_curr.fill(
np.cos(th_circ),
np.sin(th_circ),
color="w",
ec="k",
lw=6 if is_inset else 2,
)

# Set axis properties
ax_curr.set_aspect("equal")

# Label properties of each circle
text = f"integration_max_step_scale: {max_val}\n" \
f"integration_max_error_scale: {max_val}\n" \
f"streamplot time: {t_total:.2f} sec"
if max_val == 1:
text += "\n(default)"
ax.text(0.0, 0.0, text, ha="center", va="center")

# Set axis limits and show zoomed region
ax_ins.set_xlim(-1.2, -0.7)
ax_ins.set_ylim(-0.8, -0.4)
ax_ins.set_yticks(())
ax_ins.set_xticks(())

ax.set_ylim(-1.5, 1.5)
ax.axis("off")
ax.indicate_inset_zoom(ax_ins, ec="k")

plt.tight_layout()
plt.show()
# %%
Expand Down
4 changes: 4 additions & 0 deletions lib/matplotlib/pyplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4128,6 +4128,8 @@ def streamplot(
integration_direction="both",
broken_streamlines=True,
*,
integration_max_step_scale=1.0,
integration_max_error_scale=1.0,
num_arrows=1,
data=None,
):
Expand All @@ -4150,6 +4152,8 @@ def streamplot(
maxlength=maxlength,
integration_direction=integration_direction,
broken_streamlines=broken_streamlines,
integration_max_step_scale=integration_max_step_scale,
integration_max_error_scale=integration_max_error_scale,
num_arrows=num_arrows,
**({"data": data} if data is not None else {}),
)
Expand Down
58 changes: 50 additions & 8 deletions lib/matplotlib/streamplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
cmap=None, norm=None, arrowsize=1, arrowstyle='-|>',
minlength=0.1, transform=None, zorder=None, start_points=None,
maxlength=4.0, integration_direction='both',
broken_streamlines=True, *, num_arrows=1):
broken_streamlines=True, *, integration_max_step_scale=1.0,
integration_max_error_scale=1.0, num_arrows=1):
"""
Draw streamlines of a vector flow.

Expand Down Expand Up @@ -73,6 +74,24 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
If False, forces streamlines to continue until they
leave the plot domain. If True, they may be terminated if they
come too close to another streamline.
integration_max_step_scale : float, default: 1.0
Multiplier on the maximum allowable step in the streamline integration routine.
A value between zero and one results in a max integration step smaller than
the default max step, resulting in more accurate streamlines at the cost
of greater computation time; a value greater than one does the converse. Must be
greater than zero.

.. versionadded:: 3.11

integration_max_error_scale : float, default: 1.0
Multiplier on the maximum allowable error in the streamline integration routine.
A value between zero and one results in a tighter max integration error than
the default max error, resulting in more accurate streamlines at the cost
of greater computation time; a value greater than one does the converse. Must be
greater than zero.

.. versionadded:: 3.11

num_arrows : int
Number of arrows per streamline. The arrows are spaced equally along the steps
each streamline takes. Note that this can be different to being spaced equally
Expand All @@ -97,6 +116,18 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
mask = StreamMask(density)
dmap = DomainMap(grid, mask)

if integration_max_step_scale <= 0.0:
raise ValueError(
"The value of integration_max_step_scale must be > 0, " +
f"got {integration_max_step_scale}"
)

if integration_max_error_scale <= 0.0:
raise ValueError(
"The value of integration_max_error_scale must be > 0, " +
f"got {integration_max_error_scale}"
)

if num_arrows < 0:
raise ValueError(f"The value of num_arrows must be >= 0, got {num_arrows=}")

Expand Down Expand Up @@ -160,7 +191,9 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
for xm, ym in _gen_starting_points(mask.shape):
if mask[ym, xm] == 0:
xg, yg = dmap.mask2grid(xm, ym)
t = integrate(xg, yg, broken_streamlines)
t = integrate(xg, yg, broken_streamlines,
integration_max_step_scale,
integration_max_error_scale)
if t is not None:
trajectories.append(t)
else:
Expand Down Expand Up @@ -188,7 +221,8 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
xg = np.clip(xg, 0, grid.nx - 1)
yg = np.clip(yg, 0, grid.ny - 1)

t = integrate(xg, yg, broken_streamlines)
t = integrate(xg, yg, broken_streamlines, integration_max_step_scale,
integration_max_error_scale)
if t is not None:
trajectories.append(t)

Expand Down Expand Up @@ -481,7 +515,8 @@ def backward_time(xi, yi):
dxi, dyi = forward_time(xi, yi)
return -dxi, -dyi

def integrate(x0, y0, broken_streamlines=True):
def integrate(x0, y0, broken_streamlines=True, integration_max_step_scale=1.0,
integration_max_error_scale=1.0):
"""
Return x, y grid-coordinates of trajectory based on starting point.

Expand All @@ -501,14 +536,18 @@ def integrate(x0, y0, broken_streamlines=True):
return None
if integration_direction in ['both', 'backward']:
s, xyt = _integrate_rk12(x0, y0, dmap, backward_time, maxlength,
broken_streamlines)
broken_streamlines,
integration_max_step_scale,
integration_max_error_scale)
stotal += s
xy_traj += xyt[::-1]

if integration_direction in ['both', 'forward']:
dmap.reset_start_point(x0, y0)
s, xyt = _integrate_rk12(x0, y0, dmap, forward_time, maxlength,
broken_streamlines)
broken_streamlines,
integration_max_step_scale,
integration_max_error_scale)
stotal += s
xy_traj += xyt[1:]

Expand All @@ -525,7 +564,9 @@ class OutOfBounds(IndexError):
pass


def _integrate_rk12(x0, y0, dmap, f, maxlength, broken_streamlines=True):
def _integrate_rk12(x0, y0, dmap, f, maxlength, broken_streamlines=True,
integration_max_step_scale=1.0,
integration_max_error_scale=1.0):
"""
2nd-order Runge-Kutta algorithm with adaptive step size.

Expand All @@ -551,7 +592,7 @@ def _integrate_rk12(x0, y0, dmap, f, maxlength, broken_streamlines=True):
# This error is below that needed to match the RK4 integrator. It
# is set for visual reasons -- too low and corners start
# appearing ugly and jagged. Can be tuned.
maxerror = 0.003
maxerror = 0.003 * integration_max_error_scale

# This limit is important (for all integrators) to avoid the
# trajectory skipping some mask cells. We could relax this
Expand All @@ -560,6 +601,7 @@ def _integrate_rk12(x0, y0, dmap, f, maxlength, broken_streamlines=True):
# nature of the interpolation, this doesn't boost speed by much
# for quite a bit of complexity.
maxds = min(1. / dmap.mask.nx, 1. / dmap.mask.ny, 0.1)
maxds *= integration_max_step_scale

ds = maxds
stotal = 0
Expand Down
2 changes: 2 additions & 0 deletions lib/matplotlib/streamplot.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def streamplot(
integration_direction: Literal["forward", "backward", "both"] = ...,
broken_streamlines: bool = ...,
*,
integration_max_step_scale: float = ...,
integration_max_error_scale: float = ...,
num_arrows: int = ...,
) -> StreamplotSet: ...

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
76 changes: 74 additions & 2 deletions lib/matplotlib/tests/test_streamplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,66 @@ def test_direction():
linewidth=2, density=2)


@image_comparison(['streamplot_integration.png'], style='mpl20', tol=0.05)
def test_integration_options():
# Linear potential flow over a lifting cylinder
n = 50
x, y = np.meshgrid(np.linspace(-2, 2, n), np.linspace(-3, 3, n))
th = np.arctan2(y, x)
r = np.sqrt(x**2 + y**2)
vr = -np.cos(th) / r**2
vt = -np.sin(th) / r**2 - 1 / r
vx = vr * np.cos(th) - vt * np.sin(th) + 1.0
vy = vr * np.sin(th) + vt * np.cos(th)

# Seed points
n_seed = 50
seed_pts = np.column_stack((np.full(n_seed, -1.75), np.linspace(-2, 2, n_seed)))

fig, axs = plt.subplots(3, 1, figsize=(6, 14))
th_circ = np.linspace(0, 2 * np.pi, 100)
for ax, max_val in zip(axs, [0.05, 1, 5]):
ax_ins = ax.inset_axes([0.0, 0.7, 0.3, 0.35])
for ax_curr, is_inset in zip([ax, ax_ins], [False, True]):
ax_curr.streamplot(
x,
y,
vx,
vy,
start_points=seed_pts,
broken_streamlines=False,
arrowsize=1e-10,
linewidth=2 if is_inset else 0.6,
color="k",
integration_max_step_scale=max_val,
integration_max_error_scale=max_val,
)

# Draw the cylinder
ax_curr.fill(
np.cos(th_circ),
np.sin(th_circ),
color="w",
ec="k",
lw=6 if is_inset else 2,
)

# Set axis properties
ax_curr.set_aspect("equal")

# Set axis limits and show zoomed region
ax_ins.set_xlim(-1.2, -0.7)
ax_ins.set_ylim(-0.8, -0.4)
ax_ins.set_yticks(())
ax_ins.set_xticks(())

ax.set_ylim(-1.5, 1.5)
ax.axis("off")
ax.indicate_inset_zoom(ax_ins, ec="k")

fig.tight_layout()


def test_streamplot_limits():
ax = plt.axes()
x = np.linspace(-5, 10, 20)
Expand Down Expand Up @@ -156,8 +216,20 @@ def test_streamplot_grid():
x = np.array([0, 20, 40])
y = np.array([0, 20, 10])

with pytest.raises(ValueError, match="'y' must be strictly increasing"):
plt.streamplot(x, y, u, v)

def test_streamplot_integration_params():
x = np.array([[10, 20], [10, 20]])
y = np.array([[10, 10], [20, 20]])
u = np.ones((2, 2))
v = np.zeros((2, 2))

err_str = "The value of integration_max_step_scale must be > 0, got -0.5"
with pytest.raises(ValueError, match=err_str):
plt.streamplot(x, y, u, v, integration_max_step_scale=-0.5)

err_str = "The value of integration_max_error_scale must be > 0, got 0.0"
with pytest.raises(ValueError, match=err_str):
plt.streamplot(x, y, u, v, integration_max_error_scale=0.0)


def test_streamplot_inputs(): # test no exception occurs.
Expand Down
Loading