Skip to content

Commit

Permalink
patchcommander rework (#69)
Browse files Browse the repository at this point in the history
When sampling with more complex patch configurations, patches along the border of the image were not always sampled correctly. For example, with center=True, patches with the center outside the image were not sampled, ignoring the patches from the bottom row of the image. Additionally, with overlap, additional patches around the borders of the images still contained tissue. For example  in the case of a config with half overlapping slides, the patch half a patch size to the upper left of the (0, 0) coordinate was not sampled , while the bottom quarter of this patch still contained the image. These issues are fixed in this PR. For potential required lookup for the visual validation plots for this PR, check intermediate commit 7910ed8/
  • Loading branch information
JoeySpronck authored Nov 7, 2024
1 parent 41e501c commit 1e4c6ca
Showing 1 changed file with 35 additions and 12 deletions.
47 changes: 35 additions & 12 deletions wholeslidedata/buffer/patchcommander.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
import numpy as np

from wholeslidedata.image.wholeslideimage import WholeSlideImage

from wholeslidedata.samplers.utils import crop_data

@dataclass
class PatchConfiguration:
patch_shape: tuple = (512, 512, 3)
spacings: tuple = (0.5,)
overlap: tuple = (0, 0)
offset: tuple = (0, 0)
offset: tuple = (0, 0) # This currently has to be provided already scaled
center: bool = False

write_shape: tuple = None # This can be used when you crop off (overlapping) borders of the patches, only patches with mask in the inner part will be sampled

class PatchCommander(Commander):
def __init__(
Expand Down Expand Up @@ -81,20 +81,41 @@ def create_message(self) -> dict:

class SlidingPatchCommander(PatchCommander):
def get_patch_messages(self):
messages = []
step_row = int(self._patch_configuration.patch_shape[0] * self._ratio) - int(
self._patch_configuration.overlap[0] * self._ratio
)
step_col = int(self._patch_configuration.patch_shape[1] * self._ratio) - int(
self._patch_configuration.overlap[1] * self._ratio
)
step_shape = (int(self._patch_configuration.patch_shape[0]) - int(self._patch_configuration.overlap[0]), int(self._patch_configuration.patch_shape[1]) - int(self._patch_configuration.overlap[1]))
step_shape_scaled = (step_shape[0] * self._ratio, step_shape[1] * self._ratio)
# TODO: # offset has/had to be provided already scaled, I think this should be changed though
# offset_scaled = (self._patch_configuration.offset[0] * self._ratio, self._patch_configuration.offset[1] * self._ratio)
# offset_scaled = tuple(o%ss for o, ss in zip(offset_scaled, step_shape_scaled))
offset_scaled = (self._patch_configuration.offset[0], self._patch_configuration.offset[1])
offset_scaled = tuple(o%ss for o, ss in zip(offset_scaled, step_shape_scaled))
patch_shape_scaled = (self._patch_configuration.patch_shape[0] * self._ratio, self._patch_configuration.patch_shape[1] * self._ratio, self._patch_configuration.patch_shape[2])

if self._patch_configuration.center:
first_coord = [
offset_scaled[0] - ((offset_scaled[0] + patch_shape_scaled[0] // 2) // step_shape_scaled[0] - (1 if (offset_scaled[0] + patch_shape_scaled[0] // 2) % step_shape_scaled[0] == 0 else 0)) * step_shape_scaled[0],
offset_scaled[1] - ((offset_scaled[1] + patch_shape_scaled[1] // 2) // step_shape_scaled[1] - (1 if (offset_scaled[1] + patch_shape_scaled[1] // 2) % step_shape_scaled[1] == 0 else 0)) * step_shape_scaled[1]
]
else:
first_coord = [
offset_scaled[0] - ((offset_scaled[0] + patch_shape_scaled[0]) // step_shape_scaled[0] - (1 if (offset_scaled[0] + patch_shape_scaled[0]) % step_shape_scaled[0] == 0 else 0)) * step_shape_scaled[0],
offset_scaled[1] - ((offset_scaled[1] + patch_shape_scaled[1]) // step_shape_scaled[1] - (1 if (offset_scaled[1] + patch_shape_scaled[1]) % step_shape_scaled[1] == 0 else 0)) * step_shape_scaled[1]
]

max_i = int(self._x_dims + offset_scaled[0] + patch_shape_scaled[0] // 2) if self._patch_configuration.center else int(self._x_dims + offset_scaled[0])
max_j = int(self._y_dims + offset_scaled[1] + patch_shape_scaled[0] // 2) if self._patch_configuration.center else int(self._y_dims + offset_scaled[1])
max_i = max_i if not self._patch_configuration.center else max_i + patch_shape_scaled[0] // 2
max_j = max_j if not self._patch_configuration.center else max_j + patch_shape_scaled[1] // 2

range_i = list(range(first_coord[0], max_i, step_shape_scaled[0]))
range_j = list(range(first_coord[1], max_j, step_shape_scaled[1]))

wsm = None
if self._mask_path is not None:
wsm = WholeSlideImage(self._mask_path, backend=self._backend, auto_resample=True)

for row in range(self._patch_configuration.offset[1], self._y_dims, step_row):
for col in range(self._patch_configuration.offset[0], self._x_dims, step_col):
messages = []
for row in range_j:
for col in range_i:
if wsm is not None:
mask_patch = wsm.get_patch(
x=col,
Expand All @@ -106,6 +127,8 @@ def get_patch_messages(self):
relative=self._level_0_spacing,
)

if self._patch_configuration.write_shape is not None:
mask_patch = crop_data(mask_patch, self._patch_configuration.write_shape)
if np.all(mask_patch == 0):
continue

Expand Down

0 comments on commit 1e4c6ca

Please sign in to comment.