Skip to content

Commit

Permalink
adding iic encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
jizong committed Jul 23, 2020
1 parent e9ccf20 commit 9dce6c8
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 51 deletions.
101 changes: 52 additions & 49 deletions contrastyou/epocher/IIC_epocher.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,63 @@
import torch
from torch import nn

from deepclustering2 import optim
from deepclustering2.meters2 import EpochResultDict
from deepclustering2.tqdm import tqdm
from deepclustering2.optim import get_lrs_from_optimizer
from deepclustering2.trainer.trainer import T_loader, T_loss
from deepclustering2.utils import simplex, class2one_hot
from .base_epocher import SemiEpocher


class TrainEpoch(SemiEpocher.TrainEpoch):

def __init__(self, model: nn.Module, labeled_loader: T_loader, unlabeled_loader: T_loader, sup_criteiron: T_loss,
reg_criterion: T_loss, num_batches: int = 100, cur_epoch=0, device="cpu",
reg_weight: float = 0.001) -> None:
super().__init__(model, labeled_loader, unlabeled_loader, sup_criteiron, reg_criterion, num_batches, cur_epoch,
device, reg_weight)
assert reg_criterion # todo: add constraints on the reg_criterion
from torch import nn
from torch.nn import functional as F

from .contrast_epocher import PretrainEncoderEpoch as _PretrainEncoderEpoch


class IICPretrainEcoderEpoch(_PretrainEncoderEpoch):

def __init__(self, model: nn.Module, projection_head: nn.Module, projection_classifier: nn.Module,
optimizer: optim.Optimizer, pretrain_encoder_loader: T_loader = None,
contrastive_criterion: T_loss = None, num_batches: int = 0,
cur_epoch=0, device="cpu", iic_weight_ratio=1, *args, **kwargs) -> None:
"""
:param model:
:param projection_head: here the projection head should be a classifier
:param optimizer:
:param pretrain_encoder_loader:
:param contrastive_criterion:
:param num_batches:
:param cur_epoch:
:param device:
:param args:
:param kwargs:
"""
assert pretrain_encoder_loader is not None
self._projection_classifier = projection_classifier
from ..losses.iic_loss import IIDLoss
self._iic_criterion = IIDLoss()
self._iic_weight_ratio = iic_weight_ratio
super().__init__(model, projection_head, optimizer, pretrain_encoder_loader, contrastive_criterion, num_batches,
cur_epoch, device, *args, **kwargs)

def _run(self, *args, **kwargs) -> EpochResultDict:
self._model.train()
assert self._model.training, self._model.training
report_dict: EpochResultDict
self.meters["lr"].add(self._model.get_lr()[0])
self.meters["reg_weight"].add(self._reg_weight)

with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator:
for i, label_data, unlabel_data in zip(indicator, self._labeled_loader, self._unlabeled_loader):
((labelimage, labeltarget), (labelimage_tf, labeltarget_tf)), \
filename, partition_list, group_list = self._preprocess_data(label_data, self._device)
((unlabelimage, _), (unlabelimage_tf, _)), unlabel_filename, \
unlabel_partition_list, unlabel_group_list = self._preprocess_data(unlabel_data, self._device)

predict_logits = self._model(
torch.cat([labelimage, labelimage_tf, unlabelimage, unlabelimage_tf], dim=0),
force_simplex=False)
assert not simplex(predict_logits), predict_logits
label_logit, label_logit_tf, unlabel_logit, unlabel_logit_tf \
= torch.split(predict_logits,
[len(labelimage), len(labelimage_tf), len(unlabelimage), len(unlabelimage_tf)],
dim=0)
onehot_ltarget = class2one_hot(torch.cat([labeltarget.squeeze(), labeltarget_tf.squeeze()], dim=0),
4)
sup_loss = self._sup_criterion(torch.cat([label_logit, label_logit_tf], dim=0).softmax(1),
onehot_ltarget)
reg_loss = self._reg_criterion(unlabel_logit.softmax(1), unlabel_logit_tf.softmax(1))
total_loss = sup_loss + reg_loss * self._reg_weight

self._model.zero_grad()
self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0])

with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: # noqa
for i, data in zip(indicator, self._pretrain_encoder_loader):
(img, _), (img_tf, _), filename, partition_list, group_list = self._preprocess_data(data, self._device)
_, (e5, *_), *_ = self._model(torch.cat([img, img_tf], dim=0), return_features=True)
global_enc, global_tf_enc = torch.chunk(F.normalize(self._projection_head(e5), dim=1), chunks=2, dim=0)
# fixme: here lack of some code for IIC
labels = self._label_generation(partition_list, group_list)
contrastive_loss = self._contrastive_criterion(torch.stack([global_enc, global_tf_enc], dim=1),
labels=labels)
iic_loss = self._iic_criterion() # todo
total_loss = self._iic_weight_ratio * iic_loss + (1 - self._iic_weight_ratio) * contrastive_loss
self._optimizer.zero_grad()
total_loss.backward()
self._model.step()

self._optimizer.step()
# todo: meter recording.
with torch.no_grad():
self.meters["sup_loss"].add(sup_loss.item())
self.meters["ds"].add(label_logit.max(1)[1], labeltarget.squeeze(1),
group_name=list(group_list))
self.meters["reg_loss"].add(reg_loss.item())
self.meters["contrastive_loss"].add(contrastive_loss.item())
report_dict = self.meters.tracking_status()
indicator.set_postfix_dict(report_dict)
report_dict = self.meters.tracking_status()
return report_dict
3 changes: 2 additions & 1 deletion contrastyou/epocher/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base_epocher import EvalEpoch, SimpleFineTuneEpoch, MeanTeacherEpocher
from .contrast_epocher import PretrainDecoderEpoch, PretrainEncoderEpoch
from .contrast_epocher import PretrainDecoderEpoch, PretrainEncoderEpoch
from .IIC_epocher import IICPretrainEcoderEpoch
73 changes: 72 additions & 1 deletion contrastyou/losses/iic_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,82 @@
from typing import Tuple

import torch
from deepclustering2.utils import simplex
from termcolor import colored
from torch import Tensor
from torch import nn
from torch.nn import functional as F

from deepclustering2.utils import simplex

class IIDLoss(nn.Module):
def __init__(self, lamb: float = 1.0, eps: float = sys.float_info.epsilon):
"""
:param lamb:
:param eps:
"""
super().__init__()
print(colored(f"Initialize {self.__class__.__name__}.", "green"))
self.lamb = float(lamb)
self.eps = float(eps)
self.torch_vision = torch.__version__

def forward(self, x_out: Tensor, x_tf_out: Tensor):
"""
return the inverse of the MI. if the x_out == y_out, return the inverse of Entropy
:param x_out:
:param x_tf_out:
:return:
"""
assert simplex(x_out), f"x_out not normalized."
assert simplex(x_tf_out), f"x_tf_out not normalized."
_, k = x_out.size()
p_i_j = compute_joint(x_out, x_tf_out)
assert p_i_j.size() == (k, k)

p_i = (
p_i_j.sum(dim=1).view(k, 1).expand(k, k)
) # p_i should be the mean of the x_out
p_j = p_i_j.sum(dim=0).view(1, k).expand(k, k) # but should be same, symmetric

# p_i = x_out.mean(0).view(k, 1).expand(k, k)
# p_j = x_tf_out.mean(0).view(1, k).expand(k, k)
#
# avoid NaN losses. Effect will get cancelled out by p_i_j tiny anyway
if self.torch_vision < "1.3.0":
p_i_j[p_i_j < self.eps] = self.eps
p_j[p_j < self.eps] = self.eps
p_i[p_i < self.eps] = self.eps

loss = -p_i_j * (
torch.log(p_i_j) - self.lamb * torch.log(p_j) - self.lamb * torch.log(p_i)
)
loss = loss.sum()
loss_no_lamb = -p_i_j * (torch.log(p_i_j) - torch.log(p_j) - torch.log(p_i))
loss_no_lamb = loss_no_lamb.sum()
return loss, loss_no_lamb


def compute_joint(x_out: Tensor, x_tf_out: Tensor, symmetric=True) -> Tensor:
r"""
return joint probability
:param x_out: p1, simplex
:param x_tf_out: p2, simplex
:return: joint probability
"""
# produces variable that requires grad (since args require grad)
assert simplex(x_out), f"x_out not normalized."
assert simplex(x_tf_out), f"x_tf_out not normalized."

bn, k = x_out.shape
assert x_tf_out.size(0) == bn and x_tf_out.size(1) == k

p_i_j = x_out.unsqueeze(2) * x_tf_out.unsqueeze(1) # bn, k, k
p_i_j = p_i_j.sum(dim=0) # k, k aggregated over one batch
if symmetric:
p_i_j = (p_i_j + p_i_j.t()) / 2.0 # symmetric
p_i_j /= p_i_j.sum() # normalise

return p_i_j


class IIDSegmentationLoss:
Expand Down
154 changes: 154 additions & 0 deletions contrastyou/trainer/iic_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import itertools
import os
from pathlib import Path

import torch
from contrastyou import PROJECT_PATH
from contrastyou.epocher import PretrainDecoderEpoch, SimpleFineTuneEpoch, IICPretrainEcoderEpoch
from contrastyou.epocher.base_epocher import EvalEpoch
from contrastyou.losses.contrast_loss import SupConLoss
from contrastyou.trainer._utils import Flatten
from contrastyou.trainer.contrast_trainer import ContrastTrainer
from deepclustering2.loss import KL_div
from deepclustering2.meters2 import StorageIncomeDict
from deepclustering2.schedulers import GradualWarmupScheduler
from deepclustering2.writer import SummaryWriter
from torch import nn


class IICContrastTrainer(ContrastTrainer):
RUN_PATH = Path(PROJECT_PATH) / "runs"

def pretrain_encoder_init(self, num_clusters=20, *args, **kwargs):
# adding optimizer and scheduler
self._projector_contrastive = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
Flatten(),
nn.Linear(256, 256),
nn.LeakyReLU(0.01, inplace=True),
nn.Linear(256, 256),
)
self._projector_iic = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
Flatten(),
nn.Linear(256, 256),
nn.LeakyReLU(0.01, inplace=True),
nn.Linear(256, num_clusters),
nn.Softmax(1)
)
self._optimizer = torch.optim.Adam(
itertools.chain(self._model.parameters(), self._projector_contrastive.parameters(),
self._projector_iic.parameters()), lr=1e-6, weight_decay=1e-5) # noqa
self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer,
self._max_epoch_train_encoder - 10, 0)
self._scheduler = GradualWarmupScheduler(self._optimizer, 300, 10, self._scheduler) # noqa

def pretrain_encoder_run(self, *args, **kwargs):
self.to(self._device)
self._model.enable_grad_encoder() # noqa
self._model.disable_grad_decoder() # noqa

for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_encoder):
pretrain_encoder_dict = IICPretrainEcoderEpoch(
model=self._model, projection_head=self._projector_contrastive,
projection_classifier=self._projector_iic,
optimizer=self._optimizer,
pretrain_encoder_loader=self._pretrain_loader,
contrastive_criterion=SupConLoss(), num_batches=self._num_batches,
cur_epoch=self._cur_epoch, device=self._device
).run()
self._scheduler.step()
storage_dict = StorageIncomeDict(PRETRAIN_ENCODER=pretrain_encoder_dict, )
self._pretrain_encoder_storage.put_from_dict(storage_dict, epoch=self._cur_epoch)
self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch)
self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_encoder"))
self.train_encoder_done = True

def pretrain_decoder_init(self, *args, **kwargs):
# adding optimizer and scheduler
self._projector = nn.Sequential(
nn.Conv2d(64, 64, 3, 1, 1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.01, inplace=True),
nn.Conv2d(64, 32, 3, 1, 1)
)
self._optimizer = torch.optim.Adam(itertools.chain(self._model.parameters(), self._projector.parameters()),
lr=1e-6, weight_decay=0)
self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer,
self._max_epoch_train_decoder - 10, 0)
self._scheduler = GradualWarmupScheduler(self._optimizer, 300, 10, self._scheduler)

def pretrain_decoder_run(self):
self.to(self._device)
self._projector.to(self._device)

self._model.enable_grad_decoder() # noqa
self._model.disable_grad_encoder() # noqa

for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_decoder):
pretrain_decoder_dict = PretrainDecoderEpoch(
model=self._model, projection_head=self._projector,
optimizer=self._optimizer,
pretrain_decoder_loader=self._pretrain_loader,
contrastive_criterion=SupConLoss(), num_batches=self._num_batches,
cur_epoch=self._cur_epoch, device=self._device
).run()
self._scheduler.step()
storage_dict = StorageIncomeDict(PRETRAIN_DECODER=pretrain_decoder_dict, )
self._pretrain_encoder_storage.put_from_dict(storage_dict, epoch=self._cur_epoch)
self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch)
self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_decoder"))
self.train_decoder_done = True

def finetune_network_init(self, *args, **kwargs):

self._optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-7, weight_decay=1e-5)
self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer,
self._max_epoch_train_finetune - 10, 0)
self._scheduler = GradualWarmupScheduler(self._optimizer, 200, 10, self._scheduler)
self._sup_criterion = KL_div()

def finetune_network_run(self, epocher_type=SimpleFineTuneEpoch):
self.to(self._device)
self._model.enable_grad_decoder() # noqa
self._model.enable_grad_encoder() # noqa

for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_finetune):
finetune_dict = epocher_type.create_from_trainer(self).run()
val_dict, cur_score = EvalEpoch(self._model, val_loader=self._val_loader, sup_criterion=self._sup_criterion,
cur_epoch=self._cur_epoch, device=self._device).run()
self._scheduler.step()
storage_dict = StorageIncomeDict(finetune=finetune_dict, val=val_dict)
self._finetune_storage.put_from_dict(storage_dict, epoch=self._cur_epoch)
self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch)
self.save(cur_score, os.path.join(self._save_dir, "finetune"))

def start_training(self, checkpoint: str = None):

with SummaryWriter(str(self._save_dir)) as self._writer: # noqa
if self.train_encoder:
self.pretrain_encoder_init()
if checkpoint is not None:
try:
self.load_state_dict_from_path(os.path.join(checkpoint, "pretrain_encoder"))
except Exception as e:
raise RuntimeError(f"loading pretrain_encoder_checkpoint failed with {e}, ")

if not self.train_encoder_done:
self.pretrain_encoder_run()
if self.train_decoder:
self.pretrain_decoder_init()
if checkpoint is not None:
try:
self.load_state_dict_from_path(os.path.join(checkpoint, "pretrain_decoder"))
except Exception as e:
print(f"loading pretrain_decoder_checkpoint failed with {e}, ")
if not self.train_decoder_done:
self.pretrain_decoder_run()
self.finetune_network_init()
if checkpoint is not None:
try:
self.load_state_dict_from_path(os.path.join(checkpoint, "finetune"))
except Exception as e:
print(f"loading finetune_checkpoint failed with {e}, ")
self.finetune_network_run()

0 comments on commit 9dce6c8

Please sign in to comment.