From f0d7cc95b015aec2eb749db8d4d7c0cd5b9e69fe Mon Sep 17 00:00:00 2001 From: jizong Date: Tue, 21 Jul 2020 16:56:25 -0400 Subject: [PATCH] adding mean teacher as a baselien --- byol_demo/__init__.py | 0 demo_cifar.py => byol_demo/byol_cifar.py | 30 +- byol_demo/utils.py | 17 ++ config/config.yaml | 15 +- contrastyou/epocher/__init__.py | 2 + contrastyou/epocher/_utils.py | 42 ++- contrastyou/epocher/base_epocher.py | 366 +++++++++++++---------- contrastyou/epocher/contrast_epocher.py | 69 +---- contrastyou/trainer/__init__.py | 5 +- contrastyou/trainer/base_trainer.py | 44 --- contrastyou/trainer/contrast_trainer.py | 75 +++-- main_contrast.py | 12 +- references/deepclustering2 | 2 +- run_script.py | 12 +- 14 files changed, 353 insertions(+), 338 deletions(-) create mode 100644 byol_demo/__init__.py rename demo_cifar.py => byol_demo/byol_cifar.py (94%) create mode 100644 byol_demo/utils.py delete mode 100644 contrastyou/trainer/base_trainer.py diff --git a/byol_demo/__init__.py b/byol_demo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/demo_cifar.py b/byol_demo/byol_cifar.py similarity index 94% rename from demo_cifar.py rename to byol_demo/byol_cifar.py index c353c256..4f2f2672 100644 --- a/demo_cifar.py +++ b/byol_demo/byol_cifar.py @@ -8,6 +8,7 @@ from torch import nn from torch.utils.data import DataLoader +from byol_demo.utils import loss_fn, TransTwice from contrastyou import DATA_PATH, PROJECT_PATH from deepclustering2 import ModelMode from deepclustering2.augment import pil_augment @@ -19,22 +20,9 @@ from deepclustering2.tqdm import tqdm from deepclustering2.trainer.trainer import T_loader, Trainer from deepclustering2.writer import SummaryWriter -from torch.nn import functional as F -class TransTwice: - - def __init__(self, transform) -> None: - super().__init__() - self._transform = transform - - def __call__(self, img): - return [self._transform(img), self._transform(img)] - -def loss_fn(x, y): - x = F.normalize(x, dim=-1, p=2) - y = F.normalize(y, dim=-1, p=2) - return 2 - 2 * (x * y).sum(dim=-1) +# todo: redo that class ContrastEpocher(_Epocher): def __init__(self, model: Model, target_model: EMA_Model, data_loader: T_loader, num_batches: int = 1000, cur_epoch=0, device="cpu") -> None: @@ -84,6 +72,7 @@ def _preprocess_data(data, device): return (data[0][0].to(device), data[0][1].to(device)), data[1].to(device) +# todo: redo that class FineTuneEpocher(_Epocher): def __init__(self, model: Model, classify_model: Model, data_loader: T_loader, num_batches: int = 1000, cur_epoch=0, device="cpu") -> None: @@ -134,6 +123,7 @@ def _preprocess_data(data, device): return (data[0][0].to(device), data[0][1].to(device)), data[1].to(device) +# todo: redo that class EvalEpocher(FineTuneEpocher): def __init__(self, model: Model, classify_model: Model, val_loader, num_batches: int = 1000, cur_epoch=0, @@ -184,18 +174,6 @@ def __init__(self, model: Model, target_model: EMA_Model, classify_model: Model, self._finetune_loader = finetune_loader self._val_loader = val_loader - def pretrain_epoch(self, *args, **kwargs): - epocher = ContrastEpocher.create_from_trainer(self) - return epocher.run() - - def finetune_epoch(self, *args, **kwargs): - epocher = FineTuneEpocher.create_from_trainer(self) - return epocher.run() - - def eval_epoch(self): - epocher = EvalEpocher.create_from_trainer(self) - return epocher.run() - def _start_contrastive_training(self): save_path = os.path.join(self._save_dir, "pretrain") Path(save_path).mkdir(exist_ok=True, parents=True) diff --git a/byol_demo/utils.py b/byol_demo/utils.py new file mode 100644 index 00000000..ab71f3f5 --- /dev/null +++ b/byol_demo/utils.py @@ -0,0 +1,17 @@ +from torch.nn import functional as F + + +class TransTwice: + + def __init__(self, transform) -> None: + super().__init__() + self._transform = transform + + def __call__(self, img): + return [self._transform(img), self._transform(img)] + + +def loss_fn(x, y): + x = F.normalize(x, dim=-1, p=2) + y = F.normalize(y, dim=-1, p=2) + return 2 - 2 * (x * y).sum(dim=-1) diff --git a/config/config.yaml b/config/config.yaml index cb6b956d..0bd33345 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -4,18 +4,23 @@ Arch: input_dim: 1 num_classes: 4 +Data: + labeled_data_ratio: 0.05 + unlabeled_data_ratio: 0.95 + Trainer: - save_dir: test_pipeline + name: contrast + save_dir: test_semi_trainer device: cuda - num_batches: 1000 + num_batches: 500 max_epoch_train_decoder: 100 max_epoch_train_encoder: 100 max_epoch_train_finetune: 100 train_encoder: True train_decoder: True + # for mt trainer + transform_axis: [1, 2] + reg_weight: 10, -Data: - labeled_data_ratio: 0.05 - unlabeled_data_ratio: 0.95 #Checkpoint: runs/test_pipeline \ No newline at end of file diff --git a/contrastyou/epocher/__init__.py b/contrastyou/epocher/__init__.py index e69de29b..ad3cfe8c 100644 --- a/contrastyou/epocher/__init__.py +++ b/contrastyou/epocher/__init__.py @@ -0,0 +1,2 @@ +from .base_epocher import EvalEpoch, SimpleFineTuneEpoch, MeanTeacherEpocher +from .contrast_epocher import PretrainDecoderEpoch, PretrainEncoderEpoch \ No newline at end of file diff --git a/contrastyou/epocher/_utils.py b/contrastyou/epocher/_utils.py index 64489049..167357de 100644 --- a/contrastyou/epocher/_utils.py +++ b/contrastyou/epocher/_utils.py @@ -1,14 +1,19 @@ +import random + from deepclustering2.type import to_device, torch +from deepclustering2.utils import assert_list +from torch import Tensor -def preprocess_input_with_twice_transformation(data, device): +def preprocess_input_with_twice_transformation(data, device, non_blocking=True): [(image, target), (image_tf, target_tf)], filename, partition_list, group_list = \ - to_device(data[0], device), data[1], data[2], data[3] + to_device(data[0], device, non_blocking), data[1], data[2], data[3] return (image, target), (image_tf, target_tf), filename, partition_list, group_list -def preprocess_input_with_single_transformation(data, device): - return data[0][0].to(device), data[0][1].to(device), data[1], data[2], data[3] +def preprocess_input_with_single_transformation(data, device, non_blocking=True): + return data[0][0].to(device, non_blocking=non_blocking), data[0][1].to(device, non_blocking=non_blocking), data[1], \ + data[2], data[3] def unfold_position(features: torch.Tensor, partition_num=(4, 4), ): @@ -27,6 +32,35 @@ def unfold_position(features: torch.Tensor, partition_num=(4, 4), ): return torch.cat(result, dim=0), result_flag +class TensorRandomFlip: + def __init__(self, axis=None) -> None: + if isinstance(axis, int): + self._axis = [axis] + elif isinstance(axis, (list, tuple)): + assert_list(lambda x: isinstance(x, int), axis), axis + self._axis = axis + elif axis is None: + self._axis = axis + else: + raise ValueError(str(axis)) + + def __call__(self, tensor: Tensor): + tensor = tensor.clone() + if self._axis is not None: + for _one_axis in self._axis: + if random.random() < 0.5: + tensor = tensor.flip(_one_axis) + return tensor + else: + return tensor + + def __repr__(self): + string = f"{self.__class__.__name__}" + axis = "" if not self._axis else f" with axis={self._axis}." + + return string + axis + + if __name__ == '__main__': features = torch.randn(10, 3, 256, 256, requires_grad=True) diff --git a/contrastyou/epocher/base_epocher.py b/contrastyou/epocher/base_epocher.py index 46fb49ff..f8595ea3 100644 --- a/contrastyou/epocher/base_epocher.py +++ b/contrastyou/epocher/base_epocher.py @@ -1,180 +1,216 @@ +import random from typing import Union, Tuple import torch -from torch.utils.data import DataLoader - -from deepclustering2 import ModelMode +from deepclustering2.decorator import FixRandomSeed from deepclustering2.epoch import _Epocher, proxy_trainer # noqa from deepclustering2.loss import simplex from deepclustering2.meters2 import EpochResultDict, MeterInterface, AverageValueMeter, UniversalDice from deepclustering2.models import Model from deepclustering2.tqdm import tqdm -from deepclustering2.trainer.trainer import T_loader, T_loss +from deepclustering2.trainer.trainer import T_loss, T_optim, T_loader from deepclustering2.utils import class2one_hot -from ._utils import preprocess_input_with_single_transformation, preprocess_input_with_twice_transformation - - -class FSEpocher: - class TrainEpoch(_Epocher): - - def __init__(self, model: Model, data_loader: T_loader, sup_criteiron: T_loss, num_batches: int = 100, - cur_epoch=0, device="cpu") -> None: - super().__init__(model, cur_epoch, device) - self._data_loader = data_loader - self._sup_criterion = sup_criteiron - self._num_batches = num_batches - - @classmethod - @proxy_trainer - def create_from_trainer(cls, trainer): - return cls(trainer._model, trainer._tra_loader, trainer._sup_criterion, trainer._num_batches, # noqa - trainer._cur_epoch, trainer._device) # noqa - - def _configure_meters(self, meters: MeterInterface) -> MeterInterface: - meters.register_meter("lr", AverageValueMeter()) - meters.register_meter("sup_loss", AverageValueMeter()) - meters.register_meter("ds", UniversalDice(4, [1, 2, 3])) - return meters - - def _run(self, *args, **kwargs) -> Union[EpochResultDict, Tuple[EpochResultDict, float]]: - self._model.set_mode(ModelMode.TRAIN) - assert self._model.training, self._model.training - self.meters["lr"].add(self._model.get_lr()[0]) - with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: # noqa - for i, data in zip(indicator, self._data_loader): - images, targets, filename, partition_list, group_list = self._preprocess_data(data, self._device) - predict_logits = self._model(images) - assert not simplex(predict_logits), predict_logits.shape - onehot_targets = class2one_hot(targets.squeeze(1), 4) - loss = self._sup_criterion(predict_logits.softmax(1), onehot_targets) - self._model.zero_grad() - loss.backward() - self._model.step() - with torch.no_grad(): - self.meters["sup_loss"].add(loss.item()) - self.meters["ds"].add(predict_logits.max(1)[1], targets.squeeze(1), group_name=list(group_list)) +from torch import nn +from torch.utils.data import DataLoader + +from ._utils import preprocess_input_with_single_transformation, preprocess_input_with_twice_transformation, \ + TensorRandomFlip + + +class EvalEpoch(_Epocher): + + def __init__(self, model: Union[Model, nn.Module], val_loader: DataLoader, sup_criterion: T_loss, cur_epoch=0, + device="cpu") -> None: + """ + :param model: Model or nn.Module instance, network + :param val_loader: validation loader that is an instance of DataLoader, without infinitesampler + :param sup_criterion: Supervised loss to record the val_loss + :param cur_epoch: current epoch to record + :param device: cuda or cpu + """ + super().__init__(model, cur_epoch, device) + assert isinstance(val_loader, DataLoader), f"`val_loader` should be an instance of `DataLoader`, " \ + f"given {val_loader.__class__.__name__}" + assert callable(sup_criterion), f"sup_criterion must be callable, given {sup_criterion.__class__.__name__}" + self._val_loader = val_loader + self._sup_criterion = sup_criterion + + @classmethod + def create_from_trainer(cls, trainer): + pass + + def _configure_meters(self, meters: MeterInterface) -> MeterInterface: + meters.register_meter("lr", AverageValueMeter()) + meters.register_meter("sup_loss", AverageValueMeter()) + meters.register_meter("ds", UniversalDice(4, [1, 2, 3])) + return meters + + @torch.no_grad() + def _run(self, *args, **kwargs) -> Tuple[EpochResultDict, float]: + self._model.eval() + assert not self._model.training, self._model.training + with tqdm(self._val_loader).set_desc_from_epocher(self) as indicator: + for i, data in enumerate(indicator): + images, targets, filename, partiton_list, group_list = self._preprocess_data(data, self._device) + predict_logits = self._model(images) + assert not simplex(predict_logits), predict_logits.shape + onehot_targets = class2one_hot(targets.squeeze(1), 4) + loss = self._sup_criterion(predict_logits.softmax(1), onehot_targets, disable_assert=True) + self.meters["sup_loss"].add(loss.item()) + self.meters["ds"].add(predict_logits.max(1)[1], targets.squeeze(1), group_name=list(group_list)) + report_dict = self.meters.tracking_status() + indicator.set_postfix_dict(report_dict) + report_dict = self.meters.tracking_status() + return report_dict, report_dict["ds"]["DSC_mean"] + + @staticmethod + def _preprocess_data(data, device, non_blocking=True): + return preprocess_input_with_single_transformation(data, device, non_blocking) + + +class SimpleFineTuneEpoch(_Epocher): + def __init__(self, model: nn.Module, optimizer: T_optim, labeled_loader: T_loader, num_batches: int = 100, + sup_criterion: T_loss = None, cur_epoch=0, device="cpu") -> None: + super().__init__(model, cur_epoch, device) + assert isinstance(num_batches, int) and num_batches > 0, num_batches + assert callable(sup_criterion), sup_criterion + self._labeled_loader = labeled_loader + self._sup_criterion = sup_criterion + self._num_batches = num_batches + self._optimizer = optimizer + + @classmethod + def create_from_trainer(cls, trainer): + return cls( + model=trainer._model, optimizer=trainer._optimizer, labeled_loader=trainer._fine_tune_loader, + sup_criterion=trainer._sup_criterion, num_batches=trainer._num_batches, cur_epoch=trainer._cur_epoch, + device=trainer._device + ) + + def _configure_meters(self, meters: MeterInterface) -> MeterInterface: + meters.register_meter("lr", AverageValueMeter()) + meters.register_meter("sup_loss", AverageValueMeter()) + meters.register_meter("ds", UniversalDice(4, [1, 2, 3])) + return meters + + def _run(self, *args, **kwargs) -> EpochResultDict: + self._model.train() + assert self._model.training, self._model.training + report_dict: EpochResultDict + self.meters["lr"].add(self._optimizer.param_groups[0]["lr"]) + with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: + for i, label_data in zip(indicator, self._labeled_loader): + (labelimage, labeltarget), _, filename, partition_list, group_list \ + = self._preprocess_data(label_data, self._device) + predict_logits = self._model(labelimage) + assert not simplex(predict_logits), predict_logits + + onehot_ltarget = class2one_hot(labeltarget.squeeze(1), 4) + sup_loss = self._sup_criterion(predict_logits.softmax(1), onehot_ltarget) + + self._optimizer.zero_grad() + sup_loss.backward() + self._optimizer.step() + + with torch.no_grad(): + self.meters["sup_loss"].add(sup_loss.item()) + self.meters["ds"].add(predict_logits.max(1)[1], labeltarget.squeeze(1), + group_name=list(group_list)) report_dict = self.meters.tracking_status() indicator.set_postfix_dict(report_dict) report_dict = self.meters.tracking_status() - return report_dict - - @staticmethod - def _preprocess_data(data, device): - return preprocess_input_with_single_transformation(data, device) - - class EvalEpoch(TrainEpoch): - def __init__(self, model: Model, val_data_loader: T_loader, sup_criterion, cur_epoch=0, device="cpu"): - super().__init__(model=model, data_loader=val_data_loader, sup_criteiron=sup_criterion, - num_batches=None, cur_epoch=cur_epoch, device=device) # noqa - assert isinstance(val_data_loader, DataLoader), type(val_data_loader) - - @classmethod - @proxy_trainer - def create_from_trainer(cls, trainer): - return cls(trainer._model, trainer._val_loader, trainer._sup_criterion, trainer._cur_epoch, # noqa - trainer._device) # noqa - - def _configure_meters(self, meters: MeterInterface) -> MeterInterface: - super()._configure_meters(meters) - meters.delete_meters(["lr"]) - return meters - - @torch.no_grad() - def _run(self, *args, **kwargs) -> Union[EpochResultDict, Tuple[EpochResultDict, float]]: - self._model.eval() - assert not self._model.training, self._model.training - with tqdm(range(len(self._data_loader))).set_desc_from_epocher(self) as indicator: - for i, data in zip(indicator, self._data_loader): - images, targets, filename, partiton_list, group_list = self._preprocess_data(data, self._device) - predict_logits = self._model(images) - assert not simplex(predict_logits), predict_logits.shape - onehot_targets = class2one_hot(targets.squeeze(1), 4) - loss = self._sup_criterion(predict_logits.softmax(1), onehot_targets, disable_assert=True) - self.meters["sup_loss"].add(loss.item()) - self.meters["ds"].add(predict_logits.max(1)[1], targets.squeeze(1), group_name=list(group_list)) + return report_dict + + @staticmethod + def _preprocess_data(data, device): + return preprocess_input_with_twice_transformation(data, device) + + +class MeanTeacherEpocher(SimpleFineTuneEpoch): + + def __init__(self, model: nn.Module, teacher_model: nn.Module, optimizer: T_optim, labeled_loader: T_loader, + tra_loader: T_loader, num_batches: int = 100, sup_criterion: T_loss = None, + reg_criterion: T_loss = None, cur_epoch=0, device="cpu", transform_axis=None, + reg_weight: float = 0.0, ema_updater = None) -> None: + super().__init__(model, optimizer, labeled_loader, num_batches, sup_criterion, cur_epoch, device) + self._teacher_model = teacher_model + assert callable(reg_criterion), reg_weight + self._reg_criterion = reg_criterion + self._tra_loader = tra_loader + self._transformer = TensorRandomFlip(transform_axis) + print(self._transformer) + self._reg_weight = float(reg_weight) + assert ema_updater + self._ema_updater = ema_updater + + @classmethod + def create_from_trainer(cls, trainer): + return cls(model=trainer._model, teacher_model=trainer._teacher_model, optimizer=trainer._optimizer, + labeled_loader=trainer._fine_tune_loader, tra_loader=trainer._pretrain_loader, + sup_criterion=trainer._sup_criterion, reg_criterion=trainer._reg_criterion, + cur_epoch=trainer._cur_epoch, device=trainer._device, transform_axis=trainer._transform_axis, + reg_weight=trainer._reg_weight, ema_updater=trainer._ema_updater) + + def _configure_meters(self, meters: MeterInterface) -> MeterInterface: + meters = super()._configure_meters(meters) + meters.register_meter("reg_loss", AverageValueMeter()) + meters.register_meter("reg_weight", AverageValueMeter()) + return meters + + + + def _run(self, *args, **kwargs) -> EpochResultDict: + self._model.train() + self._teacher_model.train() + assert self._model.training, self._model.training + assert self._teacher_model.training, self._teacher_model.training + self.meters["lr"].add(self._optimizer.param_groups[0]["lr"]) + self.meters["reg_weight"].add(self._reg_weight) + report_dict: EpochResultDict + + with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: + for i, label_data, all_data in zip(indicator, self._labeled_loader, self._tra_loader): + (labelimage, labeltarget), _, filename, partition_list, group_list \ + = self._preprocess_data(label_data, self._device) + (unlabelimage, _), *_ = self._preprocess_data(label_data, self._device) + + seed = random.randint(0, int(1e6)) + with FixRandomSeed(seed): + unlabelimage_tf = torch.stack([self._transformer(x) for x in unlabelimage], dim=0) + assert unlabelimage_tf.shape == unlabelimage.shape + + student_logits = self._model(torch.cat([labelimage, unlabelimage_tf], dim=0)) + if simplex(student_logits): + raise RuntimeError("output of the model should be logits, instead of simplex") + student_sup_logits, student_unlabel_logits_tf = student_logits[:len(labelimage)], \ + student_logits[len(labelimage):] + + with torch.no_grad(): + teacher_unlabel_logits = self._teacher_model(unlabelimage) + with FixRandomSeed(seed): + teacher_unlabel_logits_tf = torch.stack([self._transformer(x) for x in teacher_unlabel_logits]) + assert teacher_unlabel_logits.shape == teacher_unlabel_logits_tf.shape + + # calcul the loss + onehot_ltarget = class2one_hot(labeltarget.squeeze(1), 4) + sup_loss = self._sup_criterion(student_sup_logits.softmax(1), onehot_ltarget) + + reg_loss = self._reg_criterion(student_unlabel_logits_tf.softmax(1), teacher_unlabel_logits_tf.detach().softmax(1)) + total_loss = sup_loss + self._reg_weight * reg_loss + + self._optimizer.zero_grad() + total_loss.backward() + self._optimizer.step() + + # update ema + self._ema_updater(ema_model=self._teacher_model, student_model= self._model) + + with torch.no_grad(): + self.meters["sup_loss"].add(sup_loss.item()) + self.meters["reg_loss"].add(reg_loss.item()) + self.meters["ds"].add(student_sup_logits.max(1)[1], labeltarget.squeeze(1), + group_name=list(group_list)) report_dict = self.meters.tracking_status() indicator.set_postfix_dict(report_dict) report_dict = self.meters.tracking_status() - return report_dict, report_dict["ds"]["DSC_mean"] - - @staticmethod - def _preprocess_data(data, device): - return preprocess_input_with_single_transformation(data, device) - - -class SemiEpocher: - class TrainEpoch(_Epocher): - - def __init__(self, model: Model, labeled_loader: T_loader, unlabeled_loader: T_loader, sup_criteiron: T_loss, - reg_criterion: T_loss, num_batches: int = 100, cur_epoch=0, device="cpu", - reg_weight: float = 0.001) -> None: - super().__init__(model, cur_epoch, device) - assert isinstance(num_batches, int) and num_batches > 0, num_batches - self._labeled_loader = labeled_loader - self._unlabeled_loader = unlabeled_loader - self._sup_criterion = sup_criteiron - self._reg_criterion = reg_criterion - self._num_batches = num_batches - self._reg_weight = reg_weight - - @classmethod - @proxy_trainer - def create_from_trainer(cls, trainer): - return cls(trainer._model, trainer._labeled_loader, trainer._unlabeled_loader, trainer._sup_criterion, - trainer._reg_criterion, trainer._num_batches, trainer._cur_epoch, trainer._device, - trainer._reg_weight) - - def _configure_meters(self, meters: MeterInterface) -> MeterInterface: - meters.register_meter("lr", AverageValueMeter()) - meters.register_meter("sup_loss", AverageValueMeter()) - meters.register_meter("reg_weight", AverageValueMeter()) - meters.register_meter("reg_loss", AverageValueMeter()) - meters.register_meter("ds", UniversalDice(4, [1, 2, 3])) - return meters - - def _run(self, *args, **kwargs) -> EpochResultDict: - self._model.set_mode(ModelMode.TRAIN) - assert self._model.training, self._model.training - report_dict: EpochResultDict - self.meters["lr"].add(self._model.get_lr()[0]) - self.meters["reg_weight"].add(self._reg_weight) - - with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: - for i, label_data, unlabel_data in zip(indicator, self._labeled_loader, self._unlabeled_loader): - (labelimage, labeltarget), (labelimage_tf, labeltarget_tf), filename, partition_list, group_list, ( - unlabelimage, unlabelimage_tf) = self._preprocess_data(label_data, unlabel_data, self._device) - predict_logits = self._model( - torch.cat([labelimage, labelimage_tf, unlabelimage, unlabelimage_tf], dim=0), - force_simplex=False) - assert not simplex(predict_logits), predict_logits - label_logit, label_logit_tf, unlabel_logit, unlabel_logit_tf \ - = torch.split(predict_logits, - [len(labelimage), len(labelimage_tf), len(unlabelimage), len(unlabelimage_tf)], - dim=0) - onehot_ltarget = class2one_hot(torch.cat([labeltarget.squeeze(), labeltarget_tf.squeeze()], dim=0), - 4) - sup_loss = self._sup_criterion(torch.cat([label_logit, label_logit_tf], dim=0).softmax(1), - onehot_ltarget) - reg_loss = self._reg_criterion(unlabel_logit.softmax(1), unlabel_logit_tf.softmax(1)) - total_loss = sup_loss + reg_loss * self._reg_weight - - self._model.zero_grad() - total_loss.backward() - self._model.step() - - with torch.no_grad(): - self.meters["sup_loss"].add(sup_loss.item()) - self.meters["ds"].add(label_logit.max(1)[1], labeltarget.squeeze(1), - group_name=list(group_list)) - self.meters["reg_loss"].add(reg_loss.item()) - report_dict = self.meters.tracking_status() - indicator.set_postfix_dict(report_dict) - report_dict = self.meters.tracking_status() - return report_dict - - @staticmethod - def _preprocess_data(data, device): - return preprocess_input_with_twice_transformation(data, device) - - class EvalEpoch(FSEpocher.EvalEpoch): - pass + return report_dict diff --git a/contrastyou/epocher/contrast_epocher.py b/contrastyou/epocher/contrast_epocher.py index b3367e79..e5533e3f 100644 --- a/contrastyou/epocher/contrast_epocher.py +++ b/contrastyou/epocher/contrast_epocher.py @@ -9,6 +9,7 @@ from deepclustering2.loss import KL_div from deepclustering2.meters2 import EpochResultDict, MeterInterface, AverageValueMeter, UniversalDice from deepclustering2.tqdm import tqdm +from deepclustering2.optim import get_lrs_from_optimizer from deepclustering2.trainer.trainer import T_loader, T_loss, T_optim from deepclustering2.utils import simplex, class2one_hot, np from ._utils import preprocess_input_with_twice_transformation, unfold_position @@ -41,7 +42,7 @@ def _configure_meters(self, meters: MeterInterface) -> MeterInterface: def _run(self, *args, **kwargs) -> EpochResultDict: self._model.train() assert self._model.training, self._model.training - self.meters["lr"].add(self._optimizer.param_groups[0]["lr"]) + self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0]) with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: # noqa for i, data in zip(indicator, self._pretrain_encoder_loader): @@ -49,7 +50,7 @@ def _run(self, *args, **kwargs) -> EpochResultDict: _, (e5, *_), *_ = self._model(torch.cat([img, img_tf], dim=0), return_features=True) global_enc, global_tf_enc = torch.chunk(F.normalize(self._projection_head(e5), dim=1), chunks=2, dim=0) # todo: convert representation to distance - labels = self._mask_generation(partition_list, group_list) + labels = self._label_generation(partition_list, group_list) contrastive_loss = self._contrastive_criterion(torch.stack([global_enc, global_tf_enc], dim=1), labels=labels) self._optimizer.zero_grad() @@ -67,7 +68,7 @@ def _preprocess_data(data, device): return preprocess_input_with_twice_transformation(data, device) @staticmethod - def _mask_generation(partition_list: List[str], group_list: List[str]): + def _label_generation(partition_list: List[str], group_list: List[str]): """override this to provide more mask """ return partition_list @@ -85,7 +86,7 @@ def __init__(self, model: nn.Module, projection_head: nn.Module, optimizer: opti def _run(self, *args, **kwargs) -> EpochResultDict: self._model.train() assert self._model.training, self._model.training - self.meters["lr"].add(self._optimizer.param_groups[0]["lr"]) + self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0]) with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: # noqa for i, data in zip(indicator, self._pretrain_decoder_loader): @@ -99,13 +100,13 @@ def _run(self, *args, **kwargs) -> EpochResultDict: local_enc_unfold_norm = F.normalize(local_enc_unfold.view(b, -1), p=2, dim=1) local_tf_enc_unfold_norm = F.normalize(local_tf_enc_unfold.view(b, -1), p=2, dim=1) - labels = self._mask_generation(partition_list, group_list, _fold_partition) + labels = self._label_generation(partition_list, group_list, _fold_partition) contrastive_loss = self._contrastive_criterion( torch.stack([local_enc_unfold_norm, local_tf_enc_unfold_norm], dim=1), labels=labels ) if torch.isnan(contrastive_loss): - raise RuntimeError() + raise RuntimeError(contrastive_loss) self._optimizer.zero_grad() contrastive_loss.backward() self._optimizer.step() @@ -117,7 +118,7 @@ def _run(self, *args, **kwargs) -> EpochResultDict: return report_dict @staticmethod - def _mask_generation(partition_list: List[str], group_list: List[str], folder_partition: List[str]): + def _label_generation(partition_list: List[str], group_list: List[str], folder_partition: List[str]): if len(folder_partition) > len(partition_list): ratio = int(len(folder_partition) / len(partition_list)) partition_list = partition_list.tolist() * ratio @@ -131,57 +132,3 @@ def tolabel(encode): return tolabel([str(g) + str(p) + str(f) for g, p, f in zip(group_list, partition_list, folder_partition)]) -class FineTuneEpoch(_Epocher): - def __init__(self, model: nn.Module, optimizer: T_optim, labeled_loader: T_loader, num_batches: int = 100, - cur_epoch=0, device="cpu") -> None: - super().__init__(model, cur_epoch, device) - assert isinstance(num_batches, int) and num_batches > 0, num_batches - self._labeled_loader = labeled_loader - self._sup_criterion = KL_div() - self._num_batches = num_batches - self._optimizer = optimizer - - @classmethod - @proxy_trainer - def create_from_trainer(cls, trainer): - return cls(trainer._model, trainer._labeled_loader, trainer._unlabeled_loader, trainer._sup_criterion, - trainer._reg_criterion, trainer._num_batches, trainer._cur_epoch, trainer._device, - trainer._reg_weight) - - def _configure_meters(self, meters: MeterInterface) -> MeterInterface: - meters.register_meter("lr", AverageValueMeter()) - meters.register_meter("sup_loss", AverageValueMeter()) - meters.register_meter("ds", UniversalDice(4, [1, 2, 3])) - return meters - - def _run(self, *args, **kwargs) -> EpochResultDict: - self._model.train() - assert self._model.training, self._model.training - report_dict: EpochResultDict - self.meters["lr"].add(self._optimizer.param_groups[0]["lr"]) - with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: - for i, label_data in zip(indicator, self._labeled_loader): - (labelimage, labeltarget), _, filename, partition_list, group_list \ - = self._preprocess_data(label_data, self._device) - predict_logits = self._model(labelimage) - assert not simplex(predict_logits), predict_logits - - onehot_ltarget = class2one_hot(labeltarget.squeeze(1), 4) - sup_loss = self._sup_criterion(predict_logits.softmax(1), onehot_ltarget) - - self._optimizer.zero_grad() - sup_loss.backward() - self._optimizer.step() - - with torch.no_grad(): - self.meters["sup_loss"].add(sup_loss.item()) - self.meters["ds"].add(predict_logits.max(1)[1], labeltarget.squeeze(1), - group_name=list(group_list)) - report_dict = self.meters.tracking_status() - indicator.set_postfix_dict(report_dict) - report_dict = self.meters.tracking_status() - return report_dict - - @staticmethod - def _preprocess_data(data, device): - return preprocess_input_with_twice_transformation(data, device) diff --git a/contrastyou/trainer/__init__.py b/contrastyou/trainer/__init__.py index c34700ab..b7435db0 100644 --- a/contrastyou/trainer/__init__.py +++ b/contrastyou/trainer/__init__.py @@ -1,4 +1,3 @@ -from .base_trainer import FSTrainer, SemiTrainer -from .contrast_trainer import ContrastTrainer +from .contrast_trainer import ContrastTrainer, ContrastTrainerMT -trainer_zoos = {"fs": FSTrainer, "semi": SemiTrainer, "contrast": ContrastTrainer} +trainer_zoos = {"contrast": ContrastTrainer, "contrastMT":ContrastTrainerMT} diff --git a/contrastyou/trainer/base_trainer.py b/contrastyou/trainer/base_trainer.py deleted file mode 100644 index ace1412c..00000000 --- a/contrastyou/trainer/base_trainer.py +++ /dev/null @@ -1,44 +0,0 @@ -from pathlib import Path -from typing import Tuple - -from torch.utils.data import DataLoader - -from contrastyou import PROJECT_PATH -from contrastyou.epocher.base_epocher import FSEpocher, SemiEpocher -from deepclustering2.epoch._epocher import _Epocher # noqa -from deepclustering2.meters2 import EpochResultDict -from deepclustering2.models import Model -from deepclustering2.trainer.trainer import Trainer, T_loader, T_loss - - -class FSTrainer(Trainer): - RUN_PATH = Path(PROJECT_PATH) / "runs" - - def __init__(self, model: Model, tra_loader: T_loader, labeled_loader: T_loader, unlabeled_loader: T_loader, - val_loader: DataLoader, - sup_criterion: T_loss, reg_criterion=T_loss, save_dir: str = "base", max_epoch: int = 100, - num_batches: int = None, reg_weight=0.0001, - device: str = "cpu", configuration=None): - super().__init__(model, save_dir, max_epoch, num_batches, device, configuration) - self._tra_loader = tra_loader - self._labeled_loader = labeled_loader - self._unlabeled_loader = unlabeled_loader - self._val_loader = val_loader - self._sup_criterion = sup_criterion - self._reg_criterion = reg_criterion - self._reg_weight = reg_weight - - def _run_epoch(self, epocher: _Epocher = FSEpocher.TrainEpoch, *args, **kwargs) -> EpochResultDict: - return super()._run_epoch(epocher, *args, **kwargs) - - def _eval_epoch(self, epocher: _Epocher = FSEpocher.EvalEpoch, *args, **kwargs) -> Tuple[EpochResultDict, float]: - eval_epocher = epocher.create_from_trainer(trainer=self) - return eval_epocher.run() - - -class SemiTrainer(FSTrainer): - def _run_epoch(self, epocher: _Epocher = SemiEpocher.TrainEpoch, *args, **kwargs) -> EpochResultDict: - return super()._run_epoch(epocher, *args, **kwargs) - - def _eval_epoch(self, epocher: _Epocher = SemiEpocher.EvalEpoch, *args, **kwargs) -> Tuple[EpochResultDict, float]: - return super()._eval_epoch(epocher, *args, **kwargs) diff --git a/contrastyou/trainer/contrast_trainer.py b/contrastyou/trainer/contrast_trainer.py index eda17e1a..49da2cae 100644 --- a/contrastyou/trainer/contrast_trainer.py +++ b/contrastyou/trainer/contrast_trainer.py @@ -3,12 +3,9 @@ from pathlib import Path import torch -from torch import nn -from torch.utils.data import DataLoader - from contrastyou import PROJECT_PATH -from contrastyou.epocher.base_epocher import FSEpocher -from contrastyou.epocher.contrast_epocher import PretrainEncoderEpoch, PretrainDecoderEpoch, FineTuneEpoch +from contrastyou.epocher import PretrainEncoderEpoch, PretrainDecoderEpoch, SimpleFineTuneEpoch, MeanTeacherEpocher +from contrastyou.epocher.base_epocher import EvalEpoch from contrastyou.losses.contrast_loss import SupConLoss from contrastyou.trainer._utils import Flatten from deepclustering2.loss import KL_div @@ -16,6 +13,8 @@ from deepclustering2.schedulers import GradualWarmupScheduler from deepclustering2.trainer.trainer import Trainer, T_loader from deepclustering2.writer import SummaryWriter +from torch import nn +from torch.utils.data import DataLoader class ContrastTrainer(Trainer): @@ -24,7 +23,7 @@ class ContrastTrainer(Trainer): def __init__(self, model: nn.Module, pretrain_loader: T_loader, fine_tune_loader: T_loader, val_loader: DataLoader, save_dir: str = "base", max_epoch_train_encoder: int = 100, max_epoch_train_decoder: int = 100, max_epoch_train_finetune: int = 100, num_batches: int = 256, device: str = "cpu", configuration=None, - train_encoder: bool = True, train_decoder: bool = True): + train_encoder: bool = True, train_decoder: bool = True,*args,**kwargs): """ ContrastTraining Trainer :param model: nn.module network to be pretrained @@ -61,8 +60,9 @@ def __init__(self, model: nn.Module, pretrain_loader: T_loader, fine_tune_loader self._optimizer = None self._scheduler = None self._projector = None + self._sup_criterion = None - def pretrain_encoder_init(self): + def pretrain_encoder_init(self, *args, **kwargs): # adding optimizer and scheduler self._projector = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), @@ -97,7 +97,7 @@ def pretrain_encoder_run(self): self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_encoder")) self.train_encoder_done = True - def pretrain_decoder_init(self): + def pretrain_decoder_init(self, *args, **kwargs): # adding optimizer and scheduler self._projector = nn.Sequential( nn.Conv2d(64, 64, 3, 1, 1), @@ -133,27 +133,23 @@ def pretrain_decoder_run(self): self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_decoder")) self.train_decoder_done = True - def finetune_network_init(self): + def finetune_network_init(self, *args, **kwargs): self._optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-7, weight_decay=1e-5) self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer, self._max_epoch_train_finetune - 10, 0) self._scheduler = GradualWarmupScheduler(self._optimizer, 200, 10, self._scheduler) + self._sup_criterion = KL_div() - def finetune_network_run(self): + def finetune_network_run(self, epocher_type=SimpleFineTuneEpoch): self.to(self._device) self._model.enable_grad_decoder() # noqa self._model.enable_grad_encoder() # noqa for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_finetune): - finetune_dict = FineTuneEpoch( - model=self._model, optimizer=self._optimizer, - labeled_loader=self._fine_tune_loader, num_batches=self._num_batches, - cur_epoch=self._cur_epoch, device=self._device - ).run() - val_dict, cur_score = FSEpocher.EvalEpoch(self._model, val_data_loader=self._val_loader, - sup_criterion=KL_div(), - cur_epoch=self._cur_epoch, device=self._device).run() + finetune_dict = epocher_type.create_from_trainer(self).run() + val_dict, cur_score = EvalEpoch(self._model, val_loader=self._val_loader, sup_criterion=self._sup_criterion, + cur_epoch=self._cur_epoch, device=self._device).run() self._scheduler.step() storage_dict = StorageIncomeDict(finetune=finetune_dict, val=val_dict) self._finetune_storage.put_from_dict(storage_dict, epoch=self._cur_epoch) @@ -189,3 +185,46 @@ def start_training(self, checkpoint: str = None): except Exception as e: print(f"loading finetune_checkpoint failed with {e}, ") self.finetune_network_run() + + +class ContrastTrainerMT(ContrastTrainer): + + def __init__(self, model: nn.Module, pretrain_loader: T_loader, fine_tune_loader: T_loader, val_loader: DataLoader, + save_dir: str = "base", max_epoch_train_encoder: int = 100, max_epoch_train_decoder: int = 100, + max_epoch_train_finetune: int = 100, num_batches: int = 256, device: str = "cpu", configuration=None, + train_encoder: bool = True, train_decoder: bool = True, transform_axis=[1, 2], reg_weight=0.0, + reg_criterion=nn.MSELoss(), *args, **kwargs): + super().__init__(model, pretrain_loader, fine_tune_loader, val_loader, save_dir, max_epoch_train_encoder, + max_epoch_train_decoder, max_epoch_train_finetune, num_batches, device, configuration, + train_encoder, train_decoder,*args, **kwargs) + self._teacher_model = None + self._transform_axis = transform_axis + self._reg_weight = reg_weight + self._reg_criterion = reg_criterion + + def finetune_network_init(self, *args, **kwargs): + super().finetune_network_init(*args,**kwargs) + from contrastyou.arch import UNet + from deepclustering2.models import ema_updater + # here we initialize the MT + self._teacher_model = UNet(**self._configuration["Arch"]) + for param in self._teacher_model.parameters(): + param.detach_() + self._teacher_model.train() + self._ema_updater = ema_updater(alpha=0.999, justify_alpha=True, weight_decay=1e-5, update_bn=False) + + def finetune_network_run(self, epocher_type=MeanTeacherEpocher): + self.to(self._device) + self._model.enable_grad_decoder() # noqa + self._model.enable_grad_encoder() # noqa + + for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_finetune): + finetune_dict = epocher_type.create_from_trainer(self).run() + val_dict, cur_score = EvalEpoch(self._teacher_model, val_loader=self._val_loader, + sup_criterion=self._sup_criterion, + cur_epoch=self._cur_epoch, device=self._device).run() + self._scheduler.step() + storage_dict = StorageIncomeDict(finetune=finetune_dict, val=val_dict) + self._finetune_storage.put_from_dict(storage_dict, epoch=self._cur_epoch) + self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch) + self.save(cur_score, os.path.join(self._save_dir, "finetune")) diff --git a/main_contrast.py b/main_contrast.py index a9dc6451..854af825 100644 --- a/main_contrast.py +++ b/main_contrast.py @@ -1,17 +1,16 @@ from pathlib import Path -from torch.utils.data import DataLoader - from contrastyou import DATA_PATH, PROJECT_PATH from contrastyou.arch import UNet from contrastyou.augment import ACDC_transforms from contrastyou.dataloader._seg_datset import ContrastBatchSampler # noqa from contrastyou.dataloader.acdc_dataset import ACDCSemiInterface, ACDCDataset -from contrastyou.trainer import ContrastTrainer +from contrastyou.trainer import trainer_zoos from deepclustering2.configparser import ConfigManger from deepclustering2.dataloader.sampler import InfiniteRandomSampler from deepclustering2.dataset import PatientSampler from deepclustering2.utils import set_benchmark +from torch.utils.data import DataLoader # load configure from yaml and argparser cmanager = ConfigManger(Path(PROJECT_PATH) / "config/config.yaml") @@ -47,8 +46,9 @@ shuffle=False), pin_memory=True) checkpoint = config.pop("Checkpoint", None) - -trainer = ContrastTrainer(model, iter(train_loader), iter(labeled_loader), val_loader, - configuration=cmanager.config, **config["Trainer"], ) +Trainer = trainer_zoos[config["Trainer"].pop("name")] +assert Trainer, Trainer +trainer = Trainer(model=model, pretrain_loader=iter(train_loader), fine_tune_loader=iter(labeled_loader), + val_loader=val_loader, configuration=cmanager.config, **config["Trainer"], ) trainer.start_training(checkpoint=checkpoint) diff --git a/references/deepclustering2 b/references/deepclustering2 index 9d428358..7de3444a 160000 --- a/references/deepclustering2 +++ b/references/deepclustering2 @@ -1 +1 @@ -Subproject commit 9d42835895a02a018fe628444bc4229d29dc9957 +Subproject commit 7de3444a41dd2f1a3e36785f0320f2d02ad14ed1 diff --git a/run_script.py b/run_script.py index 7ee0a344..31a3320d 100644 --- a/run_script.py +++ b/run_script.py @@ -2,14 +2,16 @@ from deepclustering2.cchelper import JobSubmiter -save_dir = "first_try" -num_batches = 2000 +num_batches = 500 random_seed = 1 -labeled_data_ratio = 0.05 +labeled_data_ratio = 0.1 unlabeled_data_ratio = 1 - labeled_data_ratio +trainer_name="contrast" # or contrastMT +save_dir = f"label_data_ration:{labeled_data_ratio}/{trainer_name}" -common_opts = f" RandomSeed={random_seed} Data.labeled_data_ratio={labeled_data_ratio} Data.unlabeled_data_ratio={unlabeled_data_ratio} Trainer.num_batches={num_batches} " + +common_opts = f" Trainer.name={trainer_name} RandomSeed={random_seed} Data.labeled_data_ratio={labeled_data_ratio} Data.unlabeled_data_ratio={unlabeled_data_ratio} Trainer.num_batches={num_batches} " jobs = [ f"python main_contrast.py {common_opts} Trainer.save_dir={save_dir}/baseline Trainer.train_encoder=False Trainer.train_decoder=False ", f"python main_contrast.py {common_opts} Trainer.save_dir={save_dir}/encoder Trainer.train_encoder=True Trainer.train_decoder=False ", @@ -21,6 +23,6 @@ jobsubmiter = JobSubmiter(project_path="./", on_local=False) for j in jobs: - jobsubmiter.prepare_env(["export OMP_NUM_THREADS=1", "source "]) + jobsubmiter.prepare_env([ "source ./venv/bin/activate ", "export OMP_NUM_THREADS=1",]) jobsubmiter.account = next(accounts) jobsubmiter.run(j)