From 13b15930a9b1dabc27ee2f235a02e876ad8bf5b8 Mon Sep 17 00:00:00 2001 From: Evening Date: Mon, 10 Jun 2024 17:12:26 +0800 Subject: [PATCH] Allow StandardScaler to override default scaler --- src/frdc/load/dataset.py | 15 ++++++++++++--- src/frdc/load/preset.py | 19 +++++++++++++------ .../chestnut_dec_may/train_fixmatch.py | 9 ++++++--- .../chestnut_dec_may/train_mixmatch.py | 10 ++++++++-- 4 files changed, 39 insertions(+), 14 deletions(-) diff --git a/src/frdc/load/dataset.py b/src/frdc/load/dataset.py index 9ab6cb6..4666283 100644 --- a/src/frdc/load/dataset.py +++ b/src/frdc/load/dataset.py @@ -73,7 +73,7 @@ def __init__( date: str, version: str | None, transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool = False, + transform_scale: bool | StandardScaler = True, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -97,7 +97,9 @@ def __init__( date: The date of the dataset, e.g. "20201218". version: The version of the dataset, e.g. "183deg". transform: The transform to apply to each segment. - transform_scale: Prepends a scaling transform to the transform. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to each label. use_legacy_bounds: Whether to use the legacy bounds.csv file. This will automatically be set to True if LABEL_STUDIO_CLIENT @@ -129,7 +131,7 @@ def __init__( self.transform = transform self.target_transform = target_transform - if transform_scale: + if transform_scale is True: self.x_scaler = StandardScaler() self.x_scaler.fit( np.concatenate( @@ -146,6 +148,13 @@ def __init__( x.shape ) ) + elif isinstance(transform_scale, StandardScaler): + self.x_scaler = transform_scale + self.transform = lambda x: transform( + self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape( + x.shape + ) + ) else: self.x_scaler = None diff --git a/src/frdc/load/preset.py b/src/frdc/load/preset.py index 4d09c88..6fa54df 100644 --- a/src/frdc/load/preset.py +++ b/src/frdc/load/preset.py @@ -6,6 +6,7 @@ import numpy as np import torch +from sklearn.preprocessing import StandardScaler from torchvision.transforms.v2 import ( Compose, ToImage, @@ -48,7 +49,7 @@ class FRDCDatasetPartial: def __call__( self, transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool = True, + transform_scale: bool | StandardScaler = True, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -58,7 +59,9 @@ def __call__( Args: transform: The transform to apply to the data. - transform_scale: Whether to scale the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to the labels. use_legacy_bounds: Whether to use the legacy bounds. polycrop: Whether to use polycrop. @@ -76,7 +79,7 @@ def __call__( def labelled( self, transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool = True, + transform_scale: bool | StandardScaler = True, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -86,7 +89,9 @@ def labelled( Args: transform: The transform to apply to the data. - transform_scale: Whether to scale the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to the labels. use_legacy_bounds: Whether to use the legacy bounds. polycrop: Whether to use polycrop. @@ -107,7 +112,7 @@ def labelled( def unlabelled( self, transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool = True, + transform_scale: bool | StandardScaler = True, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -122,7 +127,9 @@ def unlabelled( Args: transform: The transform to apply to the data. - transform_scale: Whether to scale the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to the labels. use_legacy_bounds: Whether to use the legacy bounds. polycrop: Whether to use polycrop. diff --git a/tests/model_tests/chestnut_dec_may/train_fixmatch.py b/tests/model_tests/chestnut_dec_may/train_fixmatch.py index b5804de..c83bf43 100644 --- a/tests/model_tests/chestnut_dec_may/train_fixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_fixmatch.py @@ -26,7 +26,6 @@ val_preprocess, FRDCDatasetStaticEval, n_weak_strong_aug, - get_y_encoder, weak_aug, ) @@ -57,9 +56,12 @@ def main( im_size = 255 train_lab_ds = ds.chestnut_20201218(transform=weak_aug(im_size)) train_unl_ds = ds.chestnut_20201218.unlabelled( - transform=n_weak_strong_aug(im_size, unlabelled_factor) + transform=n_weak_strong_aug(im_size, unlabelled_factor), + ) + val_ds = ds.chestnut_20210510_43m( + transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, ) - val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size)) # Prepare the datamodule and trainer dm = FRDCDataModule( @@ -115,6 +117,7 @@ def main( "20210510", "90deg43m85pct255deg", transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, ), model=m, ) diff --git a/tests/model_tests/chestnut_dec_may/train_mixmatch.py b/tests/model_tests/chestnut_dec_may/train_mixmatch.py index 7b85260..0aab51f 100644 --- a/tests/model_tests/chestnut_dec_may/train_mixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_mixmatch.py @@ -44,11 +44,16 @@ def main( ): # Prepare the dataset im_size = 299 - train_lab_ds = ds.chestnut_20201218(transform=strong_aug(im_size)) + train_lab_ds = ds.chestnut_20201218( + transform=strong_aug(im_size), + ) train_unl_ds = ds.chestnut_20201218.unlabelled( transform=n_strong_aug(im_size, 2) ) - val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size)) + val_ds = ds.chestnut_20210510_43m( + transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, + ) # Prepare the datamodule and trainer dm = FRDCDataModule( @@ -104,6 +109,7 @@ def main( "20210510", "90deg43m85pct255deg", transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, ), model=m, )