Skip to content

Commit

Permalink
Implement custom Unl Loss Scaler Scheduling
Browse files Browse the repository at this point in the history
  • Loading branch information
Eve-ning committed Nov 26, 2023
1 parent cc3b718 commit 50cbfb7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
9 changes: 6 additions & 3 deletions mixmatch/main_pl.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from pathlib import Path

import numpy as np
import pytorch_lightning as pl
import torch

from mixmatch.dataset.cifar10 import SSLCIFAR10DataModule
from mixmatch.models.mixmatch_module import MixMatchModule
from mixmatch.models.wideresnet import WideResNet
from mixmatch.utils.ease import ease_out

epochs: int = 100
batch_size: int = 64
k_augs: int = 2
# Scale LR due to removed interleaving
lr: float = 0.002 * np.sqrt((k_augs + 1))
loss_unl_scaler = ease_out(50, 0.95)
weight_decay: float = 0.00004
ema_lr: float = 0.005
train_iters: int = 1024
unl_loss_scale: float = 100
mix_beta_alpha: float = 0.75
sharpen_temp: float = 0.5
device: str = "cuda"
Expand All @@ -23,7 +26,7 @@
train_unl_size: float = 0.980

dm = SSLCIFAR10DataModule(
dir="../tests/data",
dir=Path(__file__).parents[1] / "tests/data",
train_lbl_size=train_lbl_size,
train_unl_size=train_unl_size,
batch_size=batch_size,
Expand All @@ -40,7 +43,7 @@
n_classes=10,
sharpen_temp=sharpen_temp,
mix_beta_alpha=mix_beta_alpha,
unl_loss_scale=unl_loss_scale,
loss_unl_scaler=loss_unl_scaler,
ema_lr=ema_lr,
lr=lr,
weight_decay=weight_decay,
Expand Down
43 changes: 26 additions & 17 deletions mixmatch/models/mixmatch_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass
from typing import Callable
from typing import Callable, Protocol

import numpy as np
import pytorch_lightning as pl
Expand All @@ -12,8 +14,13 @@
from torch.nn.functional import one_hot
from torchmetrics.functional import accuracy

import utils.interleave
from utils.ema import WeightEMA
import mixmatch.utils.interleave
from mixmatch.utils.ema import WeightEMA


class LossUnlScale(Protocol):
def __call__(self, progress: float) -> float:
return progress * 75


# The eq=False is to prevent overriding hash
Expand All @@ -37,38 +44,42 @@ class MixMatchModule(pl.LightningModule):
n_classes: The number of classes in the dataset.
sharpen_temp: The temperature to use for sharpening.
mix_beta_alpha: The alpha to use for the beta distribution when mixing.
unl_loss_scale: The scale to use for the unsupervised loss.
loss_unl_scaler: The scale to use for the unsupervised loss.
ema_lr: The learning rate to use for the EMA.
lr: The learning rate to use for the optimizer.
weight_decay: The weight decay to use for the optimizer.
"""

model_fn: Callable[[], nn.Module]
loss_unl_scaler: LossUnlScale
n_classes: int = 10
sharpen_temp: float = 0.5
mix_beta_alpha: float = 0.75
unl_loss_scale: float = 75
ema_lr: float = 0.001
lr: float = 0.002
weight_decay: float = 0.00004

# See our wiki for details on interleave
interleave: bool = False

train_lbl_loss: Callable[
[torch.Tensor, torch.Tensor], torch.Tensor
] = F.cross_entropy
get_loss_lbl: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.cross_entropy
# TODO: Not sure why this is different from MSELoss
# It's likely not a big deal, but it's worth investigating if we have
# too much time on our hands
train_unl_loss: Callable[
get_loss_unl: Callable[
[torch.Tensor, torch.Tensor], torch.Tensor
] = lambda pred, tgt: torch.mean((torch.softmax(pred, dim=1) - tgt) ** 2)

def __post_init__(self):
super().__init__()
self.save_hyperparameters(
ignore=["model_fn", "train_lbl_loss", "train_unl_loss", "model"]
ignore=[
"model_fn",
"get_loss_lbl",
"get_loss_unl",
"loss_unl_scaler",
"model",
]
)
self.model = self.model_fn()
self.ema_model = deepcopy(self.model)
Expand Down Expand Up @@ -160,12 +171,12 @@ def training_step(self, batch, batch_idx):
x_mix = list(torch.split(x_mix, batch_size))

# Interleave to get a consistent Batch Norm Calculation
x_mix = utils.interleave.interleave(x_mix, batch_size)
x_mix = mixmatch.utils.interleave.interleave(x_mix, batch_size)

y_mix_pred = [self(x) for x in x_mix]

# Un-interleave to shuffle back to original order
y_mix_pred = utils.interleave.interleave(y_mix_pred, batch_size)
y_mix_pred = mixmatch.utils.interleave.interleave(y_mix_pred, batch_size)

y_mix_lbl_pred = y_mix_pred[0]
y_mix_lbl = y_mix[:batch_size]
Expand All @@ -179,12 +190,10 @@ def training_step(self, batch, batch_idx):
y_mix_lbl = y_mix[:batch_size]
y_mix_unl = y_mix[batch_size:]

loss_lbl = self.train_lbl_loss(y_mix_lbl_pred, y_mix_lbl)
loss_unl = self.train_unl_loss(y_mix_unl_pred, y_mix_unl)
loss_lbl = self.get_loss_lbl(y_mix_lbl_pred, y_mix_lbl)
loss_unl = self.get_loss_unl(y_mix_unl_pred, y_mix_unl)

# The scale is a linear ramp up from 0 to self.unl_loss_scale
# over the course of training.
loss_unl_scale = self.progress * self.unl_loss_scale
loss_unl_scale = self.loss_unl_scaler(progress=self.progress)

loss = loss_lbl + loss_unl * loss_unl_scale

Expand Down

0 comments on commit 50cbfb7

Please sign in to comment.