diff --git a/core/lls_core/models/results.py b/core/lls_core/models/results.py index 6acb3701..d8781646 100644 --- a/core/lls_core/models/results.py +++ b/core/lls_core/models/results.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Iterable, Optional, Tuple, Union, cast, TYPE_CHECKING, overload -from typing_extensions import Generic, TypeVar +from typing_extensions import Generic, TypeVar, TypeAlias from pydantic.v1 import BaseModel, NonNegativeInt, Field from lls_core.types import ArrayLike, is_arraylike from lls_core.utils import make_filename_suffix @@ -75,6 +75,19 @@ class ImageSlices(ProcessedSlices[ArrayLike]): # This re-definition of the type is helpful for `mkdocs` slices: Iterable[ProcessedSlice[ArrayLike]] = Field(description="Iterable of result slices. For a given slice, you can access the image data through the `slice.data` property, which is a numpy-like array.") + def roi_previews(self) -> Iterable[ArrayLike]: + """ + Extracts a single 3D image for each ROI + """ + import numpy as np + def _preview(slices: Iterable[ProcessedSlice[ArrayLike]]) -> ArrayLike: + for slice in slices: + return slice.data + raise Exception("This ROI has no images. This shouldn't be possible") + + for roi_index, slices in groupby(self.slices, key=lambda slice: slice.roi_index): + yield _preview(slices) + def save_image(self): """ Saves result slices to disk @@ -124,7 +137,8 @@ def save(self) -> Path: else: return self.data -class WorkflowSlices(ProcessedSlices[Tuple[RawWorkflowOutput, ...]]): +MaybeTupleRawWorkflowOutput: TypeAlias = Union[Tuple[RawWorkflowOutput], RawWorkflowOutput] +class WorkflowSlices(ProcessedSlices[MaybeTupleRawWorkflowOutput]): """ The counterpart of `ImageSlices`, but for workflow outputs. This is needed because workflows have vastly different outputs that may include regular @@ -132,7 +146,7 @@ class WorkflowSlices(ProcessedSlices[Tuple[RawWorkflowOutput, ...]]): """ # This re-definition of the type is helpful for `mkdocs` - slices: Iterable[ProcessedSlice[Tuple[RawWorkflowOutput, ...]]] = Field(description="Iterable of raw workflow results, the exact nature of which is determined by the author of the workflow. Not typically useful directly, and using he result of `.process()` is recommended instead.") + slices: Iterable[ProcessedSlice[MaybeTupleRawWorkflowOutput]] = Field(description="Iterable of raw workflow results, the exact nature of which is determined by the author of the workflow. Not typically useful directly, and using he result of `.process()` is recommended instead.") def process(self) -> Iterable[ProcessedWorkflowOutput]: """ @@ -189,16 +203,20 @@ def process(self) -> Iterable[ProcessedWorkflowOutput]: else: yield ProcessedWorkflowOutput(index=i, roi_index=roi, data=pd.DataFrame(element), lattice_data=self.lattice_data) - def extract_preview(self) -> NDArray: + def roi_previews(self) -> Iterable[NDArray]: """ - Extracts a single 3D image for previewing purposes + Extracts a single 3D image for each ROI """ import numpy as np - for slice in self.slices: - for value in slice.as_tuple(): - if is_arraylike(value): - return np.asarray(value) - raise Exception("No image was returned from this workflow") + def _preview(slices: Iterable[ProcessedSlice[MaybeTupleRawWorkflowOutput]]) -> NDArray: + for slice in slices: + for value in slice.as_tuple(): + if is_arraylike(value): + return np.asarray(value) + raise Exception("This ROI has no images. This shouldn't be possible") + + for roi_index, slices in groupby(self.slices, key=lambda slice: slice.roi_index): + yield _preview(slices) def save(self) -> Iterable[Path]: """ diff --git a/core/tests/test_workflows.py b/core/tests/test_workflows.py index 9c1f2cfb..aceb5a96 100644 --- a/core/tests/test_workflows.py +++ b/core/tests/test_workflows.py @@ -107,8 +107,9 @@ def test_sum_preview(rbc_tiny: Path): workflow = "core/tests/workflows/binarisation/workflow.yml", save_dir = tmpdir ) - preview = params.process_workflow().extract_preview() - np.sum(preview, axis=(1, 2)) + previews = list(params.process_workflow().roi_previews()) + assert len(previews) == 1, "There should be 1 preview when cropping is disabled" + assert previews[0].ndim == 3, "A preview should be a 3D image" def test_crop_workflow(rbc_tiny: Path): # Tests that crop workflows only process each ROI lazily diff --git a/plugin/napari_lattice/dock_widget.py b/plugin/napari_lattice/dock_widget.py index c57ef838..7cf0aa69 100644 --- a/plugin/napari_lattice/dock_widget.py +++ b/plugin/napari_lattice/dock_widget.py @@ -149,20 +149,19 @@ def preview(self, header: str, time: int, channel: int): lattice.dy, lattice.dx ) - preview: ArrayLike + previews: Iterable[ArrayLike] # We extract the first available image to use as a preview # This works differently for workflows and non-workflows if lattice.workflow is None: - for slice in lattice.process().slices: - preview = slice.data - break + previews = lattice.process().roi_previews() else: - preview = lattice.process_workflow().extract_preview() + previews = lattice.process_workflow().roi_previews() - self.parent_viewer.add_image(preview, scale=scale, name="Napari Lattice Preview") - max_z = np.argmax(np.sum(preview, axis=(1, 2))) - self.parent_viewer.dims.set_current_step(0, max_z) + for preview in previews: + self.parent_viewer.add_image(preview, scale=scale, name="Napari Lattice Preview") + max_z = np.argmax(np.sum(preview, axis=(1, 2))) + self.parent_viewer.dims.set_current_step(0, max_z) @set_design(text="Save") diff --git a/plugin/napari_lattice/fields.py b/plugin/napari_lattice/fields.py index c46e1b59..ecce2829 100644 --- a/plugin/napari_lattice/fields.py +++ b/plugin/napari_lattice/fields.py @@ -20,7 +20,7 @@ from lls_core.models.deskew import DefinedPixelSizes from lls_core.models.output import SaveFileType from lls_core.workflow import workflow_from_path -from magicclass import FieldGroup, MagicTemplate, field, magicclass, set_design +from magicclass import FieldGroup, MagicTemplate, field, magicclass, set_design, vfield from magicclass.fields import MagicField from magicclass.widgets import ComboBox, Label, Widget from napari.layers import Image, Shapes @@ -32,9 +32,11 @@ from qtpy.QtWidgets import QTabWidget from strenum import StrEnum from napari_lattice.parent_connect import connect_parent +from napari_lattice.shape_selector import ShapeSelector if TYPE_CHECKING: from magicgui.widgets.bases import RangedWidget + from numpy.typing import NDArray logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -429,15 +431,8 @@ class CroppingFields(NapariFieldGroup): This is to support the workflow of performing a preview deskew and using that to calculate the cropping coordinates. """), widget_type="Label") fields_enabled = field(False, label="Enabled") - shapes= field(List[Shapes], widget_type="Select", label = "ROI Shape Layers").with_options(choices=lambda _x, _y: get_layers(Shapes)) - z_range = field(Tuple[int, int]).with_options( - label = "Z Range", - value = (0, 1), - options = dict( - min = 0, - ), - ) - errors = field(Label).with_options(label="Errors") + + shapes= vfield(ShapeSelector) @set_design(text="Import ROI") def import_roi(self, path: Path): @@ -455,7 +450,16 @@ def new_crop_layer(self): from napari_lattice.utils import get_viewer shapes = get_viewer().add_shapes(name="Napari Lattice Crop") shapes.mode = "ADD_RECTANGLE" - self.shapes.value += [shapes] + # self.shapes.value += [shapes] + + z_range = field(Tuple[int, int]).with_options( + label = "Z Range", + value = (0, 1), + options = dict( + min = 0, + ), + ) + errors = field(Label).with_options(label="Errors") @connect_parent("deskew_fields.img_layer") def _on_image_changed(self, field: MagicField): @@ -486,9 +490,12 @@ def _make_model(self) -> Optional[CropParams]: if self.fields_enabled.value: deskew = self._get_deskew() rois = [] - for shape_layer in self.shapes.value: - for x in shape_layer.data: - rois.append(Roi.from_array(x / deskew.dy)) + for shape in self.shapes.shapes.value: + # The Napari shape is an array with 2 dimensions. + # Each column is an axis and each row is a point defining the shape + # We drop all but the last two axes, giving us a 2D shape with XY coordinates + array: NDArray = shape.get_array()[..., -2:] / deskew.dy + rois.append(Roi.from_array(array)) return CropParams( # Convert from the input image space to the deskewed image space diff --git a/plugin/napari_lattice/shape_selection.py b/plugin/napari_lattice/shape_selection.py new file mode 100644 index 00000000..0a1fa6e3 --- /dev/null +++ b/plugin/napari_lattice/shape_selection.py @@ -0,0 +1,58 @@ +from __future__ import annotations +from napari.utils.events import EventEmitter, Event +from napari.layers import Shapes + +class ShapeLayerChangedEvent(Event): + """ + Event triggered when the shape layer selection changes. + """ + +class ShapeSelectionListener(EventEmitter): + """ + Manages shape selection events for a given Shapes layer. + + Examples: + This example code will open the viewer with an empty shape layer. + Any selection changes to that layer will trigger a notification popup. + >>> from napari import Viewer + >>> from napari.layers import Shapes + >>> viewer = Viewer() + >>> shapes = viewer.add_shapes() + >>> shape_selection = ShapeSelection(shapes) + >>> shape_selection.connect(lambda event: print("Shape selection changed!")) + """ + last_selection: set[int] + layer: Shapes + + def __init__(self, layer) -> None: + """ + Initializes the ShapeSelection with the given Shapes layer. + + Parameters: + layer: The Shapes layer to listen to. + """ + super().__init__(source=layer, event_class=ShapeLayerChangedEvent, type_name="shape_layer_selection_changed") + self.layer = layer + self.last_selection = set() + layer.events.highlight.connect(self._on_highlight) + + def _on_highlight(self, event) -> None: + new_selection = self.layer.selected_data + if new_selection != self.last_selection: + self() + self.last_selection = set(new_selection) + +def test_script(): + """ + Demo for testing the event behaviour. + """ + from napari import run, Viewer + from napari.utils.notifications import show_info + viewer = Viewer() + shapes = viewer.add_shapes() + event = ShapeSelectionListener(shapes) + event.connect(lambda x: show_info("Shape selection changed!")) + run() + +if __name__ == "__main__": + test_script() diff --git a/plugin/napari_lattice/shape_selector.py b/plugin/napari_lattice/shape_selector.py new file mode 100644 index 00000000..1cd63084 --- /dev/null +++ b/plugin/napari_lattice/shape_selector.py @@ -0,0 +1,168 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Iterator, Tuple, TYPE_CHECKING +from magicclass import field, magicclass, set_design +from magicclass.fields._fields import MagicField +from magicgui.widgets import Select, Button +from napari.layers import Shapes +from napari.components.layerlist import LayerList +from collections import defaultdict +from contextlib import contextmanager +from napari.viewer import current_viewer +from napari_lattice.shape_selection import ShapeSelectionListener + +if TYPE_CHECKING: + from napari.utils.events.event import Event + from numpy.typing import NDArray + +@dataclass(frozen=True, eq=True) +class Shape: + """ + Holds data about a single shape within a Shapes layer + """ + layer: Shapes + index: int + + def __str__(self) -> str: + return f"{self.layer.name}: Shape {self.index}" + + def get_array(self) -> NDArray: + return self.layer.data[self.index] + +@magicclass +class ShapeSelector: + + def _get_shape_choices(self, widget: Select | None = None) -> Iterator[Tuple[str, Shape]]: + """ + Returns the choices to use for the Select box + """ + viewer = current_viewer() + if viewer is not None: + for layer in viewer.layers: + if isinstance(layer, Shapes): + for index in layer.features.index: + result = Shape(layer=layer, index=index) + yield str(result), result + + _blocked: bool + _listeners: list[ShapeSelectionListener] = [] + shapes: MagicField[Select] = field(Select, options={"choices": _get_shape_choices, "label": "ROIs"}) + + @set_design(text="Select All") + def select_all(self) -> None: + self.shapes.value = self.shapes.choices + + @set_design(text="Deselect All") + def deselect_all(self) -> None: + self.shapes.value = [] + + def __init__(self, enabled: bool, *args, **kwargs) -> None: + self._blocked = False + self.enabled = enabled + + @contextmanager + def _block(self): + """ + Context manager that prevents event handlers recursively calling each other. + Yields a boolean which means functions should proceed if `True`, or return immediately if `False` + """ + if self._blocked: + yield False + else: + self._blocked = True + yield True + self._blocked = False + + def _on_selection_change(self, event: Event) -> None: + """ + Triggered when the user clicks on one or more shapes. + The widget is then updated to synchronise + """ + # Prevent recursion + with self._block() as execute: + if not execute: + return + + source: Shapes = event.source + selection: list[Shape] = [] + for index in source.selected_data: + shape = Shape(layer=source, index=index) + if shape not in self.shapes.choices: + # If we ever encounter a shape that isn't a legal choice, we have to terminate to avoid an error + # This seems to happen on Windows only due to the order of events firing + return + selection.append(shape) + self.shapes.value = selection + + def _connect_shapes(self, shapes: Shapes) -> None: + """ + Called on a newly discovered `Shapes` layer. + Listens to events on that layer that we are interested in. + """ + shapes.events.data.connect(self._on_shape_change) + listener = ShapeSelectionListener(shapes) + self._listeners.append(listener) + listener.connect(self._on_selection_change) + + def _on_shape_change(self, event: Event) -> None: + """ + Triggered whenever a shape layer changes. + Resets the select box options + """ + if isinstance(event.source, Shapes): + self.shapes.reset_choices() + + def _on_layer_add(self, event: Event) -> None: + """ + Triggered whenever a new layer is inserted. + Ensures we listen for shape changes to that new layer + """ + if isinstance(event.source, LayerList): + for layer in event.source: + if isinstance(layer, Shapes): + self._connect_shapes(layer) + + def __post_init__(self) -> None: + """ + Whenever a new layer is inserted + """ + viewer = current_viewer() + + if viewer is not None: + + # Listen for new layers + viewer.layers.events.inserted.connect(self._on_layer_add) + + # Watch current layers + for layer in viewer.layers: + if isinstance(layer, Shapes): + self._connect_shapes(layer) + + # values is a list[Shape], but if we use the correct annotation it breaks magicclass + @shapes.connect + def _widget_changed(self, values: list) -> None: + """ + Triggered when the plugin widget is changed. + We then synchronise the Napari shape selection with it. + """ + viewer = current_viewer() + if viewer is None: + return + + with self._block() as execute: + if not execute: + return + layers: dict[Shapes, list[int]] = {layer: [] for layer in viewer.layers if isinstance(layer, Shapes)} + value: Shape + + # Find the current selection for each layer + for value in values: + layers[value.layer].append(value.index) + + # For each layer, set the appropriate selection (this can't be done incrementally) + for layer, shapes in layers.items(): + layer.selected_data = shapes + + # Re-calculate the selections for all Shapes layers (since some have been deselected) + for layer in layers.keys(): + layer.refresh()