From f7967eadc48bd8be5f3108d8c7b8fc6f6a93d9c0 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Fri, 24 Nov 2023 15:05:08 +0800 Subject: [PATCH 01/38] Update magic constants for new GPU --- tests/test_main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_main.py b/tests/test_main.py index 2f3028b..8e5ee1c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -9,13 +9,13 @@ def test_main_seed_fast(): epochs=epochs, train_iteration=train_iteration ) - assert best_acc_1 == 8.53333332570394 - assert mean_acc_1 == 8.79 + assert best_acc_1 == 8.399999992370606 + assert mean_acc_1 == 8.81 def test_main_seed_epoch(): """Ensure that the model doesn't change when refactoring""" epochs = 1 best_acc_1, mean_acc_1 = main(epochs=epochs) - assert best_acc_1 == 24.266666646321614 - assert mean_acc_1 == 23.01 + assert best_acc_1 == 23.733333368937174 + assert mean_acc_1 == 22.22 From da8b8b7a8f50052745c8edbcd1b657ab28afc3d0 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Fri, 24 Nov 2023 15:05:54 +0800 Subject: [PATCH 02/38] Remove unused arg --- mixmatch/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mixmatch/main.py b/mixmatch/main.py index dd4d8ba..df84b45 100644 --- a/mixmatch/main.py +++ b/mixmatch/main.py @@ -20,7 +20,6 @@ def main( epochs: int = 1024, batch_size: int = 64, lr: float = 0.002, - n_labeled: int = 250, train_iteration: int = 1024, ema_decay: float = 0.999, lambda_u: float = 75, From cec8c3ea489e6df6898fd3e5ac11b41de7499ebd Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Fri, 24 Nov 2023 15:17:41 +0800 Subject: [PATCH 03/38] Lint code --- mixmatch/dataset/cifar10.py | 8 ++------ mixmatch/main.py | 30 ++++++++++++------------------ mixmatch/models/wideresnet.py | 12 +++--------- tests/test_main.py | 4 +--- tests/test_model.py | 4 +--- 5 files changed, 19 insertions(+), 39 deletions(-) diff --git a/mixmatch/dataset/cifar10.py b/mixmatch/dataset/cifar10.py index c383e36..422d28d 100644 --- a/mixmatch/dataset/cifar10.py +++ b/mixmatch/dataset/cifar10.py @@ -179,12 +179,8 @@ def get_dataloaders( num_workers=num_workers, ) - train_lbl_dl = DataLoader( - train_lbl_ds, shuffle=True, drop_last=True, **dl_args - ) - train_unl_dl = DataLoader( - train_unl_ds, shuffle=True, drop_last=True, **dl_args - ) + train_lbl_dl = DataLoader(train_lbl_ds, shuffle=True, drop_last=True, **dl_args) + train_unl_dl = DataLoader(train_unl_ds, shuffle=True, drop_last=True, **dl_args) val_dl = DataLoader(val_ds, shuffle=False, **dl_args) test_dl = DataLoader(src_test_ds, shuffle=False, **dl_args) diff --git a/mixmatch/main.py b/mixmatch/main.py index df84b45..1939569 100644 --- a/mixmatch/main.py +++ b/mixmatch/main.py @@ -4,12 +4,11 @@ import numpy as np import torch import torch.nn as nn -import torch.nn.parallel import torch.optim as optim from torch.utils.data import DataLoader from mixmatch.dataset.cifar10 import get_dataloaders -import mixmatch.models.wideresnet as models +from models.wideresnet import WideResNet from utils.ema import WeightEMA from utils.eval import validate, train from utils.loss import SemiLoss @@ -27,6 +26,8 @@ def main( t: float = 0.5, device: str = "cuda", seed: int = 42, + train_lbl_size=0.005, + train_unl_size=0.980, ): random.seed(seed) np.random.seed(seed) @@ -36,21 +37,19 @@ def main( torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - best_acc = 0 - # Data print(f"==> Preparing cifar10") ( train_lbl_dl, train_unl_dl, - val_loader, - test_loader, + val_dl, + test_dl, classes, ) = get_dataloaders( dataset_dir="./data", - train_lbl_size=0.005, - train_unl_size=0.980, + train_lbl_size=train_lbl_size, + train_unl_size=train_unl_size, batch_size=batch_size, seed=seed, ) @@ -58,17 +57,11 @@ def main( # Model print("==> creating WRN-28-2") - model = models.WideResNet(num_classes=10).to(device) + model = WideResNet(num_classes=len(classes)).to(device) ema_model = deepcopy(model).to(device) for param in ema_model.parameters(): param.detach_() - # cudnn.benchmark = True - print( - " Total params: %.2fM" - % (sum(p.numel() for p in model.parameters()) / 1000000.0) - ) - train_loss_fn = SemiLoss() val_loss_fn = nn.CrossEntropyLoss() train_optim = optim.Adam(model.parameters(), lr=lr) @@ -76,6 +69,7 @@ def main( ema_optim = WeightEMA(model, ema_model, alpha=ema_decay, lr=lr) test_accs = [] + best_acc = 0 # Train and val for epoch in range(epochs): print("\nEpoch: [%d | %d] LR: %f" % (epoch + 1, epochs, lr)) @@ -88,7 +82,7 @@ def main( ema_optim=ema_optim, loss_fn=train_loss_fn, epoch=epoch, - device="cuda", + device=device, train_iters=train_iteration, lambda_u=lambda_u, mix_beta_alpha=alpha, @@ -105,8 +99,8 @@ def val_ema(dl: DataLoader): ) _, train_acc = val_ema(train_lbl_dl) - val_loss, val_acc = val_ema(val_loader) - test_loss, test_acc = val_ema(test_loader) + val_loss, val_acc = val_ema(val_dl) + test_loss, test_acc = val_ema(test_dl) best_acc = max(val_acc, best_acc) test_accs.append(test_acc) diff --git a/mixmatch/models/wideresnet.py b/mixmatch/models/wideresnet.py index 68a8f2e..a8bfb90 100644 --- a/mixmatch/models/wideresnet.py +++ b/mixmatch/models/wideresnet.py @@ -112,9 +112,7 @@ def forward(self, x): class WideResNet(nn.Module): - def __init__( - self, num_classes, depth=28, widen_factor=2, dropRate=0.0, seed=42 - ): + def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0, seed=42): torch.manual_seed(42) super(WideResNet, self).__init__() nChannels = [ @@ -141,13 +139,9 @@ def __init__( activate_before_residual=True, ) # 2nd block - self.block2 = NetworkBlock( - n, nChannels[1], nChannels[2], block, 2, dropRate - ) + self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) # 3rd block - self.block3 = NetworkBlock( - n, nChannels[2], nChannels[3], block, 2, dropRate - ) + self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) # global average pooling and classifier self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001) self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) diff --git a/tests/test_main.py b/tests/test_main.py index 8e5ee1c..a9d913f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5,9 +5,7 @@ def test_main_seed_fast(): """The fast variant to ensure that the model doesn't change.""" epochs = 1 train_iteration = 8 - best_acc_1, mean_acc_1 = main( - epochs=epochs, train_iteration=train_iteration - ) + best_acc_1, mean_acc_1 = main(epochs=epochs, train_iteration=train_iteration) assert best_acc_1 == 8.399999992370606 assert mean_acc_1 == 8.81 diff --git a/tests/test_model.py b/tests/test_model.py index d87a860..e25b9ec 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -24,7 +24,5 @@ def create_model(ema=False): for param_1, param_2 in zip(model_1.parameters(), model_2.parameters()): assert torch.all(torch.eq(param_1, param_2)) - for param_1, param_2 in zip( - ema_model_1.parameters(), ema_model_2.parameters() - ): + for param_1, param_2 in zip(ema_model_1.parameters(), ema_model_2.parameters()): assert torch.all(torch.eq(param_1, param_2)) From 99a8ff50529ec1b304b12c0db389a79f9af064e4 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Fri, 24 Nov 2023 15:20:56 +0800 Subject: [PATCH 04/38] Fix unclear variable names in main --- mixmatch/main.py | 20 ++++++++++---------- mixmatch/utils/ema.py | 7 +++---- mixmatch/utils/eval.py | 4 ++-- mixmatch/utils/loss.py | 4 ++-- 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/mixmatch/main.py b/mixmatch/main.py index 1939569..d328ef4 100644 --- a/mixmatch/main.py +++ b/mixmatch/main.py @@ -20,14 +20,14 @@ def main( batch_size: int = 64, lr: float = 0.002, train_iteration: int = 1024, - ema_decay: float = 0.999, - lambda_u: float = 75, - alpha: float = 0.75, - t: float = 0.5, + ema_wgt_decay: float = 0.999, + unl_loss_scale: float = 75, + mix_beta_alpha: float = 0.75, + sharpen_temp: float = 0.5, device: str = "cuda", seed: int = 42, - train_lbl_size=0.005, - train_unl_size=0.980, + train_lbl_size: int = 0.005, + train_unl_size: int = 0.980, ): random.seed(seed) np.random.seed(seed) @@ -66,7 +66,7 @@ def main( val_loss_fn = nn.CrossEntropyLoss() train_optim = optim.Adam(model.parameters(), lr=lr) - ema_optim = WeightEMA(model, ema_model, alpha=ema_decay, lr=lr) + ema_optim = WeightEMA(model, ema_model, ema_wgt_decay=ema_wgt_decay, lr=lr) test_accs = [] best_acc = 0 @@ -84,10 +84,10 @@ def main( epoch=epoch, device=device, train_iters=train_iteration, - lambda_u=lambda_u, - mix_beta_alpha=alpha, + unl_loss_scale=unl_loss_scale, + mix_beta_alpha=mix_beta_alpha, epochs=epochs, - sharpen_temp=t, + sharpen_temp=sharpen_temp, ) def val_ema(dl: DataLoader): diff --git a/mixmatch/utils/ema.py b/mixmatch/utils/ema.py index 1725a73..5ee596b 100644 --- a/mixmatch/utils/ema.py +++ b/mixmatch/utils/ema.py @@ -8,12 +8,12 @@ def __init__( self, model: nn.Module, ema_model: nn.Module, - alpha: float = 0.999, + ema_wgt_decay: float = 0.999, lr: float = 0.002, ): self.model = model self.ema_model = ema_model - self.alpha = alpha + self.alpha = ema_wgt_decay self.params = list(model.state_dict().values()) self.ema_params = list(ema_model.state_dict().values()) self.wd = 0.02 * lr @@ -22,10 +22,9 @@ def __init__( param.data.copy_(ema_param.data) def step(self): - one_minus_alpha = 1.0 - self.alpha for param, ema_param in zip(self.params, self.ema_params): if ema_param.dtype == torch.float32: ema_param.mul_(self.alpha) - ema_param.add_(param * one_minus_alpha) + ema_param.add_(param * (1.0 - self.alpha)) # customized weight decay param.mul_(1 - self.wd) diff --git a/mixmatch/utils/eval.py b/mixmatch/utils/eval.py index c16e446..dee21d9 100644 --- a/mixmatch/utils/eval.py +++ b/mixmatch/utils/eval.py @@ -54,7 +54,7 @@ def train( epochs: int, device: str, train_iters: int, - lambda_u: float, + unl_loss_scale: float, mix_beta_alpha: float, sharpen_temp: float, ) -> tuple[float, float, float]: @@ -121,7 +121,7 @@ def train( x_unl=y_mix_unl_pred, y_unl=y_mix_unl, epoch=epoch + batch_idx / train_iters, - lambda_u=lambda_u, + loss_unl_scale=unl_loss_scale, epochs=epochs, ) diff --git a/mixmatch/utils/loss.py b/mixmatch/utils/loss.py index 8d6763d..f95844a 100644 --- a/mixmatch/utils/loss.py +++ b/mixmatch/utils/loss.py @@ -20,7 +20,7 @@ def __call__( x_unl: torch.Tensor, y_unl: torch.Tensor, epoch: float, - lambda_u: float, + loss_unl_scale: float, epochs: int, ): probs_u = torch.softmax(x_unl, dim=1) @@ -33,5 +33,5 @@ def __call__( return ( l_x, l_u, - lambda_u * self.linear_rampup(epoch, epochs), + loss_unl_scale * self.linear_rampup(epoch, epochs), ) From d26fad839478cac7b6689a690722288c43befc6b Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Fri, 24 Nov 2023 15:29:38 +0800 Subject: [PATCH 05/38] Remove unused object inherit --- mixmatch/utils/ema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mixmatch/utils/ema.py b/mixmatch/utils/ema.py index 5ee596b..cb33c05 100644 --- a/mixmatch/utils/ema.py +++ b/mixmatch/utils/ema.py @@ -3,7 +3,7 @@ import torch.nn.parallel -class WeightEMA(object): +class WeightEMA: def __init__( self, model: nn.Module, From 42299474e14c4bef2912a5f1863b07d7a334ee75 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Fri, 24 Nov 2023 15:29:52 +0800 Subject: [PATCH 06/38] Add deterministic arg --- mixmatch/dataset/cifar10.py | 39 +++++++++++++------------------------ mixmatch/main.py | 35 +++++++++++++++++++++++++-------- 2 files changed, 41 insertions(+), 33 deletions(-) diff --git a/mixmatch/dataset/cifar10.py b/mixmatch/dataset/cifar10.py index 422d28d..4c3ba42 100644 --- a/mixmatch/dataset/cifar10.py +++ b/mixmatch/dataset/cifar10.py @@ -94,7 +94,7 @@ def get_dataloaders( train_unl_size: float = 0.980, batch_size: int = 48, num_workers: int = 0, - seed: int = 42, + seed: int | None = 42, ) -> tuple[DataLoader, DataLoader, DataLoader, DataLoader, list[str]]: """Get the dataloaders for the CIFAR10 dataset. @@ -108,13 +108,18 @@ def get_dataloaders( train_unl_size: The size of the unlabelled training set. batch_size: The batch size. num_workers: The number of workers for the dataloaders. - seed: The seed for the random number generators. + seed: The seed for the random number generators. If None, then it'll be + non-deterministic. Returns: 4 DataLoaders: train_lbl_dl, train_unl_dl, val_unl_dl, test_dl """ - torch.manual_seed(seed) - np.random.seed(seed) + deterministic = seed is not None + + if deterministic: + torch.manual_seed(seed) + np.random.seed(seed) + src_train_ds = CIFAR10( dataset_dir, train=True, @@ -148,31 +153,15 @@ def get_dataloaders( stratify=lbl_targets, ) + ds_args = dict(root=dataset_dir, train=True, download=True, transform=tf_preproc) + train_lbl_ds = CIFAR10SubsetKAug( - dataset_dir, - idxs=train_lbl_ixs, - train=True, - transform=tf_preproc, - download=True, - k_augs=1, - aug=tf_aug, + **ds_args, idxs=train_lbl_ixs, k_augs=1, aug=tf_aug ) train_unl_ds = CIFAR10SubsetKAug( - dataset_dir, - idxs=train_unl_ixs, - train=True, - transform=tf_preproc, - download=True, - k_augs=2, - aug=tf_aug, - ) - val_ds = CIFAR10Subset( - dataset_dir, - idxs=val_ixs, - train=True, - transform=tf_preproc, - download=True, + **ds_args, idxs=train_unl_ixs, k_augs=2, aug=tf_aug ) + val_ds = CIFAR10Subset(**ds_args, idxs=val_ixs) dl_args = dict( batch_size=batch_size, diff --git a/mixmatch/main.py b/mixmatch/main.py index d328ef4..eea90a8 100644 --- a/mixmatch/main.py +++ b/mixmatch/main.py @@ -25,17 +25,36 @@ def main( mix_beta_alpha: float = 0.75, sharpen_temp: float = 0.5, device: str = "cuda", - seed: int = 42, + seed: int | None = 42, train_lbl_size: int = 0.005, train_unl_size: int = 0.980, ): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False + """The main function to run the MixMatch algorithm + + Args: + epochs: Number of epochs to run. + batch_size: The batch size to use. + lr: The learning rate to use. + train_iteration: The number of iterations to train for. + ema_wgt_decay: The weight decay to use for the EMA model. + unl_loss_scale: The scaling factor for the unlabeled loss. + mix_beta_alpha: The beta alpha to use for the mixup. + sharpen_temp: The temperature to use for sharpening. + device: The device to use. + seed: The seed to use. If None, then it'll be non-deterministic. + train_lbl_size: The size of the labeled training set. + train_unl_size: The size of the unlabeled training set. + """ + deterministic = seed is not None + + if deterministic: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False # Data print(f"==> Preparing cifar10") From 1a3d8ea51847973a73f1d50cd228864869773963 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Fri, 24 Nov 2023 15:47:23 +0800 Subject: [PATCH 07/38] Update arg naming for partitioning --- mixmatch/dataset/cifar10.py | 106 +++++++++++++++--------------------- mixmatch/main.py | 4 +- tests/test_dataset_load.py | 8 +-- 3 files changed, 49 insertions(+), 69 deletions(-) diff --git a/mixmatch/dataset/cifar10.py b/mixmatch/dataset/cifar10.py index 4c3ba42..0af0c61 100644 --- a/mixmatch/dataset/cifar10.py +++ b/mixmatch/dataset/cifar10.py @@ -1,11 +1,13 @@ from __future__ import annotations + +from dataclasses import dataclass, KW_ONLY from pathlib import Path -from typing import Callable, Sequence, List +from typing import Callable, Sequence import numpy as np import torch from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader, Subset, Dataset +from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.transforms.v2 import ( @@ -24,12 +26,7 @@ [ lambda x: torch.nn.functional.pad( x, - ( - 4, - 4, - 4, - 4, - ), + (4,) * 4, mode="reflect", ), RandomCrop(32), @@ -38,50 +35,33 @@ ) +@dataclass class CIFAR10Subset(CIFAR10): - def __init__( - self, - root: str, - idxs: Sequence[int] | None = None, - train: bool = True, - transform: Callable | None = None, - target_transform: Callable | None = None, - download: bool = False, - ): + _: KW_ONLY + root: str + idxs: Sequence[int] | None = None + train: bool = True + transform: Callable | None = None + target_transform: Callable | None = None + download: bool = False + + def __post_init__(self): super().__init__( - root, - train=train, - transform=transform, - target_transform=target_transform, - download=download, + root=self.root, + train=self.train, + transform=self.transform, + target_transform=self.target_transform, ) - if idxs is not None: - self.data = self.data[idxs] - self.targets = np.array(self.targets)[idxs].tolist() + if self.idxs is not None: + self.data = self.data[self.idxs] + self.targets = np.array(self.targets)[self.idxs].tolist() +@dataclass class CIFAR10SubsetKAug(CIFAR10Subset): - def __init__( - self, - root: str, - k_augs: int, - aug: Callable, - idxs: Sequence[int] | None = None, - train: bool = True, - transform: Callable | None = None, - target_transform: Callable | None = None, - download: bool = False, - ): - super().__init__( - root=root, - idxs=idxs, - train=train, - transform=transform, - target_transform=target_transform, - download=download, - ) - self.k_augs = k_augs - self.aug = aug + _: KW_ONLY + k_augs: int = 1 + aug: Callable = lambda x: x def __getitem__(self, item): img, target = super().__getitem__(item) @@ -90,8 +70,8 @@ def __getitem__(self, item): def get_dataloaders( dataset_dir: Path | str, - train_lbl_size: float = 0.005, - train_unl_size: float = 0.980, + n_train_lbl: float = 0.005, + n_train_unl: float = 0.980, batch_size: int = 48, num_workers: int = 0, seed: int | None = 42, @@ -104,8 +84,8 @@ def get_dataloaders( Args: dataset_dir: The directory where the dataset is stored. - train_lbl_size: The size of the labelled training set. - train_unl_size: The size of the unlabelled training set. + n_train_lbl: The size of the labelled training set. + n_train_unl: The size of the unlabelled training set. batch_size: The batch size. num_workers: The number of workers for the dataloaders. seed: The seed for the random number generators. If None, then it'll be @@ -133,35 +113,35 @@ def get_dataloaders( transform=tf_preproc, ) - train_size = len(src_train_ds) - train_unl_size = int(train_size * train_unl_size) - train_lbl_size = int(train_size * train_lbl_size) - val_size = int(train_size - train_unl_size - train_lbl_size) + n_train = len(src_train_ds) + n_train_unl = int(n_train * n_train_unl) + n_train_lbl = int(n_train * n_train_lbl) + n_val = int(n_train - n_train_unl - n_train_lbl) targets = np.array(src_train_ds.targets) ixs = np.arange(len(targets)) - train_unl_ixs, lbl_ixs = train_test_split( + ixs_train_unl, ixs_lbl = train_test_split( ixs, - train_size=train_unl_size, + train_size=n_train_unl, stratify=targets, ) - lbl_targets = targets[lbl_ixs] + lbl_targets = targets[ixs_lbl] - val_ixs, train_lbl_ixs = train_test_split( - lbl_ixs, - train_size=val_size, + ixs_val, ixs_train_lbl = train_test_split( + ixs_lbl, + train_size=n_val, stratify=lbl_targets, ) ds_args = dict(root=dataset_dir, train=True, download=True, transform=tf_preproc) train_lbl_ds = CIFAR10SubsetKAug( - **ds_args, idxs=train_lbl_ixs, k_augs=1, aug=tf_aug + **ds_args, idxs=ixs_train_lbl, k_augs=1, aug=tf_aug ) train_unl_ds = CIFAR10SubsetKAug( - **ds_args, idxs=train_unl_ixs, k_augs=2, aug=tf_aug + **ds_args, idxs=ixs_train_unl, k_augs=2, aug=tf_aug ) - val_ds = CIFAR10Subset(**ds_args, idxs=val_ixs) + val_ds = CIFAR10Subset(**ds_args, idxs=ixs_val) dl_args = dict( batch_size=batch_size, diff --git a/mixmatch/main.py b/mixmatch/main.py index eea90a8..3e97224 100644 --- a/mixmatch/main.py +++ b/mixmatch/main.py @@ -67,8 +67,8 @@ def main( classes, ) = get_dataloaders( dataset_dir="./data", - train_lbl_size=train_lbl_size, - train_unl_size=train_unl_size, + n_train_lbl=train_lbl_size, + n_train_unl=train_unl_size, batch_size=batch_size, seed=seed, ) diff --git a/tests/test_dataset_load.py b/tests/test_dataset_load.py index 8853836..3baff70 100644 --- a/tests/test_dataset_load.py +++ b/tests/test_dataset_load.py @@ -17,8 +17,8 @@ def test_load_seeded(): classes, ) = get_dataloaders( dataset_dir="./data", - train_lbl_size=0.005, - train_unl_size=0.980, + n_train_lbl=0.005, + n_train_unl=0.980, batch_size=batch_size, seed=seed, ) @@ -35,8 +35,8 @@ def test_load_seeded(): classes, ) = get_dataloaders( dataset_dir="./data", - train_lbl_size=0.005, - train_unl_size=0.980, + n_train_lbl=0.005, + n_train_unl=0.980, batch_size=batch_size, seed=seed, ) From 1285c130b255f722afa2449946d7b72588ac7735 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Fri, 24 Nov 2023 16:00:47 +0800 Subject: [PATCH 08/38] Refactor our scaling out of loss_fn --- mixmatch/utils/eval.py | 10 +++------- mixmatch/utils/loss.py | 25 ++++++++----------------- 2 files changed, 11 insertions(+), 24 deletions(-) diff --git a/mixmatch/utils/eval.py b/mixmatch/utils/eval.py index dee21d9..89cce17 100644 --- a/mixmatch/utils/eval.py +++ b/mixmatch/utils/eval.py @@ -114,18 +114,14 @@ def train( y_mix_unl_pred = torch.cat(y_mix_pred[1:], dim=0) y_mix_unl = y_mix[batch_size:] - # TODO: Pretty ugly that we throw in epoch, epochs and lambda_u here - loss_lbl, loss_unl, loss_unl_scale = loss_fn( + loss_lbl, loss_unl = loss_fn( x_lbl=y_mix_lbl_pred, y_lbl=y_mix_lbl, x_unl=y_mix_unl_pred, y_unl=y_mix_unl, - epoch=epoch + batch_idx / train_iters, - loss_unl_scale=unl_loss_scale, - epochs=epochs, ) - - loss = loss_lbl + loss_unl_scale * loss_unl + loss_unl_scale = (epoch + batch_idx / train_iters) / epochs * unl_loss_scale + loss = loss_lbl + loss_unl * loss_unl_scale losses.append(loss) losses_x.append(loss_lbl) diff --git a/mixmatch/utils/loss.py b/mixmatch/utils/loss.py index f95844a..ae59d2c 100644 --- a/mixmatch/utils/loss.py +++ b/mixmatch/utils/loss.py @@ -4,25 +4,20 @@ from torch.nn.functional import cross_entropy -class SemiLoss(object): - @staticmethod - def linear_rampup(current: float, rampup_length: int): - if rampup_length == 0: - return 1.0 - else: - current = np.clip(current / rampup_length, 0.0, 1.0) - return float(current) +def linear_rampup(current: float, rampup_length: int): + if rampup_length == 0: + return 1.0 + return np.clip(current / rampup_length, 0, 1) + +class SemiLoss(object): def __call__( self, x_lbl: torch.Tensor, y_lbl: torch.Tensor, x_unl: torch.Tensor, y_unl: torch.Tensor, - epoch: float, - loss_unl_scale: float, - epochs: int, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: probs_u = torch.softmax(x_unl, dim=1) l_x = cross_entropy(x_lbl, y_lbl) @@ -30,8 +25,4 @@ def __call__( # It's likely not a big deal, but it's worth investigating if we have # too much time on our hands l_u = torch.mean((probs_u - y_unl) ** 2) - return ( - l_x, - l_u, - loss_unl_scale * self.linear_rampup(epoch, epochs), - ) + return l_x, l_u From 9ef95e6c20f4876f5fdbc921463c961ba002eeee Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Fri, 24 Nov 2023 16:13:17 +0800 Subject: [PATCH 09/38] Refactor our sharpening from guess_labels --- mixmatch/utils/eval.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/mixmatch/utils/eval.py b/mixmatch/utils/eval.py index 89cce17..2e1bf7c 100644 --- a/mixmatch/utils/eval.py +++ b/mixmatch/utils/eval.py @@ -30,16 +30,23 @@ def mix_up( return x_mix, y_mix +def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: + """Sharpen the predictions by raising them to the power of 1 / temp""" + y_sharp = y ** (1 / temp) + # Sharpening will change the sum of the predictions. + y_sharp /= y_sharp.sum(dim=1, keepdim=True) + return y_sharp.detach() + + def guess_labels( model: nn.Module, x_unls: list[torch.Tensor], - sharpen_temp: float, ) -> torch.Tensor: """Guess labels from the unlabelled data""" - y_unls = [torch.softmax(model(u), dim=1) for u in x_unls] - p = sum(y_unls) / 2 - pt = p ** (1 / sharpen_temp) - return pt / pt.sum(dim=1, keepdim=True).detach() + y_unls: list[torch.Tensor] = [torch.softmax(model(u), dim=1) for u in x_unls] + # The sum will sum the tensors in the list, it doesn't reduce the tensors + y_unl = sum(y_unls) / len(y_unls) + return y_unl def train( @@ -89,11 +96,8 @@ def train( x_unls = [u.to(device) for u in x_unls] with torch.no_grad(): - y_unl = guess_labels( - model=model, - x_unls=x_unls, - sharpen_temp=sharpen_temp, - ) + y_unl = guess_labels(model=model, x_unls=x_unls) + y_unl = sharpen(y_unl, sharpen_temp) x = torch.cat([x_lbl, *x_unls], dim=0) y = torch.cat([y_lbl, y_unl, y_unl], dim=0) From 2ef59e5c0fc090e133fb80daf51f397c80102361 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Fri, 24 Nov 2023 16:17:44 +0800 Subject: [PATCH 10/38] Rename train to train_epoch --- mixmatch/main.py | 4 ++-- mixmatch/utils/eval.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mixmatch/main.py b/mixmatch/main.py index 3e97224..a028485 100644 --- a/mixmatch/main.py +++ b/mixmatch/main.py @@ -10,7 +10,7 @@ from mixmatch.dataset.cifar10 import get_dataloaders from models.wideresnet import WideResNet from utils.ema import WeightEMA -from utils.eval import validate, train +from utils.eval import validate, train_epoch from utils.loss import SemiLoss @@ -93,7 +93,7 @@ def main( for epoch in range(epochs): print("\nEpoch: [%d | %d] LR: %f" % (epoch + 1, epochs, lr)) - train_loss, train_lbl_loss, train_unl_loss = train( + train_loss, train_lbl_loss, train_unl_loss = train_epoch( train_lbl_dl=train_lbl_dl, train_unl_dl=train_unl_dl, model=model, diff --git a/mixmatch/utils/eval.py b/mixmatch/utils/eval.py index 2e1bf7c..de01d91 100644 --- a/mixmatch/utils/eval.py +++ b/mixmatch/utils/eval.py @@ -35,7 +35,7 @@ def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: y_sharp = y ** (1 / temp) # Sharpening will change the sum of the predictions. y_sharp /= y_sharp.sum(dim=1, keepdim=True) - return y_sharp.detach() + return y_sharp def guess_labels( @@ -49,7 +49,7 @@ def guess_labels( return y_unl -def train( +def train_epoch( *, train_lbl_dl: DataLoader, train_unl_dl: DataLoader, @@ -87,8 +87,6 @@ def train( unl_iter = iter(train_unl_dl) x_unls, _ = next(unl_iter) - batch_size = x_lbl.size(0) - y_lbl = one_hot(y_lbl.long(), num_classes=10) x_lbl = x_lbl.to(device) @@ -105,6 +103,7 @@ def train( # interleave labeled and unlabeled samples between batches to # get correct batchnorm calculation + batch_size = x_lbl.shape[0] x_mix = list(torch.split(x_mix, batch_size)) x_mix = interleave(x_mix, batch_size) From 2ecf4d7f98e71278580c95daecb21d6d53a7ae4b Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Fri, 24 Nov 2023 16:30:19 +0800 Subject: [PATCH 11/38] Refactor CIFAR10 to use PyTorch Lightning DataModule --- mixmatch/dataset/cifar10.py | 175 +++++++++++++++++------------------- mixmatch/main.py | 22 ++--- 2 files changed, 92 insertions(+), 105 deletions(-) diff --git a/mixmatch/dataset/cifar10.py b/mixmatch/dataset/cifar10.py index 0af0c61..9f71c0c 100644 --- a/mixmatch/dataset/cifar10.py +++ b/mixmatch/dataset/cifar10.py @@ -68,95 +68,86 @@ def __getitem__(self, item): return tuple(self.aug(img) for _ in range(self.k_augs)), target -def get_dataloaders( - dataset_dir: Path | str, - n_train_lbl: float = 0.005, - n_train_unl: float = 0.980, - batch_size: int = 48, - num_workers: int = 0, - seed: int | None = 42, -) -> tuple[DataLoader, DataLoader, DataLoader, DataLoader, list[str]]: - """Get the dataloaders for the CIFAR10 dataset. - - Notes: - The train_lbl_size and train_unl_size must sum to less than 1. - The leftover data is used for the validation set. - - Args: - dataset_dir: The directory where the dataset is stored. - n_train_lbl: The size of the labelled training set. - n_train_unl: The size of the unlabelled training set. - batch_size: The batch size. - num_workers: The number of workers for the dataloaders. - seed: The seed for the random number generators. If None, then it'll be - non-deterministic. - - Returns: - 4 DataLoaders: train_lbl_dl, train_unl_dl, val_unl_dl, test_dl - """ - deterministic = seed is not None - - if deterministic: - torch.manual_seed(seed) - np.random.seed(seed) - - src_train_ds = CIFAR10( - dataset_dir, - train=True, - download=True, - transform=tf_preproc, - ) - src_test_ds = CIFAR10( - dataset_dir, - train=False, - download=True, - transform=tf_preproc, - ) - - n_train = len(src_train_ds) - n_train_unl = int(n_train * n_train_unl) - n_train_lbl = int(n_train * n_train_lbl) - n_val = int(n_train - n_train_unl - n_train_lbl) - - targets = np.array(src_train_ds.targets) - ixs = np.arange(len(targets)) - ixs_train_unl, ixs_lbl = train_test_split( - ixs, - train_size=n_train_unl, - stratify=targets, - ) - lbl_targets = targets[ixs_lbl] - - ixs_val, ixs_train_lbl = train_test_split( - ixs_lbl, - train_size=n_val, - stratify=lbl_targets, - ) - - ds_args = dict(root=dataset_dir, train=True, download=True, transform=tf_preproc) - - train_lbl_ds = CIFAR10SubsetKAug( - **ds_args, idxs=ixs_train_lbl, k_augs=1, aug=tf_aug - ) - train_unl_ds = CIFAR10SubsetKAug( - **ds_args, idxs=ixs_train_unl, k_augs=2, aug=tf_aug - ) - val_ds = CIFAR10Subset(**ds_args, idxs=ixs_val) - - dl_args = dict( - batch_size=batch_size, - num_workers=num_workers, - ) - - train_lbl_dl = DataLoader(train_lbl_ds, shuffle=True, drop_last=True, **dl_args) - train_unl_dl = DataLoader(train_unl_ds, shuffle=True, drop_last=True, **dl_args) - val_dl = DataLoader(val_ds, shuffle=False, **dl_args) - test_dl = DataLoader(src_test_ds, shuffle=False, **dl_args) - - return ( - train_lbl_dl, - train_unl_dl, - val_dl, - test_dl, - src_train_ds.classes, - ) +import pytorch_lightning as pl + + +class CIFAR10DataModule(pl.LightningDataModule): + def __init__( + self, + dataset_dir: Path | str, + n_train_lbl: float = 0.005, + n_train_unl: float = 0.980, + batch_size: int = 48, + num_workers: int = 0, + seed: int | None = 42, + ): + super().__init__() + self.dir = dataset_dir + self.n_train_lbl = n_train_lbl + self.n_train_unl = n_train_unl + self.batch_size = batch_size + self.num_workers = num_workers + deterministic = seed is not None + + if deterministic: + torch.manual_seed(seed) + np.random.seed(seed) + + self.ds_args = dict(root=self.dir, train=True, download=True, transform=tf_preproc, ) + self.dl_args = dict( + batch_size=self.batch_size, + num_workers=self.num_workers, + ) + + def setup(self, stage: str | None = None): + src_train_ds = CIFAR10( + self.dir, + train=True, + download=True, + transform=tf_preproc, + ) + self.test_ds = CIFAR10( + self.dir, + train=False, + download=True, + transform=tf_preproc, + ) + + n_train = len(src_train_ds) + n_train_unl = int(n_train * self.n_train_unl) + n_train_lbl = int(n_train * self.n_train_lbl) + n_val = int(n_train - n_train_unl - n_train_lbl) + + targets = np.array(src_train_ds.targets) + ixs = np.arange(len(targets)) + ixs_train_unl, ixs_lbl = train_test_split( + ixs, + train_size=n_train_unl, + stratify=targets, + ) + lbl_targets = targets[ixs_lbl] + + ixs_val, ixs_train_lbl = train_test_split( + ixs_lbl, + train_size=n_val, + stratify=lbl_targets, + ) + self.train_lbl_ds = CIFAR10SubsetKAug( + **self.ds_args, idxs=ixs_train_lbl, k_augs=1, aug=tf_aug + ) + self.train_unl_ds = CIFAR10SubsetKAug( + **self.ds_args, idxs=ixs_train_unl, k_augs=2, aug=tf_aug + ) + self.val_ds = CIFAR10Subset(**self.ds_args, idxs=ixs_val) + + def train_lbl_dataloader(self): + return DataLoader(self.train_lbl_ds, shuffle=True, drop_last=True, **self.dl_args) + + def train_unl_dataloader(self): + return DataLoader(self.train_unl_ds, shuffle=True, drop_last=True, **self.dl_args) + + def val_dataloader(self): + return DataLoader(self.val_ds, shuffle=False, **self.dl_args) + + def test_dataloader(self): + return DataLoader(self.test_ds, shuffle=False, **self.dl_args) diff --git a/mixmatch/main.py b/mixmatch/main.py index a028485..27bce43 100644 --- a/mixmatch/main.py +++ b/mixmatch/main.py @@ -7,7 +7,7 @@ import torch.optim as optim from torch.utils.data import DataLoader -from mixmatch.dataset.cifar10 import get_dataloaders +from mixmatch.dataset.cifar10 import CIFAR10DataModule from models.wideresnet import WideResNet from utils.ema import WeightEMA from utils.eval import validate, train_epoch @@ -59,19 +59,15 @@ def main( # Data print(f"==> Preparing cifar10") - ( - train_lbl_dl, - train_unl_dl, - val_dl, - test_dl, - classes, - ) = get_dataloaders( + dm = CIFAR10DataModule( dataset_dir="./data", n_train_lbl=train_lbl_size, n_train_unl=train_unl_size, batch_size=batch_size, seed=seed, ) + dm.setup() + classes = dm.test_ds.classes # Model print("==> creating WRN-28-2") @@ -94,8 +90,8 @@ def main( print("\nEpoch: [%d | %d] LR: %f" % (epoch + 1, epochs, lr)) train_loss, train_lbl_loss, train_unl_loss = train_epoch( - train_lbl_dl=train_lbl_dl, - train_unl_dl=train_unl_dl, + train_lbl_dl=dm.train_lbl_dataloader(), + train_unl_dl=dm.train_unl_dataloader(), model=model, optim=train_optim, ema_optim=ema_optim, @@ -117,9 +113,9 @@ def val_ema(dl: DataLoader): device=device, ) - _, train_acc = val_ema(train_lbl_dl) - val_loss, val_acc = val_ema(val_dl) - test_loss, test_acc = val_ema(test_dl) + _, train_acc = val_ema(dm.train_lbl_dataloader()) + val_loss, val_acc = val_ema(dm.val_dataloader()) + test_loss, test_acc = val_ema(dm.test_dataloader()) best_acc = max(val_acc, best_acc) test_accs.append(test_acc) From 5ebfcdd09885c098c9d11a85c728351663b3a1cd Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Fri, 24 Nov 2023 16:36:39 +0800 Subject: [PATCH 12/38] Migrate DataModule to use dataclass --- mixmatch/dataset/cifar10.py | 54 ++++++++++++++++++------------------- mixmatch/main.py | 2 +- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/mixmatch/dataset/cifar10.py b/mixmatch/dataset/cifar10.py index 9f71c0c..85a4db6 100644 --- a/mixmatch/dataset/cifar10.py +++ b/mixmatch/dataset/cifar10.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass, KW_ONLY +from dataclasses import dataclass, KW_ONLY, field from pathlib import Path from typing import Callable, Sequence @@ -71,33 +71,29 @@ def __getitem__(self, item): import pytorch_lightning as pl +@dataclass class CIFAR10DataModule(pl.LightningDataModule): - def __init__( - self, - dataset_dir: Path | str, - n_train_lbl: float = 0.005, - n_train_unl: float = 0.980, - batch_size: int = 48, - num_workers: int = 0, - seed: int | None = 42, - ): + dir: Path | str + n_train_lbl: float = 0.005 + n_train_unl: float = 0.980 + batch_size: int = 48 + num_workers: int = 0 + seed: int | None = 42 + train_lbl_ds: CIFAR10Subset = field(init=False) + train_unl_ds: CIFAR10Subset = field(init=False) + val_ds: CIFAR10Subset = field(init=False) + test_ds: CIFAR10 = field(init=False) + + def __post_init__(self): super().__init__() - self.dir = dataset_dir - self.n_train_lbl = n_train_lbl - self.n_train_unl = n_train_unl - self.batch_size = batch_size - self.num_workers = num_workers - deterministic = seed is not None - - if deterministic: - torch.manual_seed(seed) - np.random.seed(seed) - - self.ds_args = dict(root=self.dir, train=True, download=True, transform=tf_preproc, ) - self.dl_args = dict( - batch_size=self.batch_size, - num_workers=self.num_workers, + if self.seed is not None: + torch.manual_seed(self.seed) + np.random.seed(self.seed) + + self.ds_args = dict( + root=self.dir, train=True, download=True, transform=tf_preproc ) + self.dl_args = dict(batch_size=self.batch_size, num_workers=self.num_workers) def setup(self, stage: str | None = None): src_train_ds = CIFAR10( @@ -141,10 +137,14 @@ def setup(self, stage: str | None = None): self.val_ds = CIFAR10Subset(**self.ds_args, idxs=ixs_val) def train_lbl_dataloader(self): - return DataLoader(self.train_lbl_ds, shuffle=True, drop_last=True, **self.dl_args) + return DataLoader( + self.train_lbl_ds, shuffle=True, drop_last=True, **self.dl_args + ) def train_unl_dataloader(self): - return DataLoader(self.train_unl_ds, shuffle=True, drop_last=True, **self.dl_args) + return DataLoader( + self.train_unl_ds, shuffle=True, drop_last=True, **self.dl_args + ) def val_dataloader(self): return DataLoader(self.val_ds, shuffle=False, **self.dl_args) diff --git a/mixmatch/main.py b/mixmatch/main.py index 27bce43..a598b97 100644 --- a/mixmatch/main.py +++ b/mixmatch/main.py @@ -60,7 +60,7 @@ def main( print(f"==> Preparing cifar10") dm = CIFAR10DataModule( - dataset_dir="./data", + dir="./data", n_train_lbl=train_lbl_size, n_train_unl=train_unl_size, batch_size=batch_size, From 7f507121966865bd256685b96533ee699d52aa84 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Fri, 24 Nov 2023 17:40:32 +0800 Subject: [PATCH 13/38] Refactor model file --- mixmatch/main.py | 4 +- mixmatch/models/wideresnet.py | 218 +++++++++++++++++++--------------- tests/test_model.py | 2 +- 3 files changed, 126 insertions(+), 98 deletions(-) diff --git a/mixmatch/main.py b/mixmatch/main.py index a598b97..2928798 100644 --- a/mixmatch/main.py +++ b/mixmatch/main.py @@ -8,7 +8,7 @@ from torch.utils.data import DataLoader from mixmatch.dataset.cifar10 import CIFAR10DataModule -from models.wideresnet import WideResNet +from models.wideresnet import WideResNet, WideResNetModule from utils.ema import WeightEMA from utils.eval import validate, train_epoch from utils.loss import SemiLoss @@ -72,7 +72,7 @@ def main( # Model print("==> creating WRN-28-2") - model = WideResNet(num_classes=len(classes)).to(device) + model = WideResNetModule(n_classes=len(classes)).to(device) ema_model = deepcopy(model).to(device) for param in ema_model.parameters(): param.detach_() diff --git a/mixmatch/models/wideresnet.py b/mixmatch/models/wideresnet.py index a8bfb90..72d82d1 100644 --- a/mixmatch/models/wideresnet.py +++ b/mixmatch/models/wideresnet.py @@ -1,4 +1,5 @@ import math + import torch import torch.nn as nn import torch.nn.functional as F @@ -7,164 +8,191 @@ class BasicBlock(nn.Module): def __init__( self, - in_planes, - out_planes, + in_dim, + out_dim, stride, - dropRate=0.0, + drop_rate=0.0, activate_before_residual=False, ): super(BasicBlock, self).__init__() - self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001) + self.bn1 = nn.BatchNorm2d(in_dim, momentum=0.001) self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True) self.conv1 = nn.Conv2d( - in_planes, - out_planes, + in_dim, + out_dim, kernel_size=3, stride=stride, padding=1, bias=False, ) - self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001) + self.bn2 = nn.BatchNorm2d(out_dim, momentum=0.001) self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True) self.conv2 = nn.Conv2d( - out_planes, - out_planes, + out_dim, + out_dim, kernel_size=3, stride=1, padding=1, bias=False, ) - self.droprate = dropRate - self.equalInOut = in_planes == out_planes - self.convShortcut = ( - (not self.equalInOut) - and nn.Conv2d( - in_planes, - out_planes, - kernel_size=1, - stride=stride, - padding=0, - bias=False, - ) - or None - ) + self.drop_rate = drop_rate + self.equal_in_out = in_dim == out_dim + if not self.equal_in_out: + self.conv_shortcut = nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=stride, padding=0, bias=False) + else: + self.conv_shortcut = None + self.activate_before_residual = activate_before_residual def forward(self, x): - if not self.equalInOut and self.activate_before_residual == True: + if self.equal_in_out or not self.activate_before_residual: + out = self.relu1(self.bn1(x)) + else: x = self.relu1(self.bn1(x)) + + if self.equal_in_out: + out = self.relu2(self.bn2(self.conv1(out))) else: - out = self.relu1(self.bn1(x)) - out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) - if self.droprate > 0: - out = F.dropout(out, p=self.droprate, training=self.training) + out = self.relu2(self.bn2(self.conv1(x))) + + out = F.dropout(out, p=self.drop_rate, training=self.training) out = self.conv2(out) - return torch.add(x if self.equalInOut else self.convShortcut(x), out) + + if self.equal_in_out: + y = torch.add(x, out) + else: + y = torch.add(self.conv_shortcut(x), out) + + return y class NetworkBlock(nn.Module): def __init__( self, - nb_layers, - in_planes, - out_planes, + n_blocks, + in_dim, + out_dim, block, stride, - dropRate=0.0, + drop_rate=0.0, activate_before_residual=False, ): super(NetworkBlock, self).__init__() - self.layer = self._make_layer( - block, - in_planes, - out_planes, - nb_layers, - stride, - dropRate, - activate_before_residual, - ) - - def _make_layer( - self, - block, - in_planes, - out_planes, - nb_layers, - stride, - dropRate, - activate_before_residual, - ): layers = [] - for i in range(int(nb_layers)): + for i in range(int(n_blocks)): layers.append( block( - i == 0 and in_planes or out_planes, - out_planes, + i == 0 and in_dim or out_dim, + out_dim, i == 0 and stride or 1, - dropRate, + drop_rate, activate_before_residual, ) ) - return nn.Sequential(*layers) + self.layer = nn.Sequential(*layers) def forward(self, x): return self.layer(x) class WideResNet(nn.Module): - def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0, seed=42): - torch.manual_seed(42) - super(WideResNet, self).__init__() - nChannels = [ + def __init__( + self, + n_classes: int, + depth: int = 28, + width: int = 2, + drop_rate: float = 0.0, + seed: int = 42, + ): + assert (depth - 4) % 6 == 0, "depth should be 6n+4" + super().__init__() + torch.manual_seed(seed) + n_channels = [ 16, - 16 * widen_factor, - 32 * widen_factor, - 64 * widen_factor, + 16 * width, + 32 * width, + 64 * width, ] - assert (depth - 4) % 6 == 0 - n = (depth - 4) / 6 - block = BasicBlock + blocks = (depth - 4) / 6 # 1st conv before any network block - self.conv1 = nn.Conv2d( - 3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False - ) + self.conv1 = nn.Conv2d(3, n_channels[0], kernel_size=3, padding=1, bias=False) # 1st block self.block1 = NetworkBlock( - n, - nChannels[0], - nChannels[1], - block, + blocks, + n_channels[0], + n_channels[1], + BasicBlock, 1, - dropRate, + drop_rate, activate_before_residual=True, ) # 2nd block - self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) + self.block2 = NetworkBlock( + blocks, n_channels[1], n_channels[2], BasicBlock, 2, drop_rate + ) # 3rd block - self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) + self.block3 = NetworkBlock( + blocks, n_channels[2], n_channels[3], BasicBlock, 2, drop_rate + ) # global average pooling and classifier - self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001) - self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.fc = nn.Linear(nChannels[3], num_classes) - self.nChannels = nChannels[3] + self.bn = nn.BatchNorm2d(n_channels[3], momentum=0.001) + self.leaky_relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.fc = nn.Linear(n_channels[3], n_classes) + self.out_dim = n_channels[3] for m in self.modules(): if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2.0 / n)) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() + blocks = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / blocks)) elif isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight.data) m.bias.data.zero_() def forward(self, x): - out = self.conv1(x) - out = self.block1(out) - out = self.block2(out) - out = self.block3(out) - out = self.relu(self.bn1(out)) - out = F.avg_pool2d(out, 8) - out = out.view(-1, self.nChannels) - return self.fc(out) + x = self.conv1(x) + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.leaky_relu(self.bn(x)) + x = F.avg_pool2d(x, 8) + x = x.view(-1, self.out_dim) + return self.fc(x) + +import pytorch_lightning as pl + +class WideResNetModule(pl.LightningModule): + def __init__( + self, + n_classes: int, + depth: int = 28, + width: int = 2, + drop_rate: float = 0.0, + seed: int = 42, + ): + super().__init__() + self.model = WideResNet( + n_classes=n_classes, + depth=depth, + width=width, + drop_rate=drop_rate, + seed=seed, + ) + + def forward(self, x): + return self.model(x) + + # def training_step(self, batch, batch_idx): + # x, y = batch + # y_hat = self.model(x) + # loss = F.cross_entropy(y_hat, y) + # self.log("train_loss", loss) + # return loss + # + # def validation_step(self, batch, batch_idx): + # x, y = batch + # y_hat = self.model(x) + # loss = F.cross_entropy(y_hat, y) + # self.log("val_loss", loss) + # + # def configure_optimizers(self): + # return torch.optim.Adam(self.model.parameters(), lr=0.002) \ No newline at end of file diff --git a/tests/test_model.py b/tests/test_model.py index e25b9ec..8ef4cdb 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -7,7 +7,7 @@ def test_model_seeded(): """Test that the model is always initialized the same way when seeded.""" def create_model(ema=False): - model_ = WideResNet(num_classes=10) + model_ = WideResNet(n_classes=10) model_ = model_.cuda() if ema: From 8d0f435c340aa26328af63b641d38af7457fffe2 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 08:30:09 +0800 Subject: [PATCH 14/38] Set Checkpoint --- mixmatch/dataset/cifar10.py | 67 +++++++++--- mixmatch/main.py | 6 +- mixmatch/models/wideresnet.py | 200 +++++++++++++++++++++++++++++----- mixmatch/utils/ema.py | 12 +- 4 files changed, 229 insertions(+), 56 deletions(-) diff --git a/mixmatch/dataset/cifar10.py b/mixmatch/dataset/cifar10.py index 85a4db6..1c1f988 100644 --- a/mixmatch/dataset/cifar10.py +++ b/mixmatch/dataset/cifar10.py @@ -7,7 +7,7 @@ import numpy as np import torch from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, RandomSampler from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.transforms.v2 import ( @@ -77,12 +77,13 @@ class CIFAR10DataModule(pl.LightningDataModule): n_train_lbl: float = 0.005 n_train_unl: float = 0.980 batch_size: int = 48 + train_iters: int = 1024 num_workers: int = 0 seed: int | None = 42 - train_lbl_ds: CIFAR10Subset = field(init=False) - train_unl_ds: CIFAR10Subset = field(init=False) - val_ds: CIFAR10Subset = field(init=False) + train_ds: CIFAR10 = field(init=False) + val_ds: CIFAR10 = field(init=False) test_ds: CIFAR10 = field(init=False) + k_augs: int = 2 def __post_init__(self): super().__init__() @@ -93,7 +94,7 @@ def __post_init__(self): self.ds_args = dict( root=self.dir, train=True, download=True, transform=tf_preproc ) - self.dl_args = dict(batch_size=self.batch_size, num_workers=self.num_workers) + self.dl_args = dict(batch_size=self.batch_size, pin_memory=True) def setup(self, stage: str | None = None): src_train_ds = CIFAR10( @@ -132,22 +133,56 @@ def setup(self, stage: str | None = None): **self.ds_args, idxs=ixs_train_lbl, k_augs=1, aug=tf_aug ) self.train_unl_ds = CIFAR10SubsetKAug( - **self.ds_args, idxs=ixs_train_unl, k_augs=2, aug=tf_aug + **self.ds_args, idxs=ixs_train_unl, k_augs=self.k_augs, aug=tf_aug ) self.val_ds = CIFAR10Subset(**self.ds_args, idxs=ixs_val) - def train_lbl_dataloader(self): + def train_dataloader(self) -> list[DataLoader]: + """The training dataloader returns a list of two dataloaders. + + Notes: + This train dataloader is special in that + 1) The labelled and unlabelled are sampled separately. + 2) Despite labelled being smaller than unlabelled, the dataloader + will sample with replacement to match training iterations. + + The num_samples supplied to the sampler is the exact number of + samples, so we need to multiply by the batch size. + + Returns: + A list of two dataloaders, the first for labelled data, the second + for unlabelled data. + + """ + return [ + DataLoader( + self.train_lbl_ds, + sampler=RandomSampler( + self.train_lbl_ds, + num_samples=self.batch_size * self.train_iters, + replacement=False, + ), + **self.dl_args, + num_workers=self.num_workers // 2, + ), + DataLoader( + self.train_unl_ds, + sampler=RandomSampler( + self.train_unl_ds, + num_samples=self.batch_size * self.train_iters, + replacement=False, + ), + **self.dl_args, + num_workers=self.num_workers // 2, + ), + ] + + def val_dataloader(self) -> DataLoader: return DataLoader( - self.train_lbl_ds, shuffle=True, drop_last=True, **self.dl_args + self.val_ds, shuffle=False, **self.dl_args, num_workers=self.num_workers ) - def train_unl_dataloader(self): + def test_dataloader(self) -> DataLoader: return DataLoader( - self.train_unl_ds, shuffle=True, drop_last=True, **self.dl_args + self.test_ds, shuffle=False, **self.dl_args, num_workers=self.num_workers ) - - def val_dataloader(self): - return DataLoader(self.val_ds, shuffle=False, **self.dl_args) - - def test_dataloader(self): - return DataLoader(self.test_ds, shuffle=False, **self.dl_args) diff --git a/mixmatch/main.py b/mixmatch/main.py index 2928798..b199c40 100644 --- a/mixmatch/main.py +++ b/mixmatch/main.py @@ -20,7 +20,7 @@ def main( batch_size: int = 64, lr: float = 0.002, train_iteration: int = 1024, - ema_wgt_decay: float = 0.999, + ema_lr: float = 0.999, unl_loss_scale: float = 75, mix_beta_alpha: float = 0.75, sharpen_temp: float = 0.5, @@ -36,7 +36,7 @@ def main( batch_size: The batch size to use. lr: The learning rate to use. train_iteration: The number of iterations to train for. - ema_wgt_decay: The weight decay to use for the EMA model. + ema_lr: The learning rate to use for the EMA. unl_loss_scale: The scaling factor for the unlabeled loss. mix_beta_alpha: The beta alpha to use for the mixup. sharpen_temp: The temperature to use for sharpening. @@ -81,7 +81,7 @@ def main( val_loss_fn = nn.CrossEntropyLoss() train_optim = optim.Adam(model.parameters(), lr=lr) - ema_optim = WeightEMA(model, ema_model, ema_wgt_decay=ema_wgt_decay, lr=lr) + ema_optim = WeightEMA(model, ema_model, ema_lr=ema_lr) test_accs = [] best_acc = 0 diff --git a/mixmatch/models/wideresnet.py b/mixmatch/models/wideresnet.py index 72d82d1..f709f8a 100644 --- a/mixmatch/models/wideresnet.py +++ b/mixmatch/models/wideresnet.py @@ -1,8 +1,18 @@ import math +from copy import deepcopy +from dataclasses import dataclass +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +import torch.nn.parallel +import torch.nn.parallel +from torch.nn.functional import one_hot +from torchmetrics.functional import accuracy + +from utils import SemiLoss, WeightEMA +from utils.interleave import interleave class BasicBlock(nn.Module): @@ -38,7 +48,9 @@ def __init__( self.drop_rate = drop_rate self.equal_in_out = in_dim == out_dim if not self.equal_in_out: - self.conv_shortcut = nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=stride, padding=0, bias=False) + self.conv_shortcut = nn.Conv2d( + in_dim, out_dim, kernel_size=1, stride=stride, padding=0, bias=False + ) else: self.conv_shortcut = None @@ -158,41 +170,171 @@ def forward(self, x): x = x.view(-1, self.out_dim) return self.fc(x) + import pytorch_lightning as pl + +# The eq=False is to prevent overriding hash +@dataclass(eq=False) class WideResNetModule(pl.LightningModule): - def __init__( - self, - n_classes: int, - depth: int = 28, - width: int = 2, - drop_rate: float = 0.0, - seed: int = 42, - ): + n_classes: int + depth: int = 28 + width: int = 2 + drop_rate: float = 0.0 + seed: int = 42 + sharpen_temp: float = 0.5 + mix_beta_alpha: float = 0.75 + unl_loss_scale: float = 75 + ema_lr: float = 0.001 + lr: float = 0.002 + weight_decay: float = 0.0005 + + train_loss_fn: SemiLoss = SemiLoss() + + def __post_init__(self): super().__init__() self.model = WideResNet( - n_classes=n_classes, - depth=depth, - width=width, - drop_rate=drop_rate, - seed=seed, + n_classes=self.n_classes, + depth=self.depth, + width=self.width, + drop_rate=self.drop_rate, + seed=self.seed, ) + self.ema_model = deepcopy(self.model) + for param in self.ema_model.parameters(): + param.detach_() + + self.ema_updater = WeightEMA(self.model, self.ema_model, ema_lr=self.ema_lr) def forward(self, x): return self.model(x) - # def training_step(self, batch, batch_idx): - # x, y = batch - # y_hat = self.model(x) - # loss = F.cross_entropy(y_hat, y) - # self.log("train_loss", loss) - # return loss - # - # def validation_step(self, batch, batch_idx): - # x, y = batch - # y_hat = self.model(x) - # loss = F.cross_entropy(y_hat, y) - # self.log("val_loss", loss) - # - # def configure_optimizers(self): - # return torch.optim.Adam(self.model.parameters(), lr=0.002) \ No newline at end of file + @staticmethod + def mix_up( + x: torch.Tensor, + y: torch.Tensor, + alpha: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Mix up the data + + Args: + x: The data to mix up. + y: The labels to mix up. + alpha: The alpha to use for the beta distribution. + + Returns: + The mixed up data and labels. + """ + ratio = np.random.beta(alpha, alpha) + ratio = max(ratio, 1 - ratio) + + shuf_idx = torch.randperm(x.size(0)) + + x_mix = ratio * x + (1 - ratio) * x[shuf_idx] + y_mix = ratio * y + (1 - ratio) * y[shuf_idx] + return x_mix, y_mix + + @staticmethod + def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: + """Sharpen the predictions by raising them to the power of 1 / temp + + Args: + y: The predictions to sharpen. + temp: The temperature to use. + + Returns: + The probability-normalized sharpened predictions + """ + y_sharp = y ** (1 / temp) + # Sharpening will change the sum of the predictions. + y_sharp /= y_sharp.sum(dim=1, keepdim=True) + return y_sharp + + def guess_labels( + self, + x_unls: list[torch.Tensor], + ) -> torch.Tensor: + """Guess labels from the unlabelled data""" + y_unls: list[torch.Tensor] = [torch.softmax(self(u), dim=1) for u in x_unls] + # The sum will sum the tensors in the list, it doesn't reduce the tensors + y_unl = sum(y_unls) / len(y_unls) + return y_unl + + def training_step(self, batch, batch_idx): + (x_lbl, y_lbl), (x_unls, _) = batch + x_lbl = x_lbl[0] + y_lbl = one_hot(y_lbl.long(), num_classes=self.n_classes) + + with torch.no_grad(): + y_unl = self.guess_labels(x_unls=x_unls) + y_unl = self.sharpen(y_unl, self.sharpen_temp) + + x = torch.cat([x_lbl, *x_unls], dim=0) + y = torch.cat([y_lbl, y_unl, y_unl], dim=0) + x_mix, y_mix = self.mix_up(x, y, self.mix_beta_alpha) + + is_interleave = True + if is_interleave: + # interleave labeled and unlabeled samples between batches to + # get correct batchnorm calculation + batch_size = x_lbl.shape[0] + x_mix = list(torch.split(x_mix, batch_size)) + x_mix = interleave(x_mix, batch_size) + + y_mix_pred = [self(x) for x in x_mix] + + # put interleaved samples back + y_mix_pred = interleave(y_mix_pred, batch_size) + + y_mix_lbl_pred = y_mix_pred[0] + y_mix_lbl = y_mix[:batch_size] + y_mix_unl_pred = torch.cat(y_mix_pred[1:], dim=0) + y_mix_unl = y_mix[batch_size:] + else: + batch_size = x_lbl.shape[0] + y_mix_pred = self(x_mix) + y_mix_lbl_pred = y_mix_pred[:batch_size] + y_mix_unl_pred = y_mix_pred[batch_size:] + y_mix_lbl = y_mix[:batch_size] + y_mix_unl = y_mix[batch_size:] + + loss_lbl, loss_unl = self.train_loss_fn( + x_lbl=y_mix_lbl_pred, + y_lbl=y_mix_lbl, + x_unl=y_mix_unl_pred, + y_unl=y_mix_unl, + ) + loss_unl_scale = ( + (self.current_epoch + batch_idx / self.trainer.num_training_batches) + / self.trainer.max_epochs + * self.unl_loss_scale + ) + loss = loss_lbl + loss_unl * loss_unl_scale + + self.log("train_loss", loss, prog_bar=True) + self.log("train_loss_lbl", loss_lbl) + self.log("train_loss_unl", loss_unl) + + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_pred = self(x) + loss = F.cross_entropy(y_pred, y.long()) + + acc = accuracy( + y_pred, + y, + task="multiclass", + num_classes=y_pred.shape[1], + ) + self.log("val_loss", loss) + self.log("val_acc", acc, prog_bar=True) + return loss + + def on_after_backward(self) -> None: + self.ema_updater.step() + + def configure_optimizers(self): + return torch.optim.Adam(self.model.parameters(), lr=self.lr, + weight_decay=self.weight_decay) diff --git a/mixmatch/utils/ema.py b/mixmatch/utils/ema.py index cb33c05..7d64f40 100644 --- a/mixmatch/utils/ema.py +++ b/mixmatch/utils/ema.py @@ -8,15 +8,13 @@ def __init__( self, model: nn.Module, ema_model: nn.Module, - ema_wgt_decay: float = 0.999, - lr: float = 0.002, + ema_lr: float, ): self.model = model self.ema_model = ema_model - self.alpha = ema_wgt_decay + self.ema_lr = ema_lr self.params = list(model.state_dict().values()) self.ema_params = list(ema_model.state_dict().values()) - self.wd = 0.02 * lr for param, ema_param in zip(self.params, self.ema_params): param.data.copy_(ema_param.data) @@ -24,7 +22,5 @@ def __init__( def step(self): for param, ema_param in zip(self.params, self.ema_params): if ema_param.dtype == torch.float32: - ema_param.mul_(self.alpha) - ema_param.add_(param * (1.0 - self.alpha)) - # customized weight decay - param.mul_(1 - self.wd) + ema_param.mul_(1 - self.ema_lr) + ema_param.add_(param * self.ema_lr) From f07779d825eb7c15166cd353c837b276c4996d32 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 08:58:11 +0800 Subject: [PATCH 15/38] Fix incorrect model used to guess and eval --- mixmatch/models/wideresnet.py | 10 +++++++--- mixmatch/utils/ema.py | 2 -- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mixmatch/models/wideresnet.py b/mixmatch/models/wideresnet.py index f709f8a..8412fae 100644 --- a/mixmatch/models/wideresnet.py +++ b/mixmatch/models/wideresnet.py @@ -204,7 +204,11 @@ def __post_init__(self): for param in self.ema_model.parameters(): param.detach_() - self.ema_updater = WeightEMA(self.model, self.ema_model, ema_lr=self.ema_lr) + self.ema_updater = WeightEMA( + model=self.model, + ema_model=self.ema_model, + ema_lr=self.ema_lr + ) def forward(self, x): return self.model(x) @@ -255,7 +259,7 @@ def guess_labels( x_unls: list[torch.Tensor], ) -> torch.Tensor: """Guess labels from the unlabelled data""" - y_unls: list[torch.Tensor] = [torch.softmax(self(u), dim=1) for u in x_unls] + y_unls: list[torch.Tensor] = [torch.softmax(self.ema_model(u), dim=1) for u in x_unls] # The sum will sum the tensors in the list, it doesn't reduce the tensors y_unl = sum(y_unls) / len(y_unls) return y_unl @@ -319,7 +323,7 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): x, y = batch - y_pred = self(x) + y_pred = self.ema_model(x) loss = F.cross_entropy(y_pred, y.long()) acc = accuracy( diff --git a/mixmatch/utils/ema.py b/mixmatch/utils/ema.py index 7d64f40..c608da9 100644 --- a/mixmatch/utils/ema.py +++ b/mixmatch/utils/ema.py @@ -10,8 +10,6 @@ def __init__( ema_model: nn.Module, ema_lr: float, ): - self.model = model - self.ema_model = ema_model self.ema_lr = ema_lr self.params = list(model.state_dict().values()) self.ema_params = list(ema_model.state_dict().values()) From eca3fa2bc9d8bc46b50f1f4e504b2243b368f821 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 10:17:56 +0800 Subject: [PATCH 16/38] Fix EMA params not updated --- mixmatch/utils/ema.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mixmatch/utils/ema.py b/mixmatch/utils/ema.py index c608da9..89545db 100644 --- a/mixmatch/utils/ema.py +++ b/mixmatch/utils/ema.py @@ -11,14 +11,12 @@ def __init__( ema_lr: float, ): self.ema_lr = ema_lr - self.params = list(model.state_dict().values()) - self.ema_params = list(ema_model.state_dict().values()) - - for param, ema_param in zip(self.params, self.ema_params): - param.data.copy_(ema_param.data) + self.model = model + self.ema_model = ema_model def step(self): - for param, ema_param in zip(self.params, self.ema_params): + for param, ema_param in zip(self.model.parameters(), + self.ema_model.parameters()): if ema_param.dtype == torch.float32: ema_param.mul_(1 - self.ema_lr) ema_param.add_(param * self.ema_lr) From 1c615a0b63ae49c8bebad776dffa9d1e0c07d411 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 11:20:12 +0800 Subject: [PATCH 17/38] Fix issue with longer wait times for runner spinup --- mixmatch/dataset/cifar10.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mixmatch/dataset/cifar10.py b/mixmatch/dataset/cifar10.py index 1c1f988..6d777e6 100644 --- a/mixmatch/dataset/cifar10.py +++ b/mixmatch/dataset/cifar10.py @@ -94,7 +94,9 @@ def __post_init__(self): self.ds_args = dict( root=self.dir, train=True, download=True, transform=tf_preproc ) - self.dl_args = dict(batch_size=self.batch_size, pin_memory=True) + self.dl_args = dict(batch_size=self.batch_size, + persistent_workers=True, + pin_memory=True) def setup(self, stage: str | None = None): src_train_ds = CIFAR10( From 55872a7249f6f879e5e83dbaf6aaf1f017912158 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 19:20:07 +0800 Subject: [PATCH 18/38] Fix memory leak on updating EMA --- mixmatch/models/wideresnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mixmatch/models/wideresnet.py b/mixmatch/models/wideresnet.py index 8412fae..0b03358 100644 --- a/mixmatch/models/wideresnet.py +++ b/mixmatch/models/wideresnet.py @@ -336,6 +336,7 @@ def validation_step(self, batch, batch_idx): self.log("val_acc", acc, prog_bar=True) return loss + @torch.no_grad() def on_after_backward(self) -> None: self.ema_updater.step() From 171d5bd73387a41a26f247e44df1582b78a7f1d7 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 19:24:12 +0800 Subject: [PATCH 19/38] Parametrize interleaving --- mixmatch/models/wideresnet.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/mixmatch/models/wideresnet.py b/mixmatch/models/wideresnet.py index 0b03358..c6602b1 100644 --- a/mixmatch/models/wideresnet.py +++ b/mixmatch/models/wideresnet.py @@ -12,7 +12,7 @@ from torchmetrics.functional import accuracy from utils import SemiLoss, WeightEMA -from utils.interleave import interleave +import utils.interleave class BasicBlock(nn.Module): @@ -189,6 +189,9 @@ class WideResNetModule(pl.LightningModule): lr: float = 0.002 weight_decay: float = 0.0005 + # See our wiki for details on interleave + interleave: bool = False + train_loss_fn: SemiLoss = SemiLoss() def __post_init__(self): @@ -277,18 +280,18 @@ def training_step(self, batch, batch_idx): y = torch.cat([y_lbl, y_unl, y_unl], dim=0) x_mix, y_mix = self.mix_up(x, y, self.mix_beta_alpha) - is_interleave = True - if is_interleave: - # interleave labeled and unlabeled samples between batches to - # get correct batchnorm calculation + if self.interleave: + # This performs interleaving, see our wiki for details. batch_size = x_lbl.shape[0] x_mix = list(torch.split(x_mix, batch_size)) - x_mix = interleave(x_mix, batch_size) + + # Interleave to get a consistent Batch Norm Calculation + x_mix = utils.interleave(x_mix, batch_size) y_mix_pred = [self(x) for x in x_mix] - # put interleaved samples back - y_mix_pred = interleave(y_mix_pred, batch_size) + # Un-interleave to shuffle back to original order + y_mix_pred = utils.interleave(y_mix_pred, batch_size) y_mix_lbl_pred = y_mix_pred[0] y_mix_lbl = y_mix[:batch_size] @@ -336,6 +339,8 @@ def validation_step(self, batch, batch_idx): self.log("val_acc", acc, prog_bar=True) return loss + # PyTorch Lightning doesn't automatically no_grads the EMA step. + # It's important to keep this to avoid a memory leak. @torch.no_grad() def on_after_backward(self) -> None: self.ema_updater.step() From 0220e3273b157877de68741eee244ca2d6bea4e1 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 19:52:58 +0800 Subject: [PATCH 20/38] Improve Documentation for cifar10.py --- mixmatch/dataset/cifar10.py | 81 ++++++++++++++++++++++++++++++------- 1 file changed, 67 insertions(+), 14 deletions(-) diff --git a/mixmatch/dataset/cifar10.py b/mixmatch/dataset/cifar10.py index 6d777e6..f8d6fe6 100644 --- a/mixmatch/dataset/cifar10.py +++ b/mixmatch/dataset/cifar10.py @@ -71,19 +71,69 @@ def __getitem__(self, item): import pytorch_lightning as pl +# TODO: We should make this dataset agnostic, so we can use it for other +# datasets. @dataclass -class CIFAR10DataModule(pl.LightningDataModule): +class SSLCIFAR10DataModule(pl.LightningDataModule): + """The CIFAR10 datamodule for semi-supervised learning. + + Notes: + This datamodule is configured for SSL on CIFAR10. + + The major difference is that despite the labelled data being smaller + than the unlabelled data, the dataloader will sample with replacement + to match training iterations. Hence, each epoch will have the same + number of training iterations for labelled and unlabelled data. + + The batch size, thus, doesn't affect the number of training iterations, + each iteration will have the specified batch size. + + For example: + train_lbl_size = 0.005 (250) + train_unl_size = 0.980 (49000) + batch_size = 48 + train_iters = 1024 + + In pseudocode + + for epoch in range(epochs): + for train_iter in range(1024): + lbl = sample(lbl_pool, 48) + unl = sample(unl_pool, 48) + + Each epoch will have 1024 training iterations. + Each training iteration will pull 48 labelled and 48 unlabelled + samples from the above pools, with replacement. Therefore, unlike + traditional dataloaders, we can see repeated samples in the same + epoch. (replacement=False in our RandomSampler only prevents + replacements within a minibatch) + + Args: + dir: The directory to store the data. + train_lbl_size: The size of the labelled training set. + train_unl_size: The size of the unlabelled training set. + batch_size: The batch size to use. + train_iters: The number of training iterations per epoch. + seed: The seed to use for reproducibility. If None, no seed is used. + k_augs: The number of augmentations to use for unlabelled data. + num_workers: The number of workers to use for the dataloaders. + persistent_workers: Whether to use persistent workers for the dataloaders. + pin_memory: Whether to pin memory for the dataloaders. + """ dir: Path | str - n_train_lbl: float = 0.005 - n_train_unl: float = 0.980 + train_lbl_size: float = 0.005 + train_unl_size: float = 0.980 batch_size: int = 48 train_iters: int = 1024 - num_workers: int = 0 seed: int | None = 42 - train_ds: CIFAR10 = field(init=False) + k_augs: int = 2 + num_workers: int = 0 + persistent_workers: bool = True + pin_memory: bool = True + train_lbl_ds: CIFAR10 = field(init=False) + train_unl_ds: CIFAR10 = field(init=False) val_ds: CIFAR10 = field(init=False) test_ds: CIFAR10 = field(init=False) - k_augs: int = 2 def __post_init__(self): super().__init__() @@ -94,9 +144,11 @@ def __post_init__(self): self.ds_args = dict( root=self.dir, train=True, download=True, transform=tf_preproc ) - self.dl_args = dict(batch_size=self.batch_size, - persistent_workers=True, - pin_memory=True) + self.dl_args = dict( + batch_size=self.batch_size, + persistent_workers=self.persistent_workers, + pin_memory=self.pin_memory, + ) def setup(self, stage: str | None = None): src_train_ds = CIFAR10( @@ -113,8 +165,8 @@ def setup(self, stage: str | None = None): ) n_train = len(src_train_ds) - n_train_unl = int(n_train * self.n_train_unl) - n_train_lbl = int(n_train * self.n_train_lbl) + n_train_unl = int(n_train * self.train_unl_size) + n_train_lbl = int(n_train * self.train_lbl_size) n_val = int(n_train - n_train_unl - n_train_lbl) targets = np.array(src_train_ds.targets) @@ -154,8 +206,9 @@ def train_dataloader(self) -> list[DataLoader]: Returns: A list of two dataloaders, the first for labelled data, the second for unlabelled data. - """ + lbl_workers = self.num_workers // (self.k_augs + 1) + unl_workers = self.num_workers - lbl_workers return [ DataLoader( self.train_lbl_ds, @@ -164,8 +217,8 @@ def train_dataloader(self) -> list[DataLoader]: num_samples=self.batch_size * self.train_iters, replacement=False, ), + num_workers=lbl_workers, **self.dl_args, - num_workers=self.num_workers // 2, ), DataLoader( self.train_unl_ds, @@ -174,8 +227,8 @@ def train_dataloader(self) -> list[DataLoader]: num_samples=self.batch_size * self.train_iters, replacement=False, ), + num_workers=unl_workers, **self.dl_args, - num_workers=self.num_workers // 2, ), ] From 0063ec706af6faf733df6d6fe6f1c3091786378a Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 19:53:36 +0800 Subject: [PATCH 21/38] Format --- mixmatch/dataset/cifar10.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mixmatch/dataset/cifar10.py b/mixmatch/dataset/cifar10.py index f8d6fe6..aa1be10 100644 --- a/mixmatch/dataset/cifar10.py +++ b/mixmatch/dataset/cifar10.py @@ -120,6 +120,7 @@ class SSLCIFAR10DataModule(pl.LightningDataModule): persistent_workers: Whether to use persistent workers for the dataloaders. pin_memory: Whether to pin memory for the dataloaders. """ + dir: Path | str train_lbl_size: float = 0.005 train_unl_size: float = 0.980 From 825403080e345a4283138bd3b71874206a94b8ca Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 19:53:43 +0800 Subject: [PATCH 22/38] Format with Black --- mixmatch/utils/ema.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mixmatch/utils/ema.py b/mixmatch/utils/ema.py index 89545db..ad1a50a 100644 --- a/mixmatch/utils/ema.py +++ b/mixmatch/utils/ema.py @@ -15,8 +15,9 @@ def __init__( self.ema_model = ema_model def step(self): - for param, ema_param in zip(self.model.parameters(), - self.ema_model.parameters()): + for param, ema_param in zip( + self.model.parameters(), self.ema_model.parameters() + ): if ema_param.dtype == torch.float32: ema_param.mul_(1 - self.ema_lr) ema_param.add_(param * self.ema_lr) From 9d126fc60c60a3edea04c529b454672c737c6495 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 20:02:23 +0800 Subject: [PATCH 23/38] Migrate Module to separate file --- mixmatch/models/mixmatch_module.py | 208 +++++++++++++++++++++++++++++ mixmatch/models/wideresnet.py | 187 -------------------------- 2 files changed, 208 insertions(+), 187 deletions(-) create mode 100644 mixmatch/models/mixmatch_module.py diff --git a/mixmatch/models/mixmatch_module.py b/mixmatch/models/mixmatch_module.py new file mode 100644 index 0000000..251c20a --- /dev/null +++ b/mixmatch/models/mixmatch_module.py @@ -0,0 +1,208 @@ +from copy import deepcopy +from dataclasses import dataclass +from typing import Callable + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.parallel +import torch.nn.parallel +from torch.nn.functional import one_hot +from torchmetrics.functional import accuracy + +import utils.interleave +from utils import WeightEMA + + +# The eq=False is to prevent overriding hash +@dataclass(eq=False) +class MixMatchModule(pl.LightningModule): + """PyTorch Lightning Module for MixMatch + + Notes: + This performs MixMatch as described in the paper. + https://arxiv.org/abs/1905.02249 + + This module is designed to be used with any model, not only + the WideResNet model. + + Furthermore, while it's possible to switch datasets, take a look + at how we implement the CIFAR10DataModule's DataLoaders to see + how to implement a new dataset. + + Args: + model: The model to train. + sharpen_temp: The temperature to use for sharpening. + mix_beta_alpha: The alpha to use for the beta distribution when mixing. + unl_loss_scale: The scale to use for the unsupervised loss. + ema_lr: The learning rate to use for the EMA. + lr: The learning rate to use for the optimizer. + weight_decay: The weight decay to use for the optimizer. + """ + + model: nn.Module + sharpen_temp: float = 0.5 + mix_beta_alpha: float = 0.75 + unl_loss_scale: float = 75 + ema_lr: float = 0.001 + lr: float = 0.002 + weight_decay: float = 0.0005 + + # See our wiki for details on interleave + interleave: bool = False + + train_lbl_loss: Callable[ + [torch.Tensor, torch.Tensor], torch.Tensor + ] = F.cross_entropy + # TODO: Not sure why this is different from MSELoss + # It's likely not a big deal, but it's worth investigating if we have + # too much time on our hands + train_unl_loss: Callable[ + [torch.Tensor, torch.Tensor], torch.Tensor + ] = lambda pred, tgt: torch.mean((torch.softmax(pred, dim=1) - tgt) ** 2) + + def __post_init__(self): + super().__init__() + self.ema_model = deepcopy(self.model) + for param in self.ema_model.parameters(): + param.detach_() + + self.ema_updater = WeightEMA( + model=self.model, ema_model=self.ema_model, ema_lr=self.ema_lr + ) + + def forward(self, x): + return self.model(x) + + @staticmethod + def mix_up( + x: torch.Tensor, + y: torch.Tensor, + alpha: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Mix up the data + + Args: + x: The data to mix up. + y: The labels to mix up. + alpha: The alpha to use for the beta distribution. + + Returns: + The mixed up data and labels. + """ + ratio = np.random.beta(alpha, alpha) + ratio = max(ratio, 1 - ratio) + + shuf_idx = torch.randperm(x.size(0)) + + x_mix = ratio * x + (1 - ratio) * x[shuf_idx] + y_mix = ratio * y + (1 - ratio) * y[shuf_idx] + return x_mix, y_mix + + @staticmethod + def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: + """Sharpen the predictions by raising them to the power of 1 / temp + + Args: + y: The predictions to sharpen. + temp: The temperature to use. + + Returns: + The probability-normalized sharpened predictions + """ + y_sharp = y ** (1 / temp) + # Sharpening will change the sum of the predictions. + y_sharp /= y_sharp.sum(dim=1, keepdim=True) + return y_sharp + + def guess_labels( + self, + x_unls: list[torch.Tensor], + ) -> torch.Tensor: + """Guess labels from the unlabelled data""" + y_unls: list[torch.Tensor] = [ + torch.softmax(self.ema_model(u), dim=1) for u in x_unls + ] + # The sum will sum the tensors in the list, it doesn't reduce the tensors + y_unl = sum(y_unls) / len(y_unls) + return y_unl + + def training_step(self, batch, batch_idx): + (x_lbl, y_lbl), (x_unls, _) = batch + x_lbl = x_lbl[0] + y_lbl = one_hot(y_lbl.long(), num_classes=self.n_classes) + + with torch.no_grad(): + y_unl = self.guess_labels(x_unls=x_unls) + y_unl = self.sharpen(y_unl, self.sharpen_temp) + + x = torch.cat([x_lbl, *x_unls], dim=0) + y = torch.cat([y_lbl, y_unl, y_unl], dim=0) + x_mix, y_mix = self.mix_up(x, y, self.mix_beta_alpha) + + if self.interleave: + # This performs interleaving, see our wiki for details. + batch_size = x_lbl.shape[0] + x_mix = list(torch.split(x_mix, batch_size)) + + # Interleave to get a consistent Batch Norm Calculation + x_mix = utils.interleave(x_mix, batch_size) + + y_mix_pred = [self(x) for x in x_mix] + + # Un-interleave to shuffle back to original order + y_mix_pred = utils.interleave(y_mix_pred, batch_size) + + y_mix_lbl_pred = y_mix_pred[0] + y_mix_lbl = y_mix[:batch_size] + y_mix_unl_pred = torch.cat(y_mix_pred[1:], dim=0) + y_mix_unl = y_mix[batch_size:] + else: + batch_size = x_lbl.shape[0] + y_mix_pred = self(x_mix) + y_mix_lbl_pred = y_mix_pred[:batch_size] + y_mix_unl_pred = y_mix_pred[batch_size:] + y_mix_lbl = y_mix[:batch_size] + y_mix_unl = y_mix[batch_size:] + + loss_lbl = self.train_lbl_loss(y_mix_lbl_pred, y_mix_lbl) + loss_unl = self.train_unl_loss(y_mix_unl_pred, y_mix_unl) + + # The scale is a linear ramp up from 0 to self.unl_loss_scale + # over the course of training. + loss_unl_scale = ( + (self.current_epoch + batch_idx / self.trainer.num_training_batches) + / self.trainer.max_epochs + * self.unl_loss_scale + ) + + loss = loss_lbl + loss_unl * loss_unl_scale + + self.log("train_loss", loss, prog_bar=True) + self.log("train_loss_lbl", loss_lbl) + self.log("train_loss_unl", loss_unl) + + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_pred = self.ema_model(x) + loss = F.cross_entropy(y_pred, y.long()) + + acc = accuracy(y_pred, y, task="multiclass", num_classes=y_pred.shape[1]) + self.log("val_loss", loss) + self.log("val_acc", acc, prog_bar=True) + return loss + + # PyTorch Lightning doesn't automatically no_grads the EMA step. + # It's important to keep this to avoid a memory leak. + @torch.no_grad() + def on_after_backward(self) -> None: + self.ema_updater.step() + + def configure_optimizers(self): + return torch.optim.Adam( + self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay + ) diff --git a/mixmatch/models/wideresnet.py b/mixmatch/models/wideresnet.py index c6602b1..df9b748 100644 --- a/mixmatch/models/wideresnet.py +++ b/mixmatch/models/wideresnet.py @@ -1,18 +1,10 @@ import math -from copy import deepcopy -from dataclasses import dataclass -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.parallel import torch.nn.parallel -from torch.nn.functional import one_hot -from torchmetrics.functional import accuracy - -from utils import SemiLoss, WeightEMA -import utils.interleave class BasicBlock(nn.Module): @@ -169,182 +161,3 @@ def forward(self, x): x = F.avg_pool2d(x, 8) x = x.view(-1, self.out_dim) return self.fc(x) - - -import pytorch_lightning as pl - - -# The eq=False is to prevent overriding hash -@dataclass(eq=False) -class WideResNetModule(pl.LightningModule): - n_classes: int - depth: int = 28 - width: int = 2 - drop_rate: float = 0.0 - seed: int = 42 - sharpen_temp: float = 0.5 - mix_beta_alpha: float = 0.75 - unl_loss_scale: float = 75 - ema_lr: float = 0.001 - lr: float = 0.002 - weight_decay: float = 0.0005 - - # See our wiki for details on interleave - interleave: bool = False - - train_loss_fn: SemiLoss = SemiLoss() - - def __post_init__(self): - super().__init__() - self.model = WideResNet( - n_classes=self.n_classes, - depth=self.depth, - width=self.width, - drop_rate=self.drop_rate, - seed=self.seed, - ) - self.ema_model = deepcopy(self.model) - for param in self.ema_model.parameters(): - param.detach_() - - self.ema_updater = WeightEMA( - model=self.model, - ema_model=self.ema_model, - ema_lr=self.ema_lr - ) - - def forward(self, x): - return self.model(x) - - @staticmethod - def mix_up( - x: torch.Tensor, - y: torch.Tensor, - alpha: float, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Mix up the data - - Args: - x: The data to mix up. - y: The labels to mix up. - alpha: The alpha to use for the beta distribution. - - Returns: - The mixed up data and labels. - """ - ratio = np.random.beta(alpha, alpha) - ratio = max(ratio, 1 - ratio) - - shuf_idx = torch.randperm(x.size(0)) - - x_mix = ratio * x + (1 - ratio) * x[shuf_idx] - y_mix = ratio * y + (1 - ratio) * y[shuf_idx] - return x_mix, y_mix - - @staticmethod - def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: - """Sharpen the predictions by raising them to the power of 1 / temp - - Args: - y: The predictions to sharpen. - temp: The temperature to use. - - Returns: - The probability-normalized sharpened predictions - """ - y_sharp = y ** (1 / temp) - # Sharpening will change the sum of the predictions. - y_sharp /= y_sharp.sum(dim=1, keepdim=True) - return y_sharp - - def guess_labels( - self, - x_unls: list[torch.Tensor], - ) -> torch.Tensor: - """Guess labels from the unlabelled data""" - y_unls: list[torch.Tensor] = [torch.softmax(self.ema_model(u), dim=1) for u in x_unls] - # The sum will sum the tensors in the list, it doesn't reduce the tensors - y_unl = sum(y_unls) / len(y_unls) - return y_unl - - def training_step(self, batch, batch_idx): - (x_lbl, y_lbl), (x_unls, _) = batch - x_lbl = x_lbl[0] - y_lbl = one_hot(y_lbl.long(), num_classes=self.n_classes) - - with torch.no_grad(): - y_unl = self.guess_labels(x_unls=x_unls) - y_unl = self.sharpen(y_unl, self.sharpen_temp) - - x = torch.cat([x_lbl, *x_unls], dim=0) - y = torch.cat([y_lbl, y_unl, y_unl], dim=0) - x_mix, y_mix = self.mix_up(x, y, self.mix_beta_alpha) - - if self.interleave: - # This performs interleaving, see our wiki for details. - batch_size = x_lbl.shape[0] - x_mix = list(torch.split(x_mix, batch_size)) - - # Interleave to get a consistent Batch Norm Calculation - x_mix = utils.interleave(x_mix, batch_size) - - y_mix_pred = [self(x) for x in x_mix] - - # Un-interleave to shuffle back to original order - y_mix_pred = utils.interleave(y_mix_pred, batch_size) - - y_mix_lbl_pred = y_mix_pred[0] - y_mix_lbl = y_mix[:batch_size] - y_mix_unl_pred = torch.cat(y_mix_pred[1:], dim=0) - y_mix_unl = y_mix[batch_size:] - else: - batch_size = x_lbl.shape[0] - y_mix_pred = self(x_mix) - y_mix_lbl_pred = y_mix_pred[:batch_size] - y_mix_unl_pred = y_mix_pred[batch_size:] - y_mix_lbl = y_mix[:batch_size] - y_mix_unl = y_mix[batch_size:] - - loss_lbl, loss_unl = self.train_loss_fn( - x_lbl=y_mix_lbl_pred, - y_lbl=y_mix_lbl, - x_unl=y_mix_unl_pred, - y_unl=y_mix_unl, - ) - loss_unl_scale = ( - (self.current_epoch + batch_idx / self.trainer.num_training_batches) - / self.trainer.max_epochs - * self.unl_loss_scale - ) - loss = loss_lbl + loss_unl * loss_unl_scale - - self.log("train_loss", loss, prog_bar=True) - self.log("train_loss_lbl", loss_lbl) - self.log("train_loss_unl", loss_unl) - - return loss - - def validation_step(self, batch, batch_idx): - x, y = batch - y_pred = self.ema_model(x) - loss = F.cross_entropy(y_pred, y.long()) - - acc = accuracy( - y_pred, - y, - task="multiclass", - num_classes=y_pred.shape[1], - ) - self.log("val_loss", loss) - self.log("val_acc", acc, prog_bar=True) - return loss - - # PyTorch Lightning doesn't automatically no_grads the EMA step. - # It's important to keep this to avoid a memory leak. - @torch.no_grad() - def on_after_backward(self) -> None: - self.ema_updater.step() - - def configure_optimizers(self): - return torch.optim.Adam(self.model.parameters(), lr=self.lr, - weight_decay=self.weight_decay) From e1320b58c8beca12b829af291548bff67dd030b0 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 20:03:33 +0800 Subject: [PATCH 24/38] Clean up unused dependencies --- mixmatch/main.py | 140 --------------------------- mixmatch/utils/__init__.py | 9 -- mixmatch/utils/eval.py | 189 ------------------------------------- mixmatch/utils/loss.py | 28 ------ tests/test_main.py | 19 ---- 5 files changed, 385 deletions(-) delete mode 100644 mixmatch/main.py delete mode 100644 mixmatch/utils/__init__.py delete mode 100644 mixmatch/utils/eval.py delete mode 100644 mixmatch/utils/loss.py delete mode 100644 tests/test_main.py diff --git a/mixmatch/main.py b/mixmatch/main.py deleted file mode 100644 index b199c40..0000000 --- a/mixmatch/main.py +++ /dev/null @@ -1,140 +0,0 @@ -import random -from copy import deepcopy - -import numpy as np -import torch -import torch.nn as nn -import torch.optim as optim -from torch.utils.data import DataLoader - -from mixmatch.dataset.cifar10 import CIFAR10DataModule -from models.wideresnet import WideResNet, WideResNetModule -from utils.ema import WeightEMA -from utils.eval import validate, train_epoch -from utils.loss import SemiLoss - - -def main( - *, - epochs: int = 1024, - batch_size: int = 64, - lr: float = 0.002, - train_iteration: int = 1024, - ema_lr: float = 0.999, - unl_loss_scale: float = 75, - mix_beta_alpha: float = 0.75, - sharpen_temp: float = 0.5, - device: str = "cuda", - seed: int | None = 42, - train_lbl_size: int = 0.005, - train_unl_size: int = 0.980, -): - """The main function to run the MixMatch algorithm - - Args: - epochs: Number of epochs to run. - batch_size: The batch size to use. - lr: The learning rate to use. - train_iteration: The number of iterations to train for. - ema_lr: The learning rate to use for the EMA. - unl_loss_scale: The scaling factor for the unlabeled loss. - mix_beta_alpha: The beta alpha to use for the mixup. - sharpen_temp: The temperature to use for sharpening. - device: The device to use. - seed: The seed to use. If None, then it'll be non-deterministic. - train_lbl_size: The size of the labeled training set. - train_unl_size: The size of the unlabeled training set. - """ - deterministic = seed is not None - - if deterministic: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - # Data - print(f"==> Preparing cifar10") - - dm = CIFAR10DataModule( - dir="./data", - n_train_lbl=train_lbl_size, - n_train_unl=train_unl_size, - batch_size=batch_size, - seed=seed, - ) - dm.setup() - classes = dm.test_ds.classes - - # Model - print("==> creating WRN-28-2") - - model = WideResNetModule(n_classes=len(classes)).to(device) - ema_model = deepcopy(model).to(device) - for param in ema_model.parameters(): - param.detach_() - - train_loss_fn = SemiLoss() - val_loss_fn = nn.CrossEntropyLoss() - train_optim = optim.Adam(model.parameters(), lr=lr) - - ema_optim = WeightEMA(model, ema_model, ema_lr=ema_lr) - - test_accs = [] - best_acc = 0 - # Train and val - for epoch in range(epochs): - print("\nEpoch: [%d | %d] LR: %f" % (epoch + 1, epochs, lr)) - - train_loss, train_lbl_loss, train_unl_loss = train_epoch( - train_lbl_dl=dm.train_lbl_dataloader(), - train_unl_dl=dm.train_unl_dataloader(), - model=model, - optim=train_optim, - ema_optim=ema_optim, - loss_fn=train_loss_fn, - epoch=epoch, - device=device, - train_iters=train_iteration, - unl_loss_scale=unl_loss_scale, - mix_beta_alpha=mix_beta_alpha, - epochs=epochs, - sharpen_temp=sharpen_temp, - ) - - def val_ema(dl: DataLoader): - return validate( - valloader=dl, - model=ema_model, - loss_fn=val_loss_fn, - device=device, - ) - - _, train_acc = val_ema(dm.train_lbl_dataloader()) - val_loss, val_acc = val_ema(dm.val_dataloader()) - test_loss, test_acc = val_ema(dm.test_dataloader()) - - best_acc = max(val_acc, best_acc) - test_accs.append(test_acc) - - print( - f"Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.3f} | " - f"Val Loss: {val_loss:.3f} | Val Acc: {val_acc:.3f} | " - f"Test Loss: {test_loss:.3f} | Test Acc: {test_acc:.3f} | " - f"Best Acc: {best_acc:.3f} | " - f"Mean Acc: {np.mean(test_accs[-20:]):.3f} | " - f"LR: {lr:.5f} | " - f"Train Loss X: {train_lbl_loss:.3f} | " - f"Train Loss U: {train_unl_loss:.3f} " - ) - - print("Best acc:") - print(best_acc) - - print("Mean acc:") - print(np.mean(test_accs[-20:])) - - return best_acc, np.mean(test_accs[-20:]) diff --git a/mixmatch/utils/__init__.py b/mixmatch/utils/__init__.py deleted file mode 100644 index 1fc86bb..0000000 --- a/mixmatch/utils/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Useful utils -""" -# progress bar -import os -import sys - -from .eval import * - -sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) diff --git a/mixmatch/utils/eval.py b/mixmatch/utils/eval.py deleted file mode 100644 index de01d91..0000000 --- a/mixmatch/utils/eval.py +++ /dev/null @@ -1,189 +0,0 @@ -from typing import Callable, Sequence - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.parallel -from torch.nn.functional import one_hot -from torch.optim import Optimizer -from torch.utils.data import DataLoader -from torchmetrics.functional import accuracy -from tqdm import tqdm - -from utils.ema import WeightEMA -from utils.interleave import interleave -from utils.loss import SemiLoss - - -def mix_up( - x: torch.Tensor, - y: torch.Tensor, - alpha: float, -) -> tuple[torch.Tensor, torch.Tensor]: - ratio = np.random.beta(alpha, alpha) - ratio = max(ratio, 1 - ratio) - - shuf_idx = torch.randperm(x.size(0)) - - x_mix = ratio * x + (1 - ratio) * x[shuf_idx] - y_mix = ratio * y + (1 - ratio) * y[shuf_idx] - return x_mix, y_mix - - -def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: - """Sharpen the predictions by raising them to the power of 1 / temp""" - y_sharp = y ** (1 / temp) - # Sharpening will change the sum of the predictions. - y_sharp /= y_sharp.sum(dim=1, keepdim=True) - return y_sharp - - -def guess_labels( - model: nn.Module, - x_unls: list[torch.Tensor], -) -> torch.Tensor: - """Guess labels from the unlabelled data""" - y_unls: list[torch.Tensor] = [torch.softmax(model(u), dim=1) for u in x_unls] - # The sum will sum the tensors in the list, it doesn't reduce the tensors - y_unl = sum(y_unls) / len(y_unls) - return y_unl - - -def train_epoch( - *, - train_lbl_dl: DataLoader, - train_unl_dl: DataLoader, - model: nn.Module, - optim: Optimizer, - ema_optim: WeightEMA, - loss_fn: SemiLoss, - epoch: int, - epochs: int, - device: str, - train_iters: int, - unl_loss_scale: float, - mix_beta_alpha: float, - sharpen_temp: float, -) -> tuple[float, float, float]: - losses = [] - losses_x = [] - losses_u = [] - n = [] - - lbl_iter = iter(train_lbl_dl) - unl_iter = iter(train_unl_dl) - - model.train() - for batch_idx in tqdm(range(train_iters)): - try: - (x_lbl,), y_lbl = next(lbl_iter) - except StopIteration: - lbl_iter = iter(train_lbl_dl) - (x_lbl,), y_lbl = next(lbl_iter) - - try: - x_unls, _ = next(unl_iter) - except StopIteration: - unl_iter = iter(train_unl_dl) - x_unls, _ = next(unl_iter) - - y_lbl = one_hot(y_lbl.long(), num_classes=10) - - x_lbl = x_lbl.to(device) - y_lbl = y_lbl.to(device) - x_unls = [u.to(device) for u in x_unls] - - with torch.no_grad(): - y_unl = guess_labels(model=model, x_unls=x_unls) - y_unl = sharpen(y_unl, sharpen_temp) - - x = torch.cat([x_lbl, *x_unls], dim=0) - y = torch.cat([y_lbl, y_unl, y_unl], dim=0) - x_mix, y_mix = mix_up(x, y, mix_beta_alpha) - - # interleave labeled and unlabeled samples between batches to - # get correct batchnorm calculation - batch_size = x_lbl.shape[0] - x_mix = list(torch.split(x_mix, batch_size)) - x_mix = interleave(x_mix, batch_size) - - y_mix_pred = [model(x) for x in x_mix] - - # put interleaved samples back - y_mix_pred = interleave(y_mix_pred, batch_size) - - y_mix_lbl_pred = y_mix_pred[0] - y_mix_lbl = y_mix[:batch_size] - y_mix_unl_pred = torch.cat(y_mix_pred[1:], dim=0) - y_mix_unl = y_mix[batch_size:] - - loss_lbl, loss_unl = loss_fn( - x_lbl=y_mix_lbl_pred, - y_lbl=y_mix_lbl, - x_unl=y_mix_unl_pred, - y_unl=y_mix_unl, - ) - loss_unl_scale = (epoch + batch_idx / train_iters) / epochs * unl_loss_scale - loss = loss_lbl + loss_unl * loss_unl_scale - - losses.append(loss) - losses_x.append(loss_lbl) - losses_u.append(loss_unl) - n.append(x_lbl.size(0)) - - optim.zero_grad() - loss.backward() - optim.step() - ema_optim.step() - - return ( - sum([loss * n for loss, n in zip(losses, n)]) / sum(n), - sum([loss * n for loss, n in zip(losses_x, n)]) / sum(n), - sum([loss * n for loss, n in zip(losses_u, n)]) / sum(n), - ) - - -def validate( - *, - valloader: DataLoader, - model: nn.Module, - loss_fn: Callable, - device: str, -): - n = [] - losses = [] - accs = [] - - model.eval() - with torch.no_grad(): - for x, y in tqdm(valloader): - # TODO: Pretty hacky but this is for the train loader. - if isinstance(x, Sequence): - x = x[0] - - x = x.to(device) - y = y.to(device) - - y_pred = model(x) - loss = loss_fn(y_pred, y.long()) - - # TODO: Technically, we shouldn't * 100, but it's fine for now as - # it doesn't impact training - acc = ( - accuracy( - y_pred, - y, - task="multiclass", - num_classes=y_pred.shape[1], - ) - * 100 - ) - losses.append(loss.item()) - accs.append(acc.item()) - n.append(x.size(0)) - - # return weighted mean - return ( - sum([loss * n for loss, n in zip(losses, n)]) / sum(n), - sum([top * n for top, n in zip(accs, n)]) / sum(n), - ) diff --git a/mixmatch/utils/loss.py b/mixmatch/utils/loss.py deleted file mode 100644 index ae59d2c..0000000 --- a/mixmatch/utils/loss.py +++ /dev/null @@ -1,28 +0,0 @@ -import numpy as np -import torch -import torch.nn.parallel -from torch.nn.functional import cross_entropy - - -def linear_rampup(current: float, rampup_length: int): - if rampup_length == 0: - return 1.0 - return np.clip(current / rampup_length, 0, 1) - - -class SemiLoss(object): - def __call__( - self, - x_lbl: torch.Tensor, - y_lbl: torch.Tensor, - x_unl: torch.Tensor, - y_unl: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - probs_u = torch.softmax(x_unl, dim=1) - - l_x = cross_entropy(x_lbl, y_lbl) - # TODO: Not sure why this is different from MSELoss - # It's likely not a big deal, but it's worth investigating if we have - # too much time on our hands - l_u = torch.mean((probs_u - y_unl) ** 2) - return l_x, l_u diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index a9d913f..0000000 --- a/tests/test_main.py +++ /dev/null @@ -1,19 +0,0 @@ -from mixmatch.main import main - - -def test_main_seed_fast(): - """The fast variant to ensure that the model doesn't change.""" - epochs = 1 - train_iteration = 8 - best_acc_1, mean_acc_1 = main(epochs=epochs, train_iteration=train_iteration) - - assert best_acc_1 == 8.399999992370606 - assert mean_acc_1 == 8.81 - - -def test_main_seed_epoch(): - """Ensure that the model doesn't change when refactoring""" - epochs = 1 - best_acc_1, mean_acc_1 = main(epochs=epochs) - assert best_acc_1 == 23.733333368937174 - assert mean_acc_1 == 22.22 From 4a80a8f35e4391f9031e772df9bf462a1242c347 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 20:03:48 +0800 Subject: [PATCH 25/38] Add main for PyTorch Lightning --- mixmatch/main_pl.py | 48 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 mixmatch/main_pl.py diff --git a/mixmatch/main_pl.py b/mixmatch/main_pl.py new file mode 100644 index 0000000..fb4b80a --- /dev/null +++ b/mixmatch/main_pl.py @@ -0,0 +1,48 @@ +import pytorch_lightning as pl +import torch + +from mixmatch.dataset.cifar10 import SSLCIFAR10DataModule +from mixmatch.models.wideresnet import WideResNet +from mixmatch.models.mixmatch_module import MixMatchModule + +epochs: int = 1024 +batch_size: int = 64 +k_augs: int = 2 +lr: float = 0.002 +weight_decay: float = 0.00004 +ema_lr: float = 0.001 +train_iters: int = 1024 +unl_loss_scale: float = 75 +mix_beta_alpha: float = 0.75 +sharpen_temp: float = 0.5 +device: str = "cuda" +seed: int | None = 42 +train_lbl_size: float = 0.005 +train_unl_size: float = 0.980 + +dm = SSLCIFAR10DataModule( + dir="../tests/data", + train_lbl_size=train_lbl_size, + train_unl_size=train_unl_size, + batch_size=batch_size, + train_iters=train_iters, + seed=seed, + k_augs=k_augs, + num_workers=32, +) + +mm_model = MixMatchModule( + model=WideResNet(n_classes=10, depth=28, width=2, drop_rate=0.0, seed=seed), + sharpen_temp=sharpen_temp, + mix_beta_alpha=mix_beta_alpha, + unl_loss_scale=unl_loss_scale, + ema_lr=ema_lr, + lr=lr, + weight_decay=weight_decay, +) + +torch.set_float32_matmul_precision("high") + +trainer = pl.Trainer(max_epochs=epochs, accelerator="gpu") + +trainer.fit(mm_model, dm) From 9cf04406c0fea3407a39ab26b58b73c76ae37399 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 20:04:08 +0800 Subject: [PATCH 26/38] Add empty init for utils --- mixmatch/utils/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 mixmatch/utils/__init__.py diff --git a/mixmatch/utils/__init__.py b/mixmatch/utils/__init__.py new file mode 100644 index 0000000..e69de29 From 337cb8bb4605181916540902ca4458013ba79c85 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 20:05:38 +0800 Subject: [PATCH 27/38] Format and clean up files --- mixmatch/models/mixmatch_module.py | 2 +- mixmatch/models/wideresnet.py | 2 -- mixmatch/utils/ema.py | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/mixmatch/models/mixmatch_module.py b/mixmatch/models/mixmatch_module.py index 251c20a..fc12ce1 100644 --- a/mixmatch/models/mixmatch_module.py +++ b/mixmatch/models/mixmatch_module.py @@ -13,7 +13,7 @@ from torchmetrics.functional import accuracy import utils.interleave -from utils import WeightEMA +from utils.ema import WeightEMA # The eq=False is to prevent overriding hash diff --git a/mixmatch/models/wideresnet.py b/mixmatch/models/wideresnet.py index df9b748..df91f6d 100644 --- a/mixmatch/models/wideresnet.py +++ b/mixmatch/models/wideresnet.py @@ -3,8 +3,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch.nn.parallel -import torch.nn.parallel class BasicBlock(nn.Module): diff --git a/mixmatch/utils/ema.py b/mixmatch/utils/ema.py index ad1a50a..9a32d1e 100644 --- a/mixmatch/utils/ema.py +++ b/mixmatch/utils/ema.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -import torch.nn.parallel class WeightEMA: From dbe4e22cfc9fe6cf7db12dba7e1f237178b696c2 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sat, 25 Nov 2023 22:05:26 +0800 Subject: [PATCH 28/38] Fix issue with dataclass assign module order --- mixmatch/main_pl.py | 7 +++++-- mixmatch/models/mixmatch_module.py | 12 ++++++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/mixmatch/main_pl.py b/mixmatch/main_pl.py index fb4b80a..5fe6e8f 100644 --- a/mixmatch/main_pl.py +++ b/mixmatch/main_pl.py @@ -5,7 +5,7 @@ from mixmatch.models.wideresnet import WideResNet from mixmatch.models.mixmatch_module import MixMatchModule -epochs: int = 1024 +epochs: int = 100 batch_size: int = 64 k_augs: int = 2 lr: float = 0.002 @@ -32,7 +32,10 @@ ) mm_model = MixMatchModule( - model=WideResNet(n_classes=10, depth=28, width=2, drop_rate=0.0, seed=seed), + model_fn=lambda: WideResNet( + n_classes=10, depth=28, width=2, drop_rate=0.0, seed=seed + ), + n_classes=10, sharpen_temp=sharpen_temp, mix_beta_alpha=mix_beta_alpha, unl_loss_scale=unl_loss_scale, diff --git a/mixmatch/models/mixmatch_module.py b/mixmatch/models/mixmatch_module.py index fc12ce1..eb9bafe 100644 --- a/mixmatch/models/mixmatch_module.py +++ b/mixmatch/models/mixmatch_module.py @@ -1,5 +1,5 @@ from copy import deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Callable import numpy as np @@ -42,7 +42,8 @@ class MixMatchModule(pl.LightningModule): weight_decay: The weight decay to use for the optimizer. """ - model: nn.Module + model_fn: Callable[[], nn.Module] + n_classes: int = 10 sharpen_temp: float = 0.5 mix_beta_alpha: float = 0.75 unl_loss_scale: float = 75 @@ -63,8 +64,11 @@ class MixMatchModule(pl.LightningModule): [torch.Tensor, torch.Tensor], torch.Tensor ] = lambda pred, tgt: torch.mean((torch.softmax(pred, dim=1) - tgt) ** 2) + model: nn.Module = field(init=False) + def __post_init__(self): super().__init__() + self.model = self.model_fn() self.ema_model = deepcopy(self.model) for param in self.ema_model.parameters(): param.detach_() @@ -148,12 +152,12 @@ def training_step(self, batch, batch_idx): x_mix = list(torch.split(x_mix, batch_size)) # Interleave to get a consistent Batch Norm Calculation - x_mix = utils.interleave(x_mix, batch_size) + x_mix = utils.interleave.interleave(x_mix, batch_size) y_mix_pred = [self(x) for x in x_mix] # Un-interleave to shuffle back to original order - y_mix_pred = utils.interleave(y_mix_pred, batch_size) + y_mix_pred = utils.interleave.interleave(y_mix_pred, batch_size) y_mix_lbl_pred = y_mix_pred[0] y_mix_lbl = y_mix[:batch_size] From 44ff3b8a7d534991ac63e7d1e2fe0a5db0764af1 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sun, 26 Nov 2023 13:58:15 +0800 Subject: [PATCH 29/38] Rename EMA Updater to be clearer on function call --- mixmatch/utils/ema.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mixmatch/utils/ema.py b/mixmatch/utils/ema.py index 9a32d1e..ad0047d 100644 --- a/mixmatch/utils/ema.py +++ b/mixmatch/utils/ema.py @@ -7,16 +7,20 @@ def __init__( self, model: nn.Module, ema_model: nn.Module, - ema_lr: float, ): - self.ema_lr = ema_lr self.model = model self.ema_model = ema_model - def step(self): + def update(self, lr: float): + """Update the EMA model with the current model's parameters. + + Args: + lr: A fraction controlling how much should the EMA learn from the + current model. + """ for param, ema_param in zip( self.model.parameters(), self.ema_model.parameters() ): if ema_param.dtype == torch.float32: - ema_param.mul_(1 - self.ema_lr) - ema_param.add_(param * self.ema_lr) + ema_param.mul_(1 - lr) + ema_param.add_(param * lr) From 36d05473e3dca9bec306b51cbd7c5cb723537a61 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sun, 26 Nov 2023 13:58:38 +0800 Subject: [PATCH 30/38] Implement EMA LR updating --- mixmatch/models/mixmatch_module.py | 43 ++++++++++++++++++------------ 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/mixmatch/models/mixmatch_module.py b/mixmatch/models/mixmatch_module.py index eb9bafe..aadc495 100644 --- a/mixmatch/models/mixmatch_module.py +++ b/mixmatch/models/mixmatch_module.py @@ -1,5 +1,5 @@ from copy import deepcopy -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Callable import numpy as np @@ -33,11 +33,13 @@ class MixMatchModule(pl.LightningModule): how to implement a new dataset. Args: - model: The model to train. + model_fn: The function to use to create the model. + n_classes: The number of classes in the dataset. sharpen_temp: The temperature to use for sharpening. mix_beta_alpha: The alpha to use for the beta distribution when mixing. unl_loss_scale: The scale to use for the unsupervised loss. ema_lr: The learning rate to use for the EMA. + ema_lr_exp: The exponent decay to use for the EMA learning rate. lr: The learning rate to use for the optimizer. weight_decay: The weight decay to use for the optimizer. """ @@ -48,8 +50,9 @@ class MixMatchModule(pl.LightningModule): mix_beta_alpha: float = 0.75 unl_loss_scale: float = 75 ema_lr: float = 0.001 + ema_lr_exp: float = 0.25 lr: float = 0.002 - weight_decay: float = 0.0005 + weight_decay: float = 0.00004 # See our wiki for details on interleave interleave: bool = False @@ -64,10 +67,11 @@ class MixMatchModule(pl.LightningModule): [torch.Tensor, torch.Tensor], torch.Tensor ] = lambda pred, tgt: torch.mean((torch.softmax(pred, dim=1) - tgt) ** 2) - model: nn.Module = field(init=False) - def __post_init__(self): super().__init__() + self.save_hyperparameters( + ignore=["model_fn", "train_lbl_loss", "train_unl_loss", "model"] + ) self.model = self.model_fn() self.ema_model = deepcopy(self.model) for param in self.ema_model.parameters(): @@ -82,9 +86,9 @@ def forward(self, x): @staticmethod def mix_up( - x: torch.Tensor, - y: torch.Tensor, - alpha: float, + x: torch.Tensor, + y: torch.Tensor, + alpha: float, ) -> tuple[torch.Tensor, torch.Tensor]: """Mix up the data @@ -122,8 +126,8 @@ def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: return y_sharp def guess_labels( - self, - x_unls: list[torch.Tensor], + self, + x_unls: list[torch.Tensor], ) -> torch.Tensor: """Guess labels from the unlabelled data""" y_unls: list[torch.Tensor] = [ @@ -133,7 +137,13 @@ def guess_labels( y_unl = sum(y_unls) / len(y_unls) return y_unl + @property + def progress(self): + # Progress is a linear ramp from 0 to 1 over the course of training. + return (self.global_step / self.trainer.num_training_batches) / self.trainer.max_epochs + def training_step(self, batch, batch_idx): + # Progress is a linear ramp from 0 to 1 over the course of training.q (x_lbl, y_lbl), (x_unls, _) = batch x_lbl = x_lbl[0] y_lbl = one_hot(y_lbl.long(), num_classes=self.n_classes) @@ -176,15 +186,12 @@ def training_step(self, batch, batch_idx): # The scale is a linear ramp up from 0 to self.unl_loss_scale # over the course of training. - loss_unl_scale = ( - (self.current_epoch + batch_idx / self.trainer.num_training_batches) - / self.trainer.max_epochs - * self.unl_loss_scale - ) + loss_unl_scale = self.progress * self.unl_loss_scale loss = loss_lbl + loss_unl * loss_unl_scale - self.log("train_loss", loss, prog_bar=True) + self.log("loss_unl_scale", loss_unl_scale, prog_bar=True) + self.log("train_loss", loss) self.log("train_loss_lbl", loss_lbl) self.log("train_loss_unl", loss_unl) @@ -204,7 +211,9 @@ def validation_step(self, batch, batch_idx): # It's important to keep this to avoid a memory leak. @torch.no_grad() def on_after_backward(self) -> None: - self.ema_updater.step() + ema_update_lr = self.ema_lr * (self.ema_lr_exp ** self.progress) + self.log("ema_update_lr", ema_update_lr, prog_bar=True) + self.ema_updater.update(ema_update_lr) def configure_optimizers(self): return torch.optim.Adam( From d4286291537919fa4b3aa4f8e4054d3d3a742d4b Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sun, 26 Nov 2023 13:58:53 +0800 Subject: [PATCH 31/38] Fine tune model --- mixmatch/main_pl.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/mixmatch/main_pl.py b/mixmatch/main_pl.py index 5fe6e8f..2d40d94 100644 --- a/mixmatch/main_pl.py +++ b/mixmatch/main_pl.py @@ -1,18 +1,21 @@ +import numpy as np import pytorch_lightning as pl import torch from mixmatch.dataset.cifar10 import SSLCIFAR10DataModule -from mixmatch.models.wideresnet import WideResNet from mixmatch.models.mixmatch_module import MixMatchModule +from mixmatch.models.wideresnet import WideResNet epochs: int = 100 batch_size: int = 64 k_augs: int = 2 -lr: float = 0.002 +# Scale LR due to removed interleaving +lr: float = 0.002 * np.sqrt((k_augs + 1)) weight_decay: float = 0.00004 -ema_lr: float = 0.001 +ema_lr: float = 0.005 +ema_lr_exp: float = 1 train_iters: int = 1024 -unl_loss_scale: float = 75 +unl_loss_scale: float = 100 mix_beta_alpha: float = 0.75 sharpen_temp: float = 0.5 device: str = "cuda" @@ -40,12 +43,20 @@ mix_beta_alpha=mix_beta_alpha, unl_loss_scale=unl_loss_scale, ema_lr=ema_lr, + ema_lr_exp=ema_lr_exp, lr=lr, weight_decay=weight_decay, ) torch.set_float32_matmul_precision("high") -trainer = pl.Trainer(max_epochs=epochs, accelerator="gpu") +trainer = pl.Trainer( + max_epochs=epochs, + accelerator="gpu", + callbacks=[ + pl.callbacks.LearningRateMonitor(), + pl.callbacks.StochasticWeightAveraging(swa_lrs=lr) + ], +) trainer.fit(mm_model, dm) From c5cb6cb8e196997cc2e52b36f3be3705d4a1291c Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sun, 26 Nov 2023 14:02:01 +0800 Subject: [PATCH 32/38] Remove EMA scaling --- mixmatch/models/mixmatch_module.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mixmatch/models/mixmatch_module.py b/mixmatch/models/mixmatch_module.py index aadc495..b30ac03 100644 --- a/mixmatch/models/mixmatch_module.py +++ b/mixmatch/models/mixmatch_module.py @@ -39,7 +39,6 @@ class MixMatchModule(pl.LightningModule): mix_beta_alpha: The alpha to use for the beta distribution when mixing. unl_loss_scale: The scale to use for the unsupervised loss. ema_lr: The learning rate to use for the EMA. - ema_lr_exp: The exponent decay to use for the EMA learning rate. lr: The learning rate to use for the optimizer. weight_decay: The weight decay to use for the optimizer. """ @@ -50,7 +49,6 @@ class MixMatchModule(pl.LightningModule): mix_beta_alpha: float = 0.75 unl_loss_scale: float = 75 ema_lr: float = 0.001 - ema_lr_exp: float = 0.25 lr: float = 0.002 weight_decay: float = 0.00004 @@ -211,9 +209,7 @@ def validation_step(self, batch, batch_idx): # It's important to keep this to avoid a memory leak. @torch.no_grad() def on_after_backward(self) -> None: - ema_update_lr = self.ema_lr * (self.ema_lr_exp ** self.progress) - self.log("ema_update_lr", ema_update_lr, prog_bar=True) - self.ema_updater.update(ema_update_lr) + self.ema_updater.update(self.ema_lr) def configure_optimizers(self): return torch.optim.Adam( From 03d455e152a67e677d6c812345339a8dac85df0f Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sun, 26 Nov 2023 14:02:12 +0800 Subject: [PATCH 33/38] Add model checkpoint saving --- mixmatch/main_pl.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mixmatch/main_pl.py b/mixmatch/main_pl.py index 2d40d94..1d98488 100644 --- a/mixmatch/main_pl.py +++ b/mixmatch/main_pl.py @@ -55,7 +55,13 @@ accelerator="gpu", callbacks=[ pl.callbacks.LearningRateMonitor(), - pl.callbacks.StochasticWeightAveraging(swa_lrs=lr) + pl.callbacks.StochasticWeightAveraging(swa_lrs=lr), + pl.callbacks.ModelCheckpoint( + monitor="val_acc", + filename="mm-{epoch:02d}-{val_acc:.2f}", + save_top_k=1, + mode="max", + ), ], ) From e34dc47cd7a148132fbbc3cc488133b4530bdf0e Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sun, 26 Nov 2023 14:02:29 +0800 Subject: [PATCH 34/38] Remove ema exp scaling --- mixmatch/main_pl.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mixmatch/main_pl.py b/mixmatch/main_pl.py index 1d98488..7e2226d 100644 --- a/mixmatch/main_pl.py +++ b/mixmatch/main_pl.py @@ -13,7 +13,6 @@ lr: float = 0.002 * np.sqrt((k_augs + 1)) weight_decay: float = 0.00004 ema_lr: float = 0.005 -ema_lr_exp: float = 1 train_iters: int = 1024 unl_loss_scale: float = 100 mix_beta_alpha: float = 0.75 @@ -43,7 +42,6 @@ mix_beta_alpha=mix_beta_alpha, unl_loss_scale=unl_loss_scale, ema_lr=ema_lr, - ema_lr_exp=ema_lr_exp, lr=lr, weight_decay=weight_decay, ) From cc3b718cd746051a25cea0fdf7439530d09b0a08 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sun, 26 Nov 2023 14:03:13 +0800 Subject: [PATCH 35/38] Remove ema exp scaling --- mixmatch/models/mixmatch_module.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mixmatch/models/mixmatch_module.py b/mixmatch/models/mixmatch_module.py index b30ac03..f9db88e 100644 --- a/mixmatch/models/mixmatch_module.py +++ b/mixmatch/models/mixmatch_module.py @@ -75,18 +75,16 @@ def __post_init__(self): for param in self.ema_model.parameters(): param.detach_() - self.ema_updater = WeightEMA( - model=self.model, ema_model=self.ema_model, ema_lr=self.ema_lr - ) + self.ema_updater = WeightEMA(model=self.model, ema_model=self.ema_model) def forward(self, x): return self.model(x) @staticmethod def mix_up( - x: torch.Tensor, - y: torch.Tensor, - alpha: float, + x: torch.Tensor, + y: torch.Tensor, + alpha: float, ) -> tuple[torch.Tensor, torch.Tensor]: """Mix up the data @@ -124,8 +122,8 @@ def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: return y_sharp def guess_labels( - self, - x_unls: list[torch.Tensor], + self, + x_unls: list[torch.Tensor], ) -> torch.Tensor: """Guess labels from the unlabelled data""" y_unls: list[torch.Tensor] = [ @@ -138,7 +136,9 @@ def guess_labels( @property def progress(self): # Progress is a linear ramp from 0 to 1 over the course of training. - return (self.global_step / self.trainer.num_training_batches) / self.trainer.max_epochs + return ( + self.global_step / self.trainer.num_training_batches + ) / self.trainer.max_epochs def training_step(self, batch, batch_idx): # Progress is a linear ramp from 0 to 1 over the course of training.q From 50cbfb7b5fbf1daa25c3b411a9b483d3284d7ff4 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sun, 26 Nov 2023 16:31:24 +0800 Subject: [PATCH 36/38] Implement custom Unl Loss Scaler Scheduling --- mixmatch/main_pl.py | 9 ++++--- mixmatch/models/mixmatch_module.py | 43 ++++++++++++++++++------------ 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/mixmatch/main_pl.py b/mixmatch/main_pl.py index 7e2226d..fa7d32e 100644 --- a/mixmatch/main_pl.py +++ b/mixmatch/main_pl.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np import pytorch_lightning as pl import torch @@ -5,16 +7,17 @@ from mixmatch.dataset.cifar10 import SSLCIFAR10DataModule from mixmatch.models.mixmatch_module import MixMatchModule from mixmatch.models.wideresnet import WideResNet +from mixmatch.utils.ease import ease_out epochs: int = 100 batch_size: int = 64 k_augs: int = 2 # Scale LR due to removed interleaving lr: float = 0.002 * np.sqrt((k_augs + 1)) +loss_unl_scaler = ease_out(50, 0.95) weight_decay: float = 0.00004 ema_lr: float = 0.005 train_iters: int = 1024 -unl_loss_scale: float = 100 mix_beta_alpha: float = 0.75 sharpen_temp: float = 0.5 device: str = "cuda" @@ -23,7 +26,7 @@ train_unl_size: float = 0.980 dm = SSLCIFAR10DataModule( - dir="../tests/data", + dir=Path(__file__).parents[1] / "tests/data", train_lbl_size=train_lbl_size, train_unl_size=train_unl_size, batch_size=batch_size, @@ -40,7 +43,7 @@ n_classes=10, sharpen_temp=sharpen_temp, mix_beta_alpha=mix_beta_alpha, - unl_loss_scale=unl_loss_scale, + loss_unl_scaler=loss_unl_scaler, ema_lr=ema_lr, lr=lr, weight_decay=weight_decay, diff --git a/mixmatch/models/mixmatch_module.py b/mixmatch/models/mixmatch_module.py index f9db88e..a18093d 100644 --- a/mixmatch/models/mixmatch_module.py +++ b/mixmatch/models/mixmatch_module.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from copy import deepcopy from dataclasses import dataclass -from typing import Callable +from typing import Callable, Protocol import numpy as np import pytorch_lightning as pl @@ -12,8 +14,13 @@ from torch.nn.functional import one_hot from torchmetrics.functional import accuracy -import utils.interleave -from utils.ema import WeightEMA +import mixmatch.utils.interleave +from mixmatch.utils.ema import WeightEMA + + +class LossUnlScale(Protocol): + def __call__(self, progress: float) -> float: + return progress * 75 # The eq=False is to prevent overriding hash @@ -37,17 +44,17 @@ class MixMatchModule(pl.LightningModule): n_classes: The number of classes in the dataset. sharpen_temp: The temperature to use for sharpening. mix_beta_alpha: The alpha to use for the beta distribution when mixing. - unl_loss_scale: The scale to use for the unsupervised loss. + loss_unl_scaler: The scale to use for the unsupervised loss. ema_lr: The learning rate to use for the EMA. lr: The learning rate to use for the optimizer. weight_decay: The weight decay to use for the optimizer. """ model_fn: Callable[[], nn.Module] + loss_unl_scaler: LossUnlScale n_classes: int = 10 sharpen_temp: float = 0.5 mix_beta_alpha: float = 0.75 - unl_loss_scale: float = 75 ema_lr: float = 0.001 lr: float = 0.002 weight_decay: float = 0.00004 @@ -55,20 +62,24 @@ class MixMatchModule(pl.LightningModule): # See our wiki for details on interleave interleave: bool = False - train_lbl_loss: Callable[ - [torch.Tensor, torch.Tensor], torch.Tensor - ] = F.cross_entropy + get_loss_lbl: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.cross_entropy # TODO: Not sure why this is different from MSELoss # It's likely not a big deal, but it's worth investigating if we have # too much time on our hands - train_unl_loss: Callable[ + get_loss_unl: Callable[ [torch.Tensor, torch.Tensor], torch.Tensor ] = lambda pred, tgt: torch.mean((torch.softmax(pred, dim=1) - tgt) ** 2) def __post_init__(self): super().__init__() self.save_hyperparameters( - ignore=["model_fn", "train_lbl_loss", "train_unl_loss", "model"] + ignore=[ + "model_fn", + "get_loss_lbl", + "get_loss_unl", + "loss_unl_scaler", + "model", + ] ) self.model = self.model_fn() self.ema_model = deepcopy(self.model) @@ -160,12 +171,12 @@ def training_step(self, batch, batch_idx): x_mix = list(torch.split(x_mix, batch_size)) # Interleave to get a consistent Batch Norm Calculation - x_mix = utils.interleave.interleave(x_mix, batch_size) + x_mix = mixmatch.utils.interleave.interleave(x_mix, batch_size) y_mix_pred = [self(x) for x in x_mix] # Un-interleave to shuffle back to original order - y_mix_pred = utils.interleave.interleave(y_mix_pred, batch_size) + y_mix_pred = mixmatch.utils.interleave.interleave(y_mix_pred, batch_size) y_mix_lbl_pred = y_mix_pred[0] y_mix_lbl = y_mix[:batch_size] @@ -179,12 +190,10 @@ def training_step(self, batch, batch_idx): y_mix_lbl = y_mix[:batch_size] y_mix_unl = y_mix[batch_size:] - loss_lbl = self.train_lbl_loss(y_mix_lbl_pred, y_mix_lbl) - loss_unl = self.train_unl_loss(y_mix_unl_pred, y_mix_unl) + loss_lbl = self.get_loss_lbl(y_mix_lbl_pred, y_mix_lbl) + loss_unl = self.get_loss_unl(y_mix_unl_pred, y_mix_unl) - # The scale is a linear ramp up from 0 to self.unl_loss_scale - # over the course of training. - loss_unl_scale = self.progress * self.unl_loss_scale + loss_unl_scale = self.loss_unl_scaler(progress=self.progress) loss = loss_lbl + loss_unl * loss_unl_scale From b6ffe40946f363d02295bca202f2a3eda935212b Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sun, 26 Nov 2023 19:50:35 +0800 Subject: [PATCH 37/38] Initial commit for new resnet --- mixmatch/models/nested_dict.py | 231 ++++++++++++++++++++++++++++++ mixmatch/models/utils.py | 71 +++++++++ mixmatch/models/wideresnet_new.py | 66 +++++++++ 3 files changed, 368 insertions(+) create mode 100644 mixmatch/models/nested_dict.py create mode 100644 mixmatch/models/utils.py create mode 100644 mixmatch/models/wideresnet_new.py diff --git a/mixmatch/models/nested_dict.py b/mixmatch/models/nested_dict.py new file mode 100644 index 0000000..9b7b41a --- /dev/null +++ b/mixmatch/models/nested_dict.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python +"""`nested_dict` provides dictionaries with multiple levels of nested-ness.""" +from __future__ import print_function +from __future__ import division + +################################################################################ +# +# nested_dict.py +# +# Copyright (c) 2009, 2015 Leo Goodstadt +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +################################################################################# + + +from collections import defaultdict + +import sys + + +def flatten_nested_items(dictionary): + """ + Flatten a nested_dict. + + iterate through nested dictionary (with iterkeys() method) + and return with nested keys flattened into a tuple + """ + if sys.hexversion < 0x03000000: + keys = dictionary.iterkeys + keystr = "iterkeys" + else: + keys = dictionary.keys + keystr = "keys" + for key in keys(): + value = dictionary[key] + if hasattr(value, keystr): + for keykey, value in flatten_nested_items(value): + yield (key,) + keykey, value + else: + yield (key,), value + + +class _recursive_dict(defaultdict): + """ + Parent class of nested_dict. + + Defined separately for _nested_levels to work + transparently, so dictionaries with a specified (and constant) degree of nestedness + can be created easily. + + The "_flat" functions are defined here rather than in nested_dict because they work + recursively. + + """ + + def iteritems_flat(self): + """Iterate through items with nested keys flattened into a tuple.""" + for key, value in flatten_nested_items(self): + yield key, value + + def iterkeys_flat(self): + """Iterate through keys with nested keys flattened into a tuple.""" + for key, value in flatten_nested_items(self): + yield key + + def itervalues_flat(self): + """Iterate through values with nested keys flattened into a tuple.""" + for key, value in flatten_nested_items(self): + yield value + + items_flat = iteritems_flat + keys_flat = iterkeys_flat + values_flat = itervalues_flat + + def to_dict(self, input_dict=None): + """Convert the nested dictionary to a nested series of standard ``dict`` objects.""" + # + # Calls itself recursively to unwind the dictionary. + # Use to_dict() to start at the top level of nesting + plain_dict = dict() + if input_dict is None: + input_dict = self + for key in input_dict.keys(): + value = input_dict[key] + if isinstance(value, _recursive_dict): + # print "recurse", value + plain_dict[key] = self.to_dict(value) + else: + # print "plain", value + plain_dict[key] = value + return plain_dict + + def __str__(self, indent=None): + """Representation of self as a string.""" + import json + return json.dumps(self.to_dict(), indent=indent) + + +class _any_type(object): + pass + + +def _nested_levels(level, nested_type): + """Helper function to create a specified degree of nested dictionaries.""" + if level > 2: + return lambda: _recursive_dict(_nested_levels(level - 1, nested_type)) + if level == 2: + if isinstance(nested_type, _any_type): + return lambda: _recursive_dict() + else: + return lambda: _recursive_dict(_nested_levels(level - 1, nested_type)) + return nested_type + + +if sys.hexversion < 0x03000000: + iteritems = dict.iteritems +else: + iteritems = dict.items + + +# _________________________________________________________________________________________ +# +# nested_dict +# +# _________________________________________________________________________________________ +def nested_dict_from_dict(orig_dict, nd): + """Helper to build nested_dict from a dict.""" + for key, value in iteritems(orig_dict): + if isinstance(value, (dict,)): + nd[key] = nested_dict_from_dict(value, nested_dict()) + else: + nd[key] = value + return nd + + +def _recursive_update(nd, other): + for key, value in iteritems(other): + #print ("key=", key) + if isinstance(value, (dict,)): + + # recursive update if my item is nested_dict + if isinstance(nd[key], (_recursive_dict,)): + #print ("recursive update", key, type(nd[key])) + _recursive_update(nd[key], other[key]) + + # update if my item is dict + elif isinstance(nd[key], (dict,)): + #print ("update", key, type(nd[key])) + nd[key].update(other[key]) + + # overwrite + else: + #print ("self not nested dict or dict: overwrite", key) + nd[key] = value + # other not dict: overwrite + else: + #print ("other not dict: overwrite", key) + nd[key] = value + return nd + + +# _________________________________________________________________________________________ +# +# nested_dict +# +# _________________________________________________________________________________________ +class nested_dict(_recursive_dict): + """ + Nested dict. + + Uses defaultdict to automatically add levels of nested dicts and other types. + """ + + def update(self, other): + """Update recursively.""" + _recursive_update(self, other) + + def __init__(self, *param, **named_param): + """ + Constructor. + + Takes one or two parameters + 1) int, [TYPE] + 1) dict + """ + if not len(param): + self.factory = nested_dict + defaultdict.__init__(self, self.factory) + return + + if len(param) == 1: + # int = level + if isinstance(param[0], int): + self.factory = _nested_levels(param[0], _any_type()) + defaultdict.__init__(self, self.factory) + return + # existing dict + if isinstance(param[0], dict): + self.factory = nested_dict + defaultdict.__init__(self, self.factory) + nested_dict_from_dict(param[0], self) + return + + if len(param) == 2: + if isinstance(param[0], int): + self.factory = _nested_levels(*param) + defaultdict.__init__(self, self.factory) + return + + raise Exception("nested_dict should be initialised with either " + "1) the number of nested levels and an optional type, or " + "2) an existing dict to be converted into a nested dict " + "(factory = %s. len(param) = %d, param = %s" + % (self.factory, len(param), param)) \ No newline at end of file diff --git a/mixmatch/models/utils.py b/mixmatch/models/utils.py new file mode 100644 index 0000000..c1f974f --- /dev/null +++ b/mixmatch/models/utils.py @@ -0,0 +1,71 @@ +import torch +from torch.nn.init import kaiming_normal_ +import torch.nn.functional as F +from torch.nn.parallel._functions import Broadcast +from torch.nn.parallel import scatter, parallel_apply, gather +from functools import partial +from mixmatch.models.nested_dict import nested_dict + + +def cast(params, dtype='float'): + if isinstance(params, dict): + return {k: cast(v, dtype) for k, v in params.items()} + else: + return getattr(params.cuda() if torch.cuda.is_available() else params, dtype)() + + +def conv_params(ni, no, k=1): + return kaiming_normal_(torch.Tensor(no, ni, k, k)) + + +def linear_params(ni, no): + return {'weight': kaiming_normal_(torch.Tensor(no, ni)), 'bias': torch.zeros(no)} + + +def bnparams(n): + return {'weight': torch.rand(n), + 'bias': torch.zeros(n), + 'running_mean': torch.zeros(n), + 'running_var': torch.ones(n)} + + +def data_parallel(f, input, params, device_ids, output_device=None): + assert isinstance(device_ids, list) + if output_device is None: + output_device = device_ids[0] + + if len(device_ids) == 1: + return f(input, params) + + params_all = Broadcast.apply(device_ids, *params.values()) + params_replicas = [{k: params_all[i + j * len(params)] for i, k in enumerate(params.keys())} + for j in range(len(device_ids))] + + replicas = [partial(f, params=p) + for p in params_replicas] + inputs = scatter([input], device_ids) + outputs = parallel_apply(replicas, inputs) + return gather(outputs, output_device) + + +def flatten(params): + return {'.'.join(k): v for k, v in nested_dict(params).items_flat() if v is not None} + + +def batch_norm(x, params, base): + return F.batch_norm(x, weight=params[base + '.weight'], + bias=params[base + '.bias'], + running_mean=params[base + '.running_mean'], + running_var=params[base + '.running_var'],) + + +def print_tensor_dict(params): + kmax = max(len(key) for key in params.keys()) + for i, (key, v) in enumerate(params.items()): + print(str(i).ljust(5), key.ljust(kmax + 3), str(tuple(v.shape)).ljust(23), torch.typename(v), v.requires_grad) + + +def set_requires_grad_except_bn_(params): + for k, v in params.items(): + if not k.endswith('running_mean') and not k.endswith('running_var'): + v.requires_grad = True diff --git a/mixmatch/models/wideresnet_new.py b/mixmatch/models/wideresnet_new.py new file mode 100644 index 0000000..9a16894 --- /dev/null +++ b/mixmatch/models/wideresnet_new.py @@ -0,0 +1,66 @@ +import torch.nn.functional as F +import mixmatch.models.utils as utils + + +def resnet(depth, width, num_classes): + assert (depth - 4) % 6 == 0, 'depth should be 6n+4' + n = (depth - 4) // 6 + widths = [int(v * width) for v in (16, 32, 64)] + + def gen_block_params(ni, no): + return { + 'conv0': utils.conv_params(ni, no, 3), + 'conv1': utils.conv_params(no, no, 3), + 'bn0': utils.bnparams(ni), + 'bn1': utils.bnparams(no), + 'convdim': utils.conv_params(ni, no, 1) if ni != no else None, + } + + def gen_group_params(ni, no, count): + return {'block%d' % i: gen_block_params(ni if i == 0 else no, no) + for i in range(count)} + + flat_params = utils.cast(utils.flatten({ + 'conv0': utils.conv_params(3, 16, 3), + 'group0': gen_group_params(16, widths[0], n), + 'group1': gen_group_params(widths[0], widths[1], n), + 'group2': gen_group_params(widths[1], widths[2], n), + 'bn': utils.bnparams(widths[2]), + 'fc': utils.linear_params(widths[2], num_classes), + })) + + utils.set_requires_grad_except_bn_(flat_params) + + def block(x, params, base, stride): + o1 = F.relu(utils.batch_norm(x, params, base + '.bn0'), inplace=True) + y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1) + o2 = F.relu(utils.batch_norm(y, params, base + '.bn1'), inplace=True) + z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1) + if base + '.convdim' in params: + return z + F.conv2d(o1, params[base + '.convdim'], stride=stride) + else: + return z + x + + def group(o, params, base, stride): + for i in range(n): + o = block(o, params, '%s.block%d' % (base, i), stride if i == 0 else 1) + return o + + def f(input, params): + x = F.conv2d(input, params['conv0'], padding=1) + g0 = group(x, params, 'group0', 1) + g1 = group(g0, params, 'group1', 2) + g2 = group(g1, params, 'group2', 2) + o = F.relu(utils.batch_norm(g2, params, 'bn')) + o = F.avg_pool2d(o, 8, 1, 0) + o = o.view(o.size(0), -1) + o = F.linear(o, params['fc.weight'], params['fc.bias']) + return o + + return f, flat_params + + +f, p = resnet(28, 2, 10) +import torch + +f(torch.rand(16, 3, 100, 100), p, ) From 4c3ef31ca723ef9b5914fdc463bed2a4bd71018d Mon Sep 17 00:00:00 2001 From: Evening Date: Sun, 26 Nov 2023 21:22:22 +0800 Subject: [PATCH 38/38] Add annotations for WideResNet --- mixmatch/models/utils.py | 2 +- mixmatch/models/wideresnet_new.py | 120 ++++++++++++++++++++++++++++-- 2 files changed, 114 insertions(+), 8 deletions(-) diff --git a/mixmatch/models/utils.py b/mixmatch/models/utils.py index c1f974f..34a8f5c 100644 --- a/mixmatch/models/utils.py +++ b/mixmatch/models/utils.py @@ -11,7 +11,7 @@ def cast(params, dtype='float'): if isinstance(params, dict): return {k: cast(v, dtype) for k, v in params.items()} else: - return getattr(params.cuda() if torch.cuda.is_available() else params, dtype)() + return getattr(params, dtype)() def conv_params(ni, no, k=1): diff --git a/mixmatch/models/wideresnet_new.py b/mixmatch/models/wideresnet_new.py index 9a16894..d82d256 100644 --- a/mixmatch/models/wideresnet_new.py +++ b/mixmatch/models/wideresnet_new.py @@ -1,10 +1,13 @@ import torch.nn.functional as F +from torch import nn + import mixmatch.models.utils as utils def resnet(depth, width, num_classes): assert (depth - 4) % 6 == 0, 'depth should be 6n+4' - n = (depth - 4) // 6 + block_repeats = (depth - 4) // 6 + widths = [int(v * width) for v in (16, 32, 64)] def gen_block_params(ni, no): @@ -22,9 +25,9 @@ def gen_group_params(ni, no, count): flat_params = utils.cast(utils.flatten({ 'conv0': utils.conv_params(3, 16, 3), - 'group0': gen_group_params(16, widths[0], n), - 'group1': gen_group_params(widths[0], widths[1], n), - 'group2': gen_group_params(widths[1], widths[2], n), + 'group0': gen_group_params(16, widths[0], block_repeats), + 'group1': gen_group_params(widths[0], widths[1], block_repeats), + 'group2': gen_group_params(widths[1], widths[2], block_repeats), 'bn': utils.bnparams(widths[2]), 'fc': utils.linear_params(widths[2], num_classes), })) @@ -32,28 +35,46 @@ def gen_group_params(ni, no, count): utils.set_requires_grad_except_bn_(flat_params) def block(x, params, base, stride): + print(f"\t\tBN -> ReLU = X") o1 = F.relu(utils.batch_norm(x, params, base + '.bn0'), inplace=True) + print(f"\t\tConv {params[base + '.conv0'].shape} Stride {stride} Pad 1") y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1) + print(f"\t\tBN -> ReLU") o2 = F.relu(utils.batch_norm(y, params, base + '.bn1'), inplace=True) + print(f"\t\tConv {params[base + '.conv1'].shape} Stride 1 Pad 1 = Z") z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1) + if base + '.convdim' in params: + print(f"\t\t\tX -> Conv {params[base + '.convdim'].shape} Stride {stride} Pad 0") + print(f"\t\t\tZ + X") return z + F.conv2d(o1, params[base + '.convdim'], stride=stride) else: + print(f"\t\t\tZ + X") return z + x def group(o, params, base, stride): - for i in range(n): - o = block(o, params, '%s.block%d' % (base, i), stride if i == 0 else 1) + for i in range(block_repeats): + print(f"\tBlock {i}") + o = block(o, params, '%s.block%d' % (base, i), + stride if i == 0 else 1) return o def f(input, params): + print(f"Conv {params['conv0'].shape} Stride 1 Pad 1") x = F.conv2d(input, params['conv0'], padding=1) + print(f"Group 0") g0 = group(x, params, 'group0', 1) + print(f"Group 1") g1 = group(g0, params, 'group1', 2) + print(f"Group 2") g2 = group(g1, params, 'group2', 2) + print(f"BN -> ReLU") o = F.relu(utils.batch_norm(g2, params, 'bn')) + print(f"AvgPool 8 Stride 1 Pad 0") o = F.avg_pool2d(o, 8, 1, 0) + print(f"View") o = o.view(o.size(0), -1) + print(f"Linear {params['fc.weight'].shape}") o = F.linear(o, params['fc.weight'], params['fc.bias']) return o @@ -63,4 +84,89 @@ def f(input, params): f, p = resnet(28, 2, 10) import torch -f(torch.rand(16, 3, 100, 100), p, ) +a = f(torch.rand(16, 3, 32, 32), p, ) + + +# 6x +# 1 Block: +# X --> BN --> ReLU +# --> Conv2D Stride ? Pad 1 +# --> BN --> ReLU +# --> Conv2D Stride 1 Pad 1 --> Y +# If ConvDim : X + Y --> Conv2D Stride ? Pad 0 +# Else : X + Y +# +# The ConvDim is to match the dimension, when we change blacks +# +# +# +# class Block(nn.Module): +# def __init__( +# self, +# dims: tuple[int, int, int] | tuple[int, int, int, int], +# ksizes: tuple[int, int] | tuple[int, int, int] = (3, 3, 3), +# strides: tuple[int, int] | tuple[int, int, int] = (2, 1, 2), +# pads: tuple[int, int] | tuple[int, int, int] = (1, 1, 0), +# ): +# """ +# +# Args: +# dims: +# ksizes: +# strides: +# pads: +# """ +# super().__init__() +# assert len(dims) in (3, 4), ("Only supply 3 or 4 dimensions. " +# "See docstring for more info.") +# self.bn0 = nn.BatchNorm2d(dims[0]) +# self.relu0 = nn.ReLU() +# self.conv0 = nn.Conv2d( +# dims[0], dims[1], ksizes[0], +# stride=strides[0], padding=pads[0] +# ) +# self.bn1 = nn.BatchNorm2d(dims[1]) +# self.relu1 = nn.ReLU() +# self.conv1 = nn.Conv2d( +# dims[1], dims[2], ksizes[1], +# stride=strides[1], padding=pads[1] +# ) +# if len(dims) == 4: +# self.conv_proj = nn.Conv2d( +# dims[2], dims[3], ksizes[2], +# stride=strides[2], padding=pads[2] +# ) +# else: +# self.conv_proj = None +# +# def forward(self, x): +# x0 = self.relu0(self.bn0(x)) +# x1 = self.conv0(x0) +# x1 = self.conv1(self.relu1(self.bn1(x1))) +# if self.conv_proj is not None: +# x0 = self.conv_proj(x1) +# return x + x_ +# +# +# class Group(nn.Module): +# def __init__( +# self, +# dim_in: int, +# dim_block: int, +# dim_out: int, +# n_blocks: int = 6, +# stride: int = 1, +# ): +# super().__init__() +# self.blocks = nn.Sequential( +# Block((dim_in, dim_block, dim_block), strides=(stride, stride)), +# *[Block((dim_block, dim_block, dim_block), +# strides=(stride, stride)) for _ in range(n_blocks - 2)], +# Block((dim_block, dim_block, dim_block, dim_out), +# strides=(stride, stride, stride)), +# ) +# +# def forward(self, x): +# return self.blocks(x) +# +# Group(16, 32, 64)(torch.rand(16, 16, 32, 32)).shape