-
-
Notifications
You must be signed in to change notification settings - Fork 7.9k
Handle NaN values in plot_surface
zsort
#20646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@eric-wieser @WeatherGod did you have any comments on whether this is preferred to #18114? |
@jklymak just to be sure there is no confusion, I think this is potentially useful regardless of #18114. The relationship with #18114 is that if this PR is merged, then we can solve the masking issue in a much simpler way that what I did in #18114. But if it decided that you don't want to support NaNs in |
lib/mpl_toolkits/mplot3d/art3d.py
Outdated
def nansafe(func): | ||
def f(x): | ||
value = func(x) | ||
return np.inf if np.isnan(value) else value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this actually correct behavior? Would it be better to remove the NaNs before computing the mean /min / max / whatever?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm yes I think we need both. I was just looking at plots where polys containing NaNs are completely masked, but in cases where you have just one masked corner the rendering is still wrong. One way to solve this would be to replace the _zsort_functions
with their nan
equivalents and then work around the RuntimeWarning
raised if all elements of the tested array are NaN
.
So
matplotlib/lib/mpl_toolkits/mplot3d/art3d.py
Lines 720 to 724 in 498bcdd
_zsort_functions = { | |
'average': np.average, | |
'min': np.min, | |
'max': np.max, | |
} |
would be replaced with
_zsort_functions = {
'average': np.nanmean, # mean and average the same with unweighted data?
'min': np.nanmin,
'max': np.nanmax,
}
and then we'd rewrite the nansafe
wrapper (or the method like you suggest) to be
def nansafe(func):
def f(x):
return np.inf if np.isnan(x).all() else func(x)
return f
Do you think that would be a solution? Do you see any problem with replacing the functions inside _zsort_functions
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't had time to look into the details here. But keep in mind that the sorting is done everytime you rotate the view, which we want to be a real-time response. No question that filtering nans is necessary for correctness, but if you have multiple ways of doing that, please evaluate the performance.
self._sort_zpos = None | ||
self.stale = True | ||
|
||
def _zsortval(self, zs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _zsortval(self, zs): | |
def _zsortval(self, zs): | |
""" Compute the value to use for z-sorting given the viewer z coordinates of an object `zs`, with | |
larger values drawn underneath smaller values. | |
This function should never return `nan`, and returns `np.inf` if no non-nan value is computable. | |
""" |
In response to the comment from @timhoffm I had a quick look at performance in a rather unscientific manner (please let me know if you want something better). I adapted the example from #12395 to perform a full rotation and looked at how long it took. Test codeimport numpy as np
import matplotlib.pyplot as plt
# Generate the data
x = np.linspace(-1.0, 1.0, 50)
y = np.linspace(-1.0, 1.0, 50)
x, y = np.meshgrid(x, y)
z = (1 - y / x).clip(min=-5.0, max=5.0)
# place NaNs at the discontinuity
pos = np.where(np.abs(np.diff(z)) >= 5.0)
z[pos] = np.nan
# Create the plot
ax = plt.figure().add_subplot(projection="3d")
surf = ax.plot_surface(
x,
y,
z,
rstride=1,
cstride=1,
cmap="coolwarm",
linewidth=0,
vmin=np.nanmin(z),
vmax=np.nanmax(z),
antialiased=False,
)
ax.set_title("1 - y/x")
ax.set(xlim=(-1, 1), ylim=(-1, 1), zlim=(-5, 5))
ax.set(xlabel="x", ylabel="y", zlabel="z")
# rotate the axes and update
for angle in range(0, 360):
ax.view_init(30, angle)
plt.draw()
plt.pause(0.001) Wall times:
I also tried using the with I ran these a few times and the results were pretty consistent, although I realise this is a pretty rubbish way to test this kind of thing. |
Can you cache the NaN's somehow? I've not traced this code carefully, but is there any reason to keep NaNs in the list of vertices and colors etc? |
Even caching a mask of where the nans are would probably help. |
@jklymak Yeah I think this can be solved by just stripping any NaNs before making the PolyCollection3D. I actually have another branch that does exactly that in What would be the preferred way forward? To open a new PR (third time lucky!) or force push here? |
Whatever is easiest for you, but if you opened new then make sure to include relevant comparisons |
Ok thanks, I've converted this to a draft while I sort everything out. I'll make sure to clearly include before and after comparisons of the changes. Thanks again for all the comments so far. |
Replaced by #20725, I believe. |
PR Summary
Closes #8222 #12395
While investigating different ways of solving #18114 I came across the issue where NaNs in data causing strange plotting issues. I tracked this down to the sorting of polygons where the python builtin
sorted
method is used. This doesn’t deal with NaNs so we need to be sure to make the zsort functions NaN safe. This does that by checking if the function returns NaN and if so returningsys.maxsize
.I want to separate this from #18114 as I found two open issues I think it solves, #8222 and #12395.
plot_surface
does warn that it can't deal with NaNs but it looks like that was added in response to #12395. This seems to solve that, so perhaps the warning can be removed?I still need to add some tests but I wanted to check there was interest in adding this before writing those. It seems like a fairly simple problem to solve so maybe there is some deeper reason this hasn't been done?
The before and after plots of the two issues are
#8222 Before
#8222 After
#12395 Before
#12395 After
PR Checklist
pytest
passes).flake8
on changed files to check).flake8-docstrings
and runflake8 --docstring-convention=all
).doc/users/next_whats_new/
(follow instructions in README.rst there).doc/api/next_api_changes/
(follow instructions in README.rst there).