diff --git a/.github/workflows/plugin_preview.yml b/.github/workflows/plugin_preview.yml deleted file mode 100644 index bfd5c9ef..00000000 --- a/.github/workflows/plugin_preview.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: napari hub Preview Page # we use this name to find your preview page artifact, so don't change it! -# For more info on this action, see https://github.com/chanzuckerberg/napari-hub-preview-action/blob/main/action.yml - -on: - pull_request: - branches: - - "**" - paths: - - "**/README.md" - -jobs: - preview-page: - name: Preview Page Deploy - runs-on: ubuntu-latest - - steps: - - name: Checkout repo - uses: actions/checkout@v2 - - - name: napari hub Preview Page Builder - uses: chanzuckerberg/napari-hub-preview-action@v0.1.5 - with: - hub-ref: main diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index ac3906f8..71ae1b9e 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -1,105 +1,105 @@ -# This workflows will upload a Python Package using Twine when a release is created -# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries -# For pytest-qt related fixes: https://pytest-qt.readthedocs.io/en/latest/troubleshooting.html#github-actions - -name: tests - -on: - pull_request: - paths-ignore: - - "**/README.md" - push: - paths-ignore: - - "**/README.md" - -jobs: - test: - runs-on: ubuntu-latest - defaults: - run: - shell: bash -l {0} - strategy: - fail-fast: false - matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] - env: - DISPLAY: ":99.0" - steps: - - uses: actions/checkout@v3 - - - uses: conda-incubator/setup-miniconda@v2 - with: - auto-update-conda: true - activate-environment: test - python-version: ${{ matrix.python-version }} - channels: conda-forge - - - name: Conda info - run: conda info - - - name: Save conda location - run: echo "python=$(which python)" >> "$GITHUB_ENV" - - - name: Install core - timeout-minutes: 10 - run: | - conda install -y pyopencl pocl - python --version - pip install --upgrade pip setuptools wheel - pip install --use-pep517 -e './core[testing]' - - - name: Lint core - uses: ./.github/lint - with: - directory: core - python: ${{ env.python }} - - - name: Test core - run: pytest -v --cov=core --cov-report=xml core/ - - - uses: tlambert03/setup-qt-libs@v1 - - - name: Install plugin - timeout-minutes: 10 - run: | - /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1920x1200x24 -ac +extension GLX - pip install pytest-qt PyQt5 - pip install -e './plugin[testing]' - - - name: Lint plugin - uses: ./.github/lint - with: - directory: plugin - python: ${{ env.python }} - - - name: Test plugin - run: pytest -v --cov=plugin --cov-report=xml plugin - - - name: Coverage - uses: codecov/codecov-action@v3 - - deploy: - # this will run when you have tagged a commit, starting with "v*" - # and requires that you have put your twine API key in your - # github secrets (see readme for details) - needs: [test] - runs-on: ubuntu-latest - if: contains(github.ref, 'tags') - steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.x" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -U setuptools setuptools_scm wheel twine - - name: Build and publish - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.TWINE_TOKEN }} - run: | - git tag - python setup.py sdist bdist_wheel - twine upload dist/* +# This workflows will upload a Python Package using Twine when a release is created +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries +# For pytest-qt related fixes: https://pytest-qt.readthedocs.io/en/latest/troubleshooting.html#github-actions + +name: tests + +on: + pull_request: + paths-ignore: + - "**/README.md" + push: + paths-ignore: + - "**/README.md" + +jobs: + test: + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11"] + env: + DISPLAY: ":99.0" + steps: + - uses: actions/checkout@v3 + + - uses: conda-incubator/setup-miniconda@v2 + with: + auto-update-conda: true + activate-environment: test + python-version: ${{ matrix.python-version }} + channels: conda-forge + + - name: Conda info + run: conda info + + - name: Save conda location + run: echo "python=$(which python)" >> "$GITHUB_ENV" + + - name: Install core + timeout-minutes: 10 + run: | + conda install -y pyopencl pocl + python --version + pip install --upgrade pip setuptools wheel + pip install --use-pep517 -e './core[testing]' + + - name: Lint core + uses: ./.github/lint + with: + directory: core + python: ${{ env.python }} + + - name: Test core + run: pytest -v --cov=core --cov-report=xml core/ + + - uses: tlambert03/setup-qt-libs@v1 + + - name: Install plugin + timeout-minutes: 10 + run: | + /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1920x1200x24 -ac +extension GLX + pip install pytest-qt PyQt5 + pip install -e './plugin[testing]' + + - name: Lint plugin + uses: ./.github/lint + with: + directory: plugin + python: ${{ env.python }} + + - name: Test plugin + run: pytest -v --cov=plugin --cov-report=xml plugin + + - name: Coverage + uses: codecov/codecov-action@v3 + + deploy: + # this will run when you have tagged a commit, starting with "v*" + # and requires that you have put your twine API key in your + # github secrets (see readme for details) + needs: [test] + runs-on: ubuntu-latest + if: contains(github.ref, 'tags') + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.x" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -U setuptools setuptools_scm wheel twine + - name: Build and publish + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.TWINE_TOKEN }} + run: | + git tag + python setup.py sdist bdist_wheel + twine upload dist/* diff --git a/core/lls_core/__init__.py b/core/lls_core/__init__.py index 320f3780..eb2eddf6 100644 --- a/core/lls_core/__init__.py +++ b/core/lls_core/__init__.py @@ -1,26 +1,18 @@ __version__ = "0.2.6" - +from strenum import StrEnum from enum import Enum - -# Initialize configuration options - -#Deskew Direction from pyclesperanto_prototype._tier8._affine_transform_deskew_3d import DeskewDirection +# Initialize configuration options #Choice of Deconvolution -class DeconvolutionChoice(Enum): +class DeconvolutionChoice(StrEnum): cuda_gpu = "cuda_gpu" opencl_gpu = "opencl_gpu" cpu = "cpu" -#Choice of File extension to save -class SaveFileType(Enum): - h5 = "h5" - tiff = "tiff" - #CONFIGURE LOGGING using a dictionary (can also be done with yaml file) import logging.config LOGGING_CONFIG = { diff --git a/core/lls_core/cmds/__main__.py b/core/lls_core/cmds/__main__.py index 66ae8d74..a2fabccf 100644 --- a/core/lls_core/cmds/__main__.py +++ b/core/lls_core/cmds/__main__.py @@ -1,772 +1,235 @@ # lattice_processing.py + # Run processing on command line instead of napari. # Example for deskewing files in a folder # python lattice_processing.py --input /home/pradeep/to_deskew --output /home/pradeep/output_save/ --processing deskew +from __future__ import annotations -import argparse -import os -import glob -import sys -import re -from lls_core.io import save_img, save_img_workflow -from lls_core.lattice_data import lattice_from_aics -from lls_core.utils import read_imagej_roi, get_all_py_files, get_first_last_image_and_task, modify_workflow_task, check_dimensions -from lls_core.llsz_core import crop_volume_deskew -from aicsimageio import AICSImage -from aicsimageio.types import PhysicalPixelSizes -import pyclesperanto_prototype as cle -from tqdm import tqdm -import dask.array as da -from napari_workflows._io_yaml_v1 import load_workflow +from enum import auto from pathlib import Path -import yaml -from typing import Sequence, TYPE_CHECKING - -from lls_core.deconvolution import read_psf -from lls_core import DeskewDirection, DeconvolutionChoice, SaveFileType -from enum import Enum - -from napari_workflows import Workflow - - -# define parser class so as to print help message -class ArgParser(argparse.ArgumentParser): - def error(self, message): - sys.stderr.write('error: %s\n' % message) - self.print_help() - sys.exit(2) - - -class ProcessingOptions(Enum): - deskew = "deskew" - crop = "crop" - workflow = "workflow" - workflow_crop = "workflow_crop" - - -def make_parser(): - """ Parse input arguments""" - parser = argparse.ArgumentParser(description="Lattice Data Analysis") - parser.add_argument('--input', type=str, nargs=1, help="Enter input file", - required=False) # only false if using config file - parser.add_argument('--output', type=str, nargs=1, help="Enter save folder", - required=False) # only false if using config file - parser.add_argument('--skew_direction', type=DeskewDirection, nargs=1, - help="Enter the direction of skew (default is Y)", - action="store", - choices=("Y", "X"), - default=DeskewDirection.Y) - parser.add_argument('--deskew_angle', type=float, nargs=1, - help="Enter the deskew angle (default is 30)", - default=30.0) - parser.add_argument('--processing', type=ProcessingOptions, nargs=1, - help="Enter the processing option: deskew, crop, workflow or workflow_crop", required=False, - action="store", - choices=(ProcessingOptions.deskew, ProcessingOptions.crop, - ProcessingOptions.workflow, ProcessingOptions.workflow_crop), - default=None) - parser.add_argument('--deconvolution', type=DeconvolutionChoice, nargs=1, - help="Specify the device to use for deconvolution. Options are cpu or cuda_gpu", - action="store") - parser.add_argument('--deconvolution_num_iter', type=int, nargs=1, - help="Enter the number of iterations to run Richardson-Lucy deconvolution (default is 10)") - parser.add_argument('--deconvolution_psf', type=str, nargs="+", - help="Enter paths to psf file/s separated by commas or you can enter each path with double quotes") # use + for nargs for flexible no of args - parser.add_argument('--roi_file', type=str, nargs=1, - help="Enter the path to the ROI file for performing cropping (only valid for -processing where crop or workflow_crop is specified") - parser.add_argument('--voxel_sizes', type=float, nargs=3, - help="Enter the voxel sizes as (dz,dy,dx). Make sure they are in brackets", - default=(0.3, 0.1499219272808386, 0.1499219272808386)) - parser.add_argument('--file_extension', type=str, nargs=1, - help="If choosing a folder, enter the extension of the files (make sure you enter it with the dot at the start, i.e., .czi or .tif), else .czi and .tif files will be used") - parser.add_argument('--time_range', type=int, nargs=2, - help="Enter time range to extract, default will be entire timeseries if no range is specified. For example, 0 9 will extract first 10 timepoints", - default=(None, None)) - parser.add_argument('--channel_range', type=int, nargs=2, - help="Enter channel range to extract, default will be all channels if no range is specified. For example, 0 1 will extract first two channels.", - default=(None, None)) - parser.add_argument('--workflow_path', type=str, nargs=1, - help="Enter path to the workflow file '.yml", default=[None]) - parser.add_argument('--output_file_type', type=SaveFileType, nargs=1, - help="Save as either tiff or h5, defaults to h5", - action="store", - choices=(SaveFileType.tiff, SaveFileType.h5), - default=[SaveFileType.h5]), - parser.add_argument('--channel', type=bool, nargs=1, - help="If input is a tiff file and there are channel dimensions but no time dimensions, choose as True", - default=False) - parser.add_argument('--config', type=str, nargs=1, - help="Location of config file, all other arguments will be ignored and overwriten by those in the yaml file", default=None) - parser.add_argument('--roi_number', type=int, nargs=1, - help="Process an individual ROI, loop all if unset", default=None) - parser.add_argument('--set_logging', type=str, nargs=1, - help="Set logging level [INFO,DEBUG]", default=["INFO"]) - return parser - -def main(argv: Sequence[str] = sys.argv[1:]): - parser = make_parser() - args = parser.parse_args(argv) - - # Enable Logging - import logging - logger = logging.getLogger(__name__) - - # Setting empty strings for psf paths - psf_ch1_path = "" - psf_ch2_path = "" - psf_ch3_path = "" - psf_ch4_path = "" - - # IF using a config file, set a lot of parameters here - # the rest are scattered throughout the code when needed - # could be worth bringing everything up top - print(args.config) - - if args.config: - try: - with open(args.config[0], 'r') as con: - try: - processing_parameters = yaml.safe_load(con) - - except Exception as exc: - print(exc) - - except FileNotFoundError as exc: - exit(f"Config yml file {args.config[0]} not found, please specify") - - if not processing_parameters: - logging.error(f"Config file not loaded properly") - # this is how I propose setting the command line variables - # If they're in the config, get from there. If not - # Look in command line. If not there, then exit - if 'input' in processing_parameters: - input_path = processing_parameters['input'] - elif args.input is not None: - input_path = args.input[0] - else: - exit("Input not set") - print("Processing file %s" % input_path) - - if 'output' in processing_parameters: - output_path = processing_parameters['output'] - elif args.output is not None: - output_path = args.output[0] + os.sep - else: - exit("Output not set") - - # If requried setting not in config file = use defaults from argparse - # get method for a dictionary will get a value if specified, if not, will use value from args as default value - - dz, dy, dx = processing_parameters.get('voxel_sizes', args.voxel_sizes) - channel_dimension = processing_parameters.get('channel', args.channel) - skew_dir = processing_parameters.get( - 'skew_direction', DeskewDirection.Y) - deskew_angle = processing_parameters.get('deskew_angle', 30.0) - processing = ProcessingOptions[processing_parameters.get( - 'processing', None).lower()] - time_start, time_end = processing_parameters.get( - 'time_range', (None, None)) - channel_start, channel_end = processing_parameters.get( - 'channel_range', (None, None)) - output_file_type = SaveFileType[processing_parameters.get( - 'output_file_type', args.output_file_type).lower()] - - # to allow for either/or CLI/config file Todo for rest of parameters? - if 'roi_number' in processing_parameters: - roi_to_process = processing_parameters.get('roi_number') - elif args.roi_number is not None: - roi_to_process = args.roi_number[0] +from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from strenum import StrEnum + +from lls_core.models.lattice_data import LatticeData +from lls_core.models.deskew import DeskewParams, DefinedPixelSizes +from lls_core.models.deconvolution import DeconvolutionParams +from lls_core.models.output import OutputParams +from lls_core.models.crop import CropParams +from lls_core import DeconvolutionChoice +from typer import Typer, Argument, Option, Context, Exit + +from lls_core.models.output import SaveFileType +from pydantic import ValidationError + +if TYPE_CHECKING: + from lls_core.models.utils import FieldAccessModel + from typing import Type, Any, Iterable + from rich.table import Table + +class CliDeskewDirection(StrEnum): + X = auto() + Y = auto() + +CLI_PARAM_MAP = { + "input_image": ["input_image"], + "angle": ["angle"], + "skew": ["skew"], + "pixel_sizes": ["physical_pixel_sizes"], + "rois": ["crop", "roi_list"], + "roi_indices": ["crop", "roi_subset"], + "z_start": ["crop", "z_range", 0], + "z_end": ["crop", "z_range", 1], + "decon_processing": ["deconvolution", "decon_processing"], + "psf": ["deconvolution", "psf"], + "psf_num_iter": ["deconvolution", "psf_num_iter"], + "background": ["deconvolution", "background"], + "workflow": ["workflow"], + "time_start": ["time_range", 0], + "time_end": ["time_range", 1], + "channel_start": ["channel_range", 0], + "channel_end": ["channel_range", 1], + "save_dir": ["save_dir"], + "save_name": ["save_name"], + "save_type": ["save_type"], +} + +app = Typer(add_completion=False, rich_markup_mode="rich", no_args_is_help=True) + +def field_from_model(model: Type[FieldAccessModel], field_name: str, extra_description: str = "", description: Optional[str] = None, default: Optional[Any] = None, **kwargs) -> Any: + """ + Generates a type Field from a Pydantic model field + """ + field = model.__fields__[field_name] + + from enum import Enum + if default is None: + default = field.get_default() + if isinstance(default, Enum): + default = default.name + + if description is None: + description = f"{field.field_info.description} {extra_description}" + + return Option( + default = default, + help=description, + **kwargs + ) + +def handle_merge(values: list): + if len(values) > 1: + raise ValueError(f"A parameter has been passed multiple times! Got: {', '.join(values)}") + return values[0] + +def rich_validation(e: ValidationError) -> Table: + """ + Converts + """ + from rich.table import Table + + table = Table(title="Validation Errors") + table.add_column("Parameter") + # table.add_column("Command Line Argument") + table.add_column("Error") + + for error in e.errors(): + table.add_row( + str(error["loc"][0]), + str(error["msg"]), + ) + + return table + +def pairwise(iterable: Iterable) -> Iterable: + """ + An implementation of the pairwise() function in Python 3.10+ + See: https://docs.python.org/3.12/library/itertools.html#itertools.pairwise + """ + from itertools import tee + a, b = tee(iterable) + next(b, None) + return zip(a, b) + +def update_nested_data(data: Union[dict, list], keys: list, new_value: Any): + current = data + + for key, next_key in pairwise(keys): + next = {} if isinstance(next_key, str) else [] + if isinstance(current, dict): + current = current.setdefault(key, next) + elif isinstance(current, list): + if key >= len(current): + current.insert(key, next) + current = current[key] else: - roi_to_process = None - - log_level = processing_parameters.get('--set_logging', "INFO") - workflow_path = processing_parameters.get('workflow_path', None) - if workflow_path is not None: - workflow_path = Path(workflow_path) - - logging.basicConfig(level=log_level.upper()) - logging.info(f"Logging set to {log_level.upper()}") - - if not processing: - logging.error("Processing option not set.") - exit() - - deconvolution = processing_parameters.get('deconvolution', None) - if deconvolution is not None: - deconvolution = DeconvolutionChoice[deconvolution.lower()] - psf_arg = "psf" - deconvolution_num_iter = processing_parameters.get( - 'deconvolution_num_iter', 10) - psf_paths = processing_parameters.get('deconvolution_psf') - logging.debug(psf_paths) - if not psf_paths: - logging.error("PSF paths not set option not set.") - exit() - else: - psf_ch1_path = psf_paths[0].replace(",", "").strip() - psf_ch2_path = psf_paths[1].replace(",", "").strip() - psf_ch3_path = psf_paths[2].replace(",", "").strip() - psf_ch4_path = psf_paths[3].replace(",", "").strip() + raise ValueError(f"Unknown data type {type(current)}. Cannot traverse.") - else: - deconvolution = False - - roi_file: str - if processing == ProcessingOptions.crop or processing == ProcessingOptions.workflow_crop: - if 'roi_file' in processing_parameters: - roi_file = processing_parameters.get('roi_file', False) - elif args.roi_file is not None: - roi_file = args.roi_file[0] - else: - exit("Specify roi file") - assert os.path.exists(roi_file), "Cannot find " + roi_file - print("Processing using roi file %s" % roi_file) - - assert os.path.exists(input_path), "Cannot find input " + input_path - assert os.path.exists(output_path), "Cannot find output " + output_path - - file_extension = processing_parameters.get( - 'file_extension', [".czi", ".tif", ".tiff"]) - - # setting (some - see above) parameters from CLI - else: # if not using config file - input_path = args.input[0] - output_path = args.output[0] + os.sep - dz, dy, dx = args.voxel_sizes - deskew_angle = args.deskew_angle - channel_dimension = args.channel - time_start, time_end = args.time_range - channel_start, channel_end = args.channel_range - skew_dir = args.skew_direction - processing = args.processing[0] - output_file_type = args.output_file_type[0] - roi_to_process = args.roi_number - workflow_path = args.workflow_path[0] - - log_level = args.set_logging[0] - logging.basicConfig(level=log_level.upper()) - logging.info(f"Logging set to {log_level.upper()}") - - if roi_to_process: - logging.info(f"Processing ROI {roi_to_process}") - - if not processing: - logging.error("Processing option not set.") - exit() - - # deconvolution - if args.deconvolution: - deconvolution = args.deconvolution[0] - psf_arg = "psf" - if args.deconvolution_psf: - psf_paths = re.split(';|,', args.deconvolution_psf[0]) - logging.debug(psf_paths) - try: - psf_ch1_path = psf_paths[0] - psf_ch2_path = psf_paths[1] - psf_ch3_path = psf_paths[2] - psf_ch4_path = psf_paths[3] - except IndexError: - pass - else: - logging.error("PSF paths not set option not set.") - exit() - # num of iter default is 10 if nothing specified - if args.deconvolution_num_iter: - deconvolution_num_iter = args.deconvolution_num_iter - else: - deconvolution_num_iter = 10 - else: - deconvolution = False - - # output file type to save - if not output_file_type: - output_file_type = SaveFileType.h5 - - # Get - if processing == ProcessingOptions.crop or processing == ProcessingOptions.workflow_crop: - assert args.roi_file, "Specify roi_file (ImageJ/FIJI ROI Zip file)" - roi_file = args.roi_file[0] - if os.path.isfile(roi_file): # if file make sure it is a zip file or roi file - - roi_file_extension = os.path.splitext(roi_file)[1] - assert roi_file_extension == ".zip" or roi_file_extension == ".roi", "ROI file is not a zip or .roi file" - - # Check if input and output paths exist - assert os.path.exists(input_path), "Cannot find input " + input_path - assert os.path.exists(output_path), "Cannot find output " + output_path - - if not args.file_extension: - file_extension = [".czi", ".tif", ".tiff"] - else: - file_extension = args.file_extension - - # Initialise list of images and ROIs - img_list = [] - roi_list = [] - logging.debug(f"Deconvolution is set to {deconvolution} and option is ") - - logging.debug(f"Output file type is {output_file_type}") - # If input_path a directory, get a list of images - if os.path.isdir(input_path): - for file_type in file_extension: - img_list.extend(glob.glob(input_path + os.sep + '*' + file_type)) - print("List of images: ", img_list) - elif os.path.isfile(input_path) and (os.path.splitext(input_path))[1] in file_extension: - # if a single file, just add filename to the image list - img_list.append(input_path) + last_key = keys[-1] + if isinstance(current, dict): + current[last_key] = new_value + elif isinstance(current, list): + current.insert(last_key, new_value) else: - sys.exit("Do not recognise " + input_path + " as directory or file") - - # If cropping, get list of roi files with matching image names - if processing == ProcessingOptions.crop or processing == ProcessingOptions.workflow_crop: - if os.path.isdir(roi_file): - for img in img_list: - img_name = os.path.basename(os.path.splitext(img)[0]) - roi_temp = roi_file + os.sep + img_name + ".zip" - - if os.path.exists(roi_temp): - roi_list.append(roi_temp) - else: - sys.exit("Cannot find ROI file for " + img) - - print("List of ROIs: ", roi_list) - elif os.path.isfile(roi_file): - roi_list.append(roi_file) - assert len(roi_list) == len( - img_list), "Image and ROI lists do not match" - else: - # add list of empty strings so that it can run through for loop - no_files = len(img_list) - roi_list = [""] * no_files - - # loop through the list of images and rois - for img, roi_path in zip(img_list, roi_list): - save_name = os.path.splitext(os.path.basename(img))[0] - - print("Processing Image " + img) - if processing == ProcessingOptions.crop or processing == ProcessingOptions.workflow_crop: - print("Processing ROI " + roi_path) - aics_img = AICSImage(img) - - # check if scene valid; if not it iterates through all scenes - len_scenes = len(aics_img.scenes) - for scene in range(len_scenes): - aics_img.set_scene(scene) - test = aics_img.get_image_dask_data("YX", T=0, C=0, Z=0) - try: - test_max = test.max().compute() - if test_max: - print(f"Scene {scene} is valid") - break - except Exception: - print(f"Scene {scene} not valid") - - # Initialize Lattice class - lattice = lattice_from_aics(aics_img, angle=deskew_angle, skew=skew_dir, save_name=save_name, physical_pixel_sizes=PhysicalPixelSizes(dx, dy, dz)) - - if time_start is None or time_end is None: - time_start, time_end = 0, lattice.time - 1 - - if channel_start is None or channel_end is None: - channel_start, channel_end = 0, lattice.channels - 1 - - # Verify dimensions - check_dimensions(time_start, time_end, channel_start, - channel_end, lattice.channels, lattice.time) - print("dimensions verified") - # If deconvolution, set the parameters in the lattice class - if deconvolution: - lattice.decon_processing = deconvolution - lattice.psf_num_iter = deconvolution_num_iter - logging.debug(f"Num of iterations decon, {lattice.psf_num_iter}") - logging.info("Performing Deconvolution") - lattice.psf = [] - lattice.otf_path = [] - # Remove empty values and convert to Path - lattice.psf = list(read_psf([Path(it) for it in [psf_ch1_path, psf_ch2_path, psf_ch3_path, psf_ch4_path] if it is not None], decon_option=lattice.decon_processing, lattice_class=lattice)) - - else: - lattice.decon_processing = None - - # Override pixel values by reading metadata if file is czi - if os.path.splitext(img)[1] == ".czi": - dz, dy, dx = lattice.dz, lattice.dy, lattice.dx - logging.info(f"Pixel values from metadata (zyx): {dz},{dy},{dx}") - - # Setup workflows based on user input - if processing == ProcessingOptions.workflow or processing == ProcessingOptions.workflow_crop: - # load workflow from path - # if args.config and 'workflow_path' in processing_parameters: - #workflow_path = Path(processing_parameters['workflow_path']) - # else: - #workflow_path = Path(workflow_path) - - # load custom modules (*.py) in same directory as workflow file - import importlib - parent_dir = Path(workflow_path).resolve().parents[0].__str__() + os.sep - - sys.path.append(parent_dir) - custom_py_files = get_all_py_files(parent_dir) - - if len(custom_py_files) > 0: - modules = map(importlib.import_module, custom_py_files) - logging.info(f"Custom modules imported {modules}") - - # workflow has to be reloaded for each image and reinitialised - user_workflow = load_workflow(workflow_path.__str__()) - if not isinstance(user_workflow, Workflow): - raise ValueError("Workflow file is not a napari workflow object. Check file!") - - input_arg_first, input_arg_last, first_task_name, last_task_name = get_first_last_image_and_task( - user_workflow) - logging.debug(input_arg_first, input_arg_last, - first_task_name, last_task_name) - - # get list of tasks - task_list = list(user_workflow._tasks.keys()) - - print("Workflow loaded:") - logging.info(user_workflow) - - task_name_start = first_task_name[0] - try: - task_name_last = last_task_name[0] - except IndexError: - task_name_last = task_name_start - - # if workflow involves cropping, assign first task as crop_volume_deskew - if processing == ProcessingOptions.workflow_crop: - deskewed_shape = lattice.deskew_vol_shape - deskewed_volume = da.zeros(deskewed_shape) - z_start = 0 - z_end = deskewed_shape[0] - roi = "roi" - volume = "volume" - # Create workflow for cropping and deskewing - # volume and roi used will be set dynamically - user_workflow.set("crop_deskew", crop_volume_deskew, - original_volume=volume, - deskewed_volume=deskewed_volume, - roi_shape=roi, - angle_in_degrees=deskew_angle, - voxel_size_x=dx, - voxel_size_y=dy, - voxel_size_z=dz, - z_start=z_start, - z_end=z_end, - skew_dir=lattice.skew) - # change the first task so it accepts "crop_deskew as input" - new_task = modify_workflow_task( - old_arg=input_arg_first, task_key=task_name_start, new_arg="crop_deskew", workflow=user_workflow) - user_workflow.set(task_name_start, new_task) - - elif processing == ProcessingOptions.workflow: - # Verify if deskewing function is in workflow; if not, add as first task - if user_workflow.get_task(task_name_start)[0] not in (cle.deskew_y, cle.deskew_x): - custom_workflow = True - input = "input" - - # add task to the workflow - user_workflow.set("deskew_image", lattice.deskew_func, - input_image=input, - angle_in_degrees=deskew_angle, - voxel_size_x=dx, - voxel_size_y=dy, - voxel_size_z=dz, - linear_interpolation=True) - # Set input of the workflow to be from deskewing - # change the first task so it accepts "deskew_image" as input - new_task = modify_workflow_task(old_arg=input_arg_first, task_key=task_name_start, - new_arg="deskew_image", workflow=user_workflow) - user_workflow.set(task_name_start, new_task) - else: - custom_workflow = False - - img_data = lattice.data - - # Create save directory for each image - save_path = output_path + os.sep + \ - os.path.basename(os.path.splitext(img)[0]) + os.sep - - if not os.path.exists(save_path): - try: - os.mkdir(save_path) - except FileExistsError: - # this is sometimes caused when running parallel jobs - # can safely be ignored (I hope) - pass - - logging.info(f"Saving at {save_path}") - - # Deskewing only - - if processing == ProcessingOptions.deskew: - # deconvolution - if lattice.decon_processing: - save_img(vol=img_data, - func=lattice.deskew_func, - time_start=time_start, - time_end=time_end, - channel_start=channel_start, - channel_end=channel_end, - save_path=save_path, - save_name=save_name, - save_file_type=output_file_type, - dx=dx, - dy=dy, - dz=dz, - angle=deskew_angle, - terminal=True, - lattice=lattice, - angle_in_degrees=deskew_angle, - voxel_size_x=dx, - voxel_size_y=dy, - voxel_size_z=dz, - linear_interpolation=True - ) - - else: - save_img(vol=img_data, - func=lattice.deskew_func, - time_start=time_start, - time_end=time_end, - channel_start=channel_start, - channel_end=channel_end, - save_path=save_path, - save_name=save_name, - save_file_type=output_file_type, - dx=dx, - dy=dy, - dz=dz, - angle=deskew_angle, - angle_in_degrees=deskew_angle, - voxel_size_x=dx, - voxel_size_y=dy, - voxel_size_z=dz - ) - - # Crop and deskew - elif processing == ProcessingOptions.crop or processing == ProcessingOptions.workflow_crop: - print(roi_path) - roi_img = read_imagej_roi(roi_path) - - deskewed_shape = lattice.deskew_vol_shape - deskewed_volume = da.zeros(deskewed_shape) - - # Can modify for entering custom z values - z_start = 0 - z_end = deskewed_shape[0] - - # if roi number is specified, roi_img will be a list containing only one roi. - if roi_to_process is not None: - # just do one ROI - assert roi_to_process < len( - roi_img), f"ROI specified is {roi_to_process}, which is less than total ROIs ({len(roi_img)})" - logging.info(f"Processing single ROI: {roi_to_process}") - # If only one ROI, single loop - roi_img = [roi_img[roi_to_process]] - - # loop through rois in the roi list - for idx, roi_layer in enumerate(tqdm(roi_img, desc="ROI:", position=0)): - - if roi_to_process is not None: - roi_label = str(roi_to_process) - else: - roi_label = str(idx) - - print("Processing ROI " + str(idx) + - " of " + str(len(roi_img))) - deskewed_shape = lattice.deskew_vol_shape - deskewed_volume = da.zeros(deskewed_shape) - - # Can modify for entering custom z values - z_start = 0 - z_end = deskewed_shape[0] - - if processing == ProcessingOptions.crop: - # deconvolution - if lattice.decon_processing: - save_img(img_data, - func=crop_volume_deskew, - time_start=time_start, - time_end=time_end, - channel_start=channel_start, - channel_end=channel_end, - save_name_prefix="ROI_" + roi_label + "_", - save_path=save_path, - save_name=save_name, - save_file_type=output_file_type, - dx=dx, - dy=dy, - dz=dz, - angle=deskew_angle, - terminal=True, - lattice=lattice, - deskewed_volume=deskewed_volume, - roi_shape=roi_layer, - angle_in_degrees=deskew_angle, - z_start=z_start, - z_end=z_end, - voxel_size_x=dx, - voxel_size_y=dy, - voxel_size_z=dz, - ) - else: - print("SHOULD BE DOING THIS") - print(save_path) - print(save_name) - save_img(img_data, - func=crop_volume_deskew, - time_start=time_start, - time_end=time_end, - channel_start=channel_start, - channel_end=channel_end, - save_name_prefix="ROI_" + roi_label + "_", - save_path=save_path, - save_name=save_name, - save_file_type=output_file_type, - dx=dx, - dy=dy, - dz=dz, - angle=deskew_angle, - deskewed_volume=deskewed_volume, - roi_shape=roi_layer, - angle_in_degrees=deskew_angle, - z_start=z_start, - z_end=z_end, - voxel_size_x=dx, - voxel_size_y=dy, - voxel_size_z=dz, - ) - - elif processing == ProcessingOptions.workflow_crop: - # deconvolution - user_workflow.set(roi, roi_layer) - - if lattice.decon_processing: - save_img_workflow(vol=img_data, - workflow=user_workflow, - input_arg=volume, - first_task="crop_deskew", - last_task=task_name_last, - time_start=time_start, - time_end=time_end, - channel_start=channel_start, - channel_end=channel_end, - save_path=save_path, - save_name_prefix="ROI_" + roi_label, - save_name=save_name, - save_file_type=output_file_type, - dx=dx, - dy=dy, - dz=dz, - angle=deskew_angle, - deconvolution=True, - decon_processing=lattice.decon_processing, - psf=lattice.psf, - psf_arg=psf_arg) - else: - save_img_workflow(vol=img_data, - workflow=user_workflow, - input_arg=volume, - first_task="crop_deskew", - last_task=task_name_last, - time_start=time_start, - time_end=time_end, - channel_start=channel_start, - channel_end=channel_end, - save_path=save_path, - save_file_type=output_file_type, - save_name_prefix="ROI_" + roi_label, - save_name=save_name, - dx=dx, - dy=dy, - dz=dz, - angle=deskew_angle, - deconvolution=False) - - elif processing == ProcessingOptions.workflow: - # if deskew_image task set above manually - if custom_workflow: - if lattice.decon_processing: - save_img_workflow(vol=img_data, - workflow=user_workflow, - input_arg=input, - first_task="deskew_image", - last_task=task_name_last, - time_start=time_start, - time_end=time_end, - channel_start=channel_start, - channel_end=channel_end, - save_path=save_path, - save_name=save_name, - save_file_type=output_file_type, - dx=dx, - dy=dy, - dz=dz, - angle=deskew_angle, - deconvolution=True, - decon_processing=lattice.decon_processing, - psf=lattice.psf, - psf_arg=psf_arg) - else: - save_img_workflow(vol=img_data, - workflow=user_workflow, - input_arg=input, - first_task="deskew_image", - last_task=task_name_last, - time_start=time_start, - time_end=time_end, - channel_start=channel_start, - channel_end=channel_end, - save_path=save_path, - save_name=save_name, - save_file_type=output_file_type, - dx=dx, - dy=dy, - dz=dz, - angle=deskew_angle) - else: - if lattice.decon_processing: - save_img_workflow(vol=img_data, - workflow=user_workflow, - input_arg=input_arg_first, - first_task=first_task_name, - last_task=task_name_last, - time_start=time_start, - time_end=time_end, - channel_start=channel_start, - channel_end=channel_end, - save_path=save_path, - save_name=save_name, - save_file_type=output_file_type, - dx=dx, - dy=dy, - dz=dz, - angle=deskew_angle, - deconvolution=True, - decon_processing=lattice.decon_processing, - psf=lattice.psf, - psf_arg=psf_arg) - else: - save_img_workflow(vol=img_data, - workflow=user_workflow, - input_arg=input_arg_first, - first_task=first_task_name, - last_task=task_name_last, - time_start=time_start, - time_end=time_end, - channel_start=channel_start, - channel_end=channel_end, - save_path=save_path, - save_name=save_name, - save_file_type=output_file_type, - dx=dx, - dy=dy, - dz=dz, - angle=deskew_angle) - + raise ValueError(f"Unknown data type {type(current)}. Cannot traverse.") + +# Example usage: +@app.command() +def process( + ctx: Context, + input_image: Path = Argument(None, help="Path to the image file to read, in a format readable by AICSImageIO, for example .tiff or .czi", show_default=False), + skew: CliDeskewDirection = field_from_model(DeskewParams, "skew"),# DeskewParams.make_typer_field("skew"), + angle: float = field_from_model(DeskewParams, "angle") , + pixel_sizes: Tuple[float, float, float] = field_from_model(DeskewParams, "physical_pixel_sizes", extra_description="This takes three arguments, corresponding to the Z, Y and X pixel dimensions respectively", default=( + DefinedPixelSizes.get_default("Z"), + DefinedPixelSizes.get_default("Y"), + DefinedPixelSizes.get_default("X") + )), + + rois: List[Path] = field_from_model(CropParams, "roi_list", description="A list of paths pointing to regions of interest to crop to, in ImageJ format."), #Option([], help="A list of paths pointing to regions of interest to crop to, in ImageJ format."), + roi_indices: List[int] = field_from_model(CropParams, "roi_subset"), + # Ideally this and other range values would be defined as Tuples, but these seem to be broken: https://github.com/tiangolo/typer/discussions/667 + z_start: Optional[int] = Option(0, help="The index of the first Z slice to use. All prior Z slices will be discarded.", show_default=False), + z_end: Optional[int] = Option(None, help="The index of the last Z slice to use. The selected index and all subsequent Z slices will be discarded. Defaults to the last z index of the image.", show_default=False), + + enable_deconvolution: bool = Option(False, "--deconvolution/--disable-deconvolution", rich_help_panel="Deconvolution"), + decon_processing: DeconvolutionChoice = field_from_model(DeconvolutionParams, "decon_processing", rich_help_panel="Deconvolution"), + psf: List[Path] = field_from_model(DeconvolutionParams, "psf", description="One or more paths pointing to point spread functions to use for deconvolution. Each file should in a standard image format (.czi, .tiff etc), containing a 3D image array. This option can be used multiple times to provide multiple PSF files.", rich_help_panel="Deconvolution"), + psf_num_iter: int = field_from_model(DeconvolutionParams, "psf_num_iter", rich_help_panel="Deconvolution"), + background: str = field_from_model(DeconvolutionParams, "background", rich_help_panel="Deconvolution"), + + time_start: Optional[int] = Option(0, help="Index of the first time slice to use (inclusive). Defaults to the first time index of the image.", rich_help_panel="Output"), + time_end: Optional[int] = Option(None, help="Index of the first time slice to use (exclusive). Defaults to the last time index of the image.", show_default=False, rich_help_panel="Output"), + + channel_start: Optional[int] = Option(0, help="Index of the first channel slice to use (inclusive). Defaults to the first channel index of the image.", rich_help_panel="Output"), + channel_end: Optional[int] = Option(None, help="Index of the first channel slice to use (exclusive). Defaults to the last channel index of the image.", show_default=False, rich_help_panel="Output"), + + save_dir: Path = field_from_model(OutputParams, "save_dir", rich_help_panel="Output"), + save_name: Optional[str] = field_from_model(OutputParams, "save_name", rich_help_panel="Output"), + save_type: SaveFileType = field_from_model(OutputParams, "save_type", rich_help_panel="Output"), + + workflow: Optional[Path] = Option(None, help="Path to a Napari Workflow file, in YAML format. If provided, the configured desekewing processing will be added to the chosen workflow.", show_default=False), + json_config: Optional[Path] = Option(None, show_default=False, help="Path to a JSON file from which parameters will be read."), + yaml_config: Optional[Path] = Option(None, show_default=False, help="Path to a YAML file from which parameters will be read."), + + show_schema: bool = Option(default=False, help="If provided, image processing will not be performed, and instead a JSON document outlining the JSON/YAML options will be printed to stdout. This can be used to assist with writing a config file for use with the --json-config and --yaml-config options.") +) -> None: + from click.core import ParameterSource + from rich.console import Console + + console = Console(stderr=True) + + if show_schema: + import json + import sys + json.dump( + LatticeData.to_definition_dict(), + sys.stdout, + indent=4 + ) + return + + # Just print help if the user didn't provide any arguments + if all(src != ParameterSource.COMMANDLINE for src in ctx._parameter_source.values()): + print(ctx.get_help()) + raise Exit() + + from toolz.dicttoolz import merge_with + cli_args = {} + for source, dest in CLI_PARAM_MAP.items(): + from click.core import ParameterSource + if ctx.get_parameter_source(source) != ParameterSource.DEFAULT: + update_nested_data(cli_args, dest, ctx.params[source]) + + json_args = {} + if json_config is not None: + import json + with json_config.open() as fp: + json_args = json.load(fp) + + yaml_args = {} + if yaml_config is not None: + with yaml_config.open() as fp: + from yaml import safe_load + yaml_args = safe_load(fp) + + try: + lattice = LatticeData.parse_obj( + # Merge all three sources of config: YAML, JSON and CLI + merge_with( + handle_merge, + [yaml_args, json_args, cli_args] + ) + ) + except ValidationError as e: + console.print(rich_validation(e)) + raise Exit(code=1) + + lattice.save() + console.print(f"Processing successful. Results can be found in {lattice.save_dir.resolve()}") + + +def main(): + app() if __name__ == '__main__': main() diff --git a/core/lls_core/config.py b/core/lls_core/config.py index 84601679..c63a3aec 100644 --- a/core/lls_core/config.py +++ b/core/lls_core/config.py @@ -2,5 +2,5 @@ #https://docs.python.org/3/faq/programming.html?highlight=global#how-do-i-share-global-variables-across-modules channel = 0 time = 0 -#configure default logging level for use in packages; inherit from _dock_widget or __main__.py +#configure default logging level for use in packages; inherit from dock_widget or __main__.py log_level = 10 \ No newline at end of file diff --git a/core/lls_core/cropping.py b/core/lls_core/cropping.py new file mode 100644 index 00000000..2f0c35a9 --- /dev/null +++ b/core/lls_core/cropping.py @@ -0,0 +1,75 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, NamedTuple, Tuple, List + +if TYPE_CHECKING: + from lls_core.types import PathLike + from typing_extensions import Self + from numpy.typing import NDArray + +RoiCoord = Tuple[float, float] + +class Roi(NamedTuple): + top_left: RoiCoord + top_right: RoiCoord + bottom_left: RoiCoord + bottom_right: RoiCoord + + @classmethod + def from_array(cls, array: NDArray) -> Self: + import numpy as np + return Roi(*np.reshape(array, (-1, 2)).tolist()) + +def read_roi_array(roi: PathLike) -> NDArray: + from read_roi import read_roi_file + from numpy import array + return array(read_roi_file(str(roi))) + +def read_imagej_roi(roi_path: PathLike) -> List[Roi]: + """Read an ImageJ ROI zip file so it loaded into napari shapes layer + If non rectangular ROI, will convert into a rectangle based on extreme points + Args: + roi_zip_path (zip file): ImageJ ROI zip file + + Returns: + list: List of ROIs + """ + from pathlib import Path + from os import fspath + from read_roi import read_roi_file, read_roi_zip + + roi_path = Path(fspath(roi_path)) + + # handle reading single roi or collection of rois in zip file + if roi_path.suffix == ".zip": + ij_roi = read_roi_zip(roi_path) + elif roi_path.suffix == ".roi": + ij_roi = read_roi_file(roi_path) + else: + raise Exception("ImageJ ROI file needs to be a zip/roi file") + + if ij_roi is None: + raise Exception("Failed reading ROI file") + + # initialise list of rois + roi_list = [] + + # Read through each roi and create a list so that it matches the organisation of the shapes from napari shapes layer + for value in ij_roi.values(): + if value['type'] in ('oval', 'rectangle'): + width = int(value['width']) + height = int(value['height']) + left = int(value['left']) + top = int(value['top']) + roi = Roi((top, left), (top, left+width), (top+height, left+width), (top+height, left)) + roi_list.append(roi) + elif value['type'] in ('polygon', 'freehand'): + left = min(int(it) for it in value['x']) + top = min(int(it) for it in value['y']) + right = max(int(it) for it in value['x']) + bottom = max(int(it) for it in value['y']) + roi = Roi((top, left), (top, right), (bottom, right), (bottom, left)) + roi_list.append(roi) + else: + print(f"Cannot read ROI {value}. Recognised as type {value['type']}") + + return roi_list diff --git a/core/lls_core/deconvolution.py b/core/lls_core/deconvolution.py index 52a72708..c3d7c877 100644 --- a/core/lls_core/deconvolution.py +++ b/core/lls_core/deconvolution.py @@ -1,11 +1,14 @@ +from __future__ import annotations + from pathlib import Path + +from resource_backed_dask_array import ResourceBackedDaskArray from lls_core import DeconvolutionChoice -from lls_core.lattice_data import LatticeData import pyclesperanto_prototype as cle import logging import importlib.util -from typing import Collection, Iterable,Union,Literal -from aicsimageio import AICSImage +from typing import Collection, Iterable,Union,Literal, Optional, TYPE_CHECKING +from aicsimageio.aics_image import AICSImage from skimage.io import imread from aicspylibczi import CziFile from numpy.typing import NDArray @@ -13,9 +16,12 @@ import numpy as np import dask.array as da from dask.array.core import Array as DaskArray -from resource_backed_dask_array import ResourceBackedDaskArray -from lls_core.utils import pad_image_nearest_multiple +from lls_core.utils import array_to_dask, pad_image_nearest_multiple +from lls_core.types import ArrayLike, is_arraylike + +if TYPE_CHECKING: + from lls_core.models.lattice_data import LatticeData logger = logging.getLogger(__name__) @@ -82,13 +88,13 @@ def read_psf(psf_paths: Collection[Path], # libraries are better designed.. Atleast until RL deconvolution is available in pyclesperant # Talley Lamberts pycudadecon is a great library and highly optimised. def pycuda_decon( - image, - otf_path=None, - dzdata=0.3, - dxdata=0.1449922, - dzpsf=0.3, - dxpsf=0.1449922, - psf=None, + image: ArrayLike, + otf_path: Optional[str]=None, + dzdata: float=0.3, + dxdata: float=0.1449922, + dzpsf: float=0.3, + dxpsf: float=0.1449922, + psf: Optional[ArrayLike]=None, num_iter: int = 10, cropping: bool = False, background: Union[float,Literal["auto","second_last"]] = 0 @@ -97,7 +103,7 @@ def pycuda_decon( pycudadecon can return cropped images, so we pad the images with dimensions that are a multiple of 64 Args: - image (np.array): _description_ + image : _description_ otf_path : (path to the generated otf file, if available. Otherwise psf needs to be provided) dzdata : (pixel size in z in microns) dxdata : (pixel size in xy in microns) @@ -120,10 +126,7 @@ def pycuda_decon( background = np.median(image[-2]) # if dask array, convert to numpy array - if type(image) in [ - da.core.Array, - ResourceBackedDaskArray, - ]: + if isinstance(image, DaskArray): image = np.array(image) orig_img_shape = image.shape @@ -146,13 +149,7 @@ def pycuda_decon( # pad image to a multiple of 64 image = pad_image_nearest_multiple(img=image, nearest_multiple=64) - if type(psf) in [ - np.ndarray, - np.array, - DaskArray, - ResourceBackedDaskArray, - cle._tier0._pycl.OCLArray, - ]: + if is_arraylike(psf): from pycudadecon import RLContext, TemporaryOTF, rl_decon psf = np.squeeze(psf) # remove unit dimensions @@ -229,11 +226,7 @@ def skimage_decon( from skimage.restoration import richardson_lucy as rl_decon_skimage depth = tuple(np.array(psf.shape) // 2) - if not isinstance(vol_zyx, ( - DaskArray, - ResourceBackedDaskArray, - )): - vol_zyx = da.asarray(vol_zyx) + vol_zyx = array_to_dask(vol_zyx) decon_data = vol_zyx.map_overlap( rl_decon_skimage, psf=psf, diff --git a/core/lls_core/io.py b/core/lls_core/io.py deleted file mode 100644 index 0ea4c363..00000000 --- a/core/lls_core/io.py +++ /dev/null @@ -1,618 +0,0 @@ -from __future__ import annotations -# Opening and saving files -import aicsimageio -from aicsimageio.types import ImageLike, ArrayLike -from aicsimageio import AICSImage - -from pathlib import Path - -import pyclesperanto_prototype as cle -import sys -import dask -from resource_backed_dask_array import ResourceBackedDaskArray -import dask.array as da -from dask.array.core import Array as DaskArray -import pandas as pd -from typing import TYPE_CHECKING, Callable, Optional, Literal, Any -from os import PathLike - -from dask.distributed import Client -from dask.cache import Cache - -from lls_core.utils import etree_to_dict -from lls_core.llsz_core import crop_volume_deskew -from lls_core.deconvolution import skimage_decon, pycuda_decon -from lls_core import config, DeconvolutionChoice, SaveFileType -from lls_core.lattice_data import LatticeData - -import os -import numpy as np -from tqdm import tqdm -from tifffile import imwrite, TiffWriter - -import npy2bdv - -# Enable Logging -import logging -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from napari.types import ImageData - from napari_workflows import Workflow - -def convert_imgdata_aics(img_data: ImageData): - """Return AICSimage object from napari ImageData type - - Args: - img_data ([type]): [description] - - Returns: - AICImage: [description] - """ - # Error handling for czi file - try: - # stack=aicsimageio.imread_dask(img_location) - # using AICSImage will read data as STCZYX - stack = aicsimageio.AICSImage(img_data) - # stack_meta=aicsimageio.imread_dask(img_location) - except Exception: - print("Error: A ", sys.exc_info()[ - 0], "has occurred. See below for details.") - raise - - # Dask setup - # Setting up dask scheduler - # start distributed scheduler locally. Launch dashboard - client = Client(processes=False) - # memory_limit='30GB', - dask.config.set({"temporary_directory": "C:\\Dask_temp\\", "optimization.fuse.active": False, - 'array.slicing.split_large_chunks': False}) - cache = Cache(2e9) - cache.register() - print("If using dask, dask Client can be viewed at:", client.dashboard_link) - - # Check metadata to verify postprocessing or any modifications by Zen - return stack - - -# will flesh this out once Zeiss lattice has more relevant metadata in the czi file -def check_metadata(img_path: ImageLike): - print("Checking CZI metadata") - metadatadict_czi = etree_to_dict(aicsimageio.AICSImage(img_path).metadata) - metadatadict_czi = metadatadict_czi["ImageDocument"]["Metadata"] - acquisition_mode_setup = metadatadict_czi["Experiment"]["ExperimentBlocks"][ - "AcquisitionBlock"]["HelperSetups"]["AcquisitionModeSetup"]["Detectors"] - print(acquisition_mode_setup["Detector"]["ImageOrientation"]) - print("Image Orientation: If any post processing has been applied, it will appear here.\n \ - For example, Zen 3.2 flipped the image so the coverslip was towards the bottom of Z stack. So, value will be 'Flip'") - -# TODO: write save function for deskew and for crop - - -def save_img(vol: ArrayLike, - func: Callable, - time_start: int, - time_end: int, - channel_start: int, - channel_end: int, - save_file_type: SaveFileType, - save_path: PathLike, - save_name_prefix: str = "", - save_name: str = "img", - dx: float = 1, - dy: float = 1, - dz: float = 1, - angle: Optional[float] = None, - # This is a MagicTemplate object but we want to avoid depending on magicclass - # TODO: refactor this out - LLSZWidget: Optional[Any]=None, - terminal: bool = False, - lattice: Optional[LatticeData]=None, - *args, **kwargs): - """ - Applies a function as described in callable - Args: - vol (_type_): Volume to process - func (callable): _description_ - time_start (int): _description_ - time_end (int): _description_ - channel_start (int): _description_ - channel_end (int): _description_ - save_file_type: either 'tiff' or SaveFileType.h5 - save_path (Path): _description_ - save_name_prefix (str, optional): Add a prefix to name. For example, if processng ROIs, add ROI_1_. Defaults to "". - save_name (str, optional): name of file being saved. Defaults to "img". - dx (float, optional): _description_. Defaults to 1. - dy (float, optional): _description_. Defaults to 1. - dz (float, optional): _description_. Defaults to 1. - angle(float, optional) = Deskewing angle in degrees, used to calculate new z - LLSZWidget(class,optional) = LLSZWidget class - """ - - # save_path = save_path.__str__() - _save_path: Path = Path(save_path) - - # replace any : with _ and remove spaces in case it hasn't been processed/skipped - save_name = str(save_name).replace(":", "_").replace(" ", "") - - time_range = range(time_start, time_end+1) - - channel_range = range(channel_start, channel_end+1) - - # Calculate new_pixel size in z after deskewing - if angle > 0: - import math - new_dz = math.sin(angle * math.pi / 180.0) * dz - else: - new_dz = dz - - if func is crop_volume_deskew: - # create folder for each ROI; disabled as each roi is saved as hyperstack - save_name_prefix = save_name_prefix + "_" - #save_path = save_path+os.sep+save_name_prefix+os.sep - # if not os.path.exists(save_path): - # os.makedirs(save_path) - im_final = [] - - # setup bdvwriter - if save_file_type == SaveFileType.h5: - if func is crop_volume_deskew: - save_path_h5 = _save_path / (save_name_prefix + "_" + save_name + ".h5") - else: - save_path_h5 = _save_path / (save_name + ".h5") - - bdv_writer = npy2bdv.BdvWriter(str(save_path_h5), - compression='gzip', - nchannels=len(channel_range), - subsamp=( - (1, 1, 1), (1, 2, 2), (2, 4, 4)), - overwrite=False) - - # bdv_writer = npy2bdv.BdvWriter(save_path_h5, compression=None, nchannels=len(channel_range)) #~30% faster, but up to 10x bigger filesize - else: - pass - - if terminal: - if lattice.decon_processing: - decon_value = True - decon_option = lattice.decon_processing # decon_processing holds the choice - lattice_class = lattice - logging.debug(f"Decon option {decon_option}") - - else: - try: - decon_value = LLSZWidget.LlszMenu.deconvolution.value - lattice_class = LLSZWidget.LlszMenu.lattice - decon_option = LLSZWidget.LlszMenu.lattice.decon_processing - except: - decon_value = 0 - lattice_class = 0 - decon_option = 0 - - # loop is ordered so image is saved in order TCZYX for ometiffwriter - for loop_time_idx, time_point in enumerate(tqdm(time_range, desc="Time", position=0)): - images_array = [] - for loop_ch_idx, ch in enumerate(tqdm(channel_range, desc="Channels", position=1, leave=False)): - try: - if len(vol.shape) == 3: - raw_vol = vol - elif len(vol.shape) == 4: - raw_vol = vol[time_point, :, :, :] - elif len(vol.shape) == 5: - raw_vol = vol[time_point, ch, :, :, :] - except IndexError: - assert ch <= channel_end, f"Channel out of range. Got {ch}, but image has channels {channel_end+1}" - assert time_point <= channel_end, f"Channel out of range. Got {ch}, but image has channels {channel_end+1}" - assert len(vol.shape) in [ - 3, 4, 5], f"Check shape of volume. Expected volume with shape 3,4 or 5. Got {vol.shape} with shape {len(vol.shape)}" - print(f"Using time points {time_point} and channel {ch}") - exit() - - #raw_vol = np.array(raw_vol) - image_type = raw_vol.dtype - #print(decon_value) - #print(decon_option) - #print(func) - # Add a check for last timepoint, in case acquisition incomplete - if time_point == time_end: - orig_shape = raw_vol.shape - raw_vol = raw_vol.compute() - if raw_vol.shape != orig_shape: - print( - f"Time {time_point}, channel {ch} is incomplete. Actual shape {orig_shape}, got {raw_vol.shape}") - z_diff, y_diff, x_diff = np.subtract( - orig_shape, raw_vol.shape) - print(f"Padding with{z_diff,y_diff,x_diff}") - raw_vol = np.pad( - raw_vol, ((0, z_diff), (0, y_diff), (0, x_diff))) - assert raw_vol.shape == orig_shape, f"Shape of last timepoint still doesn't match. Got {raw_vol.shape}" - - # If deconvolution is checked - if decon_value and func != crop_volume_deskew: - # Use CUDA or skimage for deconvolution based on user choice - if decon_option == DeconvolutionChoice.cuda_gpu: - raw_vol = pycuda_decon(image=raw_vol, - #otf_path = LLSZWidget.LlszMenu.lattice.otf_path[ch], - psf=lattice_class.psf[ch], - dzdata=lattice_class.dz, - dxdata=lattice_class.dx, - dzpsf=lattice_class.dz, - dxpsf=lattice_class.dx, - num_iter=lattice_class.psf_num_iter) - else: - raw_vol = skimage_decon(vol_zyx=raw_vol, - psf=lattice_class.psf[ch], - num_iter=lattice_class.psf_num_iter, - clip=False, filter_epsilon=0, boundary='nearest') - - # The following will apply the user-passed function to the input image - if func is crop_volume_deskew and decon_value == True: - processed_vol = func(original_volume=raw_vol, - deconvolution=decon_value, - decon_processing=decon_option, - psf=lattice_class.psf[ch], - num_iter=lattice_class.psf_num_iter, - *args, **kwargs).astype(image_type) - - elif func is cle.deskew_y or func is cle.deskew_x: - processed_vol = func(input_image=raw_vol, - *args, **kwargs).astype(image_type) - - elif func is crop_volume_deskew: - processed_vol = func( - original_volume=raw_vol, *args, **kwargs).astype(image_type) - else: - # if its not deskew or crop/deskew, apply the user-passed function and any specific parameters - processed_vol = func(*args, **kwargs).astype(image_type) - - processed_vol = cle.pull_zyx(processed_vol) - - if save_file_type == SaveFileType.h5: - # convert opencl array to dask array - #pvol = da.asarray(processed_vol) - # channel and time index are based on loop iteration - bdv_writer.append_view(processed_vol, - time=loop_time_idx, - channel=loop_ch_idx, - voxel_size_xyz=(dx, dy, new_dz), - voxel_units='um') - - #print("\nAppending volume to h5\n") - else: - images_array.append(processed_vol) - - # if function is not for cropping, then dataset can be quite large, so save each channel and timepoint separately - # otherwise, append it into im_final - - if func != crop_volume_deskew and save_file_type == SaveFileType.tiff: - final_name = _save_path / (save_name_prefix + "C" + str(ch) + "T" + str( time_point) + "_" + save_name+".tif") - images_array = np.array(images_array) - images_array = np.expand_dims(images_array, axis=0) - images_array = np.swapaxes(images_array, 1, 2) - imwrite(final_name, - images_array, - bigtiff=True, - resolution=(1./dx, 1./dy, "MICROMETER"), - metadata={'spacing': new_dz, 'unit': 'um', 'axes': 'TZCYX'}, imagej=True) - elif save_file_type == SaveFileType.tiff: - # convert list of arrays into a numpy array and then append to im_final - im_final.append(np.array(images_array)) - - # close the h5 writer or if its tiff, save images - if save_file_type == SaveFileType.h5: - bdv_writer.write_xml() - bdv_writer.close() - - elif func is crop_volume_deskew and save_file_type == SaveFileType.tiff: - - im_final = np.array(im_final) - final_name = _save_path / (save_name_prefix + "_" + save_name + ".tif") - - im_final = np.swapaxes(im_final, 1, 2) - # imagej=True; ImageJ hyperstack axes must be in TZCYXS order - - imwrite(final_name, - im_final, - # specify resolution unit for consistent metadata) - resolution=(1./dx, 1./dy, "MICROMETER"), - metadata={'spacing': new_dz, 'unit': 'um', 'axes': 'TZCYX'}, - imagej=True) - im_final = None - - -def save_img_workflow(vol, - workflow: Workflow, - input_arg: str, - first_task: str, - last_task: str, - time_start: int, - time_end: int, - channel_start: int, - channel_end: int, - save_file_type, - save_path: Path, - save_name_prefix: str = "", - save_name: str = "img", - dx: float = 1, - dy: float = 1, - dz: float = 1, - angle: float = None, - deconvolution: bool = False, - decon_processing: str = None, - psf_arg=None, - psf=None, - otf_path=None): - """ - Applies a workflow to the image and saves the output - Use of workflows ensures its agnostic to the processing operation - Args: - vol (_type_): Volume to process - workflow (Workflow): napari workflow - input_arg (str): name for input image - task_name (str): name of the task that should be executed in the workflow - time_start (int): _description_ - time_start (int): _description_ - time_end (int): _description_ - channel_start (int): _description_ - channel_end (int): _description_ - save_path (Path): _description_ - save_name_prefix (str, optional): Add a prefix to name. For example, if processng ROIs, add ROI_1_. Defaults to "". - save_name (str, optional): name of file being saved. Defaults to "img". - dx (float, optional): _description_. Defaults to 1. - dy (float, optional): _description_. Defaults to 1. - dz (float, optional): _description_. Defaults to 1. - angle(float, optional) = Deskewing angle in degrees, used to calculate new z - """ - - # TODO: Implement h5 saving - - save_path = save_path.__str__() - - # replace any : with _ and remove spaces in case it hasn't been processed/skipped - save_name = save_name.replace(":", "_").replace(" ", "") - - # adding +1 at the end so the last channel and time is included - - time_range = range(time_start, time_end + 1) - channel_range = range(channel_start, channel_end + 1) - - # Calculate new_pixel size in z - # convert voxel sixes to an aicsimage physicalpixelsizes object for metadata - if angle: - import math - new_dz = math.sin(angle * math.pi / 180.0) * dz - #aics_image_pixel_sizes = PhysicalPixelSizes(new_dz,dy,dx) - else: - #aics_image_pixel_sizes = PhysicalPixelSizes(dz,dy,dx) - new_dz = dz - - # get list of all functions in the workflow - workflow_functions = [i[0] for i in workflow._tasks.values()] - - # iterate through time and channels and apply workflow - # TODO: add error handling so the image writers will "close",if an error causes the program to exit - # try except? - for loop_time_idx, time_point in enumerate(tqdm(time_range, desc="Time", position=0)): - output_array = [] - data_table = [] - for loop_ch_idx, ch in enumerate(tqdm(channel_range, desc="Channels", position=1, leave=False)): - - if len(vol.shape) == 3: - raw_vol = vol - else: - raw_vol = vol[time_point, ch, :, :, :] - - # TODO: disable if support for resourc backed dask array is added - # if type(raw_vol) in [resource_backed_dask_array]: - # raw_vol = raw_vol.compute() #convert to numpy array as resource backed dask array not su - - # to access current time and channel, create a file config.py in same dir as workflow or in home directory - # add "channel = 0" and "time=0" in the file and save - # https://docs.python.org/3/faq/programming.html?highlight=global#how-do-i-share-global-variables-across-modules - - config.channel = ch - config.time = time_point - - # if deconvolution, need to define psf and choose the channel appropriate one - if deconvolution: - workflow.set(psf_arg, psf[ch]) - # if decon_processing == "cuda_gpu": - # workflow.set("psf",psf[ch]) - # else: - # workflow.set("psf",psf[ch]) - - # Add a check for last timepoint, in case acquisition incomplete or an incomplete frame - if time_point == time_end: - orig_shape = raw_vol.shape - raw_vol = raw_vol.compute() - if raw_vol.shape != orig_shape: - print( - f"Time {time_point}, channel {ch} is incomplete. Actual shape {orig_shape}, got {raw_vol.shape}") - z_diff, y_diff, x_diff = np.subtract( - orig_shape, raw_vol.shape) - print(f"Padding with{z_diff,y_diff,x_diff}") - raw_vol = np.pad( - raw_vol, ((0, z_diff), (0, y_diff), (0, x_diff))) - assert raw_vol.shape == orig_shape, f"Shape of last timepoint still doesn't match. Got {raw_vol.shape}" - - # Set input to the workflow to be volume from each time point and channel - workflow.set(input_arg, raw_vol) - # execute workflow - processed_vol = workflow.get(last_task) - - output_array.append(processed_vol) - - output_array = np.array(output_array) - - # if workflow returns multiple objects (images, dictionaries, lsits etc..), each object can be accessed by - # output_array[:,index_of_object] - - # use data from first timepoint to get the output type from workflow - # check if multiple objects in the workflow output, if so, get the index for each item - # currently, images, lists and dictionaries are supported - if loop_time_idx == 0: - - # get no of elements - - no_elements = len(processed_vol) - # initialize lsits to hold indexes for each datatype - list_element_index = [] # store indices of lists - dict_element_index = [] # store indices of dicts - # store indices of images (numpy array, dask array and pyclesperanto array) - image_element_index = [] - - # single output and is just dictionary - if isinstance(processed_vol, dict): - dict_element_index = [0] - # if image - elif isinstance(processed_vol, (np.ndarray, cle._tier0._pycl.OCLArray, DaskArray, ResourceBackedDaskArray)): - image_element_index = [0] - no_elements = 1 - # multiple elements - # list with values returns no_elements>1 so make sure its actually a list with different objects - # test this with different workflows - elif no_elements > 1 and type(processed_vol[0]) not in [np.int16, np.int32, np.float16, np.float32, np.float64, int, float]: - array_element_type = [type(output_array[0, i]) - for i in range(no_elements)] - image_element_index = [idx for idx, data_type in enumerate( - array_element_type) if data_type in [np.ndarray, cle._tier0._pycl.OCLArray, da.core.Array]] - dict_element_index = [idx for idx, data_type in enumerate( - array_element_type) if data_type in [dict]] - list_element_index = [idx for idx, data_type in enumerate( - array_element_type) if data_type in [list]] - elif type(processed_vol) is list: - list_element_index = [0] - - # setup required image writers - if len(image_element_index) > 0: - # pass list of images and index to function - writer_list = [] - # create an image writer for each image - for element in range(len(image_element_index)): - final_save_path = save_path + os.sep + save_name_prefix + "_" + \ - str(element)+"_" + save_name + \ - "." + save_file_type.value - # setup writer based on user choice of filetype - if save_file_type == SaveFileType.h5: - bdv_writer = npy2bdv.BdvWriter(final_save_path, - compression='gzip', - nchannels=len( - channel_range), - subsamp=( - (1, 1, 1), (1, 2, 2), (2, 4, 4)), - overwrite=True) # overwrite set to True; is this good practice? - writer_list.append(bdv_writer) - else: - # imagej =true throws an error - writer_list.append(TiffWriter( - final_save_path, bigtiff=True)) - - # handle image saving: either h5 or tiff saving - if len(image_element_index) > 0: - # writer_idx is for each writers, image_idx will be the index of images - for writer_idx, image_idx in enumerate(image_element_index): - # access the image - # print(len(time_range)) - if len(channel_range) == 1: - if (len(time_range)) == 1: # if only one timepoint - im_final = np.stack( - output_array[image_idx, ...]).astype(raw_vol.dtype) - else: - im_final = np.stack( - output_array[0, image_idx]).astype(raw_vol.dtype) - else: - im_final = np.stack( - output_array[:, image_idx]).astype(raw_vol.dtype) - if save_file_type == SaveFileType.h5: - for ch_idx in channel_range: - # write h5 images as 3D stacks - assert len( - im_final.shape) >= 3, f"Image shape should be >=3, got {im_final.shape}" - # print(im_final.shape) - if len(im_final.shape) == 3: - im_channel = im_final - elif len(im_final.shape) > 3: - im_channel = im_final[ch_idx, ...] - - writer_list[writer_idx].append_view(im_channel, - time=loop_time_idx, - channel=loop_ch_idx, - voxel_size_xyz=( - dx, dy, new_dz), - voxel_units='um') - else: # default to tif - # Use below with imagej=True - # if len(im_final.shape) ==4: #if only one image with no channel, then dimension will 1,z,y,x, so swap 0 and 1 - # im_final = np.swapaxes(im_final,0,1).astype(raw_vol.dtype) #was 1,2,but when stacking images, dimension is CZYX - # im_final = im_final[np.newaxis,...] #adding extra dimension for T - # elif len(im_final.shape)>4: - # im_final = np.swapaxes(im_final,1,2).astype(raw_vol.dtype) #if image with multiple channels, , it will be 1,c,z,y,x - # imagej=True; ImageJ hyperstack axes must be in TZCYXS order - #images_array = np.swapaxes(images_array,0,1).astype(raw_vol.dtype) - writer_list[writer_idx].write(im_final, - resolution=( - 1./dx, 1./dy, "MICROMETER"), - metadata={'spacing': new_dz, 'unit': 'um', 'axes': 'TZCYX', 'PhysicalSizeX': dx, - 'PhysicalSizeXUnit': 'µm', 'PhysicalSizeY': dy, 'PhysicalSizeYUnit': 'µm'}) - im_final = None - - # handle dict saving - # convert to pandas dataframe; update columns with channel and time - if len(dict_element_index) > 0: - # Iterate through the dict output from workflow and add columns for Channel and timepoint - for element in dict_element_index: - for j in channel_range: - output_array[j, element].update({"Channel": "C"+str(j)}) - output_array[j, element].update( - {"Time": "T"+str(time_point)}) - - # convert to pandas dataframe - output_dict_pd = [pd.DataFrame(i) - for i in output_array[:, element]] - - output_dict_pd = pd.concat(output_dict_pd) - # set index to the channel/time - output_dict_pd = output_dict_pd.set_index(["Time", "Channel"]) - - # Save path - dict_save_path = os.path.join( - save_path, "Measurement_"+save_name_prefix) - if not (os.path.exists(dict_save_path)): - os.mkdir(dict_save_path) - - #dict_save_path = os.path.join(dict_save_path,"C" + str(ch) + "T" + str(time_point)+"_"+str(element) + "_measurement.csv") - dict_save_path = os.path.join( - dict_save_path, "Summary_measurement_"+save_name_prefix+"_"+str(element)+"_.csv") - # Opens csv and appends it if file already exists; not efficient. - if os.path.exists(dict_save_path): - output_dict_pd_existing = pd.read_csv( - dict_save_path, index_col=["Time", "Channel"]) - output_dict_summary = pd.concat( - (output_dict_pd_existing, output_dict_pd)) - output_dict_summary.to_csv(dict_save_path) - else: - output_dict_pd.to_csv(dict_save_path) - - if len(list_element_index) > 0: - row_idx = [] - for element in dict_element_index: - for j in channel_range: - row_idx.append("C"+str(j)+"T"+str(time_point)) - output_list_pd = pd.DataFrame( - np.vstack(output_array[:, element]), index=row_idx) - # Save path - list_save_path = os.path.join( - save_path, "Measurement_"+save_name_prefix) - if not (os.path.exists(list_save_path)): - os.mkdir(list_save_path) - list_save_path = os.path.join(list_save_path, "C" + str(ch) + "T" + str( - time_point)+"_"+save_name_prefix+"_"+str(element) + "_measurement.csv") - output_list_pd.to_csv(list_save_path) - - if len(image_element_index) > 0: - for writer_idx in range(len(image_element_index)): - if save_file_type == SaveFileType.h5: - # write h5 metadata - writer_list[writer_idx].write_xml() - # close the writers (applies for both tiff and h5) - writer_list[writer_idx].close() diff --git a/core/lls_core/lattice_data.py b/core/lls_core/lattice_data.py deleted file mode 100644 index a4f250b0..00000000 --- a/core/lls_core/lattice_data.py +++ /dev/null @@ -1,207 +0,0 @@ -from __future__ import annotations -# class for initializing lattice data and setting metadata -# TODO: handle scenes -from dataclasses import dataclass, field -from aicsimageio.aics_image import AICSImage -from aicsimageio.dimensions import Dimensions -from numpy.typing import NDArray -from dataclasses import dataclass -import math -import numpy as np - -from typing import Any, List, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar - -from aicsimageio.types import ArrayLike, PhysicalPixelSizes -import pyclesperanto_prototype as cle - -from lls_core import DeskewDirection, DeconvolutionChoice -from lls_core.utils import get_deskewed_shape - -if TYPE_CHECKING: - import pyclesperanto_prototype as cle - -T = TypeVar("T") -def raise_if_none(obj: Optional[T], message: str) -> T: - if obj is None: - raise TypeError(message) - return obj - -@dataclass -class DefinedPixelSizes: - """ - Like PhysicalPixelSizes, but it's a dataclass, and - none of its fields are None - """ - X: float = 0.14 - Y: float = 0.14 - Z: float = 0.3 - -@dataclass -class LatticeData: - """ - Holds data and metadata for a given image in a consistent format - """ - #: 3-5D array - data: ArrayLike - dims: Dimensions - - #: The filename of this data when it is saved - save_name: str - - #: Geometry of the light path - skew: DeskewDirection = DeskewDirection.Y - angle: float = 30.0 - - decon_processing: Optional[DeconvolutionChoice] = None - - #: Pixel size in microns - physical_pixel_sizes: DefinedPixelSizes = field(default_factory=DefinedPixelSizes) - - new_dz: Optional[float] = None - - # Dimensions of the deskewed output - deskew_vol_shape: Optional[Tuple[int, ...]] = None - deskew_affine_transform: Optional[cle.AffineTransform3D] = None - - # PSF data that should be refactored into another class eventually - psf: Optional[List[NDArray]] = None - psf_num_iter: Optional[int] = None - otf_path: Optional[List] = None - - #: Number of time points - time: int = 0 - #: Number of channels - channels: int = 0 - - # TODO: add defaults here, rather than in the CLI - # Hack to ensure that .skew_dir behaves identically to .skew - @property - def skew_dir(self) -> DeskewDirection: - return self.skew - - @skew_dir.setter - def skew_dir(self, value: DeskewDirection): - self.skew = value - - @property - def deskew_func(self): - # Chance deskew function absed on skew direction - if self.skew == DeskewDirection.Y: - return cle.deskew_y - elif self.skew == DeskewDirection.X: - return cle.deskew_x - else: - raise ValueError() - - @property - def dx(self) -> float: - return self.physical_pixel_sizes.X - - @dx.setter - def dx(self, value: float): - self.physical_pixel_sizes.X = value - - @property - def dy(self) -> float: - return self.physical_pixel_sizes.Y - - @dy.setter - def dy(self, value: float): - self.physical_pixel_sizes.Y = value - - @property - def dz(self) -> float: - return self.physical_pixel_sizes.Z - - @dz.setter - def dz(self, value: float): - self.physical_pixel_sizes.Z = value - - def get_angle(self) -> float: - return self.angle - - def set_angle(self, angle: float) -> None: - self.angle = angle - - def set_skew(self, skew: DeskewDirection) -> None: - self.skew = skew - - def __post_init__(self): - # set new z voxel size - if self.skew == DeskewDirection.Y or self.skew == DeskewDirection.X: - self.new_dz = math.sin(self.angle * math.pi / 180.0) * self.dz - - # process the file to get shape of final deskewed image - self.deskew_vol_shape, self.deskew_affine_transform = get_deskewed_shape(self.data, self.angle, self.dx, self.dy, self.dz) - print(f"Channels: {self.channels}, Time: {self.time}") - print("If channel and time need to be swapped, you can enforce this by choosing 'Last dimension is channel' when initialising the plugin") - -def lattice_from_aics(img: AICSImage, physical_pixel_sizes: PhysicalPixelSizes = PhysicalPixelSizes(None, None, None), **kwargs: Any) -> LatticeData: - # Note: The reason we copy all of these fields rather than just storing the AICSImage is because that class is mostly immutable and so not suitable - - pixel_sizes = DefinedPixelSizes( - X = physical_pixel_sizes[0] or img.physical_pixel_sizes.X or LatticeData.physical_pixel_sizes.X, - Y = physical_pixel_sizes[1] or img.physical_pixel_sizes.Y or LatticeData.physical_pixel_sizes.Y, - Z = physical_pixel_sizes[2] or img.physical_pixel_sizes.Z or LatticeData.physical_pixel_sizes.Z - ) - - return LatticeData( - data = img.dask_data, - dims = img.dims, - time = img.dims.T, - channels = img.dims.C, - physical_pixel_sizes = pixel_sizes, - **kwargs - ) - -def img_from_array(arr: ArrayLike, last_dimension: Optional[Literal["channel", "time"]] = None, **kwargs: Any) -> AICSImage: - """ - Creates an AICSImage from an array without metadata - - Args: - arr (ArrayLike): An array - last_dimension: How to handle the dimension order - kwargs: Additional arguments to pass to the AICSImage constructor - """ - dim_order: str - - if len(arr.shape) < 3 or len(arr.shape) > 5: - raise ValueError("Array dimensions must be in the range [3, 5]") - - # if aicsimageio tiffreader assigns last dim as time when it should be channel, user can override this - if len(arr.shape) == 3: - dim_order="ZYX" - else: - if last_dimension not in ["channel", "time"]: - raise ValueError("last_dimension must be either channel or time") - if len(arr.shape) == 4: - if last_dimension == "channel": - dim_order = "CZYX" - elif last_dimension == "time": - dim_order = "TZYX" - elif len(arr.shape) == 5: - if last_dimension == "channel": - dim_order = "CTZYX" - elif last_dimension == "time": - dim_order = "TCZYX" - else: - raise ValueError() - - img = AICSImage(image=arr, dim_order=dim_order, **kwargs) - - # if last axes of "aicsimage data" shape is not equal to time, then swap channel and time - if img.data.shape[0] != img.dims.T or img.data.shape[1] != img.dims.C: - arr = np.swapaxes(arr, 0, 1) - return AICSImage(image=arr, dim_order=dim_order, **kwargs) - - -def lattice_fom_array(arr: ArrayLike, last_dimension: Optional[Literal["channel", "time"]] = None, **kwargs: Any) -> LatticeData: - """ - Creates a `LatticeData` from an array - - Args: - arr: Array to use as the data source - last_dimension: See img_from_array - """ - aics = img_from_array(arr, last_dimension) - return lattice_from_aics(aics, **kwargs) \ No newline at end of file diff --git a/core/lls_core/llsz_core.py b/core/lls_core/llsz_core.py index 3d2d143a..cb41fab8 100644 --- a/core/lls_core/llsz_core.py +++ b/core/lls_core/llsz_core.py @@ -1,12 +1,12 @@ from __future__ import annotations -from aicsimageio.types import ArrayLike import numpy as np import pyclesperanto_prototype as cle from dask.array.core import Array as DaskArray import dask.array as da from resource_backed_dask_array import ResourceBackedDaskArray -from typing import Optional, Union, TYPE_CHECKING +from typing import Any, Optional, Union, TYPE_CHECKING, overload, Literal, Tuple +from typing_extensions import Unpack, TypedDict, Required from pyclesperanto_prototype._tier8._affine_transform_deskew_3d import ( affine_transform_deskew_3d, ) @@ -14,6 +14,7 @@ from lls_core.utils import calculate_crop_bbox, check_subclass, is_napari_shape, pad_image_nearest_multiple from lls_core import config, DeskewDirection, DeconvolutionChoice +from lls_core.types import ArrayLike from lls_core.deconvolution import pycuda_decon, skimage_decon # Enable Logging @@ -33,10 +34,36 @@ cle._tier0._pycl.OCLArray, ] +class CommonArgs(TypedDict, total=False): + original_volume: Required[ArrayLike] + deskewed_volume: Union[ ArrayLike, None ] + roi_shape: Union[list, NDArray, None] + angle_in_degrees: float + voxel_size_x: float + voxel_size_y: float + voxel_size_z: float + z_start: int + z_end: int + deconvolution: bool + decon_processing: Optional[DeconvolutionChoice] + psf: Union[Psf, None] + num_iter: int + linear_interpolation: bool + skew_dir: DeskewDirection + +@overload +def crop_volume_deskew(*, debug: Literal[True], get_deskew_and_decon: bool = False, **kwargs: Unpack[CommonArgs]) -> Tuple[NDArray, NDArray]: + ... +@overload +def crop_volume_deskew(*, debug: Literal[False] = False, get_deskew_and_decon: Literal[True], **kwargs: Unpack[CommonArgs]) -> Tuple[NDArray, NDArray]: + ... +@overload +def crop_volume_deskew(*, debug: Literal[False] = False, get_deskew_and_decon: Literal[False] = False, **kwargs: Unpack[CommonArgs]) -> NDArray: + ... def crop_volume_deskew( original_volume: ArrayLike, deskewed_volume: Union[ ArrayLike, None ] = None, - roi_shape: Union[Shapes, list, NDArray, None] = None, + roi_shape: Union[list, NDArray, None] = None, angle_in_degrees: float = 30, voxel_size_x: float = 1, voxel_size_y: float = 1, @@ -45,7 +72,7 @@ def crop_volume_deskew( z_end: int = 1, debug: bool = False, deconvolution: bool = False, - decon_processing: Optional[str]=None, + decon_processing: Optional[DeconvolutionChoice]=None, psf: Union[Psf, None]=None, num_iter: int = 10, linear_interpolation: bool=True, @@ -85,15 +112,15 @@ def crop_volume_deskew( # if shapes layer, get first one # TODO: test this - if is_napari_shape(roi_shape): - shape = roi_shape.data[0] + # if is_napari_shape(roi_shape): + # shape = roi_shape.data[0] # if its a list and each element has a shape of 4, its a list of rois - elif type(roi_shape) is list and len(roi_shape[0]) == 4: + if isinstance(roi_shape, list) and len(roi_shape[0]) == 4: # TODO:change to accept any roi by passing index shape = roi_shape[0] # len(roi_shape) >= 1: # if its a array or list with shape of 4, its a single ROI - elif len(roi_shape) == 4 and type(roi_shape) in (np.ndarray, list): + elif len(roi_shape) == 4 and isinstance(roi_shape, (np.ndarray, list)): shape = roi_shape assert len(shape) == 4, print("Shape must be an array of shape 4") @@ -245,9 +272,9 @@ def crop_volume_deskew( crop_height = crop_vol_shape[1] # Find "excess" volume on both sides due to deskewing - crop_excess = ( - int(round((deskewed_height - crop_height) / 2)) - + out_bounds_correction + crop_excess: int = max( + int(round((deskewed_height - crop_height) / 2)) + out_bounds_correction, + 0 ) # Crop in Y deskewed_prelim = np.asarray(deskewed_prelim) @@ -258,10 +285,11 @@ def crop_volume_deskew( elif skew_dir == DeskewDirection.X: deskewed_width = deskewed_prelim.shape[2] crop_width = crop_vol_shape[2] + # Find "excess" volume on both sides due to deskewing - crop_excess = ( - int(round((deskewed_width - crop_width) / 2)) - + out_bounds_correction + crop_excess = max( + int(round((deskewed_width - crop_width) / 2)) + out_bounds_correction, + 0 ) # Crop in X deskewed_prelim = np.asarray(deskewed_prelim) diff --git a/core/lls_core/models/__init__.py b/core/lls_core/models/__init__.py new file mode 100644 index 00000000..3a102951 --- /dev/null +++ b/core/lls_core/models/__init__.py @@ -0,0 +1,5 @@ +from lls_core.models.crop import CropParams +from lls_core.models.deconvolution import DeconvolutionParams +from lls_core.models.deskew import DeskewParams +from lls_core.models.output import OutputParams +from lls_core.models.lattice_data import LatticeData diff --git a/core/lls_core/models/crop.py b/core/lls_core/models/crop.py new file mode 100644 index 00000000..db7adea2 --- /dev/null +++ b/core/lls_core/models/crop.py @@ -0,0 +1,68 @@ +from typing import Iterable, List, Tuple, Any +from pydantic import Field, NonNegativeInt, validator +from lls_core.models.utils import FieldAccessModel +from lls_core.cropping import Roi + +class CropParams(FieldAccessModel): + """ + Parameters for the optional cropping step. + Note that cropping is performed in the space of the deskewed shape. + This is to support the workflow of performing a preview deskew and using that + to calculate the cropping coordinates. + """ + roi_list: List[Roi] = Field( + description="List of regions of interest, each of which must be an NxD array, where N is the number of vertices and D the coordinates of each vertex.", + cli_description="List of regions of interest, each of which must be the file path to ImageJ ROI file.", + default = [] + ) + roi_subset: List[int] = Field( + description="A subset of all the ROIs to process. Each array item should be an index into the ROI list indicating an ROI to include.", + default=None + ) + z_range: Tuple[NonNegativeInt, NonNegativeInt] = Field( + default=None, + description="The range of Z slices to take. All Z slices before the first index or after the last index will be cropped out.", + cli_description="An array with two items, indicating the index of the first and last Z slice to include." + ) + + @property + def selected_rois(self) -> Iterable[Roi]: + "Returns the relevant ROIs that should be processed" + for i in self.roi_subset: + yield self.roi_list[i] + + @validator("roi_list", pre=True) + def read_roi(cls, v: Any) -> List[Roi]: + from lls_core.types import is_pathlike + from lls_core.cropping import read_imagej_roi + from numpy import ndarray + # Allow a single path + if is_pathlike(v): + v = [v] + + rois: List[Roi] = [] + for item in v: + if is_pathlike(item): + rois += read_imagej_roi(item) + elif isinstance(item, ndarray): + rois.append(Roi.from_array(item)) + elif isinstance(item, Roi): + rois.append(item) + else: + # Try converting an iterable to ROI + try: + rois.append(Roi(*item)) + except: + raise ValueError(f"{item} cannot be intepreted as an ROI") + + if len(rois) == 0: + raise ValueError("At least one region of interest must be specified if cropping is enabled") + + return rois + + @validator("roi_subset", pre=True, always=True) + def default_roi_range(cls, v: Any, values: dict): + # If the roi range isn't provided, assume all rois should be processed + if v is None and "roi_list" in values: + return list(range(len(values["roi_list"]))) + return v diff --git a/core/lls_core/models/deconvolution.py b/core/lls_core/models/deconvolution.py new file mode 100644 index 00000000..373734ee --- /dev/null +++ b/core/lls_core/models/deconvolution.py @@ -0,0 +1,52 @@ + +from pydantic import Field, NonNegativeInt, validator + +from typing import Any, List, Literal, Union +from typing_extensions import TypedDict + +from xarray import DataArray + +from lls_core import DeconvolutionChoice +from lls_core.models.utils import enum_choices, FieldAccessModel + +from lls_core.types import image_like_to_image, ImageLike + +Background = Union[float, Literal["auto", "second_last"]] +class DeconvolutionParams(FieldAccessModel): + """ + Parameters for the optional deconvolution step + """ + decon_processing: DeconvolutionChoice = Field( + default=DeconvolutionChoice.cpu, + description=f"Hardware to use to perform the deconvolution. Choices: {enum_choices(DeconvolutionChoice)}" + ) + psf: List[DataArray] = Field( + default=[], + description="List of Point Spread Functions to use for deconvolution. Each of which should be a 3D array." + ) + psf_num_iter: NonNegativeInt = Field( + default=10, + description="Number of iterations to perform in deconvolution" + ) + background: Background = Field( + default=0, + description='Background value to subtract for deconvolution. Only used when decon_processing is set to GPU. This can either be a literal number, "auto" which uses the median of the last slice, or "second_last" which uses the median of the last slice.' + ) + + @validator("decon_processing", pre=True) + def convert_decon(cls, v: Any): + if isinstance(v, str): + return DeconvolutionChoice[v] + return v + + @validator("psf", pre=True, each_item=True, allow_reuse=True) + def convert_image(cls, v): + img = image_like_to_image(v) + # Ensure the PSF is 3D + if "C" in img.dims: + img = img.isel(C=0) + if "T" in img.dims: + img = img.isel(T=0) + if len(img.dims) != 3: + raise ValueError("PSF is not a 3D array!") + return img diff --git a/core/lls_core/models/deskew.py b/core/lls_core/models/deskew.py new file mode 100644 index 00000000..4e4b3d88 --- /dev/null +++ b/core/lls_core/models/deskew.py @@ -0,0 +1,243 @@ +from __future__ import annotations +# class for initializing lattice data and setting metadata +# TODO: handle scenes +from pydantic import Field, NonNegativeFloat, validator, root_validator + +from typing import Any, Tuple +from typing_extensions import Self, TYPE_CHECKING + +import pyclesperanto_prototype as cle + +from lls_core import DeskewDirection +from xarray import DataArray + +from lls_core.models.utils import FieldAccessModel, enum_choices +from lls_core.types import image_like_to_image, is_arraylike, is_pathlike +from lls_core.utils import get_deskewed_shape + +if TYPE_CHECKING: + from aicsimageio.types import PhysicalPixelSizes + +class DefinedPixelSizes(FieldAccessModel): + """ + Like PhysicalPixelSizes, but it's a dataclass, and + none of its fields are None + """ + X: NonNegativeFloat = Field(default=0.1499219272808386, description="Size of the X dimension of the microscope pixels, in microns.") + Y: NonNegativeFloat = Field(default=0.1499219272808386, description="Size of the Y dimension of the microscope pixels, in microns.") + Z: NonNegativeFloat = Field(default=0.3, description="Size of the Z dimension of the microscope pixels, in microns.") + + @classmethod + def from_physical(cls, pixels: PhysicalPixelSizes) -> Self: + from lls_core.utils import raise_if_none + + return DefinedPixelSizes( + X=raise_if_none(pixels.X, "All pixels must be defined"), + Y=raise_if_none(pixels.Y, "All pixels must be defined"), + Z=raise_if_none(pixels.Z, "All pixels must be defined"), + ) + +class DerivedDeskewFields(FieldAccessModel): + """ + Fields that are automatically calculated based on other fields in DeskewParams. + Grouping these together into one model makes validation simpler. + """ + deskew_vol_shape: Tuple[int, ...] = Field( + init_var=False, + default=None, + description="Dimensions of the deskewed output. This is set automatically based on other input parameters, and doesn't need to be provided by the user." + ) + + deskew_affine_transform: cle.AffineTransform3D = Field(init_var=False, default=None, description="Deskewing transformation function. This is set automatically based on other input parameters, and doesn't need to be provided by the user.") + + +class DeskewParams(FieldAccessModel): + input_image: DataArray = Field( + description="A 3-5D array containing the image data.", + cli_description="A path to any standard image file (TIFF, H5 etc) containing a 3-5D array to process." + ) + skew: DeskewDirection = Field( + default=DeskewDirection.Y, + description=f"Axis along which to deskew the image. Choices: {enum_choices(DeskewDirection)}." + ) + angle: float = Field( + default=30.0, + description="Angle of deskewing, in degrees, as a float." + ) + physical_pixel_sizes: DefinedPixelSizes = Field( + # No default, because we need to distinguish between user provided arguments and defaults + description="Pixel size of the microscope, in microns.", + default=None + ) + derived: DerivedDeskewFields = Field( + init_var=False, + default=None, + description="Refer to the DerivedDeskewFields docstring", + cli_hide=True + ) + # Hack to ensure that .skew_dir behaves identically to .skew + @property + def skew_dir(self) -> DeskewDirection: + return self.skew + + @skew_dir.setter + def skew_dir(self, value: DeskewDirection): + self.skew = value + + @property + def deskew_func(self): + # Chance deskew function absed on skew direction + if self.skew == DeskewDirection.Y: + return cle.deskew_y + elif self.skew == DeskewDirection.X: + return cle.deskew_x + else: + raise ValueError() + + @property + def dx(self) -> float: + return self.physical_pixel_sizes.X + + @dx.setter + def dx(self, value: float): + self.physical_pixel_sizes.X = value + + @property + def dy(self) -> float: + return self.physical_pixel_sizes.Y + + @dy.setter + def dy(self, value: float) -> None: + self.physical_pixel_sizes.Y = value + + @property + def dz(self) -> float: + return self.physical_pixel_sizes.Z + + @dz.setter + def dz(self, value: float): + self.physical_pixel_sizes.Z = value + + def get_angle(self) -> float: + return self.angle + + def set_angle(self, angle: float) -> None: + self.angle = angle + + def set_skew(self, skew: DeskewDirection) -> None: + self.skew = skew + + @property + def dims(self): + return self.input_image.dims + + @property + def time(self) -> int: + """Number of time points""" + return self.input_image.sizes["T"] + + @property + def channels(self) -> int: + """Number of channels""" + return self.input_image.sizes["C"] + + @property + def nslices(self) -> int: + """The number of 3D slices within the image""" + return self.time * self.channels + + @property + def new_dz(self): + import math + return math.sin(self.angle * math.pi / 180.0) * self.dz + + @validator("skew", pre=True) + def convert_skew(cls, v: Any): + # Allow skew to be provided as a string + if isinstance(v, str): + return DeskewDirection[v] + return v + + @validator("physical_pixel_sizes", pre=True, always=True) + def convert_pixels(cls, v: Any, values: dict[Any, Any]): + from aicsimageio.types import PhysicalPixelSizes + if isinstance(v, PhysicalPixelSizes): + v = DefinedPixelSizes.from_physical(v) + elif isinstance(v, tuple) and len(v) == 3: + # Allow the pixel sizes to be specified as a tuple + v = DefinedPixelSizes(Z=v[0], Y=v[1], X=v[2]) + elif v is None: + # At this point, we have exhausted all other methods of obtaining pixel sizes: + # User defined and image metadata. So we just use the defaults + return DefinedPixelSizes() + + return v + + @root_validator(pre=True) + def read_image(cls, values: dict): + from aicsimageio import AICSImage + from os import fspath + + img = values["input_image"] + + aics: AICSImage | None = None + if is_pathlike(img): + aics = AICSImage(fspath(img)) + elif isinstance(img, AICSImage): + aics = img + elif is_arraylike(img): + values["input_image"] = DataArray(img) + else: + raise ValueError("Value of input_image was neither a path, an AICSImage, or array-like.") + + # If the image was convertible to AICSImage, we should use the metadata from there + if aics: + values["input_image"] = aics.xarray_dask_data + # Take pixel sizes from the image metadata, but only if they're defined + # and only if we don't already have them + if all(size is not None for size in aics.physical_pixel_sizes) and values.get("physical_pixel_sizes") is None: + values["physical_pixel_sizes"] = aics.physical_pixel_sizes + + # In all cases, input_image will be a DataArray (XArray) at this point + + return values + + @validator("input_image", pre=True) + def reshaping(cls, v: DataArray): + # This allows a user to pass in any array-like object and have it + # converted and reshaped appropriately + array = v + if not set(array.dims).issuperset({"X", "Y", "Z"}): + raise ValueError("The input array must at least have XYZ coordinates") + if "T" not in array.dims: + array = array.expand_dims("T") + if "C" not in array.dims: + array = array.expand_dims("C") + return array.transpose("T", "C", "Z", "Y", "X") + + def get_3d_slice(self) -> DataArray: + return self.input_image.isel(C=0, T=0) + + @validator("derived", always=True) + def calculate_derived(cls, v: Any, values: dict) -> DerivedDeskewFields: + """ + Sets the default deskew shape values if the user has not provided them + """ + data: DataArray = values["input_image"] + if isinstance(v, DerivedDeskewFields): + return v + elif v is None: + deskew_vol_shape, deskew_affine_transform = get_deskewed_shape( + data.isel(C=0, T=0), + values["angle"], + values["physical_pixel_sizes"].X, + values["physical_pixel_sizes"].Y, + values["physical_pixel_sizes"].Z, + values["skew"] + ) + return DerivedDeskewFields( + deskew_affine_transform=deskew_affine_transform, + deskew_vol_shape=deskew_vol_shape + ) + else: + raise ValueError("Invalid derived fields") diff --git a/core/lls_core/models/lattice_data.py b/core/lls_core/models/lattice_data.py new file mode 100644 index 00000000..a8816e0b --- /dev/null +++ b/core/lls_core/models/lattice_data.py @@ -0,0 +1,474 @@ +from __future__ import annotations +# class for initializing lattice data and setting metadata +# TODO: handle scenes +from pydantic import Field, root_validator, validator +from dask.array.core import Array as DaskArray + +from typing import Any, Iterable, Optional, TYPE_CHECKING, Type +from lls_core import DeconvolutionChoice +from lls_core.deconvolution import pycuda_decon, skimage_decon +from lls_core.llsz_core import crop_volume_deskew +from lls_core.models.crop import CropParams +from lls_core.models.deconvolution import DeconvolutionParams +from lls_core.models.output import OutputParams, SaveFileType +from lls_core.models.results import WorkflowSlices +from lls_core.models.utils import ignore_keyerror +from lls_core.types import ArrayLike +from lls_core.models.deskew import DeskewParams +from napari_workflows import Workflow + +from lls_core.workflow import get_workflow_output_name, workflow_set + +if TYPE_CHECKING: + from lls_core.models.results import ImageSlice, ImageSlices, ProcessedSlice + from lls_core.writers import Writer + from xarray import DataArray + +import logging + +logger = logging.getLogger(__name__) + +class LatticeData(OutputParams, DeskewParams): + """ + Holds data and metadata for a given image in a consistent format + """ + + # Note: originally the save-related fields were included via composition and not inheritance + # (similar to how `crop` and `workflow` are handled), but this was impractical for implementing validations + + #: If this is None, then deconvolution is disabled + deconvolution: Optional[DeconvolutionParams] = None + + #: If this is None, then cropping is disabled + crop: Optional[CropParams] = None + + workflow: Optional[Workflow] = Field( + default=None, + description="If defined, this is a workflow to add lightsheet processing onto", + cli_description="Path to a JSON file specifying a napari_workflow-compatible workflow to add lightsheet processing onto" + ) + + @root_validator(pre=True) + def read_image(cls, values: dict): + from lls_core.types import is_pathlike + from pathlib import Path + input_image = values.get("input_image") + if is_pathlike(input_image): + if values.get("save_name") is None: + values["save_name"] = Path(values["input_image"]).stem + + save_dir = values.get("save_dir") + if save_dir is None: + # By default, make the save dir be the same dir as the input + values["save_dir"] = Path(input_image).parent + elif is_pathlike(save_dir): + # Convert a string path to a Path object + values["save_dir"] = Path(save_dir) + + # Use the Deskew version of this validator, to do the actual image loading + return super().read_image(values) + + @validator("workflow", pre=True) + def parse_workflow(cls, v: Any): + # Load the workflow from disk if it was provided as a path + from lls_core.types import is_pathlike + from lls_core.workflow import workflow_from_path + from pathlib import Path + + if is_pathlike(v): + return workflow_from_path(Path(v)) + return v + + @validator("workflow", pre=False) + def validate_workflow(cls, v: Optional[Workflow]): + if v is not None: + if not "deskewed_image" in v.roots(): + raise ValueError("The workflow has no deskewed_image parameter, so is not compatible with the lls processing.") + try: + get_workflow_output_name(v) + except: + raise ValueError("The workflow has multiple output tasks. Only one is currently supported.") + return v + + @validator("crop") + def default_z_range(cls, v: Optional[CropParams], values: dict) -> Optional[CropParams]: + if v is None: + return v + with ignore_keyerror(): + # Fill in missing parts of the z range + # The max allowed value is the length of the deskew Z axis + default_start = 0 + default_end = values["derived"].deskew_vol_shape[0] + + # Set defaults + if v.z_range is None: + v.z_range = (default_start, default_end) + if v.z_range[0] is None: + v.z_range[0] = default_start + if v.z_range[1] is None: + v.z_range[1] = default_end + + # Validate + if v.z_range[1] > default_end: + raise ValueError(f"The z-index endpoint of {v.z_range[1]} is outside the size of the z-axis ({default_end})") + if v.z_range[0] < default_start: + raise ValueError(f"The z-index start of {v.z_range[0]} is outside the size of the z-axis") + + return v + + @validator("time_range", pre=True, always=True) + def parse_time_range(cls, v: Any, values: dict) -> Any: + """ + Sets the default time range if undefined + """ + # This skips the conversion if no image was provided, to ensure a more + # user-friendly error is provided, namely "image was missing" + from collections.abc import Sequence + with ignore_keyerror(): + default_start = 0 + default_end = values["input_image"].sizes["T"] + if v is None: + return range(default_start, default_end) + elif isinstance(v, Sequence) and len(v) == 2: + # Allow 2-tuples to be used as input for this field + return range(v[0] or default_start, v[1] or default_end) + return v + + @validator("channel_range", pre=True, always=True) + def parse_channel_range(cls, v: Any, values: dict) -> Any: + """ + Sets the default channel range if undefined + """ + from collections.abc import Sequence + with ignore_keyerror(): + default_start = 0 + default_end = values["input_image"].sizes["C"] + if v is None: + return range(default_start, default_end) + elif isinstance(v, Sequence) and len(v) == 2: + # Allow 2-tuples to be used as input for this field + return range(v[0] or default_start, v[1] or default_end) + return v + + @validator("time_range") + def disjoint_time_range(cls, v: range, values: dict): + """ + Validates that the time range is within the range of channels in our array + """ + with ignore_keyerror(): + max_time = values["input_image"].sizes["T"] + if v.start < 0: + raise ValueError("The lowest valid start value is 0") + if v.stop > max_time: + raise ValueError(f"The highest valid time value is the length of the time axis, which is {max_time}") + + return v + + @validator("channel_range") + def disjoint_channel_range(cls, v: range, values: dict): + """ + Validates that the channel range is within the range of channels in our array + """ + with ignore_keyerror(): + max_channel = values["input_image"].sizes["C"] + if v.start < 0: + raise ValueError("The lowest valid start value is 0") + if v.stop > max_channel: + raise ValueError(f"The highest valid channel value is the length of the channel axis, which is {max_channel}") + return v + + @validator("channel_range") + def channel_range_subset(cls, v: Optional[range], values: dict): + with ignore_keyerror(): + if v is not None and (min(v) < 0 or max(v) > values["input_image"].sizes["C"]): + raise ValueError("The output channel range must be a subset of the total available channels") + return v + + @validator("time_range") + def time_range_subset(cls, v: Optional[range], values: dict): + if v is not None and (min(v) < 0 or max(v) > values["input_image"].sizes["T"]): + raise ValueError("The output time range must be a subset of the total available time points") + return v + + @validator("deconvolution") + def check_psfs(cls, v: Optional[DeconvolutionParams], values: dict): + if v is None: + return v + with ignore_keyerror(): + channels = values["input_image"].sizes["C"] + psfs = len(v.psf) + if psfs != channels: + raise ValueError(f"There should be one PSF per channel, but there are {psfs} PSFs and {channels} channels.") + return v + + @property + def cropping_enabled(self) -> bool: + "True if cropping should be performed" + return self.crop is not None + + @property + def deconv_enabled(self) -> bool: + "True if deconvolution should be performed" + return self.deconvolution is not None + + def __post_init__(self): + logger.info(f"Channels: {self.channels}, Time: {self.time}") + logger.info("If channel and time need to be swapped, you can enforce this by choosing 'Last dimension is channel' when initialising the plugin") + + def slice_data(self, time: int, channel: int) -> DataArray: + if time > self.time: + raise ValueError("time is out of range") + if channel > self.channels: + raise ValueError("channel is out of range") + + return self.input_image.isel(T=time, C=channel) + + def iter_slices(self) -> Iterable[ProcessedSlice[ArrayLike]]: + """ + Yields array slices for each time and channel of interest. + + Params: + progress: If the progress bar is enabled + + Returns: + An iterable of tuples. Each tuple contains (time_index, time, channel_index, channel, slice) + """ + from lls_core.models.results import ProcessedSlice + from tqdm import tqdm + + for time_idx, time in tqdm(enumerate(self.time_range), desc="Timepoints", total=len(self.time_range)): + for ch_idx, ch in tqdm(enumerate(self.channel_range), desc="Channels", total=len(self.channel_range), leave=False): + yield ProcessedSlice( + data=self.slice_data(time=time, channel=ch), + time_index=time_idx, + time= time, + channel_index=ch_idx, + channel=ch, + ) + + @property + def n_slices(self) -> int: + """ + Returns the number of slices that will be returned by the `iter_*` methods. + """ + return len(self.time_range) * len(self.channel_range) + + def iter_sublattices(self, update_with: dict = {}) -> Iterable[ProcessedSlice[LatticeData]]: + """ + Yields copies of the current LatticeData, one for each slice. + These copies can then be processed separately. + Args: + update_with: dictionary of arguments to update the generated lattices with + """ + for subarray in self.iter_slices(): + new_lattice = self.copy_validate(update={ + "input_image": subarray.data, + "time_range": range(1), + "channel_range": range(1), + **update_with + }) + yield subarray.copy_with_data( new_lattice) + + def generate_workflows( + self, + ) -> Iterable[ProcessedSlice[Workflow]]: + """ + Yields copies of the input workflow, modified with the addition of deskewing and optionally, + cropping and deconvolution + """ + if self.workflow is None: + return + + from copy import copy + # We make a copy of the lattice for each slice, each of which has no associated workflow + for lattice_slice in self.iter_sublattices(update_with={"workflow": None}): + user_workflow = copy(self.workflow) + # We add a step whose result is called "input_img" that outputs a 2D image slice + user_workflow.set( + "deskewed_image", + LatticeData.process_into_image, + lattice_slice.data + ) + # Also add channel metadata to the workflow + for key in {"channel", "channel_index", "time", "time_index", "roi_index"}: + workflow_set( + user_workflow, + key, + getattr(lattice_slice, key) + ) + # The user can use any of these arguments as inputs to their tasks + yield lattice_slice.copy_with_data(user_workflow) + + def check_incomplete_acquisition(self, volume: ArrayLike, time_point: int, channel: int): + """ + Checks for a slice with incomplete data, caused by incomplete acquisition + """ + import numpy as np + if not isinstance(volume, DaskArray): + return volume + orig_shape = volume.shape + raw_vol = volume.compute() + if raw_vol.shape != orig_shape: + logger.warn(f"Time {time_point}, channel {channel} is incomplete. Actual shape {orig_shape}, got {raw_vol.shape}") + z_diff, y_diff, x_diff = np.subtract(orig_shape, raw_vol.shape) + logger.info(f"Padding with{z_diff,y_diff,x_diff}") + raw_vol = np.pad(raw_vol, ((0, z_diff), (0, y_diff), (0, x_diff))) + if raw_vol.shape != orig_shape: + raise Exception(f"Shape of last timepoint still doesn't match. Got {raw_vol.shape}") + return raw_vol + + @property + def deskewed_volume(self) -> DaskArray: + from dask.array import zeros + return zeros(self.derived.deskew_vol_shape) + + def _process_crop(self) -> Iterable[ImageSlice]: + """ + Yields processed image slices with cropping enabled + """ + from tqdm import tqdm + if self.crop is None: + raise Exception("This function can only be called when crop is set") + + # We have an extra level of iteration for the crop path: iterating over each ROI + for roi_index, roi in enumerate(tqdm(self.crop.selected_rois, desc="ROI", position=0)): + # pass arguments for save tiff, callable and function arguments + logger.info(f"Processing ROI {roi_index}") + + for slice in self.iter_slices(): + deconv_args: dict[Any, Any] = {} + if self.deconvolution is not None: + deconv_args = dict( + num_iter = self.deconvolution.psf_num_iter, + psf = self.deconvolution.psf[slice.channel].to_numpy(), + decon_processing=self.deconvolution.decon_processing + ) + + yield slice.copy(update={ + "data": crop_volume_deskew( + original_volume=slice.data, + deconvolution=self.deconv_enabled, + get_deskew_and_decon=False, + debug=False, + roi_shape=list(roi), + linear_interpolation=True, + voxel_size_x=self.dx, + voxel_size_y=self.dy, + voxel_size_z=self.dz, + angle_in_degrees=self.angle, + deskewed_volume=self.deskewed_volume, + z_start=self.crop.z_range[0], + z_end=self.crop.z_range[1], + **deconv_args + ), + "roi_index": roi_index + }) + + def _process_non_crop(self) -> Iterable[ImageSlice]: + """ + Yields processed image slices without cropping + """ + import pyclesperanto_prototype as cle + + for slice in self.iter_slices(): + data: ArrayLike = slice.data + if isinstance(slice.data, DaskArray): + data = slice.data.compute() + if self.deconvolution is not None: + if self.deconvolution.decon_processing == DeconvolutionChoice.cuda_gpu: + data = pycuda_decon( + image=data, + psf=self.deconvolution.psf[slice.channel].to_numpy(), + background=self.deconvolution.background, + dzdata=self.dz, + dxdata=self.dx, + dzpsf=self.dz, + dxpsf=self.dx, + num_iter=self.deconvolution.psf_num_iter + ) + else: + data = skimage_decon( + vol_zyx=data, + psf=self.deconvolution.psf[slice.channel].to_numpy(), + num_iter=self.deconvolution.psf_num_iter, + clip=False, + filter_epsilon=0, + boundary='nearest' + ) + + yield slice.copy_with_data( + cle.pull_zyx(self.deskew_func( + input_image=data, + angle_in_degrees=self.angle, + linear_interpolation=True, + voxel_size_x=self.dx, + voxel_size_y=self.dy, + voxel_size_z=self.dz + )) + ) + + def process_workflow(self) -> WorkflowSlices: + """ + Runs the workflow on each slice and returns the workflow results + """ + from lls_core.models.results import WorkflowSlices + WorkflowSlices.update_forward_refs(LatticeData=LatticeData) + outputs: list[ProcessedSlice[Any]] = [] + for workflow in self.generate_workflows(): + outputs.append( + workflow.copy_with_data( + # Evaluates the workflow here. + workflow.data.get(get_workflow_output_name(workflow.data)) + ) + ) + + return WorkflowSlices( + slices=outputs, + lattice_data=self + ) + + def process(self) -> ImageSlices: + """ + Execute the processing and return the result. + This will not execute the attached workflow. + """ + from lls_core.models.results import ImageSlices + ImageSlices.update_forward_refs(LatticeData=LatticeData) + + if self.cropping_enabled: + return ImageSlices( + lattice_data=self, + slices=self._process_crop() + ) + else: + return ImageSlices( + lattice_data=self, + slices=self._process_non_crop() + ) + + def save(self): + """ + + This is the main public API for processing + """ + if self.workflow: + list(self.process_workflow().save()) + else: + self.process().save_image() + + def process_into_image(self) -> ArrayLike: + """ + Shortcut method for calling process, then extracting one image layer. + This is mostly here to simplify the Workflow integration + """ + for slice in self.process().slices: + return slice.data + raise Exception("No slices produced!") + + def get_writer(self) -> Type[Writer]: + from lls_core.writers import BdvWriter, TiffWriter + if self.save_type == SaveFileType.h5: + return BdvWriter + elif self.save_type == SaveFileType.tiff: + return TiffWriter + raise Exception("Unknown output type") diff --git a/core/lls_core/models/output.py b/core/lls_core/models/output.py new file mode 100644 index 00000000..91c45e9e --- /dev/null +++ b/core/lls_core/models/output.py @@ -0,0 +1,86 @@ +from pydantic import Field, DirectoryPath, validator +from strenum import StrEnum +from pathlib import Path +from typing import TYPE_CHECKING, Union +from pandas import DataFrame +from lls_core.models.utils import FieldAccessModel, enum_choices + +if TYPE_CHECKING: + pass + +#Choice of File extension to save +class SaveFileType(StrEnum): + h5 = "h5" + tiff = "tiff" + +class OutputParams(FieldAccessModel): + save_dir: DirectoryPath = Field( + description="The directory where the output data will be saved." + ) + save_suffix: str = Field( + default="_deskewed", + description="The filename suffix that will be used for output files. This will be added as a suffix to the input file name if the input image was specified using a file name. If the input image was provided as an in-memory object, the `save_name` field should instead be specified.", + cli_description="The filename suffix that will be used for output files. This will be added as a suffix to the input file name if the --save-name flag was not specified." + ) + save_name: str = Field( + description="The filename that will be used for output files. This should not contain a leading directory or file extension. The final output files will have additional elements added to the end of this prefix to indicate the region of interest, channel, timepoint, file extension etc.", + default=None + ) + save_type: SaveFileType = Field( + default=SaveFileType.h5, + description=f"The data type to save the result as. This will also be used to determine the file extension of the output files. Choices: {enum_choices(SaveFileType)}." + ) + time_range: range = Field( + default=None, + description="The range of times to process. This defaults to all time points in the image array.", + cli_description="The range of times to process, as an array with two items: the first and last time index. This defaults to all time points in the image array." + ) + channel_range: range = Field( + default=None, + description="The range of channels to process. This defaults to all time points in the image array.", + cli_description="The range of channels to process, as an array with two items: the first and last channel index. This defaults to all channels in the image array." + ) + + @validator("save_dir", pre=True) + def validate_save_dir(cls, v: Path): + if isinstance(v, Path) and not v.is_absolute(): + # This stops the empty path being considered a valid directory + raise ValueError("The save directory must be an absolute path that exists") + return v + + @validator("save_name") + def add_save_suffix(cls, v: str, values: dict): + # This is the only place that the save suffix is used. + return v + values["save_suffix"] + + @property + def file_extension(self): + if self.save_type == SaveFileType.h5: + return "h5" + else: + return "tif" + + def make_filepath(self, suffix: str) -> Path: + """ + Returns a filepath for the resulting data + """ + return self.get_unique_filepath(self.save_dir / Path(self.save_name + suffix).with_suffix("." + self.file_extension)) + + def make_filepath_df(self, suffix: str, result: DataFrame) -> Path: + """ + Returns a filepath for the non-image data + """ + if isinstance(result, DataFrame): + return self.get_unique_filepath(self.save_dir / Path(self.save_name + suffix).with_suffix(".csv")) + + return + + def get_unique_filepath(self, path: Path) -> Path: + """ + Returns a unique filepath by appending a number to the filename if the file already exists. + """ + counter = 1 + while path.exists(): + path = path.with_name(f"{path.stem}_{counter}{path.suffix}") + counter += 1 + return path diff --git a/core/lls_core/models/results.py b/core/lls_core/models/results.py new file mode 100644 index 00000000..e3defb80 --- /dev/null +++ b/core/lls_core/models/results.py @@ -0,0 +1,176 @@ +from __future__ import annotations +from itertools import groupby +from pathlib import Path + +from typing import Iterable, Optional, Tuple, Union, cast, TYPE_CHECKING, overload +from typing_extensions import Generic, TypeVar +from pydantic import BaseModel, NonNegativeInt +from lls_core.types import ArrayLike, is_arraylike +from lls_core.utils import make_filename_suffix +from lls_core.writers import RoiIndex, Writer +from pandas import DataFrame +from lls_core.workflow import RawWorkflowOutput + +if TYPE_CHECKING: + from lls_core.models.lattice_data import LatticeData + from numpy.typing import NDArray + +T = TypeVar("T") +S = TypeVar("S") +R = TypeVar("R") +class ProcessedSlice(BaseModel, Generic[T], arbitrary_types_allowed=True): + """ + A single slice of some data that is split across multiple slices along time or channel axes + This class is generic over T, the type of data that is sliced. + """ + data: T + time_index: NonNegativeInt + time: NonNegativeInt + channel_index: NonNegativeInt + channel: NonNegativeInt + roi_index: Optional[NonNegativeInt] = None + + def copy_with_data(self, data: S) -> ProcessedSlice[S]: + """ + Return a modified version of this with new inner data + """ + from typing_extensions import cast + return cast( + ProcessedSlice[S], + self.copy(update={ + "data": data + }) + ) + + @overload + def as_tuple(self: ProcessedSlice[Tuple[R]]) -> Tuple[R]: + ... + @overload + def as_tuple(self: ProcessedSlice[T]) -> Tuple[T]: + ... + def as_tuple(self): + """ + Converts the results to a tuple if they weren't already + """ + return self.data if isinstance(self.data, (tuple, list)) else (self.data,) + +class ProcessedSlices(BaseModel, Generic[T], arbitrary_types_allowed=True): + """ + A generic parent class for holding deskewing outputs. + This will never be instantiated directly. + Refer to the concrete child classes for more detail. + """ + #: Iterable of result slices. + #: Note that this is a finite iterator that can only be iterated once + slices: Iterable[ProcessedSlice[T]] + + #: The "parent" LatticeData that was used to create this result + lattice_data: LatticeData + +ImageSlice = ProcessedSlice[ArrayLike] +class ImageSlices(ProcessedSlices[ArrayLike]): + """ + A collection of image slices, which is the main output from deskewing. + This holds an iterable of output image slices before they are saved to disk, + and provides a `save_image()` method for this purpose. + """ + def save_image(self): + """ + Saves result slices to disk + """ + Writer = self.lattice_data.get_writer() + for roi, roi_results in groupby(self.slices, key=lambda it: it.roi_index): + writer = Writer(self.lattice_data, roi_index=roi) + for slice in roi_results: + writer.write_slice(slice) + writer.close() + + +ProcessedWorkflowOutput = Union[ + # A path indicates a saved file + Path, + DataFrame +] + +class WorkflowSlices(ProcessedSlices[Union[Tuple[RawWorkflowOutput], RawWorkflowOutput]]): + """ + The counterpart of `ImageSlices`, but for workflow outputs. + This is needed because workflows have vastly different outputs that may include regular + Python types rather than only image slices. + """ + def process(self) -> Iterable[Tuple[RoiIndex, ProcessedWorkflowOutput]]: + """ + Incrementally processes the workflow outputs, and returns both image paths and data frames of the outputs, + for image slices and dict/list outputs respectively + """ + import pandas as pd + + # Handle each ROI separately + for roi, roi_results in groupby(self.slices, key=lambda it: it.roi_index): + values: list[Union[Writer, list]] = [] + for result in roi_results: + # If the user didn't return a tuple, put it into one + for i, element in enumerate(result.as_tuple()): + # If the element is array like, we assume it's an image to write to disk + if is_arraylike(element): + # Make the writer the first time only + if len(values) <= i: + values.append(self.lattice_data.get_writer()(self.lattice_data, roi_index=roi)) + + writer = cast(Writer, values[i]) + writer.write_slice( + result.copy_with_data( + element + ) + ) + else: + # Otherwise, we assume it's one row to be added to a data frame + if len(values) <= i: + values.append([]) + + rows = cast(list, values[i]) + + if isinstance(element, dict): + # If the row is a dict, it has column names + element = {"time": f"T{result.time_index}", "channel": f"C{result.channel_index}", **element} + elif isinstance(element, Iterable): + # If the row is a list, it has no column names + # We add the channel and time + element = [f"T{result.time_index}", f"C{result.channel_index}", *element] + else: + # If the row is just a value, we turn that value into a single column of the data frame + element = [f"T{result.time_index}", f"C{result.channel_index}", element] + + rows.append(element) + + for element in values: + if isinstance(element, Writer): + element.close() + for file in element.written_files: + yield roi, file + else: + yield roi, pd.DataFrame(element) + + def extract_preview(self) -> NDArray: + import numpy as np + for slice in self.slices: + for value in slice.as_tuple(): + if is_arraylike(value): + return np.asarray(value) + raise Exception("No image was returned from this workflow") + + def save(self) -> Iterable[Path]: + """ + Processes all workflow outputs and saves them to disk. + Images are saved in the format specified in the `LatticeData`, while + other data types are saved as a data frame. + """ + from pandas import DataFrame, Series + for i, (roi, result) in enumerate(self.process()): + if isinstance(result, DataFrame): + path = self.lattice_data.make_filepath_df(make_filename_suffix(roi_index=roi, prefix=f"_output_{i}"), result) + result = result.apply(Series.explode) + result.to_csv(str(path)) + yield path + else: + yield result diff --git a/core/lls_core/models/utils.py b/core/lls_core/models/utils.py new file mode 100644 index 00000000..2ec98a0c --- /dev/null +++ b/core/lls_core/models/utils.py @@ -0,0 +1,90 @@ + +from typing import Any, Type +from typing_extensions import Self +from enum import Enum +from pydantic import BaseModel, Extra +from contextlib import contextmanager + +def enum_choices(enum: Type[Enum]) -> str: + """ + Returns a human readable list of enum choices + """ + return "{" + ", ".join([it.name for it in enum]) + "}" + +@contextmanager +def ignore_keyerror(): + """ + Context manager that ignores KeyErrors from missing fields. + This allows for the validation to continue even if a single field + is missing, eventually resulting in a more user-friendly error message + """ + try: + yield + except KeyError: + pass + +class FieldAccessModel(BaseModel): + """ + Adds methods to a BaseModel for accessing useful field information + """ + class Config: + extra = Extra.forbid + arbitrary_types_allowed = True + validate_assignment = True + + @classmethod + def get_default(cls, field_name: str) -> Any: + """ + Shortcut method for returning the default value of a given field + """ + return cls.__fields__[field_name].get_default() + + @classmethod + def get_description(cls, field_name: str) -> str: + """ + Shortcut method for returning the description of a given field + """ + return cls.__fields__[field_name].field_info.description + + @classmethod + def to_definition_dict(cls) -> dict: + """ + Recursively converts the model into a dictionary whose keys are field names and whose + values are field descriptions. This is used to document the model to users + """ + ret = {} + for key, value in cls.__fields__.items(): + if value.field_info.extra.get("cli_hide"): + # Hide any fields that have cli_hide = True + continue + + if isinstance(value.outer_type_, type) and issubclass(value.outer_type_, FieldAccessModel): + rhs = value.outer_type_.to_definition_dict() + else: + rhs: str + if "cli_description" in value.field_info.extra: + # cli_description can be used to configure the help text that appears for fields for the CLI only + rhs = value.field_info.extra["cli_description"] + else: + rhs = value.field_info.description + rhs += f" Default: {value.get_default()}." + ret[key] = rhs + return ret + + def copy_validate(self, **kwargs) -> Self: + """ + Like `.copy()`, but validates the results. + See https://github.com/pydantic/pydantic/issues/418 for more information + """ + updated = self.copy(**kwargs) + return updated.validate(updated.dict()) + + @classmethod + def make(cls, validate: bool = True, **kwargs: Any): + """ + Creates an instance of this class, with validation either enabled or disabled + """ + if validate: + return cls(**kwargs) + else: + return cls.construct(**kwargs) diff --git a/core/lls_core/sample/LLS7_t1_ch1.czi b/core/lls_core/sample/LLS7_t1_ch1.czi new file mode 100755 index 00000000..5aa11594 Binary files /dev/null and b/core/lls_core/sample/LLS7_t1_ch1.czi differ diff --git a/core/lls_core/sample/LLS7_t1_ch3.czi b/core/lls_core/sample/LLS7_t1_ch3.czi new file mode 100755 index 00000000..07d160f5 Binary files /dev/null and b/core/lls_core/sample/LLS7_t1_ch3.czi differ diff --git a/core/lls_core/sample/LLS7_t2_ch1.czi b/core/lls_core/sample/LLS7_t2_ch1.czi new file mode 100755 index 00000000..1bb1fc68 Binary files /dev/null and b/core/lls_core/sample/LLS7_t2_ch1.czi differ diff --git a/core/lls_core/sample/LLS7_t2_ch3.czi b/core/lls_core/sample/LLS7_t2_ch3.czi new file mode 100755 index 00000000..10655d2d Binary files /dev/null and b/core/lls_core/sample/LLS7_t2_ch3.czi differ diff --git a/sample_data/RBC_lattice.tif b/core/lls_core/sample/RBC_lattice.tif similarity index 100% rename from sample_data/RBC_lattice.tif rename to core/lls_core/sample/RBC_lattice.tif diff --git a/sample_data/RBC_tiny.czi b/core/lls_core/sample/RBC_tiny.czi similarity index 100% rename from sample_data/RBC_tiny.czi rename to core/lls_core/sample/RBC_tiny.czi diff --git a/core/lls_core/sample/README.txt b/core/lls_core/sample/README.txt new file mode 100755 index 00000000..0a98099c --- /dev/null +++ b/core/lls_core/sample/README.txt @@ -0,0 +1,4 @@ +LLS7_t1_ch1: 3D array +LLS7_t1_ch3: 4D array with channel dimension value of 3 +LLS7_t2_ch1: 4D array with time dimension value of 2 +LLS7_t2_ch3: 4D array with time dimension of value 2 and channel of value 3 diff --git a/core/lls_core/sample/__init__.py b/core/lls_core/sample/__init__.py new file mode 100644 index 00000000..f2e4d560 --- /dev/null +++ b/core/lls_core/sample/__init__.py @@ -0,0 +1,3 @@ +import importlib_resources + +resources = importlib_resources.files(__name__) diff --git a/sample_data/config/README.md b/core/lls_core/sample/config/README.md similarity index 100% rename from sample_data/config/README.md rename to core/lls_core/sample/config/README.md diff --git a/sample_data/config/config.yaml b/core/lls_core/sample/config/config.yaml similarity index 100% rename from sample_data/config/config.yaml rename to core/lls_core/sample/config/config.yaml diff --git a/sample_data/config/config.yml b/core/lls_core/sample/config/config.yml similarity index 100% rename from sample_data/config/config.yml rename to core/lls_core/sample/config/config.yml diff --git a/core/lls_core/sample/multich_multi_time.tif b/core/lls_core/sample/multich_multi_time.tif new file mode 100755 index 00000000..0be59a54 Binary files /dev/null and b/core/lls_core/sample/multich_multi_time.tif differ diff --git a/sample_data/psfs/zeiss_simulated/488.czi b/core/lls_core/sample/psfs/zeiss_simulated/488.czi similarity index 100% rename from sample_data/psfs/zeiss_simulated/488.czi rename to core/lls_core/sample/psfs/zeiss_simulated/488.czi diff --git a/sample_data/psfs/zeiss_simulated/488.tif b/core/lls_core/sample/psfs/zeiss_simulated/488.tif similarity index 100% rename from sample_data/psfs/zeiss_simulated/488.tif rename to core/lls_core/sample/psfs/zeiss_simulated/488.tif diff --git a/sample_data/psfs/zeiss_simulated/561.czi b/core/lls_core/sample/psfs/zeiss_simulated/561.czi similarity index 100% rename from sample_data/psfs/zeiss_simulated/561.czi rename to core/lls_core/sample/psfs/zeiss_simulated/561.czi diff --git a/sample_data/psfs/zeiss_simulated/561.tif b/core/lls_core/sample/psfs/zeiss_simulated/561.tif similarity index 100% rename from sample_data/psfs/zeiss_simulated/561.tif rename to core/lls_core/sample/psfs/zeiss_simulated/561.tif diff --git a/sample_data/psfs/zeiss_simulated/640.czi b/core/lls_core/sample/psfs/zeiss_simulated/640.czi similarity index 100% rename from sample_data/psfs/zeiss_simulated/640.czi rename to core/lls_core/sample/psfs/zeiss_simulated/640.czi diff --git a/sample_data/psfs/zeiss_simulated/640.tif b/core/lls_core/sample/psfs/zeiss_simulated/640.tif similarity index 100% rename from sample_data/psfs/zeiss_simulated/640.tif rename to core/lls_core/sample/psfs/zeiss_simulated/640.tif diff --git a/sample_data/psfs/zeiss_simulated/README.md b/core/lls_core/sample/psfs/zeiss_simulated/README.md similarity index 100% rename from sample_data/psfs/zeiss_simulated/README.md rename to core/lls_core/sample/psfs/zeiss_simulated/README.md diff --git a/sample_data/psfs/zeiss_simulated/description.txt b/core/lls_core/sample/psfs/zeiss_simulated/description.txt similarity index 98% rename from sample_data/psfs/zeiss_simulated/description.txt rename to core/lls_core/sample/psfs/zeiss_simulated/description.txt index c2bcd55a..744c1b02 100644 --- a/sample_data/psfs/zeiss_simulated/description.txt +++ b/core/lls_core/sample/psfs/zeiss_simulated/description.txt @@ -1,4 +1,4 @@ -Simulated PSFs for Zeiss Lattice Lightsheet microscope generated using Zen Blue software -The names correspond to the excitation wavelength of the simulated fluorescence bead - +Simulated PSFs for Zeiss Lattice Lightsheet microscope generated using Zen Blue software +The names correspond to the excitation wavelength of the simulated fluorescence bead + Courtesy: Niall Geoghegan, Walter and Eliza Hall Institute of Medical Research, Melbourne, Australia \ No newline at end of file diff --git a/core/lls_core/types.py b/core/lls_core/types.py new file mode 100644 index 00000000..7cfe3a2e --- /dev/null +++ b/core/lls_core/types.py @@ -0,0 +1,38 @@ +from typing import Union +from typing_extensions import TypeGuard, Any, TypeAlias +from dask.array.core import Array as DaskArray +# from numpy.typing import NDArray +from pyopencl.array import Array as OCLArray +import numpy as np +from numpy.typing import NDArray +from xarray import DataArray +from aicsimageio import AICSImage +from os import fspath, PathLike as OriginalPathLike + +# This is a superset of os.PathLike +PathLike: TypeAlias = Union[str, bytes, OriginalPathLike] +def is_pathlike(x: Any) -> TypeGuard[PathLike]: + return isinstance(x, (str, bytes, OriginalPathLike)) + +ArrayLike: TypeAlias = Union[DaskArray, NDArray, OCLArray, DataArray] + +def is_arraylike(arr: Any) -> TypeGuard[ArrayLike]: + return isinstance(arr, (DaskArray, np.ndarray, OCLArray, DataArray)) + +ImageLike: TypeAlias = Union[PathLike, AICSImage, ArrayLike] +def image_like_to_image(img: ImageLike) -> DataArray: + """ + Converts an image in one of many formats to a DataArray + """ + # First try treating it as a path + try: + img = AICSImage(fspath(img)) + except TypeError: + pass + if isinstance(img, AICSImage): + return img.xarray_dask_data + else: + for required_key in ("shape", "dtype", "ndim", "__array__", "__array_ufunc__"): + if not hasattr(img, required_key): + raise ValueError(f"The provided object {img} is not array like!") + return DataArray(img) diff --git a/core/lls_core/utils.py b/core/lls_core/utils.py index fed022f5..b39e12a9 100644 --- a/core/lls_core/utils.py +++ b/core/lls_core/utils.py @@ -1,31 +1,27 @@ from __future__ import annotations -import numpy as np from collections import defaultdict from contextlib import contextmanager, redirect_stderr, redirect_stdout from os import devnull, path -import os -from typing_extensions import Any, TYPE_CHECKING, TypeGuard -from typing import List, Tuple, Union, Collection -from numpy.typing import NDArray - -import pandas as pd -import dask.array as da +from typing import Collection, List, Optional, Tuple, TypeVar, Union +import numpy as np import pyclesperanto_prototype as cle -from read_roi import read_roi_zip, read_roi_file +from lls_core.types import ArrayLike +from numpy.typing import NDArray +from read_roi import read_roi_file, read_roi_zip +from typing_extensions import TYPE_CHECKING, Any, TypeGuard -from tifffile import imsave -from . import config, DeskewDirection -from aicsimageio.types import ArrayLike +from . import DeskewDirection, config if TYPE_CHECKING: from xml.etree.ElementTree import Element + from dask.array.core import Array as DaskArray from napari.layers import Shapes - from napari_workflows import Workflow # Enable Logging import logging + logger = logging.getLogger(__name__) logger.setLevel(config.log_level) @@ -43,7 +39,7 @@ def check_subclass(obj: Any, pkg_name: str, cls_name: str) -> bool: def is_napari_shape(obj: Any) -> TypeGuard[Shapes]: return check_subclass(obj, "napari.shapes", "Shapes") -def calculate_crop_bbox(shape: int, z_start: int, z_end: int) -> tuple[List[List[Any]], List[int]]: +def calculate_crop_bbox(shape: list, z_start: int, z_end: int) -> tuple[List[List[Any]], List[int]]: """Get bounding box as vertices in 3D in the form xyz Args: @@ -103,7 +99,9 @@ def get_deskewed_shape(volume: ArrayLike, tuple: Shape of deskewed volume in zyx np.array: Affine transform for deskewing """ - from pyclesperanto_prototype._tier8._affine_transform import _determine_translation_and_bounding_box + from pyclesperanto_prototype._tier8._affine_transform import ( + _determine_translation_and_bounding_box, + ) deskew_transform = cle.AffineTransform3D() @@ -192,291 +190,6 @@ def dask_expand_dims(a: ArrayLike, axis: Union[Collection[int], int]): return a.reshape(shape) -def read_imagej_roi(roi_zip_path: str): - """Read an ImageJ ROI zip file so it loaded into napari shapes layer - If non rectangular ROI, will convert into a rectangle based on extreme points - Args: - roi_zip_path (zip file): ImageJ ROI zip file - - Returns: - list: List of ROIs - """ - roi_extension = path.splitext(roi_zip_path)[1] - - # handle reading single roi or collection of rois in zip file - if roi_extension == ".zip": - ij_roi = read_roi_zip(roi_zip_path) - elif roi_extension == ".roi": - ij_roi = read_roi_file(roi_zip_path) - else: - raise Exception("ImageJ ROI file needs to be a zip/roi file") - - # initialise list of rois - roi_list = [] - - # Read through each roi and create a list so that it matches the organisation of the shapes from napari shapes layer - for k in ij_roi.keys(): - if ij_roi[k]['type'] in ('oval', 'rectangle'): - width = ij_roi[k]['width'] - height = ij_roi[k]['height'] - left = ij_roi[k]['left'] - top = ij_roi[k]['top'] - roi = [[top, left], [top, left+width], - [top+height, left+width], [top+height, left]] - roi_list.append(roi) - elif ij_roi[k]['type'] in ('polygon', 'freehand'): - left = min(ij_roi[k]['x']) - top = min(ij_roi[k]['y']) - right = max(ij_roi[k]['x']) - bottom = max(ij_roi[k]['y']) - roi = [[top, left], [top, right], [bottom, right], [bottom, left]] - roi_list.append(roi) - else: - print("Cannot read ROI ", - ij_roi[k], ".Recognised as type ", ij_roi[k]['type']) - return roi_list - -# Functions to deal with cle workflow -# TODO: Clean up this function - - -def get_first_last_image_and_task(user_workflow: Workflow): - """Get images and tasks for first and last entry - Args: - user_workflow (Workflow): _description_ - Returns: - list: name of first input image, last input image, first task, last task - """ - - # get image with no preprocessing step (first image) - input_arg_first = user_workflow.roots()[0] - # get last image - input_arg_last = user_workflow.leafs()[0] - # get name of preceding image as that is the input to last task - img_source = user_workflow.sources_of(input_arg_last)[0] - first_task_name = [] - last_task_name = [] - - # loop through workflow keys and get key that has - for key in user_workflow._tasks.keys(): - for task in user_workflow._tasks[key]: - if task == input_arg_first: - first_task_name.append(key) - elif task == img_source: - last_task_name.append(key) - - return input_arg_first, img_source, first_task_name, last_task_name - - -def modify_workflow_task(old_arg, task_key: str, new_arg, workflow): - """_Modify items in a workflow task - Workflow is not modified, only a new task with updated arg is returned - Args: - old_arg (_type_): The argument in the workflow that needs to be modified - new_arg (_type_): New argument - task_key (str): Name of the task within the workflow - workflow (napari-workflow): Workflow - - Returns: - tuple: Modified task with name task_key - """ - task = workflow._tasks[task_key] - # convert tuple to list for modification - task_list = list(task) - try: - item_index = task_list.index(old_arg) - except ValueError: - print(old_arg, " not found in workflow file") - task_list[item_index] = new_arg - modified_task = tuple(task_list) - return modified_task - -def load_custom_py_modules(custom_py_files): - from importlib import reload, import_module - import sys - test_first_module_import = import_module(custom_py_files[0]) - if test_first_module_import not in sys.modules: - modules = map(import_module, custom_py_files) - else: - modules = map(reload, custom_py_files) - return modules - - -# TODO: CHANGE so user can select modules? Safer -def get_all_py_files(directory: str) -> list[str]: - """get all py files within directory and return as a list of filenames - Args: - directory: Directory with .py files - """ - from os.path import dirname, basename, isfile, join - import glob - - modules = glob.glob(join(dirname(directory), "*.py")) - all = [basename(f)[:-3] for f in modules if isfile(f) - and not f.endswith('__init__.py')] - print(f"Files found are: {all}") - - return all - - -def as_type(img, ref_vol): - """return image same dtype as ref_vol - - Args: - img (_type_): _description_ - ref_vol (_type_): _description_ - - Returns: - _type_: _description_ - """ - img.astype(ref_vol.dtype) - return img - - -def process_custom_workflow_output(workflow_output, - save_dir=None, - idx=None, - LLSZWidget=None, - widget_class=None, - channel=0, - time=0, - preview: bool = True): - """Check the output from a custom workflow; - saves tables and images separately - - Args: - workflow_output (_type_): _description_ - save_dir (_type_): _description_ - idx (_type_): _description_ - LLSZWidget (_type_): _description_ - widget_class (_type_): _description_ - channel (_type_): _description_ - time (_type_): _description_ - """ - if type(workflow_output) in [dict, list]: - # create function for tthis dataframe bit - df = pd.DataFrame(workflow_output) - if preview: - save_path = path.join( - save_dir, "lattice_measurement_"+str(idx)+".csv") - print(f"Detected a dictionary as output, saving preview at", save_path) - df.to_csv(save_path, index=False) - return df - - else: - return df - elif type(workflow_output) in [np.ndarray, cle._tier0._pycl.OCLArray, da.core.Array]: - if preview: - suffix_name = str(idx)+"_c" + str(channel) + "_t" + str(time) - scale = (LLSZWidget.LlszMenu.lattice.new_dz, - LLSZWidget.LlszMenu.lattice.dy, LLSZWidget.LlszMenu.lattice.dx) - widget_class.parent_viewer.add_image( - workflow_output, name="Workflow_preview_" + suffix_name, scale=scale) - else: - return workflow_output - - -def _process_custom_workflow_output_batch(ref_vol, - no_elements, - array_element_type, - channel_range, - images_array, - save_path, - time_point, - ch, - save_name_prefix, - save_name, - dx=None, - dy=None, - new_dz=None - ): - # create columns index for the list - if list in array_element_type: - row_idx = [] - - # Iterate through the dict or list output from workflow and add columns for Channel and timepoint - for i in range(no_elements): - for j in channel_range: - if type(images_array[j, i]) in [dict]: - # images_array[j,i].update({"Channel/Time":"C"+str(j)+"T"+str(time_point)}) - images_array[j, i].update({"Channel": "C"+str(j)}) - images_array[j, i].update({"Time": "T"+str(time_point)}) - elif type(images_array[j, i]) in [list]: - row_idx.append("C"+str(j)+"T"+str(time_point)) - # row_idx.append("C"+str(j)) - # row_idx.append("T"+str(time_point)) - - for element in range(no_elements): - if(array_element_type[element]) in [dict]: - # convert to pandas dataframe - output_dict_pd = [pd.DataFrame(i) - for i in images_array[:, element]] - - output_dict_pd = pd.concat(output_dict_pd) - # set index to the channel/time - output_dict_pd = output_dict_pd.set_index(["Time", "Channel"]) - - # Save path - dict_save_path = os.path.join( - save_path, "Measurement_"+save_name_prefix) - if not(os.path.exists(dict_save_path)): - os.mkdir(dict_save_path) - - #dict_save_path = os.path.join(dict_save_path,"C" + str(ch) + "T" + str(time_point)+"_"+str(element) + "_measurement.csv") - dict_save_path = os.path.join( - dict_save_path, "Summary_measurement_"+save_name_prefix+"_"+str(element)+"_.csv") - # Opens csv and appends it if file already exists; not efficient. - if os.path.exists(dict_save_path): - output_dict_pd_existing = pd.read_csv( - dict_save_path, index_col=["Time", "Channel"]) - output_dict_summary = pd.concat( - (output_dict_pd_existing, output_dict_pd)) - output_dict_summary.to_csv(dict_save_path) - else: - output_dict_pd.to_csv(dict_save_path) - - # TODO:modify this so one file saved for measurement - elif(array_element_type[element]) in [list]: - - output_list_pd = pd.DataFrame( - np.vstack(images_array[:, element]), index=row_idx) - # Save path - list_save_path = os.path.join( - save_path, "Measurement_"+save_name_prefix) - if not(os.path.exists(list_save_path)): - os.mkdir(list_save_path) - list_save_path = os.path.join(list_save_path, "C" + str(ch) + "T" + str( - time_point)+"_"+save_name_prefix+"_"+str(element) + "_measurement.csv") - output_list_pd.to_csv(list_save_path) - - elif(array_element_type[element]) in [np.ndarray, cle._tier0._pycl.OCLArray, da.core.Array]: - - # Save path - img_save_path = os.path.join( - save_path, "Measurement_"+save_name_prefix) - if not(os.path.exists(img_save_path)): - os.mkdir(img_save_path) - - im_final = np.stack(images_array[:, element]).astype(ref_vol.dtype) - final_name = os.path.join(img_save_path, save_name_prefix + "_"+str(element) + "_T" + str( - time_point) + "_" + save_name + ".tif") - # "C" + str(ch) + - #OmeTiffWriter.save(images_array, final_name, physical_pixel_sizes=aics_image_pixel_sizes) - # if only one image with no channel, then dimension will 1,z,y,x, so swap 0 and 1 - if len(im_final.shape) == 4: - # was 1,2,but when stacking images, dimension is CZYX - im_final = np.swapaxes(im_final, 0, 1) - # adding extra dimension for T - im_final = im_final[np.newaxis, ...] - elif len(im_final.shape) > 4: # if - # if image with multiple channels, , it will be 1,c,z,y,x - im_final = np.swapaxes(im_final, 1, 2) - # imagej=True; ImageJ hyperstack axes must be in TZCYXS order - imsave(final_name, im_final, bigtiff=True, imagej=True, resolution=(1./dx, 1./dy), - metadata={'spacing': new_dz, 'unit': 'um', 'axes': 'TZCYX'}) # imagej=True - im_final = None - return - def pad_image_nearest_multiple(img: NDArray, nearest_multiple: int) -> NDArray: """pad an Image to the nearest multiple of provided number @@ -499,7 +212,7 @@ def pad_image_nearest_multiple(img: NDArray, nearest_multiple: int) -> NDArray: return padded_img -def check_dimensions(user_time_start: int, user_time_end, user_channel_start: int, user_channel_end: int, total_channels: int, total_time: int): +def check_dimensions(user_time_start: int, user_time_end: int, user_channel_start: int, user_channel_end: int, total_channels: int, total_time: int): if total_time == 1 or total_time == 2: max_time = 1 @@ -572,3 +285,52 @@ def crop_psf(psf_img: np.ndarray, threshold: float = 3e-3): psf_crop = psf_img[min_z:max_z, min_y:max_y, min_x:max_x] return psf_crop + +def as_type(img, ref_vol): + """return image same dtype as ref_vol + + Args: + img (_type_): _description_ + ref_vol (_type_): _description_ + + Returns: + _type_: _description_ + """ + img.astype(ref_vol.dtype) + return img + +T = TypeVar("T") +def raise_if_none(obj: Optional[T], message: str) -> T: + """ + Asserts that `obj` is not None + """ + if obj is None: + raise TypeError(message) + return obj + +def array_to_dask(arr: ArrayLike) -> DaskArray: + from dask.array.core import Array as DaskArray, from_array + from xarray import DataArray + from resource_backed_dask_array import ResourceBackedDaskArray + + if isinstance(arr, DataArray): + arr = arr.data + if isinstance(arr, (DaskArray, ResourceBackedDaskArray)): + return arr + else: + return from_array(arr) + +def make_filename_suffix(prefix: Optional[str] = None, roi_index: Optional[str] = None, channel: Optional[str] = None, time: Optional[str] = None) -> str: + """ + Generates a filename for this result + """ + components: List[str] = [] + if prefix is not None: + components.append(prefix) + if roi_index is not None: + components.append(f"ROI_{roi_index}") + if channel is not None: + components.append(f"C{channel}") + if time is not None: + components.append(f"T{time}") + return "_".join(components) diff --git a/core/lls_core/workflow.py b/core/lls_core/workflow.py new file mode 100644 index 00000000..ab1ad2e3 --- /dev/null +++ b/core/lls_core/workflow.py @@ -0,0 +1,178 @@ +""" +Functions related to manipulating Napari Workflows +""" +from __future__ import annotations + +from pathlib import Path +from typing import Any, Generator, Iterable, Iterator, Tuple, TypeVar, Union + +from typing_extensions import TYPE_CHECKING + +from lls_core.types import ArrayLike + +if TYPE_CHECKING: + from napari_workflows import Workflow + +import logging +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +RawWorkflowOutput = Union[ + ArrayLike, + dict, + list +] + +def get_workflow_output_name(workflow: Workflow) -> str: + """ + Returns the name of the singular workflow output + """ + results = [ + leaf + for leaf in workflow.leafs() + if leaf not in {"deskewed_image", "channel", "channel_index", "time", "time_index", "roi_index"} + ] + if len(results) > 1: + raise Exception("Only workflows with one output are supported.") + return results[0] + +def workflow_set(workflow: Workflow, name: str, func_or_data: Any, args: list = []): + """ + The same as Workflow.set, but less buggy + """ + workflow._tasks[name] = tuple([func_or_data] + args) + +def get_workflow_inputs(workflow: Workflow) -> Generator[Tuple[str, int, str]]: + """ + Yields tuples of (task_name, argument_index, input_argument) corresponding to the workflow's inputs, + namely the arguments that are unfilled. + Note that the index returned is the index in the overall task tuple, which includes the task name + """ + for root_arg in workflow.roots(): + for taskname, (task_func, *args) in workflow._tasks.items(): + if root_arg in args: + yield taskname, args.index(root_arg) + 1, root_arg + +def update_workflow(workflow: Workflow, task_name: str, task_index: int, new_value: Any) -> None: + """ + Mutates `workflow` by finding the task with name `task_name`, and setting the argument with index + `task_index` to `new_value`. + """ + task = list(workflow.get_task(task_name)) + task[task_index] = new_value + workflow.set_task(task_name, tuple(task)) + +def get_first_last_image_and_task(user_workflow: Workflow) -> Tuple[str, str, str, str]: + """ + Get images and tasks for first and last entry + Returns: + Tuple of (name of first input image, name of last input image, name of first task, name of last task) + """ + + # get image with no preprocessing step (first image) + input_arg_first = user_workflow.roots()[0] + # get last image + input_arg_last = user_workflow.leafs()[0] + # get name of preceding image as that is the input to last task + img_source = user_workflow.sources_of(input_arg_last)[0] + first_task_name = [] + last_task_name = [] + + # loop through workflow keys and get key that has + for key in user_workflow._tasks.keys(): + for task in user_workflow._tasks[key]: + if task == input_arg_first: + first_task_name.append(key) + elif task == img_source: + last_task_name.append(key) + + return input_arg_first, img_source, first_task_name[0], last_task_name[0] if len(last_task_name) > 0 else first_task_name[0] + + +def modify_workflow_task(old_arg: str, task_key: str, new_arg: str, workflow: Workflow) -> tuple: + """ + Replies one argument in a workflow with another + Workflow is not modified, only a new task with updated arg is returned + Args: + old_arg: The argument in the workflow that needs to be modified + new_arg: New argument + task_key: Name of the task within the workflow + workflow: Workflow + + Returns: + tuple: Modified task with name task_key + """ + task = workflow._tasks[task_key] + # convert tuple to list for modification + task_list = list(task) + try: + item_index = task_list.index(old_arg) + except ValueError: + raise Exception(f"{old_arg} not found in workflow file") + + task_list[item_index] = new_arg + modified_task = tuple(task_list) + return modified_task + +# TODO: CHANGE so user can select modules? Safer +def get_all_py_files(directory: str) -> list[str]: + """get all py files within directory and return as a list of filenames + Args: + directory: Directory with .py files + """ + import glob + from os.path import basename, dirname, isfile, join + + modules = glob.glob(join(dirname(directory), "*.py")) + all = [basename(f)[:-3] for f in modules if isfile(f) + and not f.endswith('__init__.py')] + print(f"Files found are: {all}") + + return all + +def import_script(script: Path): + """ + Imports a Python script given its path + """ + import importlib.util + import sys + module_name = script.stem + spec = importlib.util.spec_from_file_location(module_name, script) + if spec is None or spec.loader is None: + raise Exception(f"Failed to import {script}!") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + +def _import_workflow_modules(workflow: Path) -> None: + """ + Imports all the Python files that might be used in a given custom workflow + + Args: + workflow: Path to the workflow YAML file + """ + if not workflow.exists(): + raise Exception("Workflow doesn't exist!") + if not workflow.is_file(): + raise Exception("Workflow must be a file!") + + counter = 0 + for script in workflow.parent.glob("*.py"): + if script.stem == "__init__": + # Skip __init__.py + continue + import_script(script) + counter += 1 + + if counter == 0: + logger.warn(f"No custom modules imported. If you'd like to use a custom module, place a *.py file in same folder as the workflow file {workflow.parent}") + else: + logger.info(f"{counter} custom modules imported") + +def workflow_from_path(workflow: Path) -> Workflow: + """ + Imports the dependency modules for a workflow, and loads it from disk + """ + from napari_workflows._io_yaml_v1 import load_workflow + _import_workflow_modules(workflow) + return load_workflow(str(workflow)) diff --git a/core/lls_core/writers.py b/core/lls_core/writers.py new file mode 100644 index 00000000..4a00b566 --- /dev/null +++ b/core/lls_core/writers.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional + +from lls_core.types import ArrayLike + +from pydantic import NonNegativeInt + +from lls_core.utils import make_filename_suffix +RoiIndex = Optional[NonNegativeInt] + +if TYPE_CHECKING: + from lls_core.models.lattice_data import LatticeData + import npy2bdv + from lls_core.models.results import ProcessedSlice, ImageSlice + from pathlib import Path + + +@dataclass +class Writer(ABC): + """ + A writer is an abstraction over the logic used to write image slices to disk + Writers need to work incrementally, in order that we don't need the entire multidimensional + image in memory at the same time + """ + lattice: LatticeData + roi_index: RoiIndex + written_files: List[Path] = field(default_factory=list, init=False) + + @abstractmethod + def write_slice(self, slice: ProcessedSlice[ArrayLike]): + """ + Writes a 3D image slice + """ + pass + + def close(self): + """ + Called when no more image slices are available, and the writer should finalise its output files + """ + pass + +@dataclass +class BdvWriter(Writer): + """ + A writer for for Fiji BigDataViewer output format + """ + bdv_writer: npy2bdv.BdvWriter = field(init=False) + + def __post_init__(self): + import npy2bdv + suffix = f"_{make_filename_suffix(roi_index=str(self.roi_index))}" if self.roi_index is not None else "" + path = self.lattice.make_filepath(suffix) + self.bdv_writer = npy2bdv.BdvWriter( + filename=str(path), + compression='gzip', + nchannels=len(self.lattice.channel_range), + subsamp=((1, 1, 1), (1, 2, 2), (2, 4, 4)), + overwrite=False + ) + self.written_files.append(path) + + def write_slice(self, slice: ProcessedSlice[ArrayLike]): + import numpy as np + self.bdv_writer.append_view( + np.array(slice.data), + # We need to use the indices here to ensure they start from 0 and + # are contiguous + time=slice.time_index, + channel=slice.channel_index, + voxel_size_xyz=(self.lattice.dx, self.lattice.dy, self.lattice.new_dz), + voxel_units='um' + ) + + def close(self): + self.bdv_writer.write_xml() + self.bdv_writer.close() + +@dataclass +class TiffWriter(Writer): + """ + A writer for for TIFF output format + """ + pending_slices: List[ImageSlice] = field(default_factory=list, init=False) + time: Optional[NonNegativeInt] = None + + def __post_init__(self): + self.pending_slices = [] + + def flush(self): + "Write out all pending slices" + import numpy as np + import tifffile + if len(self.pending_slices) > 0: + first_result = self.pending_slices[0] + images_array = np.swapaxes( + np.expand_dims([result.data for result in self.pending_slices], axis=0), + 1, 2 + ).astype("uint16") + # ImageJ TIFF can only handle 16-bit uints, not 32 + path = self.lattice.make_filepath( + make_filename_suffix( + channel=first_result.channel, + time=first_result.time, + roi_index=first_result.roi_index + ) + ) + tifffile.imwrite( + str(path), + data = images_array, + bigtiff=True, + resolution=(1./self.lattice.dx, 1./self.lattice.dy, "MICROMETER"), + metadata={'spacing': self.lattice.new_dz, 'unit': 'um', 'axes': 'TZCYX'}, + imagej=True + ) + self.written_files.append(path) + + # Reinitialise + self.pending_slices = [] + + def write_slice(self, slice: ProcessedSlice[ArrayLike]): + if slice.time != self.time: + self.flush() + + self.time = slice.time + self.pending_slices.append(slice) + + def close(self): + self.flush() diff --git a/core/pyproject.toml b/core/pyproject.toml index ab533555..2b472e28 100644 --- a/core/pyproject.toml +++ b/core/pyproject.toml @@ -39,21 +39,31 @@ dependencies = [ # Earlier versions don't have Python 3.11 binaries, and the sdist # is misconfigured: https://github.com/AllenCellModeling/aicspylibczi/issues/90 "aicspylibczi>=3.1.1", + "click", "dask", "dask[distributed]", "fsspec>=2022.8.2", - "pyclesperanto_prototype>=0.20.0", + "importlib_resources", + # Older tifffile are not compatible with aicsimageio, see: https://github.com/AllenCellModeling/aicsimageio/issues/518 "napari-workflows>=0.2.8", "npy2bdv", "numpy", "pandas", + "pyclesperanto_prototype>=0.20.0", + "pyopencl", + "pydantic~=1.0", "pyyaml", "read-roi", + "rich", "resource-backed-dask-array>=0.1.0", "scikit-image", - "tifffile", + "StrEnum", + "tifffile>=2023.3.15", #>=2022.8.12 + "toolz", "tqdm", - "typing_extensions" + "typer", + "typing_extensions>=4.7.0", + "xarray[parallel]", ] [project.urls] @@ -73,17 +83,37 @@ testing = [ "pytest-cov", # https://pytest-cov.readthedocs.io/en/latest/ "napari" # We use napari for type checking only, and not at runtime ] -psf = [ - "pycudadecon" +deconvolution = [ + "pycudadecon~=0.5", + "numpy<2" ] [project.scripts] lls-pipeline = "lls_core.cmds.__main__:main" +[tool.mypy] +plugins = [ + "pydantic.mypy" +] + +[tool.pydantic-mypy] +init_typed = false + [tool.pyright] typeCheckingMode = "off" reportUndefinedVariable = "error" reportMissingImports = "none" +reportMissingTypeStubs = false +reportUnknownVariableType = false +reportUnknownArgumentType = false +reportUnknownLambdaType = false +reportUnknownMemberType = false +reportUnknownParameterType = false +reportUntypedFunctionDecorator = false +reportMissingTypeArgument = false +reportPrivateUsage = false +reportPrivateImportUsage = false +reportUnnecessaryComparison = false [tool.fawltydeps] ignore_unused = [ diff --git a/core/tests/conftest.py b/core/tests/conftest.py new file mode 100644 index 00000000..8eaf5022 --- /dev/null +++ b/core/tests/conftest.py @@ -0,0 +1,120 @@ +from importlib_resources import as_file +from typer.testing import CliRunner +import pytest +from skimage.io import imsave +import numpy as np +from pathlib import Path +import pyclesperanto_prototype as cle +import tempfile +from numpy.typing import NDArray +from copy import copy +from lls_core.sample import resources + +from napari_workflows import Workflow +from napari_workflows._io_yaml_v1 import save_workflow + +@pytest.fixture +def runner() -> CliRunner: + return CliRunner() + +@pytest.fixture +def rbc_tiny(): + with as_file(resources / "RBC_tiny.czi") as image_path: + yield image_path + +@pytest.fixture +def multi_channel_time(): + with as_file(resources / "multich_multi_time.tif") as image_path: + yield image_path + +@pytest.fixture(params=[ + "LLS7_t1_ch1.czi", + "LLS7_t1_ch3.czi", + "LLS7_t2_ch1.czi", + "LLS7_t2_ch3.czi", +]) +def minimal_image_path(request: pytest.FixtureRequest): + """ + Fixture function that yields a minimal set of test images as file paths + """ + with as_file(resources / request.param) as image_path: + yield image_path + +@pytest.fixture(params=[ + "RBC_tiny.czi", + "RBC_lattice.tif", + "LLS7_t1_ch1.czi", + "LLS7_t1_ch3.czi", + "LLS7_t2_ch1.czi", + "LLS7_t2_ch3.czi", + "multich_multi_time.tif" +]) +def image_path(request: pytest.FixtureRequest): + """ + Fixture function that yields test images as file paths + """ + with as_file(resources / request.param) as image_path: + yield image_path + +@pytest.fixture +def image_workflow() -> Workflow: + # Simple segmentation workflow that returns an image + image_seg_workflow = Workflow() + image_seg_workflow.set("gaussian", cle.gaussian_blur, "deskewed_image", sigma_x=1, sigma_y=1, sigma_z=1) + image_seg_workflow.set("binarisation", cle.threshold, "gaussian", constant=0.5) + image_seg_workflow.set("labeling", cle.connected_components_labeling_box, "binarisation") + return image_seg_workflow + +@pytest.fixture +def table_workflow(image_workflow: Workflow) -> Workflow: + # Complex workflow that returns a tuple of (image, dict, list, int) + ret = copy(image_workflow) + ret.set("result", lambda x: ( + x, + { + "foo": 1, + "bar": 2 + }, + ["foo", "bar"], + 1 + ), "labeling") + return ret + +@pytest.fixture +def test_image() -> NDArray[np.float64]: + raw = np.zeros((5, 5, 5)) + raw[2, 2, 2] = 10 + return raw + +@pytest.fixture +def workflow_config(image_workflow: Workflow, test_image: NDArray): + # Create a config file + yield { + "input_image": test_image, + "workflow": image_workflow, + } + +@pytest.fixture +def workflow_config_cli(image_workflow: Workflow, test_image: NDArray): + with tempfile.TemporaryDirectory() as tempdir_str: + tempdir = Path(tempdir_str) + input = tempdir / "raw.tiff" + output = tempdir / "output" + output.mkdir(parents=True) + workflow_path = tempdir / "workflow.json" + save_workflow(str(workflow_path), image_workflow) + + # Create a zero array of shape 5x5x5 with a value of 10 at (2,4,2) + imsave(input, test_image) + assert input.exists() + + # Create a config file + yield { + key: str(val) + for key, val in + { + "input_image": input, + "save_dir": output, + "workflow": workflow_path, + }.items() + } diff --git a/core/tests/data/crop/roi_1.tif b/core/tests/data/crop/roi_1.tif new file mode 100644 index 00000000..b702ebe0 Binary files /dev/null and b/core/tests/data/crop/roi_1.tif differ diff --git a/core/tests/data/crop/roi_2.tif b/core/tests/data/crop/roi_2.tif new file mode 100644 index 00000000..d31ee2ca Binary files /dev/null and b/core/tests/data/crop/roi_2.tif differ diff --git a/core/tests/data/crop/two_rois.zip b/core/tests/data/crop/two_rois.zip new file mode 100644 index 00000000..2a60e8e1 Binary files /dev/null and b/core/tests/data/crop/two_rois.zip differ diff --git a/core/tests/params.py b/core/tests/params.py new file mode 100644 index 00000000..2bb1f208 --- /dev/null +++ b/core/tests/params.py @@ -0,0 +1,20 @@ +import pytest +import json +import yaml + +from lls_core.models.output import SaveFileType + +parameterized = pytest.mark.parametrize("args", [ + {"skew": "X"}, + {"skew": "Y"}, + {"angle": 30}, + {"physical_pixel_sizes": (1, 1, 1)}, + {"save_type": SaveFileType.h5}, + {"save_type": SaveFileType.tiff}, +]) + +# Allows parameterisation over two serialization formats +config_types = pytest.mark.parametrize(["save_func", "cli_param"], [ + (json.dump, "--json-config"), + (yaml.safe_dump, "--yaml-config") +]) diff --git a/core/tests/test_arg_parser.py b/core/tests/test_arg_parser.py index afcf0af1..26bb1e20 100644 --- a/core/tests/test_arg_parser.py +++ b/core/tests/test_arg_parser.py @@ -1,13 +1,18 @@ -from lls_core.cmds.__main__ import make_parser +from typer.main import get_command +from click import Context +from lls_core.cmds.__main__ import app + def test_voxel_parsing(): # Tests that we can parse voxel lists correctly - parser = make_parser() - args = parser.parse_args([ - "--input", "input", - "--output", "output", - "--processing", "deskew", - "--output_file_type", "tiff", - "--voxel_sizes", "1", "1", "1" + command = get_command(app) + ctx = Context(command) + parser = command.make_parser(ctx) + args, _, _ = parser.parse_args(args=[ + "process", + "input", + "--save-name", "output", + "--save-type", "tiff", + "--pixel-sizes", "1", "1", "1" ]) - assert args.voxel_sizes == [1.0, 1.0, 1.0] \ No newline at end of file + assert args["pixel_sizes"] == ("1", "1", "1") diff --git a/core/tests/test_batch_deskew_args.py b/core/tests/test_batch_deskew_args.py deleted file mode 100644 index eb09bc35..00000000 --- a/core/tests/test_batch_deskew_args.py +++ /dev/null @@ -1,53 +0,0 @@ -# Tests for napari_lattice using arguments and saving output files as h5, as well as tiff - -from skimage.io import imsave -import numpy as np -from pathlib import Path -import tempfile -from lls_core.cmds.__main__ import main as run_cli - -def create_image(path: Path): - # Create a zero array of shape 5x5x5 with a value of 10 at (2,4,2) - raw = np.zeros((5, 5, 5)) - raw[2, 4, 2] = 10 - # Save image as a tif filw in home directory - imsave(str(path), raw) - assert path.exists() - - -def test_batch_deskew_h5(): - """Write image to disk and then execute napari_lattice from terminal - Checks if an deskewed output file is created for both tif and h5 - """ - with tempfile.TemporaryDirectory() as out_dir: - out_dir = Path(out_dir) - input_file = out_dir / 'raw.tiff' - create_image(input_file) - # Batch deskew and save as h5 - run_cli([ - "--input", str(input_file), - "--output", str(out_dir), - "--processing", "deskew", - "--output_file_type", "h5" - ]) - - # checks if h5 files written - assert (out_dir / "raw" / "raw.h5").exists() - assert (out_dir / "raw" / "raw.xml").exists() - - -def test_batch_deskew_tiff(): - # tiff file deskew - with tempfile.TemporaryDirectory() as out_dir: - out_dir = Path(out_dir) - input_file = out_dir / 'raw.tiff' - create_image(input_file) - run_cli([ - "--input", str(input_file), - "--output", str(out_dir), - "--processing", "deskew", - "--output_file_type", "tiff" - ]) - - # checks if tiff written - assert (out_dir / "raw" / "C0T0_raw.tif").exists() diff --git a/core/tests/test_batch_deskew_yaml.py b/core/tests/test_batch_deskew_yaml.py deleted file mode 100644 index b8c4f3ea..00000000 --- a/core/tests/test_batch_deskew_yaml.py +++ /dev/null @@ -1,67 +0,0 @@ -# Tests for napari_lattice using the config file and saving ouput as h5 -# Thanks to DrLachie for cool function to write the config file - -from skimage.io import imread, imsave -import tempfile -import numpy as np -from pathlib import Path -from lls_core.cmds.__main__ import main as run_cli - -def write_config_file(config_settings: dict, output_file_location: Path): - # Write config file for napari_lattice - with output_file_location.open('w') as f: - for key, val in config_settings.items(): - if val is not None: - if type(val) is str: - print('%s: "%s"' % (key, val), file=f) - - if type(val) is int: - print('%s: %i' % (key, val), file=f) - - if type(val) is list: - print("%s:" % key, file=f) - for x in val: - if type(x) is int: - print(" - %i" % x, file=f) - else: - print(' - "%s"' % x, file=f) - - print("Config found written to %s" % output_file_location) - -def create_data(dir: Path) -> Path: - input_file = dir / 'raw.tiff' - config_location = dir / "config_deskew.yaml" - - # Create a zero array of shape 5x5x5 with a value of 10 at (2,4,2) - raw = np.zeros((5, 5, 5)) - raw[2, 4, 2] = 10 - # Save image as a tif filw in home directory - imsave(input_file, raw) - assert input_file.exists() - - config: dict[str, str] = { - "input": str(input_file), - "output": str(dir), - "processing": "deskew", - "output_file_type": "h5" - } - - write_config_file(config, config_location) - assert config_location.exists() - - return config_location - - -def test_yaml_deskew(): - """Write image to disk and then execute napari_lattice from terminal - Checks if an deskewed output file is created for both tif and h5 - """ - with tempfile.TemporaryDirectory() as test_dir: - test_dir = Path(test_dir) - config_location = create_data(test_dir) - # Batch deskew and save as h5 - run_cli(["--config", str(config_location)]) - - # checks if h5 files written - assert (test_dir / "raw" / "raw.h5").exists() - assert (test_dir / "raw" / "raw.xml").exists() diff --git a/core/tests/test_cli.py b/core/tests/test_cli.py new file mode 100644 index 00000000..7d96b33f --- /dev/null +++ b/core/tests/test_cli.py @@ -0,0 +1,103 @@ +# Tests for napari_lattice using arguments and saving output files as h5, as well as tiff + +from typing import Callable, List +import pytest +from aicsimageio.aics_image import AICSImage +from npy2bdv import BdvEditor +import numpy as np +from pathlib import Path +import tempfile +from tests.utils import invoke +import yaml + +def create_image(path: Path): + # Create a zero array of shape 5x5x5 with a value of 10 at (2,4,2) + raw = np.zeros((5, 5, 5)) + raw[2, 4, 2] = 10 + # Save image as a tif filw in home directory + AICSImage(raw).save(path) + assert path.exists() + + +def create_data(dir: Path) -> Path: + # Creates and returns a YAML config file + input_file = dir / 'raw.tiff' + config_location = dir / "config_deskew.yaml" + + # Create a zero array of shape 5x5x5 with a value of 10 at (2,4,2) + raw = np.zeros((5, 5, 5)) + raw[2, 4, 2] = 10 + # Save image as a tif filw in home directory + AICSImage(raw).save(input_file) + assert input_file.exists() + + config: dict[str, str] = { + "input_image": str(input_file), + "save_dir": str(dir), + "save_type": "h5" + } + + with config_location.open("w") as fp: + yaml.safe_dump(config, fp) + + return config_location + +def assert_tiff(output_dir: Path): + """Checks that a valid TIFF was generated in the directory""" + results = list(output_dir.glob("*.tif")) + assert len(results) > 0 + for result in results: + AICSImage(result).get_image_data() + +def assert_h5(output_dir: Path): + """Checks that a valid H5 was generated""" + h5s = list(output_dir.glob("*.h5")) + assert len(h5s) > 0 + assert len(list(output_dir.glob("*.xml"))) == len(h5s) + for h5 in h5s: + BdvEditor(str(h5)).read_view() + +@pytest.mark.parametrize( + ["flags", "check_fn"], + [ + [["--save-type", "h5"], assert_h5], + [["--save-type", "tiff"], assert_tiff], + [["--save-type", "tiff", "--time-start", "0", "--time-end", "1"], assert_tiff], + ] +) +def test_batch_deskew(flags: List[str], check_fn: Callable[[Path], None]): + """ + Write image to disk and then execute napari_lattice from terminal + Checks if an deskewed output file is created for both tif and h5 + """ + with tempfile.TemporaryDirectory() as _test_dir: + test_dir = Path(_test_dir) + + # Inputs + input_file = test_dir / "raw.tiff" + create_image(input_file) + + # Outputs + out_dir = test_dir / "output" + out_dir.mkdir() + + # Batch deskew and save as h5 + invoke([ + str(input_file), + "--save-dir", str(out_dir), + *flags + ]) + + check_fn(out_dir) + +def test_yaml_deskew(): + """ + Write image to disk and then execute napari_lattice from terminal + Checks if an deskewed output file is created for both tif and h5 + """ + with tempfile.TemporaryDirectory() as test_dir: + test_dir = Path(test_dir) + config_location = create_data(test_dir) + # Batch deskew and save as h5 + invoke(["--yaml-config", str(config_location)], ) + assert_h5(test_dir) diff --git a/core/tests/test_crop_deskew.py b/core/tests/test_crop_deskew.py index e873c3f7..9b3a4b21 100644 --- a/core/tests/test_crop_deskew.py +++ b/core/tests/test_crop_deskew.py @@ -1,35 +1,35 @@ -import pyclesperanto_prototype as cle -import numpy as np - -from lls_core.llsz_core import crop_volume_deskew - - -def test_crop_deskew(): - raw = np.zeros((5,5,5)) - raw[2,4,2] = 10 - deskew_angle = 60 - - deskewed = cle.deskew_y(raw,angle_in_degrees=deskew_angle).astype(raw.dtype) - - #print(np.argwhere(deskewed>0)) - - #Crop deskewed volume - ref_crop_deskew_img = deskewed[0:4,3:5,0:5] - - #Similarly, generate an roi with coordinates (x1=0,y1=3,z1=0) to (x2=5,y2=5,z2=5) - #Use this for cropping deskewed volume to get matching area - roi = np.array(((3,0),(3,5),(5,5),(5,0))) - z1 = 0 - z2 = 5 - - cropped_deskew_img = crop_volume_deskew(original_volume = raw, - deskewed_volume = deskewed, - roi_shape = roi, - angle_in_degrees = deskew_angle, - z_start = z1, - z_end = z2, - linear_interpolation=True).astype(raw.dtype) - - assert cropped_deskew_img[0, 1, 2] == ref_crop_deskew_img[0,1,2] - assert cropped_deskew_img[0, 0, 2] == ref_crop_deskew_img[0,0,2] - assert ref_crop_deskew_img.shape == cropped_deskew_img.shape \ No newline at end of file +import pyclesperanto_prototype as cle +import numpy as np + +from lls_core.llsz_core import crop_volume_deskew + + +def test_crop_deskew(): + raw = np.zeros((5,5,5)) + raw[2,4,2] = 10 + deskew_angle = 60 + + deskewed = cle.deskew_y(raw,angle_in_degrees=deskew_angle).astype(raw.dtype) + + #print(np.argwhere(deskewed>0)) + + #Crop deskewed volume + ref_crop_deskew_img = deskewed[0:4,3:5,0:5] + + #Similarly, generate an roi with coordinates (x1=0,y1=3,z1=0) to (x2=5,y2=5,z2=5) + #Use this for cropping deskewed volume to get matching area + roi = np.array(((3,0),(3,5),(5,5),(5,0))) + z1 = 0 + z2 = 5 + + cropped_deskew_img = crop_volume_deskew(original_volume = raw, + deskewed_volume = deskewed, + roi_shape = roi, + angle_in_degrees = deskew_angle, + z_start = z1, + z_end = z2, + linear_interpolation=True).astype(raw.dtype) + + assert cropped_deskew_img[0, 1, 2] == ref_crop_deskew_img[0,1,2] + assert cropped_deskew_img[0, 0, 2] == ref_crop_deskew_img[0,0,2] + assert ref_crop_deskew_img.shape == cropped_deskew_img.shape diff --git a/core/tests/test_deskew.py b/core/tests/test_deskew.py index 8cd06c93..adbcd6cc 100644 --- a/core/tests/test_deskew.py +++ b/core/tests/test_deskew.py @@ -1,20 +1,28 @@ -#filename and function name should start with "test_" when using pytest -import pyclesperanto_prototype as cle -import numpy as np -from lls_core.lattice_data import lattice_fom_array - -def test_deskew(): - - raw = np.zeros((5,5,5)) - raw[2,0,0] = 10 - - deskewed = cle.deskew_y(raw,angle_in_degrees=60) - - #np.argwhere(deskewed>0) - assert deskewed.shape == (4,8,5) - assert deskewed[2,2,0] == 0.5662433505058289 - -def test_lattice_data_deskew(): - raw = np.zeros((5, 5, 5)) - lattice = lattice_fom_array(raw, physical_pixel_sizes = (1, 1, 1), save_name="test") - assert lattice.deskew_vol_shape == [2, 9, 5] +#filename and function name should start with "test_" when using pytest +import pyclesperanto_prototype as cle +import numpy as np +from lls_core.models.lattice_data import LatticeData +from xarray import DataArray +import tempfile + +def test_deskew(): + + raw = np.zeros((5,5,5)) + raw[2,0,0] = 10 + + deskewed = cle.deskew_y(raw,angle_in_degrees=60) + + #np.argwhere(deskewed>0) + assert deskewed.shape == (4,8,5) + assert deskewed[2,2,0] == 0.5662433505058289 + +def test_lattice_data_deskew(): + raw = DataArray(np.zeros((5, 5, 5)), dims=["X", "Y", "Z"]) + with tempfile.TemporaryDirectory() as tmpdir: + lattice = LatticeData( + input_image=raw, + physical_pixel_sizes = (1, 1, 1), + save_name="test", + save_dir=tmpdir + ) + assert lattice.derived.deskew_vol_shape == (2, 9, 5) diff --git a/core/tests/test_process.py b/core/tests/test_process.py new file mode 100644 index 00000000..13d6efd4 --- /dev/null +++ b/core/tests/test_process.py @@ -0,0 +1,223 @@ +from typing import Any, List, Optional +import pytest +from lls_core.models import LatticeData +from lls_core.models.crop import CropParams +from lls_core.models.output import SaveFileType +from lls_core.sample import resources +from importlib_resources import as_file +import tempfile +from pathlib import Path +from napari_workflows import Workflow +from pytest import FixtureRequest + +from .params import parameterized + +root = Path(__file__).parent / "data" + + +def open_psf(name: str): + with as_file(resources / "psfs" / "zeiss_simulated" / name) as path: + return path + + +@parameterized +def test_process(minimal_image_path: str, args: dict): + # Processes a minimal set of images, with multiple parameter combinations + for slice in ( + LatticeData.parse_obj({"input_image": minimal_image_path, **args}) + .process() + .slices + ): + assert slice.data.ndim == 3 + + +def test_process_all(image_path: str): + # Processes all input images, but without parameter combinations + for slice in ( + LatticeData.parse_obj({"input_image": image_path}).process().slices + ): + assert slice.data.ndim == 3 + + +@parameterized +def test_save(minimal_image_path: str, args: dict): + with tempfile.TemporaryDirectory() as tempdir: + LatticeData.parse_obj( + {"input_image": minimal_image_path, "save_dir": tempdir, **args} + ).process().save_image() + results = list(Path(tempdir).iterdir()) + assert len(results) > 0 + + +def test_process_deconv_crop(): + for slice in ( + LatticeData.parse_obj( + { + "input_image": root / "raw.tif", + "deconvolution": { + "psf": [root / "psf.tif"], + }, + "crop": CropParams( + roi_list=[[[0, 0], [0, 110], [95, 0], [95, 110]]] + ), + } + ) + .process() + .slices + ): + assert slice.data.ndim == 3 + + +def test_process_time_range(multi_channel_time: Path): + from lls_core.models.output import SaveFileType + + with tempfile.TemporaryDirectory() as outdir: + LatticeData.parse_obj( + { + "input_image": multi_channel_time, + # Channels 2 & 3 + "channel_range": range(1, 3), + # Time point 2 + "time_range": range(1, 2), + "save_dir": outdir, + "save_type": SaveFileType.h5, + } + ).save() + + +@pytest.mark.parametrize(["background"], [(1,), ("auto",), ("second_last",)]) +@parameterized +def test_process_deconvolution(args: dict, background: Any): + for slice in ( + LatticeData.parse_obj( + { + "input_image": root / "raw.tif", + "deconvolution": { + "psf": [root / "psf.tif"], + "background": background, + }, + **args, + } + ) + .process() + .slices + ): + assert slice.data.ndim == 3 + + +@parameterized +@pytest.mark.parametrize( + ["workflow_name"], [("image_workflow",), ("table_workflow",)] +) +def test_process_workflow( + args: dict, request: FixtureRequest, workflow_name: str +): + from pandas import DataFrame + + workflow: Workflow = request.getfixturevalue(workflow_name) + with tempfile.TemporaryDirectory() as tmpdir: + for roi, output in ( + LatticeData.parse_obj( + { + "input_image": root / "raw.tif", + "workflow": workflow, + "save_dir": tmpdir, + **args, + } + ) + .process_workflow() + .process() + ): + assert roi is None or isinstance(roi, int) + assert isinstance(output, (Path, DataFrame)) + +def test_table_workflow( + rbc_tiny: Path, table_workflow: Workflow +): + with tempfile.TemporaryDirectory() as _tmpdir: + tmpdir = Path(_tmpdir) + results = set(LatticeData.parse_obj( + { + "input_image": rbc_tiny, + "workflow": table_workflow, + "save_dir": tmpdir + } + ).process_workflow().save()) + # There should be one output for each element of the tuple + assert {result.name for result in results} == {'RBC_tiny_deskewed_output_3.csv', 'RBC_tiny_deskewed.h5', 'RBC_tiny_deskewed_output_1.csv', 'RBC_tiny_deskewed_output_2.csv'} + +@pytest.mark.parametrize( + ["roi_subset"], + [ + [None], + [[0]], + [[0, 1]], + ], +) +@parameterized +def test_process_crop_roi_file(args: dict, roi_subset: Optional[List[int]]): + # Test cropping with a roi zip file, selecting different subsets from that file + with as_file(resources / "RBC_tiny.czi") as lattice_path: + rois = root / "crop" / "two_rois.zip" + slices = list( + LatticeData.parse_obj( + { + "input_image": lattice_path, + "crop": {"roi_list": [rois], "roi_subset": roi_subset}, + **args, + } + ) + .process() + .slices + ) + # Check we made the correct number of slices + assert len(slices) == len(roi_subset) if roi_subset is not None else 2 + for slice in slices: + # Check correct dimensionality + assert slice.data.ndim == 3 + + +def test_process_crop_workflow(table_workflow: Workflow): + # Test cropping with a roi zip file, selecting different subsets from that file + with as_file( + resources / "RBC_tiny.czi" + ) as lattice_path, tempfile.TemporaryDirectory() as outdir: + LatticeData.parse_obj( + { + "input_image": lattice_path, + "workflow": table_workflow, + "save_dir": outdir, + "save_type": SaveFileType.h5, + "crop": { + "roi_list": [root / "crop" / "two_rois.zip"], + }, + } + ).process().save_image() + # Two separate H5 files should be created in this scenario: one for each ROI + results = list(Path(outdir).glob("*.h5")) + assert len(results) == 2 + + +@pytest.mark.parametrize( + ["roi"], + [ + [[[(174.0, 24.0), (174.0, 88.0), (262.0, 88.0), (262.0, 24.0)]]], + [[[(174.13, 24.2), (173.98, 87.87), (262.21, 88.3), (261.99, 23.79)]]], + ], +) +@parameterized +def test_process_crop_roi_manual(args: dict, roi: List): + # Test manually provided ROIs, both with integer and float values + with as_file(resources / "RBC_tiny.czi") as lattice_path: + for slice in ( + LatticeData.parse_obj( + { + "input_image": lattice_path, + "crop": {"roi_list": roi}, + **args, + } + ) + .process() + .slices + ): + assert slice.data.ndim == 3 diff --git a/core/tests/test_validation.py b/core/tests/test_validation.py new file mode 100644 index 00000000..bcf12106 --- /dev/null +++ b/core/tests/test_validation.py @@ -0,0 +1,54 @@ +from pathlib import Path +from lls_core.models.crop import CropParams +from lls_core.models.lattice_data import LatticeData +from lls_core.models.deskew import DeskewParams +from lls_core.models.output import OutputParams +import pytest +from pydantic import ValidationError +import tempfile +from unittest.mock import patch, PropertyMock + +def test_default_save_dir(rbc_tiny: Path): + # Test that the save dir is inferred to be the input dir + params = LatticeData(input_image=rbc_tiny) + assert params.save_dir == rbc_tiny.parent + +def test_auto_z_range(rbc_tiny: Path): + # Tests that the Z range is automatically set, and it is set + # based on the size of the deskewed volume + params = LatticeData(input_image=rbc_tiny, crop=CropParams( + roi_list=[[[0, 0], [0, 1], [1, 0], [1, 1]]] + )) + assert params.crop.z_range == (0, 59) + +def test_reject_crop(): + # Tests that the parameters fail validation if cropping is specified without an ROI + with pytest.raises(ValidationError): + CropParams( + roi_list=[] + ) + +def test_pixel_tuple_order(rbc_tiny: Path): + # Tests that a tuple of Z, Y, X is appropriately assigned in the right order + deskew = DeskewParams( + input_image=rbc_tiny, + physical_pixel_sizes=(1., 2., 3.) + ) + + assert deskew.physical_pixel_sizes.X == 3. + assert deskew.physical_pixel_sizes.Y == 2. + assert deskew.physical_pixel_sizes.Z == 1. + +def test_allow_trailing_slash(): + with tempfile.TemporaryDirectory() as tmpdir: + output = OutputParams( + save_dir=f"{tmpdir}/" + ) + assert str(output.save_dir) == tmpdir + +def test_infer_czi_pixel_sizes(rbc_tiny: Path): + mock = PropertyMock() + with patch("aicsimageio.AICSImage.physical_pixel_sizes", new=mock): + DeskewParams(input_image=rbc_tiny) + # The AICSImage should be queried for the pixel sizes + assert mock.called diff --git a/core/tests/test_workflows.py b/core/tests/test_workflows.py index bd285f20..85785c2e 100644 --- a/core/tests/test_workflows.py +++ b/core/tests/test_workflows.py @@ -1,126 +1,108 @@ -from skimage.io import imread, imsave -import os -import numpy as np -from pathlib import Path -import platform -import pyclesperanto_prototype as cle - -from napari_workflows import Workflow -from napari_workflows._io_yaml_v1 import load_workflow, save_workflow - -from lls_core.cmds.__main__ import main as run_cli - -# For testing in Windows -if platform.system() == "Windows": - home_dir = str(Path.home()) - home_dir = home_dir.replace("\\", "\\\\") - img_dir = home_dir + "\\\\raw.tiff" - workflow_location = home_dir + "\\\\deskew_segment.yaml" - config_location = home_dir + "\\\\config_deskew.yaml" -else: - home_dir = str(Path.home()) - img_dir = os.path.join(home_dir, "raw.tiff") - workflow_location = os.path.join(home_dir, "deskew_segment.yaml") - config_location = os.path.join(home_dir, "config_deskew.yaml") - - -def write_config_file(config_settings, output_file_location): - with open(output_file_location, 'w') as f: - for key, val in config_settings.items(): - if val is not None: - if type(val) is str: - print('%s: "%s"' % (key, val), file=f) - - if type(val) is int: - print('%s: %i' % (key, val), file=f) - - if type(val) is list: - print("%s:" % key, file=f) - for x in val: - if type(x) is int: - print(" - %i" % x, file=f) - else: - print(' - "%s"' % x, file=f) - - print("Config found written to %s" % output_file_location) - - -def create_data(): - # Create a zero array of shape 5x5x5 with a value of 10 at (2,4,2) - raw = np.zeros((5, 5, 5)) - raw[2, 2, 2] = 10 - # Save image as a tif filw in home directory - imsave(img_dir, raw) - assert os.path.exists(img_dir) - - # Create a config file - config = { - "input": img_dir, - "output": home_dir, - "processing": "workflow", - "workflow_path": workflow_location, - "output_file_type": "h5"} - - write_config_file(config, config_location) - assert os.path.exists(config_location) - - -def create_workflow(): - # Zeiss lattice - voxel_size_x_in_microns = 0.14499219272808386 - voxel_size_y_in_microns = 0.14499219272808386 - voxel_size_z_in_microns = 0.3 - deskewing_angle_in_degrees = 30.0 - - # Instantiate segmentation workflow - image_seg_workflow = Workflow() - - image_seg_workflow.set("gaussian", cle.gaussian_blur, - "input", sigma_x=1, sigma_y=1, sigma_z=1) - - image_seg_workflow.set("binarisation", cle.threshold, - "gaussian", constant=0.5) - - image_seg_workflow.set( - "labeling", cle.connected_components_labeling_box, "binarisation") - - save_workflow(workflow_location, image_seg_workflow) - - assert os.path.exists(workflow_location) - - -def test_napari_workflow(): - """Test napari workflow to see if it works before we run it using napari_lattice - This is without deskewing - """ - create_data() - create_workflow() - - image_seg_workflow = load_workflow(workflow_location) - - # Open the saved image from above - raw = imread(img_dir) - # Set input image to be the "raw" image - image_seg_workflow.set("input", raw) - labeling = image_seg_workflow.get("labeling") - assert (labeling[2, 2, 2] == 1) - - -def test_workflow_lattice(): - """Test workflow by loading into napari lattice - This will apply deskewing before processing the workflow - """ - # Deskew, apply workflow and save as h5 - run_cli(["--config", config_location]) - - # checks if h5 file written - h5_img = os.path.join(home_dir, "raw", "_0_raw.h5") - assert os.path.exists(h5_img) - - import npy2bdv - h5_file = npy2bdv.npy2bdv.BdvEditor(h5_img) - - label_img = h5_file.read_view(time=0, channel=0) - - assert (label_img.shape == (3, 14, 5)) - assert (label_img[1, 6, 2] == 1) +from typing import Callable +from copy import copy +from numpy.typing import NDArray + +from napari_workflows import Workflow +import tempfile + +from pandas import DataFrame +from lls_core.models.lattice_data import LatticeData + +from tests.utils import invoke +from pathlib import Path +from .params import config_types +from .utils import invoke, valid_image_path + + +def test_napari_workflow(image_workflow: Workflow, test_image: NDArray): + """ + Test napari workflow to see if it works before we run it using napari_lattice + This is without deskewing + """ + workflow = copy(image_workflow) + # Set input image to be the "raw" image + workflow.set("deskewed_image", test_image) + labeling = workflow.get("labeling") + assert labeling[2, 2, 2] == 1 + +@config_types +def test_workflow_cli(workflow_config_cli: dict, save_func: Callable, cli_param: str): + """ + Test workflow processing via CLI + This will apply deskewing before processing the workflow + """ + with tempfile.NamedTemporaryFile(mode="w") as fp: + save_func(workflow_config_cli, fp) + fp.flush() + + # Deskew, apply workflow and save as h5 + invoke([ + cli_param, fp.name + ]) + + # checks if h5 file written + save_dir = Path(workflow_config_cli["save_dir"]) + saved_files = list(save_dir.glob("*.h5")) + assert len(saved_files) > 0 + assert len(list(save_dir.glob("*.xml"))) > 0 + + import npy2bdv + for h5_img in saved_files: + h5_file = npy2bdv.npy2bdv.BdvEditor(str(h5_img)) + label_img = h5_file.read_view(time=0, channel=0) + assert label_img.shape == (3, 14, 5) + assert label_img[1, 6, 2] == 1 + +def test_image_workflow(minimal_image_path: Path, image_workflow: Workflow): + # Test that a regular workflow that returns an image directly works + with tempfile.TemporaryDirectory() as tmpdir: + for roi, output in LatticeData( + input_image = minimal_image_path, + workflow = image_workflow, + save_dir = tmpdir + ).process_workflow().process(): + assert isinstance(output, Path) + assert valid_image_path(output) + +def test_table_workflow(minimal_image_path: Path, table_workflow: Workflow): + # Test a complex workflow that returns a tuple of images and data + with tempfile.TemporaryDirectory() as tmpdir: + params = LatticeData( + input_image = minimal_image_path, + workflow = table_workflow, + save_dir = tmpdir + ) + for _roi, output in params.process_workflow().process(): + assert isinstance(output, (DataFrame, Path)) + if isinstance(output, DataFrame): + nrow, ncol = output.shape + assert nrow == params.nslices + assert ncol > 0 + # Check that time and channel are included + assert output.iloc[0, 0] == "T0" + assert output.iloc[0, 1] == "C0" + else: + assert valid_image_path(output) + +def test_argument_order(rbc_tiny: Path): + # Tests that only the first unfilled argument is passed an array + with tempfile.TemporaryDirectory() as tmpdir: + params = LatticeData( + input_image = rbc_tiny, + workflow = "core/tests/workflows/argument_order/test_workflow.yml", + save_dir = tmpdir + ) + for roi, output in params.process_workflow().process(): + print(output) + +def test_sum_preview(rbc_tiny: Path): + import numpy as np + # Tests that we can sum the preview result. This is required for the plugin + with tempfile.TemporaryDirectory() as tmpdir: + params = LatticeData( + input_image = rbc_tiny, + workflow = "core/tests/workflows/binarisation/workflow.yml", + save_dir = tmpdir + ) + preview = params.process_workflow().extract_preview() + np.sum(preview, axis=(1, 2)) diff --git a/core/tests/utils.py b/core/tests/utils.py new file mode 100644 index 00000000..c9744a03 --- /dev/null +++ b/core/tests/utils.py @@ -0,0 +1,17 @@ +from pathlib import Path +from typing import Sequence +from typer.testing import CliRunner +from lls_core.cmds.__main__ import app +import npy2bdv +from aicsimageio import AICSImage + +def invoke(args: Sequence[str]): + CliRunner().invoke(app, args, catch_exceptions=False) + +def valid_image_path(path: Path) -> bool: + if path.suffix in {".hdf5", ".h5"}: + npy2bdv.npy2bdv.BdvEditor(str(path)).read_view() + return True + else: + AICSImage(path).get_image_data() + return True diff --git a/core/tests/workflows/argument_order/custom_function.py b/core/tests/workflows/argument_order/custom_function.py new file mode 100644 index 00000000..07ac327e --- /dev/null +++ b/core/tests/workflows/argument_order/custom_function.py @@ -0,0 +1,7 @@ +from numpy import ndarray + +def test(a, b, c): + assert isinstance(a, ndarray) + assert isinstance(b, str) + assert isinstance(c, int) + return a diff --git a/core/tests/workflows/argument_order/test_workflow.yml b/core/tests/workflows/argument_order/test_workflow.yml new file mode 100644 index 00000000..ba479d79 --- /dev/null +++ b/core/tests/workflows/argument_order/test_workflow.yml @@ -0,0 +1,7 @@ +!!python/object:napari_workflows._workflow.Workflow +_tasks: + test: !!python/tuple + - !!python/name:custom_function.test '' + - deskewed_image + - "abc" + - 3 diff --git a/core/tests/workflows/binarisation/workflow.yml b/core/tests/workflows/binarisation/workflow.yml new file mode 100644 index 00000000..78ee333c --- /dev/null +++ b/core/tests/workflows/binarisation/workflow.yml @@ -0,0 +1,14 @@ +!!python/object:napari_workflows._workflow.Workflow +_tasks: + binarisation: !!python/tuple + - !!python/name:pyclesperanto_prototype.greater_constant '' + - median + - null + - 100 + median: !!python/tuple + - !!python/name:pyclesperanto_prototype.median_sphere '' + - deskewed_image + - null + - 2 + - 2 + - 2 diff --git a/plugin/napari_lattice/_dock_widget.py b/plugin/napari_lattice/_dock_widget.py deleted file mode 100644 index f488e831..00000000 --- a/plugin/napari_lattice/_dock_widget.py +++ /dev/null @@ -1,1153 +0,0 @@ -import os -import sys -import yaml -import numpy as np -from pathlib import Path -import dask.array as da -import pandas as pd -from typing import Union, Optional, Callable, Literal -from enum import Enum - -from magicclass.wrappers import set_design -from magicgui import magicgui -from magicclass import magicclass, field, vfield, set_options, MagicTemplate -from magicclass.utils import click -from qtpy.QtCore import Qt - -from napari.layers import Layer, Shapes -from napari.types import ImageData -from napari.utils import history - -import pyclesperanto_prototype as cle - -from napari.types import ImageData, ShapesData - -from tqdm import tqdm - -from napari_workflows import Workflow, WorkflowManager -from napari_workflows._io_yaml_v1 import load_workflow - -from lls_core import config, DeskewDirection, DeconvolutionChoice, SaveFileType, Log_Levels -from lls_core.io import LatticeData, save_img, save_img_workflow -from lls_core.utils import read_imagej_roi, get_first_last_image_and_task, modify_workflow_task, get_all_py_files, as_type, process_custom_workflow_output, check_dimensions, load_custom_py_modules -from lls_core.llsz_core import crop_volume_deskew -from lls_core.deconvolution import read_psf, pycuda_decon, skimage_decon - -from napari_lattice.ui_core import _Preview, _Deskew_Save -from napari_lattice._reader import lattice_from_napari - -# Enable Logging -import logging -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - -class LastDimensionOptions(Enum): - channel = "Channel" - time = "Time" - get_from_metadata = "Get from Metadata" - -@magicclass(widget_type="split") -class LLSZWidget(MagicTemplate): - - @magicclass(widget_type="split") - class LlszMenu(MagicTemplate): - open_file: bool = False - lattice: LatticeData = None - skew_dir: DeskewDirection - angle_value: float - deskew_func: Callable - - main_heading = field("

Napari Lattice: Visualization & Analysis

", widget_type="Label") - heading1 = field("Drag and drop an image file onto napari.\nOnce image has opened, initialize the\nplugin by clicking the button below.\nEnsure the image layer and voxel sizes are accurate in the prompt.\n If everything initalises properly, the button turns green.", widget_type="Label") - - @set_design(background_color="magenta", font_family="Consolas", visible=True, text="Initialize Plugin", max_height=75, font_size=13) - @set_options(pixel_size_dx={"widget_type": "FloatSpinBox", "value": 0.1449922, "step": 0.000000001}, - pixel_size_dy={"widget_type": "FloatSpinBox", - "value": 0.1449922, "step": 0.000000001}, - pixel_size_dz={"widget_type": "FloatSpinBox", - "value": 0.3, "step": 0.000000001}, - angle={"widget_type": "FloatSpinBox", - "value": 30, "step": 0.1}, - select_device={"widget_type": "ComboBox", "choices": cle.available_device_names( - ), "value": cle.available_device_names()[0]}, - last_dimension_channel={"widget_type": "ComboBox", - "label": "Set Last dimension (channel/time)", "tooltip": "If the last dimension is initialised incorrectly, you can assign it as either channel/time"}, - merge_all_channel_layers={"widget_type": "CheckBox", "value": True, "label": "Merge all napari layers as channels", - "tooltip": "Use this option if the channels are in separate layers. napari-lattice requires all channels to be in same layer"}, - skew_dir={"widget_type": "ComboBox", "choices": DeskewDirection, "value": DeskewDirection.Y, - "label": "Direction of skew (Y or X)", "tooltip": "Skew direction when image is acquired. Ask your microscopist for details"}, - set_logging={"widget_type": "ComboBox", "choices": Log_Levels, "value": Log_Levels.INFO, - "label": "Log Level", "tooltip": "Only use for debugging. Leave it as INFO for regular operation"} - ) - def Choose_Image_Layer(self, - img_layer: Layer, - pixel_size_dx: float = 0.1449922, - pixel_size_dy: float = 0.1449922, - pixel_size_dz: float = 0.3, - angle: float = 30, - select_device: str = cle.available_device_names()[ - 0], - last_dimension_channel: LastDimensionOptions = LastDimensionOptions.get_from_metadata, - merge_all_channel_layers: bool = False, - skew_dir: DeskewDirection=DeskewDirection.Y, - set_logging: Log_Levels=Log_Levels.INFO): - - logger.setLevel(set_logging.value) - config.log_level = set_logging.value - logger.info(f"Logging set to {set_logging}") - logger.info("Using existing image layer") - - if self.parent_viewer is None: - raise Exception("This function can only be used when inside of a Napari viewer") - - # Select device for processing - cle.select_device(select_device) - - #assert skew_dir in DeskewDirection, "Skew direction not recognised. Enter either Y or X" - LLSZWidget.LlszMenu.skew_dir = skew_dir - LLSZWidget.LlszMenu.angle_value = angle - - if LLSZWidget.LlszMenu.skew_dir == DeskewDirection.Y: - LLSZWidget.LlszMenu.deskew_func = cle.deskew_y - #LLSZWidget.LlszMenu.skew_dir = DeskewDirection.Y - elif LLSZWidget.LlszMenu.skew_dir == DeskewDirection.X: - LLSZWidget.LlszMenu.deskew_func = cle.deskew_x - #LLSZWidget.LlszMenu.skew_dir = DeskewDirection.X - - # merge all napari image layers as one multidimensional image - if merge_all_channel_layers: - from napari.layers.utils.stack_utils import images_to_stack - # get list of napari layers as a list - layer_list = list(self.parent_viewer.layers) - # if more than one layer - if len(layer_list) > 1: - # convert the list of images into a stack - new_layer = images_to_stack(layer_list) - # select all current layers - self.parent_viewer.layers.select_all() - # remove selected layers - self.parent_viewer.layers.remove_selected() - # add the new composite image layer - self.parent_viewer.add_layer(new_layer) - img_layer = new_layer - - LLSZWidget.LlszMenu.lattice = lattice_from_napari( - img=img_layer, - last_dimension=None if last_dimension_channel == LastDimensionOptions.get_from_metadata else last_dimension_channel, - angle=angle, - skew=LLSZWidget.LlszMenu.skew_dir, - physical_pixel_sizes=(pixel_size_dx, pixel_size_dy, pixel_size_dz) - ) - #LLSZWidget.LlszMenu.aics = LLSZWidget.LlszMenu.lattice.data - - # LLSZWidget.LlszMenu.dask = False # Use GPU by default - - # We initialise these variables here, but they can be changed in the deconvolution section - # list to store psf images for each channel - LLSZWidget.LlszMenu.lattice.psf = [] - LLSZWidget.LlszMenu.lattice.psf_num_iter = 10 - LLSZWidget.LlszMenu.lattice.decon_processing = DeconvolutionChoice.cpu - # list to store otf paths for each channel (Deprecated) - LLSZWidget.LlszMenu.lattice.otf_path = [] - # if not using GPU - #LLSZWidget.LlszMenu.dask = not use_GPU - - # flag for ensuring a file has been opened and plugin initialised - LLSZWidget.LlszMenu.open_file = True - - logger.info( - f"Pixel size (ZYX) in microns: {LLSZWidget.LlszMenu.lattice.dz,LLSZWidget.LlszMenu.lattice.dy,LLSZWidget.LlszMenu.lattice.dx}") - logger.info( - f"Dimensions of image layer (ZYX): {list(LLSZWidget.LlszMenu.lattice.data.shape[-3:])}") - logger.info( - f"Dimensions of deskewed image (ZYX): {LLSZWidget.LlszMenu.lattice.deskew_vol_shape}") - logger.info( - f"Deskewing angle is :{LLSZWidget.LlszMenu.lattice.angle}") - logger.info( - f"Deskew Direction :{LLSZWidget.LlszMenu.lattice.skew}") - # Add dimension labels correctly - # if channel, and not time - if LLSZWidget.LlszMenu.lattice.time == 0 and (last_dimension_channel or LLSZWidget.LlszMenu.lattice.channels > 0): - self.parent_viewer.dims.axis_labels = ('Channel', "Z", "Y", "X") - # if no channel, but has time - elif LLSZWidget.LlszMenu.lattice.channels == 0 and LLSZWidget.LlszMenu.lattice.time > 0: - self.parent_viewer.dims.axis_labels = ('Time', "Z", "Y", "X") - # if it has channels - elif LLSZWidget.LlszMenu.lattice.channels > 1: - # If merge to stack is used, channel slider goes to the bottom - if int(self.parent_viewer.dims.dict()["range"][0][1]) == LLSZWidget.LlszMenu.lattice.channels: - self.parent_viewer.dims.axis_labels = ('Channel', "Time", "Z", "Y", "X") - else: - self.parent_viewer.dims.axis_labels = ('Time', "Channel", "Z", "Y", "X") - # if channels initialized by aicsimagio, then channels is 1 - elif LLSZWidget.LlszMenu.lattice.channels == 1 and LLSZWidget.LlszMenu.lattice.time > 1: - self.parent_viewer.dims.axis_labels = ('Time', "Z", "Y", "X") - - logger.info(f"Initialised") - self["Choose_Image_Layer"].background_color = "green" - self["Choose_Image_Layer"].text = "Plugin Initialised" - - return - - # Pycudadecon library for deconvolution - # options={"enabled": True}, - deconvolution = vfield(bool, name="Use Deconvolution") - deconvolution.value = False - - @deconvolution.connect - def _set_decon(self): - if self.deconvolution: - logger.info("Deconvolution Activated") - LLSZWidget.LlszMenu.deconvolution.value = True - else: - logger.info("Deconvolution Disabled") - LLSZWidget.LlszMenu.deconvolution.value = False - return - - @set_design(background_color="magenta", font_family="Consolas", visible=True, text="Click to select PSFs for deconvolution", max_height=75, font_size=11) - @set_options(header=dict(widget_type="Label", label="

Enter path to the PSF images

"), - psf_ch1_path={"widget_type": "FileEdit", - "label": "Channel 1:"}, - psf_ch2_path={"widget_type": "FileEdit", - "label": "Channel 2"}, - psf_ch3_path={"widget_type": "FileEdit", - "label": "Channel 3"}, - psf_ch4_path={"widget_type": "FileEdit", - "label": "Channel 4"}, - device_option={ - "widget_type": "ComboBox", "label": "Choose processing device", "choices": DeconvolutionChoice}, - no_iter={ - "widget_type": "SpinBox", "label": "No of iterations (Deconvolution)", "value": 10, "min": 1, "max": 50, "step": 1} - ) - def deconvolution_gui(self, - header, - psf_ch1_path: Path, - psf_ch2_path: Path, - psf_ch3_path: Path, - psf_ch4_path: Path, - device_option, - no_iter: int): - """GUI for Deconvolution button""" - LLSZWidget.LlszMenu.lattice.decon_processing = device_option - assert LLSZWidget.LlszMenu.deconvolution.value == True, "Deconvolution is set to False. Tick the box to activate deconvolution." - LLSZWidget.LlszMenu.lattice.psf = list(read_psf([ - psf_ch1_path, - psf_ch2_path, - psf_ch3_path, - psf_ch4_path, - ], - device_option, - lattice_class=LLSZWidget.LlszMenu.lattice - )) - LLSZWidget.LlszMenu.lattice.psf_num_iter = no_iter - self["deconvolution_gui"].background_color = "green" - self["deconvolution_gui"].text = "PSFs added" - - @magicclass(widget_type="collapsible") - class Preview: - @magicgui(header=dict(widget_type="Label", label="

Preview Deskew

"), - time=dict(label="Time:", max=2**15), - channel=dict(label="Channel:"), - call_button="Preview") - def Preview_Deskew(self, - header, - time: int, - channel: int, - img_data: ImageData): - """ - Preview deskewed data for a single timepoint and channel - - """ - _Preview(LLSZWidget, - self, - time, - channel, - img_data) - return - - # Tabbed Widget container to house all the widgets - @magicclass(widget_type="tabbed", name="Functions") - class WidgetContainer(MagicTemplate): - - @magicclass(name="Deskew", widget_type="scrollable", properties={"min_width": 100}) - class DeskewWidget(MagicTemplate): - - @magicgui(header=dict(widget_type="Label", label="

Deskew and Save

"), - time_start=dict(label="Time Start:", max=2**20), - time_end=dict(label="Time End:", value=1, max=2**20), - ch_start=dict(label="Channel Start:"), - ch_end=dict(label="Channel End:", value=1), - save_as_type={ - "label": "Save as filetype:", "choices": SaveFileType, "value": SaveFileType.h5}, - save_path=dict(mode='d', label="Directory to save"), - call_button="Save") - def Deskew_Save(self, - header, - time_start: int, - time_end: int, - ch_start: int, - ch_end: int, - save_as_type: str, - save_path: Path = Path(history.get_save_history()[0])): - """ Widget to Deskew and Save Data""" - _Deskew_Save(LLSZWidget, - time_start, - time_end, - ch_start, - ch_end, - save_as_type, - save_path) - return - - @magicclass(name="Crop and Deskew", widget_type="scrollable") - class CropWidget(MagicTemplate): - - # add function for previewing cropped image - @magicclass(name="Cropping Preview", widget_type="scrollable", properties={ - "min_width": 100, - "shapes_layer": Shapes - }) - class Preview_Crop_Menu(MagicTemplate): - shapes_layer: Shapes - - @set_design(font_size=10, text="Click to activate Cropping Layer", background_color="magenta") - @click(enables=["Import_ImageJ_ROI", "Crop_Preview"]) - def activate_cropping(self): - LLSZWidget.WidgetContainer.CropWidget.Preview_Crop_Menu.shapes_layer = self.parent_viewer.add_shapes(shape_type='polygon', edge_width=1, edge_color='white', - face_color=[1, 1, 1, 0], name="Cropping BBOX layer") - # TO select ROIs if needed - LLSZWidget.WidgetContainer.CropWidget.Preview_Crop_Menu.shapes_layer.mode = "SELECT" - self["activate_cropping"].text = "Cropping layer active" - self["activate_cropping"].background_color = "green" - return - - heading2 = field("You can either import ImageJ ROI (.zip) files or manually define ROIs using the shape layer", widget_type="Label") - - @click(enabled=False) - def Import_ImageJ_ROI(self, path: Path = Path(history.get_open_history()[0])): - logger.info(f"Opening{path}") - roi_list = read_imagej_roi(path) - # convert to canvas coordinates - roi_list = (np.array(roi_list) * - LLSZWidget.LlszMenu.lattice.dy).tolist() - LLSZWidget.WidgetContainer.CropWidget.Preview_Crop_Menu.shapes_layer.add(roi_list, shape_type='polygon', edge_width=1, edge_color='yellow', - face_color=[1, 1, 1, 0]) - return - - time_crop = field( - int, options={"min": 0, "step": 1, "max": 2**20}, name="Time") - chan_crop = field( - int, options={"min": 0, "step": 1}, name="Channels") - heading_roi = field("If there are multiple ROIs, select the ROI before clicking button below", widget_type="Label") - #roi_idx = field(int, options={"min": 0, "step": 1}, name="ROI number") - - @click(enabled=False) - # -> LayerDataTuple: - def Crop_Preview(self, roi_layer: ShapesData): - assert roi_layer, "No coordinates found for cropping. Check if right shapes layer or initialise shapes layer and draw ROIs." - # TODO: Add assertion to check if bbox layer or coordinates - time = self.time_crop.value - channel = self.chan_crop.value - - assert time < LLSZWidget.LlszMenu.lattice.time, "Time is out of range" - assert channel < LLSZWidget.LlszMenu.lattice.channels, "Channel is out of range" - - logger.info(f"Using channel {channel} and time {time}") - - vol = LLSZWidget.LlszMenu.lattice.data - vol_zyx = vol[time, channel, ...] - vol_zyx = np.array(vol_zyx) - - deskewed_shape = LLSZWidget.LlszMenu.lattice.deskew_vol_shape - # Create a dask array same shape as deskewed image - deskewed_volume = da.zeros(deskewed_shape) - - # Option for entering custom z start value? - z_start = 0 - z_end = deskewed_shape[0] - - # if only one roi drawn, use the first ROI for cropping - if len(roi_layer) == 1: - roi_idx = 0 - else: - assert len( - LLSZWidget.WidgetContainer.CropWidget.Preview_Crop_Menu.shapes_layer.selected_data) > 0, "Please select an ROI" - roi_idx = list( - LLSZWidget.WidgetContainer.CropWidget.Preview_Crop_Menu.shapes_layer.selected_data)[0] - - roi_choice = roi_layer[roi_idx] - # As the original image is scaled, the coordinates are in microns, so we need to convert - # roi from micron to canvas/world coordinates - roi_choice = [ - x/LLSZWidget.LlszMenu.lattice.dy for x in roi_choice] - logger.info(f"Previewing ROI {roi_idx}") - - # crop - if LLSZWidget.LlszMenu.deconvolution.value: - logger.info( - f"Deskewing for Time:{time} and Channel: {channel} with deconvolution") - #psf = LLSZWidget.LlszMenu.lattice.psf[channel] - if LLSZWidget.LlszMenu.lattice.decon_processing == DeconvolutionChoice.cuda_gpu: - crop_roi_vol_desk = crop_volume_deskew(original_volume=vol_zyx, - deskewed_volume=deskewed_volume, - roi_shape=roi_choice, - angle_in_degrees=LLSZWidget.LlszMenu.angle_value, - voxel_size_x=LLSZWidget.LlszMenu.lattice.dx, - voxel_size_y=LLSZWidget.LlszMenu.lattice.dy, - voxel_size_z=LLSZWidget.LlszMenu.lattice.dz, - z_start=z_start, - z_end=z_end, - deconvolution=LLSZWidget.LlszMenu.deconvolution.value, - decon_processing=LLSZWidget.LlszMenu.lattice.decon_processing, - psf=LLSZWidget.LlszMenu.lattice.psf[channel], - num_iter=LLSZWidget.LlszMenu.lattice.psf_num_iter).astype(vol_zyx.dtype) - else: - crop_roi_vol_desk = crop_volume_deskew(original_volume=vol_zyx, - deskewed_volume=deskewed_volume, - roi_shape=roi_choice, - angle_in_degrees=LLSZWidget.LlszMenu.angle_value, - voxel_size_x=LLSZWidget.LlszMenu.lattice.dx, - voxel_size_y=LLSZWidget.LlszMenu.lattice.dy, - voxel_size_z=LLSZWidget.LlszMenu.lattice.dz, - z_start=z_start, - z_end=z_end, - deconvolution=LLSZWidget.LlszMenu.deconvolution.value, - decon_processing=LLSZWidget.LlszMenu.lattice.decon_processing, - psf=LLSZWidget.LlszMenu.lattice.psf[channel], - num_iter=LLSZWidget.LlszMenu.lattice.psf_num_iter).astype(vol_zyx.dtype) - else: - crop_roi_vol_desk = crop_volume_deskew(original_volume=vol_zyx, - deskewed_volume=deskewed_volume, - roi_shape=roi_choice, - angle_in_degrees=LLSZWidget.LlszMenu.angle_value, - voxel_size_x=LLSZWidget.LlszMenu.lattice.dx, - voxel_size_y=LLSZWidget.LlszMenu.lattice.dy, - voxel_size_z=LLSZWidget.LlszMenu.lattice.dz, - z_start=z_start, - z_end=z_end, - skew_dir=LLSZWidget.LlszMenu.skew_dir).astype(vol_zyx.dtype) - crop_roi_vol_desk = cle.pull(crop_roi_vol_desk) - - # get array back from gpu or addding cle array to napari can throw errors - - scale = (LLSZWidget.LlszMenu.lattice.new_dz, - LLSZWidget.LlszMenu.lattice.dy, - LLSZWidget.LlszMenu.lattice.dx) - self.parent_viewer.add_image( - crop_roi_vol_desk, scale=scale) - - @magicclass(name="Crop and Save Data") - class CropSaveData(MagicTemplate): - @magicgui(header=dict(widget_type="Label", label="

Crop and Save Data

"), - time_start=dict(label="Time Start:"), - time_end=dict(label="Time End:", value=1), - ch_start=dict(label="Channel Start:"), - ch_end=dict(label="Channel End:", value=1), - save_as_type={ - "label": "Save as filetype:", "choices": SaveFileType}, - save_path=dict(mode='d', label="Directory to save ")) - def Crop_Save(self, - header, - time_start: int, - time_end: int, - ch_start: int, - ch_end: int, - save_as_type: str, - roi_layer_list: ShapesData, - save_path: Path = Path(history.get_save_history()[0])): - - if not roi_layer_list: - logger.error( - "No coordinates found or cropping. Initialise shapes layer and draw ROIs.") - else: - assert LLSZWidget.LlszMenu.open_file, "Image not initialised" - - check_dimensions(time_start, time_end, ch_start, ch_end, - LLSZWidget.LlszMenu.lattice.channels, LLSZWidget.LlszMenu.lattice.time) - - angle = LLSZWidget.LlszMenu.lattice.angle - dx = LLSZWidget.LlszMenu.lattice.dx - dy = LLSZWidget.LlszMenu.lattice.dy - dz = LLSZWidget.LlszMenu.lattice.dz - - # get image data - img_data = LLSZWidget.LlszMenu.lattice.data - # Get shape of deskewed image - deskewed_shape = LLSZWidget.LlszMenu.lattice.deskew_vol_shape - deskewed_volume = da.zeros(deskewed_shape) - z_start = 0 - z_end = deskewed_shape[0] - - logger.info("Cropping and saving files...") - - # necessary when scale is used for napari.viewer.add_image operations - roi_layer_list = [ - x/LLSZWidget.LlszMenu.lattice.dy for x in roi_layer_list] - - for idx, roi_layer in enumerate(tqdm(roi_layer_list, desc="ROI:", position=0)): - # pass arguments for save tiff, callable and function arguments - logger.info("Processing ROI ", idx) - # pass parameters for the crop_volume_deskew function - - save_img(vol=img_data, - func=crop_volume_deskew, - time_start=time_start, - time_end=time_end, - channel_start=ch_start, - channel_end=ch_end, - save_name_prefix="ROI_" + - str(idx), - save_path=save_path, - save_file_type=save_as_type, - save_name=LLSZWidget.LlszMenu.lattice.save_name, - dx=dx, - dy=dy, - dz=dz, - angle=angle, - deskewed_volume=deskewed_volume, - roi_shape=roi_layer, - angle_in_degrees=angle, - z_start=z_start, - z_end=z_end, - voxel_size_x=dx, - voxel_size_y=dy, - voxel_size_z=dz, - LLSZWidget=LLSZWidget - ) - - logger.info( - f"Cropping and Saving Complete -> {save_path}") - return - - @magicclass(name="Workflow", widget_type="scrollable") - class WorkflowWidget: - @magicclass(name="Preview Workflow", widget_type="scrollable") - class PreviewWorkflow: - #time_preview= field(int, options={"min": 0, "step": 1}, name="Time") - #chan_preview = field(int, options={"min": 0, "step": 1}, name="Channels") - @magicgui(header=dict(widget_type="Label", label="

Preview Workflow

"), - time_preview=dict(label="Time:", max=2**20), - chan_preview=dict(label="Channel:"), - get_active_workflow=dict( - widget_type="Checkbox", label="Get active workflow in napari-workflow", value=False), - workflow_path=dict( - mode='r', label="Load custom workflow (.yaml/yml)"), - Use_Cropping=dict( - widget_type="Checkbox", label="Crop Data", value=False), - #custom_module=dict(widget_type="Checkbox",label="Load custom module (looks for *.py files in the workflow directory)",value = False), - call_button="Apply and Preview Workflow") - def Workflow_Preview(self, - header, - time_preview: int, - chan_preview: int, - get_active_workflow: bool, - Use_Cropping: bool, - roi_layer_list: ShapesData, - workflow_path: Path = Path.home()): - """ - Apply napari_workflows to the processing pipeline - User can define a pipeline which can be inspected in napari workflow inspector - and then execute it by ticking the get active workflow checkbox, - OR - Use a predefined workflow - - In both cases, if deskewing is not present as first step, it will be added on - and rest of the task will be made followers - Args: - - """ - print("Previewing deskewed channel and time with workflow") - if get_active_workflow: - # installs the workflow to napari - user_workflow = WorkflowManager.install( - self.parent_viewer).workflow - parent_dir = workflow_path.resolve( - ).parents[0].__str__()+os.sep - logger.info("Workflow loaded from napari") - else: - - try: - # Automatically scan workflow file directory for *.py files. - # If it findss one, load it as a module - - parent_dir = workflow_path.resolve( - ).parents[0].__str__()+os.sep - sys.path.append(parent_dir) - custom_py_files = get_all_py_files(parent_dir) - if len(custom_py_files) == 0: - logger.error( - f"No custom modules imported. If you'd like to use a cusotm module, place a *.py file in same folder as the workflow file {parent_dir}") - else: - modules = load_custom_py_modules( - custom_py_files) - - logger.info( - f"Custom modules imported {modules}") - user_workflow = load_workflow( - workflow_path.__str__()) - except yaml.loader.ConstructorError as e: - logger.error( - "\033[91m While loading workflow, got the following error which may mean you need to install the corresponding module in your Python environment: \033[0m") - logger.error(e) - - #user_workflow = load_workflow(workflow_path) - logger.info("Workflow loaded from file") - - assert type( - user_workflow) is Workflow, "Workflow loading error. Check if file is workflow or if required libraries are installed" - - input_arg_first, input_arg_last, first_task_name, last_task_name = get_first_last_image_and_task( - user_workflow) - #print(input_arg_first, input_arg_last, first_task_name,last_task_name ) - # get list of tasks - task_list = list(user_workflow._tasks.keys()) - logger.info(f"Workflow loaded:{user_workflow}") - # logger.info() - - # when using fields, self.time_preview.value - assert time_preview < LLSZWidget.LlszMenu.lattice.time, "Time is out of range" - assert chan_preview < LLSZWidget.LlszMenu.lattice.channels, "Channel is out of range" - - time = time_preview - channel = chan_preview - - # to access current time and channel and pass it to workflow file - config.channel = channel - config.time = time - - logger.info( - f"Processing for Time: {time} and Channel: {channel}") - - vol = LLSZWidget.LlszMenu.lattice.data - vol_zyx = vol[time, channel, ...] - vol_zyx = np.array(vol_zyx) - - task_name_start = first_task_name[0] - try: - task_name_last = last_task_name[0] - except IndexError: - task_name_last = task_name_start - - # variables to hold task name, initialize it as None - # if gpu, set otf_path, otherwise use psf - psf = None - otf_path = None - - if LLSZWidget.LlszMenu.lattice.decon_processing == DeconvolutionChoice.cuda_gpu: - otf_path = "otf_path" - else: - psf = "psf" - - # if cropping, set that as first task - # get the function associated with the first task and check if its deskewing - if Use_Cropping: - # use deskewed volume for cropping function - deskewed_shape = LLSZWidget.LlszMenu.lattice.deskew_vol_shape - deskewed_volume = da.zeros(deskewed_shape) - z_start = 0 - z_end = deskewed_shape[0] - if user_workflow.get_task(task_name_start)[0] not in [crop_volume_deskew]: - # if only one roi drawn, use the first ROI for cropping - if len(roi_layer_list) == 1: - roi_idx = 0 - else: # else get the user selection - assert len( - LLSZWidget.WidgetContainer.CropWidget.Preview_Crop_Menu.shapes_layer.selected_data) > 0, "Please select an ROI" - roi_idx = list( - LLSZWidget.WidgetContainer.CropWidget.Preview_Crop_Menu.shapes_layer.selected_data)[0] - - roi_choice = roi_layer_list[roi_idx] - # As the original image is scaled, the coordinates are in microns, so we need to convert - # roi to from micron to canvas/world coordinates - roi_choice = [ - x/LLSZWidget.LlszMenu.lattice.dy for x in roi_choice] - logger.info(f"Previewing ROI {roi_idx}") - if LLSZWidget.LlszMenu.deconvolution.value: - user_workflow.set("crop_deskew_image", crop_volume_deskew, - original_volume=vol_zyx, - deskewed_volume=deskewed_volume, - roi_shape=roi_choice, - angle_in_degrees=LLSZWidget.LlszMenu.lattice.angle, - voxel_size_x=LLSZWidget.LlszMenu.lattice.dx, - voxel_size_y=LLSZWidget.LlszMenu.lattice.dy, - voxel_size_z=LLSZWidget.LlszMenu.lattice.dz, - z_start=z_start, - z_end=z_end, - deconvolution=LLSZWidget.LlszMenu.deconvolution.value, - decon_processing=LLSZWidget.LlszMenu.lattice.decon_processing, - otf_path=otf_path, - psf=LLSZWidget.LlszMenu.lattice.psf[channel], - skew_dir=LLSZWidget.LlszMenu.skew_dir) - else: - user_workflow.set("crop_deskew_image", crop_volume_deskew, - original_volume=vol_zyx, - deskewed_volume=deskewed_volume, - roi_shape=roi_choice, - angle_in_degrees=LLSZWidget.LlszMenu.lattice.angle, - voxel_size_x=LLSZWidget.LlszMenu.lattice.dx, - voxel_size_y=LLSZWidget.LlszMenu.lattice.dy, - voxel_size_z=LLSZWidget.LlszMenu.lattice.dz, - z_start=z_start, - z_end=z_end, - skew_dir=LLSZWidget.LlszMenu.skew_dir) - - # Set input of the workflow to be crop_deskewing, i.e., the original first operation will now have crop_deskew_image as an input (becoming second instead) - user_workflow.set( - input_arg_first, "crop_deskew_image") - else: - user_workflow.set(input_arg_first, vol_zyx) - # Not cropping; If deskew not in workflow, append to start - elif user_workflow.get_task(task_name_start)[0] not in (cle.deskew_y, cle.deskew_x): - # if deconvolution checked, add it to start of workflow (add upstream of deskewing) - if LLSZWidget.LlszMenu.deconvolution.value: - psf = LLSZWidget.LlszMenu.lattice.psf[channel] - input_arg_first_decon, input_arg_last_decon, first_task_name_decon, last_task_name_decon = get_first_last_image_and_task( - user_workflow) - - if LLSZWidget.LlszMenu.lattice.decon_processing == DeconvolutionChoice.cuda_gpu: - user_workflow.set("deconvolution", - pycuda_decon, - image=vol_zyx, - psf=LLSZWidget.LlszMenu.lattice.psf[channel], - dzdata=LLSZWidget.LlszMenu.lattice.dz, - dxdata=LLSZWidget.LlszMenu.lattice.dx, - dzpsf=LLSZWidget.LlszMenu.lattice.dz, - dxpsf=LLSZWidget.LlszMenu.lattice.dx, - num_iter=LLSZWidget.LlszMenu.lattice.psf_num_iter) - # user_workflow.set(input_arg_first_decon,"deconvolution") - else: - user_workflow.set("deconvolution", - skimage_decon, - vol_zyx=vol_zyx, - psf=LLSZWidget.LlszMenu.lattice.psf[channel], - num_iter=LLSZWidget.LlszMenu.lattice.psf_num_iter, - clip=False, - filter_epsilon=0, - boundary='nearest') - # user_workflow.set(input_arg_first_decon,"deconvolution") - - user_workflow.set("deskew_image", - LLSZWidget.LlszMenu.deskew_func, - "deconvolution", - angle_in_degrees=LLSZWidget.LlszMenu.lattice.angle, - voxel_size_x=LLSZWidget.LlszMenu.lattice.dx, - voxel_size_y=LLSZWidget.LlszMenu.lattice.dy, - voxel_size_z=LLSZWidget.LlszMenu.lattice.dz, - linear_interpolation=True) - - # user_workflow.set("change_bitdepth",as_type,"deskew_image",vol_zyx) - # Set input of the workflow to be from deskewing output with same bit depth as original volume - # user_workflow.set(input_arg_first,"change_bitdepth") - - else: - user_workflow.set("deskew_image", - LLSZWidget.LlszMenu.deskew_func, - vol_zyx, - angle_in_degrees=LLSZWidget.LlszMenu.lattice.angle, - voxel_size_x=LLSZWidget.LlszMenu.lattice.dx, - voxel_size_y=LLSZWidget.LlszMenu.lattice.dy, - voxel_size_z=LLSZWidget.LlszMenu.lattice.dz, - linear_interpolation=True) - # Set input of the workflow to be from deskewing - # user_workflow.set(input_arg_first,"deskew_image") - - user_workflow.set( - "change_bitdepth", as_type, "deskew_image", vol_zyx) - # Set input of the workflow to be from deskewing with same bit depth as original volume - user_workflow.set( - input_arg_first, "change_bitdepth") - - else: - # if deskew already in workflow, just check if deconvolution needs to be added - # repitition of above (maybe create a function?) - # if deconvolution checked, add it to start of workflow (add upstream of deskewing) - if LLSZWidget.LlszMenu.deconvolution.value: - psf = LLSZWidget.LlszMenu.lattice.psf[channel] - input_arg_first, input_arg_last, first_task_name, last_task_name = get_first_last_image_and_task( - user_workflow) - - if LLSZWidget.LlszMenu.lattice.decon_processing == DeconvolutionChoice.cuda_gpu: - user_workflow.set("deconvolution", - pycuda_decon, - image=vol_zyx, - psf=LLSZWidget.LlszMenu.lattice.psf[channel], - dzdata=LLSZWidget.LlszMenu.lattice.dz, - dxdata=LLSZWidget.LlszMenu.lattice.dx, - dzpsf=LLSZWidget.LlszMenu.lattice.dz, - dxpsf=LLSZWidget.LlszMenu.lattice.dx, - num_iter=LLSZWidget.LlszMenu.lattice.psf_num_iter) - # user_workflow.set(input_arg_first,"deconvolution") - else: - user_workflow.set("deconvolution", - skimage_decon, - vol_zyx=vol_zyx, - psf=LLSZWidget.LlszMenu.lattice.psf[channel], - num_iter=LLSZWidget.LlszMenu.lattice.psf_num_iter, - clip=False, - filter_epsilon=0, - boundary='nearest') - - # set input to subsequent task as deconvolution output - user_workflow.set( - input_arg_first, "deconvolution") - - logger.info("Workflow to be executed:") - logger.info(user_workflow) - # Execute workflow - processed_vol = user_workflow.get(task_name_last) - - # check if a measurement table (usually a dictionary or list) or a tuple with different data types - # The function below saves the tables and adds any images to napari window - if type(processed_vol) in [dict, list, tuple]: - if (len(processed_vol) > 1): - df = pd.DataFrame() - for idx, i in enumerate(processed_vol): - df_temp = process_custom_workflow_output( - i, parent_dir, idx, LLSZWidget, self, channel, time, preview=True) - final_df = pd.concat([df, df_temp]) - # append dataframes from every loop and have table command outside loop? - # TODO: Figure out why table is not displaying - from napari_spreadsheet import _widget - table_viewer = _widget.TableViewerWidget( - show=True) - table_viewer.add_spreadsheet(final_df) - # widgets.Table(value=final_df).show() - - else: - # add image to napari window - # TODO: check if its an image napari supports? - process_custom_workflow_output( - processed_vol, parent_dir, 0, LLSZWidget, self, channel, time) - - print("Workflow complete") - pass - - @magicgui(header=dict(widget_type="Label", label="

Apply Workflow and Save Output

"), - time_start=dict(label="Time Start:", max=2**20), - time_end=dict(label="Time End:", - value=1, max=2**20), - ch_start=dict(label="Channel Start:"), - ch_end=dict(label="Channel End:", value=1), - Use_Cropping=dict( - widget_type="Checkbox", label="Crop Data", value=False), - get_active_workflow=dict( - widget_type="Checkbox", label="Get active workflow in napari-workflow", value=False), - workflow_path=dict( - mode='r', label="Load custom workflow (.yaml/yml)"), - save_as_type={ - "label": "Save as filetype:", "choices": SaveFileType}, - save_path=dict( - mode='d', label="Directory to save "), - #custom_module=dict(widget_type="Checkbox",label="Load custom module (same dir as workflow)",value = False), - call_button="Apply Workflow and Save Result") - def Apply_Workflow_and_Save(self, - header, - time_start: int, - time_end: int, - ch_start: int, - ch_end: int, - Use_Cropping, - roi_layer_list: ShapesData, - get_active_workflow: bool = False, - workflow_path: Path = Path.home(), - save_as_type: str = SaveFileType.tiff, - save_path: Path = Path(history.get_save_history()[0])): - """ - Apply a user-defined analysis workflow using napari-workflows - - Args: - time_start (int): Start Time - time_end (int): End Time - ch_start (int): Start Channel - ch_end (int): End Channel - Use_Cropping (_type_): Use cropping based on ROIs in the shapes layer - roi_layer_list (ShapesData): Shapes layer to use for cropping; can be a list of shapes - get_active_workflow (bool, optional): Gets active workflow in napari. Defaults to False. - workflow_path (Path, optional): User can also choose a custom workflow defined in a yaml file. - save_path (Path, optional): Path to save resulting data - """ - assert LLSZWidget.LlszMenu.open_file, "Image not initialised" - - check_dimensions(time_start, time_end, ch_start, ch_end, - LLSZWidget.LlszMenu.lattice.channels, LLSZWidget.LlszMenu.lattice.time) - - # Get parameters - angle = LLSZWidget.LlszMenu.lattice.angle - dx = LLSZWidget.LlszMenu.lattice.dx - dy = LLSZWidget.LlszMenu.lattice.dy - dz = LLSZWidget.LlszMenu.lattice.dz - - if get_active_workflow: - # installs the workflow to napari - user_workflow = WorkflowManager.install( - self.parent_viewer).workflow - print("Workflow installed") - else: - # Automatically scan workflow file directory for *.py files. - # If it findss one, load it as a module - import importlib - parent_dir = workflow_path.resolve( - ).parents[0].__str__()+os.sep - sys.path.append(parent_dir) - custom_py_files = get_all_py_files(parent_dir) - if len(custom_py_files) == 0: - print( - f"No custom modules imported. If you'd like to use a cusotm module, place a *.py file in same folder as the workflow file {parent_dir}") - else: - modules = map( - importlib.import_module, custom_py_files) - print(f"Custom modules imported {modules}") - user_workflow = load_workflow(workflow_path) - - assert type( - user_workflow) is Workflow, "Workflow file is not a napari workflow object. Check file! You can use workflow inspector if needed" - - input_arg_first, input_arg_last, first_task_name, last_task_name = get_first_last_image_and_task( - user_workflow) - print(input_arg_first, input_arg_last, - first_task_name, last_task_name) - # get list of tasks - task_list = list(user_workflow._tasks.keys()) - print("Workflow loaded:") - print(user_workflow) - - vol = LLSZWidget.LlszMenu.lattice.data - - #vol_zyx= vol[time,channel,...] - - task_name_start = first_task_name[0] - - try: - task_name_last = last_task_name[0] - except IndexError: - task_name_last = task_name_start - - # variables to hold task name, initialize it as None - # if gpu, set otf_path, otherwise use psf - psf = None - otf_path = None - - if LLSZWidget.LlszMenu.lattice.decon_processing == DeconvolutionChoice.cuda_gpu: - #otf_path = "otf_path" - psf_arg = "psf" - psf = LLSZWidget.LlszMenu.lattice.psf - else: - psf_arg = "psf" - psf = LLSZWidget.LlszMenu.lattice.psf - # if cropping, set that as first task - - if Use_Cropping: - # convert Roi pixel coordinates to canvas coordinates - # necessary only when scale is used for napari.viewer.add_image operations - roi_layer_list = [ - x/LLSZWidget.LlszMenu.lattice.dy for x in roi_layer_list] - - deskewed_shape = LLSZWidget.LlszMenu.lattice.deskew_vol_shape - deskewed_volume = da.zeros(deskewed_shape) - z_start = 0 - z_end = deskewed_shape[0] - roi = "roi" - volume = "volume" - # Check if decon ticked, if so set as first and crop as second? - - # Create workflow for cropping and deskewing - # volume and roi used will be set dynamically - user_workflow.set("crop_deskew_image", crop_volume_deskew, - original_volume=volume, - deskewed_volume=deskewed_volume, - roi_shape=roi, - angle_in_degrees=angle, - voxel_size_x=dx, - voxel_size_y=dy, - voxel_size_z=dz, - z_start=z_start, - z_end=z_end, - deconvolution=LLSZWidget.LlszMenu.deconvolution.value, - decon_processing=LLSZWidget.LlszMenu.lattice.decon_processing, - psf=psf_arg, - skew_dir=LLSZWidget.LlszMenu.skew_dir) - - # change the first task so it accepts "crop_deskew as input" - new_task = modify_workflow_task( - old_arg=input_arg_first, task_key=task_name_start, new_arg="crop_deskew_image", workflow=user_workflow) - user_workflow.set(task_name_start, new_task) - - for idx, roi_layer in enumerate(tqdm(roi_layer_list, desc="ROI:", position=0)): - print("Processing ROI ", idx) - user_workflow.set(roi, roi_layer) - save_img_workflow(vol=vol, - workflow=user_workflow, - input_arg=volume, - first_task="crop_deskew_image", - last_task=task_name_last, - time_start=time_start, - time_end=time_end, - channel_start=ch_start, - channel_end=ch_end, - save_file_type=save_as_type, - save_path=save_path, - #roi_layer = roi_layer, - save_name_prefix="ROI_" + \ - str(idx), - save_name=LLSZWidget.LlszMenu.lattice.save_name, - dx=dx, - dy=dy, - dz=dz, - angle=angle, - deconvolution=LLSZWidget.LlszMenu.deconvolution.value, - decon_processing=LLSZWidget.LlszMenu.lattice.decon_processing, - otf_path=otf_path, - psf_arg=psf_arg, - psf=psf) - - # IF just deskewing and its not in the tasks, add that as first task - elif user_workflow.get_task(task_name_start)[0] not in (cle.deskew_y, cle.deskew_x): - input = "input" - # add task to the workflow - user_workflow.set("deskew_image", - LLSZWidget.LlszMenu.deskew_func, - input_image=input, - angle_in_degrees=angle, - voxel_size_x=dx, - voxel_size_y=dy, - voxel_size_z=dz, - linear_interpolation=True) - # Set input of the workflow to be from deskewing - # change workflow task starts from is "deskew_image" and - new_task = modify_workflow_task( - old_arg=input_arg_first, task_key=task_name_start, new_arg="deskew_image", workflow=user_workflow) - user_workflow.set(task_name_start, new_task) - - # if deconvolution checked, add it to start of workflow (add upstream of deskewing) - if LLSZWidget.LlszMenu.deconvolution.value: - psf = "psf" - otf_path = "otf_path" - input_arg_first, input_arg_last, first_task_name, last_task_name = get_first_last_image_and_task( - user_workflow) - - if LLSZWidget.LlszMenu.lattice.decon_processing == DeconvolutionChoice.cuda_gpu: - user_workflow.set("deconvolution", - pycuda_decon, - image=input, - psf=psf_arg, - dzdata=LLSZWidget.LlszMenu.lattice.dz, - dxdata=LLSZWidget.LlszMenu.lattice.dx, - dzpsf=LLSZWidget.LlszMenu.lattice.dz, - dxpsf=LLSZWidget.LlszMenu.lattice.dx, - num_iter=LLSZWidget.LlszMenu.lattice.psf_num_iter) - # user_workflow.set(input_arg_first,"deconvolution") - else: - user_workflow.set("deconvolution", - skimage_decon, - vol_zyx=input, - psf=psf_arg, - num_iter=LLSZWidget.LlszMenu.lattice.psf_num_iter, - clip=False, - filter_epsilon=0, - boundary='nearest') - # modify the user workflow so "deconvolution" is accepted - new_task = modify_workflow_task( - old_arg=input_arg_first, task_key=task_name_start, new_arg="deconvolution", workflow=user_workflow) - user_workflow.set(task_name_start, new_task) - input_arg_first, input_arg_last, first_task_name, last_task_name = get_first_last_image_and_task( - user_workflow) - task_name_start = first_task_name[0] - - save_img_workflow(vol=vol, - workflow=user_workflow, - input_arg=input, - first_task=task_name_start, - last_task=task_name_last, - time_start=time_start, - time_end=time_end, - channel_start=ch_start, - channel_end=ch_end, - save_file_type=save_as_type, - save_path=save_path, - save_name=LLSZWidget.LlszMenu.lattice.save_name, - dx=dx, - dy=dy, - dz=dz, - angle=angle, - deconvolution=LLSZWidget.LlszMenu.deconvolution.value, - decon_processing=LLSZWidget.LlszMenu.lattice.decon_processing, - otf_path=otf_path, - psf_arg=psf_arg, - psf=psf) - - # If deskewing is already as a task, then set the first argument to input so we can modify that later - else: - # if deskewing is already first task, then check if deconvolution needed - # if deconvolution checked, add it to start of workflow (add upstream of deskewing) - if LLSZWidget.LlszMenu.deconvolution.value: - psf = "psf" - otf_path = "otf_path" - input_arg_first, input_arg_last, first_task_name, last_task_name = get_first_last_image_and_task( - user_workflow) - - if LLSZWidget.LlszMenu.lattice.decon_processing == DeconvolutionChoice.cuda_gpu: - user_workflow.set("deconvolution", - pycuda_decon, - image=input, - psf=psf_arg, - dzdata=LLSZWidget.LlszMenu.lattice.dz, - dxdata=LLSZWidget.LlszMenu.lattice.dx, - dzpsf=LLSZWidget.LlszMenu.lattice.dz, - dxpsf=LLSZWidget.LlszMenu.lattice.dx, - num_iter=LLSZWidget.LlszMenu.lattice.psf_num_iter) - # user_workflow.set(input_arg_first,"deconvolution") - else: - user_workflow.set("deconvolution", - skimage_decon, - vol_zyx=input, - psf=psf_arg, - num_iter=LLSZWidget.LlszMenu.lattice.psf_num_iter, - clip=False, - filter_epsilon=0, - boundary='nearest') - # modify the user workflow so "deconvolution" is accepted - new_task = modify_workflow_task( - old_arg=input_arg_first, task_key=task_name_start, new_arg="deconvolution", workflow=user_workflow) - user_workflow.set(task_name_start, new_task) - input_arg_first, input_arg_last, first_task_name, last_task_name = get_first_last_image_and_task( - user_workflow) - task_name_start = first_task_name[0] - - # we pass first argument as input - save_img_workflow(vol=vol, - workflow=user_workflow, - input_arg=input_arg_first, - first_task=task_name_start, - last_task=task_name_last, - time_start=time_start, - time_end=time_end, - channel_start=ch_start, - channel_end=ch_end, - save_file_type=save_as_type, - save_path=save_path, - save_name=LLSZWidget.LlszMenu.lattice.save_name, - dx=dx, - dy=dy, - dz=dz, - angle=angle, - deconvolution=LLSZWidget.LlszMenu.deconvolution.value, - decon_processing=LLSZWidget.LlszMenu.lattice.decon_processing, - otf_path=otf_path, - psf_arg=psf_arg, - psf=psf) - - print("Workflow complete") - return - - pass - -def _napari_lattice_widget_wrapper() -> LLSZWidget: - # split widget type enables a resizable widget - #max_height = 50 - # Important to have this or napari won't recognize the classes and magicclass qidgets - widget = LLSZWidget() - # aligning collapsible widgets at the top instead of having them centered vertically - widget._widget._layout.setAlignment(Qt.AlignTop) - - # widget._widget._layout.setWidgetResizable(True) - return widget diff --git a/plugin/napari_lattice/_reader.py b/plugin/napari_lattice/_reader.py deleted file mode 100644 index f872dbbc..00000000 --- a/plugin/napari_lattice/_reader.py +++ /dev/null @@ -1,160 +0,0 @@ -""" -reader plugin for h5 saved using np2bdv -https://github.com/nvladimus/npy2bdv -#TODO: pass pyramidal layer to napari -##use ilevel parameter in read_view to access different subsamples/pyramids -#pass a list of images with different resolution for pyramid; use is_pyramid=True flag in napari.add_image -, however pyramidal support for 3D not available yet -""" -from __future__ import annotations - -import dask.array as da -import dask.delayed as delayed -import os -import numpy as np -from napari.layers import image, Layer -from napari.layers._data_protocols import LayerDataProtocol - -from typing_extensions import Literal -from typing import Any, Optional, cast, TYPE_CHECKING, Tuple, List - -from lls_core.lattice_data import lattice_from_aics, LatticeData, img_from_array -from aicsimageio.types import ArrayLike, ImageLike - -if TYPE_CHECKING: - from aicsimageio.aics_image import AICSImage - -def lattice_from_napari( - img: Layer, - last_dimension: Optional[Literal["channel", "time"]], - **kwargs: Any -) -> LatticeData: - """ - Factory function for generating a LatticeData from a Napari Image - - Arguments: - kwargs: Extra arguments to pass to the LatticeData constructor - """ - - img_data_aics: AICSImage - - if 'aicsimage' in img.metadata.keys(): - img_data_aics = img.metadata['aicsimage'] - else: - if not last_dimension: - raise ValueError("Either the Napari image must have dimensional metadata, or last_dimension must be provided") - img_data_aics = img_from_array(cast(ArrayLike, img.data), last_dimension=last_dimension, physical_pixel_sizes=kwargs.get("physical_pixel_sizes")) - - save_name: str - if img.source.path is None: - # remove colon (:) and any leading spaces - save_name = img.name.replace(":", "").strip() - # replace any group of spaces with "_" - save_name = '_'.join(save_name.split()) - else: - file_name_noext = os.path.basename(img.source.path) - file_name = os.path.splitext(file_name_noext)[0] - # remove colon (:) and any leading spaces - save_name = file_name.replace(":", "").strip() - # replace any group of spaces with "_" - save_name = '_'.join(save_name.split()) - - return lattice_from_aics(img_data_aics, save_name=save_name, **kwargs) - -def napari_get_reader(path: list[str] | str): - """Check if file ends with h5 and returns reader function if true - Parameters - ---------- - path : str or list of str - Path to file, or list of paths. - Returns - ------- - function - """ - if isinstance(path, list): - # reader plugins may be handed single path, or a list of paths. - # if it is a list, we are only going to open first file - path = path[0] - - tiff_formats = (".tif",".tiff") - - if path.endswith(".h5"): - return bdv_h5_reader - elif path.endswith(tiff_formats): - return tiff_reader - # if we know we cannot read the file, we immediately return None. - else: - return None - - -def bdv_h5_reader(path): - """Take a path and returns a list of LayerData tuples.""" - - os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" - - #print(path) - import npy2bdv - h5_file = npy2bdv.npy2bdv.BdvEditor(path) - - img = [] - - #get dimensions of first image - first_timepoint = h5_file.read_view(time=0,channel=0) - - #Threshold to figure out when to use out-of-memory loading/dask - #Got the idea from napari-aicsimageio - #https://github.com/AllenCellModeling/napari-aicsimageio/blob/22934757c2deda30c13f39ec425343182fa91a89/napari_aicsimageio/core.py#L222 - mem_threshold_bytes = 4e9 - mem_per_threshold = 0.3 - - from psutil import virtual_memory - - file_size = os.path.getsize(path) - avail_mem = virtual_memory().available - - #if file size <30% of available memory and <4GB, open - if file_size<=mem_per_threshold*avail_mem and file_size List[Tuple[AICSImage, dict, str]]: - """Take path to tiff image and returns a list of LayerData tuples. - Specifying tiff_reader to have better control over tifffile related errors when using AICSImage - """ - - try: - image = AICSImage(path) - except Exception as e: - raise Exception("Error reading TIFF. Try upgrading tifffile library: pip install tifffile --upgrade.") from e - - # optional kwargs for the corresponding viewer.add_* method - add_kwargs = {} - - layer_type = "image" # optional, default is "image" - return [(image, add_kwargs, layer_type)] diff --git a/plugin/napari_lattice/circle-regular.svg b/plugin/napari_lattice/circle-regular.svg new file mode 100644 index 00000000..13e8ddaf --- /dev/null +++ b/plugin/napari_lattice/circle-regular.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/plugin/napari_lattice/disableable.py b/plugin/napari_lattice/disableable.py new file mode 100644 index 00000000..fbd3e9f1 --- /dev/null +++ b/plugin/napari_lattice/disableable.py @@ -0,0 +1,34 @@ +from tkinter import BaseWidget +from qtpy import QtWidgets as QtW, QtCore +from qtpy.QtCore import Qt, QEvent, QSize +from magicgui.application import use_app +from magicgui.widgets import Widget +from magicgui.widgets._concrete import _LabeledWidget +from magicgui.backends._qtpy.widgets import ( + QBaseWidget, + Container as ContainerBase, + MainWindow as MainWindowBase, +) + +class _Disableable(ContainerBase): + def __init__(self, layout="vertical", scrollable: bool = False, **kwargs): + BaseWidget.__init__(self, QtW.QWidget) + + if layout == "horizontal": + self._layout: QtW.QLayout = QtW.QHBoxLayout() + else: + self._layout = QtW.QVBoxLayout() + + self._stacked_widget = QtW.QStackedWidget(self._qwidget) + self._stacked_widget.setContentsMargins(0, 0, 0, 0) + self._inner_qwidget = QtW.QWidget(self._qwidget) + self._qwidget.setLayout(self._layout) + self._layout.addWidget(self._stacked_widget) + self._layout.addWidget(self._inner_qwidget) + + def _mgui_insert_widget(self, position: int, widget: Widget): + self._stacked_widget.insertWidget(position, widget.native) + + def _mgui_remove_widget(self, widget: Widget): + self._stacked_widget.removeWidget(widget.native) + widget.native.setParent(None) diff --git a/plugin/napari_lattice/dock_widget.py b/plugin/napari_lattice/dock_widget.py new file mode 100644 index 00000000..c57ef838 --- /dev/null +++ b/plugin/napari_lattice/dock_widget.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import logging +from textwrap import dedent +from typing import TYPE_CHECKING +import numpy as np +from lls_core.models.lattice_data import LatticeData +from magicclass import MagicTemplate, field, magicclass, set_options, vfield +from magicclass.wrappers import set_design +from napari_lattice.fields import ( + CroppingFields, + DeconvolutionFields, + DeskewFields, + OutputFields, + WorkflowFields, +) +from qtpy.QtCore import Qt +from qtpy.QtWidgets import QTabWidget +from napari_lattice.parent_connect import ParentConnect + +if TYPE_CHECKING: + from typing import Iterable + from napari_lattice.fields import NapariFieldGroup + from lls_core.types import ArrayLike + +# Enable Logging +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +@magicclass(widget_type="split") +class LLSZWidget(MagicTemplate): + def __post_init__(self): + # aligning collapsible widgets at the top instead of having them centered vertically + self._widget._layout.setAlignment(Qt.AlignTop) + + + def _check_validity(self) -> bool: + """ + Returns True if the model is valid + """ + try: + self._make_model() + return True + except: + return False + + def _make_model(self, validate: bool = True) -> LatticeData: + from rich import print + from sys import stdout + + deskew_args = self.LlszMenu.WidgetContainer.deskew_fields._get_kwargs() + output_args = self.LlszMenu.WidgetContainer.output_fields._make_model(validate=False) + params = LatticeData.make( + validate=validate, + + # Deskew + input_image=deskew_args["data"], + angle=deskew_args["angle"], + physical_pixel_sizes=deskew_args["physical_pixel_sizes"], + skew=deskew_args["skew"], + + # Output + channel_range=output_args.channel_range, + time_range=output_args.time_range, + save_dir=output_args.save_dir, + save_name=output_args.save_name or deskew_args["save_name"], + save_type=output_args.save_type, + save_suffix=output_args.save_suffix, + + workflow=self.LlszMenu.WidgetContainer.workflow_fields._make_model(), + deconvolution=self.LlszMenu.WidgetContainer.deconv_fields._make_model(), + crop=self.LlszMenu.WidgetContainer.cropping_fields._make_model() + ) + # Log the lattice + print(params, file=stdout) + return params + + @magicclass(widget_type="split") + class LlszMenu(MagicTemplate): + main_heading = field("

Napari Lattice: Visualization & Analysis

", widget_type="Label") + heading1 = field(dedent(""" +
+ Specify deskewing parameters and image layers in Tab 1.  + Additional analysis parameters can be configured in the other tabs.  + When you are ready to save, go to Tab 5.  + Output to specify the output directory.  + For more information, please refer to the documentation here. +
+ """.strip()), widget_type="Label") + + def __post_init__(self): + from qtpy.QtCore import Qt + from qtpy.QtWidgets import QLabel, QLayout + + if isinstance(self._widget._layout, QLayout): + self._widget._layout.setAlignment(Qt.AlignmentFlag.AlignTop) + + if isinstance(self.heading1.native, QLabel): + self.heading1.native.setWordWrap(True) + + # Tabbed Widget container to house all the widgets + @magicclass(widget_type="tabbed", name="Functions", labels=False) + class WidgetContainer(MagicTemplate): + + def __post_init__(self): + tab_widget: QTabWidget= self._widget._tab_widget + # Manually set the tab labels, because by default magicgui uses the widget names, but setting + # the names to human readable text makes them difficult to access via self + for i, label in enumerate(["1. Deskew", "2. Deconvolution", "3. Crop", "4. Workflow", "5. Output"]): + tab_widget.setTabText(i, label) + for field in [self.deskew_fields, self.deconv_fields, self.cropping_fields, self.workflow_fields, self.output_fields]: + # Connect event handlers + for subfield_name in dir(field): + subfield = getattr(field, subfield_name) + if isinstance(subfield, ParentConnect): + subfield.resolve(self, field, subfield_name) + # Trigger validation of default data + field._validate() + + # Using vfields here seems to prevent https://github.com/hanjinliu/magic-class/issues/110 + deskew_fields = vfield(DeskewFields) + deconv_fields = vfield(DeconvolutionFields) + cropping_fields = vfield(CroppingFields) + workflow_fields = vfield(WorkflowFields) + output_fields = vfield(OutputFields) + + @set_options(header=dict(widget_type="Label", label="

Preview Deskew

"), + time=dict(label="Time:", max=2**15), + channel=dict(label="Channel:"), + call_button="Preview" + ) + @set_design(text="Preview") + def preview(self, header: str, time: int, channel: int): + from pathlib import Path + + # We only need to process one time point for the preview, + # so we made a copy using a subset of the times + lattice = self._make_model(validate=False).copy_validate(update=dict( + time_range = range(time, time+1), + channel_range = range(channel, channel+1), + # Patch in a placeholder for the save dir because previewing doesn't use it + # TODO: use a more elegant solution such as making the "saveable" lattice + # a child class which more validations + save_dir = Path.home() + )) + + scale = ( + lattice.new_dz, + lattice.dy, + lattice.dx + ) + preview: ArrayLike + + # We extract the first available image to use as a preview + # This works differently for workflows and non-workflows + if lattice.workflow is None: + for slice in lattice.process().slices: + preview = slice.data + break + else: + preview = lattice.process_workflow().extract_preview() + + self.parent_viewer.add_image(preview, scale=scale, name="Napari Lattice Preview") + max_z = np.argmax(np.sum(preview, axis=(1, 2))) + self.parent_viewer.dims.set_current_step(0, max_z) + + + @set_design(text="Save") + def save(self): + from napari.utils.notifications import show_info + lattice = self._make_model() + lattice.save() + show_info(f"Deskewing successfuly completed. Results are located in {lattice.save_dir}") + + def _get_fields(self) -> Iterable[NapariFieldGroup]: + """Yields all the child Field classes which inherit from NapariFieldGroup""" + container = self.LlszMenu.WidgetContainer + yield from set(container.__magicclass_children__) diff --git a/plugin/napari_lattice/fields.py b/plugin/napari_lattice/fields.py new file mode 100644 index 00000000..c5c5e899 --- /dev/null +++ b/plugin/napari_lattice/fields.py @@ -0,0 +1,578 @@ +# FieldGroups that the users interact with to input data +import logging +from pathlib import Path +from textwrap import dedent +from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING +from typing_extensions import TypeVar +import pyclesperanto_prototype as cle +from xarray import DataArray +from lls_core import ( + DeconvolutionChoice, + DeskewDirection, + Log_Levels, +) +from lls_core.models import ( + CropParams, + DeconvolutionParams, + DeskewParams, + LatticeData, + OutputParams, +) +from lls_core.models.deskew import DefinedPixelSizes +from lls_core.models.output import SaveFileType +from lls_core.workflow import workflow_from_path +from magicclass import FieldGroup, MagicTemplate, field, magicclass, set_design +from magicclass.fields import MagicField +from magicclass.widgets import ComboBox, Label, Widget +from napari.layers import Image, Shapes +from napari.types import ShapesData +from napari_lattice.icons import GREEN, GREY, RED +from napari_lattice.reader import NapariImageParams, lattice_params_from_napari +from napari_lattice.utils import get_layers +from napari_workflows import Workflow, WorkflowManager +from qtpy.QtWidgets import QTabWidget +from strenum import StrEnum +from napari_lattice.parent_connect import connect_parent + +if TYPE_CHECKING: + from magicgui.widgets.bases import RangedWidget + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +def adjust_maximum(widget: "RangedWidget", max: int): + """ + Updates the maximum value + """ + widget.max = max + if widget.value > max: + widget.value = max + +def exception_to_html(e: BaseException) -> str: + """ + Converts an exception to HTML for reporting back to the user + """ + from pydantic import ValidationError + if isinstance(e, ValidationError): + message = [] + for error in e.errors(): + header = ", ".join([str(it).capitalize() for it in error['loc']]) + message.append(f"
  • {header} {error['msg']}
  • ") + joined = '\n'.join(message) + return f"" + else: + return f"{type(e).__name__}: {e}" + +def get_friendly_validations(model: FieldGroup) -> str: + """ + Generates a BaseModel, but returns validation errors in a user friendly way + """ + try: + model._make_model() + return "" + except BaseException as e: + return exception_to_html(e) + +class PixelSizeSource(StrEnum): + Metadata = "Image Metadata" + Manual = "Manual" + +class WorkflowSource(StrEnum): + ActiveWorkflow = "Active Workflow" + CustomPath = "Custom Path" + +class BackgroundSource(StrEnum): + Auto = "Automatic" + SecondLast = "Second Last" + Custom = "Custom" + +def enable_field(field: MagicField, enabled: bool = True) -> None: + """ + Enable the widget associated with a field + + Args: + field: The field to enable/disable + enabled: If False, disable the field instead of enabling it + """ + for real_field in field._guis.values(): + if not isinstance(real_field, Widget): + raise Exception("Define your fields with field() not vfield()!") + try: + real_field.visible = enabled + real_field.enabled = enabled + except RuntimeError: + pass + +FieldValueType = TypeVar("FieldValueType") +SelfType = TypeVar("SelfType") +def enable_if(fields: List[MagicField]): + """ + Makes an event handler that dynamically disables and enables a set of fields based on a criteria + Args: + condition: A function that takes an instance of the class and returns True if the fields should be enabled + fields: A list of fields to be dynamically enabled or disabled + Example: + :: + @some_field.connect + @enable_if( + [some_field] + ) + def _enable_fields(self, value) -> bool: + return value + """ + # Ideally we could use subclassing to add both the vfield and this event handler, but there + # seems to be a bug preventing this: https://github.com/hanjinliu/magic-class/issues/113. + + # Start by disabling all the fields + + def _decorator(fn: Callable[[SelfType, FieldValueType], bool])-> Callable[[SelfType, FieldValueType], None]: + for field in fields: + field.enabled = False + field.visible = False + + def make_handler(fn: Callable[[SelfType, FieldValueType], bool]) -> Callable[[SelfType, FieldValueType], None]: + def handler(self: Any, value: Any): + enable = fn(self, value) + for field in fields: + if enable: + logger.info(f"{field.name} Activated") + else: + logger.info(f"{field.name} Deactivated") + enable_field(field, enabled=enable) + return handler + + return make_handler(fn) + + return _decorator + +class StackAlong(StrEnum): + CHANNEL = "Channel" + TIME = "Time" + +class NapariFieldGroup(MagicTemplate): + def __post_init__(self): + self.changed.connect(self._validate, unique=False) + + # Style the error label. + # We have to check this is a QLabel because in theory this might run in a non-QT backend + errors = self.errors.native + from qtpy.QtWidgets import QLabel + if isinstance(errors, QLabel): + errors.setStyleSheet("color: red;") + # errors.setWordWrap(True) + + from qtpy.QtCore import Qt + self._widget._layout.setAlignment(Qt.AlignmentFlag.AlignTop) + + def _get_deskew(self) -> DeskewParams: + "Returns the DeskewParams from the other tab" + from napari_lattice.dock_widget import LLSZWidget + parent = self.find_ancestor(LLSZWidget) + return parent.LlszMenu.WidgetContainer.deskew_fields._make_model() + + def _get_parent_tab_widget(self) -> QTabWidget: + qwidget = self.native + # Walk up the widget tree until we find the tab widget + while not isinstance(qwidget, QTabWidget): + qwidget = qwidget.parent() + return qwidget + + def _get_tab_index(self) -> int: + return self._get_parent_tab_widget().indexOf(self._widget._qwidget) + + def _set_valid(self, valid: bool): + from qtpy.QtGui import QIcon + from importlib_resources import as_file + tab_parent = self._get_parent_tab_widget() + index = self._get_tab_index() + + if hasattr(self, "fields_enabled") and not self.fields_enabled.value: + # Special case for "diabled" sections + icon = GREY + elif valid: + icon = GREEN + else: + icon = RED + + with as_file(icon) as path: + tab_parent.setTabIcon(index, QIcon(str(path))) + + def reset_choices(self): + # This is used to prevent validation from re-running when a napari layer is added or removed + from magicgui.widgets import Container + with self.changed.blocked(): + super(Container, self).reset_choices() + + def _validate(self): + self.errors.value = get_friendly_validations(self) + valid = not bool(self.errors.value) + self.errors.visible = not valid + self._set_valid(valid) + + def _make_model(self): + raise NotImplementedError() + +class DeskewKwargs(NapariImageParams): + angle: float + skew: DeskewDirection + +@magicclass +class DeskewFields(NapariFieldGroup): + def _get_dimension_options(self, _) -> List[str]: + """ + Returns the list of dimension order options that might be possible for the current image stack + """ + default = ["Get from Metadata"] + ndims = max([len(layer.data.shape) for layer in self.img_layer.value], default=None) + if ndims is None: + return default + elif ndims == 3: + return ["ZYX"] + default + elif ndims == 4: + return ["TZYX", "CZYX"] + default + elif ndims == 5: + return ["TCZYX", "CTZYX"] + default + else: + raise Exception("Only 3-5 dimensional arrays are supported") + + img_layer = field(List[Image], widget_type='Select').with_options( + label="Image Layer(s) to Deskew", + tooltip="All the image layers you select will be stacked into one image and then deskewed. To select multiple layers, hold Command (MacOS) or Control (Windows, Linux)." + ).with_choices(lambda _x, _y: get_layers(Image)) + stack_along = field( + str, + ).with_choices( + [it.value for it in StackAlong] + ).with_options( + label="Stack Along", + tooltip="The direction along which to stack multiple selected layers.", + value=StackAlong.CHANNEL + ) + pixel_sizes_source = field(PixelSizeSource.Metadata, widget_type="RadioButtons").with_options(label="Pixel Size Source", orientation="horizontal").with_choices([it.value for it in PixelSizeSource]) + pixel_sizes = field(Tuple[float, float, float]).with_options( + label="Pixel Sizes: XYZ (µm)", + tooltip="The size of each pixel in microns. The first field selects the X pixel size, then Y, then Z." + ) + angle = field(LatticeData.get_default("angle")).with_options( + value=LatticeData.get_default("angle"), + label="Skew Angle (°)", + tooltip="The angle to deskew the image, in degrees" + ) + device = field(str).with_choices(cle.available_device_names()).with_options( + label="Graphics Device", + tooltip="The GPU that will be used to perform the processing" + ) + # merge_all_channels = field(False).with_options(label="Merge all Channels") + dimension_order = field( + str + ).with_choices( + _get_dimension_options + ).with_options( + label="Dimension Order", + tooltip="Specifies the order of dimensions in the input images. For example, if your image is a 4D array with multiple channels along the first axis, you will specify CZYX.", + value="Get from Metadata" + ) + skew_dir = field(DeskewDirection.Y, widget_type="RadioButtons").with_options( + label="Skew Direction", + tooltip="The axis along which to deskew", + orientation="horizontal" + ) + errors = field(Label).with_options(label="Errors") + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + from magicgui.widgets import TupleEdit + from qtpy.QtWidgets import QDoubleSpinBox + + # Enormous hack to set the precision + # A better method has been requested here: https://github.com/pyapp-kit/magicgui/issues/581#issuecomment-1709467219 + if isinstance(self.pixel_sizes, TupleEdit): + for subwidget in self.pixel_sizes._list: + if isinstance(subwidget.native, QDoubleSpinBox): + subwidget.native.setDecimals(10) + # We have to re-set the default value after changing the precision, because it's already been rounded up + # Also, we have to block emitting the changed signal at construction time + with self.changed.blocked(): + self.pixel_sizes.value = ( + DefinedPixelSizes.get_default("X"), + DefinedPixelSizes.get_default("Y"), + DefinedPixelSizes.get_default("Z") + ) + + @img_layer.connect + def _img_changed(self) -> None: + # Recalculate the dimension options whenever the image changes + self.dimension_order.reset_choices() + + @pixel_sizes_source.connect + @pixel_sizes.connect + def _rescale_image(self): + # Whenever the pixel sizes are changed, this should be reflected in the viewer + image: Image + from napari_lattice.utils import get_viewer + try: + pixels = self._get_kwargs()["physical_pixel_sizes"] + for image in self.img_layer.value: + image.scale = ( + *image.scale[0:-3], + pixels.Z, + pixels.Y, + pixels.X, + ) + viewer = get_viewer() + viewer.reset_view() + except: + pass + + @pixel_sizes_source.connect + @enable_if([pixel_sizes]) + def _hide_pixel_sizes(self, pixel_sizes_source: str): + # Hide the "Pixel Sizes" option unless the user specifies manual pixel size source + return pixel_sizes_source == PixelSizeSource.Manual + + @img_layer.connect + @enable_if([stack_along]) + def _hide_stack_along(self, img_layer: List[Image]): + # Hide the "Stack Along" option if we only have one image + return len(img_layer) > 1 + + def _get_kwargs(self) -> DeskewKwargs: + """ + Returns the LatticeData fields that the Deskew tab can provide + """ + from aicsimageio.types import PhysicalPixelSizes + DeskewParams.update_forward_refs() + params = lattice_params_from_napari( + imgs=self.img_layer.value, + dimension_order=None if self.dimension_order.value == "Get from Metadata" else self.dimension_order.value, + physical_pixel_sizes= None if self.pixel_sizes_source.value == PixelSizeSource.Metadata else PhysicalPixelSizes( + X = self.pixel_sizes.value[0], + Y = self.pixel_sizes.value[1], + Z = self.pixel_sizes.value[2] + ), + stack_along="C" if self.stack_along.value == StackAlong.CHANNEL else "T" + ) + return DeskewKwargs( + **params, + angle=self.angle.value, + skew = self.skew_dir.value, + ) + + def _make_model(self) -> DeskewParams: + kwargs = self._get_kwargs() + return DeskewParams( + input_image=kwargs["data"], + physical_pixel_sizes=kwargs["physical_pixel_sizes"], + angle=kwargs["angle"], + skew = kwargs["skew"] + ) + +@magicclass +class DeconvolutionFields(NapariFieldGroup): + # A counterpart to the DeconvolutionParams Pydantic class + fields_enabled = field(False, label="Enabled") + decon_processing = field(DeconvolutionChoice, label="Processing Algorithm") + psf = field(Tuple[Path, Path, Path, Path], label = "PSFs").with_options( + tooltip="PSFs must be in the same order as the image channels", + layout="vertical" + ) + psf_num_iter = field(int, label = "Number of Iterations") + background = field(ComboBox).with_choices( + [it.value for it in BackgroundSource] + ).with_options(label="Background") + background_custom = field(float).with_options( + visible=False, + label="Custom Background" + ) + errors = field(Label).with_options(label="Errors") + + @background.connect + @enable_if( + [background_custom] + ) + def _enable_custom_background(self, background: str) -> bool: + return background == BackgroundSource.Custom + + @fields_enabled.connect + @enable_if( + fields = [ + decon_processing, + psf, + psf_num_iter, + background + ] + ) + def _enable_fields(self, enabled: bool) -> bool: + return enabled + + def _make_model(self) -> Optional[DeconvolutionParams]: + if not self.fields_enabled.value: + return None + if self.background.value == BackgroundSource.Custom: + background = self.background_custom.value + elif self.background.value == BackgroundSource.Auto: + background = "auto" + else: + background = "second_last" + return DeconvolutionParams( + decon_processing=self.decon_processing.value, + background=background, + # Filter out unset PSFs + psf=[psf for psf in self.psf.value if psf.is_file()], + psf_num_iter=self.psf_num_iter.value + ) + +@magicclass +class CroppingFields(NapariFieldGroup): + # A counterpart to the CropParams Pydantic class + header = field(dedent(""" + Note that all cropping, including the regions of interest and Z range, is performed in the space of the deskewed shape. + This is to support the workflow of performing a preview deskew and using that to calculate the cropping coordinates. + """), widget_type="Label") + fields_enabled = field(False, label="Enabled") + shapes= field(List[Shapes], widget_type="Select", label = "ROI Shape Layers").with_options(choices=lambda _x, _y: get_layers(Shapes)) + z_range = field(Tuple[int, int]).with_options( + label = "Z Range", + value = (0, 1), + options = dict( + min = 0, + ), + ) + errors = field(Label).with_options(label="Errors") + + @set_design(text="Import ROI") + def import_roi(self, path: Path): + from lls_core.cropping import read_imagej_roi + from napari_lattice.utils import get_viewer + import numpy as np + roi_list = read_imagej_roi(path) + # convert to canvas coordinates + roi_list = (np.array(roi_list) * self._get_deskew().dy).tolist() + viewer = get_viewer() + viewer.add_shapes(roi_list, shape_type='polygon', edge_width=1, edge_color='yellow', face_color=[1, 1, 1, 0]) + + @set_design(text="New Crop") + def new_crop_layer(self): + from napari_lattice.utils import get_viewer + shapes = get_viewer().add_shapes(name="Napari Lattice Crop") + shapes.mode = "ADD_RECTANGLE" + self.shapes.value += [shapes] + + @connect_parent("deskew_fields.img_layer") + def _on_image_changed(self, field: MagicField): + try: + deskew = self._get_deskew() + except: + # Ignore if the deskew parameters are invalid + return + + deskewed_zmax = deskew.derived.deskew_vol_shape[0] + + # Update the allowed Z based the deskewed shape + for widget in self.z_range: + adjust_maximum(widget, deskewed_zmax) + + # Update the current max value to be the max of the shape + self.z_range[1].value = deskewed_zmax + + @fields_enabled.connect + @enable_if([shapes, z_range]) + def _enable_crop(self, enabled: bool) -> bool: + return enabled + + def _make_model(self) -> Optional[CropParams]: + import numpy as np + if self.fields_enabled.value: + deskew = self._get_deskew() + return CropParams( + # Convert from the input image space to the deskewed image space + # We assume here that dx == dy which isn't ideal + roi_list=ShapesData([np.array(shape.data) / deskew.dy for shape in self.shapes.value if len(shape.data) > 0]), + z_range=tuple(self.z_range.value), + ) + return None + +@magicclass +class WorkflowFields(NapariFieldGroup): + """ + Handles the workflow related parameters + """ + fields_enabled = field(False, label="Enabled") + workflow_source = field(ComboBox).with_options(label = "Workflow Source").with_choices([it.value for it in WorkflowSource]) + workflow_path = field(Path).with_options(label = "Workflow Path", visible=False) + errors = field(Label).with_options(label="Errors") + + @fields_enabled.connect + @enable_if([workflow_source]) + def _enable_workflow(self, enabled: bool) -> bool: + return enabled + + @workflow_source.connect + @enable_if([workflow_path]) + def _workflow_path(self, workflow_source: WorkflowSource) -> bool: + return workflow_source == WorkflowSource.CustomPath + + def _make_model(self) -> Optional[Workflow]: + if not self.fields_enabled.value: + return None + if self.workflow_source.value == WorkflowSource.ActiveWorkflow: + return WorkflowManager.install(self.parent_viewer).workflow + else: + return workflow_from_path(self.workflow_path.value) + +@magicclass +class OutputFields(NapariFieldGroup): + set_logging = field(Log_Levels.INFO).with_options(label="Logging Level") + time_range = field(Tuple[int, int]).with_options( + label="Time Export Range", + value=(0, 1), + options = dict( + min=0, + max=100, + ) + ) + channel_range = field(Tuple[int, int]).with_options( + label="Channel Range", + value=(0, 1), + options = dict( + min=0, + max=100, + ) + ) + save_type = field(SaveFileType).with_options( + label = "Save Format" + ) + save_path = field(Path).with_options( + label = "Save Directory", + # Directory select + mode="d" + ) + save_suffix = field(str).with_options( + value=OutputParams.get_default("save_suffix"), + label = "Save Suffix", + ) + errors = field(Label).with_options(label="Errors") + + def _make_model(self, validate: bool = True) -> OutputParams: + return OutputParams.make( + validate=validate, + + channel_range=range(self.channel_range.value[0], self.channel_range.value[1]), + time_range=range(self.time_range.value[0], self.time_range.value[1]), + save_dir=self.save_path.value, + save_suffix=self.save_suffix.value, + save_type=self.save_type.value, + ) + + @connect_parent("deskew_fields.img_layer") + def _on_image_changed(self, field: MagicField): + try: + img = self._get_deskew().input_image + except: + return + + # Update the maximum T and C + for widget in self.time_range: + adjust_maximum(widget, img.sizes["T"]) + for widget in self.channel_range: + adjust_maximum(widget, img.sizes["C"]) diff --git a/plugin/napari_lattice/green.png b/plugin/napari_lattice/green.png new file mode 100644 index 00000000..8addaf35 Binary files /dev/null and b/plugin/napari_lattice/green.png differ diff --git a/plugin/napari_lattice/grey.png b/plugin/napari_lattice/grey.png new file mode 100644 index 00000000..d3218dbc Binary files /dev/null and b/plugin/napari_lattice/grey.png differ diff --git a/plugin/napari_lattice/icons.py b/plugin/napari_lattice/icons.py new file mode 100644 index 00000000..dff6eb2d --- /dev/null +++ b/plugin/napari_lattice/icons.py @@ -0,0 +1,7 @@ +import importlib_resources + +resources = importlib_resources.files(__name__) + +GREEN = resources / "valid.svg" +GREY = resources / "circle-regular.svg" +RED = resources / "invalid.svg" diff --git a/plugin/napari_lattice/invalid.svg b/plugin/napari_lattice/invalid.svg new file mode 100644 index 00000000..e1d7dcc5 --- /dev/null +++ b/plugin/napari_lattice/invalid.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/plugin/napari_lattice/napari.yaml b/plugin/napari_lattice/napari.yaml index 9f0bc229..2ec0f9a5 100644 --- a/plugin/napari_lattice/napari.yaml +++ b/plugin/napari_lattice/napari.yaml @@ -2,9 +2,9 @@ name: napari-lattice display_name: Lattice Lightsheet Analysis contributions: commands: - - id: napari-lattice._dock_widget + - id: napari-lattice.dock_widget title: Create napari_lattice widget - python_name: napari_lattice._dock_widget:_napari_lattice_widget_wrapper + python_name: napari_lattice.dock_widget:LLSZWidget # ~~ Reader ~~ - id: napari-lattice.get_reader @@ -19,7 +19,7 @@ contributions: # python_name: napari_lattice.use_workflow:_workflow_widget widgets: - - command: napari-lattice._dock_widget + - command: napari-lattice.dock_widget display_name: Lattice Lightsheet Analysis #- command: napari-lattice.crop_deskew diff --git a/plugin/napari_lattice/parent_connect.py b/plugin/napari_lattice/parent_connect.py new file mode 100644 index 00000000..ad55a36d --- /dev/null +++ b/plugin/napari_lattice/parent_connect.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, TYPE_CHECKING +from operator import attrgetter + +if TYPE_CHECKING: + from magicclass import MagicTemplate + from magicgui.widgets.bases import ValueWidget + + +@dataclass +class ParentConnect: + """ + A function that wants to be connected to a parent or sibling's field. + This will be resolved after the GUI is instantiated. + """ + path: str + func: Callable + + def resolve(self, root: MagicTemplate, event_owner: MagicTemplate, field_name: str) -> None:# -> Callable[..., Any]: + """ + Converts this object into a true function that is connected to the appropriate change event + """ + field: ValueWidget = attrgetter(self.path)(root) + # field_owner = field.parent + bound_func = field.changed.connect(lambda: self.func(event_owner, field.value)) + setattr(event_owner, field_name, bound_func) + +def connect_parent(path: str) -> Callable[..., ParentConnect]: + """ + Mark this function as wanting to connect to a parent or sibling event + """ + def decorator(fn: Callable) -> ParentConnect: + return ParentConnect( + path=path, + func=fn + ) + return decorator diff --git a/plugin/napari_lattice/reader.py b/plugin/napari_lattice/reader.py new file mode 100644 index 00000000..cbe90615 --- /dev/null +++ b/plugin/napari_lattice/reader.py @@ -0,0 +1,208 @@ +""" +reader plugin for h5 saved using np2bdv +https://github.com/nvladimus/npy2bdv +#TODO: pass pyramidal layer to napari +##use ilevel parameter in read_view to access different subsamples/pyramids +#pass a list of images with different resolution for pyramid; use is_pyramid=True flag in napari.add_image +, however pyramidal support for 3D not available yet +""" +from __future__ import annotations + +import dask.array as da +import dask.delayed as delayed +import os +import numpy as np +from napari.layers import Image +from aicsimageio.aics_image import AICSImage + +from typing import List, Optional, Tuple, Collection, TYPE_CHECKING, TypedDict + +from aicsimageio.types import PhysicalPixelSizes +from lls_core.models.deskew import DefinedPixelSizes + +from logging import getLogger +logger = getLogger(__name__) + +if TYPE_CHECKING: + from aicsimageio.types import ImageLike + from xarray import DataArray + +class NapariImageParams(TypedDict): + data: DataArray + physical_pixel_sizes: DefinedPixelSizes + save_name: str + +def lattice_params_from_napari( + imgs: Collection[Image], + stack_along: str, + dimension_order: Optional[str] = None, + physical_pixel_sizes: Optional[PhysicalPixelSizes] = None, +) -> NapariImageParams: + """ + Factory function for generating a LatticeData from a Napari Image + """ + from xarray import DataArray, concat + + if len(imgs) < 1: + raise ValueError("At least one image must be provided.") + + if len(set(len(it.data.shape) for it in imgs)) > 1: + size_message = ",".join(f"{img.name}: {len(img.data.shape)}" for img in imgs) + raise ValueError(f"The input images have multiple different dimensions, which napari lattice doesn't support: {size_message}") + + save_name: str + # This is a set of all pixel sizes that we have seen so far + metadata_pixel_sizes: set[PhysicalPixelSizes] = set() + save_names = [] + # The pixel sizes according to the AICS metadata, if any + final_imgs: list[DataArray] = [] + + for img in imgs: + if img.source.path is None: + # remove colon (:) and any leading spaces + save_name = img.name.replace(":", "").strip() + # replace any group of spaces with "_" + save_name = '_'.join(save_name.split()) + else: + file_name_noext = os.path.basename(img.source.path) + file_name = os.path.splitext(file_name_noext)[0] + # remove colon (:) and any leading spaces + save_name = file_name.replace(":", "").strip() + # replace any group of spaces with "_" + save_name = '_'.join(save_name.split()) + + save_names.append(save_name) + + if 'aicsimage' in img.metadata.keys(): + img_data_aics: AICSImage = img.metadata['aicsimage'] + # If the user has not provided pixel sizes, we extract them fro the metadata + # Only process pixel sizes that are not none + if physical_pixel_sizes is None and all(img_data_aics.physical_pixel_sizes): + metadata_pixel_sizes.add(img_data_aics.physical_pixel_sizes) + + metadata_order = list(img_data_aics.dims.order) + metadata_shape = list(img_data_aics.dims.shape) + while len(metadata_order) > len(img.data.shape): + logger.info(f"Image metadata implies there are more dimensions ({len(metadata_order)}) than the image actually has ({len(img.data.shape)})") + for i, size in enumerate(metadata_shape): + if size not in img.data.shape: + logger.info(f"Excluding the {metadata_order[i]} dimension to reconcile dimension order") + del metadata_order[i] + del metadata_shape[i] + calculated_order = metadata_order + elif dimension_order is None: + raise ValueError("Either the Napari image must have dimensional metadata, or a dimension order must be provided") + else: + calculated_order = tuple(dimension_order) + + final_imgs.append(DataArray(img.data, dims=calculated_order)) + + if physical_pixel_sizes: + final_pixel_size = DefinedPixelSizes.from_physical(physical_pixel_sizes) + else: + if len(metadata_pixel_sizes) > 1: + raise Exception(f"Two or more layers that you have tried to merge have different pixel sizes according to their metadata! {metadata_pixel_sizes}") + elif len(metadata_pixel_sizes) < 1: + raise Exception("No pixel sizes could be determined from the image metadata. Consider manually specifying the pixel sizes.") + else: + final_pixel_size = DefinedPixelSizes.from_physical(metadata_pixel_sizes.pop()) + + final_img = concat(final_imgs, dim=stack_along) + return NapariImageParams(save_name=save_names[0], physical_pixel_sizes=final_pixel_size, data=final_img, dims=final_img.shape) + +def napari_get_reader(path: list[str] | str): + """Check if file ends with h5 and returns reader function if true + Parameters + ---------- + path : str or list of str + Path to file, or list of paths. + Returns + ------- + function + """ + if isinstance(path, list): + # reader plugins may be handed single path, or a list of paths. + # if it is a list, we are only going to open first file + path = path[0] + + tiff_formats = (".tif",".tiff") + + if path.endswith(".h5"): + return bdv_h5_reader + elif path.endswith(tiff_formats): + return tiff_reader + # if we know we cannot read the file, we immediately return None. + else: + return None + + +def bdv_h5_reader(path): + """Take a path and returns a list of LayerData tuples.""" + + os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" + + #print(path) + import npy2bdv + h5_file = npy2bdv.npy2bdv.BdvEditor(path) + + img = [] + + #get dimensions of first image + first_timepoint = h5_file.read_view(time=0,channel=0) + + #Threshold to figure out when to use out-of-memory loading/dask + #Got the idea from napari-aicsimageio + #https://github.com/AllenCellModeling/napari-aicsimageio/blob/22934757c2deda30c13f39ec425343182fa91a89/napari_aicsimageio/core.py#L222 + mem_threshold_bytes = 4e9 + mem_per_threshold = 0.3 + + from psutil import virtual_memory + + file_size = os.path.getsize(path) + avail_mem = virtual_memory().available + + #if file size <30% of available memory and <4GB, open + if file_size<=mem_per_threshold*avail_mem and file_size List[Tuple[AICSImage, dict, str]]: + """Take path to tiff image and returns a list of LayerData tuples. + Specifying tiff_reader to have better control over tifffile related errors when using AICSImage + """ + + try: + image = AICSImage(path) + except Exception as e: + raise Exception("Error reading TIFF. Try upgrading tifffile library: pip install tifffile --upgrade.") from e + + # optional kwargs for the corresponding viewer.add_* method + add_kwargs = {} + + layer_type = "image" # optional, default is "image" + return [(image, add_kwargs, layer_type)] diff --git a/plugin/napari_lattice/red.png b/plugin/napari_lattice/red.png new file mode 100644 index 00000000..1ca32a4d Binary files /dev/null and b/plugin/napari_lattice/red.png differ diff --git a/plugin/napari_lattice/ui_core.py b/plugin/napari_lattice/ui_core.py deleted file mode 100644 index 1e9af07e..00000000 --- a/plugin/napari_lattice/ui_core.py +++ /dev/null @@ -1,143 +0,0 @@ -from __future__ import annotations - -import numpy as np -from pathlib import Path -from typing import TYPE_CHECKING - -import pyclesperanto_prototype as cle -from lls_core.utils import check_dimensions -from lls_core.deconvolution import pycuda_decon, skimage_decon -from lls_core import config, DeskewDirection, DeconvolutionChoice -from lls_core .io import save_img - -if TYPE_CHECKING: - from napari.types import ImageData - -# Enable Logging -import logging -logger = logging.getLogger(__name__) -# inherit log level from config -logger.setLevel(config.log_level) - -def _Preview(LLSZWidget, - self_class, - time: int, - channel: int, - img_data: ImageData): - - logger.info("Previewing deskewed channel and time") - assert img_data.size, "No image open or selected" - assert time < LLSZWidget.LlszMenu.lattice.time, "Time is out of range" - assert channel < LLSZWidget.LlszMenu.lattice.channels, "Channel is out of range" - assert LLSZWidget.LlszMenu.lattice.skew in DeskewDirection, f"Skew direction not recognised. Got {LLSZWidget.LlszMenu.lattice.skew}" - - vol = LLSZWidget.LlszMenu.lattice.data - vol_zyx = np.array(vol[time, channel, :, :, :]) - - # apply deconvolution if checked - if LLSZWidget.LlszMenu.deconvolution.value: - print( - f"Deskewing for Time:{time} and Channel: {channel} with deconvolution") - psf = LLSZWidget.LlszMenu.lattice.psf[channel] - if LLSZWidget.LlszMenu.lattice.decon_processing == DeconvolutionChoice.cuda_gpu: - decon_data = pycuda_decon(image=vol_zyx, - psf=psf, - dzdata=LLSZWidget.LlszMenu.lattice.dz, - dxdata=LLSZWidget.LlszMenu.lattice.dx, - dzpsf=LLSZWidget.LlszMenu.lattice.dz, - dxpsf=LLSZWidget.LlszMenu.lattice.dx) - # pycuda_decon(image,otf_path,dzdata,dxdata,dzpsf,dxpsf) - else: - decon_data = skimage_decon( - vol_zyx=vol_zyx, psf=psf, num_iter=10, clip=False, filter_epsilon=0, boundary='nearest') - - deskew_final = LLSZWidget.LlszMenu.deskew_func(decon_data, - angle_in_degrees=LLSZWidget.LlszMenu.angle_value, - voxel_size_x=LLSZWidget.LlszMenu.lattice.dx, - voxel_size_y=LLSZWidget.LlszMenu.lattice.dy, - voxel_size_z=LLSZWidget.LlszMenu.lattice.dz, - linear_interpolation=True).astype(vol.dtype) - else: - logger.info(f"Deskewing for Time:{time} and Channel: {channel}") - deskew_final = LLSZWidget.LlszMenu.deskew_func(vol_zyx, - angle_in_degrees=LLSZWidget.LlszMenu.angle_value, - voxel_size_x=LLSZWidget.LlszMenu.lattice.dx, - voxel_size_y=LLSZWidget.LlszMenu.lattice.dy, - voxel_size_z=LLSZWidget.LlszMenu.lattice.dz, - linear_interpolation=True).astype(vol.dtype) - - # if getting an error LogicError: clSetKernelArg failed: #INVALID_ARG_SIZE - when processing arg#13 (1-based) - # make sure array is pulled from GPU - - deskew_final = cle.pull(deskew_final) - # TODO: Use dask - # if LLSZWidget.LlszMenu.dask: - #logger.info(f"Using CPU for deskewing") - # use cle library for affine transforms, but use dask and scipy - # deskew_final = deskew_final.compute() - - max_proj_deskew = cle.maximum_z_projection(deskew_final) - - # add channel and time information to the name - suffix_name = "_c" + str(channel) + "_t" + str(time) - scale = (LLSZWidget.LlszMenu.lattice.new_dz, - LLSZWidget.LlszMenu.lattice.dy, LLSZWidget.LlszMenu.lattice.dx) - # TODO:adding img of difff scales change dim slider - self_class.parent_viewer.add_image( - deskew_final, name="Deskewed image" + suffix_name, scale=scale) - self_class.parent_viewer.add_image( - max_proj_deskew, name="Deskew_MIP", scale=scale[1:3]) - self_class.parent_viewer.layers[0].visible = False - - logger.info(f"Preview: Deskewing complete") - return - - -def _Deskew_Save(LLSZWidget, - time_start: int, - time_end: int, - ch_start: int, - ch_end: int, - save_as_type, - save_path: Path): - - assert LLSZWidget.LlszMenu.open_file, "Image not initialised" - check_dimensions(time_start, time_end, ch_start, ch_end, - LLSZWidget.LlszMenu.lattice.channels, LLSZWidget.LlszMenu.lattice.time) - #time_range = range(time_start, time_end) - #channel_range = range(ch_start, ch_end) - angle = LLSZWidget.LlszMenu.lattice.angle - dx = LLSZWidget.LlszMenu.lattice.dx - dy = LLSZWidget.LlszMenu.lattice.dy - dz = LLSZWidget.LlszMenu.lattice.dz - - # Convert path to string - #save_path = save_path.__str__() - - # get the image data as dask array - img_data = LLSZWidget.LlszMenu.lattice.data - - # pass arguments for save tiff, callable and function arguments - save_img(vol=img_data, - func=LLSZWidget.LlszMenu.deskew_func, - time_start=time_start, - time_end=time_end, - channel_start=ch_start, - channel_end=ch_end, - save_file_type=save_as_type, - save_path=save_path, - save_name=LLSZWidget.LlszMenu.lattice.save_name, - dx=dx, - dy=dy, - dz=dz, - angle=angle, - angle_in_degrees=angle, - voxel_size_x=dx, - voxel_size_y=dy, - voxel_size_z=dz, - linear_interpolation=True, - LLSZWidget=LLSZWidget) - - print("Deskewing and Saving Complete -> ", save_path) - return - diff --git a/plugin/napari_lattice/utils.py b/plugin/napari_lattice/utils.py new file mode 100644 index 00000000..fd7594ae --- /dev/null +++ b/plugin/napari_lattice/utils.py @@ -0,0 +1,25 @@ +from napari.viewer import current_viewer, Viewer +from napari.layers import Layer +from typing_extensions import TypeVar +from typing import Sequence, Type + +def get_viewer() -> Viewer: + """ + Returns the current viewer, throwing an exception if one doesn't exist + """ + viewer = current_viewer() + if viewer is None: + raise Exception("No viewer present!") + return viewer + +LayerType = TypeVar("LayerType", bound=Layer) +def get_layers(type: Type[LayerType]) -> Sequence[LayerType]: + """ + Returns all layers in the current napari viewer of a given `Layer` subtype. + For example, if you pass `napari.layers.Image`, it will return a list of + Image layers + """ + viewer = current_viewer() + if viewer is None: + return [] + return [layer for layer in viewer.layers if isinstance(layer, type)] diff --git a/plugin/napari_lattice/valid.svg b/plugin/napari_lattice/valid.svg new file mode 100644 index 00000000..81ad09bf --- /dev/null +++ b/plugin/napari_lattice/valid.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/plugin/pyproject.toml b/plugin/pyproject.toml index 0b29b4f9..45eef7d4 100644 --- a/plugin/pyproject.toml +++ b/plugin/pyproject.toml @@ -35,29 +35,33 @@ classifiers = [ ] requires-python = ">=3.8" dependencies = [ - "aicsimageio>=4.11.0", - "dask", + "aicsimageio>=4.6.3", "dask[distributed]", # This isn't used directly, but we need to pin this version "fsspec>=2022.8.2", - # We need this Python 3.8 fix: https://github.com/hanjinliu/magic-class/pull/108 - "magic-class>=0.7.4", - "magicgui", + "importlib_resources", "lls_core", - "napari-aicsimageio>=0.7.2", + # The lower bound is because we need this Python 3.8 fix: https://github.com/hanjinliu/magic-class/pull/108 + # The upper bound is because we are waiting on https://github.com/hanjinliu/magic-class/issues/128 + "magic-class>=0.7.5,<0.8.0", + "magicgui<0.8.0", + # Currently commented out to avoid installation issues, although + # This can be reinstated once https://github.com/pypa/pip/pull/12095 is merged + # "napari-aicsimageio>=0.7.2", "napari-spreadsheet", "napari-workflow-inspector", "napari-workflows>=0.2.8", - "napari>=0.4.11", + "napari>=0.4.11,<0.5", "npy2bdv", "numpy", - "pandas", "psutil", "pyclesperanto_prototype>=0.20.0", + "pydantic", "qtpy", - "tqdm", - "typing_extensions", - "pyyaml", + "typing_extensions>=4.7.0", + "rich", + "StrEnum", + "xarray" ] [project.urls] @@ -68,7 +72,7 @@ SourceCode = "https://github.com/BioimageAnalysisCoreWEHI/napari_lattice" UserSupport = "https://github.com/BioimageAnalysisCoreWEHI/napari_lattice/issues" [tool.setuptools.package-data] -napari_lattice = ["*.yaml"] +napari_lattice = ["*.yaml", "*.svg", "*.png"] [project.optional-dependencies] testing = [ @@ -92,6 +96,7 @@ ignore_unused = [ # These napari plugins are needed to use the plugin, but aren't imported directly "napari-aicsimageio", "napari-workflow-inspector", + "napari-spreadsheet", # This is pinned but unused "fsspec", @@ -103,3 +108,14 @@ output_format = "human_detailed" typeCheckingMode = "off" reportUndefinedVariable = "error" reportMissingImports = "none" +reportMissingTypeStubs = false +reportUnknownVariableType = false +reportUnknownArgumentType = false +reportUnknownLambdaType = false +reportUnknownMemberType = false +reportUnknownParameterType = false +reportUntypedFunctionDecorator = false +reportMissingTypeArgument = false +reportPrivateUsage = false +reportPrivateImportUsage = false +reportUnnecessaryComparison = false diff --git a/plugin/tests/test_dock_widget.py b/plugin/tests/test_dock_widget.py index 7772c6a1..ef846cfb 100644 --- a/plugin/tests/test_dock_widget.py +++ b/plugin/tests/test_dock_widget.py @@ -1,54 +1,87 @@ -from __future__ import annotations -from napari_lattice._dock_widget import _napari_lattice_widget_wrapper -import numpy as np -from typing import Callable, TYPE_CHECKING -from magicclass.testing import check_function_gui_buildable, FunctionGuiTester -from napari.layers import Image -from magicclass import MagicTemplate -from magicclass.widgets import Widget -from magicclass._gui._gui_modes import ErrorMode - -if TYPE_CHECKING: - from napari import Viewer - -# Test if the widget can be created - -# make_napari_viewer is a pytest fixture that returns a napari viewer object -# Commenting this out as github CI is fixed -# @pytest.mark.skip(reason="GUI tests currently fail in github CI, unclear why") -# When testing locally, need pytest-qt - -def set_debug(cls: MagicTemplate): - """ - Recursively disables GUI error handling, so that this works with pytest - """ - def _handler(e: Exception, parent: Widget): - raise e - ErrorMode.get_handler = lambda self: _handler - cls._error_mode = ErrorMode.stderr - for child in cls.__magicclass_children__: - set_debug(child) - -def test_dock_widget(make_napari_viewer: Callable[[], Viewer]): - # make viewer and add an image layer using our fixture - viewer = make_napari_viewer() - - # Check if an image can be added as a layer - viewer.add_image(np.random.random((100, 100))) - - # Test if napari-lattice widget can be created in napari - gui = _napari_lattice_widget_wrapper() - viewer.window.add_dock_widget(gui) - -def test_check_buildable(): - widget = _napari_lattice_widget_wrapper() - check_function_gui_buildable(widget) - -def test_plugin_initialize(make_napari_viewer: Callable[[], Viewer]): - ui = _napari_lattice_widget_wrapper() - viewer = make_napari_viewer() - viewer.window.add_dock_widget(ui) - image = Image(np.random.random((100, 100, 100, 100))) - set_debug(ui) - tester = FunctionGuiTester(ui.LlszMenu.Choose_Image_Layer) - tester.call(img_layer=image, last_dimension_channel="time") +from __future__ import annotations + +from importlib_resources import as_file +from napari_lattice.dock_widget import LLSZWidget +from typing import Callable, TYPE_CHECKING +from magicclass.testing import check_function_gui_buildable, FunctionGuiTester +from magicclass import MagicTemplate +from magicclass.widgets import Widget +from magicclass._gui._gui_modes import ErrorMode +import pytest +from lls_core.sample import resources +from aicsimageio.aics_image import AICSImage +from napari_lattice.fields import PixelSizeSource +from tempfile import TemporaryDirectory + +if TYPE_CHECKING: + from napari import Viewer + +# Test if the widget can be created + +# make_napari_viewer is a pytest fixture that returns a napari viewer object +# Commenting this out as github CI is fixed +# @pytest.mark.skip(reason="GUI tests currently fail in github CI, unclear why") +# When testing locally, need pytest-qt + +@pytest.fixture(params=[ + "RBC_tiny.czi", + "LLS7_t1_ch1.czi", + "LLS7_t1_ch3.czi", + "LLS7_t2_ch1.czi", + "LLS7_t2_ch3.czi", +]) +def image_data(request: pytest.FixtureRequest): + """ + Fixture function that yields test images as file paths + """ + with as_file(resources / request.param) as image_path: + yield AICSImage(image_path) + +def set_debug(cls: MagicTemplate): + """ + Recursively disables GUI error handling, so that this works with pytest + """ + def _handler(e: Exception, parent: Widget): + raise e + ErrorMode.get_handler = lambda self: _handler + cls._error_mode = ErrorMode.stderr + for child in cls.__magicclass_children__: + set_debug(child) + +def test_dock_widget(make_napari_viewer: Callable[[], Viewer], image_data: AICSImage): + # make viewer and add an image layer using our fixture + viewer = make_napari_viewer() + + # Check if an image can be added as a layer + viewer.add_image(image_data.xarray_dask_data) + + # Test if napari-lattice widget can be created in napari + ui = LLSZWidget() + set_debug(ui) + viewer.window.add_dock_widget(ui) + + # Set the input parameters and execute the processing + with TemporaryDirectory() as tmpdir: + # Specify values for all the required GUI fields + fields = ui.LlszMenu.WidgetContainer.deskew_fields + # TODO: refactor this logic into a `lattice_params_from_aics` method + fields.img_layer.value = list(viewer.layers) + fields.dimension_order.value = image_data.dims.order + fields.pixel_sizes_source.value = PixelSizeSource.Manual + + # Test previewing + tester = FunctionGuiTester(ui.preview) + tester.call("", 0, 0) + + # Add the save path which shouldn't be needed for previewing + ui.LlszMenu.WidgetContainer.output_fields.save_path.value = tmpdir + + # Test saving + tester = FunctionGuiTester(ui.save) + tester.call() + + +def test_check_buildable(): + ui = LLSZWidget() + set_debug(ui) + check_function_gui_buildable(ui) diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 37798396..00000000 --- a/pyproject.toml +++ /dev/null @@ -1,10 +0,0 @@ -[build-system] -requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta" - -[tool.black] -line-length = 79 - -[tool.isort] -profile = "black" -line_length = 79