Skip to content

Commit

Permalink
Migrate references for train
Browse files Browse the repository at this point in the history
  • Loading branch information
Eve-ning committed Dec 28, 2023
1 parent 2928147 commit 60b835d
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions tests/model_tests/chestnut_dec_may/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
"""

# Uncomment this to run the W&B monitoring locally
# import os
# os.environ["WANDB_MODE"] = "offline"
import os

from frdc.utils.training import predict, plot_confusion_matrix

os.environ["WANDB_MODE"] = "offline"

from pathlib import Path

Expand All @@ -21,14 +24,13 @@
from lightning.pytorch.loggers import WandbLogger
from sklearn.preprocessing import StandardScaler, OrdinalEncoder

from frdc.load.dataset import FRDCUnlabelledDataset, FRDCDatasetPreset
from frdc.load.dataset import FRDCDatasetPreset as ds
from frdc.models.inceptionv3 import InceptionV3MixMatchModule
from frdc.train.frdc_datamodule import FRDCDataModule
from model_tests.utils import (
train_preprocess,
train_unl_preprocess,
preprocess,
evaluate,
FRDCDatasetFlipped,
)

Expand All @@ -43,15 +45,13 @@ def main(
run = wandb.init()
logger = WandbLogger(name="chestnut_dec_may", project="frdc")
# Prepare the dataset
train_lab_ds = FRDCDatasetPreset.chestnut_20201218(
transform=train_preprocess
)
train_lab_ds = ds.chestnut_20201218(transform=train_preprocess)

train_unl_ds = FRDCDatasetPreset.chestnut_20201218.unlabelled(
transform=train_unl_preprocess(2),
train_unl_ds = ds.chestnut_20201218.unlabelled(
transform=train_unl_preprocess(2)
)

val_ds = FRDCDatasetPreset.chestnut_20210510_43m(transform=preprocess)
val_ds = ds.chestnut_20210510_43m(transform=preprocess)

oe = OrdinalEncoder(
handle_unknown="use_encoded_value",
Expand Down Expand Up @@ -106,15 +106,20 @@ def main(
f"- Results: [WandB Report]({run.get_url()})"
)

fig, acc = evaluate(
y_true, y_pred = predict(
ds=FRDCDatasetFlipped(
"chestnut_nature_park",
"20210510",
"90deg43m85pct255deg",
transform=preprocess,
),
model_cls=InceptionV3MixMatchModule,
ckpt_pth=Path(ckpt.best_model_path),
)
fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0])
acc = np.sum(y_true == y_pred) / len(y_true)
ax.set_title(f"Accuracy: {acc:.2%}")

wandb.log({"confusion_matrix": wandb.Image(fig)})
wandb.log({"eval_accuracy": acc})

Expand Down

0 comments on commit 60b835d

Please sign in to comment.