Skip to content

Add sizes to scatter plots #289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 15, 2023
56 changes: 56 additions & 0 deletions examples/desktop/scatter/scatter_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Scatter Plot
============
Example showing point size change for scatter plot.
"""

# test_example = true
import numpy as np
import fastplotlib as fpl

# grid with 2 rows and 3 columns
grid_shape = (2,1)

# pan-zoom controllers for each view
# views are synced if they have the
# same controller ID
controllers = [
[0],
[0]
]


# you can give string names for each subplot within the gridplot
names = [
["scalar_size"],
["array_size"]
]

# Create the grid plot
plot = fpl.GridPlot(
shape=grid_shape,
controllers=controllers,
names=names,
size=(1000, 1000)
)

# get y_values using sin function
angles = np.arange(0, 20*np.pi+0.001, np.pi / 20)
y_values = 30*np.sin(angles) # 1 thousand points
x_values = np.array([x for x in range(len(y_values))], dtype=np.float32)

data = np.column_stack([x_values, y_values])

plot["scalar_size"].add_scatter(data=data, sizes=5, colors="blue") # add a set of scalar sizes

non_scalar_sizes = np.abs((y_values / np.pi)) # ensure minimum size of 5
plot["array_size"].add_scatter(data=data, sizes=non_scalar_sizes, colors="red")

for graph in plot:
graph.auto_scale(maintain_aspect=True)

plot.show()

if __name__ == "__main__":
print(__doc__)
fpl.run()
3 changes: 3 additions & 0 deletions examples/desktop/screenshots/scatter_size.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
71 changes: 71 additions & 0 deletions examples/notebooks/scatter_sizes_animation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from time import time\n",
"\n",
"import numpy as np\n",
"import fastplotlib as fpl\n",
"\n",
"plot = fpl.Plot()\n",
"\n",
"points = np.array([[-1,0,1],[-1,0,1]], dtype=np.float32).swapaxes(0,1)\n",
"size_delta_scales = np.array([10, 40, 100], dtype=np.float32)\n",
"min_sizes = 6\n",
"\n",
"def update_positions():\n",
" current_time = time()\n",
" newPositions = points + np.sin(((current_time / 4) % 1)*np.pi)\n",
" plot.graphics[0].data = newPositions\n",
" plot.camera.width = 4*np.max(newPositions[0,:])\n",
" plot.camera.height = 4*np.max(newPositions[1,:])\n",
"\n",
"def update_sizes():\n",
" current_time = time()\n",
" sin_sample = np.sin(((current_time / 4) % 1)*np.pi)\n",
" size_delta = sin_sample*size_delta_scales\n",
" plot.graphics[0].sizes = min_sizes + size_delta\n",
"\n",
"points = np.array([[0,0], \n",
" [1,1], \n",
" [2,2]])\n",
"scatter = plot.add_scatter(points, colors=[\"red\", \"green\", \"blue\"], sizes=12)\n",
"plot.add_animations(update_positions, update_sizes)\n",
"plot.show(autoscale=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "fastplotlib-dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
86 changes: 86 additions & 0 deletions examples/notebooks/scatter_sizes_grid.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"Scatter Plot\n",
"============\n",
"Example showing point size change for scatter plot.\n",
"\"\"\"\n",
"\n",
"# test_example = true\n",
"import numpy as np\n",
"import fastplotlib as fpl\n",
"\n",
"# grid with 2 rows and 3 columns\n",
"grid_shape = (2,1)\n",
"\n",
"# pan-zoom controllers for each view\n",
"# views are synced if they have the \n",
"# same controller ID\n",
"controllers = [\n",
" [0],\n",
" [0]\n",
"]\n",
"\n",
"\n",
"# you can give string names for each subplot within the gridplot\n",
"names = [\n",
" [\"scalar_size\"],\n",
" [\"array_size\"]\n",
"]\n",
"\n",
"# Create the grid plot\n",
"plot = fpl.GridPlot(\n",
" shape=grid_shape,\n",
" controllers=controllers,\n",
" names=names,\n",
" size=(1000, 1000)\n",
")\n",
"\n",
"# get y_values using sin function\n",
"angles = np.arange(0, 20*np.pi+0.001, np.pi / 20)\n",
"y_values = 30*np.sin(angles) # 1 thousand points\n",
"x_values = np.array([x for x in range(len(y_values))], dtype=np.float32)\n",
"\n",
"data = np.column_stack([x_values, y_values])\n",
"\n",
"plot[\"scalar_size\"].add_scatter(data=data, sizes=5, colors=\"blue\") # add a set of scalar sizes\n",
"\n",
"non_scalar_sizes = np.abs((y_values / np.pi)) # ensure minimum size of 5\n",
"plot[\"array_size\"].add_scatter(data=data, sizes=non_scalar_sizes, colors=\"red\")\n",
"\n",
"for graph in plot:\n",
" graph.auto_scale(maintain_aspect=True)\n",
"\n",
"plot.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "fastplotlib-dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 2 additions & 0 deletions fastplotlib/graphics/_features/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._colors import ColorFeature, CmapFeature, ImageCmapFeature, HeatmapCmapFeature
from ._data import PointsDataFeature, ImageDataFeature, HeatmapDataFeature
from ._sizes import PointsSizesFeature
from ._present import PresentFeature
from ._thickness import ThicknessFeature
from ._base import GraphicFeature, GraphicFeatureIndexable, FeatureEvent, to_gpu_supported_dtype
Expand All @@ -11,6 +12,7 @@
"ImageCmapFeature",
"HeatmapCmapFeature",
"PointsDataFeature",
"PointsSizesFeature",
"ImageDataFeature",
"HeatmapDataFeature",
"PresentFeature",
Expand Down
108 changes: 108 additions & 0 deletions fastplotlib/graphics/_features/_sizes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import Any

import numpy as np

import pygfx

from ._base import (
GraphicFeatureIndexable,
cleanup_slice,
FeatureEvent,
to_gpu_supported_dtype,
cleanup_array_slice,
)


class PointsSizesFeature(GraphicFeatureIndexable):
"""
Access to the vertex buffer data shown in the graphic.
Supports fancy indexing if the data array also supports it.
"""

def __init__(self, parent, sizes: Any, collection_index: int = None):
sizes = self._fix_sizes(sizes, parent)
super(PointsSizesFeature, self).__init__(
parent, sizes, collection_index=collection_index
)

@property
def buffer(self) -> pygfx.Buffer:
return self._parent.world_object.geometry.sizes

def __getitem__(self, item):
return self.buffer.data[item]

def _fix_sizes(self, sizes, parent):
graphic_type = parent.__class__.__name__

n_datapoints = parent.data().shape[0]
if not isinstance(sizes, (list, tuple, np.ndarray)):
sizes = np.full(n_datapoints, sizes, dtype=np.float32) # force it into a float to avoid weird gpu errors
elif not isinstance(sizes, np.ndarray): # if it's not a ndarray already, make it one
sizes = np.array(sizes, dtype=np.float32) # read it in as a numpy.float32
if (sizes.ndim != 1) or (sizes.size != parent.data().shape[0]):
raise ValueError(
f"sequence of `sizes` must be 1 dimensional with "
f"the same length as the number of datapoints"
)

sizes = to_gpu_supported_dtype(sizes)

if any(s < 0 for s in sizes):
raise ValueError("All sizes must be positive numbers greater than or equal to 0.0.")

if sizes.ndim == 1:
if graphic_type == "ScatterGraphic":
sizes = np.array(sizes)
else:
raise ValueError(f"Sizes must be an array of shape (n,) where n == the number of data points provided.\
Received shape={sizes.shape}.")

return np.array(sizes)

def __setitem__(self, key, value):
if isinstance(key, np.ndarray):
# make sure 1D array of int or boolean
key = cleanup_array_slice(key, self._upper_bound)

# put sizes into right shape if they're only indexing datapoints
if isinstance(key, (slice, int, np.ndarray, np.integer)):
value = self._fix_sizes(value, self._parent)
# otherwise assume that they have the right shape
# numpy will throw errors if it can't broadcast

if value.size != self.buffer.data[key].size:
raise ValueError(f"{value.size} is not equal to buffer size {self.buffer.data[key].size}.\
If you want to set size to a non-scalar value, make sure it's the right length!")

self.buffer.data[key] = value
self._update_range(key)
# avoid creating dicts constantly if there are no events to handle
if len(self._event_handlers) > 0:
self._feature_changed(key, value)

def _update_range(self, key):
self._update_range_indices(key)

def _feature_changed(self, key, new_data):
if key is not None:
key = cleanup_slice(key, self._upper_bound)
if isinstance(key, (int, np.integer)):
indices = [key]
elif isinstance(key, slice):
indices = range(key.start, key.stop, key.step)
elif isinstance(key, np.ndarray):
indices = key
elif key is None:
indices = None

pick_info = {
"index": indices,
"collection-index": self._collection_index,
"world_object": self._parent.world_object,
"new_data": new_data,
}

event_data = FeatureEvent(type="sizes", pick_info=pick_info)

self._call_event_handlers(event_data)
23 changes: 5 additions & 18 deletions fastplotlib/graphics/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@

from ..utils import parse_cmap_values
from ._base import Graphic
from ._features import PointsDataFeature, ColorFeature, CmapFeature
from ._features import PointsDataFeature, ColorFeature, CmapFeature, PointsSizesFeature


class ScatterGraphic(Graphic):
feature_events = ("data", "colors", "cmap", "present")
feature_events = ("data", "sizes", "colors", "cmap", "present")

def __init__(
self,
data: np.ndarray,
sizes: Union[int, np.ndarray, list] = 1,
sizes: Union[int, float, np.ndarray, list] = 1,
colors: np.ndarray = "w",
alpha: float = 1.0,
cmap: str = None,
Expand Down Expand Up @@ -86,24 +86,11 @@ def __init__(
self, self.colors(), cmap_name=cmap, cmap_values=cmap_values
)

if isinstance(sizes, int):
sizes = np.full(self.data().shape[0], sizes, dtype=np.float32)
elif isinstance(sizes, np.ndarray):
if (sizes.ndim != 1) or (sizes.size != self.data().shape[0]):
raise ValueError(
f"numpy array of `sizes` must be 1 dimensional with "
f"the same length as the number of datapoints"
)
elif isinstance(sizes, list):
if len(sizes) != self.data().shape[0]:
raise ValueError(
"list of `sizes` must have the same length as the number of datapoints"
)

self.sizes = PointsSizesFeature(self, sizes)
super(ScatterGraphic, self).__init__(*args, **kwargs)

world_object = pygfx.Points(
pygfx.Geometry(positions=self.data(), sizes=sizes, colors=self.colors()),
pygfx.Geometry(positions=self.data(), sizes=self.sizes(), colors=self.colors()),
material=pygfx.PointsMaterial(vertex_colors=True, vertex_sizes=True),
)

Expand Down