Skip to content

Commit

Permalink
Merge pull request #85 from multimeric/fix-66-v2
Browse files Browse the repository at this point in the history
Fix 66 v2
  • Loading branch information
multimeric authored Oct 4, 2024
2 parents f57b489 + 3332bc6 commit 061c94a
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 146 deletions.
216 changes: 122 additions & 94 deletions core/lls_core/models/lattice_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations
# class for initializing lattice data and setting metadata
# TODO: handle scenes
from typing import Tuple, cast
from pydantic.v1 import Field, root_validator, validator
from dask.array.core import Array as DaskArray

Expand All @@ -9,19 +8,17 @@
from lls_core.llsz_core import crop_volume_deskew
from lls_core.models.crop import CropParams
from lls_core.models.deconvolution import DeconvolutionParams
from lls_core.models.output import OutputParams, SaveFileType
from lls_core.models.results import WorkflowSlices
from lls_core.models.utils import ignore_keyerror
from lls_core.types import ArrayLike
from lls_core.models.deskew import DeskewParams
from lls_core.models.output import OutputParams, SaveFileType
from napari_workflows import Workflow

from lls_core.workflow import get_workflow_output_name, workflow_set

if TYPE_CHECKING:
from lls_core.models.results import ImageSlice, ImageSlices, ProcessedSlice
from lls_core.writers import Writer
from xarray import DataArray
from lls_core.workflow import RawWorkflowOutput
from lls_core.types import ArrayLike
from lls_core.models.results import WorkflowSlices

import logging

Expand Down Expand Up @@ -54,6 +51,11 @@ class LatticeData(OutputParams, DeskewParams):
cli_description="Path to a JSON file specifying a napari_workflow-compatible workflow to add lightsheet processing onto"
)

progress_bar: bool = Field(
default = True,
description = "If true, show progress bars"
)

@root_validator(pre=True)
def read_image(cls, values: dict):
from lls_core.types import is_pathlike
Expand Down Expand Up @@ -87,6 +89,7 @@ def parse_workflow(cls, v: Any):

@validator("workflow", pre=False)
def validate_workflow(cls, v: Optional[Workflow]):
from lls_core.workflow import get_workflow_output_name
if v is not None:
if not "deskewed_image" in v.roots():
raise ValueError("The workflow has no deskewed_image parameter, so is not compatible with the lls processing.")
Expand All @@ -98,6 +101,7 @@ def validate_workflow(cls, v: Optional[Workflow]):

@validator("crop")
def default_z_range(cls, v: Optional[CropParams], values: dict) -> Optional[CropParams]:
from lls_core.models.utils import ignore_keyerror
if v is None:
return v
with ignore_keyerror():
Expand Down Expand Up @@ -127,6 +131,7 @@ def parse_time_range(cls, v: Any, values: dict) -> Any:
"""
Sets the default time range if undefined
"""
from lls_core.models.utils import ignore_keyerror
# This skips the conversion if no image was provided, to ensure a more
# user-friendly error is provided, namely "image was missing"
from collections.abc import Sequence
Expand All @@ -145,7 +150,9 @@ def parse_channel_range(cls, v: Any, values: dict) -> Any:
"""
Sets the default channel range if undefined
"""
from lls_core.models.utils import ignore_keyerror
from collections.abc import Sequence

with ignore_keyerror():
default_start = 0
default_end = values["input_image"].sizes["C"]
Expand All @@ -161,6 +168,7 @@ def disjoint_time_range(cls, v: range, values: dict):
"""
Validates that the time range is within the range of channels in our array
"""
from lls_core.models.utils import ignore_keyerror
with ignore_keyerror():
max_time = values["input_image"].sizes["T"]
if v.start < 0:
Expand All @@ -175,6 +183,7 @@ def disjoint_channel_range(cls, v: range, values: dict):
"""
Validates that the channel range is within the range of channels in our array
"""
from lls_core.models.utils import ignore_keyerror
with ignore_keyerror():
max_channel = values["input_image"].sizes["C"]
if v.start < 0:
Expand All @@ -185,6 +194,7 @@ def disjoint_channel_range(cls, v: range, values: dict):

@validator("channel_range")
def channel_range_subset(cls, v: Optional[range], values: dict):
from lls_core.models.utils import ignore_keyerror
with ignore_keyerror():
if v is not None and (min(v) < 0 or max(v) > values["input_image"].sizes["C"]):
raise ValueError("The output channel range must be a subset of the total available channels")
Expand All @@ -198,6 +208,7 @@ def time_range_subset(cls, v: Optional[range], values: dict):

@validator("deconvolution")
def check_psfs(cls, v: Optional[DeconvolutionParams], values: dict):
from lls_core.models.utils import ignore_keyerror
if v is None:
return v
with ignore_keyerror():
Expand Down Expand Up @@ -229,28 +240,37 @@ def slice_data(self, time: int, channel: int) -> DataArray:

return self.input_image.isel(T=time, C=channel)

def iter_slices(self) -> Iterable[ProcessedSlice[ArrayLike]]:
def iter_roi_indices(self) -> Iterable[Optional[int]]:
"""
Yields array slices for each time and channel of interest.
Params:
progress: If the progress bar is enabled
Yields region of interest indices, with a progress bar.
This yields `None` exactly once if cropping is disabled, for compatibility.
"""
from tqdm import tqdm
if self.cropping_enabled and self.crop is not None:
for index in tqdm(self.crop.roi_subset, desc="ROI", position=0, disable=not self.progress_bar):
yield index
else:
yield None

Returns:
An iterable of tuples. Each tuple contains (time_index, time, channel_index, channel, slice)
def iter_slices(self) -> Iterable[ProcessedSlice[ArrayLike]]:
"""
Yields 3D array slices for each time, channel and region of interest.
These are guaranteed to iterate in the following order: ROI (slowest), timepoint, channel (fastest)
"""
from lls_core.models.results import ProcessedSlice
from tqdm import tqdm

for time_idx, time in tqdm(enumerate(self.time_range), desc="Timepoints", total=len(self.time_range)):
for ch_idx, ch in tqdm(enumerate(self.channel_range), desc="Channels", total=len(self.channel_range), leave=False):
yield ProcessedSlice(
data=self.slice_data(time=time, channel=ch),
time_index=time_idx,
time= time,
channel_index=ch_idx,
channel=ch,
)
for roi_index in self.iter_roi_indices():
for time_idx, time in tqdm(enumerate(self.time_range), desc="Timepoints", total=len(self.time_range), disable=not self.progress_bar, leave=not self.cropping_enabled, position=1 if self.cropping_enabled else 0):
for ch_idx, ch in tqdm(enumerate(self.channel_range), desc="Channels", total=len(self.channel_range), leave=False, disable=not self.progress_bar, position=2 if self.cropping_enabled else 1):
yield ProcessedSlice(
data=self.slice_data(time=time, channel=ch),
roi_index=roi_index,
time_index=time_idx,
time=time,
channel_index=ch_idx,
channel=ch,
)

@property
def n_slices(self) -> int:
Expand All @@ -267,43 +287,54 @@ def iter_sublattices(self, update_with: dict = {}) -> Iterable[ProcessedSlice[La
update_with: dictionary of arguments to update the generated lattices with
"""
for subarray in self.iter_slices():

if subarray.roi_index is not None and self.crop is not None:
crop = self.crop.copy_validate(update = {
"roi_subset": [subarray.roi_index]
})
else:
crop = None
new_lattice = self.copy_validate(update={
"input_image": subarray.data,
"time_range": range(1),
"channel_range": range(1),
"crop": crop,
**update_with
})
yield subarray.copy_with_data( new_lattice)
yield subarray.copy_with_data(new_lattice)

def generate_workflows(
self,
) -> Iterable[ProcessedSlice[Workflow]]:
"""
Yields copies of the input workflow, modified with the addition of deskewing and optionally,
cropping and deconvolution
"""
if self.workflow is None:
return

from copy import copy
# We make a copy of the lattice for each slice, each of which has no associated workflow
for lattice_slice in self.iter_sublattices(update_with={"workflow": None}):
user_workflow = copy(self.workflow)
# We add a step whose result is called "input_img" that outputs a 2D image slice
user_workflow.set(
"deskewed_image",
LatticeData.process_into_image,
lattice_slice.data
"""
Yields copies of the input workflow, modified with the addition of deskewing and optionally,
cropping and deconvolution
"""
from lls_core.workflow import workflow_set

if self.workflow is None:
return

from copy import copy
# We make a copy of the lattice for each slice, each of which has no associated workflow
# Also hide the progress bar for each sublattice, because we already have a global progress bar at this point
for lattice_slice in self.iter_sublattices(update_with={"workflow": None, "progress_bar": False}):
user_workflow = copy(self.workflow)
# We add a step whose result is called "input_img" that outputs a 2D image slice
user_workflow.set(
"deskewed_image",
LatticeData.process_into_image,
lattice_slice.data
)
# Also add channel metadata to the workflow
for key in {"channel", "channel_index", "time", "time_index", "roi_index"}:
workflow_set(
user_workflow,
key,
getattr(lattice_slice, key)
)
# Also add channel metadata to the workflow
for key in {"channel", "channel_index", "time", "time_index", "roi_index"}:
workflow_set(
user_workflow,
key,
getattr(lattice_slice, key)
)
# The user can use any of these arguments as inputs to their tasks
yield lattice_slice.copy_with_data(user_workflow)
# The user can use any of these arguments as inputs to their tasks
yield lattice_slice.copy_with_data(user_workflow)

def check_incomplete_acquisition(self, volume: ArrayLike, time_point: int, channel: int):
"""
Expand Down Expand Up @@ -332,44 +363,40 @@ def _process_crop(self) -> Iterable[ImageSlice]:
"""
Yields processed image slices with cropping enabled
"""
from tqdm import tqdm
if self.crop is None:
raise Exception("This function can only be called when crop is set")

# We have an extra level of iteration for the crop path: iterating over each ROI
for roi_index, roi in enumerate(tqdm(self.crop.selected_rois, desc="ROI", position=0)):
# pass arguments for save tiff, callable and function arguments
logger.info(f"Processing ROI {self.crop.roi_subset[roi_index]}")

for slice in self.iter_slices():
deconv_args: dict[Any, Any] = {}
if self.deconvolution is not None:
deconv_args = dict(
num_iter = self.deconvolution.psf_num_iter,
psf = self.deconvolution.psf[slice.channel].to_numpy(),
decon_processing=self.deconvolution.decon_processing
)

for slice in self.iter_slices():
roi_index = cast(int, slice.roi_index)
roi = self.crop.roi_list[roi_index]
deconv_args: dict[Any, Any] = {}
if self.deconvolution is not None:
deconv_args = dict(
num_iter = self.deconvolution.psf_num_iter,
psf = self.deconvolution.psf[slice.channel].to_numpy(),
decon_processing=self.deconvolution.decon_processing
)

yield slice.copy(update={
"data": crop_volume_deskew(
original_volume=slice.data,
deconvolution=self.deconv_enabled,
get_deskew_and_decon=False,
debug=False,
roi_shape=list(roi),
linear_interpolation=True,
voxel_size_x=self.dx,
voxel_size_y=self.dy,
voxel_size_z=self.dz,
angle_in_degrees=self.angle,
deskewed_volume=self.deskewed_volume,
z_start=self.crop.z_range[0],
z_end=self.crop.z_range[1],
**deconv_args
),
"roi_index": self.crop.roi_subset[roi_index]
})
yield slice.copy(update={
"data": crop_volume_deskew(
original_volume=slice.data,
deconvolution=self.deconv_enabled,
get_deskew_and_decon=False,
debug=False,
roi_shape=list(roi),
linear_interpolation=True,
voxel_size_x=self.dx,
voxel_size_y=self.dy,
voxel_size_z=self.dz,
angle_in_degrees=self.angle,
deskewed_volume=self.deskewed_volume,
z_start=self.crop.z_range[0],
z_end=self.crop.z_range[1],
**deconv_args
),
"roi_index": roi_index
})

def _process_non_crop(self) -> Iterable[ImageSlice]:
"""
Yields processed image slices without cropping
Expand Down Expand Up @@ -417,19 +444,20 @@ def process_workflow(self) -> WorkflowSlices:
"""
Runs the workflow on each slice and returns the workflow results
"""
from lls_core.workflow import get_workflow_output_name
from lls_core.models.results import WorkflowSlices
from lls_core.models.utils import as_tuple

WorkflowSlices.update_forward_refs(LatticeData=LatticeData)
outputs: list[ProcessedSlice[Any]] = []
for workflow in self.generate_workflows():
outputs.append(
workflow.copy_with_data(
# Evaluates the workflow here.
workflow.data.get(get_workflow_output_name(workflow.data))
)
)

def _generator() -> Iterable[ProcessedSlice[Tuple[RawWorkflowOutput, ...]]]:
for workflow in self.generate_workflows():
# Evaluates the workflow here.
result = workflow.data.get(get_workflow_output_name(workflow.data))
yield workflow.copy_with_data(as_tuple(result))

return WorkflowSlices(
slices=outputs,
slices=_generator(),
lattice_data=self
)

Expand Down
5 changes: 1 addition & 4 deletions core/lls_core/models/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,7 @@ def make_filepath_df(self, suffix: str, result: DataFrame) -> Path:
"""
Returns a filepath for the non-image data
"""
if isinstance(result, DataFrame):
return self.get_unique_filepath(self.save_dir / Path(self.save_name + suffix).with_suffix(".csv"))

return
return self.get_unique_filepath(self.save_dir / Path(self.save_name + suffix).with_suffix(".csv"))

def get_unique_filepath(self, path: Path) -> Path:
"""
Expand Down
Loading

0 comments on commit 061c94a

Please sign in to comment.