Skip to content

Commit

Permalink
Merge pull request #72 from FR-DC/frml-153
Browse files Browse the repository at this point in the history
FRML-153 Migrate const rotated dataset spec as variant in presets
  • Loading branch information
Eve-ning authored Jun 11, 2024
2 parents 236c019 + 4fd7875 commit 8e17245
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 63 deletions.
39 changes: 39 additions & 0 deletions src/frdc/load/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import pandas as pd
from PIL import Image
from sklearn.preprocessing import StandardScaler
from torch import rot90
from torch.utils.data import Dataset, ConcatDataset
from torchvision.transforms.v2.functional import hflip

from frdc.conf import (
BAND_CONFIG,
Expand Down Expand Up @@ -334,3 +336,40 @@ def __getitem__(self, item):
if self.transform
else self.ar_segments[item]
)


class FRDCConstRotatedDataset(FRDCDataset):
def __len__(self):
"""Assume that the dataset is 8x larger than it actually is.
There are 8 possible orientations for each image.
1. As-is
2, 3, 4. Rotated 90, 180, 270 degrees
5. Horizontally flipped
6, 7, 8. Horizontally flipped and rotated 90, 180, 270 degrees
"""
return super().__len__() * 8

def __getitem__(self, idx):
"""Alter the getitem method to implement the logic above."""
x, y = super().__getitem__(int(idx // 8))
assert x.ndim == 3, "x must be a 3D tensor"
x_ = None
if idx % 8 == 0:
x_ = x
elif idx % 8 == 1:
x_ = rot90(x, 1, (1, 2))
elif idx % 8 == 2:
x_ = rot90(x, 2, (1, 2))
elif idx % 8 == 3:
x_ = rot90(x, 3, (1, 2))
elif idx % 8 == 4:
x_ = hflip(x)
elif idx % 8 == 5:
x_ = hflip(rot90(x, 1, (1, 2)))
elif idx % 8 == 6:
x_ = hflip(rot90(x, 2, (1, 2)))
elif idx % 8 == 7:
x_ = hflip(rot90(x, 3, (1, 2)))

return x_, y
44 changes: 43 additions & 1 deletion src/frdc/load/preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
Resize,
)

from frdc.load.dataset import FRDCDataset, FRDCUnlabelledDataset
from frdc.load.dataset import (
FRDCDataset,
FRDCUnlabelledDataset,
FRDCConstRotatedDataset,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -147,6 +151,44 @@ def unlabelled(
polycrop_value=polycrop_value,
)

def const_rotated(
self,
transform: Callable[[np.ndarray], Any] = lambda x: x,
transform_scale: bool | StandardScaler = True,
target_transform: Callable[[str], str] = lambda x: x,
use_legacy_bounds: bool = False,
polycrop: bool = False,
polycrop_value: Any = np.nan,
):
"""Returns the Unlabelled Dataset.
Notes:
This simply masks away the labels during __getitem__.
The same behaviour can be achieved by setting __class__ to
FRDCUnlabelledDataset, but this is a more convenient way to do so.
Args:
transform: The transform to apply to the data.
transform_scale: Whether to scale the data. If True, it will fit
a StandardScaler to the data. If a StandardScaler is passed,
it will use that instead. If False, it will not scale the data.
target_transform: The transform to apply to the labels.
use_legacy_bounds: Whether to use the legacy bounds.
polycrop: Whether to use polycrop.
polycrop_value: The value to use for polycrop.
"""
return FRDCConstRotatedDataset(
self.site,
self.date,
self.version,
transform=transform,
transform_scale=transform_scale,
target_transform=target_transform,
use_legacy_bounds=use_legacy_bounds,
polycrop=polycrop,
polycrop_value=polycrop_value,
)


@dataclass
class FRDCDatasetPreset:
Expand Down
26 changes: 12 additions & 14 deletions tests/model_tests/chestnut_dec_may/train_fixmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,29 @@
from frdc.utils.training import predict, plot_confusion_matrix
from model_tests.utils import (
const_weak_aug,
FRDCDatasetStaticEval,
n_rand_weak_strong_aug,
rand_weak_aug,
)


# %%

# Uncomment this to run the W&B monitoring locally
# import os
#
# os.environ["WANDB_MODE"] = "offline"


def main(
batch_size=32,
epochs=10,
train_iters=25,
unlabelled_factor=2,
lr=1e-3,
accelerator="gpu",
wandb_active: bool = True,
wandb_name="chestnut_dec_may",
wandb_project="frdc",
batch_size=32,
epochs=10,
train_iters=25,
unlabelled_factor=2,
lr=1e-3,
accelerator="gpu",
wandb_active: bool = True,
wandb_name="chestnut_dec_may",
wandb_project="frdc",
):
if not wandb_active:
import os
Expand Down Expand Up @@ -112,10 +113,7 @@ def main(
)

y_true, y_pred = predict(
ds=FRDCDatasetStaticEval(
"chestnut_nature_park",
"20210510",
"90deg43m85pct255deg",
ds=ds.chestnut_20210510_43m.const_rotated(
transform=const_weak_aug(im_size),
transform_scale=train_lab_ds.x_scaler,
),
Expand Down
7 changes: 2 additions & 5 deletions tests/model_tests/chestnut_dec_may/train_mixmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
)
from lightning.pytorch.loggers import WandbLogger

from frdc.load.dataset import FRDCConstRotatedDataset
from frdc.load.preset import FRDCDatasetPreset as ds
from frdc.models.efficientnetb1 import EfficientNetB1MixMatchModule
from frdc.train.frdc_datamodule import FRDCDataModule
from frdc.utils.training import predict, plot_confusion_matrix
from model_tests.utils import (
const_weak_aug,
FRDCDatasetStaticEval,
n_rand_strong_aug,
rand_strong_aug,
)
Expand Down Expand Up @@ -104,10 +104,7 @@ def main(
)

y_true, y_pred = predict(
ds=FRDCDatasetStaticEval(
"chestnut_nature_park",
"20210510",
"90deg43m85pct255deg",
ds=ds.chestnut_20210510_43m.const_rotated(
transform=const_weak_aug(im_size),
transform_scale=train_lab_ds.x_scaler,
),
Expand Down
43 changes: 0 additions & 43 deletions tests/model_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@

from pathlib import Path

import numpy as np
import torch
from sklearn.preprocessing import OrdinalEncoder, StandardScaler
from torch import rot90
from torchvision.transforms import RandomVerticalFlip
from torchvision.transforms.v2 import (
Compose,
Expand All @@ -19,52 +16,12 @@
Resize,
)
from torchvision.transforms.v2 import RandomHorizontalFlip
from torchvision.transforms.v2.functional import hflip

from frdc.load.dataset import FRDCDataset

THIS_DIR = Path(__file__).parent

BANDS = ["NB", "NG", "NR", "RE", "NIR"]


class FRDCDatasetStaticEval(FRDCDataset):
def __len__(self):
"""Assume that the dataset is 8x larger than it actually is.
There are 8 possible orientations for each image.
1. As-is
2, 3, 4. Rotated 90, 180, 270 degrees
5. Horizontally flipped
6, 7, 8. Horizontally flipped and rotated 90, 180, 270 degrees
"""
return super().__len__() * 8

def __getitem__(self, idx):
"""Alter the getitem method to implement the logic above."""
x, y = super().__getitem__(int(idx // 8))
assert x.ndim == 3, "x must be a 3D tensor"
x_ = None
if idx % 8 == 0:
x_ = x
elif idx % 8 == 1:
x_ = rot90(x, 1, (1, 2))
elif idx % 8 == 2:
x_ = rot90(x, 2, (1, 2))
elif idx % 8 == 3:
x_ = rot90(x, 3, (1, 2))
elif idx % 8 == 4:
x_ = hflip(x)
elif idx % 8 == 5:
x_ = hflip(rot90(x, 1, (1, 2)))
elif idx % 8 == 6:
x_ = hflip(rot90(x, 2, (1, 2)))
elif idx % 8 == 7:
x_ = hflip(rot90(x, 3, (1, 2)))

return x_, y


def n_times(f, n: int):
return lambda x: [f(x) for _ in range(n)]

Expand Down

0 comments on commit 8e17245

Please sign in to comment.