Skip to content

Commit

Permalink
first try on iic loss
Browse files Browse the repository at this point in the history
  • Loading branch information
jizong committed Jul 25, 2020
1 parent 58485d5 commit 4c052cb
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 116 deletions.
2 changes: 1 addition & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Trainer:
train_decoder: True

PretrainEncoder:
group_option: patient
group_option: partition

PretrainDecoder:
null
Expand Down
21 changes: 15 additions & 6 deletions contrastyou/epocher/IIC_epocher.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import random

import torch
from torch import nn
from torch.nn import functional as F

from contrastyou.epocher._utils import unfold_position
from deepclustering2 import optim
from deepclustering2.decorator import FixRandomSeed
from deepclustering2.meters2 import EpochResultDict
from deepclustering2.meters2 import EpochResultDict, MeterInterface, AverageValueMeter
from deepclustering2.optim import get_lrs_from_optimizer
from deepclustering2.tqdm import tqdm
from deepclustering2.trainer.trainer import T_loader, T_loss
from torch import nn
from torch.nn import functional as F

from .contrast_epocher import PretrainDecoderEpoch as _PretrainDecoderEpoch
from .contrast_epocher import PretrainEncoderEpoch as _PretrainEncoderEpoch

Expand Down Expand Up @@ -41,6 +42,11 @@ def __init__(self, model: nn.Module, projection_head: nn.Module, projection_clas
self._iic_criterion = IIDLoss()
self._iic_weight_ratio = iic_weight_ratio

def _configure_meters(self, meters: MeterInterface) -> MeterInterface:
meters.register_meter("iic_loss", AverageValueMeter())
meters = super()._configure_meters(meters)
return meters

def _run(self, *args, **kwargs) -> EpochResultDict:
self._model.train()
assert self._model.training, self._model.training
Expand All @@ -56,14 +62,15 @@ def _run(self, *args, **kwargs) -> EpochResultDict:
labels = self._label_generation(partition_list, group_list)
contrastive_loss = self._contrastive_criterion(torch.stack([global_enc, global_tf_enc], dim=1),
labels=labels)
iic_loss = self._iic_criterion(global_probs, global_tf_probs) # todo
iic_loss = self._iic_criterion(global_probs, global_tf_probs)[0] # todo
total_loss = self._iic_weight_ratio * iic_loss + (1 - self._iic_weight_ratio) * contrastive_loss
self._optimizer.zero_grad()
total_loss.backward()
self._optimizer.step()
# todo: meter recording.
with torch.no_grad():
self.meters["contrastive_loss"].add(contrastive_loss.item())
self.meters["iic_loss"].add(iic_loss.item())
report_dict = self.meters.tracking_status()
indicator.set_postfix_dict(report_dict)
return report_dict
Expand Down Expand Up @@ -98,7 +105,9 @@ def _run(self, *args, **kwargs) -> EpochResultDict:
assert d4_ctf_gtf.shape == d4_ctf.shape, (d4_ctf_gtf.shape, d4_ctf.shape)
d4_tf = torch.cat([d4_gtf, d4_ctf_gtf], dim=0)
local_enc_tf, local_enc_tf_ctf = torch.chunk(self._projection_head(d4_tf), chunks=2, dim=0)
# todo: convert representation to distance
# todo: iic local presentation
pass

local_enc_unfold, _ = unfold_position(local_enc_tf, partition_num=(2, 2))
local_tf_enc_unfold, _fold_partition = unfold_position(local_enc_tf_ctf, partition_num=(2, 2))
b, *_ = local_enc_unfold.shape
Expand Down
9 changes: 2 additions & 7 deletions contrastyou/losses/iic_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,12 @@ def forward(self, x_out: Tensor, x_tf_out: Tensor):
# p_i = x_out.mean(0).view(k, 1).expand(k, k)
# p_j = x_tf_out.mean(0).view(1, k).expand(k, k)
#
# avoid NaN losses. Effect will get cancelled out by p_i_j tiny anyway
if self.torch_vision < "1.3.0":
p_i_j[p_i_j < self.eps] = self.eps
p_j[p_j < self.eps] = self.eps
p_i[p_i < self.eps] = self.eps

loss = -p_i_j * (
torch.log(p_i_j) - self.lamb * torch.log(p_j) - self.lamb * torch.log(p_i)
torch.log(p_i_j+1e-10) - self.lamb * torch.log(p_j+1e-10) - self.lamb * torch.log(p_i+1e-10)
)
loss = loss.sum()
loss_no_lamb = -p_i_j * (torch.log(p_i_j) - torch.log(p_j) - torch.log(p_i))
loss_no_lamb = -p_i_j * (torch.log(p_i_j+1e-10) - torch.log(p_j+1e-10) - torch.log(p_i+1e-10))
loss_no_lamb = loss_no_lamb.sum()
return loss, loss_no_lamb

Expand Down
3 changes: 2 additions & 1 deletion contrastyou/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .contrast_trainer import ContrastTrainer, ContrastTrainerMT
from .iic_trainer import IICContrastTrainer

trainer_zoos = {"contrast": ContrastTrainer, "contrastMT":ContrastTrainerMT}
trainer_zoos = {"contrast": ContrastTrainer, "contrastMT": ContrastTrainerMT, "iiccontrast": IICContrastTrainer}
157 changes: 56 additions & 101 deletions contrastyou/trainer/iic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,21 @@
from pathlib import Path

import torch
from torch import nn

from contrastyou import PROJECT_PATH
from contrastyou.epocher import PretrainDecoderEpoch, SimpleFineTuneEpoch, IICPretrainEcoderEpoch
from contrastyou.epocher.base_epocher import EvalEpoch
from contrastyou.epocher.IIC_epocher import IICPretrainEcoderEpoch, IICPretrainDecoderEpoch
from contrastyou.losses.contrast_loss import SupConLoss
from contrastyou.trainer._utils import Flatten
from contrastyou.trainer.contrast_trainer import ContrastTrainer
from deepclustering2.loss import KL_div
from deepclustering2.meters2 import StorageIncomeDict
from deepclustering2.schedulers import GradualWarmupScheduler
from deepclustering2.writer import SummaryWriter
from torch import nn


class IICContrastTrainer(ContrastTrainer):
RUN_PATH = Path(PROJECT_PATH) / "runs"

def pretrain_encoder_init(self, group_option, num_clusters=20):
# adding optimizer and scheduler
def pretrain_encoder_init(self, group_option, num_clusters=40, iic_weight=1):
self._projector_contrastive = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
Flatten(),
Expand All @@ -43,19 +40,25 @@ def pretrain_encoder_init(self, group_option, num_clusters=20):
self._max_epoch_train_encoder - 10, 0)
self._scheduler = GradualWarmupScheduler(self._optimizer, 300, 10, self._scheduler) # noqa

def pretrain_encoder_run(self, *args, **kwargs):
self._group_option = group_option # noqa

# set augmentation method as `total_freedom = True`
self._pretrain_loader.dataset._transform._total_freedom = True # noqa
self._pretrain_loader_iter = iter(self._pretrain_loader) # noqa
self._iic_weight = iic_weight

def pretrain_encoder_run(self):
self.to(self._device)
self._model.enable_grad_encoder() # noqa
self._model.disable_grad_decoder() # noqa

for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_encoder):
pretrain_encoder_dict = IICPretrainEcoderEpoch(
model=self._model, projection_head=self._projector_contrastive,
projection_classifier=self._projector_iic,
optimizer=self._optimizer,
pretrain_encoder_loader=self._pretrain_loader,
contrastive_criterion=SupConLoss(), num_batches=self._num_batches,
cur_epoch=self._cur_epoch, device=self._device
projection_classifier=self._projector_iic, optimizer=self._optimizer,
pretrain_encoder_loader=self._pretrain_loader_iter, contrastive_criterion=SupConLoss(),
num_batches=self._num_batches, cur_epoch=self._cur_epoch, device=self._device,
group_option=self._group_option, iic_weight_ratio=self._iic_weight,
).run()
self._scheduler.step()
storage_dict = StorageIncomeDict(PRETRAIN_ENCODER=pretrain_encoder_dict, )
Expand All @@ -64,91 +67,43 @@ def pretrain_encoder_run(self, *args, **kwargs):
self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_encoder"))
self.train_encoder_done = True

def pretrain_decoder_init(self, *args, **kwargs):
# adding optimizer and scheduler
self._projector = nn.Sequential(
nn.Conv2d(64, 64, 3, 1, 1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.01, inplace=True),
nn.Conv2d(64, 32, 3, 1, 1)
)
self._optimizer = torch.optim.Adam(itertools.chain(self._model.parameters(), self._projector.parameters()),
lr=1e-6, weight_decay=0)
self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer,
self._max_epoch_train_decoder - 10, 0)
self._scheduler = GradualWarmupScheduler(self._optimizer, 300, 10, self._scheduler)

def pretrain_decoder_run(self):
self.to(self._device)
self._projector.to(self._device)

self._model.enable_grad_decoder() # noqa
self._model.disable_grad_encoder() # noqa

for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_decoder):
pretrain_decoder_dict = PretrainDecoderEpoch(
model=self._model, projection_head=self._projector,
optimizer=self._optimizer,
pretrain_decoder_loader=self._pretrain_loader,
contrastive_criterion=SupConLoss(), num_batches=self._num_batches,
cur_epoch=self._cur_epoch, device=self._device
).run()
self._scheduler.step()
storage_dict = StorageIncomeDict(PRETRAIN_DECODER=pretrain_decoder_dict, )
self._pretrain_encoder_storage.put_from_dict(storage_dict, epoch=self._cur_epoch)
self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch)
self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_decoder"))
self.train_decoder_done = True

def finetune_network_init(self, *args, **kwargs):

self._optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-7, weight_decay=1e-5)
self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer,
self._max_epoch_train_finetune - 10, 0)
self._scheduler = GradualWarmupScheduler(self._optimizer, 200, 10, self._scheduler)
self._sup_criterion = KL_div()

def finetune_network_run(self, epocher_type=SimpleFineTuneEpoch):
self.to(self._device)
self._model.enable_grad_decoder() # noqa
self._model.enable_grad_encoder() # noqa

for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_finetune):
finetune_dict = epocher_type.create_from_trainer(self).run()
val_dict, cur_score = EvalEpoch(self._model, val_loader=self._val_loader, sup_criterion=self._sup_criterion,
cur_epoch=self._cur_epoch, device=self._device).run()
self._scheduler.step()
storage_dict = StorageIncomeDict(finetune=finetune_dict, val=val_dict)
self._finetune_storage.put_from_dict(storage_dict, epoch=self._cur_epoch)
self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch)
self.save(cur_score, os.path.join(self._save_dir, "finetune"))

def start_training(self, checkpoint: str = None):

with SummaryWriter(str(self._save_dir)) as self._writer: # noqa
if self.train_encoder:
self.pretrain_encoder_init()
if checkpoint is not None:
try:
self.load_state_dict_from_path(os.path.join(checkpoint, "pretrain_encoder"))
except Exception as e:
raise RuntimeError(f"loading pretrain_encoder_checkpoint failed with {e}, ")

if not self.train_encoder_done:
self.pretrain_encoder_run()
if self.train_decoder:
self.pretrain_decoder_init()
if checkpoint is not None:
try:
self.load_state_dict_from_path(os.path.join(checkpoint, "pretrain_decoder"))
except Exception as e:
print(f"loading pretrain_decoder_checkpoint failed with {e}, ")
if not self.train_decoder_done:
self.pretrain_decoder_run()
self.finetune_network_init()
if checkpoint is not None:
try:
self.load_state_dict_from_path(os.path.join(checkpoint, "finetune"))
except Exception as e:
print(f"loading finetune_checkpoint failed with {e}, ")
self.finetune_network_run()
# def pretrain_decoder_init(self, *args, **kwargs):
# # adding optimizer and scheduler
# self._projector_contrastive = nn.Sequential(
# nn.Conv2d(64, 64, 3, 1, 1),
# nn.BatchNorm2d(64),
# nn.LeakyReLU(0.01, inplace=True),
# nn.Conv2d(64, 32, 3, 1, 1)
# )
# self._proejctor_iic = None
# self._optimizer = torch.optim.Adam(itertools.chain(self._model.parameters(), self._projector.parameters()),
# lr=1e-6, weight_decay=0)
# self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer,
# self._max_epoch_train_decoder - 10, 0)
# self._scheduler = GradualWarmupScheduler(self._optimizer, 300, 10, self._scheduler)
#
# # set augmentation method as `total_freedom = False`
# self._pretrain_loader.dataset._transform._total_freedom = False # noqa
# self._pretrain_loader_iter = iter(self._pretrain_loader) # noqa
#
# def pretrain_decoder_run(self):
# self.to(self._device)
# self._projector.to(self._device)
#
# self._model.enable_grad_decoder() # noqa
# self._model.disable_grad_encoder() # noqa
#
# for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_decoder):
# pretrain_decoder_dict = IICPretrainDecoderEpoch(
# model=self._model, projection_head=self._projector, projection_classifier=self._projector_iic,
# optimizer=self._optimizer,
# pretrain_decoder_loader=self._pretrain_loader_iter,
# contrastive_criterion=SupConLoss(), num_batches=self._num_batches,
# cur_epoch=self._cur_epoch, device=self._device,
# ).run()
# self._scheduler.step()
# storage_dict = StorageIncomeDict(PRETRAIN_DECODER=pretrain_decoder_dict, )
# self._pretrain_encoder_storage.put_from_dict(storage_dict, epoch=self._cur_epoch)
# self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch)
# self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_decoder"))
# self.train_decoder_done = True
50 changes: 50 additions & 0 deletions run_script_iic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import argparse
from itertools import cycle

from deepclustering2.cchelper import JobSubmiter

parser = argparse.ArgumentParser()

parser.add_argument("-l", "--label_ratio", default=0.1, type=float)
parser.add_argument("-n", "--trainer_name", required=True, type=str)
parser.add_argument("-b", "--num_batches", default=100, type=int)
parser.add_argument("-s", "--random_seed", default=1, type=int)
parser.add_argument("-o", "--contrast_on", default="partition", type=str)
parser.add_argument("-c", "--num_clusters", default=40, type=int)

args = parser.parse_args()

num_batches = args.num_batches
random_seed = args.random_seed

labeled_data_ratio = args.label_ratio
unlabeled_data_ratio = 1 - labeled_data_ratio

trainer_name = args.trainer_name
assert trainer_name == "iiccontrast"
contrast_on = args.contrast_on
save_dir = f"iic_contrast/label_data_ration_{labeled_data_ratio}/{trainer_name}/contrast_on_{contrast_on}"

common_opts = f" Trainer.name={trainer_name} PretrainEncoder.group_option={contrast_on} " \
f" RetrainEncoder.num_clusters={args.num_clusters} RandomSeed={random_seed} " \
f" Data.labeled_data_ratio={labeled_data_ratio} Data.unlabeled_data_ratio={unlabeled_data_ratio} " \
f" Trainer.num_batches={num_batches} "
if trainer_name == "contrastMT":
common_opts += f" FineTune.reg_weight={args.reg_weight} "

jobs = [
f"python -O main_contrast.py {common_opts} Trainer.save_dir={save_dir}/baseline Trainer.train_encoder=False Trainer.train_decoder=False ",
f"python -O main_contrast.py {common_opts} Trainer.save_dir={save_dir}/iic_0.0 Trainer.train_encoder=True Trainer.train_decoder=False PretrainEncoder.iic_weight=0.0",
f"python -O main_contrast.py {common_opts} Trainer.save_dir={save_dir}/iic_0.5 Trainer.train_encoder=True Trainer.train_decoder=False PretrainEncoder.iic_weight=0.5",
f"python -O main_contrast.py {common_opts} Trainer.save_dir={save_dir}/iic_1.0 Trainer.train_encoder=True Trainer.train_decoder=False PretrainEncoder.iic_weight=1.0",
]

# CC things
accounts = cycle(["def-chdesa", "def-mpederso", "rrg-mpederso"])

jobsubmiter = JobSubmiter(project_path="./", on_local=True, time=4)
for j in jobs:
jobsubmiter.prepare_env(["source ./venv/bin/activate ", "export OMP_NUM_THREADS=1", ])
jobsubmiter.account = next(accounts)
jobsubmiter.run(j)
# print(j)

0 comments on commit 4c052cb

Please sign in to comment.