From d96030ddacc5c224eb1bfa283bd0202069da6fdb Mon Sep 17 00:00:00 2001 From: Evening Date: Mon, 8 Jan 2024 13:50:17 +0800 Subject: [PATCH 1/3] Minor Black formatting --- src/frdc/conf.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/frdc/conf.py b/src/frdc/conf.py index bbbbdf68..0d32eb45 100644 --- a/src/frdc/conf.py +++ b/src/frdc/conf.py @@ -88,5 +88,3 @@ f"LABEL_STUDIO_CLIENT will be None." ) LABEL_STUDIO_CLIENT = None - - From 3794501c943c9e3ee007cac941bb7afb4d2ea176 Mon Sep 17 00:00:00 2001 From: Evening Date: Mon, 8 Jan 2024 13:50:47 +0800 Subject: [PATCH 2/3] Fix issue with WandB hist logger too many bins --- src/frdc/train/mixmatch_module.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index 75e581e8..9e3af191 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -154,7 +154,11 @@ def progress(self): def training_step(self, batch, batch_idx): (x_lbl, y_lbl), x_unls = batch self.lbl_logger( - self.logger.experiment, "Input Y Label", y_lbl, flush_every=10 + self.logger.experiment, + "Input Y Label", + y_lbl, + flush_every=10, + num_bins=self.n_classes, ) y_lbl_ohe = one_hot(y_lbl.long(), num_classes=self.n_classes) @@ -186,6 +190,7 @@ def training_step(self, batch, batch_idx): "Labelled Y Pred", torch.argmax(y_mix_lbl_pred, dim=1), flush_every=10, + num_bins=self.n_classes, ) loss_unl = self.loss_unl(y_mix_unl_pred, y_mix_unl) self.lbl_logger( @@ -193,6 +198,7 @@ def training_step(self, batch, batch_idx): "Unlabelled Y Pred", torch.argmax(y_mix_unl_pred, dim=1), flush_every=10, + num_bins=self.n_classes, ) loss_unl_scale = self.loss_unl_scaler(progress=self.progress) @@ -218,7 +224,11 @@ def on_after_backward(self) -> None: def validation_step(self, batch, batch_idx): x, y = batch self.lbl_logger( - self.logger.experiment, "Val Input Y Label", y, flush_every=1 + self.logger.experiment, + "Val Input Y Label", + y, + flush_every=1, + num_bins=self.n_classes, ) y_pred = self.ema_model(x) self.lbl_logger( @@ -226,6 +236,7 @@ def validation_step(self, batch, batch_idx): "Val Pred Y Label", torch.argmax(y_pred, dim=1), flush_every=1, + num_bins=self.n_classes, ) loss = F.cross_entropy(y_pred, y.long()) @@ -334,6 +345,7 @@ def __call__( logger: wandb.sdk.wandb_run.Run, key: str, value: torch.Tensor, + num_bins: int, flush_every: int = 10, ): """Log the labels to WandB @@ -354,7 +366,8 @@ def __call__( logger.log( { key: wandb.Histogram( - torch.flatten(value).detach().cpu().tolist() + torch.flatten(value).detach().cpu().tolist(), + num_bins=num_bins, ) } ) From c8b050a2ebe10fe88d17e39b068f66f1e30480f9 Mon Sep 17 00:00:00 2001 From: Evening Date: Mon, 8 Jan 2024 13:51:06 +0800 Subject: [PATCH 3/3] Fix issue with redundant initializing wandb --- tests/model_tests/chestnut_dec_may/train.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/model_tests/chestnut_dec_may/train.py b/tests/model_tests/chestnut_dec_may/train.py index 863a9476..8d4aad1c 100644 --- a/tests/model_tests/chestnut_dec_may/train.py +++ b/tests/model_tests/chestnut_dec_may/train.py @@ -3,6 +3,7 @@ This test is done by training a model on the 20201218 dataset, then testing on the 20210510 dataset. """ +import os # Uncomment this to run the W&B monitoring locally # import os @@ -41,9 +42,6 @@ def main( val_iters=15, lr=1e-3, ): - run = wandb.init() - logger = WandbLogger(name="chestnut_dec_may", project="frdc") - # Prepare the dataset train_lab_ds = ds.chestnut_20201218(transform=train_preprocess) train_unl_ds = ds.chestnut_20201218.unlabelled( @@ -87,7 +85,9 @@ def main( monitor="val_loss", mode="min", save_top_k=1 ), ], - logger=logger, + logger=( + logger := WandbLogger(name="chestnut_dec_may", project="frdc") + ), ) m = InceptionV3MixMatchModule( @@ -103,7 +103,7 @@ def main( with open(Path(__file__).parent / "report.md", "w") as f: f.write( f"# Chestnut Nature Park (Dec 2020 vs May 2021)\n" - f"- Results: [WandB Report]({run.get_url()})" + f"- Results: [WandB Report]({wandb.run.get_url()})" ) y_true, y_pred = predict( @@ -133,8 +133,8 @@ def main( VAL_ITERS = 15 LR = 1e-3 - assert wandb.run is None - wandb.setup(wandb.Settings(program=__name__, program_relpath=__name__)) + wandb.login(key=os.environ["WANDB_API_KEY"]) + main( batch_size=BATCH_SIZE, epochs=EPOCHS,