diff --git a/plugin/napari_lattice/fields.py b/plugin/napari_lattice/fields.py index 2360c69f..b2fc7a3b 100644 --- a/plugin/napari_lattice/fields.py +++ b/plugin/napari_lattice/fields.py @@ -20,7 +20,7 @@ from lls_core.models.deskew import DefinedPixelSizes from lls_core.models.output import SaveFileType from lls_core.workflow import workflow_from_path -from magicclass import FieldGroup, MagicTemplate, field, magicclass, set_design +from magicclass import FieldGroup, MagicTemplate, field, magicclass, set_design, vfield from magicclass.fields import MagicField from magicclass.widgets import ComboBox, Label, Widget from napari.layers import Image, Shapes @@ -32,6 +32,7 @@ from qtpy.QtWidgets import QTabWidget from strenum import StrEnum from napari_lattice.parent_connect import connect_parent +from plugin.napari_lattice.shape_selector import ShapeSelector if TYPE_CHECKING: from magicgui.widgets.bases import RangedWidget @@ -429,7 +430,8 @@ class CroppingFields(NapariFieldGroup): This is to support the workflow of performing a preview deskew and using that to calculate the cropping coordinates. """), widget_type="Label") fields_enabled = field(False, label="Enabled") - shapes= field(List[Shapes], widget_type="Select", label = "ROI Shape Layers").with_options(choices=lambda _x, _y: get_layers(Shapes)) + + shapes= vfield(ShapeSelector) z_range = field(Tuple[int, int]).with_options( label = "Z Range", value = (0, 1), diff --git a/plugin/napari_lattice/shape_selector.py b/plugin/napari_lattice/shape_selector.py new file mode 100644 index 00000000..81e966e0 --- /dev/null +++ b/plugin/napari_lattice/shape_selector.py @@ -0,0 +1,121 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Iterator, Tuple, TYPE_CHECKING +from magicclass import field, magicclass +from magicgui.widgets import Select +from napari.layers import Shapes +from napari.components.layerlist import LayerList + +from plugin.napari_lattice.utils import get_viewer + +if TYPE_CHECKING: + from napari.utils.events.event import Event + from numpy.typing import NDArray + +@dataclass +class Shape: + """ + Holds data about a single shape within a Shapes layer + """ + layer: Shapes + index: int + + def __str__(self) -> str: + return f"{self.layer.name} {self.index}" + + def get_array(self) -> NDArray: + return self.layer.data[self.index] + +@magicclass +class ShapeSelector: + def __init__(self, *args, **kwargs) -> None: + # Needed to handle extra kwargs + pass + + def _get_shape_choices(self, widget: Select | None = None) -> Iterator[Tuple[str, Shape]]: + """ + Returns the choices to use for the Select box + """ + viewer = get_viewer() + for layer in viewer.layers: + if isinstance(layer, Shapes): + for index in layer.features.index: + result = Shape(layer=layer, index=index) + yield str(result), result + + def _on_selection_change(self, event: Event) -> None: + """ + Triggered when the user clicks on one or more shapes. + The widget is then updated to synchronise + """ + source: Shapes = event.source + selection: list[Shape] = [] + for index in source.selected_data: + selection.append(Shape(layer=source, index=index)) + self.shapes.value = selection + + def _connect_shapes(self, shapes: Shapes) -> None: + """ + Called on a newly discovered `Shapes` layer. + Listens to events on that layer that we are interested in. + """ + shapes.events.data.connect(self._on_shape_change) + shapes.events.highlight.connect(self._on_selection_change) + # shapes.events.current_properties.connect(self._on_selection_change) + + def _on_shape_change(self, event: Event) -> None: + """ + Triggered whenever a shape layer changes. + Resets the select box options + """ + if isinstance(event.source, Shapes): + self.shapes.reset_choices() + + def _on_layer_add(self, event: Event) -> None: + """ + Triggered whenever a new layer is inserted. + Ensures we listen for shape changes to that new layer + """ + if isinstance(event.source, LayerList): + for layer in event.source: + if isinstance(layer, Shapes): + self._connect_shapes(layer) + + def __post_init__(self) -> None: + """ + Whenever a new layer is inserted + """ + viewer = get_viewer() + + # Listen for new layers + viewer.layers.events.inserted.connect(self._on_layer_add) + + # Watch current layers + for layer in viewer.layers: + if isinstance(layer, Shapes): + self._connect_shapes(layer) + + shapes = field(Select, options={"choices": _get_shape_choices}) + + # values is a list[Shape], but if we use the correct annotation it breaks magicclass + @shapes.connect + def _widget_changed(self, values: list) -> None: + """ + Triggered when the plugin widget is changed. + We then synchronise the Napari shape selection with it. + """ + layers: set[Shapes] = set() + indices: set[int] = set() + value: Shape + + for value in values: + layers.add(value.layer) + indices.add(value.index) + + if len(layers) > 1: + raise Exception("Shapes from multiple layers selected. This shouldn't be possible") + + if layers: + layer = layers.pop() + layer.selected_data = indices + layer.refresh()