Skip to content

Commit fd417d2

Browse files
authored
Merge pull request #143 from kushalkolar/large-images
HeatmapGraphic, supports dims larger than 8192
2 parents 70b4908 + e464925 commit fd417d2

File tree

7 files changed

+268
-19
lines changed

7 files changed

+268
-19
lines changed

fastplotlib/graphics/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from .histogram import HistogramGraphic
22
from .line import LineGraphic
33
from .scatter import ScatterGraphic
4-
from .image import ImageGraphic
5-
from .heatmap import HeatmapGraphic
4+
from .image import ImageGraphic, HeatmapGraphic
5+
# from .heatmap import HeatmapGraphic
66
from .text import TextGraphic
77
from .line_collection import LineCollection, LineStack
88

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from ._colors import ColorFeature, CmapFeature, ImageCmapFeature
2-
from ._data import PointsDataFeature, ImageDataFeature
1+
from ._colors import ColorFeature, CmapFeature, ImageCmapFeature, HeatmapCmapFeature
2+
from ._data import PointsDataFeature, ImageDataFeature, HeatmapDataFeature
33
from ._present import PresentFeature
44
from ._thickness import ThicknessFeature
55
from ._base import GraphicFeature, GraphicFeatureIndexable

fastplotlib/graphics/features/_base.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,33 @@
77
from pygfx import Buffer
88

99

10+
supported_dtypes = [
11+
np.uint8,
12+
np.uint16,
13+
np.uint32,
14+
np.int8,
15+
np.int16,
16+
np.int32,
17+
np.float16,
18+
np.float32
19+
]
20+
21+
22+
def to_gpu_supported_dtype(array):
23+
if isinstance(array, np.ndarray):
24+
if array.dtype not in supported_dtypes:
25+
if np.issubdtype(array.dtype, np.integer):
26+
warn(f"converting {array.dtype} array to int32")
27+
return array.astype(np.int32)
28+
elif np.issubdtype(array.dtype, np.floating):
29+
warn(f"converting {array.dtype} array to float32")
30+
return array.astype(np.float32, copy=False)
31+
else:
32+
raise TypeError("Unsupported type, supported array types must be int or float dtypes")
33+
34+
return array
35+
36+
1037
class FeatureEvent:
1138
"""
1239
type: <feature_name>, example: "colors"
@@ -43,7 +70,7 @@ def __init__(self, parent, data: Any, collection_index: int = None):
4370
"""
4471
self._parent = parent
4572
if isinstance(data, np.ndarray):
46-
data = data.astype(np.float32)
73+
data = to_gpu_supported_dtype(data)
4774

4875
self._data = data
4976

@@ -227,3 +254,4 @@ def _update_range_indices(self, key):
227254
self._buffer.update_range(ix, size=1)
228255
else:
229256
raise TypeError("must pass int or slice to update range")
257+

fastplotlib/graphics/features/_colors.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,16 @@ def _feature_changed(self, key, new_data):
238238
event_data = FeatureEvent(type="cmap", pick_info=pick_info)
239239

240240
self._call_event_handlers(event_data)
241+
242+
243+
class HeatmapCmapFeature(ImageCmapFeature):
244+
"""
245+
Colormap for HeatmapGraphic
246+
"""
247+
248+
def _set(self, cmap_name: str):
249+
self._parent._material.map.texture.data[:] = make_colors(256, cmap_name)
250+
self._parent._material.map.texture.update_range((0, 0, 0), size=(256, 1, 1))
251+
self.name = cmap_name
252+
253+
self._feature_changed(key=None, new_data=self.name)

fastplotlib/graphics/features/_data.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,7 @@
33
import numpy as np
44
from pygfx import Buffer, Texture
55

6-
from ._base import GraphicFeatureIndexable, cleanup_slice, FeatureEvent
7-
8-
9-
def to_float32(array):
10-
if isinstance(array, np.ndarray):
11-
return array.astype(np.float32, copy=False)
12-
13-
return array
6+
from ._base import GraphicFeatureIndexable, cleanup_slice, FeatureEvent, to_gpu_supported_dtype
147

158

169
class PointsDataFeature(GraphicFeatureIndexable):
@@ -102,7 +95,7 @@ def __init__(self, parent, data: Any):
10295
"``[x_dim, y_dim]`` or ``[x_dim, y_dim, rgb]``"
10396
)
10497

105-
data = to_float32(data)
98+
data = to_gpu_supported_dtype(data)
10699
super(ImageDataFeature, self).__init__(parent, data)
107100

108101
@property
@@ -114,7 +107,7 @@ def __getitem__(self, item):
114107

115108
def __setitem__(self, key, value):
116109
# make sure float32
117-
value = to_float32(value)
110+
value = to_gpu_supported_dtype(value)
118111

119112
self._buffer.data[key] = value
120113
self._update_range(key)
@@ -145,3 +138,47 @@ def _feature_changed(self, key, new_data):
145138
event_data = FeatureEvent(type="data", pick_info=pick_info)
146139

147140
self._call_event_handlers(event_data)
141+
142+
143+
class HeatmapDataFeature(ImageDataFeature):
144+
@property
145+
def _buffer(self) -> List[Texture]:
146+
return [img.geometry.grid.texture for img in self._parent.world_object.children]
147+
148+
def __getitem__(self, item):
149+
return self._data[item]
150+
151+
def __setitem__(self, key, value):
152+
# make sure supported type, not float64 etc.
153+
value = to_gpu_supported_dtype(value)
154+
155+
self._data[key] = value
156+
self._update_range(key)
157+
158+
# avoid creating dicts constantly if there are no events to handle
159+
if len(self._event_handlers) > 0:
160+
self._feature_changed(key, value)
161+
162+
def _update_range(self, key):
163+
for buffer in self._buffer:
164+
buffer.update_range((0, 0, 0), size=buffer.size)
165+
166+
def _feature_changed(self, key, new_data):
167+
if key is not None:
168+
key = cleanup_slice(key, self._upper_bound)
169+
if isinstance(key, int):
170+
indices = [key]
171+
elif isinstance(key, slice):
172+
indices = range(key.start, key.stop, key.step)
173+
elif key is None:
174+
indices = None
175+
176+
pick_info = {
177+
"index": indices,
178+
"world_object": self._parent.world_object,
179+
"new_data": new_data
180+
}
181+
182+
event_data = FeatureEvent(type="data", pick_info=pick_info)
183+
184+
self._call_event_handlers(event_data)

fastplotlib/graphics/image.py

Lines changed: 175 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from typing import *
2+
from math import ceil
3+
from itertools import product
24

35
import pygfx
6+
from pygfx.utils import unpack_bitfield
47

58
from ._base import Graphic, Interaction, PreviouslyModifiedData
6-
from .features import ImageCmapFeature, ImageDataFeature
9+
from .features import ImageCmapFeature, ImageDataFeature, HeatmapDataFeature, HeatmapCmapFeature
710
from ..utils import quick_min_max
811

912

@@ -119,5 +122,176 @@ def _reset_feature(self, feature: str):
119122
pass
120123

121124

125+
class _ImageTile(pygfx.Image):
126+
"""
127+
Similar to pygfx.Image, only difference is that it contains a few properties to keep track of
128+
row chunk index, column chunk index
122129
123130
131+
"""
132+
def _wgpu_get_pick_info(self, pick_value):
133+
tex = self.geometry.grid
134+
if hasattr(tex, "texture"):
135+
tex = tex.texture # tex was a view
136+
# This should match with the shader
137+
values = unpack_bitfield(pick_value, wobject_id=20, x=22, y=22)
138+
x = values["x"] / 4194304 * tex.size[0] - 0.5
139+
y = values["y"] / 4194304 * tex.size[1] - 0.5
140+
ix, iy = int(x + 0.5), int(y + 0.5)
141+
return {
142+
"index": (ix, iy),
143+
"pixel_coord": (x - ix, y - iy),
144+
"row_chunk_index": self.row_chunk_index,
145+
"col_chunk_index": self.col_chunk_index
146+
}
147+
148+
@property
149+
def row_chunk_index(self) -> int:
150+
return self._row_chunk_index
151+
152+
@row_chunk_index.setter
153+
def row_chunk_index(self, index: int):
154+
self._row_chunk_index = index
155+
156+
@property
157+
def col_chunk_index(self) -> int:
158+
return self._col_chunk_index
159+
160+
@col_chunk_index.setter
161+
def col_chunk_index(self, index: int):
162+
self._col_chunk_index = index
163+
164+
165+
class HeatmapGraphic(Graphic, Interaction):
166+
feature_events = (
167+
"data",
168+
"cmap",
169+
)
170+
171+
def __init__(
172+
self,
173+
data: Any,
174+
vmin: int = None,
175+
vmax: int = None,
176+
cmap: str = 'plasma',
177+
filter: str = "nearest",
178+
chunk_size: int = 8192,
179+
*args,
180+
**kwargs
181+
):
182+
"""
183+
Create an Image Graphic
184+
185+
Parameters
186+
----------
187+
data: array-like
188+
array-like, usually numpy.ndarray, must support ``memoryview()``
189+
Tensorflow Tensors also work **probably**, but not thoroughly tested
190+
| shape must be ``[x_dim, y_dim]``
191+
vmin: int, optional
192+
minimum value for color scaling, calculated from data if not provided
193+
vmax: int, optional
194+
maximum value for color scaling, calculated from data if not provided
195+
cmap: str, optional, default "plasma"
196+
colormap to use to display the data
197+
filter: str, optional, default "nearest"
198+
interpolation filter, one of "nearest" or "linear"
199+
chunk_size: int, default 8192, max 8192
200+
chunk size for each tile used to make up the heatmap texture
201+
args:
202+
additional arguments passed to Graphic
203+
kwargs:
204+
additional keyword arguments passed to Graphic
205+
206+
Examples
207+
--------
208+
.. code-block:: python
209+
210+
from fastplotlib import Plot
211+
# create a `Plot` instance
212+
plot = Plot()
213+
# make some random 2D image data
214+
data = np.random.rand(512, 512)
215+
# plot the image data
216+
plot.add_image(data=data)
217+
# show the plot
218+
plot.show()
219+
"""
220+
221+
super().__init__(*args, **kwargs)
222+
223+
if chunk_size > 8192:
224+
raise ValueError("Maximum chunk size is 8192")
225+
226+
self.data = HeatmapDataFeature(self, data)
227+
228+
row_chunks = range(ceil(data.shape[0] / chunk_size))
229+
col_chunks = range(ceil(data.shape[1] / chunk_size))
230+
231+
chunks = list(product(row_chunks, col_chunks))
232+
# chunks is the index position of each chunk
233+
234+
start_ixs = [list(map(lambda c: c * chunk_size, chunk)) for chunk in chunks]
235+
stop_ixs = [list(map(lambda c: c + chunk_size, chunk)) for chunk in start_ixs]
236+
237+
self._world_object = pygfx.Group()
238+
239+
if (vmin is None) or (vmax is None):
240+
vmin, vmax = quick_min_max(data)
241+
242+
self.cmap = HeatmapCmapFeature(self, cmap)
243+
self._material = pygfx.ImageBasicMaterial(clim=(vmin, vmax), map=self.cmap())
244+
245+
for start, stop, chunk in zip(start_ixs, stop_ixs, chunks):
246+
row_start, col_start = start
247+
row_stop, col_stop = stop
248+
249+
# x and y positions of the Tile in world space coordinates
250+
y_pos, x_pos = row_start, col_start
251+
252+
tex_view = pygfx.Texture(data[row_start:row_stop, col_start:col_stop], dim=2).get_view(filter=filter)
253+
geometry = pygfx.Geometry(grid=tex_view)
254+
# material = pygfx.ImageBasicMaterial(clim=(0, 1), map=self.cmap())
255+
256+
img = _ImageTile(geometry, self._material)
257+
258+
# row and column chunk index for this Tile
259+
img.row_chunk_index = chunk[0]
260+
img.col_chunk_index = chunk[1]
261+
262+
img.position.set_x(x_pos)
263+
img.position.set_y(y_pos)
264+
265+
self.world_object.add(img)
266+
267+
@property
268+
def vmin(self) -> float:
269+
"""Minimum contrast limit."""
270+
return self._material.clim[0]
271+
272+
@vmin.setter
273+
def vmin(self, value: float):
274+
"""Minimum contrast limit."""
275+
self._material.clim = (
276+
value,
277+
self._material.clim[1]
278+
)
279+
280+
@property
281+
def vmax(self) -> float:
282+
"""Maximum contrast limit."""
283+
return self._material.clim[1]
284+
285+
@vmax.setter
286+
def vmax(self, value: float):
287+
"""Maximum contrast limit."""
288+
self._material.clim = (
289+
self._material.clim[0],
290+
value
291+
)
292+
293+
def _set_feature(self, feature: str, new_data: Any, indices: Any):
294+
pass
295+
296+
def _reset_feature(self, feature: str):
297+
pass

fastplotlib/layouts/_subplot.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,6 @@ def add_graphic(self, graphic, center: bool = True):
255255
graphic.world_object.position.z = len(self._graphics)
256256
super(Subplot, self).add_graphic(graphic, center)
257257

258-
if isinstance(graphic, graphics.HeatmapGraphic):
259-
self.controller.scale.y = copysign(self.controller.scale.y, -1)
260-
261258
def set_axes_visibility(self, visible: bool):
262259
"""Toggles axes visibility."""
263260
if visible:

0 commit comments

Comments
 (0)