diff --git a/wholeslidedata/buffer/patchcommander.py b/wholeslidedata/buffer/patchcommander.py index 538a621..70ec98a 100644 --- a/wholeslidedata/buffer/patchcommander.py +++ b/wholeslidedata/buffer/patchcommander.py @@ -7,16 +7,16 @@ import numpy as np from wholeslidedata.image.wholeslideimage import WholeSlideImage - +from wholeslidedata.samplers.utils import crop_data @dataclass class PatchConfiguration: patch_shape: tuple = (512, 512, 3) spacings: tuple = (0.5,) overlap: tuple = (0, 0) - offset: tuple = (0, 0) + offset: tuple = (0, 0) # This currently has to be provided already scaled center: bool = False - + write_shape: tuple = None # This can be used when you crop off (overlapping) borders of the patches, only patches with mask in the inner part will be sampled class PatchCommander(Commander): def __init__( @@ -81,20 +81,41 @@ def create_message(self) -> dict: class SlidingPatchCommander(PatchCommander): def get_patch_messages(self): - messages = [] - step_row = int(self._patch_configuration.patch_shape[0] * self._ratio) - int( - self._patch_configuration.overlap[0] * self._ratio - ) - step_col = int(self._patch_configuration.patch_shape[1] * self._ratio) - int( - self._patch_configuration.overlap[1] * self._ratio - ) + step_shape = (int(self._patch_configuration.patch_shape[0]) - int(self._patch_configuration.overlap[0]), int(self._patch_configuration.patch_shape[1]) - int(self._patch_configuration.overlap[1])) + step_shape_scaled = (step_shape[0] * self._ratio, step_shape[1] * self._ratio) + # TODO: # offset has/had to be provided already scaled, I think this should be changed though + # offset_scaled = (self._patch_configuration.offset[0] * self._ratio, self._patch_configuration.offset[1] * self._ratio) + # offset_scaled = tuple(o%ss for o, ss in zip(offset_scaled, step_shape_scaled)) + offset_scaled = (self._patch_configuration.offset[0], self._patch_configuration.offset[1]) + offset_scaled = tuple(o%ss for o, ss in zip(offset_scaled, step_shape_scaled)) + patch_shape_scaled = (self._patch_configuration.patch_shape[0] * self._ratio, self._patch_configuration.patch_shape[1] * self._ratio, self._patch_configuration.patch_shape[2]) + + if self._patch_configuration.center: + first_coord = [ + offset_scaled[0] - ((offset_scaled[0] + patch_shape_scaled[0] // 2) // step_shape_scaled[0] - (1 if (offset_scaled[0] + patch_shape_scaled[0] // 2) % step_shape_scaled[0] == 0 else 0)) * step_shape_scaled[0], + offset_scaled[1] - ((offset_scaled[1] + patch_shape_scaled[1] // 2) // step_shape_scaled[1] - (1 if (offset_scaled[1] + patch_shape_scaled[1] // 2) % step_shape_scaled[1] == 0 else 0)) * step_shape_scaled[1] + ] + else: + first_coord = [ + offset_scaled[0] - ((offset_scaled[0] + patch_shape_scaled[0]) // step_shape_scaled[0] - (1 if (offset_scaled[0] + patch_shape_scaled[0]) % step_shape_scaled[0] == 0 else 0)) * step_shape_scaled[0], + offset_scaled[1] - ((offset_scaled[1] + patch_shape_scaled[1]) // step_shape_scaled[1] - (1 if (offset_scaled[1] + patch_shape_scaled[1]) % step_shape_scaled[1] == 0 else 0)) * step_shape_scaled[1] + ] + max_i = int(self._x_dims + offset_scaled[0] + patch_shape_scaled[0] // 2) if self._patch_configuration.center else int(self._x_dims + offset_scaled[0]) + max_j = int(self._y_dims + offset_scaled[1] + patch_shape_scaled[0] // 2) if self._patch_configuration.center else int(self._y_dims + offset_scaled[1]) + max_i = max_i if not self._patch_configuration.center else max_i + patch_shape_scaled[0] // 2 + max_j = max_j if not self._patch_configuration.center else max_j + patch_shape_scaled[1] // 2 + + range_i = list(range(first_coord[0], max_i, step_shape_scaled[0])) + range_j = list(range(first_coord[1], max_j, step_shape_scaled[1])) + wsm = None if self._mask_path is not None: wsm = WholeSlideImage(self._mask_path, backend=self._backend, auto_resample=True) - for row in range(self._patch_configuration.offset[1], self._y_dims, step_row): - for col in range(self._patch_configuration.offset[0], self._x_dims, step_col): + messages = [] + for row in range_j: + for col in range_i: if wsm is not None: mask_patch = wsm.get_patch( x=col, @@ -106,6 +127,8 @@ def get_patch_messages(self): relative=self._level_0_spacing, ) + if self._patch_configuration.write_shape is not None: + mask_patch = crop_data(mask_patch, self._patch_configuration.write_shape) if np.all(mask_patch == 0): continue