Skip to content

Commit

Permalink
Allow StandardScaler to override default scaler
Browse files Browse the repository at this point in the history
  • Loading branch information
Eve-ning committed Jun 10, 2024
1 parent ab9de7c commit 13b1593
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 14 deletions.
15 changes: 12 additions & 3 deletions src/frdc/load/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
date: str,
version: str | None,
transform: Callable[[np.ndarray], Any] = lambda x: x,
transform_scale: bool = False,
transform_scale: bool | StandardScaler = True,
target_transform: Callable[[str], str] = lambda x: x,
use_legacy_bounds: bool = False,
polycrop: bool = False,
Expand All @@ -97,7 +97,9 @@ def __init__(
date: The date of the dataset, e.g. "20201218".
version: The version of the dataset, e.g. "183deg".
transform: The transform to apply to each segment.
transform_scale: Prepends a scaling transform to the transform.
transform_scale: Whether to scale the data. If True, it will fit
a StandardScaler to the data. If a StandardScaler is passed,
it will use that instead. If False, it will not scale the data.
target_transform: The transform to apply to each label.
use_legacy_bounds: Whether to use the legacy bounds.csv file.
This will automatically be set to True if LABEL_STUDIO_CLIENT
Expand Down Expand Up @@ -129,7 +131,7 @@ def __init__(
self.transform = transform
self.target_transform = target_transform

if transform_scale:
if transform_scale is True:
self.x_scaler = StandardScaler()
self.x_scaler.fit(
np.concatenate(
Expand All @@ -146,6 +148,13 @@ def __init__(
x.shape
)
)
elif isinstance(transform_scale, StandardScaler):
self.x_scaler = transform_scale
self.transform = lambda x: transform(
self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape(
x.shape
)
)
else:
self.x_scaler = None

Expand Down
19 changes: 13 additions & 6 deletions src/frdc/load/preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import torch
from sklearn.preprocessing import StandardScaler
from torchvision.transforms.v2 import (
Compose,
ToImage,
Expand Down Expand Up @@ -48,7 +49,7 @@ class FRDCDatasetPartial:
def __call__(
self,
transform: Callable[[np.ndarray], Any] = lambda x: x,
transform_scale: bool = True,
transform_scale: bool | StandardScaler = True,
target_transform: Callable[[str], str] = lambda x: x,
use_legacy_bounds: bool = False,
polycrop: bool = False,
Expand All @@ -58,7 +59,9 @@ def __call__(
Args:
transform: The transform to apply to the data.
transform_scale: Whether to scale the data.
transform_scale: Whether to scale the data. If True, it will fit
a StandardScaler to the data. If a StandardScaler is passed,
it will use that instead. If False, it will not scale the data.
target_transform: The transform to apply to the labels.
use_legacy_bounds: Whether to use the legacy bounds.
polycrop: Whether to use polycrop.
Expand All @@ -76,7 +79,7 @@ def __call__(
def labelled(
self,
transform: Callable[[np.ndarray], Any] = lambda x: x,
transform_scale: bool = True,
transform_scale: bool | StandardScaler = True,
target_transform: Callable[[str], str] = lambda x: x,
use_legacy_bounds: bool = False,
polycrop: bool = False,
Expand All @@ -86,7 +89,9 @@ def labelled(
Args:
transform: The transform to apply to the data.
transform_scale: Whether to scale the data.
transform_scale: Whether to scale the data. If True, it will fit
a StandardScaler to the data. If a StandardScaler is passed,
it will use that instead. If False, it will not scale the data.
target_transform: The transform to apply to the labels.
use_legacy_bounds: Whether to use the legacy bounds.
polycrop: Whether to use polycrop.
Expand All @@ -107,7 +112,7 @@ def labelled(
def unlabelled(
self,
transform: Callable[[np.ndarray], Any] = lambda x: x,
transform_scale: bool = True,
transform_scale: bool | StandardScaler = True,
target_transform: Callable[[str], str] = lambda x: x,
use_legacy_bounds: bool = False,
polycrop: bool = False,
Expand All @@ -122,7 +127,9 @@ def unlabelled(
Args:
transform: The transform to apply to the data.
transform_scale: Whether to scale the data.
transform_scale: Whether to scale the data. If True, it will fit
a StandardScaler to the data. If a StandardScaler is passed,
it will use that instead. If False, it will not scale the data.
target_transform: The transform to apply to the labels.
use_legacy_bounds: Whether to use the legacy bounds.
polycrop: Whether to use polycrop.
Expand Down
9 changes: 6 additions & 3 deletions tests/model_tests/chestnut_dec_may/train_fixmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
val_preprocess,
FRDCDatasetStaticEval,
n_weak_strong_aug,
get_y_encoder,
weak_aug,
)

Expand Down Expand Up @@ -57,9 +56,12 @@ def main(
im_size = 255
train_lab_ds = ds.chestnut_20201218(transform=weak_aug(im_size))
train_unl_ds = ds.chestnut_20201218.unlabelled(
transform=n_weak_strong_aug(im_size, unlabelled_factor)
transform=n_weak_strong_aug(im_size, unlabelled_factor),
)
val_ds = ds.chestnut_20210510_43m(
transform=val_preprocess(im_size),
transform_scale=train_lab_ds.x_scaler,
)
val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size))

# Prepare the datamodule and trainer
dm = FRDCDataModule(
Expand Down Expand Up @@ -115,6 +117,7 @@ def main(
"20210510",
"90deg43m85pct255deg",
transform=val_preprocess(im_size),
transform_scale=train_lab_ds.x_scaler,
),
model=m,
)
Expand Down
10 changes: 8 additions & 2 deletions tests/model_tests/chestnut_dec_may/train_mixmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,16 @@ def main(
):
# Prepare the dataset
im_size = 299
train_lab_ds = ds.chestnut_20201218(transform=strong_aug(im_size))
train_lab_ds = ds.chestnut_20201218(
transform=strong_aug(im_size),
)
train_unl_ds = ds.chestnut_20201218.unlabelled(
transform=n_strong_aug(im_size, 2)
)
val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size))
val_ds = ds.chestnut_20210510_43m(
transform=val_preprocess(im_size),
transform_scale=train_lab_ds.x_scaler,
)

# Prepare the datamodule and trainer
dm = FRDCDataModule(
Expand Down Expand Up @@ -104,6 +109,7 @@ def main(
"20210510",
"90deg43m85pct255deg",
transform=val_preprocess(im_size),
transform_scale=train_lab_ds.x_scaler,
),
model=m,
)
Expand Down

0 comments on commit 13b1593

Please sign in to comment.