Skip to content

Commit

Permalink
fix!: Fixed patches not being aligned with chunks when using chunk sa…
Browse files Browse the repository at this point in the history
…mpling
  • Loading branch information
Karol-G committed May 7, 2024
1 parent f0b31aa commit 156d47a
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 57 deletions.
6 changes: 3 additions & 3 deletions examples/inference_chunk.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from patchly import GridSampler, Aggregator
from patchly import GridSampler, Aggregator, SamplingMode
from torch.utils.data import DataLoader, Dataset
import torch

Expand All @@ -16,15 +16,15 @@ def example():
chunk_size = (500, 500)

# Init GridSampler
sampler = GridSampler(image=image, spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, spatial_first=False)
sampler = GridSampler(image=image, spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, chunk_size=chunk_size, spatial_first=False, mode=SamplingMode.SAMPLE_SQUEEZE)
# Convert sampler into a PyTorch dataset
loader = SamplerDataset(sampler)
# Init dataloader
loader = DataLoader(loader, batch_size=4, num_workers=2, shuffle=False, pin_memory=False)
# Create an empty prediction passed to the aggregator
prediction = np.zeros(spatial_size, dtype=np.uint8)
# Init aggregator
aggregator = Aggregator(sampler=sampler, output=prediction, chunk_size=chunk_size, weights='gaussian', softmax_dim=0, spatial_first=False, has_batch_dim=True)
aggregator = Aggregator(sampler=sampler, output=prediction, weights='gaussian', softmax_dim=0, spatial_first=False, has_batch_dim=True)

# Run inference
with torch.no_grad():
Expand Down
40 changes: 3 additions & 37 deletions patchly/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class PatchStatus(Enum):


class Aggregator:
def __init__(self, sampler: GridSampler, output_size: Optional[Union[Tuple, npt.ArrayLike]] = None, output: Optional[npt.ArrayLike] = None, chunk_size: Optional[Union[Tuple, npt.ArrayLike]] = None,
def __init__(self, sampler: GridSampler, output_size: Optional[Union[Tuple, npt.ArrayLike]] = None, output: Optional[npt.ArrayLike] = None,
weights: Union[str, Callable] = 'avg', softmax_dim: Optional[int] = None, has_batch_dim: bool = False, spatial_first: bool = True, device: str = 'cpu'):
"""
Initializes the Aggregator object for aggregating patches into a larger output image.
Expand All @@ -45,7 +45,7 @@ def __init__(self, sampler: GridSampler, output_size: Optional[Union[Tuple, npt.
self.image_size_s = sampler.image_size_s
self.patch_size_s = sampler.patch_size_s
self.step_size_s = sampler.step_size_s
self.chunk_size_s = chunk_size
self.chunk_size_s = sampler.chunk_size_s
self.spatial_first = spatial_first
self.mode = sampler.mode
self.softmax_dim = softmax_dim
Expand Down Expand Up @@ -138,12 +138,6 @@ def check_sanity(self) -> None:
raise RuntimeError("The spatial size of the given output {} is unequal to the given spatial size {}.".format(self.output_h.shape[:len(self.image_size_s)], self.image_size_s))
if (not self.spatial_first) and (self.output_h.shape[-len(self.image_size_s):] != tuple(self.image_size_s)):
raise RuntimeError("The spatial size of the given output {} is unequal to the given spatial size {}.".format(self.output_h.shape[-len(self.image_size_s):], self.image_size_s))
if self.chunk_size_s is not None and np.any(self.chunk_size_s > self.image_size_s):
raise RuntimeError("The chunk size ({}) cannot be greater than the spatial size ({}) in one or more dimensions.".format(self.chunk_size_s, self.image_size_s))
if self.chunk_size_s is not None and np.any(self.patch_size_s >= self.chunk_size_s):
raise RuntimeError("The patch size ({}) cannot be greater or equal to the chunk size ({}) in one or more dimensions.".format(self.patch_size_s, self.chunk_size_s))
if self.chunk_size_s is not None and len(self.image_size_s) != len(self.chunk_size_s):
raise RuntimeError("The dimensionality of the chunk size ({}) is required to be the same as the spatial size ({}).".format(self.chunk_size_s, self.image_size_s))
if self.has_batch_dim and self.spatial_first:
raise RuntimeError("The arguments has_batch_dim and spatial_first cannot both be true at the same time.")
if self.mode.name.startswith('PAD_') and self.chunk_size_s is not None:
Expand Down Expand Up @@ -377,7 +371,7 @@ def __init__(self, sampler: GridSampler, image_size_s: Union[Tuple, npt.ArrayLik
self.step_size_s = step_size_s
self.chunk_size_s = chunk_size_s
self.chunk_dtype = self.set_chunk_dtype()
self.chunk_sampler, self.chunk_patch_dict, self.patch_chunk_dict = self.compute_patches()
self.chunk_sampler, self.chunk_patch_dict, self.patch_chunk_dict = self.sampler.chunk_sampler, self.sampler.chunk_patch_dict, self.sampler.patch_chunk_dict
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)

def set_chunk_dtype(self) -> np.dtype:
Expand All @@ -391,34 +385,6 @@ def set_chunk_dtype(self) -> np.dtype:
else:
return np.float32

def compute_patches(self) -> Tuple[_AdaptiveGridSampler, defaultdict, defaultdict]:
"""
Computes and organizes the patches and chunks needed for chunk-based aggregation in the _ChunkAggregator class.
:return: Tuple[_AdaptiveGridSampler, defaultdict, defaultdict] - Returns a tuple containing the chunk sampler, a dictionary mapping chunk IDs to patch data, and a dictionary mapping patch bounding boxes to chunk IDs.
"""
patch_sampler = self.sampler
chunk_sampler = _AdaptiveGridSampler(image_h=None, image_size_s=self.image_size_s, patch_size_s=self.chunk_size_s, step_size_s=self.chunk_size_s)
chunk_patch_dict = defaultdict(dict)
patch_chunk_dict = defaultdict(dict)

for idx in range(len(patch_sampler)):
patch_bbox_s = patch_sampler._get_bbox(idx)
patch_h = utils.LazyArray()
patch_chunk_dict[str(patch_bbox_s)]["patch"] = patch_h
patch_chunk_dict[str(patch_bbox_s)]["chunks"] = []
for chunk_id, chunk_bbox_s in enumerate(chunk_sampler):
if utils.is_overlapping(chunk_bbox_s, patch_bbox_s):
# Shift to chunk coordinate system
valid_patch_bbox_s = patch_bbox_s - np.array([chunk_bbox_s[:, 0], chunk_bbox_s[:, 0]]).T
# Crop patch bbox to chunk bounds
valid_patch_bbox_s = np.array([[max(valid_patch_bbox_s[i][0], 0), min(valid_patch_bbox_s[i][1], chunk_bbox_s[i][1] - chunk_bbox_s[i][0])] for i in range(len(chunk_bbox_s))])
crop_patch_bbox_s = valid_patch_bbox_s + np.array([chunk_bbox_s[:, 0], chunk_bbox_s[:, 0]]).T - np.array([patch_bbox_s[:, 0], patch_bbox_s[:, 0]]).T
chunk_patch_dict[chunk_id][str(patch_bbox_s)] = {"valid_patch_bbox": valid_patch_bbox_s, "crop_patch_bbox": crop_patch_bbox_s, "patch": patch_h, "status": PatchStatus.EMPTY}
patch_chunk_dict[str(patch_bbox_s)]["chunks"].append(chunk_id)

return chunk_sampler, chunk_patch_dict, patch_chunk_dict

def append(self, patch_h: npt.ArrayLike, patch_bbox_s: Union[Tuple, npt.ArrayLike]) -> None:
"""
Appends a patch to the chunk aggregator. This method is part of the _ChunkAggregator class and handles the complex logic of chunk-based patch aggregation.
Expand Down
76 changes: 71 additions & 5 deletions patchly/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Union, Optional, Tuple
import numpy.typing as npt
from enum import Enum
from collections import defaultdict


class SamplingMode(Enum):
Expand All @@ -14,9 +15,15 @@ class SamplingMode(Enum):
PAD_UNKNOWN = 5


class PatchStatus(Enum):
EMPTY = 1
FILLED = 2
COMPLETED = 3


class GridSampler:
def __init__(self, image: npt.ArrayLike, spatial_size: Union[Tuple, npt.ArrayLike], patch_size: Union[Tuple, npt.ArrayLike], step_size: Optional[Union[Tuple, npt.ArrayLike]] = None,
spatial_first: bool = True, mode: SamplingMode = SamplingMode.SAMPLE_SQUEEZE, pad_kwargs: dict = None):
chunk_size: Optional[Union[Tuple, npt.ArrayLike]] = None, spatial_first: bool = True, mode: SamplingMode = SamplingMode.SAMPLE_SQUEEZE, pad_kwargs: dict = None):
"""
Initializes the GridSampler object with specified parameters for sampling patches from an image.
Expand All @@ -26,6 +33,7 @@ def __init__(self, image: npt.ArrayLike, spatial_size: Union[Tuple, npt.ArrayLik
:param spatial_size: Union[Tuple, npt.ArrayLike] - The size of the spatial dimensions of the image.
:param patch_size: Union[Tuple, npt.ArrayLike] - The size of the patches to be sampled.
:param step_size: Optional[Union[Tuple, npt.ArrayLike]] - The step size between patches. Defaults to the same as patch_size if None.
:param chunk_size: Optional[Union[Tuple, npt.ArrayLike]] - The size of chunks for chunk-based processing. Optional.
:param spatial_first: bool - Indicates whether spatial dimensions come first in the image array. Defaults to True.
:param mode: SamplingMode - The sampling mode to use, which affects how patch borders are handled. Defaults to SamplingMode.SAMPLE_SQUEEZE.
:param pad_kwargs: dict - Additional keyword arguments for numpy's pad function, used in certain padding modes. Defaults to None.
Expand All @@ -34,12 +42,15 @@ def __init__(self, image: npt.ArrayLike, spatial_size: Union[Tuple, npt.ArrayLik
self.image_size_s = np.asarray(spatial_size)
self.patch_size_s = np.asarray(patch_size)
self.step_size_s = self.set_step_size(step_size, patch_size)
self.chunk_size_s = chunk_size
self.spatial_first = spatial_first
self.mode = mode
self.pad_kwargs = pad_kwargs
self.pad_width = None
self.check_sanity()
self.sampler = self.create_sampler()
self.chunk_sampler, self.chunk_patch_dict, self.patch_chunk_dict, self.patch_str_dict = self.set_chunks(chunk_size)
self.align_patches_with_chunks()

def set_step_size(self, step_size_s: Union[Tuple, np.ndarray], patch_size_s: Union[Tuple, np.ndarray]) -> np.ndarray:
"""
Expand All @@ -54,6 +65,12 @@ def set_step_size(self, step_size_s: Union[Tuple, np.ndarray], patch_size_s: Uni
else:
step_size_s = np.asarray(step_size_s)
return step_size_s

def set_chunks(self, chunk_size):
if chunk_size is None:
return None, None, None, None
else:
return self.compute_chunks()

def check_sanity(self):
"""
Expand All @@ -76,6 +93,12 @@ def check_sanity(self):
if self.step_size_s is not None and len(self.image_size_s) != len(self.step_size_s):
raise RuntimeError("The dimensionality of the patch offset ({}) is required to be the same as the spatial size ({})."
.format(self.step_size_s, self.image_size_s))
if self.chunk_size_s is not None and np.any(self.chunk_size_s > self.image_size_s):
raise RuntimeError("The chunk size ({}) cannot be greater than the spatial size ({}) in one or more dimensions.".format(self.chunk_size_s, self.image_size_s))
if self.chunk_size_s is not None and np.any(self.patch_size_s >= self.chunk_size_s):
raise RuntimeError("The patch size ({}) cannot be greater or equal to the chunk size ({}) in one or more dimensions.".format(self.patch_size_s, self.chunk_size_s))
if self.chunk_size_s is not None and len(self.image_size_s) != len(self.chunk_size_s):
raise RuntimeError("The dimensionality of the chunk size ({}) is required to be the same as the spatial size ({}).".format(self.chunk_size_s, self.image_size_s))
if self.mode.name.startswith('PAD_') and (self.image_h is None or not isinstance(self.image_h, np.ndarray)):
raise RuntimeError("The given sampling mode ({}) requires the image to be given and as type np.ndarray.".format(self.mode))

Expand Down Expand Up @@ -162,14 +185,57 @@ def __next__(self):
"""
return self.sampler.__next__()

def _get_bbox(self, idx: int) -> np.ndarray:
def get_bbox(self, idx: int) -> np.ndarray:
"""
Retrieves the bounding box coordinates of the patch at the specified index. This internal method is used to determine the spatial location of a patch within the larger image.
:param idx: int - The index of the patch for which the bounding box is required.
:return: np.ndarray - The bounding box coordinates of the specified patch.
"""
return self.sampler._get_bbox(idx)
return self.sampler.get_bbox(idx)

def compute_chunks(self):
"""
Computes and organizes the patches and chunks needed for chunk-based aggregation in the _ChunkAggregator class.
:return: Tuple[_AdaptiveGridSampler, defaultdict, defaultdict] - Returns a tuple containing the chunk sampler, a dictionary mapping chunk IDs to patch data, and a dictionary mapping patch bounding boxes to chunk IDs.
"""
patch_sampler = self.sampler
chunk_sampler = _AdaptiveGridSampler(image_h=None, image_size_s=self.image_size_s, patch_size_s=self.chunk_size_s, step_size_s=self.chunk_size_s)
chunk_patch_dict = defaultdict(dict)
patch_chunk_dict = defaultdict(dict)
patch_str_dict = defaultdict(dict)

for idx in range(len(patch_sampler)):
patch_bbox_s = patch_sampler.get_bbox(idx)
patch_str_dict[str(patch_bbox_s)] = patch_bbox_s
patch_h = utils.LazyArray()
patch_chunk_dict[str(patch_bbox_s)]["patch"] = patch_h
patch_chunk_dict[str(patch_bbox_s)]["chunks"] = []
for chunk_id, chunk_bbox_s in enumerate(chunk_sampler):
if utils.is_overlapping(chunk_bbox_s, patch_bbox_s):
# Shift to chunk coordinate system
valid_patch_bbox_s = patch_bbox_s - np.array([chunk_bbox_s[:, 0], chunk_bbox_s[:, 0]]).T
# Crop patch bbox to chunk bounds
valid_patch_bbox_s = np.array([[max(valid_patch_bbox_s[i][0], 0), min(valid_patch_bbox_s[i][1], chunk_bbox_s[i][1] - chunk_bbox_s[i][0])] for i in range(len(chunk_bbox_s))])
crop_patch_bbox_s = valid_patch_bbox_s + np.array([chunk_bbox_s[:, 0], chunk_bbox_s[:, 0]]).T - np.array([patch_bbox_s[:, 0], patch_bbox_s[:, 0]]).T
chunk_patch_dict[chunk_id][str(patch_bbox_s)] = {"valid_patch_bbox": valid_patch_bbox_s, "crop_patch_bbox": crop_patch_bbox_s, "patch": patch_h, "status": PatchStatus.EMPTY}
patch_chunk_dict[str(patch_bbox_s)]["chunks"].append(chunk_id)

return chunk_sampler, chunk_patch_dict, patch_chunk_dict, patch_str_dict

def align_patches_with_chunks(self):
if self.chunk_size_s is not None:
patch_chunk_tuples = [(patch, np.min(chunks["chunks"])) for patch, chunks in self.patch_chunk_dict.items()]
# patch_chunk_tuples = sorted(patch_chunk_tuples, key=lambda item: (item[1], item[0]))
patch_chunk_tuples = sorted(patch_chunk_tuples, key=lambda item: item[1])
patch_index_mapping = {str(self.sampler.get_bbox(idx)): idx for idx in range(len(self.sampler))}
aligned_patch_positions_s = [self.patch_str_dict[patch] for patch, _ in patch_chunk_tuples]
aligned_patch_indices = [patch_index_mapping[str(pos)] for pos in aligned_patch_positions_s]
aligned_patch_positions_s = self.sampler.patch_positions_s[aligned_patch_indices]
aligned_patch_sizes_s = self.sampler.patch_sizes_s[aligned_patch_indices]
self.sampler.patch_positions_s = aligned_patch_positions_s
self.sampler.patch_sizes_s = aligned_patch_sizes_s


class _CropGridSampler:
Expand Down Expand Up @@ -230,7 +296,7 @@ def __getitem__(self, idx: int):
:param idx: int - The index of the patch to retrieve.
:return: The patch and patch location at the specified index.
"""
patch_bbox_s = self._get_bbox(idx)
patch_bbox_s = self.get_bbox(idx)
patch_result = self.get_patch_result(patch_bbox_s)
return patch_result

Expand All @@ -246,7 +312,7 @@ def __next__(self):
else:
raise StopIteration

def _get_bbox(self, idx: int) -> np.ndarray:
def get_bbox(self, idx: int) -> np.ndarray:
"""
Computes the bounding box for the patch at the specified index. This internal method calculates the spatial coordinates defining the area of the image
covered by the patch, facilitating the extraction of the specific patch.
Expand Down
8 changes: 4 additions & 4 deletions patchly/tests/test_adaptive/test_adaptive_chunk_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def _test_aggregator(self, image, spatial_size, patch_size, chunk_size, step_siz
output_size = np.moveaxis(image.shape, softmax_dim, 0)[1:]

# Test with output size
sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_ADAPTIVE)
aggregator = Aggregator(sampler=sampler, output_size=output_size, chunk_size=chunk_size, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim)
sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, chunk_size=chunk_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_ADAPTIVE)
aggregator = Aggregator(sampler=sampler, output_size=output_size, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim)

for i, (patch, patch_bbox) in enumerate(sampler):
if multiply_elements_by_two:
Expand All @@ -230,8 +230,8 @@ def _test_aggregator(self, image, spatial_size, patch_size, chunk_size, step_siz
# Test without output array
if output is None:
output = np.zeros_like(image)
sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_ADAPTIVE)
aggregator = Aggregator(sampler=sampler, output=output, chunk_size=chunk_size, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim)
sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, chunk_size=chunk_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_ADAPTIVE)
aggregator = Aggregator(sampler=sampler, output=output, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim)

for i, (patch, patch_bbox) in enumerate(sampler):
if multiply_elements_by_two:
Expand Down
Loading

0 comments on commit 156d47a

Please sign in to comment.