Skip to content

Commit 113e307

Browse files
authored
Merge pull request #216 from jo-mueller/color-histogram-bars
Color histogram bars
2 parents 4516bd3 + acfb10e commit 113e307

File tree

5 files changed

+60
-14
lines changed

5 files changed

+60
-14
lines changed

src/napari_matplotlib/histogram.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Any, Optional
1+
from typing import Any, Optional, cast
22

33
import napari
44
import numpy as np
55
import numpy.typing as npt
6+
from matplotlib.container import BarContainer
67
from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget
78

89
from .base import SingleAxesWidget
@@ -162,12 +163,39 @@ def on_update_layers(self) -> None:
162163

163164
def draw(self) -> None:
164165
"""Clear the axes and histogram the currently selected layer/slice."""
166+
# get the colormap from the layer depending on its type
167+
if isinstance(self.layers[0], napari.layers.Points):
168+
colormap = self.layers[0].face_colormap
169+
self.layers[0].face_color = self.x_axis_key
170+
elif isinstance(self.layers[0], napari.layers.Vectors):
171+
colormap = self.layers[0].edge_colormap
172+
self.layers[0].edge_color = self.x_axis_key
173+
else:
174+
colormap = None
175+
176+
# apply new colors to the layer
177+
self.viewer.layers[self.layers[0].name].refresh_colors(True)
178+
self.viewer.layers[self.layers[0].name].refresh()
179+
180+
# Draw the histogram
165181
data, x_axis_name = self._get_data()
166182

167183
if data is None:
168184
return
169185

170-
self.axes.hist(data, bins=50, edgecolor="white", linewidth=0.3)
186+
_, bins, patches = self.axes.hist(
187+
data, bins=50, edgecolor="white", linewidth=0.3
188+
)
189+
patches = cast(BarContainer, patches)
190+
191+
# recolor the histogram plot
192+
if colormap is not None:
193+
self.bins_norm = (bins - bins.min()) / (bins.max() - bins.min())
194+
colors = colormap.map(self.bins_norm)
195+
196+
# Set histogram style:
197+
for idx, patch in enumerate(patches):
198+
patch.set_facecolor(colors[idx])
171199

172200
# set ax labels
173201
self.axes.set_xlabel(x_axis_name)
Binary file not shown.
Loading
Loading

src/napari_matplotlib/tests/test_histogram.py

+30-12
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def test_histogram_3D(make_napari_viewer, brain_data):
3838
def test_feature_histogram(make_napari_viewer):
3939
n_points = 1000
4040
random_points = np.random.random((n_points, 3)) * 10
41+
random_directions = np.random.random((n_points, 3)) * 10
42+
random_vectors = np.stack([random_points, random_directions], axis=1)
4143
feature1 = np.random.random(n_points)
4244
feature2 = np.random.normal(size=n_points)
4345

@@ -47,10 +49,10 @@ def test_feature_histogram(make_napari_viewer):
4749
properties={"feature1": feature1, "feature2": feature2},
4850
name="points1",
4951
)
50-
viewer.add_points(
51-
random_points,
52+
viewer.add_vectors(
53+
random_vectors,
5254
properties={"feature1": feature1, "feature2": feature2},
53-
name="points2",
55+
name="vectors1",
5456
)
5557

5658
widget = FeaturesHistogramWidget(viewer)
@@ -70,26 +72,42 @@ def test_feature_histogram(make_napari_viewer):
7072

7173

7274
@pytest.mark.mpl_image_compare
73-
def test_feature_histogram2(make_napari_viewer):
74-
import numpy as np
75+
def test_feature_histogram_vectors(make_napari_viewer):
76+
n_points = 1000
77+
np.random.seed(42)
78+
random_points = np.random.random((n_points, 3)) * 10
79+
random_directions = np.random.random((n_points, 3)) * 10
80+
random_vectors = np.stack([random_points, random_directions], axis=1)
81+
feature1 = np.random.random(n_points)
82+
83+
viewer = make_napari_viewer()
84+
viewer.add_vectors(
85+
random_vectors,
86+
properties={"feature1": feature1},
87+
name="vectors1",
88+
)
89+
90+
widget = FeaturesHistogramWidget(viewer)
91+
viewer.window.add_dock_widget(widget)
92+
widget._set_axis_keys("feature1")
7593

94+
fig = FeaturesHistogramWidget(viewer).figure
95+
return deepcopy(fig)
96+
97+
98+
@pytest.mark.mpl_image_compare
99+
def test_feature_histogram_points(make_napari_viewer):
76100
np.random.seed(0)
77101
n_points = 1000
78102
random_points = np.random.random((n_points, 3)) * 10
79103
feature1 = np.random.random(n_points)
80-
feature2 = np.random.normal(size=n_points)
81104

82105
viewer = make_napari_viewer()
83106
viewer.add_points(
84107
random_points,
85-
properties={"feature1": feature1, "feature2": feature2},
108+
properties={"feature1": feature1},
86109
name="points1",
87110
)
88-
viewer.add_points(
89-
random_points,
90-
properties={"feature1": feature1, "feature2": feature2},
91-
name="points2",
92-
)
93111

94112
widget = FeaturesHistogramWidget(viewer)
95113
viewer.window.add_dock_widget(widget)

0 commit comments

Comments
 (0)