Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FRML-146 Migrate Preprocessing toward dataset for independent dataset processing #71

Merged
merged 4 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions src/frdc/load/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, ConcatDataset

from frdc.conf import (
Expand Down Expand Up @@ -71,8 +72,9 @@ def __init__(
site: str,
date: str,
version: str | None,
transform: Callable[[list[np.ndarray]], Any] = None,
target_transform: Callable[[list[str]], list[str]] = None,
transform: Callable[[np.ndarray], Any] = lambda x: x,
transform_scale: bool | StandardScaler = True,
target_transform: Callable[[str], str] = lambda x: x,
use_legacy_bounds: bool = False,
polycrop: bool = False,
polycrop_value: Any = np.nan,
Expand All @@ -95,13 +97,17 @@ def __init__(
date: The date of the dataset, e.g. "20201218".
version: The version of the dataset, e.g. "183deg".
transform: The transform to apply to each segment.
transform_scale: Whether to scale the data. If True, it will fit
a StandardScaler to the data. If a StandardScaler is passed,
it will use that instead. If False, it will not scale the data.
target_transform: The transform to apply to each label.
use_legacy_bounds: Whether to use the legacy bounds.csv file.
This will automatically be set to True if LABEL_STUDIO_CLIENT
is None, which happens when Label Studio cannot be connected
to.
polycrop: Whether to further crop the segments via its polygon
bounds. The cropped area will be padded with np.nan.
polycrop_value: The value to pad the cropped area with.
"""
self.site = site
self.date = date
Expand All @@ -125,17 +131,40 @@ def __init__(
self.transform = transform
self.target_transform = target_transform

if transform_scale is True:
self.x_scaler = StandardScaler()
self.x_scaler.fit(
np.concatenate(
[
# Segments: [H x W x C] -> [H*W, C]
# Reshaping is necessary for StandardScaler
segm.reshape(-1, segm.shape[-1])
for segm in self.ar_segments
]
)
)
self.transform = lambda x: transform(
self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape(
x.shape
)
)
elif isinstance(transform_scale, StandardScaler):
self.x_scaler = transform_scale
self.transform = lambda x: transform(
self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape(
x.shape
)
)
else:
self.x_scaler = None

def __len__(self):
return len(self.ar_segments)

def __getitem__(self, idx):
return (
self.transform(self.ar_segments[idx])
if self.transform
else self.ar_segments[idx],
self.target_transform(self.targets[idx])
if self.target_transform
else self.targets[idx],
self.transform(self.ar_segments[idx]),
self.target_transform(self.targets[idx]),
)

@property
Expand Down
64 changes: 51 additions & 13 deletions src/frdc/load/preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import torch
from sklearn.preprocessing import StandardScaler
from torchvision.transforms.v2 import (
Compose,
ToImage,
Expand Down Expand Up @@ -47,15 +48,28 @@ class FRDCDatasetPartial:

def __call__(
self,
transform: Callable[[list[np.ndarray]], Any] = None,
target_transform: Callable[[list[str]], list[str]] = None,
transform: Callable[[np.ndarray], Any] = lambda x: x,
transform_scale: bool | StandardScaler = True,
target_transform: Callable[[str], str] = lambda x: x,
use_legacy_bounds: bool = False,
polycrop: bool = False,
polycrop_value: Any = np.nan,
):
"""Alias for labelled()."""
"""Alias for labelled().

Args:
transform: The transform to apply to the data.
transform_scale: Whether to scale the data. If True, it will fit
a StandardScaler to the data. If a StandardScaler is passed,
it will use that instead. If False, it will not scale the data.
target_transform: The transform to apply to the labels.
use_legacy_bounds: Whether to use the legacy bounds.
polycrop: Whether to use polycrop.
polycrop_value: The value to use for polycrop.
"""
return self.labelled(
transform,
transform_scale,
target_transform,
use_legacy_bounds=use_legacy_bounds,
polycrop=polycrop,
Expand All @@ -64,28 +78,42 @@ def __call__(

def labelled(
self,
transform: Callable[[list[np.ndarray]], Any] = None,
target_transform: Callable[[list[str]], list[str]] = None,
transform: Callable[[np.ndarray], Any] = lambda x: x,
transform_scale: bool | StandardScaler = True,
target_transform: Callable[[str], str] = lambda x: x,
use_legacy_bounds: bool = False,
polycrop: bool = False,
polycrop_value: Any = np.nan,
):
"""Returns the Labelled Dataset."""
"""Returns the Labelled Dataset.

Args:
transform: The transform to apply to the data.
transform_scale: Whether to scale the data. If True, it will fit
a StandardScaler to the data. If a StandardScaler is passed,
it will use that instead. If False, it will not scale the data.
target_transform: The transform to apply to the labels.
use_legacy_bounds: Whether to use the legacy bounds.
polycrop: Whether to use polycrop.
polycrop_value: The value to use for polycrop.
"""
return FRDCDataset(
self.site,
self.date,
self.version,
transform,
target_transform,
transform=transform,
transform_scale=transform_scale,
target_transform=target_transform,
use_legacy_bounds=use_legacy_bounds,
polycrop=polycrop,
polycrop_value=polycrop_value,
)

def unlabelled(
self,
transform: Callable[[list[np.ndarray]], Any] = None,
target_transform: Callable[[list[str]], list[str]] = None,
transform: Callable[[np.ndarray], Any] = lambda x: x,
transform_scale: bool | StandardScaler = True,
target_transform: Callable[[str], str] = lambda x: x,
use_legacy_bounds: bool = False,
polycrop: bool = False,
polycrop_value: Any = np.nan,
Expand All @@ -96,13 +124,24 @@ def unlabelled(
This simply masks away the labels during __getitem__.
The same behaviour can be achieved by setting __class__ to
FRDCUnlabelledDataset, but this is a more convenient way to do so.

Args:
transform: The transform to apply to the data.
transform_scale: Whether to scale the data. If True, it will fit
a StandardScaler to the data. If a StandardScaler is passed,
it will use that instead. If False, it will not scale the data.
target_transform: The transform to apply to the labels.
use_legacy_bounds: Whether to use the legacy bounds.
polycrop: Whether to use polycrop.
polycrop_value: The value to use for polycrop.
"""
return FRDCUnlabelledDataset(
self.site,
self.date,
self.version,
transform,
target_transform,
transform=transform,
transform_scale=transform_scale,
target_transform=target_transform,
use_legacy_bounds=use_legacy_bounds,
polycrop=polycrop,
polycrop_value=polycrop_value,
Expand Down Expand Up @@ -167,5 +206,4 @@ class FRDCDatasetPreset:
Resize((resize, resize)),
]
),
target_transform=None,
)
34 changes: 9 additions & 25 deletions src/frdc/models/efficientnetb1.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from copy import deepcopy
from typing import Dict, Any
from typing import Sequence

import torch
from sklearn.preprocessing import OrdinalEncoder, StandardScaler
from torch import nn
from torchvision.models import (
EfficientNet,
efficientnet_b1,
EfficientNet_B1_Weights,
)

from frdc.models.utils import save_unfrozen, load_checkpoint_lenient
from frdc.train.fixmatch_module import FixMatchModule
from frdc.train.mixmatch_module import MixMatchModule
from frdc.utils.ema import EMA
Expand Down Expand Up @@ -81,10 +79,8 @@ def __init__(
self,
*,
in_channels: int,
n_classes: int,
out_targets: Sequence[str],
lr: float,
x_scaler: StandardScaler,
y_encoder: OrdinalEncoder,
ema_lr: float = 0.001,
weight_decay: float = 1e-5,
frozen: bool = True,
Expand All @@ -93,10 +89,8 @@ def __init__(

Args:
in_channels: The number of input channels.
n_classes: The number of classes.
out_targets: The output targets.
lr: The learning rate.
x_scaler: The X input StandardScaler.
y_encoder: The Y input OrdinalEncoder.
ema_lr: The learning rate for the EMA model.
weight_decay: The weight decay.
frozen: Whether to freeze the base model.
Expand All @@ -108,16 +102,14 @@ def __init__(
self.weight_decay = weight_decay

super().__init__(
n_classes=n_classes,
x_scaler=x_scaler,
y_encoder=y_encoder,
out_targets=out_targets,
sharpen_temp=0.5,
mix_beta_alpha=0.75,
)

self.eff = efficientnet_b1_backbone(in_channels, frozen)
self.fc = nn.Sequential(
nn.Linear(self.EFF_OUT_DIMS, n_classes),
nn.Linear(self.EFF_OUT_DIMS, self.n_classes),
nn.Softmax(dim=1),
)

Expand Down Expand Up @@ -155,21 +147,17 @@ def __init__(
self,
*,
in_channels: int,
n_classes: int,
out_targets: Sequence[str],
lr: float,
x_scaler: StandardScaler,
y_encoder: OrdinalEncoder,
weight_decay: float = 1e-5,
frozen: bool = True,
):
"""Initialize the EfficientNet model.

Args:
in_channels: The number of input channels.
n_classes: The number of classes.
out_targets: The output targets.
lr: The learning rate.
x_scaler: The X input StandardScaler.
y_encoder: The Y input OrdinalEncoder.
weight_decay: The weight decay.
frozen: Whether to freeze the base model.

Expand All @@ -179,16 +167,12 @@ def __init__(
self.lr = lr
self.weight_decay = weight_decay

super().__init__(
n_classes=n_classes,
x_scaler=x_scaler,
y_encoder=y_encoder,
)
super().__init__(out_targets=out_targets)

self.eff = efficientnet_b1_backbone(in_channels, frozen)

self.fc = nn.Sequential(
nn.Linear(self.EFF_OUT_DIMS, n_classes),
nn.Linear(self.EFF_OUT_DIMS, self.n_classes),
nn.Softmax(dim=1),
)

Expand Down
Loading
Loading