diff --git a/examples/inference_chunk.py b/examples/inference_chunk.py index 021fd9d..95e771f 100644 --- a/examples/inference_chunk.py +++ b/examples/inference_chunk.py @@ -1,5 +1,5 @@ import numpy as np -from patchly import GridSampler, Aggregator +from patchly import GridSampler, Aggregator, SamplingMode from torch.utils.data import DataLoader, Dataset import torch @@ -16,7 +16,7 @@ def example(): chunk_size = (500, 500) # Init GridSampler - sampler = GridSampler(image=image, spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, spatial_first=False) + sampler = GridSampler(image=image, spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, chunk_size=chunk_size, spatial_first=False, mode=SamplingMode.SAMPLE_SQUEEZE) # Convert sampler into a PyTorch dataset loader = SamplerDataset(sampler) # Init dataloader @@ -24,7 +24,7 @@ def example(): # Create an empty prediction passed to the aggregator prediction = np.zeros(spatial_size, dtype=np.uint8) # Init aggregator - aggregator = Aggregator(sampler=sampler, output=prediction, chunk_size=chunk_size, weights='gaussian', softmax_dim=0, spatial_first=False, has_batch_dim=True) + aggregator = Aggregator(sampler=sampler, output=prediction, weights='gaussian', softmax_dim=0, spatial_first=False, has_batch_dim=True) # Run inference with torch.no_grad(): diff --git a/patchly/aggregator.py b/patchly/aggregator.py index bb68e50..3a640a0 100644 --- a/patchly/aggregator.py +++ b/patchly/aggregator.py @@ -24,7 +24,7 @@ class PatchStatus(Enum): class Aggregator: - def __init__(self, sampler: GridSampler, output_size: Optional[Union[Tuple, npt.ArrayLike]] = None, output: Optional[npt.ArrayLike] = None, chunk_size: Optional[Union[Tuple, npt.ArrayLike]] = None, + def __init__(self, sampler: GridSampler, output_size: Optional[Union[Tuple, npt.ArrayLike]] = None, output: Optional[npt.ArrayLike] = None, weights: Union[str, Callable] = 'avg', softmax_dim: Optional[int] = None, has_batch_dim: bool = False, spatial_first: bool = True, device: str = 'cpu'): """ Initializes the Aggregator object for aggregating patches into a larger output image. @@ -45,7 +45,7 @@ def __init__(self, sampler: GridSampler, output_size: Optional[Union[Tuple, npt. self.image_size_s = sampler.image_size_s self.patch_size_s = sampler.patch_size_s self.step_size_s = sampler.step_size_s - self.chunk_size_s = chunk_size + self.chunk_size_s = sampler.chunk_size_s self.spatial_first = spatial_first self.mode = sampler.mode self.softmax_dim = softmax_dim @@ -138,12 +138,6 @@ def check_sanity(self) -> None: raise RuntimeError("The spatial size of the given output {} is unequal to the given spatial size {}.".format(self.output_h.shape[:len(self.image_size_s)], self.image_size_s)) if (not self.spatial_first) and (self.output_h.shape[-len(self.image_size_s):] != tuple(self.image_size_s)): raise RuntimeError("The spatial size of the given output {} is unequal to the given spatial size {}.".format(self.output_h.shape[-len(self.image_size_s):], self.image_size_s)) - if self.chunk_size_s is not None and np.any(self.chunk_size_s > self.image_size_s): - raise RuntimeError("The chunk size ({}) cannot be greater than the spatial size ({}) in one or more dimensions.".format(self.chunk_size_s, self.image_size_s)) - if self.chunk_size_s is not None and np.any(self.patch_size_s >= self.chunk_size_s): - raise RuntimeError("The patch size ({}) cannot be greater or equal to the chunk size ({}) in one or more dimensions.".format(self.patch_size_s, self.chunk_size_s)) - if self.chunk_size_s is not None and len(self.image_size_s) != len(self.chunk_size_s): - raise RuntimeError("The dimensionality of the chunk size ({}) is required to be the same as the spatial size ({}).".format(self.chunk_size_s, self.image_size_s)) if self.has_batch_dim and self.spatial_first: raise RuntimeError("The arguments has_batch_dim and spatial_first cannot both be true at the same time.") if self.mode.name.startswith('PAD_') and self.chunk_size_s is not None: @@ -377,7 +371,7 @@ def __init__(self, sampler: GridSampler, image_size_s: Union[Tuple, npt.ArrayLik self.step_size_s = step_size_s self.chunk_size_s = chunk_size_s self.chunk_dtype = self.set_chunk_dtype() - self.chunk_sampler, self.chunk_patch_dict, self.patch_chunk_dict = self.compute_patches() + self.chunk_sampler, self.chunk_patch_dict, self.patch_chunk_dict = self.sampler.chunk_sampler, self.sampler.chunk_patch_dict, self.sampler.patch_chunk_dict self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) def set_chunk_dtype(self) -> np.dtype: @@ -391,34 +385,6 @@ def set_chunk_dtype(self) -> np.dtype: else: return np.float32 - def compute_patches(self) -> Tuple[_AdaptiveGridSampler, defaultdict, defaultdict]: - """ - Computes and organizes the patches and chunks needed for chunk-based aggregation in the _ChunkAggregator class. - - :return: Tuple[_AdaptiveGridSampler, defaultdict, defaultdict] - Returns a tuple containing the chunk sampler, a dictionary mapping chunk IDs to patch data, and a dictionary mapping patch bounding boxes to chunk IDs. - """ - patch_sampler = self.sampler - chunk_sampler = _AdaptiveGridSampler(image_h=None, image_size_s=self.image_size_s, patch_size_s=self.chunk_size_s, step_size_s=self.chunk_size_s) - chunk_patch_dict = defaultdict(dict) - patch_chunk_dict = defaultdict(dict) - - for idx in range(len(patch_sampler)): - patch_bbox_s = patch_sampler._get_bbox(idx) - patch_h = utils.LazyArray() - patch_chunk_dict[str(patch_bbox_s)]["patch"] = patch_h - patch_chunk_dict[str(patch_bbox_s)]["chunks"] = [] - for chunk_id, chunk_bbox_s in enumerate(chunk_sampler): - if utils.is_overlapping(chunk_bbox_s, patch_bbox_s): - # Shift to chunk coordinate system - valid_patch_bbox_s = patch_bbox_s - np.array([chunk_bbox_s[:, 0], chunk_bbox_s[:, 0]]).T - # Crop patch bbox to chunk bounds - valid_patch_bbox_s = np.array([[max(valid_patch_bbox_s[i][0], 0), min(valid_patch_bbox_s[i][1], chunk_bbox_s[i][1] - chunk_bbox_s[i][0])] for i in range(len(chunk_bbox_s))]) - crop_patch_bbox_s = valid_patch_bbox_s + np.array([chunk_bbox_s[:, 0], chunk_bbox_s[:, 0]]).T - np.array([patch_bbox_s[:, 0], patch_bbox_s[:, 0]]).T - chunk_patch_dict[chunk_id][str(patch_bbox_s)] = {"valid_patch_bbox": valid_patch_bbox_s, "crop_patch_bbox": crop_patch_bbox_s, "patch": patch_h, "status": PatchStatus.EMPTY} - patch_chunk_dict[str(patch_bbox_s)]["chunks"].append(chunk_id) - - return chunk_sampler, chunk_patch_dict, patch_chunk_dict - def append(self, patch_h: npt.ArrayLike, patch_bbox_s: Union[Tuple, npt.ArrayLike]) -> None: """ Appends a patch to the chunk aggregator. This method is part of the _ChunkAggregator class and handles the complex logic of chunk-based patch aggregation. diff --git a/patchly/sampler.py b/patchly/sampler.py index 48c6752..7e0a99b 100644 --- a/patchly/sampler.py +++ b/patchly/sampler.py @@ -4,6 +4,7 @@ from typing import Union, Optional, Tuple import numpy.typing as npt from enum import Enum +from collections import defaultdict class SamplingMode(Enum): @@ -14,9 +15,15 @@ class SamplingMode(Enum): PAD_UNKNOWN = 5 +class PatchStatus(Enum): + EMPTY = 1 + FILLED = 2 + COMPLETED = 3 + + class GridSampler: def __init__(self, image: npt.ArrayLike, spatial_size: Union[Tuple, npt.ArrayLike], patch_size: Union[Tuple, npt.ArrayLike], step_size: Optional[Union[Tuple, npt.ArrayLike]] = None, - spatial_first: bool = True, mode: SamplingMode = SamplingMode.SAMPLE_SQUEEZE, pad_kwargs: dict = None): + chunk_size: Optional[Union[Tuple, npt.ArrayLike]] = None, spatial_first: bool = True, mode: SamplingMode = SamplingMode.SAMPLE_SQUEEZE, pad_kwargs: dict = None): """ Initializes the GridSampler object with specified parameters for sampling patches from an image. @@ -26,6 +33,7 @@ def __init__(self, image: npt.ArrayLike, spatial_size: Union[Tuple, npt.ArrayLik :param spatial_size: Union[Tuple, npt.ArrayLike] - The size of the spatial dimensions of the image. :param patch_size: Union[Tuple, npt.ArrayLike] - The size of the patches to be sampled. :param step_size: Optional[Union[Tuple, npt.ArrayLike]] - The step size between patches. Defaults to the same as patch_size if None. + :param chunk_size: Optional[Union[Tuple, npt.ArrayLike]] - The size of chunks for chunk-based processing. Optional. :param spatial_first: bool - Indicates whether spatial dimensions come first in the image array. Defaults to True. :param mode: SamplingMode - The sampling mode to use, which affects how patch borders are handled. Defaults to SamplingMode.SAMPLE_SQUEEZE. :param pad_kwargs: dict - Additional keyword arguments for numpy's pad function, used in certain padding modes. Defaults to None. @@ -34,12 +42,15 @@ def __init__(self, image: npt.ArrayLike, spatial_size: Union[Tuple, npt.ArrayLik self.image_size_s = np.asarray(spatial_size) self.patch_size_s = np.asarray(patch_size) self.step_size_s = self.set_step_size(step_size, patch_size) + self.chunk_size_s = chunk_size self.spatial_first = spatial_first self.mode = mode self.pad_kwargs = pad_kwargs self.pad_width = None self.check_sanity() self.sampler = self.create_sampler() + self.chunk_sampler, self.chunk_patch_dict, self.patch_chunk_dict, self.patch_str_dict = self.set_chunks(chunk_size) + self.align_patches_with_chunks() def set_step_size(self, step_size_s: Union[Tuple, np.ndarray], patch_size_s: Union[Tuple, np.ndarray]) -> np.ndarray: """ @@ -54,6 +65,12 @@ def set_step_size(self, step_size_s: Union[Tuple, np.ndarray], patch_size_s: Uni else: step_size_s = np.asarray(step_size_s) return step_size_s + + def set_chunks(self, chunk_size): + if chunk_size is None: + return None, None, None, None + else: + return self.compute_chunks() def check_sanity(self): """ @@ -76,6 +93,12 @@ def check_sanity(self): if self.step_size_s is not None and len(self.image_size_s) != len(self.step_size_s): raise RuntimeError("The dimensionality of the patch offset ({}) is required to be the same as the spatial size ({})." .format(self.step_size_s, self.image_size_s)) + if self.chunk_size_s is not None and np.any(self.chunk_size_s > self.image_size_s): + raise RuntimeError("The chunk size ({}) cannot be greater than the spatial size ({}) in one or more dimensions.".format(self.chunk_size_s, self.image_size_s)) + if self.chunk_size_s is not None and np.any(self.patch_size_s >= self.chunk_size_s): + raise RuntimeError("The patch size ({}) cannot be greater or equal to the chunk size ({}) in one or more dimensions.".format(self.patch_size_s, self.chunk_size_s)) + if self.chunk_size_s is not None and len(self.image_size_s) != len(self.chunk_size_s): + raise RuntimeError("The dimensionality of the chunk size ({}) is required to be the same as the spatial size ({}).".format(self.chunk_size_s, self.image_size_s)) if self.mode.name.startswith('PAD_') and (self.image_h is None or not isinstance(self.image_h, np.ndarray)): raise RuntimeError("The given sampling mode ({}) requires the image to be given and as type np.ndarray.".format(self.mode)) @@ -162,14 +185,57 @@ def __next__(self): """ return self.sampler.__next__() - def _get_bbox(self, idx: int) -> np.ndarray: + def get_bbox(self, idx: int) -> np.ndarray: """ Retrieves the bounding box coordinates of the patch at the specified index. This internal method is used to determine the spatial location of a patch within the larger image. :param idx: int - The index of the patch for which the bounding box is required. :return: np.ndarray - The bounding box coordinates of the specified patch. """ - return self.sampler._get_bbox(idx) + return self.sampler.get_bbox(idx) + + def compute_chunks(self): + """ + Computes and organizes the patches and chunks needed for chunk-based aggregation in the _ChunkAggregator class. + + :return: Tuple[_AdaptiveGridSampler, defaultdict, defaultdict] - Returns a tuple containing the chunk sampler, a dictionary mapping chunk IDs to patch data, and a dictionary mapping patch bounding boxes to chunk IDs. + """ + patch_sampler = self.sampler + chunk_sampler = _AdaptiveGridSampler(image_h=None, image_size_s=self.image_size_s, patch_size_s=self.chunk_size_s, step_size_s=self.chunk_size_s) + chunk_patch_dict = defaultdict(dict) + patch_chunk_dict = defaultdict(dict) + patch_str_dict = defaultdict(dict) + + for idx in range(len(patch_sampler)): + patch_bbox_s = patch_sampler.get_bbox(idx) + patch_str_dict[str(patch_bbox_s)] = patch_bbox_s + patch_h = utils.LazyArray() + patch_chunk_dict[str(patch_bbox_s)]["patch"] = patch_h + patch_chunk_dict[str(patch_bbox_s)]["chunks"] = [] + for chunk_id, chunk_bbox_s in enumerate(chunk_sampler): + if utils.is_overlapping(chunk_bbox_s, patch_bbox_s): + # Shift to chunk coordinate system + valid_patch_bbox_s = patch_bbox_s - np.array([chunk_bbox_s[:, 0], chunk_bbox_s[:, 0]]).T + # Crop patch bbox to chunk bounds + valid_patch_bbox_s = np.array([[max(valid_patch_bbox_s[i][0], 0), min(valid_patch_bbox_s[i][1], chunk_bbox_s[i][1] - chunk_bbox_s[i][0])] for i in range(len(chunk_bbox_s))]) + crop_patch_bbox_s = valid_patch_bbox_s + np.array([chunk_bbox_s[:, 0], chunk_bbox_s[:, 0]]).T - np.array([patch_bbox_s[:, 0], patch_bbox_s[:, 0]]).T + chunk_patch_dict[chunk_id][str(patch_bbox_s)] = {"valid_patch_bbox": valid_patch_bbox_s, "crop_patch_bbox": crop_patch_bbox_s, "patch": patch_h, "status": PatchStatus.EMPTY} + patch_chunk_dict[str(patch_bbox_s)]["chunks"].append(chunk_id) + + return chunk_sampler, chunk_patch_dict, patch_chunk_dict, patch_str_dict + + def align_patches_with_chunks(self): + if self.chunk_size_s is not None: + patch_chunk_tuples = [(patch, np.min(chunks["chunks"])) for patch, chunks in self.patch_chunk_dict.items()] + # patch_chunk_tuples = sorted(patch_chunk_tuples, key=lambda item: (item[1], item[0])) + patch_chunk_tuples = sorted(patch_chunk_tuples, key=lambda item: item[1]) + patch_index_mapping = {str(self.sampler.get_bbox(idx)): idx for idx in range(len(self.sampler))} + aligned_patch_positions_s = [self.patch_str_dict[patch] for patch, _ in patch_chunk_tuples] + aligned_patch_indices = [patch_index_mapping[str(pos)] for pos in aligned_patch_positions_s] + aligned_patch_positions_s = self.sampler.patch_positions_s[aligned_patch_indices] + aligned_patch_sizes_s = self.sampler.patch_sizes_s[aligned_patch_indices] + self.sampler.patch_positions_s = aligned_patch_positions_s + self.sampler.patch_sizes_s = aligned_patch_sizes_s class _CropGridSampler: @@ -230,7 +296,7 @@ def __getitem__(self, idx: int): :param idx: int - The index of the patch to retrieve. :return: The patch and patch location at the specified index. """ - patch_bbox_s = self._get_bbox(idx) + patch_bbox_s = self.get_bbox(idx) patch_result = self.get_patch_result(patch_bbox_s) return patch_result @@ -246,7 +312,7 @@ def __next__(self): else: raise StopIteration - def _get_bbox(self, idx: int) -> np.ndarray: + def get_bbox(self, idx: int) -> np.ndarray: """ Computes the bounding box for the patch at the specified index. This internal method calculates the spatial coordinates defining the area of the image covered by the patch, facilitating the extraction of the specific patch. diff --git a/patchly/tests/test_adaptive/test_adaptive_chunk_aggregator.py b/patchly/tests/test_adaptive/test_adaptive_chunk_aggregator.py index 3fdb67a..f3cd8bd 100644 --- a/patchly/tests/test_adaptive/test_adaptive_chunk_aggregator.py +++ b/patchly/tests/test_adaptive/test_adaptive_chunk_aggregator.py @@ -209,8 +209,8 @@ def _test_aggregator(self, image, spatial_size, patch_size, chunk_size, step_siz output_size = np.moveaxis(image.shape, softmax_dim, 0)[1:] # Test with output size - sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_ADAPTIVE) - aggregator = Aggregator(sampler=sampler, output_size=output_size, chunk_size=chunk_size, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim) + sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, chunk_size=chunk_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_ADAPTIVE) + aggregator = Aggregator(sampler=sampler, output_size=output_size, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim) for i, (patch, patch_bbox) in enumerate(sampler): if multiply_elements_by_two: @@ -230,8 +230,8 @@ def _test_aggregator(self, image, spatial_size, patch_size, chunk_size, step_siz # Test without output array if output is None: output = np.zeros_like(image) - sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_ADAPTIVE) - aggregator = Aggregator(sampler=sampler, output=output, chunk_size=chunk_size, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim) + sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, chunk_size=chunk_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_ADAPTIVE) + aggregator = Aggregator(sampler=sampler, output=output, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim) for i, (patch, patch_bbox) in enumerate(sampler): if multiply_elements_by_two: diff --git a/patchly/tests/test_edge/test_edge_chunk_aggregator.py b/patchly/tests/test_edge/test_edge_chunk_aggregator.py index eb9c013..4a4b354 100644 --- a/patchly/tests/test_edge/test_edge_chunk_aggregator.py +++ b/patchly/tests/test_edge/test_edge_chunk_aggregator.py @@ -212,8 +212,8 @@ def _test_aggregator(self, image, spatial_size, patch_size, chunk_size, step_siz output_size = np.moveaxis(image.shape, softmax_dim, 0)[1:] # Test with output size - sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_EDGE) - aggregator = Aggregator(sampler=sampler, output_size=output_size, chunk_size=chunk_size, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim) + sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, chunk_size=chunk_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_EDGE) + aggregator = Aggregator(sampler=sampler, output_size=output_size, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim) for i, (patch, patch_bbox) in enumerate(sampler): if multiply_elements_by_two: @@ -233,8 +233,8 @@ def _test_aggregator(self, image, spatial_size, patch_size, chunk_size, step_siz # Test without output array if output is None: output = np.zeros_like(image) - sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_EDGE) - aggregator = Aggregator(sampler=sampler, output=output, chunk_size=chunk_size, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim) + sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, chunk_size=chunk_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_EDGE) + aggregator = Aggregator(sampler=sampler, output=output, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim) for i, (patch, patch_bbox) in enumerate(sampler): if multiply_elements_by_two: diff --git a/patchly/tests/test_squeeze/test_squeeze_chunk_aggregator.py b/patchly/tests/test_squeeze/test_squeeze_chunk_aggregator.py index b000c9e..48f9d7f 100644 --- a/patchly/tests/test_squeeze/test_squeeze_chunk_aggregator.py +++ b/patchly/tests/test_squeeze/test_squeeze_chunk_aggregator.py @@ -212,8 +212,8 @@ def _test_aggregator(self, image, spatial_size, patch_size, chunk_size, step_siz output_size = np.moveaxis(image.shape, softmax_dim, 0)[1:] # Test with output size - sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_SQUEEZE) - aggregator = Aggregator(sampler=sampler, output_size=output_size, chunk_size=chunk_size, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim) + sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, chunk_size=chunk_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_SQUEEZE) + aggregator = Aggregator(sampler=sampler, output_size=output_size, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim) for i, (patch, patch_bbox) in enumerate(sampler): if multiply_elements_by_two: @@ -233,8 +233,8 @@ def _test_aggregator(self, image, spatial_size, patch_size, chunk_size, step_siz # Test without output array if output is None: output = np.zeros_like(image) - sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_SQUEEZE) - aggregator = Aggregator(sampler=sampler, output=output, chunk_size=chunk_size, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim) + sampler = GridSampler(image=copy.deepcopy(image), spatial_size=spatial_size, patch_size=patch_size, step_size=step_size, chunk_size=chunk_size, spatial_first=spatial_first_sampler, mode=SamplingMode.SAMPLE_SQUEEZE) + aggregator = Aggregator(sampler=sampler, output=output, weights=weights, spatial_first=spatial_first_aggregator, softmax_dim=softmax_dim) for i, (patch, patch_bbox) in enumerate(sampler): if multiply_elements_by_two: