diff --git a/.gitignore b/.gitignore index 402f8845..fc5cf880 100644 --- a/.gitignore +++ b/.gitignore @@ -18,7 +18,7 @@ venv __pycache__/ *.py[cod] *$py.class - +*.out # C extensions *.so diff --git a/config/config.yaml b/config/config.yaml index 203cc383..d7003bf2 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -18,9 +18,15 @@ Trainer: max_epoch_train_finetune: 100 train_encoder: True train_decoder: True - # for mt trainer - transform_axis: [1, 2] - reg_weight: 10 + +PretrainEncoder: + group_option: patient + +PretrainDecoder: + null +FineTune: + reg_weight: 15 + #Checkpoint: runs/test_pipeline \ No newline at end of file diff --git a/contrastyou/augment/__init__.py b/contrastyou/augment/__init__.py index ffa4d740..3cd3dea8 100644 --- a/contrastyou/augment/__init__.py +++ b/contrastyou/augment/__init__.py @@ -1,47 +1,24 @@ -from typing import Callable, Union, List, Tuple +from torchvision import transforms -from deepclustering2.augment import pil_augment, SequentialWrapper +from contrastyou.augment.sequential_wrapper import SequentialWrapperTwice, SequentialWrapper +from deepclustering2.augment import pil_augment -class SequentialWrapperTwice(SequentialWrapper): - - def __init__(self, img_transform: Callable = None, target_transform: Callable = None, - if_is_target: Union[List[bool], Tuple[bool, ...]] = []) -> None: - super().__init__(img_transform, target_transform, if_is_target) - - def __call__( - self, *imgs, random_seed=None - ): - return [ - super(SequentialWrapperTwice, self).__call__(*imgs, random_seed=random_seed), - super(SequentialWrapperTwice, self).__call__(*imgs, random_seed=random_seed), - ] - - -class ACDC_transforms: +class ACDCTransforms: train = SequentialWrapperTwice( - pil_augment.Compose([ + comm_transform=pil_augment.Compose([ pil_augment.RandomCrop(224), pil_augment.RandomRotation(30), - pil_augment.ToTensor() ]), - pil_augment.Compose([ - pil_augment.RandomCrop(224), - pil_augment.RandomRotation(30), + img_transform=pil_augment.Compose([ + transforms.ColorJitter(brightness=[0.5, 1.5], contrast=[0.5, 1.5], saturation=[0.5, 1.5]), + transforms.ToTensor() + ]), + target_transform=pil_augment.Compose([ pil_augment.ToLabel() ]), - if_is_target=[False, True] - + total_freedom=True ) val = SequentialWrapper( - pil_augment.Compose([ - pil_augment.CenterCrop(224), - pil_augment.ToTensor() - ]), - pil_augment.Compose([ - pil_augment.CenterCrop(224), - pil_augment.ToLabel() - ]), - if_is_target=[False, True] - + comm_transform=pil_augment.CenterCrop(224) ) diff --git a/contrastyou/augment/sequential_wrapper.py b/contrastyou/augment/sequential_wrapper.py new file mode 100644 index 00000000..db376939 --- /dev/null +++ b/contrastyou/augment/sequential_wrapper.py @@ -0,0 +1,100 @@ +import random +from typing import Callable, List + +from PIL import Image +from torch import Tensor + +from deepclustering2.augment import pil_augment +from deepclustering2.decorator import FixRandomSeed + + +class SequentialWrapper: + + def __init__( + self, + comm_transform: Callable[[Image.Image], Image.Image] = None, + img_transform: Callable[[Image.Image], Tensor] = pil_augment.ToTensor(), + target_transform: Callable[[Image.Image], Tensor] = pil_augment.ToLabel() + ) -> None: + """ + :param comm_transform: common geo-transformation + :param img_transform: transformation only applied for images + :param target_transform: transformation only applied for targets + """ + self._comm_transform = comm_transform + self._img_transform = img_transform + self._target_transform = target_transform + + def __call__(self, imgs: List[Image.Image], targets: List[Image.Image] = None, comm_seed=None, img_seed=None, + target_seed=None): + _comm_seed: int = int(random.randint(0, int(1e5))) if comm_seed is None else int(comm_seed) # type ignore + imgs_after_comm, targets_after_comm = imgs, targets + if self._comm_transform: + imgs_after_comm, targets_after_comm = [], [] + for img in imgs: + with FixRandomSeed(_comm_seed): + img_ = self._comm_transform(img) + imgs_after_comm.append(img_) + if targets: + for target in targets: + with FixRandomSeed(_comm_seed): + target_ = self._comm_transform(target) + targets_after_comm.append(target_) + imgs_after_img_transform = [] + targets_after_target_transform = [] + _img_seed: int = int(random.randint(0, int(1e5))) if img_seed is None else int(img_seed) # type ignore + for img in imgs_after_comm: + with FixRandomSeed(_img_seed): + img_ = self._img_transform(img) + imgs_after_img_transform.append(img_) + + _target_seed: int = int(random.randint(0, int(1e5))) if target_seed is None else int(target_seed) # type ignore + if targets_after_comm: + for target in targets_after_comm: + with FixRandomSeed(_target_seed): + target_ = self._target_transform(target) + targets_after_target_transform.append(target_) + + if targets is None: + targets_after_target_transform = None + + if targets_after_target_transform is None: + return imgs_after_img_transform + return [*imgs_after_img_transform, *targets_after_target_transform] + + def __repr__(self): + return ( + f"comm_transform:{self._comm_transform}\n" + f"img_transform:{self._img_transform}.\n" + f"target_transform: {self._target_transform}" + ) + + +class SequentialWrapperTwice(SequentialWrapper): + + def __init__(self, comm_transform: Callable[[Image.Image], Image.Image] = None, + img_transform: Callable[[Image.Image], Tensor] = pil_augment.ToTensor(), + target_transform: Callable[[Image.Image], Tensor] = pil_augment.ToLabel(), + total_freedom=True) -> None: + """ + :param total_freedom: if True, the two-time generated images are using different seeds for all aspect, + otherwise, the images are used different random seed only for img_seed + """ + super().__init__(comm_transform, img_transform, target_transform) + self._total_freedom = total_freedom + + def __call__(self, imgs: List[Image.Image], targets: List[Image.Image] = None, global_seed=None, **kwargs): + global_seed = int(random.randint(0, int(1e5))) if global_seed is None else int(global_seed) # type ignore + with FixRandomSeed(global_seed): + comm_seed1, comm_seed2 = int(random.randint(0, int(1e5))), int(random.randint(0, int(1e5))) + img_seed1, img_seed2 = int(random.randint(0, int(1e5))), int(random.randint(0, int(1e5))) + target_seed1, target_seed2 = int(random.randint(0, int(1e5))), int(random.randint(0, int(1e5))) + if self._total_freedom: + return [ + super().__call__(imgs, targets, comm_seed1, img_seed1, target_seed1), + super().__call__(imgs, targets, comm_seed2, img_seed2, target_seed2), + ] + return [ + super().__call__(imgs, targets, comm_seed1, img_seed1, target_seed1), + super().__call__(imgs, targets, comm_seed1, img_seed2, target_seed1), + ] diff --git a/contrastyou/dataloader/_seg_datset.py b/contrastyou/dataloader/_seg_datset.py index 0ba3c9f7..7bfd45c3 100644 --- a/contrastyou/dataloader/_seg_datset.py +++ b/contrastyou/dataloader/_seg_datset.py @@ -56,7 +56,7 @@ class _SamplerIterator: def __init__(self, group2index, partion2index, group_sample_num=4, partition_sample_num=1) -> None: self._group2index, self._partition2index = dcp(group2index), dcp(partion2index) - assert group_sample_num >= 1 and group_sample_num <= len(self._group2index.keys()), group_sample_num + assert 1 <= group_sample_num <= len(self._group2index.keys()), group_sample_num self._group_sample_num = group_sample_num self._partition_sample_num = partition_sample_num diff --git a/contrastyou/dataloader/acdc_dataset.py b/contrastyou/dataloader/acdc_dataset.py index 09decd55..49ad3e36 100644 --- a/contrastyou/dataloader/acdc_dataset.py +++ b/contrastyou/dataloader/acdc_dataset.py @@ -1,12 +1,13 @@ import os import re +from pathlib import Path from typing import List, Tuple, Union import numpy as np from torch import Tensor +from contrastyou.augment.sequential_wrapper import SequentialWrapper from contrastyou.dataloader._seg_datset import ContrastDataset -from deepclustering2.augment import SequentialWrapper from deepclustering2.dataset import ACDCDataset as _ACDCDataset, ACDCSemiInterface as _ACDCSemiInterface @@ -15,20 +16,23 @@ class ACDCDataset(ContrastDataset, _ACDCDataset): zip_name = "ACDC_contrast.zip" folder_name = "ACDC_contrast" - def __init__(self, root_dir: str, mode: str, transforms: SequentialWrapper = None, + def __init__(self, root_dir: str, mode: str, transforms: SequentialWrapper = SequentialWrapper(), verbose=True, *args, **kwargs) -> None: super().__init__(root_dir, mode, ["img", "gt"], transforms, verbose) self._acdc_info = np.load(os.path.join(self._root_dir, "acdc_info.npy"), allow_pickle=True).item() assert isinstance(self._acdc_info, dict) and len(self._acdc_info) == 200 + self._transform = transforms def __getitem__(self, index) -> Tuple[List[Tensor], str, str, str]: - data, filename = super().__getitem__(index) + [img_png, target_png], filename_list = self._getitem_index(index) + filename = Path(filename_list[0]).stem + data = self._transform(imgs=[img_png], targets=[target_png], ) partition = self._get_partition(filename) group = self._get_group(filename) return data, filename, partition, group def _get_group(self, filename) -> Union[str, int]: - return self._get_group_name(filename) + return str(self._get_group_name(filename)) def _get_partition(self, filename) -> Union[str, int]: # set partition @@ -36,10 +40,10 @@ def _get_partition(self, filename) -> Union[str, int]: cutting_point = max_len_given_group // 3 cur_index = int(re.compile(r"\d+").findall(filename)[-1]) if cur_index <= cutting_point - 1: - return 0 + return str(0) if cur_index <= 2 * cutting_point: - return 1 - return 2 + return str(1) + return str(2) def show_paritions(self) -> List[Union[str, int]]: return [self._get_partition(f) for f in list(self._filenames.values())[0]] diff --git a/contrastyou/epocher/IIC_epocher.py b/contrastyou/epocher/IIC_epocher.py index 261f31c8..5f87280b 100644 --- a/contrastyou/epocher/IIC_epocher.py +++ b/contrastyou/epocher/IIC_epocher.py @@ -1,40 +1,45 @@ +import random + import torch +from contrastyou.epocher._utils import unfold_position from deepclustering2 import optim +from deepclustering2.decorator import FixRandomSeed from deepclustering2.meters2 import EpochResultDict from deepclustering2.optim import get_lrs_from_optimizer from deepclustering2.trainer.trainer import T_loader, T_loss from torch import nn from torch.nn import functional as F +from .contrast_epocher import PretrainDecoderEpoch as _PretrainDecoderEpoch from .contrast_epocher import PretrainEncoderEpoch as _PretrainEncoderEpoch class IICPretrainEcoderEpoch(_PretrainEncoderEpoch): def __init__(self, model: nn.Module, projection_head: nn.Module, projection_classifier: nn.Module, - optimizer: optim.Optimizer, pretrain_encoder_loader: T_loader = None, - contrastive_criterion: T_loss = None, num_batches: int = 0, - cur_epoch=0, device="cpu", iic_weight_ratio=1, *args, **kwargs) -> None: + optimizer: optim.Optimizer, pretrain_encoder_loader: T_loader, + contrastive_criterion: T_loss, num_batches: int = 0, + cur_epoch=0, device="cpu", group_option: str = "partition", iic_weight_ratio=1) -> None: """ - :param model: - :param projection_head: here the projection head should be a classifier + :param projection_head: + :param projection_classifier: classification head :param optimizer: - :param pretrain_encoder_loader: + :param pretrain_encoder_loader: infinite dataloader with `total freedom = True` :param contrastive_criterion: :param num_batches: :param cur_epoch: :param device: - :param args: - :param kwargs: + :param iic_weight_ratio: iic weight_ratio """ + super(IICPretrainEcoderEpoch, self).__init__(model, projection_head, optimizer, pretrain_encoder_loader, + contrastive_criterion, num_batches, + cur_epoch, device, group_option=group_option) assert pretrain_encoder_loader is not None self._projection_classifier = projection_classifier from ..losses.iic_loss import IIDLoss self._iic_criterion = IIDLoss() self._iic_weight_ratio = iic_weight_ratio - super().__init__(model, projection_head, optimizer, pretrain_encoder_loader, contrastive_criterion, num_batches, - cur_epoch, device, *args, **kwargs) def _run(self, *args, **kwargs) -> EpochResultDict: self._model.train() @@ -46,11 +51,12 @@ def _run(self, *args, **kwargs) -> EpochResultDict: (img, _), (img_tf, _), filename, partition_list, group_list = self._preprocess_data(data, self._device) _, (e5, *_), *_ = self._model(torch.cat([img, img_tf], dim=0), return_features=True) global_enc, global_tf_enc = torch.chunk(F.normalize(self._projection_head(e5), dim=1), chunks=2, dim=0) + global_probs, global_tf_probs = torch.chunk(self._projection_classifier(e5), chunks=2, dim=0) # fixme: here lack of some code for IIC labels = self._label_generation(partition_list, group_list) contrastive_loss = self._contrastive_criterion(torch.stack([global_enc, global_tf_enc], dim=1), labels=labels) - iic_loss = self._iic_criterion() # todo + iic_loss = self._iic_criterion(global_probs, global_tf_probs) # todo total_loss = self._iic_weight_ratio * iic_loss + (1 - self._iic_weight_ratio) * contrastive_loss self._optimizer.zero_grad() total_loss.backward() @@ -61,3 +67,57 @@ def _run(self, *args, **kwargs) -> EpochResultDict: report_dict = self.meters.tracking_status() indicator.set_postfix_dict(report_dict) return report_dict + + +class IICPretrainDecoderEpoch(_PretrainDecoderEpoch): + def __init__(self, model: nn.Module, projection_head: nn.Module, projection_classifier: nn.Module, + optimizer: optim.Optimizer, pretrain_decoder_loader: T_loader, contrastive_criterion: T_loss, + iic_criterion: T_loss, num_batches: int = 0, cur_epoch=0, device="cpu") -> None: + super().__init__(model, projection_head, optimizer, pretrain_decoder_loader, contrastive_criterion, num_batches, + cur_epoch, device) + self._projection_classifer = projection_classifier + self._iic_criterion = iic_criterion + + def _run(self, *args, **kwargs) -> EpochResultDict: + self._model.train() + assert self._model.training, self._model.training + self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0]) + + with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: # noqa + for i, data in zip(indicator, self._pretrain_decoder_loader): + (img, _), (img_ctf, _), filename, partition_list, group_list = self._preprocess_data(data, self._device) + seed = random.randint(0, int(1e5)) + with FixRandomSeed(seed): + img_gtf = torch.stack([self._transformer(x) for x in img], dim=0) + assert img_gtf.shape == img.shape, (img_gtf.shape, img.shape) + + _, *_, (_, d4, *_) = self._model(torch.cat([img_gtf, img_ctf], dim=0), return_features=True) + d4_gtf, d4_ctf = torch.chunk(d4, chunks=2, dim=0) + with FixRandomSeed(seed): + d4_ctf_gtf = torch.stack([self._transformer(x) for x in d4_ctf], dim=0) + assert d4_ctf_gtf.shape == d4_ctf.shape, (d4_ctf_gtf.shape, d4_ctf.shape) + d4_tf = torch.cat([d4_gtf, d4_ctf_gtf], dim=0) + local_enc_tf, local_enc_tf_ctf = torch.chunk(self._projection_head(d4_tf), chunks=2, dim=0) + # todo: convert representation to distance + local_enc_unfold, _ = unfold_position(local_enc_tf, partition_num=(2, 2)) + local_tf_enc_unfold, _fold_partition = unfold_position(local_enc_tf_ctf, partition_num=(2, 2)) + b, *_ = local_enc_unfold.shape + local_enc_unfold_norm = F.normalize(local_enc_unfold.view(b, -1), p=2, dim=1) + local_tf_enc_unfold_norm = F.normalize(local_tf_enc_unfold.view(b, -1), p=2, dim=1) + + labels = self._label_generation(partition_list, group_list, _fold_partition) + contrastive_loss = self._contrastive_criterion( + torch.stack([local_enc_unfold_norm, local_tf_enc_unfold_norm], dim=1), + labels=labels + ) + if torch.isnan(contrastive_loss): + raise RuntimeError(contrastive_loss) + self._optimizer.zero_grad() + contrastive_loss.backward() + self._optimizer.step() + # todo: meter recording. + with torch.no_grad(): + self.meters["contrastive_loss"].add(contrastive_loss.item()) + report_dict = self.meters.tracking_status() + indicator.set_postfix_dict(report_dict) + return report_dict diff --git a/contrastyou/epocher/_utils.py b/contrastyou/epocher/_utils.py index 167357de..632548be 100644 --- a/contrastyou/epocher/_utils.py +++ b/contrastyou/epocher/_utils.py @@ -1,8 +1,20 @@ -import random +from typing import List +import numpy as np from deepclustering2.type import to_device, torch -from deepclustering2.utils import assert_list -from torch import Tensor + + +def unique_mapping(name_list): + unique_map = np.unique(name_list) + mapping = {} + for i, u in enumerate(unique_map): + mapping[u] = i + return [mapping[n] for n in name_list] + + +def _string_list_adding(list1, list2): + assert len(list1) == len(list2) + return [x + "_" + y for x, y in zip(list1, list2)] def preprocess_input_with_twice_transformation(data, device, non_blocking=True): @@ -32,33 +44,41 @@ def unfold_position(features: torch.Tensor, partition_num=(4, 4), ): return torch.cat(result, dim=0), result_flag -class TensorRandomFlip: - def __init__(self, axis=None) -> None: - if isinstance(axis, int): - self._axis = [axis] - elif isinstance(axis, (list, tuple)): - assert_list(lambda x: isinstance(x, int), axis), axis - self._axis = axis - elif axis is None: - self._axis = axis - else: - raise ValueError(str(axis)) - - def __call__(self, tensor: Tensor): - tensor = tensor.clone() - if self._axis is not None: - for _one_axis in self._axis: - if random.random() < 0.5: - tensor = tensor.flip(_one_axis) - return tensor - else: - return tensor - - def __repr__(self): - string = f"{self.__class__.__name__}" - axis = "" if not self._axis else f" with axis={self._axis}." - - return string + axis +class GlobalLabelGenerator: + + def __init__(self, contrastive_on_patient=False, contrastive_on_partition=True) -> None: + self._contrastive_on_patient = contrastive_on_patient + self._contrastive_on_partition = contrastive_on_partition + + def __call__(self, partition_list: List[str], patient_list: List[str]) -> List[int]: + assert len(partition_list) == len(patient_list), (len(partition_list), len(patient_list)) + batch_size = len(partition_list) + + final_string = [""] * batch_size + if self._contrastive_on_patient: + final_string = _string_list_adding(final_string, patient_list) + + if self._contrastive_on_partition: + final_string = _string_list_adding(final_string, partition_list) + + return unique_mapping(final_string) + + +class LocalLabelGenerator(GlobalLabelGenerator): + + def __init__(self, ) -> None: + super().__init__(True, True) + + def __call__(self, partition_list: List[str], patient_list: List[str], location_list: List[str]) -> List[int]: + partition_list = [str(x) for x in partition_list] + patient_list = [str(x) for x in patient_list] + location_list = [str(x) for x in location_list] + mul_factor = int(len(location_list) // len(patient_list)) + partition_list = partition_list * mul_factor + patient_list = patient_list * mul_factor + assert len(location_list) == len(partition_list) + + return super().__call__(_string_list_adding(patient_list, partition_list), location_list) if __name__ == '__main__': diff --git a/contrastyou/epocher/base_epocher.py b/contrastyou/epocher/base_epocher.py index f5e13663..98f2c193 100644 --- a/contrastyou/epocher/base_epocher.py +++ b/contrastyou/epocher/base_epocher.py @@ -2,9 +2,7 @@ from typing import Union, Tuple import torch -from torch import nn -from torch.utils.data import DataLoader - +from deepclustering2.augment.tensor_augment import TensorRandomFlip from deepclustering2.decorator import FixRandomSeed from deepclustering2.epoch import _Epocher, proxy_trainer # noqa from deepclustering2.loss import simplex @@ -14,8 +12,10 @@ from deepclustering2.tqdm import tqdm from deepclustering2.trainer.trainer import T_loss, T_optim, T_loader from deepclustering2.utils import class2one_hot -from ._utils import preprocess_input_with_single_transformation, preprocess_input_with_twice_transformation, \ - TensorRandomFlip +from torch import nn +from torch.utils.data import DataLoader + +from ._utils import preprocess_input_with_single_transformation, preprocess_input_with_twice_transformation class EvalEpoch(_Epocher): @@ -82,7 +82,7 @@ def __init__(self, model: nn.Module, optimizer: T_optim, labeled_loader: T_loade @classmethod def create_from_trainer(cls, trainer): return cls( - model=trainer._model, optimizer=trainer._optimizer, labeled_loader=trainer._fine_tune_loader, + model=trainer._model, optimizer=trainer._optimizer, labeled_loader=trainer._fine_tune_loader_iter, sup_criterion=trainer._sup_criterion, num_batches=trainer._num_batches, cur_epoch=trainer._cur_epoch, device=trainer._device ) @@ -146,8 +146,9 @@ def __init__(self, model: nn.Module, teacher_model: nn.Module, optimizer: T_opti @classmethod def create_from_trainer(cls, trainer): return cls(model=trainer._model, teacher_model=trainer._teacher_model, optimizer=trainer._optimizer, - labeled_loader=trainer._fine_tune_loader, tra_loader=trainer._pretrain_loader, - sup_criterion=trainer._sup_criterion, reg_criterion=trainer._reg_criterion, num_batches=trainer._num_batches, + labeled_loader=trainer._fine_tune_loader_iter, tra_loader=trainer._pretrain_loader, + sup_criterion=trainer._sup_criterion, reg_criterion=trainer._reg_criterion, + num_batches=trainer._num_batches, cur_epoch=trainer._cur_epoch, device=trainer._device, transform_axis=trainer._transform_axis, reg_weight=trainer._reg_weight, ema_updater=trainer._ema_updater) diff --git a/contrastyou/epocher/contrast_epocher.py b/contrastyou/epocher/contrast_epocher.py index e5533e3f..12892d79 100644 --- a/contrastyou/epocher/contrast_epocher.py +++ b/contrastyou/epocher/contrast_epocher.py @@ -1,33 +1,67 @@ +import random from typing import List import torch -from torch import nn -from torch.nn import functional as F - from deepclustering2 import optim +from deepclustering2.decorator import FixRandomSeed from deepclustering2.epoch import _Epocher, proxy_trainer # noqa -from deepclustering2.loss import KL_div -from deepclustering2.meters2 import EpochResultDict, MeterInterface, AverageValueMeter, UniversalDice -from deepclustering2.tqdm import tqdm +from deepclustering2.meters2 import EpochResultDict, MeterInterface, AverageValueMeter from deepclustering2.optim import get_lrs_from_optimizer -from deepclustering2.trainer.trainer import T_loader, T_loss, T_optim -from deepclustering2.utils import simplex, class2one_hot, np -from ._utils import preprocess_input_with_twice_transformation, unfold_position +from deepclustering2.tqdm import tqdm +from deepclustering2.trainer.trainer import T_loader, T_loss +from torch import nn +from torch.nn import functional as F + +from ._utils import preprocess_input_with_twice_transformation, unfold_position, GlobalLabelGenerator, \ + LocalLabelGenerator class PretrainEncoderEpoch(_Epocher): """using a pretrained network to train with a data loader with contrastive loss.""" def __init__(self, model: nn.Module, projection_head: nn.Module, optimizer: optim.Optimizer, - pretrain_encoder_loader: T_loader = None, contrastive_criterion: T_loss = None, num_batches: int = 0, - cur_epoch=0, device="cpu", *args, - **kwargs) -> None: + pretrain_encoder_loader: T_loader, contrastive_criterion: T_loss, num_batches: int = 0, + cur_epoch=0, device="cpu", group_option: str = None) -> None: + """ + PretrainEncoder Epocher + :param model: nn.Module for a model + :param projection_head: shallow projection head + :param optimizer: optimizer for both network and shallow projection head. + :param pretrain_encoder_loader: dataloader for epocher + :param contrastive_criterion: contrastive loss, can be any loss given the normalized norm. + :param num_batches: num_batches to be used + :param cur_epoch: current epoch + :param device: device for images + :param group_option: group option for contrastive loss + :param args: additional args + :param kwargs: additional kwargs + """ super().__init__(model, cur_epoch, device) self._projection_head = projection_head self._optimizer = optimizer self._pretrain_encoder_loader = pretrain_encoder_loader self._contrastive_criterion = contrastive_criterion + assert isinstance(num_batches, int) and num_batches > 0, num_batches self._num_batches = num_batches + assert isinstance(group_option, str) and group_option in ("partition", "patient", "both"), group_option + self._group_option = group_option + self._init_label_generator(self._group_option) + + def _init_label_generator(self, group_option): + contrastive_on_partition = False + contrastive_on_patient = False + + if group_option == "partition": + contrastive_on_partition = True + if group_option == "patient": + contrastive_on_patient = True + if group_option == "both": + contrastive_on_patient = True + contrastive_on_partition = True + self._global_contrastive_label_generator = GlobalLabelGenerator( + contrastive_on_partition=contrastive_on_partition, + contrastive_on_patient=contrastive_on_patient + ) @classmethod def create_from_trainer(cls, trainer): @@ -36,7 +70,7 @@ def create_from_trainer(cls, trainer): def _configure_meters(self, meters: MeterInterface) -> MeterInterface: meters.register_meter("contrastive_loss", AverageValueMeter()) - meters.register_meter("lr",AverageValueMeter()) + meters.register_meter("lr", AverageValueMeter()) return meters def _run(self, *args, **kwargs) -> EpochResultDict: @@ -67,21 +101,26 @@ def _run(self, *args, **kwargs) -> EpochResultDict: def _preprocess_data(data, device): return preprocess_input_with_twice_transformation(data, device) - @staticmethod - def _label_generation(partition_list: List[str], group_list: List[str]): + def _label_generation(self, partition_list: List[str], group_list: List[str]): """override this to provide more mask """ - return partition_list + return self._global_contrastive_label_generator(partition_list=partition_list, + patient_list=group_list) class PretrainDecoderEpoch(PretrainEncoderEpoch): """using a pretrained network to train with a dataloader, for decoder part""" def __init__(self, model: nn.Module, projection_head: nn.Module, optimizer: optim.Optimizer, - pretrain_decoder_loader: T_loader = None, contrastive_criterion: T_loss = None, num_batches: int = 0, - cur_epoch=0, device="cpu", *args, **kwargs) -> None: + pretrain_decoder_loader: T_loader, contrastive_criterion: T_loss, num_batches: int = 0, cur_epoch=0, + device="cpu", ) -> None: super().__init__(model, projection_head, optimizer, pretrain_decoder_loader, contrastive_criterion, num_batches, - cur_epoch, device, *args, **kwargs) + cur_epoch, device, "both", ) self._pretrain_decoder_loader = self._pretrain_encoder_loader + from deepclustering2.augment.tensor_augment import TensorRandomFlip + self._transformer = TensorRandomFlip(axis=[1, 2], threshold=1) + + def _init_label_generator(self, group_option): + self._local_contrastive_label_generator = LocalLabelGenerator() def _run(self, *args, **kwargs) -> EpochResultDict: self._model.train() @@ -90,12 +129,22 @@ def _run(self, *args, **kwargs) -> EpochResultDict: with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: # noqa for i, data in zip(indicator, self._pretrain_decoder_loader): - (img, _), (img_tf, _), filename, partition_list, group_list = self._preprocess_data(data, self._device) - _, *_, (_, d4, *_) = self._model(torch.cat([img, img_tf], dim=0), return_features=True) - local_enc, local_tf_enc = torch.chunk(F.normalize(self._projection_head(d4), dim=1), chunks=2, dim=0) + (img, _), (img_ctf, _), filename, partition_list, group_list = self._preprocess_data(data, self._device) + seed = random.randint(0, int(1e5)) + with FixRandomSeed(seed): + img_gtf = torch.stack([self._transformer(x) for x in img], dim=0) + assert img_gtf.shape == img.shape, (img_gtf.shape, img.shape) + + _, *_, (_, d4, *_) = self._model(torch.cat([img_gtf, img_ctf], dim=0), return_features=True) + d4_gtf, d4_ctf = torch.chunk(d4, chunks=2, dim=0) + with FixRandomSeed(seed): + d4_ctf_gtf = torch.stack([self._transformer(x) for x in d4_ctf], dim=0) + assert d4_ctf_gtf.shape == d4_ctf.shape, (d4_ctf_gtf.shape, d4_ctf.shape) + d4_tf = torch.cat([d4_gtf, d4_ctf_gtf]) + local_enc_tf, local_enc_tf_ctf = torch.chunk(self._projection_head(d4_tf), chunks=2, dim=0) # todo: convert representation to distance - local_enc_unfold, _ = unfold_position(local_enc, partition_num=(2, 2)) - local_tf_enc_unfold, _fold_partition = unfold_position(local_tf_enc, partition_num=(2, 2)) + local_enc_unfold, _ = unfold_position(local_enc_tf, partition_num=(2, 2)) + local_tf_enc_unfold, _fold_partition = unfold_position(local_enc_tf_ctf, partition_num=(2, 2)) b, *_ = local_enc_unfold.shape local_enc_unfold_norm = F.normalize(local_enc_unfold.view(b, -1), p=2, dim=1) local_tf_enc_unfold_norm = F.normalize(local_tf_enc_unfold.view(b, -1), p=2, dim=1) @@ -117,18 +166,6 @@ def _run(self, *args, **kwargs) -> EpochResultDict: indicator.set_postfix_dict(report_dict) return report_dict - @staticmethod - def _label_generation(partition_list: List[str], group_list: List[str], folder_partition: List[str]): - if len(folder_partition) > len(partition_list): - ratio = int(len(folder_partition) / len(partition_list)) - partition_list = partition_list.tolist() * ratio - group_list = group_list * ratio - - def tolabel(encode): - unique_labels = np.unique(encode) - mapping = {k: i for i, k in enumerate(unique_labels)} - return [mapping[k] for k in encode] - - return tolabel([str(g) + str(p) + str(f) for g, p, f in zip(group_list, partition_list, folder_partition)]) - - + def _label_generation(self, partition_list: List[str], patient_list: List[str], location_list: List[str]): + return self._local_contrastive_label_generator(partition_list=partition_list, patient_list=patient_list, + location_list=location_list) diff --git a/contrastyou/trainer/contrast_trainer.py b/contrastyou/trainer/contrast_trainer.py index 69496c7a..cb5526fd 100644 --- a/contrastyou/trainer/contrast_trainer.py +++ b/contrastyou/trainer/contrast_trainer.py @@ -3,9 +3,6 @@ from pathlib import Path import torch -from torch import nn -from torch.utils.data import DataLoader - from contrastyou import PROJECT_PATH from contrastyou.epocher import PretrainEncoderEpoch, PretrainDecoderEpoch, SimpleFineTuneEpoch, MeanTeacherEpocher from contrastyou.epocher.base_epocher import EvalEpoch @@ -16,6 +13,8 @@ from deepclustering2.schedulers import GradualWarmupScheduler from deepclustering2.trainer.trainer import Trainer, T_loader from deepclustering2.writer import SummaryWriter +from torch import nn +from torch.utils.data import DataLoader class ContrastTrainer(Trainer): @@ -24,7 +23,7 @@ class ContrastTrainer(Trainer): def __init__(self, model: nn.Module, pretrain_loader: T_loader, fine_tune_loader: T_loader, val_loader: DataLoader, save_dir: str = "base", max_epoch_train_encoder: int = 100, max_epoch_train_decoder: int = 100, max_epoch_train_finetune: int = 100, num_batches: int = 256, device: str = "cpu", configuration=None, - train_encoder: bool = True, train_decoder: bool = True, *args, **kwargs): + train_encoder: bool = True, train_decoder: bool = True): """ ContrastTraining Trainer :param model: nn.module network to be pretrained @@ -63,7 +62,7 @@ def __init__(self, model: nn.Module, pretrain_loader: T_loader, fine_tune_loader self._projector = None self._sup_criterion = None - def pretrain_encoder_init(self, *args, **kwargs): + def pretrain_encoder_init(self, group_option: str): # adding optimizer and scheduler self._projector = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), @@ -78,6 +77,12 @@ def pretrain_encoder_init(self, *args, **kwargs): self._max_epoch_train_encoder - 10, 0) self._scheduler = GradualWarmupScheduler(self._optimizer, 300, 10, self._scheduler) # noqa + self._group_option = group_option # noqa + + # set augmentation method as `total_freedom = True` + self._pretrain_loader.dataset._transform._total_freedom = True # noqa + self._pretrain_loader_iter = iter(self._pretrain_loader) # noqa + def pretrain_encoder_run(self): self.to(self._device) self._model.enable_grad_encoder() # noqa @@ -87,9 +92,9 @@ def pretrain_encoder_run(self): pretrain_encoder_dict = PretrainEncoderEpoch( model=self._model, projection_head=self._projector, optimizer=self._optimizer, - pretrain_encoder_loader=self._pretrain_loader, + pretrain_encoder_loader=self._pretrain_loader_iter, contrastive_criterion=SupConLoss(), num_batches=self._num_batches, - cur_epoch=self._cur_epoch, device=self._device + cur_epoch=self._cur_epoch, device=self._device, group_option=self._group_option ).run() self._scheduler.step() storage_dict = StorageIncomeDict(PRETRAIN_ENCODER=pretrain_encoder_dict, ) @@ -112,6 +117,11 @@ def pretrain_decoder_init(self, *args, **kwargs): self._max_epoch_train_decoder - 10, 0) self._scheduler = GradualWarmupScheduler(self._optimizer, 300, 10, self._scheduler) + # set augmentation method as `total_freedom = False` + self._pretrain_loader.dataset._transform._total_freedom = False # noqa + + self._pretrain_loader_iter = iter(self._pretrain_loader) # noqa + def pretrain_decoder_run(self): self.to(self._device) self._projector.to(self._device) @@ -123,9 +133,9 @@ def pretrain_decoder_run(self): pretrain_decoder_dict = PretrainDecoderEpoch( model=self._model, projection_head=self._projector, optimizer=self._optimizer, - pretrain_decoder_loader=self._pretrain_loader, + pretrain_decoder_loader=self._pretrain_loader_iter, contrastive_criterion=SupConLoss(), num_batches=self._num_batches, - cur_epoch=self._cur_epoch, device=self._device + cur_epoch=self._cur_epoch, device=self._device, ).run() self._scheduler.step() storage_dict = StorageIncomeDict(PRETRAIN_DECODER=pretrain_decoder_dict, ) @@ -136,12 +146,16 @@ def pretrain_decoder_run(self): def finetune_network_init(self, *args, **kwargs): - self._optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-7, weight_decay=1e-5) + self._optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-6, weight_decay=1e-5) self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer, self._max_epoch_train_finetune - 10, 0) - self._scheduler = GradualWarmupScheduler(self._optimizer, 200, 10, self._scheduler) + self._scheduler = GradualWarmupScheduler(self._optimizer, 300, 10, self._scheduler) self._sup_criterion = KL_div() + # set augmentation method as `total_freedom = True` + self._fine_tune_loader.dataset._transform._total_freedom = True # noqa + self._fine_tune_loader_iter = iter(self._fine_tune_loader) # noqa + def finetune_network_run(self, epocher_type=SimpleFineTuneEpoch): self.to(self._device) self._model.enable_grad_decoder() # noqa @@ -157,11 +171,22 @@ def finetune_network_run(self, epocher_type=SimpleFineTuneEpoch): self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch) self.save(cur_score, os.path.join(self._save_dir, "finetune")) - def start_training(self, checkpoint: str = None): - + def start_training( + self, checkpoint: str = None, + pretrain_encoder_init_options=None, + pretrain_decoder_init_options=None, + finetune_network_init_options=None + ): + + if finetune_network_init_options is None: + finetune_network_init_options = {} + if pretrain_decoder_init_options is None: + pretrain_decoder_init_options = {} + if pretrain_encoder_init_options is None: + pretrain_encoder_init_options = {} with SummaryWriter(str(self._save_dir)) as self._writer: # noqa if self.train_encoder: - self.pretrain_encoder_init() + self.pretrain_encoder_init(**pretrain_encoder_init_options) if checkpoint is not None: try: self.load_state_dict_from_path(os.path.join(checkpoint, "pretrain_encoder")) @@ -171,7 +196,7 @@ def start_training(self, checkpoint: str = None): if not self.train_encoder_done: self.pretrain_encoder_run() if self.train_decoder: - self.pretrain_decoder_init() + self.pretrain_decoder_init(**pretrain_decoder_init_options) if checkpoint is not None: try: self.load_state_dict_from_path(os.path.join(checkpoint, "pretrain_decoder")) @@ -179,7 +204,7 @@ def start_training(self, checkpoint: str = None): print(f"loading pretrain_decoder_checkpoint failed with {e}, ") if not self.train_decoder_done: self.pretrain_decoder_run() - self.finetune_network_init() + self.finetune_network_init(**finetune_network_init_options) if checkpoint is not None: try: self.load_state_dict_from_path(os.path.join(checkpoint, "finetune")) @@ -190,21 +215,13 @@ def start_training(self, checkpoint: str = None): class ContrastTrainerMT(ContrastTrainer): - def __init__(self, model: nn.Module, pretrain_loader: T_loader, fine_tune_loader: T_loader, val_loader: DataLoader, - save_dir: str = "base", max_epoch_train_encoder: int = 100, max_epoch_train_decoder: int = 100, - max_epoch_train_finetune: int = 100, num_batches: int = 256, device: str = "cpu", configuration=None, - train_encoder: bool = True, train_decoder: bool = True, transform_axis=[1, 2], reg_weight=0.0, - reg_criterion=nn.MSELoss(), *args, **kwargs): - super().__init__(model, pretrain_loader, fine_tune_loader, val_loader, save_dir, max_epoch_train_encoder, - max_epoch_train_decoder, max_epoch_train_finetune, num_batches, device, configuration, - train_encoder, train_decoder, *args, **kwargs) - self._teacher_model = None - self._transform_axis = transform_axis + def finetune_network_init(self, reg_weight: float = 0.0, reg_criterion=nn.MSELoss(), transform_axis=[1, 2], *args, **kwargs): + super().finetune_network_init() + self._reg_weight = reg_weight self._reg_criterion = reg_criterion + self._transform_axis = transform_axis - def finetune_network_init(self, *args, **kwargs): - super().finetune_network_init(*args, **kwargs) from contrastyou.arch import UNet from deepclustering2.models import ema_updater # here we initialize the MT diff --git a/contrastyou/trainer/iic_trainer.py b/contrastyou/trainer/iic_trainer.py index 5ff6bc06..63552f16 100644 --- a/contrastyou/trainer/iic_trainer.py +++ b/contrastyou/trainer/iic_trainer.py @@ -19,7 +19,7 @@ class IICContrastTrainer(ContrastTrainer): RUN_PATH = Path(PROJECT_PATH) / "runs" - def pretrain_encoder_init(self, num_clusters=20, *args, **kwargs): + def pretrain_encoder_init(self, group_option, num_clusters=20): # adding optimizer and scheduler self._projector_contrastive = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), diff --git a/main_contrast.py b/main_contrast.py index d0876609..1227c877 100644 --- a/main_contrast.py +++ b/main_contrast.py @@ -5,7 +5,7 @@ from contrastyou import DATA_PATH, PROJECT_PATH from contrastyou.arch import UNet -from contrastyou.augment import ACDC_transforms +from contrastyou.augment import ACDCTransforms from contrastyou.dataloader._seg_datset import ContrastBatchSampler # noqa from contrastyou.dataloader.acdc_dataset import ACDCSemiInterface, ACDCDataset from contrastyou.trainer import trainer_zoos @@ -28,11 +28,11 @@ unlabeled_data_ratio=config["Data"]["unlabeled_data_ratio"]) label_set, unlabel_set, val_set = acdc_manager._create_semi_supervised_datasets( # noqa - labeled_transform=ACDC_transforms.train, - unlabeled_transform=ACDC_transforms.train, - val_transform=ACDC_transforms.val + labeled_transform=ACDCTransforms.train, + unlabeled_transform=ACDCTransforms.train, + val_transform=ACDCTransforms.val ) -train_set = ACDCDataset(root_dir=DATA_PATH, mode="train", transforms=ACDC_transforms.train) +train_set = ACDCDataset(root_dir=DATA_PATH, mode="train", transforms=ACDCTransforms.train) # all training set is with ContrastBatchSampler train_loader = DataLoader(train_set, # noqa @@ -51,7 +51,9 @@ checkpoint = config.pop("Checkpoint", None) Trainer = trainer_zoos[config["Trainer"].pop("name")] assert Trainer, Trainer -trainer = Trainer(model=model, pretrain_loader=iter(train_loader), fine_tune_loader=iter(labeled_loader), +trainer = Trainer(model=model, pretrain_loader=train_loader, fine_tune_loader=labeled_loader, val_loader=val_loader, configuration=cmanager.config, **config["Trainer"], ) -trainer.start_training(checkpoint=checkpoint) +trainer.start_training(checkpoint=checkpoint, pretrain_encoder_init_options=config["PretrainEncoder"], + pretrain_decoder_init_options=config["PretrainDecoder"], + finetune_network_init_options=config["FineTune"]) diff --git a/references/deepclustering2 b/references/deepclustering2 index 1b3841ca..d66f8bfb 160000 --- a/references/deepclustering2 +++ b/references/deepclustering2 @@ -1 +1 @@ -Subproject commit 1b3841caabab895ff361d3cbe2ac1bbd0b4d0e68 +Subproject commit d66f8bfbc181858292c3e227d90c472cf4bd8269 diff --git a/run_script.py b/run_script.py index c6ca6a44..615371cf 100644 --- a/run_script.py +++ b/run_script.py @@ -5,10 +5,13 @@ parser = argparse.ArgumentParser() -parser.add_argument("--label_ratio", default=0.1, type=float) -parser.add_argument("--trainer_name", required=True, type=str) -parser.add_argument("--num_batches", default=500, type=int) -parser.add_argument("--random_seed", default=1, type=int) +parser.add_argument("-l", "--label_ratio", default=0.1, type=float) +parser.add_argument("-n", "--trainer_name", required=True, type=str) +parser.add_argument("-b", "--num_batches", default=100, type=int) +parser.add_argument("-s", "--random_seed", default=1, type=int) +parser.add_argument("-o", "--contrast_on", default="partition", type=str) +parser.add_argument("-w", "--reg_weight", default=0.0, type=float) + args = parser.parse_args() num_batches = args.num_batches @@ -19,9 +22,18 @@ # trainer_name="contrast" # or contrastMT trainer_name = args.trainer_name -save_dir = f"label_data_ration_{labeled_data_ratio}/{trainer_name}" +contrast_on = args.contrast_on +save_dir = f"label_data_ration_{labeled_data_ratio}/{trainer_name}/contrast_on_{contrast_on}" + +if trainer_name == "contrastMT": + save_dir = save_dir + f"/reg_weight_{args.reg_weight:.2f}" + +common_opts = f" Trainer.name={trainer_name} PretrainEncoder.group_option={contrast_on} RandomSeed={random_seed} " \ + f" Data.labeled_data_ratio={labeled_data_ratio} Data.unlabeled_data_ratio={unlabeled_data_ratio} " \ + f" Trainer.num_batches={num_batches} " +if trainer_name == "contrastMT": + common_opts += f" FineTune.reg_weight={args.reg_weight} " -common_opts = f" Trainer.name={trainer_name} RandomSeed={random_seed} Data.labeled_data_ratio={labeled_data_ratio} Data.unlabeled_data_ratio={unlabeled_data_ratio} Trainer.num_batches={num_batches} " jobs = [ f"python main_contrast.py {common_opts} Trainer.save_dir={save_dir}/baseline Trainer.train_encoder=False Trainer.train_decoder=False ", f"python main_contrast.py {common_opts} Trainer.save_dir={save_dir}/encoder Trainer.train_encoder=True Trainer.train_decoder=False ", @@ -31,8 +43,9 @@ # CC things accounts = cycle(["def-chdesa", "def-mpederso", "rrg-mpederso"]) -jobsubmiter = JobSubmiter(project_path="./", on_local=False, time=8) +jobsubmiter = JobSubmiter(project_path="./", on_local=False, time=4) for j in jobs: jobsubmiter.prepare_env(["source ./venv/bin/activate ", "export OMP_NUM_THREADS=1", ]) jobsubmiter.account = next(accounts) jobsubmiter.run(j) + # print(j) diff --git a/test/test_sequenetial_wrapper.py b/test/test_sequenetial_wrapper.py new file mode 100644 index 00000000..d3db02e2 --- /dev/null +++ b/test/test_sequenetial_wrapper.py @@ -0,0 +1,126 @@ +from io import BytesIO +from unittest import TestCase + +import matplotlib.pyplot as plt +import numpy as np +import requests +from PIL import Image +from torchvision import transforms + +from contrastyou.augment.sequential_wrapper import SequentialWrapper, SequentialWrapperTwice +from deepclustering2.augment import pil_augment +from deepclustering2.decorator import FixRandomSeed + +url = "https://www.sciencemag.org/sites/default/files/styles/article_main_image_-_1280w__no_aspect_/public/dogs_1280p_0.jpg?itok=6jQzdNB8" +response = requests.get(url) + + +class TestTransformationWrapper(TestCase): + def setUp(self) -> None: + super().setUp() + self._img1 = Image.open(BytesIO(response.content)) + self._img2 = Image.fromarray((255.0 - np.asarray(self._img1)).astype(np.uint8)) + self._target1 = Image.fromarray((np.asarray(self._img1) < 128).astype(np.uint8)).convert("L") + self._target2 = Image.fromarray((np.asarray(self._img1) >= 128).astype(np.uint8)).convert("L") + + def test_sqeuential_wrapper(self): + comm_transform = pil_augment.Compose([ + pil_augment.RandomCrop(224), + pil_augment.RandomRotation(23), + ]) + img_transform = pil_augment.ToTensor() + target_transform = pil_augment.ToLabel() + wrapper = SequentialWrapper(comm_transform=comm_transform, + img_transform=img_transform, + target_transform=target_transform) + + imgs, targets = wrapper(imgs=[self._img1, self._img2], targets=[self._target1, self._target2]) + + plt.imshow(imgs[0].numpy().transpose(1, 2, 0)) + plt.show() + plt.imshow(imgs[1].numpy().transpose(1, 2, 0)) + plt.show() + plt.imshow(targets[0].numpy().squeeze()) + plt.show() + + plt.imshow(targets[1].numpy().squeeze()) + plt.show() + + def test_on_double(self): + comm_transform = pil_augment.Compose([ + pil_augment.RandomCrop(224), + pil_augment.RandomRotation(23), + ]) + img_transform = pil_augment.Compose([ + transforms.ColorJitter(brightness=[0.5, 1.1]), + pil_augment.ToTensor() + ]) + target_transform = pil_augment.ToLabel() + wrapper = SequentialWrapper(comm_transform=comm_transform, + img_transform=img_transform, + target_transform=target_transform) + with FixRandomSeed(2): + imgs1, targets1 = wrapper(imgs=[self._img1, self._img2], targets=[self._target1, self._target2]) + with FixRandomSeed(2): + imgs2, targets2 = wrapper(imgs=[self._img1, self._img2], targets=[self._target1, self._target2], img_seed=3) + plt.imshow(imgs1[0].numpy().transpose(1, 2, 0)) + plt.show() + plt.imshow(targets1[0].numpy().squeeze()) + plt.show() + plt.imshow(imgs2[0].numpy().transpose(1, 2, 0)) + plt.show() + plt.imshow(targets2[0].numpy().squeeze()) + plt.show() + + def test_on_img(self): + comm_transform = pil_augment.Compose([ + pil_augment.RandomCrop(224), + pil_augment.RandomRotation(23), + ]) + img_transform = pil_augment.Compose([ + transforms.ColorJitter(brightness=[0.5, 1.1]), + pil_augment.ToTensor() + ]) + target_transform = pil_augment.ToLabel() + wrapper = SequentialWrapper(comm_transform=comm_transform, + img_transform=img_transform, + target_transform=target_transform) + imgs1, targets1 = wrapper(imgs=[self._img1, self._img2]) + + def test_on_twice_wrapper(self): + comm_transform = pil_augment.Compose([ + pil_augment.RandomCrop(224), + pil_augment.RandomRotation(23), + ]) + img_transform = pil_augment.Compose([ + transforms.ColorJitter(brightness=[0, 1.1], contrast=[0, 1.5]), + pil_augment.ToTensor() + ]) + target_transform = pil_augment.ToLabel() + wrapper = SequentialWrapperTwice(comm_transform=comm_transform, + img_transform=img_transform, + target_transform=target_transform, + total_freedom=True) + (imgs1, targets1), (imgs2, targets2) = wrapper(imgs=[self._img1, self._img2], targets=[self._target1]) + plt.imshow(imgs1[0].numpy().transpose(1, 2, 0)) + plt.show() + plt.imshow(targets1[0].numpy().squeeze()) + plt.show() + plt.imshow(imgs2[0].numpy().transpose(1, 2, 0)) + plt.show() + plt.imshow(targets2[0].numpy().squeeze()) + plt.show() + + wrapper = SequentialWrapperTwice(comm_transform=comm_transform, + img_transform=img_transform, + target_transform=target_transform, + total_freedom=False) + (imgs1, targets1), (imgs2, targets2) = wrapper(imgs=[self._img1, self._img2], targets=[self._target1]) + plt.imshow(imgs1[0].numpy().transpose(1, 2, 0)) + plt.show() + plt.imshow(targets1[0].numpy().squeeze()) + plt.show() + plt.imshow(imgs2[0].numpy().transpose(1, 2, 0)) + plt.show() + plt.imshow(targets2[0].numpy().squeeze()) + plt.show()