diff --git a/tests/model_tests/chestnut_dec_may/train.py b/tests/model_tests/chestnut_dec_may/train.py index 53099ece..97123f02 100644 --- a/tests/model_tests/chestnut_dec_may/train.py +++ b/tests/model_tests/chestnut_dec_may/train.py @@ -5,8 +5,11 @@ """ # Uncomment this to run the W&B monitoring locally -# import os -# os.environ["WANDB_MODE"] = "offline" +import os + +from frdc.utils.training import predict, plot_confusion_matrix + +os.environ["WANDB_MODE"] = "offline" from pathlib import Path @@ -21,14 +24,13 @@ from lightning.pytorch.loggers import WandbLogger from sklearn.preprocessing import StandardScaler, OrdinalEncoder -from frdc.load.dataset import FRDCUnlabelledDataset, FRDCDatasetPreset +from frdc.load.dataset import FRDCDatasetPreset as ds from frdc.models.inceptionv3 import InceptionV3MixMatchModule from frdc.train.frdc_datamodule import FRDCDataModule from model_tests.utils import ( train_preprocess, train_unl_preprocess, preprocess, - evaluate, FRDCDatasetFlipped, ) @@ -43,15 +45,13 @@ def main( run = wandb.init() logger = WandbLogger(name="chestnut_dec_may", project="frdc") # Prepare the dataset - train_lab_ds = FRDCDatasetPreset.chestnut_20201218( - transform=train_preprocess - ) + train_lab_ds = ds.chestnut_20201218(transform=train_preprocess) - train_unl_ds = FRDCDatasetPreset.chestnut_20201218.unlabelled( - transform=train_unl_preprocess(2), + train_unl_ds = ds.chestnut_20201218.unlabelled( + transform=train_unl_preprocess(2) ) - val_ds = FRDCDatasetPreset.chestnut_20210510_43m(transform=preprocess) + val_ds = ds.chestnut_20210510_43m(transform=preprocess) oe = OrdinalEncoder( handle_unknown="use_encoded_value", @@ -106,15 +106,20 @@ def main( f"- Results: [WandB Report]({run.get_url()})" ) - fig, acc = evaluate( + y_true, y_pred = predict( ds=FRDCDatasetFlipped( "chestnut_nature_park", "20210510", "90deg43m85pct255deg", transform=preprocess, ), + model_cls=InceptionV3MixMatchModule, ckpt_pth=Path(ckpt.best_model_path), ) + fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0]) + acc = np.sum(y_true == y_pred) / len(y_true) + ax.set_title(f"Accuracy: {acc:.2%}") + wandb.log({"confusion_matrix": wandb.Image(fig)}) wandb.log({"eval_accuracy": acc})