From 50cbfb7b5fbf1daa25c3b411a9b483d3284d7ff4 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Sun, 26 Nov 2023 16:31:24 +0800 Subject: [PATCH] Implement custom Unl Loss Scaler Scheduling --- mixmatch/main_pl.py | 9 ++++--- mixmatch/models/mixmatch_module.py | 43 ++++++++++++++++++------------ 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/mixmatch/main_pl.py b/mixmatch/main_pl.py index 7e2226d..fa7d32e 100644 --- a/mixmatch/main_pl.py +++ b/mixmatch/main_pl.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np import pytorch_lightning as pl import torch @@ -5,16 +7,17 @@ from mixmatch.dataset.cifar10 import SSLCIFAR10DataModule from mixmatch.models.mixmatch_module import MixMatchModule from mixmatch.models.wideresnet import WideResNet +from mixmatch.utils.ease import ease_out epochs: int = 100 batch_size: int = 64 k_augs: int = 2 # Scale LR due to removed interleaving lr: float = 0.002 * np.sqrt((k_augs + 1)) +loss_unl_scaler = ease_out(50, 0.95) weight_decay: float = 0.00004 ema_lr: float = 0.005 train_iters: int = 1024 -unl_loss_scale: float = 100 mix_beta_alpha: float = 0.75 sharpen_temp: float = 0.5 device: str = "cuda" @@ -23,7 +26,7 @@ train_unl_size: float = 0.980 dm = SSLCIFAR10DataModule( - dir="../tests/data", + dir=Path(__file__).parents[1] / "tests/data", train_lbl_size=train_lbl_size, train_unl_size=train_unl_size, batch_size=batch_size, @@ -40,7 +43,7 @@ n_classes=10, sharpen_temp=sharpen_temp, mix_beta_alpha=mix_beta_alpha, - unl_loss_scale=unl_loss_scale, + loss_unl_scaler=loss_unl_scaler, ema_lr=ema_lr, lr=lr, weight_decay=weight_decay, diff --git a/mixmatch/models/mixmatch_module.py b/mixmatch/models/mixmatch_module.py index f9db88e..a18093d 100644 --- a/mixmatch/models/mixmatch_module.py +++ b/mixmatch/models/mixmatch_module.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from copy import deepcopy from dataclasses import dataclass -from typing import Callable +from typing import Callable, Protocol import numpy as np import pytorch_lightning as pl @@ -12,8 +14,13 @@ from torch.nn.functional import one_hot from torchmetrics.functional import accuracy -import utils.interleave -from utils.ema import WeightEMA +import mixmatch.utils.interleave +from mixmatch.utils.ema import WeightEMA + + +class LossUnlScale(Protocol): + def __call__(self, progress: float) -> float: + return progress * 75 # The eq=False is to prevent overriding hash @@ -37,17 +44,17 @@ class MixMatchModule(pl.LightningModule): n_classes: The number of classes in the dataset. sharpen_temp: The temperature to use for sharpening. mix_beta_alpha: The alpha to use for the beta distribution when mixing. - unl_loss_scale: The scale to use for the unsupervised loss. + loss_unl_scaler: The scale to use for the unsupervised loss. ema_lr: The learning rate to use for the EMA. lr: The learning rate to use for the optimizer. weight_decay: The weight decay to use for the optimizer. """ model_fn: Callable[[], nn.Module] + loss_unl_scaler: LossUnlScale n_classes: int = 10 sharpen_temp: float = 0.5 mix_beta_alpha: float = 0.75 - unl_loss_scale: float = 75 ema_lr: float = 0.001 lr: float = 0.002 weight_decay: float = 0.00004 @@ -55,20 +62,24 @@ class MixMatchModule(pl.LightningModule): # See our wiki for details on interleave interleave: bool = False - train_lbl_loss: Callable[ - [torch.Tensor, torch.Tensor], torch.Tensor - ] = F.cross_entropy + get_loss_lbl: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.cross_entropy # TODO: Not sure why this is different from MSELoss # It's likely not a big deal, but it's worth investigating if we have # too much time on our hands - train_unl_loss: Callable[ + get_loss_unl: Callable[ [torch.Tensor, torch.Tensor], torch.Tensor ] = lambda pred, tgt: torch.mean((torch.softmax(pred, dim=1) - tgt) ** 2) def __post_init__(self): super().__init__() self.save_hyperparameters( - ignore=["model_fn", "train_lbl_loss", "train_unl_loss", "model"] + ignore=[ + "model_fn", + "get_loss_lbl", + "get_loss_unl", + "loss_unl_scaler", + "model", + ] ) self.model = self.model_fn() self.ema_model = deepcopy(self.model) @@ -160,12 +171,12 @@ def training_step(self, batch, batch_idx): x_mix = list(torch.split(x_mix, batch_size)) # Interleave to get a consistent Batch Norm Calculation - x_mix = utils.interleave.interleave(x_mix, batch_size) + x_mix = mixmatch.utils.interleave.interleave(x_mix, batch_size) y_mix_pred = [self(x) for x in x_mix] # Un-interleave to shuffle back to original order - y_mix_pred = utils.interleave.interleave(y_mix_pred, batch_size) + y_mix_pred = mixmatch.utils.interleave.interleave(y_mix_pred, batch_size) y_mix_lbl_pred = y_mix_pred[0] y_mix_lbl = y_mix[:batch_size] @@ -179,12 +190,10 @@ def training_step(self, batch, batch_idx): y_mix_lbl = y_mix[:batch_size] y_mix_unl = y_mix[batch_size:] - loss_lbl = self.train_lbl_loss(y_mix_lbl_pred, y_mix_lbl) - loss_unl = self.train_unl_loss(y_mix_unl_pred, y_mix_unl) + loss_lbl = self.get_loss_lbl(y_mix_lbl_pred, y_mix_lbl) + loss_unl = self.get_loss_unl(y_mix_unl_pred, y_mix_unl) - # The scale is a linear ramp up from 0 to self.unl_loss_scale - # over the course of training. - loss_unl_scale = self.progress * self.unl_loss_scale + loss_unl_scale = self.loss_unl_scaler(progress=self.progress) loss = loss_lbl + loss_unl * loss_unl_scale