diff --git a/tests/model_tests/chestnut_dec_may/train.py b/tests/model_tests/chestnut_dec_may/train.py index 863a9476..8d4aad1c 100644 --- a/tests/model_tests/chestnut_dec_may/train.py +++ b/tests/model_tests/chestnut_dec_may/train.py @@ -3,6 +3,7 @@ This test is done by training a model on the 20201218 dataset, then testing on the 20210510 dataset. """ +import os # Uncomment this to run the W&B monitoring locally # import os @@ -41,9 +42,6 @@ def main( val_iters=15, lr=1e-3, ): - run = wandb.init() - logger = WandbLogger(name="chestnut_dec_may", project="frdc") - # Prepare the dataset train_lab_ds = ds.chestnut_20201218(transform=train_preprocess) train_unl_ds = ds.chestnut_20201218.unlabelled( @@ -87,7 +85,9 @@ def main( monitor="val_loss", mode="min", save_top_k=1 ), ], - logger=logger, + logger=( + logger := WandbLogger(name="chestnut_dec_may", project="frdc") + ), ) m = InceptionV3MixMatchModule( @@ -103,7 +103,7 @@ def main( with open(Path(__file__).parent / "report.md", "w") as f: f.write( f"# Chestnut Nature Park (Dec 2020 vs May 2021)\n" - f"- Results: [WandB Report]({run.get_url()})" + f"- Results: [WandB Report]({wandb.run.get_url()})" ) y_true, y_pred = predict( @@ -133,8 +133,8 @@ def main( VAL_ITERS = 15 LR = 1e-3 - assert wandb.run is None - wandb.setup(wandb.Settings(program=__name__, program_relpath=__name__)) + wandb.login(key=os.environ["WANDB_API_KEY"]) + main( batch_size=BATCH_SIZE, epochs=EPOCHS,