Skip to content

Commit

Permalink
adding mean teacher as a baselien
Browse files Browse the repository at this point in the history
  • Loading branch information
jizong committed Jul 21, 2020
1 parent 83efc97 commit f0d7cc9
Show file tree
Hide file tree
Showing 14 changed files with 353 additions and 338 deletions.
Empty file added byol_demo/__init__.py
Empty file.
30 changes: 4 additions & 26 deletions demo_cifar.py → byol_demo/byol_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import nn
from torch.utils.data import DataLoader

from byol_demo.utils import loss_fn, TransTwice
from contrastyou import DATA_PATH, PROJECT_PATH
from deepclustering2 import ModelMode
from deepclustering2.augment import pil_augment
Expand All @@ -19,22 +20,9 @@
from deepclustering2.tqdm import tqdm
from deepclustering2.trainer.trainer import T_loader, Trainer
from deepclustering2.writer import SummaryWriter
from torch.nn import functional as F

class TransTwice:

def __init__(self, transform) -> None:
super().__init__()
self._transform = transform

def __call__(self, img):
return [self._transform(img), self._transform(img)]

def loss_fn(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)

# todo: redo that
class ContrastEpocher(_Epocher):
def __init__(self, model: Model, target_model: EMA_Model, data_loader: T_loader, num_batches: int = 1000,
cur_epoch=0, device="cpu") -> None:
Expand Down Expand Up @@ -84,6 +72,7 @@ def _preprocess_data(data, device):
return (data[0][0].to(device), data[0][1].to(device)), data[1].to(device)


# todo: redo that
class FineTuneEpocher(_Epocher):
def __init__(self, model: Model, classify_model: Model, data_loader: T_loader, num_batches: int = 1000,
cur_epoch=0, device="cpu") -> None:
Expand Down Expand Up @@ -134,6 +123,7 @@ def _preprocess_data(data, device):
return (data[0][0].to(device), data[0][1].to(device)), data[1].to(device)


# todo: redo that
class EvalEpocher(FineTuneEpocher):

def __init__(self, model: Model, classify_model: Model, val_loader, num_batches: int = 1000, cur_epoch=0,
Expand Down Expand Up @@ -184,18 +174,6 @@ def __init__(self, model: Model, target_model: EMA_Model, classify_model: Model,
self._finetune_loader = finetune_loader
self._val_loader = val_loader

def pretrain_epoch(self, *args, **kwargs):
epocher = ContrastEpocher.create_from_trainer(self)
return epocher.run()

def finetune_epoch(self, *args, **kwargs):
epocher = FineTuneEpocher.create_from_trainer(self)
return epocher.run()

def eval_epoch(self):
epocher = EvalEpocher.create_from_trainer(self)
return epocher.run()

def _start_contrastive_training(self):
save_path = os.path.join(self._save_dir, "pretrain")
Path(save_path).mkdir(exist_ok=True, parents=True)
Expand Down
17 changes: 17 additions & 0 deletions byol_demo/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from torch.nn import functional as F


class TransTwice:

def __init__(self, transform) -> None:
super().__init__()
self._transform = transform

def __call__(self, img):
return [self._transform(img), self._transform(img)]


def loss_fn(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
15 changes: 10 additions & 5 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,23 @@ Arch:
input_dim: 1
num_classes: 4

Data:
labeled_data_ratio: 0.05
unlabeled_data_ratio: 0.95

Trainer:
save_dir: test_pipeline
name: contrast
save_dir: test_semi_trainer
device: cuda
num_batches: 1000
num_batches: 500
max_epoch_train_decoder: 100
max_epoch_train_encoder: 100
max_epoch_train_finetune: 100
train_encoder: True
train_decoder: True
# for mt trainer
transform_axis: [1, 2]
reg_weight: 10,

Data:
labeled_data_ratio: 0.05
unlabeled_data_ratio: 0.95

#Checkpoint: runs/test_pipeline
2 changes: 2 additions & 0 deletions contrastyou/epocher/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base_epocher import EvalEpoch, SimpleFineTuneEpoch, MeanTeacherEpocher
from .contrast_epocher import PretrainDecoderEpoch, PretrainEncoderEpoch
42 changes: 38 additions & 4 deletions contrastyou/epocher/_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import random

from deepclustering2.type import to_device, torch
from deepclustering2.utils import assert_list
from torch import Tensor


def preprocess_input_with_twice_transformation(data, device):
def preprocess_input_with_twice_transformation(data, device, non_blocking=True):
[(image, target), (image_tf, target_tf)], filename, partition_list, group_list = \
to_device(data[0], device), data[1], data[2], data[3]
to_device(data[0], device, non_blocking), data[1], data[2], data[3]
return (image, target), (image_tf, target_tf), filename, partition_list, group_list


def preprocess_input_with_single_transformation(data, device):
return data[0][0].to(device), data[0][1].to(device), data[1], data[2], data[3]
def preprocess_input_with_single_transformation(data, device, non_blocking=True):
return data[0][0].to(device, non_blocking=non_blocking), data[0][1].to(device, non_blocking=non_blocking), data[1], \
data[2], data[3]


def unfold_position(features: torch.Tensor, partition_num=(4, 4), ):
Expand All @@ -27,6 +32,35 @@ def unfold_position(features: torch.Tensor, partition_num=(4, 4), ):
return torch.cat(result, dim=0), result_flag


class TensorRandomFlip:
def __init__(self, axis=None) -> None:
if isinstance(axis, int):
self._axis = [axis]
elif isinstance(axis, (list, tuple)):
assert_list(lambda x: isinstance(x, int), axis), axis
self._axis = axis
elif axis is None:
self._axis = axis
else:
raise ValueError(str(axis))

def __call__(self, tensor: Tensor):
tensor = tensor.clone()
if self._axis is not None:
for _one_axis in self._axis:
if random.random() < 0.5:
tensor = tensor.flip(_one_axis)
return tensor
else:
return tensor

def __repr__(self):
string = f"{self.__class__.__name__}"
axis = "" if not self._axis else f" with axis={self._axis}."

return string + axis


if __name__ == '__main__':
features = torch.randn(10, 3, 256, 256, requires_grad=True)

Expand Down
Loading

0 comments on commit f0d7cc9

Please sign in to comment.