Skip to content

Commit

Permalink
Swap to EfficientNet & Update aug params
Browse files Browse the repository at this point in the history
  • Loading branch information
Eve-ning committed Feb 21, 2024
1 parent 555d812 commit b3e805c
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions tests/model_tests/chestnut_dec_may/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sklearn.preprocessing import StandardScaler, OrdinalEncoder

from frdc.load.preset import FRDCDatasetPreset as ds
from frdc.models.inceptionv3 import InceptionV3MixMatchModule
from frdc.models.efficientnetb1 import EfficientNetB1MixMatchModule
from frdc.train.frdc_datamodule import FRDCDataModule
from frdc.utils.training import predict, plot_confusion_matrix
from model_tests.utils import (
Expand Down Expand Up @@ -61,11 +61,14 @@ def main(
wandb_project="frdc",
):
# Prepare the dataset
train_lab_ds = ds.chestnut_20201218(transform=train_preprocess_augment)
im_size = 299
train_lab_ds = ds.chestnut_20201218(
transform=train_preprocess_augment(im_size)
)
train_unl_ds = ds.chestnut_20201218.unlabelled(
transform=train_unl_preprocess(2)
transform=train_unl_preprocess(im_size, 2)
)
val_ds = ds.chestnut_20210510_43m(transform=val_preprocess)
val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size))

# Prepare the datamodule and trainer
dm = FRDCDataModule(
Expand Down Expand Up @@ -103,13 +106,13 @@ def main(
oe = get_y_encoder(train_lab_ds.targets)
ss = get_x_scaler(train_lab_ds.ar_segments)

m = InceptionV3MixMatchModule(
m = EfficientNetB1MixMatchModule(
in_channels=train_lab_ds.ar.shape[-1],
n_classes=len(oe.categories_[0]),
lr=lr,
x_scaler=ss,
y_encoder=oe,
imagenet_scaling=True,
frozen=True,
)

trainer.fit(m, datamodule=dm)
Expand All @@ -125,10 +128,9 @@ def main(
"chestnut_nature_park",
"20210510",
"90deg43m85pct255deg",
transform=val_preprocess,
transform=val_preprocess(im_size),
),
model_cls=InceptionV3MixMatchModule,
ckpt_pth=Path(ckpt.best_model_path),
model=m,
)
fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0])
acc = np.sum(y_true == y_pred) / len(y_true)
Expand All @@ -152,6 +154,6 @@ def main(
epochs=EPOCHS,
train_iters=TRAIN_ITERS,
lr=LR,
wandb_name="Try with Inception Unfrozen & Random Erasing",
wandb_name="EfficientNet 299x299",
wandb_project="frdc-dev",
)

0 comments on commit b3e805c

Please sign in to comment.