From 9dce6c8646060edcc96e9a157b7fe8430fb90af1 Mon Sep 17 00:00:00 2001 From: jizong Date: Thu, 23 Jul 2020 13:30:56 -0400 Subject: [PATCH] adding iic encoder --- contrastyou/epocher/IIC_epocher.py | 101 ++++++++++--------- contrastyou/epocher/__init__.py | 3 +- contrastyou/losses/iic_loss.py | 73 +++++++++++++- contrastyou/trainer/iic_trainer.py | 154 +++++++++++++++++++++++++++++ 4 files changed, 280 insertions(+), 51 deletions(-) create mode 100644 contrastyou/trainer/iic_trainer.py diff --git a/contrastyou/epocher/IIC_epocher.py b/contrastyou/epocher/IIC_epocher.py index 7138ca7c..261f31c8 100644 --- a/contrastyou/epocher/IIC_epocher.py +++ b/contrastyou/epocher/IIC_epocher.py @@ -1,60 +1,63 @@ import torch -from torch import nn - +from deepclustering2 import optim from deepclustering2.meters2 import EpochResultDict -from deepclustering2.tqdm import tqdm +from deepclustering2.optim import get_lrs_from_optimizer from deepclustering2.trainer.trainer import T_loader, T_loss -from deepclustering2.utils import simplex, class2one_hot -from .base_epocher import SemiEpocher - - -class TrainEpoch(SemiEpocher.TrainEpoch): - - def __init__(self, model: nn.Module, labeled_loader: T_loader, unlabeled_loader: T_loader, sup_criteiron: T_loss, - reg_criterion: T_loss, num_batches: int = 100, cur_epoch=0, device="cpu", - reg_weight: float = 0.001) -> None: - super().__init__(model, labeled_loader, unlabeled_loader, sup_criteiron, reg_criterion, num_batches, cur_epoch, - device, reg_weight) - assert reg_criterion # todo: add constraints on the reg_criterion +from torch import nn +from torch.nn import functional as F + +from .contrast_epocher import PretrainEncoderEpoch as _PretrainEncoderEpoch + + +class IICPretrainEcoderEpoch(_PretrainEncoderEpoch): + + def __init__(self, model: nn.Module, projection_head: nn.Module, projection_classifier: nn.Module, + optimizer: optim.Optimizer, pretrain_encoder_loader: T_loader = None, + contrastive_criterion: T_loss = None, num_batches: int = 0, + cur_epoch=0, device="cpu", iic_weight_ratio=1, *args, **kwargs) -> None: + """ + + :param model: + :param projection_head: here the projection head should be a classifier + :param optimizer: + :param pretrain_encoder_loader: + :param contrastive_criterion: + :param num_batches: + :param cur_epoch: + :param device: + :param args: + :param kwargs: + """ + assert pretrain_encoder_loader is not None + self._projection_classifier = projection_classifier + from ..losses.iic_loss import IIDLoss + self._iic_criterion = IIDLoss() + self._iic_weight_ratio = iic_weight_ratio + super().__init__(model, projection_head, optimizer, pretrain_encoder_loader, contrastive_criterion, num_batches, + cur_epoch, device, *args, **kwargs) def _run(self, *args, **kwargs) -> EpochResultDict: + self._model.train() assert self._model.training, self._model.training - report_dict: EpochResultDict - self.meters["lr"].add(self._model.get_lr()[0]) - self.meters["reg_weight"].add(self._reg_weight) - - with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: - for i, label_data, unlabel_data in zip(indicator, self._labeled_loader, self._unlabeled_loader): - ((labelimage, labeltarget), (labelimage_tf, labeltarget_tf)), \ - filename, partition_list, group_list = self._preprocess_data(label_data, self._device) - ((unlabelimage, _), (unlabelimage_tf, _)), unlabel_filename, \ - unlabel_partition_list, unlabel_group_list = self._preprocess_data(unlabel_data, self._device) - - predict_logits = self._model( - torch.cat([labelimage, labelimage_tf, unlabelimage, unlabelimage_tf], dim=0), - force_simplex=False) - assert not simplex(predict_logits), predict_logits - label_logit, label_logit_tf, unlabel_logit, unlabel_logit_tf \ - = torch.split(predict_logits, - [len(labelimage), len(labelimage_tf), len(unlabelimage), len(unlabelimage_tf)], - dim=0) - onehot_ltarget = class2one_hot(torch.cat([labeltarget.squeeze(), labeltarget_tf.squeeze()], dim=0), - 4) - sup_loss = self._sup_criterion(torch.cat([label_logit, label_logit_tf], dim=0).softmax(1), - onehot_ltarget) - reg_loss = self._reg_criterion(unlabel_logit.softmax(1), unlabel_logit_tf.softmax(1)) - total_loss = sup_loss + reg_loss * self._reg_weight - - self._model.zero_grad() + self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0]) + + with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: # noqa + for i, data in zip(indicator, self._pretrain_encoder_loader): + (img, _), (img_tf, _), filename, partition_list, group_list = self._preprocess_data(data, self._device) + _, (e5, *_), *_ = self._model(torch.cat([img, img_tf], dim=0), return_features=True) + global_enc, global_tf_enc = torch.chunk(F.normalize(self._projection_head(e5), dim=1), chunks=2, dim=0) + # fixme: here lack of some code for IIC + labels = self._label_generation(partition_list, group_list) + contrastive_loss = self._contrastive_criterion(torch.stack([global_enc, global_tf_enc], dim=1), + labels=labels) + iic_loss = self._iic_criterion() # todo + total_loss = self._iic_weight_ratio * iic_loss + (1 - self._iic_weight_ratio) * contrastive_loss + self._optimizer.zero_grad() total_loss.backward() - self._model.step() - + self._optimizer.step() + # todo: meter recording. with torch.no_grad(): - self.meters["sup_loss"].add(sup_loss.item()) - self.meters["ds"].add(label_logit.max(1)[1], labeltarget.squeeze(1), - group_name=list(group_list)) - self.meters["reg_loss"].add(reg_loss.item()) + self.meters["contrastive_loss"].add(contrastive_loss.item()) report_dict = self.meters.tracking_status() indicator.set_postfix_dict(report_dict) - report_dict = self.meters.tracking_status() return report_dict diff --git a/contrastyou/epocher/__init__.py b/contrastyou/epocher/__init__.py index ad3cfe8c..98b3fa47 100644 --- a/contrastyou/epocher/__init__.py +++ b/contrastyou/epocher/__init__.py @@ -1,2 +1,3 @@ from .base_epocher import EvalEpoch, SimpleFineTuneEpoch, MeanTeacherEpocher -from .contrast_epocher import PretrainDecoderEpoch, PretrainEncoderEpoch \ No newline at end of file +from .contrast_epocher import PretrainDecoderEpoch, PretrainEncoderEpoch +from .IIC_epocher import IICPretrainEcoderEpoch \ No newline at end of file diff --git a/contrastyou/losses/iic_loss.py b/contrastyou/losses/iic_loss.py index dbc55a42..4017ef46 100644 --- a/contrastyou/losses/iic_loss.py +++ b/contrastyou/losses/iic_loss.py @@ -2,11 +2,82 @@ from typing import Tuple import torch +from deepclustering2.utils import simplex from termcolor import colored from torch import Tensor +from torch import nn from torch.nn import functional as F -from deepclustering2.utils import simplex + +class IIDLoss(nn.Module): + def __init__(self, lamb: float = 1.0, eps: float = sys.float_info.epsilon): + """ + :param lamb: + :param eps: + """ + super().__init__() + print(colored(f"Initialize {self.__class__.__name__}.", "green")) + self.lamb = float(lamb) + self.eps = float(eps) + self.torch_vision = torch.__version__ + + def forward(self, x_out: Tensor, x_tf_out: Tensor): + """ + return the inverse of the MI. if the x_out == y_out, return the inverse of Entropy + :param x_out: + :param x_tf_out: + :return: + """ + assert simplex(x_out), f"x_out not normalized." + assert simplex(x_tf_out), f"x_tf_out not normalized." + _, k = x_out.size() + p_i_j = compute_joint(x_out, x_tf_out) + assert p_i_j.size() == (k, k) + + p_i = ( + p_i_j.sum(dim=1).view(k, 1).expand(k, k) + ) # p_i should be the mean of the x_out + p_j = p_i_j.sum(dim=0).view(1, k).expand(k, k) # but should be same, symmetric + + # p_i = x_out.mean(0).view(k, 1).expand(k, k) + # p_j = x_tf_out.mean(0).view(1, k).expand(k, k) + # + # avoid NaN losses. Effect will get cancelled out by p_i_j tiny anyway + if self.torch_vision < "1.3.0": + p_i_j[p_i_j < self.eps] = self.eps + p_j[p_j < self.eps] = self.eps + p_i[p_i < self.eps] = self.eps + + loss = -p_i_j * ( + torch.log(p_i_j) - self.lamb * torch.log(p_j) - self.lamb * torch.log(p_i) + ) + loss = loss.sum() + loss_no_lamb = -p_i_j * (torch.log(p_i_j) - torch.log(p_j) - torch.log(p_i)) + loss_no_lamb = loss_no_lamb.sum() + return loss, loss_no_lamb + + +def compute_joint(x_out: Tensor, x_tf_out: Tensor, symmetric=True) -> Tensor: + r""" + return joint probability + :param x_out: p1, simplex + :param x_tf_out: p2, simplex + :return: joint probability + """ + # produces variable that requires grad (since args require grad) + assert simplex(x_out), f"x_out not normalized." + assert simplex(x_tf_out), f"x_tf_out not normalized." + + bn, k = x_out.shape + assert x_tf_out.size(0) == bn and x_tf_out.size(1) == k + + p_i_j = x_out.unsqueeze(2) * x_tf_out.unsqueeze(1) # bn, k, k + p_i_j = p_i_j.sum(dim=0) # k, k aggregated over one batch + if symmetric: + p_i_j = (p_i_j + p_i_j.t()) / 2.0 # symmetric + p_i_j /= p_i_j.sum() # normalise + + return p_i_j class IIDSegmentationLoss: diff --git a/contrastyou/trainer/iic_trainer.py b/contrastyou/trainer/iic_trainer.py new file mode 100644 index 00000000..5ff6bc06 --- /dev/null +++ b/contrastyou/trainer/iic_trainer.py @@ -0,0 +1,154 @@ +import itertools +import os +from pathlib import Path + +import torch +from contrastyou import PROJECT_PATH +from contrastyou.epocher import PretrainDecoderEpoch, SimpleFineTuneEpoch, IICPretrainEcoderEpoch +from contrastyou.epocher.base_epocher import EvalEpoch +from contrastyou.losses.contrast_loss import SupConLoss +from contrastyou.trainer._utils import Flatten +from contrastyou.trainer.contrast_trainer import ContrastTrainer +from deepclustering2.loss import KL_div +from deepclustering2.meters2 import StorageIncomeDict +from deepclustering2.schedulers import GradualWarmupScheduler +from deepclustering2.writer import SummaryWriter +from torch import nn + + +class IICContrastTrainer(ContrastTrainer): + RUN_PATH = Path(PROJECT_PATH) / "runs" + + def pretrain_encoder_init(self, num_clusters=20, *args, **kwargs): + # adding optimizer and scheduler + self._projector_contrastive = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + Flatten(), + nn.Linear(256, 256), + nn.LeakyReLU(0.01, inplace=True), + nn.Linear(256, 256), + ) + self._projector_iic = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + Flatten(), + nn.Linear(256, 256), + nn.LeakyReLU(0.01, inplace=True), + nn.Linear(256, num_clusters), + nn.Softmax(1) + ) + self._optimizer = torch.optim.Adam( + itertools.chain(self._model.parameters(), self._projector_contrastive.parameters(), + self._projector_iic.parameters()), lr=1e-6, weight_decay=1e-5) # noqa + self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer, + self._max_epoch_train_encoder - 10, 0) + self._scheduler = GradualWarmupScheduler(self._optimizer, 300, 10, self._scheduler) # noqa + + def pretrain_encoder_run(self, *args, **kwargs): + self.to(self._device) + self._model.enable_grad_encoder() # noqa + self._model.disable_grad_decoder() # noqa + + for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_encoder): + pretrain_encoder_dict = IICPretrainEcoderEpoch( + model=self._model, projection_head=self._projector_contrastive, + projection_classifier=self._projector_iic, + optimizer=self._optimizer, + pretrain_encoder_loader=self._pretrain_loader, + contrastive_criterion=SupConLoss(), num_batches=self._num_batches, + cur_epoch=self._cur_epoch, device=self._device + ).run() + self._scheduler.step() + storage_dict = StorageIncomeDict(PRETRAIN_ENCODER=pretrain_encoder_dict, ) + self._pretrain_encoder_storage.put_from_dict(storage_dict, epoch=self._cur_epoch) + self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch) + self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_encoder")) + self.train_encoder_done = True + + def pretrain_decoder_init(self, *args, **kwargs): + # adding optimizer and scheduler + self._projector = nn.Sequential( + nn.Conv2d(64, 64, 3, 1, 1), + nn.BatchNorm2d(64), + nn.LeakyReLU(0.01, inplace=True), + nn.Conv2d(64, 32, 3, 1, 1) + ) + self._optimizer = torch.optim.Adam(itertools.chain(self._model.parameters(), self._projector.parameters()), + lr=1e-6, weight_decay=0) + self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer, + self._max_epoch_train_decoder - 10, 0) + self._scheduler = GradualWarmupScheduler(self._optimizer, 300, 10, self._scheduler) + + def pretrain_decoder_run(self): + self.to(self._device) + self._projector.to(self._device) + + self._model.enable_grad_decoder() # noqa + self._model.disable_grad_encoder() # noqa + + for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_decoder): + pretrain_decoder_dict = PretrainDecoderEpoch( + model=self._model, projection_head=self._projector, + optimizer=self._optimizer, + pretrain_decoder_loader=self._pretrain_loader, + contrastive_criterion=SupConLoss(), num_batches=self._num_batches, + cur_epoch=self._cur_epoch, device=self._device + ).run() + self._scheduler.step() + storage_dict = StorageIncomeDict(PRETRAIN_DECODER=pretrain_decoder_dict, ) + self._pretrain_encoder_storage.put_from_dict(storage_dict, epoch=self._cur_epoch) + self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch) + self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_decoder")) + self.train_decoder_done = True + + def finetune_network_init(self, *args, **kwargs): + + self._optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-7, weight_decay=1e-5) + self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer, + self._max_epoch_train_finetune - 10, 0) + self._scheduler = GradualWarmupScheduler(self._optimizer, 200, 10, self._scheduler) + self._sup_criterion = KL_div() + + def finetune_network_run(self, epocher_type=SimpleFineTuneEpoch): + self.to(self._device) + self._model.enable_grad_decoder() # noqa + self._model.enable_grad_encoder() # noqa + + for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_finetune): + finetune_dict = epocher_type.create_from_trainer(self).run() + val_dict, cur_score = EvalEpoch(self._model, val_loader=self._val_loader, sup_criterion=self._sup_criterion, + cur_epoch=self._cur_epoch, device=self._device).run() + self._scheduler.step() + storage_dict = StorageIncomeDict(finetune=finetune_dict, val=val_dict) + self._finetune_storage.put_from_dict(storage_dict, epoch=self._cur_epoch) + self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch) + self.save(cur_score, os.path.join(self._save_dir, "finetune")) + + def start_training(self, checkpoint: str = None): + + with SummaryWriter(str(self._save_dir)) as self._writer: # noqa + if self.train_encoder: + self.pretrain_encoder_init() + if checkpoint is not None: + try: + self.load_state_dict_from_path(os.path.join(checkpoint, "pretrain_encoder")) + except Exception as e: + raise RuntimeError(f"loading pretrain_encoder_checkpoint failed with {e}, ") + + if not self.train_encoder_done: + self.pretrain_encoder_run() + if self.train_decoder: + self.pretrain_decoder_init() + if checkpoint is not None: + try: + self.load_state_dict_from_path(os.path.join(checkpoint, "pretrain_decoder")) + except Exception as e: + print(f"loading pretrain_decoder_checkpoint failed with {e}, ") + if not self.train_decoder_done: + self.pretrain_decoder_run() + self.finetune_network_init() + if checkpoint is not None: + try: + self.load_state_dict_from_path(os.path.join(checkpoint, "finetune")) + except Exception as e: + print(f"loading finetune_checkpoint failed with {e}, ") + self.finetune_network_run()