Skip to content

Commit

Permalink
try to make decode training work
Browse files Browse the repository at this point in the history
  • Loading branch information
jizong committed Jul 25, 2020
1 parent 9dce6c8 commit 58485d5
Show file tree
Hide file tree
Showing 16 changed files with 542 additions and 179 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ venv
__pycache__/
*.py[cod]
*$py.class

*.out
# C extensions
*.so

Expand Down
12 changes: 9 additions & 3 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@ Trainer:
max_epoch_train_finetune: 100
train_encoder: True
train_decoder: True
# for mt trainer
transform_axis: [1, 2]
reg_weight: 10

PretrainEncoder:
group_option: patient

PretrainDecoder:
null
FineTune:
reg_weight: 15



#Checkpoint: runs/test_pipeline
47 changes: 12 additions & 35 deletions contrastyou/augment/__init__.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,24 @@
from typing import Callable, Union, List, Tuple
from torchvision import transforms

from deepclustering2.augment import pil_augment, SequentialWrapper
from contrastyou.augment.sequential_wrapper import SequentialWrapperTwice, SequentialWrapper
from deepclustering2.augment import pil_augment


class SequentialWrapperTwice(SequentialWrapper):

def __init__(self, img_transform: Callable = None, target_transform: Callable = None,
if_is_target: Union[List[bool], Tuple[bool, ...]] = []) -> None:
super().__init__(img_transform, target_transform, if_is_target)

def __call__(
self, *imgs, random_seed=None
):
return [
super(SequentialWrapperTwice, self).__call__(*imgs, random_seed=random_seed),
super(SequentialWrapperTwice, self).__call__(*imgs, random_seed=random_seed),
]


class ACDC_transforms:
class ACDCTransforms:
train = SequentialWrapperTwice(
pil_augment.Compose([
comm_transform=pil_augment.Compose([
pil_augment.RandomCrop(224),
pil_augment.RandomRotation(30),
pil_augment.ToTensor()
]),
pil_augment.Compose([
pil_augment.RandomCrop(224),
pil_augment.RandomRotation(30),
img_transform=pil_augment.Compose([
transforms.ColorJitter(brightness=[0.5, 1.5], contrast=[0.5, 1.5], saturation=[0.5, 1.5]),
transforms.ToTensor()
]),
target_transform=pil_augment.Compose([
pil_augment.ToLabel()
]),
if_is_target=[False, True]

total_freedom=True
)
val = SequentialWrapper(
pil_augment.Compose([
pil_augment.CenterCrop(224),
pil_augment.ToTensor()
]),
pil_augment.Compose([
pil_augment.CenterCrop(224),
pil_augment.ToLabel()
]),
if_is_target=[False, True]

comm_transform=pil_augment.CenterCrop(224)
)
100 changes: 100 additions & 0 deletions contrastyou/augment/sequential_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import random
from typing import Callable, List

from PIL import Image
from torch import Tensor

from deepclustering2.augment import pil_augment
from deepclustering2.decorator import FixRandomSeed


class SequentialWrapper:

def __init__(
self,
comm_transform: Callable[[Image.Image], Image.Image] = None,
img_transform: Callable[[Image.Image], Tensor] = pil_augment.ToTensor(),
target_transform: Callable[[Image.Image], Tensor] = pil_augment.ToLabel()
) -> None:
"""
:param comm_transform: common geo-transformation
:param img_transform: transformation only applied for images
:param target_transform: transformation only applied for targets
"""
self._comm_transform = comm_transform
self._img_transform = img_transform
self._target_transform = target_transform

def __call__(self, imgs: List[Image.Image], targets: List[Image.Image] = None, comm_seed=None, img_seed=None,
target_seed=None):
_comm_seed: int = int(random.randint(0, int(1e5))) if comm_seed is None else int(comm_seed) # type ignore
imgs_after_comm, targets_after_comm = imgs, targets
if self._comm_transform:
imgs_after_comm, targets_after_comm = [], []
for img in imgs:
with FixRandomSeed(_comm_seed):
img_ = self._comm_transform(img)
imgs_after_comm.append(img_)
if targets:
for target in targets:
with FixRandomSeed(_comm_seed):
target_ = self._comm_transform(target)
targets_after_comm.append(target_)
imgs_after_img_transform = []
targets_after_target_transform = []
_img_seed: int = int(random.randint(0, int(1e5))) if img_seed is None else int(img_seed) # type ignore
for img in imgs_after_comm:
with FixRandomSeed(_img_seed):
img_ = self._img_transform(img)
imgs_after_img_transform.append(img_)

_target_seed: int = int(random.randint(0, int(1e5))) if target_seed is None else int(target_seed) # type ignore
if targets_after_comm:
for target in targets_after_comm:
with FixRandomSeed(_target_seed):
target_ = self._target_transform(target)
targets_after_target_transform.append(target_)

if targets is None:
targets_after_target_transform = None

if targets_after_target_transform is None:
return imgs_after_img_transform
return [*imgs_after_img_transform, *targets_after_target_transform]

def __repr__(self):
return (
f"comm_transform:{self._comm_transform}\n"
f"img_transform:{self._img_transform}.\n"
f"target_transform: {self._target_transform}"
)


class SequentialWrapperTwice(SequentialWrapper):

def __init__(self, comm_transform: Callable[[Image.Image], Image.Image] = None,
img_transform: Callable[[Image.Image], Tensor] = pil_augment.ToTensor(),
target_transform: Callable[[Image.Image], Tensor] = pil_augment.ToLabel(),
total_freedom=True) -> None:
"""
:param total_freedom: if True, the two-time generated images are using different seeds for all aspect,
otherwise, the images are used different random seed only for img_seed
"""
super().__init__(comm_transform, img_transform, target_transform)
self._total_freedom = total_freedom

def __call__(self, imgs: List[Image.Image], targets: List[Image.Image] = None, global_seed=None, **kwargs):
global_seed = int(random.randint(0, int(1e5))) if global_seed is None else int(global_seed) # type ignore
with FixRandomSeed(global_seed):
comm_seed1, comm_seed2 = int(random.randint(0, int(1e5))), int(random.randint(0, int(1e5)))
img_seed1, img_seed2 = int(random.randint(0, int(1e5))), int(random.randint(0, int(1e5)))
target_seed1, target_seed2 = int(random.randint(0, int(1e5))), int(random.randint(0, int(1e5)))
if self._total_freedom:
return [
super().__call__(imgs, targets, comm_seed1, img_seed1, target_seed1),
super().__call__(imgs, targets, comm_seed2, img_seed2, target_seed2),
]
return [
super().__call__(imgs, targets, comm_seed1, img_seed1, target_seed1),
super().__call__(imgs, targets, comm_seed1, img_seed2, target_seed1),
]
2 changes: 1 addition & 1 deletion contrastyou/dataloader/_seg_datset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class _SamplerIterator:
def __init__(self, group2index, partion2index, group_sample_num=4, partition_sample_num=1) -> None:
self._group2index, self._partition2index = dcp(group2index), dcp(partion2index)

assert group_sample_num >= 1 and group_sample_num <= len(self._group2index.keys()), group_sample_num
assert 1 <= group_sample_num <= len(self._group2index.keys()), group_sample_num
self._group_sample_num = group_sample_num
self._partition_sample_num = partition_sample_num

Expand Down
18 changes: 11 additions & 7 deletions contrastyou/dataloader/acdc_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import re
from pathlib import Path
from typing import List, Tuple, Union

import numpy as np
from torch import Tensor

from contrastyou.augment.sequential_wrapper import SequentialWrapper
from contrastyou.dataloader._seg_datset import ContrastDataset
from deepclustering2.augment import SequentialWrapper
from deepclustering2.dataset import ACDCDataset as _ACDCDataset, ACDCSemiInterface as _ACDCSemiInterface


Expand All @@ -15,31 +16,34 @@ class ACDCDataset(ContrastDataset, _ACDCDataset):
zip_name = "ACDC_contrast.zip"
folder_name = "ACDC_contrast"

def __init__(self, root_dir: str, mode: str, transforms: SequentialWrapper = None,
def __init__(self, root_dir: str, mode: str, transforms: SequentialWrapper = SequentialWrapper(),
verbose=True, *args, **kwargs) -> None:
super().__init__(root_dir, mode, ["img", "gt"], transforms, verbose)
self._acdc_info = np.load(os.path.join(self._root_dir, "acdc_info.npy"), allow_pickle=True).item()
assert isinstance(self._acdc_info, dict) and len(self._acdc_info) == 200
self._transform = transforms

def __getitem__(self, index) -> Tuple[List[Tensor], str, str, str]:
data, filename = super().__getitem__(index)
[img_png, target_png], filename_list = self._getitem_index(index)
filename = Path(filename_list[0]).stem
data = self._transform(imgs=[img_png], targets=[target_png], )
partition = self._get_partition(filename)
group = self._get_group(filename)
return data, filename, partition, group

def _get_group(self, filename) -> Union[str, int]:
return self._get_group_name(filename)
return str(self._get_group_name(filename))

def _get_partition(self, filename) -> Union[str, int]:
# set partition
max_len_given_group = self._acdc_info[self._get_group_name(filename)]
cutting_point = max_len_given_group // 3
cur_index = int(re.compile(r"\d+").findall(filename)[-1])
if cur_index <= cutting_point - 1:
return 0
return str(0)
if cur_index <= 2 * cutting_point:
return 1
return 2
return str(1)
return str(2)

def show_paritions(self) -> List[Union[str, int]]:
return [self._get_partition(f) for f in list(self._filenames.values())[0]]
Expand Down
82 changes: 71 additions & 11 deletions contrastyou/epocher/IIC_epocher.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,45 @@
import random

import torch
from contrastyou.epocher._utils import unfold_position
from deepclustering2 import optim
from deepclustering2.decorator import FixRandomSeed
from deepclustering2.meters2 import EpochResultDict
from deepclustering2.optim import get_lrs_from_optimizer
from deepclustering2.trainer.trainer import T_loader, T_loss
from torch import nn
from torch.nn import functional as F

from .contrast_epocher import PretrainDecoderEpoch as _PretrainDecoderEpoch
from .contrast_epocher import PretrainEncoderEpoch as _PretrainEncoderEpoch


class IICPretrainEcoderEpoch(_PretrainEncoderEpoch):

def __init__(self, model: nn.Module, projection_head: nn.Module, projection_classifier: nn.Module,
optimizer: optim.Optimizer, pretrain_encoder_loader: T_loader = None,
contrastive_criterion: T_loss = None, num_batches: int = 0,
cur_epoch=0, device="cpu", iic_weight_ratio=1, *args, **kwargs) -> None:
optimizer: optim.Optimizer, pretrain_encoder_loader: T_loader,
contrastive_criterion: T_loss, num_batches: int = 0,
cur_epoch=0, device="cpu", group_option: str = "partition", iic_weight_ratio=1) -> None:
"""
:param model:
:param projection_head: here the projection head should be a classifier
:param projection_head:
:param projection_classifier: classification head
:param optimizer:
:param pretrain_encoder_loader:
:param pretrain_encoder_loader: infinite dataloader with `total freedom = True`
:param contrastive_criterion:
:param num_batches:
:param cur_epoch:
:param device:
:param args:
:param kwargs:
:param iic_weight_ratio: iic weight_ratio
"""
super(IICPretrainEcoderEpoch, self).__init__(model, projection_head, optimizer, pretrain_encoder_loader,
contrastive_criterion, num_batches,
cur_epoch, device, group_option=group_option)
assert pretrain_encoder_loader is not None
self._projection_classifier = projection_classifier
from ..losses.iic_loss import IIDLoss
self._iic_criterion = IIDLoss()
self._iic_weight_ratio = iic_weight_ratio
super().__init__(model, projection_head, optimizer, pretrain_encoder_loader, contrastive_criterion, num_batches,
cur_epoch, device, *args, **kwargs)

def _run(self, *args, **kwargs) -> EpochResultDict:
self._model.train()
Expand All @@ -46,11 +51,12 @@ def _run(self, *args, **kwargs) -> EpochResultDict:
(img, _), (img_tf, _), filename, partition_list, group_list = self._preprocess_data(data, self._device)
_, (e5, *_), *_ = self._model(torch.cat([img, img_tf], dim=0), return_features=True)
global_enc, global_tf_enc = torch.chunk(F.normalize(self._projection_head(e5), dim=1), chunks=2, dim=0)
global_probs, global_tf_probs = torch.chunk(self._projection_classifier(e5), chunks=2, dim=0)
# fixme: here lack of some code for IIC
labels = self._label_generation(partition_list, group_list)
contrastive_loss = self._contrastive_criterion(torch.stack([global_enc, global_tf_enc], dim=1),
labels=labels)
iic_loss = self._iic_criterion() # todo
iic_loss = self._iic_criterion(global_probs, global_tf_probs) # todo
total_loss = self._iic_weight_ratio * iic_loss + (1 - self._iic_weight_ratio) * contrastive_loss
self._optimizer.zero_grad()
total_loss.backward()
Expand All @@ -61,3 +67,57 @@ def _run(self, *args, **kwargs) -> EpochResultDict:
report_dict = self.meters.tracking_status()
indicator.set_postfix_dict(report_dict)
return report_dict


class IICPretrainDecoderEpoch(_PretrainDecoderEpoch):
def __init__(self, model: nn.Module, projection_head: nn.Module, projection_classifier: nn.Module,
optimizer: optim.Optimizer, pretrain_decoder_loader: T_loader, contrastive_criterion: T_loss,
iic_criterion: T_loss, num_batches: int = 0, cur_epoch=0, device="cpu") -> None:
super().__init__(model, projection_head, optimizer, pretrain_decoder_loader, contrastive_criterion, num_batches,
cur_epoch, device)
self._projection_classifer = projection_classifier
self._iic_criterion = iic_criterion

def _run(self, *args, **kwargs) -> EpochResultDict:
self._model.train()
assert self._model.training, self._model.training
self.meters["lr"].add(get_lrs_from_optimizer(self._optimizer)[0])

with tqdm(range(self._num_batches)).set_desc_from_epocher(self) as indicator: # noqa
for i, data in zip(indicator, self._pretrain_decoder_loader):
(img, _), (img_ctf, _), filename, partition_list, group_list = self._preprocess_data(data, self._device)
seed = random.randint(0, int(1e5))
with FixRandomSeed(seed):
img_gtf = torch.stack([self._transformer(x) for x in img], dim=0)
assert img_gtf.shape == img.shape, (img_gtf.shape, img.shape)

_, *_, (_, d4, *_) = self._model(torch.cat([img_gtf, img_ctf], dim=0), return_features=True)
d4_gtf, d4_ctf = torch.chunk(d4, chunks=2, dim=0)
with FixRandomSeed(seed):
d4_ctf_gtf = torch.stack([self._transformer(x) for x in d4_ctf], dim=0)
assert d4_ctf_gtf.shape == d4_ctf.shape, (d4_ctf_gtf.shape, d4_ctf.shape)
d4_tf = torch.cat([d4_gtf, d4_ctf_gtf], dim=0)
local_enc_tf, local_enc_tf_ctf = torch.chunk(self._projection_head(d4_tf), chunks=2, dim=0)
# todo: convert representation to distance
local_enc_unfold, _ = unfold_position(local_enc_tf, partition_num=(2, 2))
local_tf_enc_unfold, _fold_partition = unfold_position(local_enc_tf_ctf, partition_num=(2, 2))
b, *_ = local_enc_unfold.shape
local_enc_unfold_norm = F.normalize(local_enc_unfold.view(b, -1), p=2, dim=1)
local_tf_enc_unfold_norm = F.normalize(local_tf_enc_unfold.view(b, -1), p=2, dim=1)

labels = self._label_generation(partition_list, group_list, _fold_partition)
contrastive_loss = self._contrastive_criterion(
torch.stack([local_enc_unfold_norm, local_tf_enc_unfold_norm], dim=1),
labels=labels
)
if torch.isnan(contrastive_loss):
raise RuntimeError(contrastive_loss)
self._optimizer.zero_grad()
contrastive_loss.backward()
self._optimizer.step()
# todo: meter recording.
with torch.no_grad():
self.meters["contrastive_loss"].add(contrastive_loss.item())
report_dict = self.meters.tracking_status()
indicator.set_postfix_dict(report_dict)
return report_dict
Loading

0 comments on commit 58485d5

Please sign in to comment.