diff --git a/src/frdc/train/frdc_datamodule.py b/src/frdc/train/frdc_datamodule.py index cabcb604..5e4e6dbd 100644 --- a/src/frdc/train/frdc_datamodule.py +++ b/src/frdc/train/frdc_datamodule.py @@ -1,11 +1,13 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Literal from lightning import LightningDataModule -from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data import DataLoader, RandomSampler, Sampler from frdc.load.dataset import FRDCDataset, FRDCUnlabelledDataset +from frdc.train.stratified_sampling import RandomStratifiedSampler @dataclass @@ -61,6 +63,7 @@ class FRDCDataModule(LightningDataModule): batch_size: int = 4 train_iters: int = 100 val_iters: int = 100 + sampling_strategy: Literal["stratified", "random"] = "stratified" def __post_init__(self): super().__init__() @@ -70,24 +73,29 @@ def __post_init__(self): def train_dataloader(self): num_samples = self.batch_size * self.train_iters + if self.sampling_strategy == "stratified": + sampler = lambda ds: RandomStratifiedSampler( + ds.targets, num_samples=num_samples, replacement=True + ) + elif self.sampling_strategy == "random": + sampler = lambda ds: RandomSampler( + ds, num_samples=num_samples, replacement=True + ) + else: + raise ValueError( + f"Invalid sampling strategy: {self.sampling_strategy}" + ) + lab_dl = DataLoader( self.train_lab_ds, batch_size=self.batch_size, - sampler=RandomSampler( - self.train_lab_ds, - num_samples=num_samples, - replacement=False, - ), + sampler=sampler(self.train_lab_ds), ) unl_dl = ( DataLoader( self.train_unl_ds, batch_size=self.batch_size, - sampler=RandomSampler( - self.train_unl_ds, - num_samples=self.batch_size * self.train_iters, - replacement=False, - ), + sampler=sampler(self.train_unl_ds), ) if self.train_unl_ds is not None # This is a hacky way to create an empty dataloader. @@ -99,7 +107,6 @@ def train_dataloader(self): sampler=RandomSampler( empty, num_samples=num_samples, - replacement=False, ), ) ) diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index 784380b6..75e581e8 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -8,6 +8,7 @@ import torch.nn.functional as F import torch.nn.parallel import torch.nn.parallel +import wandb from lightning import LightningModule from sklearn.preprocessing import StandardScaler, OrdinalEncoder from torch.nn.functional import one_hot @@ -52,6 +53,7 @@ def __init__( self.sharpen_temp = sharpen_temp self.mix_beta_alpha = mix_beta_alpha self.save_hyperparameters() + self.lbl_logger = WandBLabelLogger() @property @abstractmethod @@ -150,10 +152,12 @@ def progress(self): ) / self.trainer.max_epochs def training_step(self, batch, batch_idx): - # Progress is a linear ramp from 0 to 1 over the course of training. (x_lbl, y_lbl), x_unls = batch + self.lbl_logger( + self.logger.experiment, "Input Y Label", y_lbl, flush_every=10 + ) - y_lbl = one_hot(y_lbl.long(), num_classes=self.n_classes) + y_lbl_ohe = one_hot(y_lbl.long(), num_classes=self.n_classes) # If x_unls is Truthy, then we are using MixMatch. # Otherwise, we are just using supervised learning. @@ -164,7 +168,7 @@ def training_step(self, batch, batch_idx): y_unl = self.sharpen(y_unl, self.sharpen_temp) x = torch.cat([x_lbl, *x_unls], dim=0) - y = torch.cat([y_lbl, *(y_unl,) * len(x_unls)], dim=0) + y = torch.cat([y_lbl_ohe, *(y_unl,) * len(x_unls)], dim=0) x_mix, y_mix = self.mix_up(x, y, self.mix_beta_alpha) # This had interleaving, but it was removed as it's not @@ -177,7 +181,19 @@ def training_step(self, batch, batch_idx): y_mix_unl = y_mix[batch_size:] loss_lbl = self.loss_lbl(y_mix_lbl_pred, y_mix_lbl) + self.lbl_logger( + self.logger.experiment, + "Labelled Y Pred", + torch.argmax(y_mix_lbl_pred, dim=1), + flush_every=10, + ) loss_unl = self.loss_unl(y_mix_unl_pred, y_mix_unl) + self.lbl_logger( + self.logger.experiment, + "Unlabelled Y Pred", + torch.argmax(y_mix_unl_pred, dim=1), + flush_every=10, + ) loss_unl_scale = self.loss_unl_scaler(progress=self.progress) loss = loss_lbl + loss_unl * loss_unl_scale @@ -188,7 +204,7 @@ def training_step(self, batch, batch_idx): else: # This route implies that we are just using supervised learning y_pred = self(x_lbl) - loss = self.loss_lbl(y_pred, y_lbl.float()) + loss = self.loss_lbl(y_pred, y_lbl_ohe.float()) self.log("train_loss", loss) return loss @@ -201,7 +217,16 @@ def on_after_backward(self) -> None: def validation_step(self, batch, batch_idx): x, y = batch + self.lbl_logger( + self.logger.experiment, "Val Input Y Label", y, flush_every=1 + ) y_pred = self.ema_model(x) + self.lbl_logger( + self.logger.experiment, + "Val Pred Y Label", + torch.argmax(y_pred, dim=1), + flush_every=1, + ) loss = F.cross_entropy(y_pred, y.long()) acc = accuracy( @@ -299,3 +324,38 @@ def y_trans_fn(y): return (x_lab_trans, y_trans.long()), x_unl_trans else: return x_lab_trans, y_trans.long() + + +class WandBLabelLogger(dict): + """Logger to log y labels to WandB""" + + def __call__( + self, + logger: wandb.sdk.wandb_run.Run, + key: str, + value: torch.Tensor, + flush_every: int = 10, + ): + """Log the labels to WandB + + Args: + logger: The W&B logger. Accessible through `self.logger.experiment` + key: The key to log the labels under. + value: The labels to log. + flush_every: How often to flush the labels to WandB. + + """ + if key not in self.keys(): + self[key] = [value] + else: + self[key].append(value) + + if len(self[key]) % flush_every == 0: + logger.log( + { + key: wandb.Histogram( + torch.flatten(value).detach().cpu().tolist() + ) + } + ) + self[key] = [] diff --git a/src/frdc/train/stratified_sampling.py b/src/frdc/train/stratified_sampling.py new file mode 100644 index 00000000..dd17762c --- /dev/null +++ b/src/frdc/train/stratified_sampling.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing import Iterator, Any, Sequence + +import pandas as pd +import torch +from sklearn.preprocessing import LabelEncoder +from torch.utils.data import Sampler + + +class RandomStratifiedSampler(Sampler[int]): + def __init__( + self, + targets: Sequence[Any], + num_samples: int | None = None, + replacement: bool = True, + ) -> None: + """Stratified sampling from a dataset, such that each class is + sampled with equal probability. + + Examples: + Use this with DataLoader to sample from a dataset in a stratified + fashion. For example:: + + ds = TensorDataset(...) + dl = DataLoader( + ds, + batch_size=..., + sampler=RandomStratifiedSampler(), + ) + + This will use the targets' frequency as the inverse probability + for sampling. For example, if the targets are [0, 0, 1, 2], + then the probability of sampling the + + Args: + targets: The targets to stratify by. Must be integers. + num_samples: The number of samples to draw. If None, the + number of samples is equal to the length of the dataset. + """ + super().__init__() + + # Given targets [0, 0, 1] + # bincount = [2, 1] + # 1 / bincount = [0.5, 1] + # 1 / bincount / len(bincount) = [0.25, 0.5] + # The indexing then just projects it to the original targets. + targets_lab = torch.tensor(LabelEncoder().fit_transform(targets)) + self.target_probs: torch.Tensor = ( + 1 / (bincount := torch.bincount(targets_lab)) / len(bincount) + )[targets_lab] + + self.num_samples = num_samples if num_samples else len(targets) + self.replacement = replacement + + def __len__(self) -> int: + return self.num_samples + + def __iter__(self) -> Iterator[int]: + """This should be a generator that yields indices from the dataset.""" + yield from torch.multinomial( + self.target_probs, + num_samples=self.num_samples, + replacement=self.replacement, + ) diff --git a/tests/model_tests/chestnut_dec_may/train.py b/tests/model_tests/chestnut_dec_may/train.py index 70f3dada..863a9476 100644 --- a/tests/model_tests/chestnut_dec_may/train.py +++ b/tests/model_tests/chestnut_dec_may/train.py @@ -25,6 +25,7 @@ from frdc.load.preset import FRDCDatasetPreset as ds from frdc.models.inceptionv3 import InceptionV3MixMatchModule from frdc.train.frdc_datamodule import FRDCDataModule +from frdc.utils.training import predict, plot_confusion_matrix from model_tests.utils import ( train_preprocess, train_unl_preprocess, @@ -42,13 +43,12 @@ def main( ): run = wandb.init() logger = WandbLogger(name="chestnut_dec_may", project="frdc") + # Prepare the dataset train_lab_ds = ds.chestnut_20201218(transform=train_preprocess) - train_unl_ds = ds.chestnut_20201218.unlabelled( transform=train_unl_preprocess(2) ) - val_ds = ds.chestnut_20210510_43m(transform=preprocess) oe = OrdinalEncoder( @@ -64,12 +64,12 @@ def main( # Prepare the datamodule and trainer dm = FRDCDataModule( train_lab_ds=train_lab_ds, - # Pass in None to use the default supervised DM - train_unl_ds=train_unl_ds, + train_unl_ds=train_unl_ds, # None to use supervised DM val_ds=val_ds, batch_size=batch_size, train_iters=train_iters, val_iters=val_iters, + sampling_strategy="stratified", ) trainer = pl.Trainer( @@ -89,12 +89,14 @@ def main( ], logger=logger, ) + m = InceptionV3MixMatchModule( n_classes=n_classes, lr=lr, x_scaler=ss, y_encoder=oe, ) + logger.watch(m) trainer.fit(m, datamodule=dm) diff --git a/tests/unit_tests/train/__init__.py b/tests/unit_tests/train/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/train/test_stratified_sampling.py b/tests/unit_tests/train/test_stratified_sampling.py new file mode 100644 index 00000000..e8019b64 --- /dev/null +++ b/tests/unit_tests/train/test_stratified_sampling.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import torch +from torch.utils.data import DataLoader, TensorDataset + +from frdc.train.stratified_sampling import RandomStratifiedSampler + + +def test_stratifed_sampling_has_correct_probs(): + sampler = RandomStratifiedSampler(["A", "A", "B"]) + + assert torch.all(sampler.target_probs == torch.tensor([0.25, 0.25, 0.5])) + + +def test_stratified_sampling_fairly_samples(): + """This test checks that the stratified sampler works with a dataloader.""" + + # This is a simple example of a dataset with 2 classes. + # The first 2 samples are class 0, the third is class 1. + x = torch.tensor([0, 1, 2]) + y = ["A", "A", "B"] + + # To check that it's truly stratified, we'll sample 1000 times + # then assert that both classes are sampled roughly equally. + + # In this case, the first 2 x should be sampled roughly 250 times, + # and the third x should be sampled roughly 500 times. + + num_samples = 1000 + batch_size = 10 + dl = DataLoader( + TensorDataset(x), + batch_size=batch_size, + sampler=RandomStratifiedSampler(y, num_samples=num_samples), + ) + + # Note that when we sample from a TensorDataset, we get a tuple of tensors. + # So we need to unpack the tuple. + x_samples = torch.cat([x for (x,) in dl]) + + assert len(x_samples) == num_samples + assert torch.allclose( + torch.bincount(x_samples), + torch.tensor([250, 250, 500]), + # atol is the absolute tolerance, so the result can differ by 50 + atol=50, + )