-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
jizong
committed
Jul 23, 2020
1 parent
e9ccf20
commit 9dce6c8
Showing
4 changed files
with
280 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,60 +1,63 @@ | ||
import torch | ||
from torch import nn | ||
|
||
from deepclustering2 import optim | ||
from deepclustering2.meters2 import EpochResultDict | ||
from deepclustering2.tqdm import tqdm | ||
from deepclustering2.optim import get_lrs_from_optimizer | ||
from deepclustering2.trainer.trainer import T_loader, T_loss | ||
from deepclustering2.utils import simplex, class2one_hot | ||
from .base_epocher import SemiEpocher | ||
|
||
|
||
class TrainEpoch(SemiEpocher.TrainEpoch): | ||
|
||
def __init__(self, model: nn.Module, labeled_loader: T_loader, unlabeled_loader: T_loader, sup_criteiron: T_loss, | ||
reg_criterion: T_loss, num_batches: int = 100, cur_epoch=0, device="cpu", | ||
reg_weight: float = 0.001) -> None: | ||
super().__init__(model, labeled_loader, unlabeled_loader, sup_criteiron, reg_criterion, num_batches, cur_epoch, | ||
device, reg_weight) | ||
assert reg_criterion # todo: add constraints on the reg_criterion | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
from .contrast_epocher import PretrainEncoderEpoch as _PretrainEncoderEpoch | ||
|
||
|
||
class IICPretrainEcoderEpoch(_PretrainEncoderEpoch): | ||
|
||
def __init__(self, model: nn.Module, projection_head: nn.Module, projection_classifier: nn.Module, | ||
optimizer: optim.Optimizer, pretrain_encoder_loader: T_loader = None, | ||
contrastive_criterion: T_loss = None, num_batches: int = 0, | ||
cur_epoch=0, device="cpu", iic_weight_ratio=1, *args, **kwargs) -> None: | ||
""" | ||
:param model: | ||
:param projection_head: here the projection head should be a classifier | ||
:param optimizer: | ||
:param pretrain_encoder_loader: | ||
:param contrastive_criterion: | ||
:param num_batches: | ||
:param cur_epoch: | ||
:param device: | ||
:param args: | ||
:param kwargs: | ||
""" | ||
assert pretrain_encoder_loader is not None | ||
self._projection_classifier = projection_classifier | ||
from ..losses.iic_loss import IIDLoss | ||
self._iic_criterion = IIDLoss() | ||
self._iic_weight_ratio = iic_weight_ratio | ||
super().__init__(model, projection_head, optimizer, pretrain_encoder_loader, contrastive_criterion, num_batches, | ||
cur_epoch, device, *args, **kwargs) | ||
|
||
def _run(self, *args, **kwargs) -> EpochResultDict: | ||
self._model.train() | ||
assert self._model.training, self._model.training | ||
report_dict: EpochResultDict | ||
self.meters["lr"].add(self._model.get_lr()[0]) | ||
self.meters["reg_weight"].add(self._reg_weight) | ||
|
||
with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: | ||
for i, label_data, unlabel_data in zip(indicator, self._labeled_loader, self._unlabeled_loader): | ||
((labelimage, labeltarget), (labelimage_tf, labeltarget_tf)), \ | ||
filename, partition_list, group_list = self._preprocess_data(label_data, self._device) | ||
((unlabelimage, _), (unlabelimage_tf, _)), unlabel_filename, \ | ||
unlabel_partition_list, unlabel_group_list = self._preprocess_data(unlabel_data, self._device) | ||
|
||
predict_logits = self._model( | ||
torch.cat([labelimage, labelimage_tf, unlabelimage, unlabelimage_tf], dim=0), | ||
force_simplex=False) | ||
assert not simplex(predict_logits), predict_logits | ||
label_logit, label_logit_tf, unlabel_logit, unlabel_logit_tf \ | ||
= torch.split(predict_logits, | ||
[len(labelimage), len(labelimage_tf), len(unlabelimage), len(unlabelimage_tf)], | ||
dim=0) | ||
onehot_ltarget = class2one_hot(torch.cat([labeltarget.squeeze(), labeltarget_tf.squeeze()], dim=0), | ||
4) | ||
sup_loss = self._sup_criterion(torch.cat([label_logit, label_logit_tf], dim=0).softmax(1), | ||
onehot_ltarget) | ||
reg_loss = self._reg_criterion(unlabel_logit.softmax(1), unlabel_logit_tf.softmax(1)) | ||
total_loss = sup_loss + reg_loss * self._reg_weight | ||
|
||
self._model.zero_grad() | ||
self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0]) | ||
|
||
with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: # noqa | ||
for i, data in zip(indicator, self._pretrain_encoder_loader): | ||
(img, _), (img_tf, _), filename, partition_list, group_list = self._preprocess_data(data, self._device) | ||
_, (e5, *_), *_ = self._model(torch.cat([img, img_tf], dim=0), return_features=True) | ||
global_enc, global_tf_enc = torch.chunk(F.normalize(self._projection_head(e5), dim=1), chunks=2, dim=0) | ||
# fixme: here lack of some code for IIC | ||
labels = self._label_generation(partition_list, group_list) | ||
contrastive_loss = self._contrastive_criterion(torch.stack([global_enc, global_tf_enc], dim=1), | ||
labels=labels) | ||
iic_loss = self._iic_criterion() # todo | ||
total_loss = self._iic_weight_ratio * iic_loss + (1 - self._iic_weight_ratio) * contrastive_loss | ||
self._optimizer.zero_grad() | ||
total_loss.backward() | ||
self._model.step() | ||
|
||
self._optimizer.step() | ||
# todo: meter recording. | ||
with torch.no_grad(): | ||
self.meters["sup_loss"].add(sup_loss.item()) | ||
self.meters["ds"].add(label_logit.max(1)[1], labeltarget.squeeze(1), | ||
group_name=list(group_list)) | ||
self.meters["reg_loss"].add(reg_loss.item()) | ||
self.meters["contrastive_loss"].add(contrastive_loss.item()) | ||
report_dict = self.meters.tracking_status() | ||
indicator.set_postfix_dict(report_dict) | ||
report_dict = self.meters.tracking_status() | ||
return report_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .base_epocher import EvalEpoch, SimpleFineTuneEpoch, MeanTeacherEpocher | ||
from .contrast_epocher import PretrainDecoderEpoch, PretrainEncoderEpoch | ||
from .contrast_epocher import PretrainDecoderEpoch, PretrainEncoderEpoch | ||
from .IIC_epocher import IICPretrainEcoderEpoch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import itertools | ||
import os | ||
from pathlib import Path | ||
|
||
import torch | ||
from contrastyou import PROJECT_PATH | ||
from contrastyou.epocher import PretrainDecoderEpoch, SimpleFineTuneEpoch, IICPretrainEcoderEpoch | ||
from contrastyou.epocher.base_epocher import EvalEpoch | ||
from contrastyou.losses.contrast_loss import SupConLoss | ||
from contrastyou.trainer._utils import Flatten | ||
from contrastyou.trainer.contrast_trainer import ContrastTrainer | ||
from deepclustering2.loss import KL_div | ||
from deepclustering2.meters2 import StorageIncomeDict | ||
from deepclustering2.schedulers import GradualWarmupScheduler | ||
from deepclustering2.writer import SummaryWriter | ||
from torch import nn | ||
|
||
|
||
class IICContrastTrainer(ContrastTrainer): | ||
RUN_PATH = Path(PROJECT_PATH) / "runs" | ||
|
||
def pretrain_encoder_init(self, num_clusters=20, *args, **kwargs): | ||
# adding optimizer and scheduler | ||
self._projector_contrastive = nn.Sequential( | ||
nn.AdaptiveAvgPool2d((1, 1)), | ||
Flatten(), | ||
nn.Linear(256, 256), | ||
nn.LeakyReLU(0.01, inplace=True), | ||
nn.Linear(256, 256), | ||
) | ||
self._projector_iic = nn.Sequential( | ||
nn.AdaptiveAvgPool2d((1, 1)), | ||
Flatten(), | ||
nn.Linear(256, 256), | ||
nn.LeakyReLU(0.01, inplace=True), | ||
nn.Linear(256, num_clusters), | ||
nn.Softmax(1) | ||
) | ||
self._optimizer = torch.optim.Adam( | ||
itertools.chain(self._model.parameters(), self._projector_contrastive.parameters(), | ||
self._projector_iic.parameters()), lr=1e-6, weight_decay=1e-5) # noqa | ||
self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer, | ||
self._max_epoch_train_encoder - 10, 0) | ||
self._scheduler = GradualWarmupScheduler(self._optimizer, 300, 10, self._scheduler) # noqa | ||
|
||
def pretrain_encoder_run(self, *args, **kwargs): | ||
self.to(self._device) | ||
self._model.enable_grad_encoder() # noqa | ||
self._model.disable_grad_decoder() # noqa | ||
|
||
for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_encoder): | ||
pretrain_encoder_dict = IICPretrainEcoderEpoch( | ||
model=self._model, projection_head=self._projector_contrastive, | ||
projection_classifier=self._projector_iic, | ||
optimizer=self._optimizer, | ||
pretrain_encoder_loader=self._pretrain_loader, | ||
contrastive_criterion=SupConLoss(), num_batches=self._num_batches, | ||
cur_epoch=self._cur_epoch, device=self._device | ||
).run() | ||
self._scheduler.step() | ||
storage_dict = StorageIncomeDict(PRETRAIN_ENCODER=pretrain_encoder_dict, ) | ||
self._pretrain_encoder_storage.put_from_dict(storage_dict, epoch=self._cur_epoch) | ||
self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch) | ||
self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_encoder")) | ||
self.train_encoder_done = True | ||
|
||
def pretrain_decoder_init(self, *args, **kwargs): | ||
# adding optimizer and scheduler | ||
self._projector = nn.Sequential( | ||
nn.Conv2d(64, 64, 3, 1, 1), | ||
nn.BatchNorm2d(64), | ||
nn.LeakyReLU(0.01, inplace=True), | ||
nn.Conv2d(64, 32, 3, 1, 1) | ||
) | ||
self._optimizer = torch.optim.Adam(itertools.chain(self._model.parameters(), self._projector.parameters()), | ||
lr=1e-6, weight_decay=0) | ||
self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer, | ||
self._max_epoch_train_decoder - 10, 0) | ||
self._scheduler = GradualWarmupScheduler(self._optimizer, 300, 10, self._scheduler) | ||
|
||
def pretrain_decoder_run(self): | ||
self.to(self._device) | ||
self._projector.to(self._device) | ||
|
||
self._model.enable_grad_decoder() # noqa | ||
self._model.disable_grad_encoder() # noqa | ||
|
||
for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_decoder): | ||
pretrain_decoder_dict = PretrainDecoderEpoch( | ||
model=self._model, projection_head=self._projector, | ||
optimizer=self._optimizer, | ||
pretrain_decoder_loader=self._pretrain_loader, | ||
contrastive_criterion=SupConLoss(), num_batches=self._num_batches, | ||
cur_epoch=self._cur_epoch, device=self._device | ||
).run() | ||
self._scheduler.step() | ||
storage_dict = StorageIncomeDict(PRETRAIN_DECODER=pretrain_decoder_dict, ) | ||
self._pretrain_encoder_storage.put_from_dict(storage_dict, epoch=self._cur_epoch) | ||
self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch) | ||
self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_decoder")) | ||
self.train_decoder_done = True | ||
|
||
def finetune_network_init(self, *args, **kwargs): | ||
|
||
self._optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-7, weight_decay=1e-5) | ||
self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optimizer, | ||
self._max_epoch_train_finetune - 10, 0) | ||
self._scheduler = GradualWarmupScheduler(self._optimizer, 200, 10, self._scheduler) | ||
self._sup_criterion = KL_div() | ||
|
||
def finetune_network_run(self, epocher_type=SimpleFineTuneEpoch): | ||
self.to(self._device) | ||
self._model.enable_grad_decoder() # noqa | ||
self._model.enable_grad_encoder() # noqa | ||
|
||
for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_finetune): | ||
finetune_dict = epocher_type.create_from_trainer(self).run() | ||
val_dict, cur_score = EvalEpoch(self._model, val_loader=self._val_loader, sup_criterion=self._sup_criterion, | ||
cur_epoch=self._cur_epoch, device=self._device).run() | ||
self._scheduler.step() | ||
storage_dict = StorageIncomeDict(finetune=finetune_dict, val=val_dict) | ||
self._finetune_storage.put_from_dict(storage_dict, epoch=self._cur_epoch) | ||
self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch) | ||
self.save(cur_score, os.path.join(self._save_dir, "finetune")) | ||
|
||
def start_training(self, checkpoint: str = None): | ||
|
||
with SummaryWriter(str(self._save_dir)) as self._writer: # noqa | ||
if self.train_encoder: | ||
self.pretrain_encoder_init() | ||
if checkpoint is not None: | ||
try: | ||
self.load_state_dict_from_path(os.path.join(checkpoint, "pretrain_encoder")) | ||
except Exception as e: | ||
raise RuntimeError(f"loading pretrain_encoder_checkpoint failed with {e}, ") | ||
|
||
if not self.train_encoder_done: | ||
self.pretrain_encoder_run() | ||
if self.train_decoder: | ||
self.pretrain_decoder_init() | ||
if checkpoint is not None: | ||
try: | ||
self.load_state_dict_from_path(os.path.join(checkpoint, "pretrain_decoder")) | ||
except Exception as e: | ||
print(f"loading pretrain_decoder_checkpoint failed with {e}, ") | ||
if not self.train_decoder_done: | ||
self.pretrain_decoder_run() | ||
self.finetune_network_init() | ||
if checkpoint is not None: | ||
try: | ||
self.load_state_dict_from_path(os.path.join(checkpoint, "finetune")) | ||
except Exception as e: | ||
print(f"loading finetune_checkpoint failed with {e}, ") | ||
self.finetune_network_run() |