Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FRML-81 Implement Stratified Sampling #43

Merged
merged 10 commits into from
Jan 2, 2024
31 changes: 19 additions & 12 deletions src/frdc/train/frdc_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

from lightning import LightningDataModule
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader, RandomSampler, Sampler

from frdc.load.dataset import FRDCDataset, FRDCUnlabelledDataset
from frdc.train.stratified_sampling import RandomStratifiedSampler


@dataclass
Expand Down Expand Up @@ -61,6 +63,7 @@ class FRDCDataModule(LightningDataModule):
batch_size: int = 4
train_iters: int = 100
val_iters: int = 100
sampling_strategy: Literal["stratified", "random"] = "stratified"

def __post_init__(self):
super().__init__()
Expand All @@ -70,24 +73,29 @@ def __post_init__(self):

def train_dataloader(self):
num_samples = self.batch_size * self.train_iters
if self.sampling_strategy == "stratified":
sampler = lambda ds: RandomStratifiedSampler(
ds.targets, num_samples=num_samples, replacement=True
)
elif self.sampling_strategy == "random":
sampler = lambda ds: RandomSampler(
ds, num_samples=num_samples, replacement=True
)
else:
raise ValueError(
f"Invalid sampling strategy: {self.sampling_strategy}"
)

lab_dl = DataLoader(
self.train_lab_ds,
batch_size=self.batch_size,
sampler=RandomSampler(
self.train_lab_ds,
num_samples=num_samples,
replacement=False,
),
sampler=sampler(self.train_lab_ds),
)
unl_dl = (
DataLoader(
self.train_unl_ds,
batch_size=self.batch_size,
sampler=RandomSampler(
self.train_unl_ds,
num_samples=self.batch_size * self.train_iters,
replacement=False,
),
sampler=sampler(self.train_unl_ds),
)
if self.train_unl_ds is not None
# This is a hacky way to create an empty dataloader.
Expand All @@ -99,7 +107,6 @@ def train_dataloader(self):
sampler=RandomSampler(
empty,
num_samples=num_samples,
replacement=False,
),
)
)
Expand Down
68 changes: 64 additions & 4 deletions src/frdc/train/mixmatch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F
import torch.nn.parallel
import torch.nn.parallel
import wandb
from lightning import LightningModule
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
from torch.nn.functional import one_hot
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
self.sharpen_temp = sharpen_temp
self.mix_beta_alpha = mix_beta_alpha
self.save_hyperparameters()
self.lbl_logger = WandBLabelLogger()

@property
@abstractmethod
Expand Down Expand Up @@ -150,10 +152,12 @@ def progress(self):
) / self.trainer.max_epochs

def training_step(self, batch, batch_idx):
# Progress is a linear ramp from 0 to 1 over the course of training.
(x_lbl, y_lbl), x_unls = batch
self.lbl_logger(
self.logger.experiment, "Input Y Label", y_lbl, flush_every=10
)

y_lbl = one_hot(y_lbl.long(), num_classes=self.n_classes)
y_lbl_ohe = one_hot(y_lbl.long(), num_classes=self.n_classes)

# If x_unls is Truthy, then we are using MixMatch.
# Otherwise, we are just using supervised learning.
Expand All @@ -164,7 +168,7 @@ def training_step(self, batch, batch_idx):
y_unl = self.sharpen(y_unl, self.sharpen_temp)

x = torch.cat([x_lbl, *x_unls], dim=0)
y = torch.cat([y_lbl, *(y_unl,) * len(x_unls)], dim=0)
y = torch.cat([y_lbl_ohe, *(y_unl,) * len(x_unls)], dim=0)
x_mix, y_mix = self.mix_up(x, y, self.mix_beta_alpha)

# This had interleaving, but it was removed as it's not
Expand All @@ -177,7 +181,19 @@ def training_step(self, batch, batch_idx):
y_mix_unl = y_mix[batch_size:]

loss_lbl = self.loss_lbl(y_mix_lbl_pred, y_mix_lbl)
self.lbl_logger(
self.logger.experiment,
"Labelled Y Pred",
torch.argmax(y_mix_lbl_pred, dim=1),
flush_every=10,
)
loss_unl = self.loss_unl(y_mix_unl_pred, y_mix_unl)
self.lbl_logger(
self.logger.experiment,
"Unlabelled Y Pred",
torch.argmax(y_mix_unl_pred, dim=1),
flush_every=10,
)
loss_unl_scale = self.loss_unl_scaler(progress=self.progress)

loss = loss_lbl + loss_unl * loss_unl_scale
Expand All @@ -188,7 +204,7 @@ def training_step(self, batch, batch_idx):
else:
# This route implies that we are just using supervised learning
y_pred = self(x_lbl)
loss = self.loss_lbl(y_pred, y_lbl.float())
loss = self.loss_lbl(y_pred, y_lbl_ohe.float())

self.log("train_loss", loss)
return loss
Expand All @@ -201,7 +217,16 @@ def on_after_backward(self) -> None:

def validation_step(self, batch, batch_idx):
x, y = batch
self.lbl_logger(
self.logger.experiment, "Val Input Y Label", y, flush_every=1
)
y_pred = self.ema_model(x)
self.lbl_logger(
self.logger.experiment,
"Val Pred Y Label",
torch.argmax(y_pred, dim=1),
flush_every=1,
)
loss = F.cross_entropy(y_pred, y.long())

acc = accuracy(
Expand Down Expand Up @@ -299,3 +324,38 @@ def y_trans_fn(y):
return (x_lab_trans, y_trans.long()), x_unl_trans
else:
return x_lab_trans, y_trans.long()


class WandBLabelLogger(dict):
"""Logger to log y labels to WandB"""

def __call__(
self,
logger: wandb.sdk.wandb_run.Run,
key: str,
value: torch.Tensor,
flush_every: int = 10,
):
"""Log the labels to WandB

Args:
logger: The W&B logger. Accessible through `self.logger.experiment`
key: The key to log the labels under.
value: The labels to log.
flush_every: How often to flush the labels to WandB.

"""
if key not in self.keys():
self[key] = [value]
else:
self[key].append(value)

if len(self[key]) % flush_every == 0:
logger.log(
{
key: wandb.Histogram(
torch.flatten(value).detach().cpu().tolist()
)
}
)
self[key] = []
65 changes: 65 additions & 0 deletions src/frdc/train/stratified_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

from typing import Iterator, Any, Sequence

import pandas as pd
import torch
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Sampler


class RandomStratifiedSampler(Sampler[int]):
def __init__(
self,
targets: Sequence[Any],
num_samples: int | None = None,
replacement: bool = True,
) -> None:
"""Stratified sampling from a dataset, such that each class is
sampled with equal probability.

Examples:
Use this with DataLoader to sample from a dataset in a stratified
fashion. For example::

ds = TensorDataset(...)
dl = DataLoader(
ds,
batch_size=...,
sampler=RandomStratifiedSampler(),
)

This will use the targets' frequency as the inverse probability
for sampling. For example, if the targets are [0, 0, 1, 2],
then the probability of sampling the

Args:
targets: The targets to stratify by. Must be integers.
num_samples: The number of samples to draw. If None, the
number of samples is equal to the length of the dataset.
"""
super().__init__()

# Given targets [0, 0, 1]
# bincount = [2, 1]
# 1 / bincount = [0.5, 1]
# 1 / bincount / len(bincount) = [0.25, 0.5]
# The indexing then just projects it to the original targets.
targets_lab = torch.tensor(LabelEncoder().fit_transform(targets))
self.target_probs: torch.Tensor = (
1 / (bincount := torch.bincount(targets_lab)) / len(bincount)
)[targets_lab]

self.num_samples = num_samples if num_samples else len(targets)
self.replacement = replacement

def __len__(self) -> int:
return self.num_samples

def __iter__(self) -> Iterator[int]:
"""This should be a generator that yields indices from the dataset."""
yield from torch.multinomial(
self.target_probs,
num_samples=self.num_samples,
replacement=self.replacement,
)
10 changes: 6 additions & 4 deletions tests/model_tests/chestnut_dec_may/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from frdc.load.preset import FRDCDatasetPreset as ds
from frdc.models.inceptionv3 import InceptionV3MixMatchModule
from frdc.train.frdc_datamodule import FRDCDataModule
from frdc.utils.training import predict, plot_confusion_matrix
from model_tests.utils import (
train_preprocess,
train_unl_preprocess,
Expand All @@ -42,13 +43,12 @@ def main(
):
run = wandb.init()
logger = WandbLogger(name="chestnut_dec_may", project="frdc")

# Prepare the dataset
train_lab_ds = ds.chestnut_20201218(transform=train_preprocess)

train_unl_ds = ds.chestnut_20201218.unlabelled(
transform=train_unl_preprocess(2)
)

val_ds = ds.chestnut_20210510_43m(transform=preprocess)

oe = OrdinalEncoder(
Expand All @@ -64,12 +64,12 @@ def main(
# Prepare the datamodule and trainer
dm = FRDCDataModule(
train_lab_ds=train_lab_ds,
# Pass in None to use the default supervised DM
train_unl_ds=train_unl_ds,
train_unl_ds=train_unl_ds, # None to use supervised DM
val_ds=val_ds,
batch_size=batch_size,
train_iters=train_iters,
val_iters=val_iters,
sampling_strategy="stratified",
)

trainer = pl.Trainer(
Expand All @@ -89,12 +89,14 @@ def main(
],
logger=logger,
)

m = InceptionV3MixMatchModule(
n_classes=n_classes,
lr=lr,
x_scaler=ss,
y_encoder=oe,
)
logger.watch(m)

trainer.fit(m, datamodule=dm)

Expand Down
Empty file.
47 changes: 47 additions & 0 deletions tests/unit_tests/train/test_stratified_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import torch
from torch.utils.data import DataLoader, TensorDataset

from frdc.train.stratified_sampling import RandomStratifiedSampler


def test_stratifed_sampling_has_correct_probs():
sampler = RandomStratifiedSampler(["A", "A", "B"])

assert torch.all(sampler.target_probs == torch.tensor([0.25, 0.25, 0.5]))


def test_stratified_sampling_fairly_samples():
"""This test checks that the stratified sampler works with a dataloader."""

# This is a simple example of a dataset with 2 classes.
# The first 2 samples are class 0, the third is class 1.
x = torch.tensor([0, 1, 2])
y = ["A", "A", "B"]

# To check that it's truly stratified, we'll sample 1000 times
# then assert that both classes are sampled roughly equally.

# In this case, the first 2 x should be sampled roughly 250 times,
# and the third x should be sampled roughly 500 times.

num_samples = 1000
batch_size = 10
dl = DataLoader(
TensorDataset(x),
batch_size=batch_size,
sampler=RandomStratifiedSampler(y, num_samples=num_samples),
)

# Note that when we sample from a TensorDataset, we get a tuple of tensors.
# So we need to unpack the tuple.
x_samples = torch.cat([x for (x,) in dl])

assert len(x_samples) == num_samples
assert torch.allclose(
torch.bincount(x_samples),
torch.tensor([250, 250, 500]),
# atol is the absolute tolerance, so the result can differ by 50
atol=50,
)