Skip to content

Commit

Permalink
Update evaluation dataset to use all 8 orients
Browse files Browse the repository at this point in the history
  • Loading branch information
Eve-ning committed May 29, 2024
1 parent 2716e60 commit 3302063
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 19 deletions.
4 changes: 2 additions & 2 deletions tests/model_tests/chestnut_dec_may/train_fixmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from frdc.utils.training import predict, plot_confusion_matrix
from model_tests.utils import (
val_preprocess,
FRDCDatasetFlipped,
FRDCDatasetStaticEval,
n_weak_strong_aug,
get_y_encoder,
get_x_scaler,
Expand Down Expand Up @@ -116,7 +116,7 @@ def main(
)

y_true, y_pred = predict(
ds=FRDCDatasetFlipped(
ds=FRDCDatasetStaticEval(
"chestnut_nature_park",
"20210510",
"90deg43m85pct255deg",
Expand Down
4 changes: 2 additions & 2 deletions tests/model_tests/chestnut_dec_may/train_mixmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from frdc.utils.training import predict, plot_confusion_matrix
from model_tests.utils import (
val_preprocess,
FRDCDatasetFlipped,
FRDCDatasetStaticEval,
n_strong_aug,
strong_aug,
get_y_encoder,
Expand Down Expand Up @@ -107,7 +107,7 @@ def main(
)

y_true, y_pred = predict(
ds=FRDCDatasetFlipped(
ds=FRDCDatasetStaticEval(
"chestnut_nature_park",
"20210510",
"90deg43m85pct255deg",
Expand Down
46 changes: 31 additions & 15 deletions tests/model_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch
from sklearn.preprocessing import OrdinalEncoder, StandardScaler
from torch import rot90
from torchvision.transforms import RandomVerticalFlip
from torchvision.transforms.v2 import (
Compose,
Expand All @@ -18,6 +19,7 @@
Resize,
)
from torchvision.transforms.v2 import RandomHorizontalFlip
from torchvision.transforms.v2.functional import hflip

from frdc.load.dataset import FRDCDataset

Expand All @@ -26,27 +28,41 @@
BANDS = ["NB", "NG", "NR", "RE", "NIR"]


class FRDCDatasetFlipped(FRDCDataset):
class FRDCDatasetStaticEval(FRDCDataset):
def __len__(self):
"""Assume that the dataset is 4x larger than it actually is.
"""Assume that the dataset is 8x larger than it actually is.
For example, for index 0, we return the original image. For index 1, we
return the horizontally flipped image and so on, until index 3.
Then, return the next image for index 4, and so on.
There are 8 possible orientations for each image.
1. As-is
2, 3, 4. Rotated 90, 180, 270 degrees
5. Horizontally flipped
6, 7, 8. Horizontally flipped and rotated 90, 180, 270 degrees
"""
return super().__len__() * 4
return super().__len__() * 8

def __getitem__(self, idx):
"""Alter the getitem method to implement the logic above."""
x, y = super().__getitem__(int(idx // 4))
if idx % 4 == 0:
return x, y
elif idx % 4 == 1:
return RandomHorizontalFlip(p=1)(x), y
elif idx % 4 == 2:
return RandomVerticalFlip(p=1)(x), y
elif idx % 4 == 3:
return RandomHorizontalFlip(p=1)(RandomVerticalFlip(p=1)(x)), y
x, y = super().__getitem__(int(idx // 8))
assert x.ndim == 3, "x must be a 3D tensor"
x_ = None
if idx % 8 == 0:
x_ = x
elif idx % 8 == 1:
x_ = rot90(x, 1, (1, 2))
elif idx % 8 == 2:
x_ = rot90(x, 2, (1, 2))
elif idx % 8 == 3:
x_ = rot90(x, 3, (1, 2))
elif idx % 8 == 4:
x_ = hflip(x)
elif idx % 8 == 5:
x_ = hflip(rot90(x, 1, (1, 2)))
elif idx % 8 == 6:
x_ = hflip(rot90(x, 2, (1, 2)))
elif idx % 8 == 7:
x_ = hflip(rot90(x, 3, (1, 2)))

return x_, y


def val_preprocess(size: int):
Expand Down

0 comments on commit 3302063

Please sign in to comment.