From 02a4d57e794ada0bb347290fa97ae0e52ee13c52 Mon Sep 17 00:00:00 2001 From: Evening Date: Fri, 7 Jun 2024 22:19:43 +0800 Subject: [PATCH 1/4] Migrate x as preprocessing step in dataset --- src/frdc/load/dataset.py | 36 ++++++++--- src/frdc/load/preset.py | 57 ++++++++++++++---- src/frdc/models/efficientnetb1.py | 6 -- src/frdc/train/fixmatch_module.py | 4 -- src/frdc/train/mixmatch_module.py | 3 - src/frdc/train/utils.py | 59 ++----------------- .../chestnut_dec_may/train_fixmatch.py | 3 - tests/model_tests/utils.py | 8 --- 8 files changed, 76 insertions(+), 100 deletions(-) diff --git a/src/frdc/load/dataset.py b/src/frdc/load/dataset.py index e0a0d9c0..9ab6cb6c 100644 --- a/src/frdc/load/dataset.py +++ b/src/frdc/load/dataset.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd from PIL import Image +from sklearn.preprocessing import StandardScaler from torch.utils.data import Dataset, ConcatDataset from frdc.conf import ( @@ -71,8 +72,9 @@ def __init__( site: str, date: str, version: str | None, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool = False, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, @@ -95,6 +97,7 @@ def __init__( date: The date of the dataset, e.g. "20201218". version: The version of the dataset, e.g. "183deg". transform: The transform to apply to each segment. + transform_scale: Prepends a scaling transform to the transform. target_transform: The transform to apply to each label. use_legacy_bounds: Whether to use the legacy bounds.csv file. This will automatically be set to True if LABEL_STUDIO_CLIENT @@ -102,6 +105,7 @@ def __init__( to. polycrop: Whether to further crop the segments via its polygon bounds. The cropped area will be padded with np.nan. + polycrop_value: The value to pad the cropped area with. """ self.site = site self.date = date @@ -125,17 +129,33 @@ def __init__( self.transform = transform self.target_transform = target_transform + if transform_scale: + self.x_scaler = StandardScaler() + self.x_scaler.fit( + np.concatenate( + [ + # Segments: [H x W x C] -> [H*W, C] + # Reshaping is necessary for StandardScaler + segm.reshape(-1, segm.shape[-1]) + for segm in self.ar_segments + ] + ) + ) + self.transform = lambda x: transform( + self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape( + x.shape + ) + ) + else: + self.x_scaler = None + def __len__(self): return len(self.ar_segments) def __getitem__(self, idx): return ( - self.transform(self.ar_segments[idx]) - if self.transform - else self.ar_segments[idx], - self.target_transform(self.targets[idx]) - if self.target_transform - else self.targets[idx], + self.transform(self.ar_segments[idx]), + self.target_transform(self.targets[idx]), ) @property diff --git a/src/frdc/load/preset.py b/src/frdc/load/preset.py index da0aedcb..4d09c88d 100644 --- a/src/frdc/load/preset.py +++ b/src/frdc/load/preset.py @@ -47,15 +47,26 @@ class FRDCDatasetPartial: def __call__( self, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool = True, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, ): - """Alias for labelled().""" + """Alias for labelled(). + + Args: + transform: The transform to apply to the data. + transform_scale: Whether to scale the data. + target_transform: The transform to apply to the labels. + use_legacy_bounds: Whether to use the legacy bounds. + polycrop: Whether to use polycrop. + polycrop_value: The value to use for polycrop. + """ return self.labelled( transform, + transform_scale, target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, @@ -64,19 +75,30 @@ def __call__( def labelled( self, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool = True, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, ): - """Returns the Labelled Dataset.""" + """Returns the Labelled Dataset. + + Args: + transform: The transform to apply to the data. + transform_scale: Whether to scale the data. + target_transform: The transform to apply to the labels. + use_legacy_bounds: Whether to use the legacy bounds. + polycrop: Whether to use polycrop. + polycrop_value: The value to use for polycrop. + """ return FRDCDataset( self.site, self.date, self.version, - transform, - target_transform, + transform=transform, + transform_scale=transform_scale, + target_transform=target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, polycrop_value=polycrop_value, @@ -84,8 +106,9 @@ def labelled( def unlabelled( self, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool = True, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, @@ -96,13 +119,22 @@ def unlabelled( This simply masks away the labels during __getitem__. The same behaviour can be achieved by setting __class__ to FRDCUnlabelledDataset, but this is a more convenient way to do so. + + Args: + transform: The transform to apply to the data. + transform_scale: Whether to scale the data. + target_transform: The transform to apply to the labels. + use_legacy_bounds: Whether to use the legacy bounds. + polycrop: Whether to use polycrop. + polycrop_value: The value to use for polycrop. """ return FRDCUnlabelledDataset( self.site, self.date, self.version, - transform, - target_transform, + transform=transform, + transform_scale=transform_scale, + target_transform=target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, polycrop_value=polycrop_value, @@ -167,5 +199,4 @@ class FRDCDatasetPreset: Resize((resize, resize)), ] ), - target_transform=None, ) diff --git a/src/frdc/models/efficientnetb1.py b/src/frdc/models/efficientnetb1.py index 5d5b4ef1..d793249c 100644 --- a/src/frdc/models/efficientnetb1.py +++ b/src/frdc/models/efficientnetb1.py @@ -83,7 +83,6 @@ def __init__( in_channels: int, n_classes: int, lr: float, - x_scaler: StandardScaler, y_encoder: OrdinalEncoder, ema_lr: float = 0.001, weight_decay: float = 1e-5, @@ -95,7 +94,6 @@ def __init__( in_channels: The number of input channels. n_classes: The number of classes. lr: The learning rate. - x_scaler: The X input StandardScaler. y_encoder: The Y input OrdinalEncoder. ema_lr: The learning rate for the EMA model. weight_decay: The weight decay. @@ -109,7 +107,6 @@ def __init__( super().__init__( n_classes=n_classes, - x_scaler=x_scaler, y_encoder=y_encoder, sharpen_temp=0.5, mix_beta_alpha=0.75, @@ -157,7 +154,6 @@ def __init__( in_channels: int, n_classes: int, lr: float, - x_scaler: StandardScaler, y_encoder: OrdinalEncoder, weight_decay: float = 1e-5, frozen: bool = True, @@ -168,7 +164,6 @@ def __init__( in_channels: The number of input channels. n_classes: The number of classes. lr: The learning rate. - x_scaler: The X input StandardScaler. y_encoder: The Y input OrdinalEncoder. weight_decay: The weight decay. frozen: Whether to freeze the base model. @@ -181,7 +176,6 @@ def __init__( super().__init__( n_classes=n_classes, - x_scaler=x_scaler, y_encoder=y_encoder, ) diff --git a/src/frdc/train/fixmatch_module.py b/src/frdc/train/fixmatch_module.py index 0d8cd80c..93e8c9ab 100644 --- a/src/frdc/train/fixmatch_module.py +++ b/src/frdc/train/fixmatch_module.py @@ -21,7 +21,6 @@ class FixMatchModule(LightningModule): def __init__( self, *, - x_scaler: StandardScaler, y_encoder: OrdinalEncoder, n_classes: int = 10, unl_conf_threshold: float = 0.95, @@ -41,7 +40,6 @@ def __init__( Args: n_classes: The number of classes in the dataset. - x_scaler: The StandardScaler to use for the data. y_encoder: The OrdinalEncoder to use for the labels. unl_conf_threshold: The confidence threshold for unlabelled data to be considered correctly labelled. @@ -49,7 +47,6 @@ def __init__( super().__init__() - self.x_scaler = x_scaler self.y_encoder = y_encoder self.n_classes = n_classes self.unl_conf_threshold = unl_conf_threshold @@ -244,7 +241,6 @@ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: return preprocess( x_lbl=x_lbl, y_lbl=y_lbl, - x_scaler=self.x_scaler, y_encoder=self.y_encoder, x_unl=x_unl, ) diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index 96177517..cf6442ae 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -24,7 +24,6 @@ class MixMatchModule(LightningModule): def __init__( self, *, - x_scaler: StandardScaler, y_encoder: OrdinalEncoder, n_classes: int = 10, sharpen_temp: float = 0.5, @@ -52,7 +51,6 @@ def __init__( super().__init__() - self.x_scaler = x_scaler self.y_encoder = y_encoder self.n_classes = n_classes self.sharpen_temp = sharpen_temp @@ -257,7 +255,6 @@ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: return preprocess( x_lbl=x_lbl, y_lbl=y_lbl, - x_scaler=self.x_scaler, y_encoder=self.y_encoder, x_unl=x_unl, ) diff --git a/src/frdc/train/utils.py b/src/frdc/train/utils.py index 6e85fe7a..bb1b2e47 100644 --- a/src/frdc/train/utils.py +++ b/src/frdc/train/utils.py @@ -55,7 +55,6 @@ def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: def preprocess( x_lbl: torch.Tensor, y_lbl: torch.Tensor, - x_scaler: StandardScaler, y_encoder: OrdinalEncoder, x_unl: list[torch.Tensor] = None, nan_mask: bool = True, @@ -72,7 +71,6 @@ def preprocess( Args: x_lbl: The data to preprocess. y_lbl: The labels to preprocess. - x_scaler: The StandardScaler to use. y_encoder: The OrdinalEncoder to use. x_unl: The unlabelled data to preprocess. nan_mask: Whether to remove nan values from the batch. @@ -83,13 +81,8 @@ def preprocess( x_unl = [] if x_unl is None else x_unl - x_lbl_trans = x_standard_scale(x_scaler, x_lbl) - y_trans = y_encode(y_encoder, y_lbl) - x_unl_trans = fn_recursive( - x_unl, - fn=lambda x: x_standard_scale(x_scaler, x), - type_atom=torch.Tensor, - type_list=list, + y_trans = torch.from_numpy( + y_encoder.transform(np.array(y_lbl).reshape(-1, 1))[..., 0] ) # Remove nan values from the batch @@ -98,10 +91,10 @@ def preprocess( nan = ( ~torch.isnan(y_trans) if nan_mask else torch.ones_like(y_trans).bool() ) - x_lbl_trans = x_lbl_trans[nan] + x_lbl_trans = x_lbl[nan] x_lbl_trans = torch.nan_to_num(x_lbl_trans) x_unl_trans = fn_recursive( - x_unl_trans, + x_unl, fn=lambda x: torch.nan_to_num(x[nan]), type_atom=torch.Tensor, type_list=list, @@ -111,50 +104,6 @@ def preprocess( return (x_lbl_trans, y_trans.long()), x_unl_trans -def x_standard_scale( - x_scaler: StandardScaler, x: torch.Tensor -) -> torch.Tensor: - """Standard scales the data - - Notes: - This is a wrapper around the StandardScaler to handle PyTorch tensors. - - Args: - x_scaler: The StandardScaler to use. - x: The data to standard scale, of shape (B, C, H, W). - """ - # Standard Scaler only accepts (n_samples, n_features), - # so we need to do some fancy reshaping. - # Note that moving dimensions then reshaping is different from just - # reshaping! - - # Move Channel to the last dimension then transform - # B x C x H x W -> B x H x W x C - b, c, h, w = x.shape - x_ss = x_scaler.transform(x.permute(0, 2, 3, 1).reshape(-1, c)) - - # Move Channel back to the second dimension - # B x H x W x C -> B x C x H x W - return torch.nan_to_num( - torch.from_numpy(x_ss.reshape(b, h, w, c)).permute(0, 3, 1, 2).float() - ) - - -def y_encode(y_encoder: OrdinalEncoder, y: torch.Tensor) -> torch.Tensor: - """Encodes the labels - - Notes: - This is a wrapper around the OrdinalEncoder to handle PyTorch tensors. - - Args: - y_encoder: The OrdinalEncoder to use. - y: The labels to encode. - """ - return torch.from_numpy( - y_encoder.transform(np.array(y).reshape(-1, 1))[..., 0] - ) - - def wandb_hist(x: torch.Tensor, num_bins: int) -> wandb.Histogram: """Records a W&B Histogram""" return wandb.Histogram( diff --git a/tests/model_tests/chestnut_dec_may/train_fixmatch.py b/tests/model_tests/chestnut_dec_may/train_fixmatch.py index b1df238a..79758837 100644 --- a/tests/model_tests/chestnut_dec_may/train_fixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_fixmatch.py @@ -27,7 +27,6 @@ FRDCDatasetStaticEval, n_weak_strong_aug, get_y_encoder, - get_x_scaler, weak_aug, ) @@ -96,13 +95,11 @@ def main( ) oe = get_y_encoder(train_lab_ds.targets) - ss = get_x_scaler(train_lab_ds.ar_segments) m = EfficientNetB1FixMatchModule( in_channels=train_lab_ds.ar.shape[-1], n_classes=len(oe.categories_[0]), lr=lr, - x_scaler=ss, y_encoder=oe, frozen=True, ) diff --git a/tests/model_tests/utils.py b/tests/model_tests/utils.py index ac5b54ac..e4816b36 100644 --- a/tests/model_tests/utils.py +++ b/tests/model_tests/utils.py @@ -132,11 +132,3 @@ def get_y_encoder(targets): ) oe.fit(np.array(targets).reshape(-1, 1)) return oe - - -def get_x_scaler(segments): - ss = StandardScaler() - ss.fit( - np.concatenate([segm.reshape(-1, segm.shape[-1]) for segm in segments]) - ) - return ss From 8497c3e72307c1827056a0963d0abbb42f2433ba Mon Sep 17 00:00:00 2001 From: Evening Date: Mon, 10 Jun 2024 11:23:35 +0800 Subject: [PATCH 2/4] Fix unexpected additional x scaler arg --- tests/integration_tests/test_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration_tests/test_pipeline.py b/tests/integration_tests/test_pipeline.py index 97676dcb..9db094c0 100644 --- a/tests/integration_tests/test_pipeline.py +++ b/tests/integration_tests/test_pipeline.py @@ -48,7 +48,6 @@ def test_manual_segmentation_pipeline(model_fn, ds): in_channels=ds.ar.shape[-1], n_classes=n_classes, lr=1e-3, - x_scaler=ss, y_encoder=oe, frozen=True, ) From ab9de7ca2e5f8601b32936ff868494a15a91b50e Mon Sep 17 00:00:00 2001 From: Evening Date: Mon, 10 Jun 2024 12:40:21 +0800 Subject: [PATCH 3/4] Refactor common modules to FRDCModule --- src/frdc/models/efficientnetb1.py | 28 ++-- src/frdc/train/fixmatch_module.py | 81 ++++------ src/frdc/train/frdc_module.py | 143 ++++++++++++++++++ src/frdc/train/mixmatch_module.py | 73 ++++----- src/frdc/train/utils.py | 52 ------- tests/integration_tests/test_pipeline.py | 18 +-- .../chestnut_dec_may/train_fixmatch.py | 7 +- .../chestnut_dec_may/train_mixmatch.py | 11 +- 8 files changed, 212 insertions(+), 201 deletions(-) create mode 100644 src/frdc/train/frdc_module.py diff --git a/src/frdc/models/efficientnetb1.py b/src/frdc/models/efficientnetb1.py index d793249c..ae87d8d1 100644 --- a/src/frdc/models/efficientnetb1.py +++ b/src/frdc/models/efficientnetb1.py @@ -1,8 +1,7 @@ from copy import deepcopy -from typing import Dict, Any +from typing import Sequence import torch -from sklearn.preprocessing import OrdinalEncoder, StandardScaler from torch import nn from torchvision.models import ( EfficientNet, @@ -10,7 +9,6 @@ EfficientNet_B1_Weights, ) -from frdc.models.utils import save_unfrozen, load_checkpoint_lenient from frdc.train.fixmatch_module import FixMatchModule from frdc.train.mixmatch_module import MixMatchModule from frdc.utils.ema import EMA @@ -81,9 +79,8 @@ def __init__( self, *, in_channels: int, - n_classes: int, + out_targets: Sequence[str], lr: float, - y_encoder: OrdinalEncoder, ema_lr: float = 0.001, weight_decay: float = 1e-5, frozen: bool = True, @@ -92,9 +89,8 @@ def __init__( Args: in_channels: The number of input channels. - n_classes: The number of classes. + out_targets: The output targets. lr: The learning rate. - y_encoder: The Y input OrdinalEncoder. ema_lr: The learning rate for the EMA model. weight_decay: The weight decay. frozen: Whether to freeze the base model. @@ -106,15 +102,14 @@ def __init__( self.weight_decay = weight_decay super().__init__( - n_classes=n_classes, - y_encoder=y_encoder, + out_targets=out_targets, sharpen_temp=0.5, mix_beta_alpha=0.75, ) self.eff = efficientnet_b1_backbone(in_channels, frozen) self.fc = nn.Sequential( - nn.Linear(self.EFF_OUT_DIMS, n_classes), + nn.Linear(self.EFF_OUT_DIMS, self.n_classes), nn.Softmax(dim=1), ) @@ -152,9 +147,8 @@ def __init__( self, *, in_channels: int, - n_classes: int, + out_targets: Sequence[str], lr: float, - y_encoder: OrdinalEncoder, weight_decay: float = 1e-5, frozen: bool = True, ): @@ -162,9 +156,8 @@ def __init__( Args: in_channels: The number of input channels. - n_classes: The number of classes. + out_targets: The output targets. lr: The learning rate. - y_encoder: The Y input OrdinalEncoder. weight_decay: The weight decay. frozen: Whether to freeze the base model. @@ -174,15 +167,12 @@ def __init__( self.lr = lr self.weight_decay = weight_decay - super().__init__( - n_classes=n_classes, - y_encoder=y_encoder, - ) + super().__init__(out_targets=out_targets) self.eff = efficientnet_b1_backbone(in_channels, frozen) self.fc = nn.Sequential( - nn.Linear(self.EFF_OUT_DIMS, n_classes), + nn.Linear(self.EFF_OUT_DIMS, self.n_classes), nn.Softmax(dim=1), ) diff --git a/src/frdc/train/fixmatch_module.py b/src/frdc/train/fixmatch_module.py index 93e8c9ab..f29316c0 100644 --- a/src/frdc/train/fixmatch_module.py +++ b/src/frdc/train/fixmatch_module.py @@ -1,28 +1,23 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Dict +from typing import Sequence import torch import torch.nn.functional as F import wandb -from lightning import LightningModule -from sklearn.preprocessing import StandardScaler, OrdinalEncoder +from lightning.pytorch.utilities.types import STEP_OUTPUT from torchmetrics.functional import accuracy -from frdc.models.utils import save_unfrozen, load_checkpoint_lenient -from frdc.train.utils import ( - wandb_hist, - preprocess, -) +from frdc.train.frdc_module import FRDCModule +from frdc.train.utils import wandb_hist -class FixMatchModule(LightningModule): +class FixMatchModule(FRDCModule): def __init__( self, *, - y_encoder: OrdinalEncoder, - n_classes: int = 10, + out_targets: Sequence[str], unl_conf_threshold: float = 0.95, ): """PyTorch Lightning Module for MixMatch @@ -39,16 +34,13 @@ def __init__( how to implement a new dataset. Args: - n_classes: The number of classes in the dataset. - y_encoder: The OrdinalEncoder to use for the labels. + out_targets: The output targets for the model. unl_conf_threshold: The confidence threshold for unlabelled data to be considered correctly labelled. """ - super().__init__() + super().__init__(out_targets=out_targets) - self.y_encoder = y_encoder - self.n_classes = n_classes self.unl_conf_threshold = unl_conf_threshold self.save_hyperparameters() @@ -60,7 +52,11 @@ def __init__( def forward(self, x): ... - def training_step(self, batch, batch_idx): + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): """A single training step for a batch Notes: @@ -91,7 +87,6 @@ def training_step(self, batch, batch_idx): Loss: ℓ_lbl + ℓ_unl """ - def training_step(self, batch, batch_idx): (x_lbl, y_lbl), x_unls = batch opt = self.optimizers() @@ -172,7 +167,11 @@ def training_step(self, batch, batch_idx): } ) - def validation_step(self, batch, batch_idx): + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): # The batch outputs x_unls due to our on_before_batch_transfer (x, y), _x_unls = batch wandb.log({"val/y_lbl": wandb_hist(y, self.n_classes)}) @@ -194,7 +193,11 @@ def validation_step(self, batch, batch_idx): self.log("val/acc", acc, prog_bar=True) return loss - def test_step(self, batch, batch_idx): + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ) -> STEP_OUTPUT: # The batch outputs x_unls due to our on_before_batch_transfer (x, y), _x_unls = batch y_pred = self(x) @@ -207,7 +210,10 @@ def test_step(self, batch, batch_idx): self.log("test/acc", acc, prog_bar=True) return loss - def predict_step(self, batch, *args, **kwargs) -> Any: + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ): (x, y), _x_unls = batch y_pred = self(x) y_true_str = self.y_encoder.inverse_transform( @@ -217,36 +223,3 @@ def predict_step(self, batch, *args, **kwargs) -> Any: y_pred.argmax(dim=1).cpu().numpy().reshape(-1, 1) ) return y_true_str, y_pred_str - - @torch.no_grad() - def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: - """This method is called before any data transfer to the device. - - We leverage this to do some preprocessing on the data. - Namely, we use the StandardScaler and OrdinalEncoder to transform the - data. - - Notes: - PyTorch Lightning may complain about this being on the Module - instead of the DataModule. However, this is intentional as we - want to export the model alongside the transformations. - """ - - if self.training: - (x_lbl, y_lbl), x_unl = batch - else: - x_lbl, y_lbl = batch - x_unl = None - - return preprocess( - x_lbl=x_lbl, - y_lbl=y_lbl, - y_encoder=self.y_encoder, - x_unl=x_unl, - ) - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - save_unfrozen(self, checkpoint) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - load_checkpoint_lenient(self, checkpoint) diff --git a/src/frdc/train/frdc_module.py b/src/frdc/train/frdc_module.py new file mode 100644 index 00000000..22582e92 --- /dev/null +++ b/src/frdc/train/frdc_module.py @@ -0,0 +1,143 @@ +from typing import Any, Dict, Sequence + +import numpy as np +import torch +from lightning import LightningModule +from sklearn.preprocessing import OrdinalEncoder + +from frdc.models.utils import save_unfrozen, load_checkpoint_lenient +from frdc.utils.utils import fn_recursive + + +class FRDCModule(LightningModule): + def __init__( + self, + *, + out_targets: Sequence[str], + nan_mask_missing_y_labels: bool = True, + ): + """Base Lightning Module for MixMatch + + Notes: + This is the base class for MixMatch and FixMatch. + This implements the Y-Encoder logic so that all modules can + encode and decode the tree string labels. + + Generally the hierarchy is: + Module + -> Module + -> FRDCModule + + E.g. + EfficientNetB1MixMatchModule + -> MixMatchModule + -> FRDCModule + + WideResNetFixMatchModule + -> FixMatchModule + -> FRDCModule + + Args: + out_targets: The output targets for the model. + nan_mask_missing_y_labels: Whether to mask away x values that + have missing y labels. This happens when the y label is not + present in the OrdinalEncoder's categories, which happens + during non-training steps. E.g. A new unseen tree is inferred. + """ + + super().__init__() + + self.y_encoder: OrdinalEncoder = OrdinalEncoder( + handle_unknown="use_encoded_value", + unknown_value=np.nan, + ) + self.y_encoder.fit(np.array(out_targets).reshape(-1, 1)) + self.nan_mask_missing_y_labels = nan_mask_missing_y_labels + self.save_hyperparameters() + + @property + def n_classes(self): + return len(self.y_encoder.categories_[0]) + + @torch.no_grad() + def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + """This method is called before any data transfer to the device. + + Notes: + This method wraps OrdinalEncoder to convert labels from str to int + before transferring the data to the device. + + Note that this step must happen before the transfer as tensors + don't support str types. + + PyTorch Lightning may complain about this being on the Module + instead of the DataModule. However, this is intentional as we + want to export the model alongside the transformations. + """ + + if self.training: + (x_lbl, y_lbl), x_unl = batch + else: + x_lbl, y_lbl = batch + x_unl = [] + + y_trans = torch.from_numpy( + self.y_encoder.transform(np.array(y_lbl).reshape(-1, 1))[..., 0] + ) + + # Remove nan values from the batch + # Ordinal Encoders can return a np.nan if the value is not in the + # categories. We will remove that from the batch. + nan = ( + ~torch.isnan(y_trans) # Keeps all non-nan values + if self.nan_mask_missing_y_labels + else torch.ones_like(y_trans).bool() # Keeps all values + ) + + x_lbl_trans = torch.nan_to_num(x_lbl[nan]) + + # This function applies nan_to_num to all tensors in the list, + # regardless of how deeply nested they are. + x_unl_trans = fn_recursive( + x_unl, + fn=lambda x: torch.nan_to_num(x[nan]), + type_atom=torch.Tensor, + type_list=list, + ) + y_trans = y_trans[nan].long() + + return (x_lbl_trans, y_trans), x_unl_trans + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + save_unfrozen(self, checkpoint) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + load_checkpoint_lenient(self, checkpoint) + + # The following methods are to enforce the batch schema typing. + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ) -> Any: + ... diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index cf6442ae..3b857851 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -1,31 +1,28 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Dict +from typing import Any, Dict, Sequence import torch import torch.nn.functional as F import wandb -from lightning import LightningModule -from sklearn.preprocessing import StandardScaler, OrdinalEncoder from torch.nn.functional import one_hot from torchmetrics.functional import accuracy -from frdc.models.utils import save_unfrozen, load_checkpoint_lenient +from frdc.models.utils import save_unfrozen +from frdc.train.frdc_module import FRDCModule from frdc.train.utils import ( mix_up, sharpen, wandb_hist, - preprocess, ) -class MixMatchModule(LightningModule): +class MixMatchModule(FRDCModule): def __init__( self, *, - y_encoder: OrdinalEncoder, - n_classes: int = 10, + out_targets: Sequence[str], sharpen_temp: float = 0.5, mix_beta_alpha: float = 0.75, ): @@ -43,16 +40,14 @@ def __init__( how to implement a new dataset. Args: - n_classes: The number of classes in the dataset. + out_targets: The output targets for the model. sharpen_temp: The temperature to use for sharpening. mix_beta_alpha: The alpha to use for the beta distribution when mixing. """ - super().__init__() + super().__init__(out_targets=out_targets) - self.y_encoder = y_encoder - self.n_classes = n_classes self.sharpen_temp = sharpen_temp self.mix_beta_alpha = mix_beta_alpha self.save_hyperparameters() @@ -112,7 +107,11 @@ def progress(self): self.global_step / self.trainer.num_training_batches ) / self.trainer.max_epochs - def training_step(self, batch, batch_idx): + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x_lbl, y_lbl), x_unls = batch self.log("train/x_lbl_mean", x_lbl.mean()) @@ -189,7 +188,11 @@ def training_step(self, batch, batch_idx): def on_after_backward(self) -> None: self.update_ema() - def validation_step(self, batch, batch_idx): + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x, y), _x_unls = batch wandb.log({"val/y_lbl": wandb_hist(y, self.n_classes)}) y_pred = self.ema_model(x) @@ -209,7 +212,11 @@ def validation_step(self, batch, batch_idx): self.log("val/acc", acc, prog_bar=True) return loss - def test_step(self, batch, batch_idx): + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x, y), _x_unls = batch y_pred = self.ema_model(x) loss = F.cross_entropy(y_pred, y.long()) @@ -221,7 +228,10 @@ def test_step(self, batch, batch_idx): self.log("test/acc", acc, prog_bar=True) return loss - def predict_step(self, batch, *args, **kwargs) -> Any: + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ) -> Any: (x, y), _x_unls = batch y_pred = self.ema_model(x) y_true_str = self.y_encoder.inverse_transform( @@ -232,39 +242,10 @@ def predict_step(self, batch, *args, **kwargs) -> Any: ) return y_true_str, y_pred_str - @torch.no_grad() - def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: - """This method is called before any data transfer to the device. - - We leverage this to do some preprocessing on the data. - Namely, we use the StandardScaler and OrdinalEncoder to transform the - data. - - Notes: - PyTorch Lightning may complain about this being on the Module - instead of the DataModule. However, this is intentional as we - want to export the model alongside the transformations. - """ - - if self.training: - (x_lbl, y_lbl), x_unl = batch - else: - x_lbl, y_lbl = batch - x_unl = None - - return preprocess( - x_lbl=x_lbl, - y_lbl=y_lbl, - y_encoder=self.y_encoder, - x_unl=x_unl, - ) - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """This override the original method to save the EMAs as well.""" save_unfrozen( self, checkpoint, include_also=lambda k: k.startswith("_ema_model.fc."), ) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - load_checkpoint_lenient(self, checkpoint) diff --git a/src/frdc/train/utils.py b/src/frdc/train/utils.py index bb1b2e47..b42ce91e 100644 --- a/src/frdc/train/utils.py +++ b/src/frdc/train/utils.py @@ -52,58 +52,6 @@ def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: return y_sharp -def preprocess( - x_lbl: torch.Tensor, - y_lbl: torch.Tensor, - y_encoder: OrdinalEncoder, - x_unl: list[torch.Tensor] = None, - nan_mask: bool = True, -) -> tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]]: - """Preprocesses the data - - Notes: - The reason why x and y's preprocessing is coupled is due to the NaN - elimination step. The NaN elimination step is due to unseen labels by y - - fn_recursive is to recursively apply some function to a nested list. - This happens due to unlabelled being a list of tensors. - - Args: - x_lbl: The data to preprocess. - y_lbl: The labels to preprocess. - y_encoder: The OrdinalEncoder to use. - x_unl: The unlabelled data to preprocess. - nan_mask: Whether to remove nan values from the batch. - - Returns: - The preprocessed data and labels. - """ - - x_unl = [] if x_unl is None else x_unl - - y_trans = torch.from_numpy( - y_encoder.transform(np.array(y_lbl).reshape(-1, 1))[..., 0] - ) - - # Remove nan values from the batch - # Ordinal Encoders can return a np.nan if the value is not in the - # categories. We will remove that from the batch. - nan = ( - ~torch.isnan(y_trans) if nan_mask else torch.ones_like(y_trans).bool() - ) - x_lbl_trans = x_lbl[nan] - x_lbl_trans = torch.nan_to_num(x_lbl_trans) - x_unl_trans = fn_recursive( - x_unl, - fn=lambda x: torch.nan_to_num(x[nan]), - type_atom=torch.Tensor, - type_list=list, - ) - y_trans = y_trans[nan] - - return (x_lbl_trans, y_trans.long()), x_unl_trans - - def wandb_hist(x: torch.Tensor, num_bins: int) -> wandb.Histogram: """Records a W&B Histogram""" return wandb.Histogram( diff --git a/tests/integration_tests/test_pipeline.py b/tests/integration_tests/test_pipeline.py index 9db094c0..3de3e6be 100644 --- a/tests/integration_tests/test_pipeline.py +++ b/tests/integration_tests/test_pipeline.py @@ -2,10 +2,8 @@ from pathlib import Path import lightning as pl -import numpy as np import pytest import torch -from sklearn.preprocessing import StandardScaler, OrdinalEncoder from frdc.models.efficientnetb1 import ( EfficientNetB1MixMatchModule, @@ -33,22 +31,10 @@ def test_manual_segmentation_pipeline(model_fn, ds): val_ds=ds, batch_size=BATCH_SIZE, ) - - oe = OrdinalEncoder( - handle_unknown="use_encoded_value", - unknown_value=np.nan, - ) - oe.fit(np.array(ds.targets).reshape(-1, 1)) - n_classes = len(oe.categories_[0]) - - ss = StandardScaler() - ss.fit(ds.ar.reshape(-1, ds.ar.shape[-1])) - m = model_fn( in_channels=ds.ar.shape[-1], - n_classes=n_classes, lr=1e-3, - y_encoder=oe, + out_targets=ds.targets, frozen=True, ) @@ -92,4 +78,4 @@ def test_manual_segmentation_pipeline(model_fn, ds): # E.g. achieved via hash comparison. # This is because BatchNorm usually keeps running statistics # and reloading the model will reset them. - # We don't necessarily need to + # We don't necessarily need to check for this. diff --git a/tests/model_tests/chestnut_dec_may/train_fixmatch.py b/tests/model_tests/chestnut_dec_may/train_fixmatch.py index 79758837..b5804de1 100644 --- a/tests/model_tests/chestnut_dec_may/train_fixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_fixmatch.py @@ -94,13 +94,10 @@ def main( ), ) - oe = get_y_encoder(train_lab_ds.targets) - m = EfficientNetB1FixMatchModule( in_channels=train_lab_ds.ar.shape[-1], - n_classes=len(oe.categories_[0]), + out_targets=train_lab_ds.targets, lr=lr, - y_encoder=oe, frozen=True, ) @@ -121,7 +118,7 @@ def main( ), model=m, ) - fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0]) + fig, ax = plot_confusion_matrix(y_true, y_pred, m.y_encoder.categories_[0]) acc = np.sum(y_true == y_pred) / len(y_true) ax.set_title(f"Accuracy: {acc:.2%}") diff --git a/tests/model_tests/chestnut_dec_may/train_mixmatch.py b/tests/model_tests/chestnut_dec_may/train_mixmatch.py index 1bf839f5..7b852607 100644 --- a/tests/model_tests/chestnut_dec_may/train_mixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_mixmatch.py @@ -25,8 +25,6 @@ FRDCDatasetStaticEval, n_strong_aug, strong_aug, - get_y_encoder, - get_x_scaler, ) @@ -85,15 +83,10 @@ def main( ), ) - oe = get_y_encoder(train_lab_ds.targets) - ss = get_x_scaler(train_lab_ds.ar_segments) - m = EfficientNetB1MixMatchModule( in_channels=train_lab_ds.ar.shape[-1], - n_classes=len(oe.categories_[0]), + out_targets=train_lab_ds.targets, lr=lr, - x_scaler=ss, - y_encoder=oe, frozen=True, ) @@ -114,7 +107,7 @@ def main( ), model=m, ) - fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0]) + fig, ax = plot_confusion_matrix(y_true, y_pred, m.y_encoder.categories_[0]) acc = np.sum(y_true == y_pred) / len(y_true) ax.set_title(f"Accuracy: {acc:.2%}") From 13b15930a9b1dabc27ee2f235a02e876ad8bf5b8 Mon Sep 17 00:00:00 2001 From: Evening Date: Mon, 10 Jun 2024 17:12:26 +0800 Subject: [PATCH 4/4] Allow StandardScaler to override default scaler --- src/frdc/load/dataset.py | 15 ++++++++++++--- src/frdc/load/preset.py | 19 +++++++++++++------ .../chestnut_dec_may/train_fixmatch.py | 9 ++++++--- .../chestnut_dec_may/train_mixmatch.py | 10 ++++++++-- 4 files changed, 39 insertions(+), 14 deletions(-) diff --git a/src/frdc/load/dataset.py b/src/frdc/load/dataset.py index 9ab6cb6c..4666283f 100644 --- a/src/frdc/load/dataset.py +++ b/src/frdc/load/dataset.py @@ -73,7 +73,7 @@ def __init__( date: str, version: str | None, transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool = False, + transform_scale: bool | StandardScaler = True, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -97,7 +97,9 @@ def __init__( date: The date of the dataset, e.g. "20201218". version: The version of the dataset, e.g. "183deg". transform: The transform to apply to each segment. - transform_scale: Prepends a scaling transform to the transform. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to each label. use_legacy_bounds: Whether to use the legacy bounds.csv file. This will automatically be set to True if LABEL_STUDIO_CLIENT @@ -129,7 +131,7 @@ def __init__( self.transform = transform self.target_transform = target_transform - if transform_scale: + if transform_scale is True: self.x_scaler = StandardScaler() self.x_scaler.fit( np.concatenate( @@ -146,6 +148,13 @@ def __init__( x.shape ) ) + elif isinstance(transform_scale, StandardScaler): + self.x_scaler = transform_scale + self.transform = lambda x: transform( + self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape( + x.shape + ) + ) else: self.x_scaler = None diff --git a/src/frdc/load/preset.py b/src/frdc/load/preset.py index 4d09c88d..6fa54dff 100644 --- a/src/frdc/load/preset.py +++ b/src/frdc/load/preset.py @@ -6,6 +6,7 @@ import numpy as np import torch +from sklearn.preprocessing import StandardScaler from torchvision.transforms.v2 import ( Compose, ToImage, @@ -48,7 +49,7 @@ class FRDCDatasetPartial: def __call__( self, transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool = True, + transform_scale: bool | StandardScaler = True, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -58,7 +59,9 @@ def __call__( Args: transform: The transform to apply to the data. - transform_scale: Whether to scale the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to the labels. use_legacy_bounds: Whether to use the legacy bounds. polycrop: Whether to use polycrop. @@ -76,7 +79,7 @@ def __call__( def labelled( self, transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool = True, + transform_scale: bool | StandardScaler = True, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -86,7 +89,9 @@ def labelled( Args: transform: The transform to apply to the data. - transform_scale: Whether to scale the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to the labels. use_legacy_bounds: Whether to use the legacy bounds. polycrop: Whether to use polycrop. @@ -107,7 +112,7 @@ def labelled( def unlabelled( self, transform: Callable[[np.ndarray], Any] = lambda x: x, - transform_scale: bool = True, + transform_scale: bool | StandardScaler = True, target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, @@ -122,7 +127,9 @@ def unlabelled( Args: transform: The transform to apply to the data. - transform_scale: Whether to scale the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to the labels. use_legacy_bounds: Whether to use the legacy bounds. polycrop: Whether to use polycrop. diff --git a/tests/model_tests/chestnut_dec_may/train_fixmatch.py b/tests/model_tests/chestnut_dec_may/train_fixmatch.py index b5804de1..c83bf430 100644 --- a/tests/model_tests/chestnut_dec_may/train_fixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_fixmatch.py @@ -26,7 +26,6 @@ val_preprocess, FRDCDatasetStaticEval, n_weak_strong_aug, - get_y_encoder, weak_aug, ) @@ -57,9 +56,12 @@ def main( im_size = 255 train_lab_ds = ds.chestnut_20201218(transform=weak_aug(im_size)) train_unl_ds = ds.chestnut_20201218.unlabelled( - transform=n_weak_strong_aug(im_size, unlabelled_factor) + transform=n_weak_strong_aug(im_size, unlabelled_factor), + ) + val_ds = ds.chestnut_20210510_43m( + transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, ) - val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size)) # Prepare the datamodule and trainer dm = FRDCDataModule( @@ -115,6 +117,7 @@ def main( "20210510", "90deg43m85pct255deg", transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, ), model=m, ) diff --git a/tests/model_tests/chestnut_dec_may/train_mixmatch.py b/tests/model_tests/chestnut_dec_may/train_mixmatch.py index 7b852607..0aab51f0 100644 --- a/tests/model_tests/chestnut_dec_may/train_mixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_mixmatch.py @@ -44,11 +44,16 @@ def main( ): # Prepare the dataset im_size = 299 - train_lab_ds = ds.chestnut_20201218(transform=strong_aug(im_size)) + train_lab_ds = ds.chestnut_20201218( + transform=strong_aug(im_size), + ) train_unl_ds = ds.chestnut_20201218.unlabelled( transform=n_strong_aug(im_size, 2) ) - val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size)) + val_ds = ds.chestnut_20210510_43m( + transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, + ) # Prepare the datamodule and trainer dm = FRDCDataModule( @@ -104,6 +109,7 @@ def main( "20210510", "90deg43m85pct255deg", transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, ), model=m, )