Skip to content

Commit

Permalink
add trainer with no mirroring and segmentation resampling order 0
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Feb 6, 2024
1 parent 52cab35 commit 164d94c
Showing 1 changed file with 53 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,56 @@ def get_dataloaders(self):
max(1, allowed_num_processes // 2), 3, None, True, 0.02)

return mt_gen_train, mt_gen_val


class nnUNetTrainer_DASegOrd0_NoMirroring(nnUNetTrainer):
def get_dataloaders(self):
"""
changed order_resampling_data, order_resampling_seg
"""
# we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether
# we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be
patch_size = self.configuration_manager.patch_size
dim = len(patch_size)

# needed for deep supervision: how much do we need to downscale the segmentation targets for the different
# outputs?
deep_supervision_scales = self._get_deep_supervision_scales()

rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()

# Deactivate mirroring data augmentation
mirror_axes = None
self.inference_allowed_mirroring_axes = None

# training pipeline
tr_transforms = self.get_training_transforms(
patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,
order_resampling_data=3, order_resampling_seg=0,
use_mask_for_norm=self.configuration_manager.use_mask_for_norm,
is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels,
regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,
ignore_label=self.label_manager.ignore_label)

# validation pipeline
val_transforms = self.get_validation_transforms(deep_supervision_scales,
is_cascaded=self.is_cascaded,
foreground_labels=self.label_manager.all_labels,
regions=self.label_manager.foreground_regions if
self.label_manager.has_regions else None,
ignore_label=self.label_manager.ignore_label)

dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)

allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
allowed_num_processes, 6, None, True, 0.02)
mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
max(1, allowed_num_processes // 2), 3, None, True, 0.02)

return mt_gen_train, mt_gen_val

0 comments on commit 164d94c

Please sign in to comment.