From 3289d49491bb24497112abaddcb3f75b832688a3 Mon Sep 17 00:00:00 2001 From: Evening Date: Fri, 29 Dec 2023 16:03:03 +0800 Subject: [PATCH 1/9] Implement Stratified Sampling --- src/frdc/train/stratified_sampling.py | 60 +++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/frdc/train/stratified_sampling.py diff --git a/src/frdc/train/stratified_sampling.py b/src/frdc/train/stratified_sampling.py new file mode 100644 index 00000000..4d1e96cd --- /dev/null +++ b/src/frdc/train/stratified_sampling.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import Iterator + +import torch +from torch.utils.data import Sampler + + +class RandomStratifiedSampler(Sampler[int]): + def __init__( + self, + targets: torch.Tensor, + num_samples: int | None = None, + ) -> None: + """Stratified sampling from a dataset, such that each class is + sampled with equal probability. + + Examples: + Use this with DataLoader to sample from a dataset in a stratified + fashion. For example:: + + ds = TensorDataset(...) + dl = DataLoader( + ds, + batch_size=..., + sampler=RandomStratifiedSampler(), + ) + + This will use the targets' frequency as the inverse probability + for sampling. For example, if the targets are [0, 0, 1, 2], + then the probability of sampling the + + Args: + targets: The targets to stratify by. Must be integers. + num_samples: The number of samples to draw. If None, the + number of samples is equal to the length of the dataset. + """ + super().__init__() + + # Given targets [0, 0, 1] + # bincount = [2, 1] + # 1 / bincount = [0.5, 1] + # 1 / bincount / len(bincount) = [0.25, 0.5] + # The indexing then just projects it to the original targets. + self.target_probs: torch.Tensor = ( + 1 / (bincount := torch.bincount(targets)) / len(bincount) + )[targets] + + self.num_samples = num_samples if num_samples else len(targets) + + def __len__(self) -> int: + return self.num_samples + + def __iter__(self) -> Iterator[int]: + """This should be a generator that yields indices from the dataset.""" + yield from torch.multinomial( + self.target_probs, + num_samples=self.num_samples, + replacement=True, + ) From fdfa17a99b8d384c32f0d0793532640fa3783af2 Mon Sep 17 00:00:00 2001 From: Evening Date: Fri, 29 Dec 2023 16:03:09 +0800 Subject: [PATCH 2/9] Add test for Stratified Sampling --- tests/unit_tests/train/__init__.py | 0 .../train/test_stratified_sampling.py | 47 +++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 tests/unit_tests/train/__init__.py create mode 100644 tests/unit_tests/train/test_stratified_sampling.py diff --git a/tests/unit_tests/train/__init__.py b/tests/unit_tests/train/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/train/test_stratified_sampling.py b/tests/unit_tests/train/test_stratified_sampling.py new file mode 100644 index 00000000..78eecb7b --- /dev/null +++ b/tests/unit_tests/train/test_stratified_sampling.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import torch +from torch.utils.data import DataLoader, TensorDataset + +from frdc.train.stratified_sampling import RandomStratifiedSampler + + +def test_stratifed_sampling_has_correct_probs(): + sampler = RandomStratifiedSampler(torch.tensor([0, 0, 1])) + + assert torch.all(sampler.target_probs == torch.tensor([0.25, 0.25, 0.5])) + + +def test_stratified_sampling_fairly_samples(): + """This test checks that the stratified sampler works with a dataloader.""" + + # This is a simple example of a dataset with 2 classes. + # The first 2 samples are class 0, the third is class 1. + x = torch.tensor([0, 1, 2]) + y = torch.tensor([0, 0, 1]) + + # To check that it's truly stratified, we'll sample 1000 times + # then assert that both classes are sampled roughly equally. + + # In this case, the first 2 x should be sampled roughly 250 times, + # and the third x should be sampled roughly 500 times. + + num_samples = 1000 + batch_size = 10 + dl = DataLoader( + TensorDataset(x), + batch_size=batch_size, + sampler=RandomStratifiedSampler(y, num_samples=num_samples), + ) + + # Note that when we sample from a TensorDataset, we get a tuple of tensors. + # So we need to unpack the tuple. + x_samples = torch.cat([x for (x,) in dl]) + + assert len(x_samples) == num_samples + assert torch.allclose( + torch.bincount(x_samples), + torch.tensor([250, 250, 500]), + # atol is the absolute tolerance, so the result can differ by 50 + atol=50, + ) From 349e7cd3051924899351a03edc7ec544e3f9aa90 Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 2 Jan 2024 11:36:30 +0800 Subject: [PATCH 3/9] Implement Stratified Sampling on DM --- src/frdc/train/frdc_datamodule.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/frdc/train/frdc_datamodule.py b/src/frdc/train/frdc_datamodule.py index cabcb604..5e4e6dbd 100644 --- a/src/frdc/train/frdc_datamodule.py +++ b/src/frdc/train/frdc_datamodule.py @@ -1,11 +1,13 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Literal from lightning import LightningDataModule -from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data import DataLoader, RandomSampler, Sampler from frdc.load.dataset import FRDCDataset, FRDCUnlabelledDataset +from frdc.train.stratified_sampling import RandomStratifiedSampler @dataclass @@ -61,6 +63,7 @@ class FRDCDataModule(LightningDataModule): batch_size: int = 4 train_iters: int = 100 val_iters: int = 100 + sampling_strategy: Literal["stratified", "random"] = "stratified" def __post_init__(self): super().__init__() @@ -70,24 +73,29 @@ def __post_init__(self): def train_dataloader(self): num_samples = self.batch_size * self.train_iters + if self.sampling_strategy == "stratified": + sampler = lambda ds: RandomStratifiedSampler( + ds.targets, num_samples=num_samples, replacement=True + ) + elif self.sampling_strategy == "random": + sampler = lambda ds: RandomSampler( + ds, num_samples=num_samples, replacement=True + ) + else: + raise ValueError( + f"Invalid sampling strategy: {self.sampling_strategy}" + ) + lab_dl = DataLoader( self.train_lab_ds, batch_size=self.batch_size, - sampler=RandomSampler( - self.train_lab_ds, - num_samples=num_samples, - replacement=False, - ), + sampler=sampler(self.train_lab_ds), ) unl_dl = ( DataLoader( self.train_unl_ds, batch_size=self.batch_size, - sampler=RandomSampler( - self.train_unl_ds, - num_samples=self.batch_size * self.train_iters, - replacement=False, - ), + sampler=sampler(self.train_unl_ds), ) if self.train_unl_ds is not None # This is a hacky way to create an empty dataloader. @@ -99,7 +107,6 @@ def train_dataloader(self): sampler=RandomSampler( empty, num_samples=num_samples, - replacement=False, ), ) ) From dc05b35ab13f9bde505efc1fd450dd2b9aa421f2 Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 2 Jan 2024 11:42:16 +0800 Subject: [PATCH 4/9] Allow Stratified Sampling for arbitrary seq types --- src/frdc/train/stratified_sampling.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/frdc/train/stratified_sampling.py b/src/frdc/train/stratified_sampling.py index 4d1e96cd..dd17762c 100644 --- a/src/frdc/train/stratified_sampling.py +++ b/src/frdc/train/stratified_sampling.py @@ -1,16 +1,19 @@ from __future__ import annotations -from typing import Iterator +from typing import Iterator, Any, Sequence +import pandas as pd import torch +from sklearn.preprocessing import LabelEncoder from torch.utils.data import Sampler class RandomStratifiedSampler(Sampler[int]): def __init__( self, - targets: torch.Tensor, + targets: Sequence[Any], num_samples: int | None = None, + replacement: bool = True, ) -> None: """Stratified sampling from a dataset, such that each class is sampled with equal probability. @@ -42,11 +45,13 @@ def __init__( # 1 / bincount = [0.5, 1] # 1 / bincount / len(bincount) = [0.25, 0.5] # The indexing then just projects it to the original targets. + targets_lab = torch.tensor(LabelEncoder().fit_transform(targets)) self.target_probs: torch.Tensor = ( - 1 / (bincount := torch.bincount(targets)) / len(bincount) - )[targets] + 1 / (bincount := torch.bincount(targets_lab)) / len(bincount) + )[targets_lab] self.num_samples = num_samples if num_samples else len(targets) + self.replacement = replacement def __len__(self) -> int: return self.num_samples @@ -56,5 +61,5 @@ def __iter__(self) -> Iterator[int]: yield from torch.multinomial( self.target_probs, num_samples=self.num_samples, - replacement=True, + replacement=self.replacement, ) From a8dcafcf9acbad694fb1b1b40395d6b5ac482f9b Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 2 Jan 2024 11:42:35 +0800 Subject: [PATCH 5/9] Fix missing imports for pred and plot --- tests/model_tests/chestnut_dec_may/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/model_tests/chestnut_dec_may/train.py b/tests/model_tests/chestnut_dec_may/train.py index 70f3dada..c37404e8 100644 --- a/tests/model_tests/chestnut_dec_may/train.py +++ b/tests/model_tests/chestnut_dec_may/train.py @@ -25,6 +25,7 @@ from frdc.load.preset import FRDCDatasetPreset as ds from frdc.models.inceptionv3 import InceptionV3MixMatchModule from frdc.train.frdc_datamodule import FRDCDataModule +from frdc.utils.training import predict, plot_confusion_matrix from model_tests.utils import ( train_preprocess, train_unl_preprocess, From e6f6a9c6ea984ebe1658d48380168c29903cb0bb Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 2 Jan 2024 11:42:53 +0800 Subject: [PATCH 6/9] Change test to use str list --- tests/unit_tests/train/test_stratified_sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/train/test_stratified_sampling.py b/tests/unit_tests/train/test_stratified_sampling.py index 78eecb7b..e8019b64 100644 --- a/tests/unit_tests/train/test_stratified_sampling.py +++ b/tests/unit_tests/train/test_stratified_sampling.py @@ -7,7 +7,7 @@ def test_stratifed_sampling_has_correct_probs(): - sampler = RandomStratifiedSampler(torch.tensor([0, 0, 1])) + sampler = RandomStratifiedSampler(["A", "A", "B"]) assert torch.all(sampler.target_probs == torch.tensor([0.25, 0.25, 0.5])) @@ -18,7 +18,7 @@ def test_stratified_sampling_fairly_samples(): # This is a simple example of a dataset with 2 classes. # The first 2 samples are class 0, the third is class 1. x = torch.tensor([0, 1, 2]) - y = torch.tensor([0, 0, 1]) + y = ["A", "A", "B"] # To check that it's truly stratified, we'll sample 1000 times # then assert that both classes are sampled roughly equally. From 86d11df7a27d085496e6c4f7666db74c4f96da74 Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 2 Jan 2024 14:44:40 +0800 Subject: [PATCH 7/9] Implement W&B vis of label spread --- src/frdc/train/mixmatch_module.py | 68 +++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index 784380b6..75e581e8 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -8,6 +8,7 @@ import torch.nn.functional as F import torch.nn.parallel import torch.nn.parallel +import wandb from lightning import LightningModule from sklearn.preprocessing import StandardScaler, OrdinalEncoder from torch.nn.functional import one_hot @@ -52,6 +53,7 @@ def __init__( self.sharpen_temp = sharpen_temp self.mix_beta_alpha = mix_beta_alpha self.save_hyperparameters() + self.lbl_logger = WandBLabelLogger() @property @abstractmethod @@ -150,10 +152,12 @@ def progress(self): ) / self.trainer.max_epochs def training_step(self, batch, batch_idx): - # Progress is a linear ramp from 0 to 1 over the course of training. (x_lbl, y_lbl), x_unls = batch + self.lbl_logger( + self.logger.experiment, "Input Y Label", y_lbl, flush_every=10 + ) - y_lbl = one_hot(y_lbl.long(), num_classes=self.n_classes) + y_lbl_ohe = one_hot(y_lbl.long(), num_classes=self.n_classes) # If x_unls is Truthy, then we are using MixMatch. # Otherwise, we are just using supervised learning. @@ -164,7 +168,7 @@ def training_step(self, batch, batch_idx): y_unl = self.sharpen(y_unl, self.sharpen_temp) x = torch.cat([x_lbl, *x_unls], dim=0) - y = torch.cat([y_lbl, *(y_unl,) * len(x_unls)], dim=0) + y = torch.cat([y_lbl_ohe, *(y_unl,) * len(x_unls)], dim=0) x_mix, y_mix = self.mix_up(x, y, self.mix_beta_alpha) # This had interleaving, but it was removed as it's not @@ -177,7 +181,19 @@ def training_step(self, batch, batch_idx): y_mix_unl = y_mix[batch_size:] loss_lbl = self.loss_lbl(y_mix_lbl_pred, y_mix_lbl) + self.lbl_logger( + self.logger.experiment, + "Labelled Y Pred", + torch.argmax(y_mix_lbl_pred, dim=1), + flush_every=10, + ) loss_unl = self.loss_unl(y_mix_unl_pred, y_mix_unl) + self.lbl_logger( + self.logger.experiment, + "Unlabelled Y Pred", + torch.argmax(y_mix_unl_pred, dim=1), + flush_every=10, + ) loss_unl_scale = self.loss_unl_scaler(progress=self.progress) loss = loss_lbl + loss_unl * loss_unl_scale @@ -188,7 +204,7 @@ def training_step(self, batch, batch_idx): else: # This route implies that we are just using supervised learning y_pred = self(x_lbl) - loss = self.loss_lbl(y_pred, y_lbl.float()) + loss = self.loss_lbl(y_pred, y_lbl_ohe.float()) self.log("train_loss", loss) return loss @@ -201,7 +217,16 @@ def on_after_backward(self) -> None: def validation_step(self, batch, batch_idx): x, y = batch + self.lbl_logger( + self.logger.experiment, "Val Input Y Label", y, flush_every=1 + ) y_pred = self.ema_model(x) + self.lbl_logger( + self.logger.experiment, + "Val Pred Y Label", + torch.argmax(y_pred, dim=1), + flush_every=1, + ) loss = F.cross_entropy(y_pred, y.long()) acc = accuracy( @@ -299,3 +324,38 @@ def y_trans_fn(y): return (x_lab_trans, y_trans.long()), x_unl_trans else: return x_lab_trans, y_trans.long() + + +class WandBLabelLogger(dict): + """Logger to log y labels to WandB""" + + def __call__( + self, + logger: wandb.sdk.wandb_run.Run, + key: str, + value: torch.Tensor, + flush_every: int = 10, + ): + """Log the labels to WandB + + Args: + logger: The W&B logger. Accessible through `self.logger.experiment` + key: The key to log the labels under. + value: The labels to log. + flush_every: How often to flush the labels to WandB. + + """ + if key not in self.keys(): + self[key] = [value] + else: + self[key].append(value) + + if len(self[key]) % flush_every == 0: + logger.log( + { + key: wandb.Histogram( + torch.flatten(value).detach().cpu().tolist() + ) + } + ) + self[key] = [] From a355c39843f0cb835c7a1d9034346f479b191488 Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 2 Jan 2024 14:45:01 +0800 Subject: [PATCH 8/9] Clean up train.py --- tests/model_tests/chestnut_dec_may/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/model_tests/chestnut_dec_may/train.py b/tests/model_tests/chestnut_dec_may/train.py index c37404e8..28e79a92 100644 --- a/tests/model_tests/chestnut_dec_may/train.py +++ b/tests/model_tests/chestnut_dec_may/train.py @@ -43,13 +43,12 @@ def main( ): run = wandb.init() logger = WandbLogger(name="chestnut_dec_may", project="frdc") + # Prepare the dataset train_lab_ds = ds.chestnut_20201218(transform=train_preprocess) - train_unl_ds = ds.chestnut_20201218.unlabelled( transform=train_unl_preprocess(2) ) - val_ds = ds.chestnut_20210510_43m(transform=preprocess) oe = OrdinalEncoder( @@ -65,12 +64,12 @@ def main( # Prepare the datamodule and trainer dm = FRDCDataModule( train_lab_ds=train_lab_ds, - # Pass in None to use the default supervised DM - train_unl_ds=train_unl_ds, + train_unl_ds=train_unl_ds, # None to use supervised DM val_ds=val_ds, batch_size=batch_size, train_iters=train_iters, val_iters=val_iters, + sampling_strategy="stratified", ) trainer = pl.Trainer( @@ -90,6 +89,7 @@ def main( ], logger=logger, ) + m = InceptionV3MixMatchModule( n_classes=n_classes, lr=lr, From dff83781b1461a91c524b5df5ef2a086720b3f8e Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 2 Jan 2024 14:45:10 +0800 Subject: [PATCH 9/9] Make W&B Watch model --- tests/model_tests/chestnut_dec_may/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/model_tests/chestnut_dec_may/train.py b/tests/model_tests/chestnut_dec_may/train.py index 28e79a92..863a9476 100644 --- a/tests/model_tests/chestnut_dec_may/train.py +++ b/tests/model_tests/chestnut_dec_may/train.py @@ -96,6 +96,7 @@ def main( x_scaler=ss, y_encoder=oe, ) + logger.watch(m) trainer.fit(m, datamodule=dm)