Skip to content

Commit

Permalink
Merge pull request #64 from FR-DC/frml-145
Browse files Browse the repository at this point in the history
FRML-145 Update evaluation dataset to use all 8 orients
  • Loading branch information
Eve-ning authored May 29, 2024
2 parents 2716e60 + 3302063 commit 0547000
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 19 deletions.
4 changes: 2 additions & 2 deletions tests/model_tests/chestnut_dec_may/train_fixmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from frdc.utils.training import predict, plot_confusion_matrix
from model_tests.utils import (
val_preprocess,
FRDCDatasetFlipped,
FRDCDatasetStaticEval,
n_weak_strong_aug,
get_y_encoder,
get_x_scaler,
Expand Down Expand Up @@ -116,7 +116,7 @@ def main(
)

y_true, y_pred = predict(
ds=FRDCDatasetFlipped(
ds=FRDCDatasetStaticEval(
"chestnut_nature_park",
"20210510",
"90deg43m85pct255deg",
Expand Down
4 changes: 2 additions & 2 deletions tests/model_tests/chestnut_dec_may/train_mixmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from frdc.utils.training import predict, plot_confusion_matrix
from model_tests.utils import (
val_preprocess,
FRDCDatasetFlipped,
FRDCDatasetStaticEval,
n_strong_aug,
strong_aug,
get_y_encoder,
Expand Down Expand Up @@ -107,7 +107,7 @@ def main(
)

y_true, y_pred = predict(
ds=FRDCDatasetFlipped(
ds=FRDCDatasetStaticEval(
"chestnut_nature_park",
"20210510",
"90deg43m85pct255deg",
Expand Down
46 changes: 31 additions & 15 deletions tests/model_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch
from sklearn.preprocessing import OrdinalEncoder, StandardScaler
from torch import rot90
from torchvision.transforms import RandomVerticalFlip
from torchvision.transforms.v2 import (
Compose,
Expand All @@ -18,6 +19,7 @@
Resize,
)
from torchvision.transforms.v2 import RandomHorizontalFlip
from torchvision.transforms.v2.functional import hflip

from frdc.load.dataset import FRDCDataset

Expand All @@ -26,27 +28,41 @@
BANDS = ["NB", "NG", "NR", "RE", "NIR"]


class FRDCDatasetFlipped(FRDCDataset):
class FRDCDatasetStaticEval(FRDCDataset):
def __len__(self):
"""Assume that the dataset is 4x larger than it actually is.
"""Assume that the dataset is 8x larger than it actually is.
For example, for index 0, we return the original image. For index 1, we
return the horizontally flipped image and so on, until index 3.
Then, return the next image for index 4, and so on.
There are 8 possible orientations for each image.
1. As-is
2, 3, 4. Rotated 90, 180, 270 degrees
5. Horizontally flipped
6, 7, 8. Horizontally flipped and rotated 90, 180, 270 degrees
"""
return super().__len__() * 4
return super().__len__() * 8

def __getitem__(self, idx):
"""Alter the getitem method to implement the logic above."""
x, y = super().__getitem__(int(idx // 4))
if idx % 4 == 0:
return x, y
elif idx % 4 == 1:
return RandomHorizontalFlip(p=1)(x), y
elif idx % 4 == 2:
return RandomVerticalFlip(p=1)(x), y
elif idx % 4 == 3:
return RandomHorizontalFlip(p=1)(RandomVerticalFlip(p=1)(x)), y
x, y = super().__getitem__(int(idx // 8))
assert x.ndim == 3, "x must be a 3D tensor"
x_ = None
if idx % 8 == 0:
x_ = x
elif idx % 8 == 1:
x_ = rot90(x, 1, (1, 2))
elif idx % 8 == 2:
x_ = rot90(x, 2, (1, 2))
elif idx % 8 == 3:
x_ = rot90(x, 3, (1, 2))
elif idx % 8 == 4:
x_ = hflip(x)
elif idx % 8 == 5:
x_ = hflip(rot90(x, 1, (1, 2)))
elif idx % 8 == 6:
x_ = hflip(rot90(x, 2, (1, 2)))
elif idx % 8 == 7:
x_ = hflip(rot90(x, 3, (1, 2)))

return x_, y


def val_preprocess(size: int):
Expand Down

0 comments on commit 0547000

Please sign in to comment.