From 6dd6be3a6c623ecbc71d124aeb0fa4e18daa443d Mon Sep 17 00:00:00 2001 From: Michael Milton Date: Thu, 10 Oct 2024 16:31:47 +1100 Subject: [PATCH] Fixes for multiple shape layers --- plugin/napari_lattice/shape_selector.py | 71 +++++++++++++++++-------- 1 file changed, 48 insertions(+), 23 deletions(-) diff --git a/plugin/napari_lattice/shape_selector.py b/plugin/napari_lattice/shape_selector.py index 81e966e0..732e2e53 100644 --- a/plugin/napari_lattice/shape_selector.py +++ b/plugin/napari_lattice/shape_selector.py @@ -5,6 +5,8 @@ from magicgui.widgets import Select from napari.layers import Shapes from napari.components.layerlist import LayerList +from collections import defaultdict +from contextlib import contextmanager from plugin.napari_lattice.utils import get_viewer @@ -12,7 +14,7 @@ from napari.utils.events.event import Event from numpy.typing import NDArray -@dataclass +@dataclass(frozen=True, eq=True) class Shape: """ Holds data about a single shape within a Shapes layer @@ -21,16 +23,32 @@ class Shape: index: int def __str__(self) -> str: - return f"{self.layer.name} {self.index}" + return f"{self.layer.name}: Shape {self.index}" def get_array(self) -> NDArray: return self.layer.data[self.index] @magicclass class ShapeSelector: + + _blocked: bool + def __init__(self, *args, **kwargs) -> None: # Needed to handle extra kwargs - pass + self._blocked = False + + @contextmanager + def _block(self): + """ + Context manager that prevents event handlers recursively calling each other. + Yields a boolean, which means functions should proceed if `True`, or return immediately if `False` + """ + if self._blocked: + yield False + else: + self._blocked = True + yield True + self._blocked = False def _get_shape_choices(self, widget: Select | None = None) -> Iterator[Tuple[str, Shape]]: """ @@ -48,11 +66,16 @@ def _on_selection_change(self, event: Event) -> None: Triggered when the user clicks on one or more shapes. The widget is then updated to synchronise """ - source: Shapes = event.source - selection: list[Shape] = [] - for index in source.selected_data: - selection.append(Shape(layer=source, index=index)) - self.shapes.value = selection + # Prevent recursion + with self._block() as execute: + if not execute: + return + + source: Shapes = event.source + selection: list[Shape] = [] + for index in source.selected_data: + selection.append(Shape(layer=source, index=index)) + self.shapes.value = selection def _connect_shapes(self, shapes: Shapes) -> None: """ @@ -104,18 +127,20 @@ def _widget_changed(self, values: list) -> None: Triggered when the plugin widget is changed. We then synchronise the Napari shape selection with it. """ - layers: set[Shapes] = set() - indices: set[int] = set() - value: Shape - - for value in values: - layers.add(value.layer) - indices.add(value.index) - - if len(layers) > 1: - raise Exception("Shapes from multiple layers selected. This shouldn't be possible") - - if layers: - layer = layers.pop() - layer.selected_data = indices - layer.refresh() + with self._block() as execute: + if not execute: + return + layers: dict[Shapes, list[int]] = {layer: [] for layer in get_viewer().layers if isinstance(layer, Shapes)} + value: Shape + + # Find the current selection for each layer + for value in values: + layers[value.layer].append(value.index) + + # For each layer, set the appropriate selection (this can't be done incrementally) + for layer, shapes in layers.items(): + layer.selected_data = shapes + + # Re-calculate the selections for all Shapes layers (since some have been deselected) + for layer in layers.keys(): + layer.refresh()