From 86d11df7a27d085496e6c4f7666db74c4f96da74 Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 2 Jan 2024 14:44:40 +0800 Subject: [PATCH 1/3] Implement W&B vis of label spread --- src/frdc/train/mixmatch_module.py | 68 +++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index 784380b6..75e581e8 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -8,6 +8,7 @@ import torch.nn.functional as F import torch.nn.parallel import torch.nn.parallel +import wandb from lightning import LightningModule from sklearn.preprocessing import StandardScaler, OrdinalEncoder from torch.nn.functional import one_hot @@ -52,6 +53,7 @@ def __init__( self.sharpen_temp = sharpen_temp self.mix_beta_alpha = mix_beta_alpha self.save_hyperparameters() + self.lbl_logger = WandBLabelLogger() @property @abstractmethod @@ -150,10 +152,12 @@ def progress(self): ) / self.trainer.max_epochs def training_step(self, batch, batch_idx): - # Progress is a linear ramp from 0 to 1 over the course of training. (x_lbl, y_lbl), x_unls = batch + self.lbl_logger( + self.logger.experiment, "Input Y Label", y_lbl, flush_every=10 + ) - y_lbl = one_hot(y_lbl.long(), num_classes=self.n_classes) + y_lbl_ohe = one_hot(y_lbl.long(), num_classes=self.n_classes) # If x_unls is Truthy, then we are using MixMatch. # Otherwise, we are just using supervised learning. @@ -164,7 +168,7 @@ def training_step(self, batch, batch_idx): y_unl = self.sharpen(y_unl, self.sharpen_temp) x = torch.cat([x_lbl, *x_unls], dim=0) - y = torch.cat([y_lbl, *(y_unl,) * len(x_unls)], dim=0) + y = torch.cat([y_lbl_ohe, *(y_unl,) * len(x_unls)], dim=0) x_mix, y_mix = self.mix_up(x, y, self.mix_beta_alpha) # This had interleaving, but it was removed as it's not @@ -177,7 +181,19 @@ def training_step(self, batch, batch_idx): y_mix_unl = y_mix[batch_size:] loss_lbl = self.loss_lbl(y_mix_lbl_pred, y_mix_lbl) + self.lbl_logger( + self.logger.experiment, + "Labelled Y Pred", + torch.argmax(y_mix_lbl_pred, dim=1), + flush_every=10, + ) loss_unl = self.loss_unl(y_mix_unl_pred, y_mix_unl) + self.lbl_logger( + self.logger.experiment, + "Unlabelled Y Pred", + torch.argmax(y_mix_unl_pred, dim=1), + flush_every=10, + ) loss_unl_scale = self.loss_unl_scaler(progress=self.progress) loss = loss_lbl + loss_unl * loss_unl_scale @@ -188,7 +204,7 @@ def training_step(self, batch, batch_idx): else: # This route implies that we are just using supervised learning y_pred = self(x_lbl) - loss = self.loss_lbl(y_pred, y_lbl.float()) + loss = self.loss_lbl(y_pred, y_lbl_ohe.float()) self.log("train_loss", loss) return loss @@ -201,7 +217,16 @@ def on_after_backward(self) -> None: def validation_step(self, batch, batch_idx): x, y = batch + self.lbl_logger( + self.logger.experiment, "Val Input Y Label", y, flush_every=1 + ) y_pred = self.ema_model(x) + self.lbl_logger( + self.logger.experiment, + "Val Pred Y Label", + torch.argmax(y_pred, dim=1), + flush_every=1, + ) loss = F.cross_entropy(y_pred, y.long()) acc = accuracy( @@ -299,3 +324,38 @@ def y_trans_fn(y): return (x_lab_trans, y_trans.long()), x_unl_trans else: return x_lab_trans, y_trans.long() + + +class WandBLabelLogger(dict): + """Logger to log y labels to WandB""" + + def __call__( + self, + logger: wandb.sdk.wandb_run.Run, + key: str, + value: torch.Tensor, + flush_every: int = 10, + ): + """Log the labels to WandB + + Args: + logger: The W&B logger. Accessible through `self.logger.experiment` + key: The key to log the labels under. + value: The labels to log. + flush_every: How often to flush the labels to WandB. + + """ + if key not in self.keys(): + self[key] = [value] + else: + self[key].append(value) + + if len(self[key]) % flush_every == 0: + logger.log( + { + key: wandb.Histogram( + torch.flatten(value).detach().cpu().tolist() + ) + } + ) + self[key] = [] From a355c39843f0cb835c7a1d9034346f479b191488 Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 2 Jan 2024 14:45:01 +0800 Subject: [PATCH 2/3] Clean up train.py --- tests/model_tests/chestnut_dec_may/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/model_tests/chestnut_dec_may/train.py b/tests/model_tests/chestnut_dec_may/train.py index c37404e8..28e79a92 100644 --- a/tests/model_tests/chestnut_dec_may/train.py +++ b/tests/model_tests/chestnut_dec_may/train.py @@ -43,13 +43,12 @@ def main( ): run = wandb.init() logger = WandbLogger(name="chestnut_dec_may", project="frdc") + # Prepare the dataset train_lab_ds = ds.chestnut_20201218(transform=train_preprocess) - train_unl_ds = ds.chestnut_20201218.unlabelled( transform=train_unl_preprocess(2) ) - val_ds = ds.chestnut_20210510_43m(transform=preprocess) oe = OrdinalEncoder( @@ -65,12 +64,12 @@ def main( # Prepare the datamodule and trainer dm = FRDCDataModule( train_lab_ds=train_lab_ds, - # Pass in None to use the default supervised DM - train_unl_ds=train_unl_ds, + train_unl_ds=train_unl_ds, # None to use supervised DM val_ds=val_ds, batch_size=batch_size, train_iters=train_iters, val_iters=val_iters, + sampling_strategy="stratified", ) trainer = pl.Trainer( @@ -90,6 +89,7 @@ def main( ], logger=logger, ) + m = InceptionV3MixMatchModule( n_classes=n_classes, lr=lr, From dff83781b1461a91c524b5df5ef2a086720b3f8e Mon Sep 17 00:00:00 2001 From: Evening Date: Tue, 2 Jan 2024 14:45:10 +0800 Subject: [PATCH 3/3] Make W&B Watch model --- tests/model_tests/chestnut_dec_may/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/model_tests/chestnut_dec_may/train.py b/tests/model_tests/chestnut_dec_may/train.py index 28e79a92..863a9476 100644 --- a/tests/model_tests/chestnut_dec_may/train.py +++ b/tests/model_tests/chestnut_dec_may/train.py @@ -96,6 +96,7 @@ def main( x_scaler=ss, y_encoder=oe, ) + logger.watch(m) trainer.fit(m, datamodule=dm)