Skip to content

Commit

Permalink
Fix issue with MixMatch Run
Browse files Browse the repository at this point in the history
  • Loading branch information
Eve-ning committed Jun 28, 2024
1 parent bf801ee commit bcea011
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
23 changes: 11 additions & 12 deletions tests/model_tests/chestnut_dec_may/train_mixmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from lightning.pytorch.loggers import WandbLogger

from frdc.load.dataset import FRDCConstRotatedDataset
from frdc.load.dataset import FRDCConstRotatedDataset, ImageStandardScaler
from frdc.load.preset import FRDCDatasetPreset as ds
from frdc.models.efficientnetb1 import EfficientNetB1MixMatchModule
from frdc.train.frdc_datamodule import FRDCDataModule
Expand All @@ -25,6 +25,7 @@
const_weak_aug,
n_rand_strong_aug,
rand_strong_aug,
rand_weak_aug,
)


Expand All @@ -44,16 +45,17 @@ def main(
):
# Prepare the dataset
im_size = 299
train_lab_ds = ds.chestnut_20201218(
transform=rand_strong_aug(im_size),
)
train_lab_ds = ds.chestnut_20201218(transform=rand_weak_aug(im_size))
iss = ImageStandardScaler().fit_nested(train_lab_ds[:][0])
train_lab_ds.transform.transforms.append(iss.transform_nested)
train_unl_ds = ds.chestnut_20201218.unlabelled(
transform=n_rand_strong_aug(im_size, 2)
)
train_unl_ds.transform.transforms.append(iss.transform_nested)
val_ds = ds.chestnut_20210510_43m(
transform=const_weak_aug(im_size),
transform_scale=train_lab_ds.x_scaler,
)
val_ds.transform.transforms.append(iss.transform_nested)

# Prepare the datamodule and trainer
dm = FRDCDataModule(
Expand Down Expand Up @@ -102,14 +104,11 @@ def main(
f"# Chestnut Nature Park (Dec 2020 vs May 2021) MixMatch\n"
f"- Results: [WandB Report]({wandb.run.get_url()})\n"
)

y_true, y_pred = predict(
ds=ds.chestnut_20210510_43m.const_rotated(
transform=const_weak_aug(im_size),
transform_scale=train_lab_ds.x_scaler,
),
model=m,
test_ds = ds.chestnut_20210510_43m.const_rotated(
transform=const_weak_aug(im_size),
)
test_ds.transform.transforms.append(iss.transform_nested)
y_true, y_pred = predict(ds=test_ds, model=m)
fig, ax = plot_confusion_matrix(y_true, y_pred, m.y_encoder.categories_[0])
acc = np.sum(y_true == y_pred) / len(y_true)
ax.set_title(f"Accuracy: {acc:.2%}")
Expand Down
2 changes: 1 addition & 1 deletion tests/model_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


def n_times(f, n: int):
return lambda x: [f(x) for _ in range(n)]
return Compose([lambda x: [f(x) for _ in range(n)]])


def n_rand_weak_aug(size, n_aug: int = 2):
Expand Down

0 comments on commit bcea011

Please sign in to comment.