Skip to content

Commit

Permalink
Add shape selector widget
Browse files Browse the repository at this point in the history
  • Loading branch information
multimeric committed Oct 10, 2024
1 parent 88e210c commit 917c34f
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 2 deletions.
6 changes: 4 additions & 2 deletions plugin/napari_lattice/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from lls_core.models.deskew import DefinedPixelSizes
from lls_core.models.output import SaveFileType
from lls_core.workflow import workflow_from_path
from magicclass import FieldGroup, MagicTemplate, field, magicclass, set_design
from magicclass import FieldGroup, MagicTemplate, field, magicclass, set_design, vfield
from magicclass.fields import MagicField
from magicclass.widgets import ComboBox, Label, Widget
from napari.layers import Image, Shapes
Expand All @@ -32,6 +32,7 @@
from qtpy.QtWidgets import QTabWidget
from strenum import StrEnum
from napari_lattice.parent_connect import connect_parent
from plugin.napari_lattice.shape_selector import ShapeSelector

if TYPE_CHECKING:
from magicgui.widgets.bases import RangedWidget
Expand Down Expand Up @@ -429,7 +430,8 @@ class CroppingFields(NapariFieldGroup):
This is to support the workflow of performing a preview deskew and using that to calculate the cropping coordinates.
"""), widget_type="Label")
fields_enabled = field(False, label="Enabled")
shapes= field(List[Shapes], widget_type="Select", label = "ROI Shape Layers").with_options(choices=lambda _x, _y: get_layers(Shapes))

shapes= vfield(ShapeSelector)
z_range = field(Tuple[int, int]).with_options(
label = "Z Range",
value = (0, 1),
Expand Down
121 changes: 121 additions & 0 deletions plugin/napari_lattice/shape_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterator, Tuple, TYPE_CHECKING
from magicclass import field, magicclass
from magicgui.widgets import Select
from napari.layers import Shapes
from napari.components.layerlist import LayerList

from plugin.napari_lattice.utils import get_viewer

if TYPE_CHECKING:
from napari.utils.events.event import Event
from numpy.typing import NDArray

@dataclass
class Shape:
"""
Holds data about a single shape within a Shapes layer
"""
layer: Shapes
index: int

def __str__(self) -> str:
return f"{self.layer.name} {self.index}"

def get_array(self) -> NDArray:
return self.layer.data[self.index]

@magicclass
class ShapeSelector:
def __init__(self, *args, **kwargs) -> None:
# Needed to handle extra kwargs
pass

def _get_shape_choices(self, widget: Select | None = None) -> Iterator[Tuple[str, Shape]]:
"""
Returns the choices to use for the Select box
"""
viewer = get_viewer()
for layer in viewer.layers:
if isinstance(layer, Shapes):
for index in layer.features.index:
result = Shape(layer=layer, index=index)
yield str(result), result

def _on_selection_change(self, event: Event) -> None:
"""
Triggered when the user clicks on one or more shapes.
The widget is then updated to synchronise
"""
source: Shapes = event.source
selection: list[Shape] = []
for index in source.selected_data:
selection.append(Shape(layer=source, index=index))
self.shapes.value = selection

def _connect_shapes(self, shapes: Shapes) -> None:
"""
Called on a newly discovered `Shapes` layer.
Listens to events on that layer that we are interested in.
"""
shapes.events.data.connect(self._on_shape_change)
shapes.events.highlight.connect(self._on_selection_change)
# shapes.events.current_properties.connect(self._on_selection_change)

def _on_shape_change(self, event: Event) -> None:
"""
Triggered whenever a shape layer changes.
Resets the select box options
"""
if isinstance(event.source, Shapes):
self.shapes.reset_choices()

def _on_layer_add(self, event: Event) -> None:
"""
Triggered whenever a new layer is inserted.
Ensures we listen for shape changes to that new layer
"""
if isinstance(event.source, LayerList):
for layer in event.source:
if isinstance(layer, Shapes):
self._connect_shapes(layer)

def __post_init__(self) -> None:
"""
Whenever a new layer is inserted
"""
viewer = get_viewer()

# Listen for new layers
viewer.layers.events.inserted.connect(self._on_layer_add)

# Watch current layers
for layer in viewer.layers:
if isinstance(layer, Shapes):
self._connect_shapes(layer)

shapes = field(Select, options={"choices": _get_shape_choices})

# values is a list[Shape], but if we use the correct annotation it breaks magicclass
@shapes.connect
def _widget_changed(self, values: list) -> None:
"""
Triggered when the plugin widget is changed.
We then synchronise the Napari shape selection with it.
"""
layers: set[Shapes] = set()
indices: set[int] = set()
value: Shape

for value in values:
layers.add(value.layer)
indices.add(value.index)

if len(layers) > 1:
raise Exception("Shapes from multiple layers selected. This shouldn't be possible")

if layers:
layer = layers.pop()
layer.selected_data = indices
layer.refresh()

0 comments on commit 917c34f

Please sign in to comment.