diff --git a/tests/model_tests/chestnut_dec_may/train.py b/tests/model_tests/chestnut_dec_may/train.py index 847ef65..a23cbcf 100644 --- a/tests/model_tests/chestnut_dec_may/train.py +++ b/tests/model_tests/chestnut_dec_may/train.py @@ -18,7 +18,7 @@ from sklearn.preprocessing import StandardScaler, OrdinalEncoder from frdc.load.preset import FRDCDatasetPreset as ds -from frdc.models.inceptionv3 import InceptionV3MixMatchModule +from frdc.models.efficientnetb1 import EfficientNetB1MixMatchModule from frdc.train.frdc_datamodule import FRDCDataModule from frdc.utils.training import predict, plot_confusion_matrix from model_tests.utils import ( @@ -61,11 +61,14 @@ def main( wandb_project="frdc", ): # Prepare the dataset - train_lab_ds = ds.chestnut_20201218(transform=train_preprocess_augment) + im_size = 299 + train_lab_ds = ds.chestnut_20201218( + transform=train_preprocess_augment(im_size) + ) train_unl_ds = ds.chestnut_20201218.unlabelled( - transform=train_unl_preprocess(2) + transform=train_unl_preprocess(im_size, 2) ) - val_ds = ds.chestnut_20210510_43m(transform=val_preprocess) + val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size)) # Prepare the datamodule and trainer dm = FRDCDataModule( @@ -103,13 +106,13 @@ def main( oe = get_y_encoder(train_lab_ds.targets) ss = get_x_scaler(train_lab_ds.ar_segments) - m = InceptionV3MixMatchModule( + m = EfficientNetB1MixMatchModule( in_channels=train_lab_ds.ar.shape[-1], n_classes=len(oe.categories_[0]), lr=lr, x_scaler=ss, y_encoder=oe, - imagenet_scaling=True, + frozen=True, ) trainer.fit(m, datamodule=dm) @@ -125,10 +128,9 @@ def main( "chestnut_nature_park", "20210510", "90deg43m85pct255deg", - transform=val_preprocess, + transform=val_preprocess(im_size), ), - model_cls=InceptionV3MixMatchModule, - ckpt_pth=Path(ckpt.best_model_path), + model=m, ) fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0]) acc = np.sum(y_true == y_pred) / len(y_true) @@ -152,6 +154,6 @@ def main( epochs=EPOCHS, train_iters=TRAIN_ITERS, lr=LR, - wandb_name="Try with Inception Unfrozen & Random Erasing", + wandb_name="EfficientNet 299x299", wandb_project="frdc-dev", )