diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index 784380b6..75e581e8 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -8,6 +8,7 @@ import torch.nn.functional as F import torch.nn.parallel import torch.nn.parallel +import wandb from lightning import LightningModule from sklearn.preprocessing import StandardScaler, OrdinalEncoder from torch.nn.functional import one_hot @@ -52,6 +53,7 @@ def __init__( self.sharpen_temp = sharpen_temp self.mix_beta_alpha = mix_beta_alpha self.save_hyperparameters() + self.lbl_logger = WandBLabelLogger() @property @abstractmethod @@ -150,10 +152,12 @@ def progress(self): ) / self.trainer.max_epochs def training_step(self, batch, batch_idx): - # Progress is a linear ramp from 0 to 1 over the course of training. (x_lbl, y_lbl), x_unls = batch + self.lbl_logger( + self.logger.experiment, "Input Y Label", y_lbl, flush_every=10 + ) - y_lbl = one_hot(y_lbl.long(), num_classes=self.n_classes) + y_lbl_ohe = one_hot(y_lbl.long(), num_classes=self.n_classes) # If x_unls is Truthy, then we are using MixMatch. # Otherwise, we are just using supervised learning. @@ -164,7 +168,7 @@ def training_step(self, batch, batch_idx): y_unl = self.sharpen(y_unl, self.sharpen_temp) x = torch.cat([x_lbl, *x_unls], dim=0) - y = torch.cat([y_lbl, *(y_unl,) * len(x_unls)], dim=0) + y = torch.cat([y_lbl_ohe, *(y_unl,) * len(x_unls)], dim=0) x_mix, y_mix = self.mix_up(x, y, self.mix_beta_alpha) # This had interleaving, but it was removed as it's not @@ -177,7 +181,19 @@ def training_step(self, batch, batch_idx): y_mix_unl = y_mix[batch_size:] loss_lbl = self.loss_lbl(y_mix_lbl_pred, y_mix_lbl) + self.lbl_logger( + self.logger.experiment, + "Labelled Y Pred", + torch.argmax(y_mix_lbl_pred, dim=1), + flush_every=10, + ) loss_unl = self.loss_unl(y_mix_unl_pred, y_mix_unl) + self.lbl_logger( + self.logger.experiment, + "Unlabelled Y Pred", + torch.argmax(y_mix_unl_pred, dim=1), + flush_every=10, + ) loss_unl_scale = self.loss_unl_scaler(progress=self.progress) loss = loss_lbl + loss_unl * loss_unl_scale @@ -188,7 +204,7 @@ def training_step(self, batch, batch_idx): else: # This route implies that we are just using supervised learning y_pred = self(x_lbl) - loss = self.loss_lbl(y_pred, y_lbl.float()) + loss = self.loss_lbl(y_pred, y_lbl_ohe.float()) self.log("train_loss", loss) return loss @@ -201,7 +217,16 @@ def on_after_backward(self) -> None: def validation_step(self, batch, batch_idx): x, y = batch + self.lbl_logger( + self.logger.experiment, "Val Input Y Label", y, flush_every=1 + ) y_pred = self.ema_model(x) + self.lbl_logger( + self.logger.experiment, + "Val Pred Y Label", + torch.argmax(y_pred, dim=1), + flush_every=1, + ) loss = F.cross_entropy(y_pred, y.long()) acc = accuracy( @@ -299,3 +324,38 @@ def y_trans_fn(y): return (x_lab_trans, y_trans.long()), x_unl_trans else: return x_lab_trans, y_trans.long() + + +class WandBLabelLogger(dict): + """Logger to log y labels to WandB""" + + def __call__( + self, + logger: wandb.sdk.wandb_run.Run, + key: str, + value: torch.Tensor, + flush_every: int = 10, + ): + """Log the labels to WandB + + Args: + logger: The W&B logger. Accessible through `self.logger.experiment` + key: The key to log the labels under. + value: The labels to log. + flush_every: How often to flush the labels to WandB. + + """ + if key not in self.keys(): + self[key] = [value] + else: + self[key].append(value) + + if len(self[key]) % flush_every == 0: + logger.log( + { + key: wandb.Histogram( + torch.flatten(value).detach().cpu().tolist() + ) + } + ) + self[key] = [] diff --git a/tests/model_tests/chestnut_dec_may/train.py b/tests/model_tests/chestnut_dec_may/train.py index c37404e8..863a9476 100644 --- a/tests/model_tests/chestnut_dec_may/train.py +++ b/tests/model_tests/chestnut_dec_may/train.py @@ -43,13 +43,12 @@ def main( ): run = wandb.init() logger = WandbLogger(name="chestnut_dec_may", project="frdc") + # Prepare the dataset train_lab_ds = ds.chestnut_20201218(transform=train_preprocess) - train_unl_ds = ds.chestnut_20201218.unlabelled( transform=train_unl_preprocess(2) ) - val_ds = ds.chestnut_20210510_43m(transform=preprocess) oe = OrdinalEncoder( @@ -65,12 +64,12 @@ def main( # Prepare the datamodule and trainer dm = FRDCDataModule( train_lab_ds=train_lab_ds, - # Pass in None to use the default supervised DM - train_unl_ds=train_unl_ds, + train_unl_ds=train_unl_ds, # None to use supervised DM val_ds=val_ds, batch_size=batch_size, train_iters=train_iters, val_iters=val_iters, + sampling_strategy="stratified", ) trainer = pl.Trainer( @@ -90,12 +89,14 @@ def main( ], logger=logger, ) + m = InceptionV3MixMatchModule( n_classes=n_classes, lr=lr, x_scaler=ss, y_encoder=oe, ) + logger.watch(m) trainer.fit(m, datamodule=dm)