Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support plate loading for varying well sizes (ref #240) #241

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 150 additions & 18 deletions ome_zarr/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def __init__(self, node: Node) -> None:
LOGGER.debug("Plate created with ZarrLocation fmt: %s", self.zarr.fmt)
self.get_pyramid_lazy(node)

def get_pyramid_lazy(self, node: Node) -> None:
def get_pyramid_lazy(self, node: Node, loading: str = "new") -> None:
"""
Return a pyramid of dask data, where the highest resolution is the
stitched full-resolution images.
Expand All @@ -490,7 +490,7 @@ def get_pyramid_lazy(self, node: Node) -> None:
LOGGER.info("plate_data: %s", self.plate_data)
self.rows = self.plate_data.get("rows")
self.columns = self.plate_data.get("columns")
self.first_field = "0"

self.row_names = [row["name"] for row in self.rows]
self.col_names = [col["name"] for col in self.columns]

Expand All @@ -500,23 +500,50 @@ def get_pyramid_lazy(self, node: Node) -> None:
self.row_count = len(self.rows)
self.column_count = len(self.columns)

# Get the first well...
well_zarr = self.zarr.create(self.well_paths[0])
well_node = Node(well_zarr, node)
well_spec: Optional[Well] = well_node.first(Well)
if well_spec is None:
raise Exception("Could not find first well")
self.numpy_type = well_spec.numpy_type

LOGGER.debug("img_pyramid_shapes: %s", well_spec.img_pyramid_shapes)
# Default loading path
if loading == "default":
self.first_field = "0"
# Get the first well...
well_zarr = self.zarr.create(self.well_paths[0])
well_node = Node(well_zarr, node)
well_spec: Optional[Well] = well_node.first(Well)
if well_spec is None:
raise Exception("Could not find first well")
self.numpy_type = well_spec.numpy_type

LOGGER.debug("img_pyramid_shapes: %s", well_spec.img_pyramid_shapes)

self.axes = well_spec.img_metadata["axes"]
# Create a dask pyramid for the plate
pyramid = []
for level, tile_shape in enumerate(well_spec.img_pyramid_shapes):
lazy_plate = self.get_stitched_grid(level, tile_shape)
pyramid.append(lazy_plate)

# New loading path that handles wells with different shapes
elif loading == "new":
# Get all the well specs
# TODO: Find a way to speed up this loading, i.e. discussion here:
# https://github.com/ome/ngff/issues/141
well_specs = self.get_plate_well_specs(node)

# Assumption: The following information is consistent across wells
# Get the numpy type for the first well
well_spec = well_specs[self.well_paths[0]]
self.numpy_type = well_spec.numpy_type
LOGGER.debug("img_pyramid_shapes: %s", well_spec.img_pyramid_shapes)
self.axes = well_spec.img_metadata["axes"]
self.levels = len(well_spec.img_pyramid_shapes)
# FIXME: Figure out the real downsampling factor
self.downsampling_factor = 2

pyramid = []
for level in range(self.levels):
lazy_plate = self.get_stiched_plate(level, well_specs)
pyramid.append(lazy_plate)

self.axes = well_spec.img_metadata["axes"]

# Create a dask pyramid for the plate
pyramid = []
for level, tile_shape in enumerate(well_spec.img_pyramid_shapes):
lazy_plate = self.get_stitched_grid(level, tile_shape)
pyramid.append(lazy_plate)
else:
raise Exception("No valid loading path specified")

# Set the node.data to be pyramid view of the plate
node.data = pyramid
Expand All @@ -526,9 +553,114 @@ def get_pyramid_lazy(self, node: Node) -> None:
# "metadata" dict gets added to each 'plate' layer in napari
node.metadata.update({"metadata": {"plate": self.plate_data}})

def get_stiched_plate(self, level: int, well_specs: Dict):
LOGGER.debug(f"get_stiched_plate() level: {level}")
# New method to replace get_stitched_grid that can load a different
# shape for each well
def get_tile(tile_name: str) -> np.ndarray:
"""tile_name is 'level,z,c,t,row,col'"""
path = self.get_new_tile_path(level, tile_name)
LOGGER.debug("LOADING tile... %s with shape: %s", path, tile_shape)

try:
data = self.zarr.load(path)
except ValueError:
# With the new loading scheme, I don't think we hit this
# part anymore, but maybe there are exceptions where it still
# occurs?
LOGGER.exception("Failed to load %s", path)
data = np.zeros(tile_shape, dtype=self.numpy_type)
return data

def get_max_well_size(well_specs, padding: int = 10):
"""
Calculates the max size of any of the wells

:param well_specs: Dict of well_spec (Well Node)
:param padding: xy padding to be added between wells

"""
max_well_dims = list(list(well_specs.values())[0].img_pyramid_shapes[level])
for well_spec in well_specs.values():
new_dims = well_spec.img_pyramid_shapes[level]
for dim in range(len(max_well_dims) - 2):
if new_dims[dim] > max_well_dims[dim]:
max_well_dims[dim] = new_dims[dim]
for dim in range(len(max_well_dims) - 2, len(max_well_dims)):
real_padding = padding * self.downsampling_factor ** -(
level - self.levels
)
if new_dims[dim] + real_padding > max_well_dims[dim]:
max_well_dims[dim] = new_dims[dim] + real_padding
return max_well_dims

def calculate_required_padding(max_well_dims, tile_shape):
# Calculate the required padding by dimension
diff_size = []
for i in range(len(max_well_dims)):
diff_size.append(max_well_dims[i] - tile_shape[i])

# Decide which side gets padded
# Logic:
# 1. Pad x & y equally on both sides
# 2. Pad z, c, t on right side (keep aligned at the same 0)
# Limitations:
# 1. Does not take into account transformations
# 2. FIXME: Padding of channels is not optimal, could make a
# channel appear as something that its not in the viewer
padding = []
for i in range(len(max_well_dims) - 2):
padding.append((0, diff_size[i]))

for i in range(len(max_well_dims) - 2, len(max_well_dims)):
padding.append((int(diff_size[i] / 2), round(diff_size[i] / 2 + 0.1)))

return tuple(padding)

max_well_dims = get_max_well_size(well_specs)

lazy_reader = delayed(get_tile)

lazy_rows = []
for row_name in self.row_names:
lazy_row: List[da.Array] = []
for col_name in self.col_names:
tile_name = f"{row_name}/{col_name}"
if tile_name in well_specs:
tile_shape = well_specs[tile_name].img_pyramid_shapes[level]
lazy_tile = da.from_delayed(
lazy_reader(tile_name), shape=tile_shape, dtype=self.numpy_type
)
padding = calculate_required_padding(max_well_dims, tile_shape)
padded_lazy_tile = da.pad(
lazy_tile, pad_width=padding, mode="constant", constant_values=0
)
else:
# If a well does not exist on disk,
# just get an array of 0s of the fitting size
padded_lazy_tile = da.zeros(max_well_dims, dtype=self.numpy_type)
lazy_row.append(padded_lazy_tile)
lazy_rows.append(da.concatenate(lazy_row, axis=len(self.axes) - 1))
return da.concatenate(lazy_rows, axis=len(self.axes) - 2)

def get_plate_well_specs(self, node) -> Dict:
well_specs = {}
for well_path in self.well_paths:
LOGGER.info(f"Loading Well spec for {well_path}")
well_zarr = self.zarr.create(well_path)
well_node = Node(well_zarr, node)
well_spec: Optional[Well] = well_node.first(Well)
well_specs[well_path] = well_spec
return well_specs

def get_numpy_type(self, image_node: Node) -> np.dtype:
return image_node.data[0].dtype

def get_new_tile_path(
self, level: int, tile_name: str, image_index: int = 0
) -> str:
return f"{tile_name}/{image_index}/{level}"

def get_tile_path(self, level: int, row: int, col: int) -> str:
return (
f"{self.row_names[row]}/"
Expand Down