From bcea011b71f4184754a09394e561b4f637d1c809 Mon Sep 17 00:00:00 2001 From: Eve-ning Date: Fri, 28 Jun 2024 15:56:01 +0800 Subject: [PATCH] Fix issue with MixMatch Run --- .../chestnut_dec_may/train_mixmatch.py | 23 +++++++++---------- tests/model_tests/utils.py | 2 +- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/model_tests/chestnut_dec_may/train_mixmatch.py b/tests/model_tests/chestnut_dec_may/train_mixmatch.py index edae259..0d1727e 100644 --- a/tests/model_tests/chestnut_dec_may/train_mixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_mixmatch.py @@ -16,7 +16,7 @@ ) from lightning.pytorch.loggers import WandbLogger -from frdc.load.dataset import FRDCConstRotatedDataset +from frdc.load.dataset import FRDCConstRotatedDataset, ImageStandardScaler from frdc.load.preset import FRDCDatasetPreset as ds from frdc.models.efficientnetb1 import EfficientNetB1MixMatchModule from frdc.train.frdc_datamodule import FRDCDataModule @@ -25,6 +25,7 @@ const_weak_aug, n_rand_strong_aug, rand_strong_aug, + rand_weak_aug, ) @@ -44,16 +45,17 @@ def main( ): # Prepare the dataset im_size = 299 - train_lab_ds = ds.chestnut_20201218( - transform=rand_strong_aug(im_size), - ) + train_lab_ds = ds.chestnut_20201218(transform=rand_weak_aug(im_size)) + iss = ImageStandardScaler().fit_nested(train_lab_ds[:][0]) + train_lab_ds.transform.transforms.append(iss.transform_nested) train_unl_ds = ds.chestnut_20201218.unlabelled( transform=n_rand_strong_aug(im_size, 2) ) + train_unl_ds.transform.transforms.append(iss.transform_nested) val_ds = ds.chestnut_20210510_43m( transform=const_weak_aug(im_size), - transform_scale=train_lab_ds.x_scaler, ) + val_ds.transform.transforms.append(iss.transform_nested) # Prepare the datamodule and trainer dm = FRDCDataModule( @@ -102,14 +104,11 @@ def main( f"# Chestnut Nature Park (Dec 2020 vs May 2021) MixMatch\n" f"- Results: [WandB Report]({wandb.run.get_url()})\n" ) - - y_true, y_pred = predict( - ds=ds.chestnut_20210510_43m.const_rotated( - transform=const_weak_aug(im_size), - transform_scale=train_lab_ds.x_scaler, - ), - model=m, + test_ds = ds.chestnut_20210510_43m.const_rotated( + transform=const_weak_aug(im_size), ) + test_ds.transform.transforms.append(iss.transform_nested) + y_true, y_pred = predict(ds=test_ds, model=m) fig, ax = plot_confusion_matrix(y_true, y_pred, m.y_encoder.categories_[0]) acc = np.sum(y_true == y_pred) / len(y_true) ax.set_title(f"Accuracy: {acc:.2%}") diff --git a/tests/model_tests/utils.py b/tests/model_tests/utils.py index 5399039..d21ef1d 100644 --- a/tests/model_tests/utils.py +++ b/tests/model_tests/utils.py @@ -23,7 +23,7 @@ def n_times(f, n: int): - return lambda x: [f(x) for _ in range(n)] + return Compose([lambda x: [f(x) for _ in range(n)]]) def n_rand_weak_aug(size, n_aug: int = 2):