diff --git a/src/frdc/load/dataset.py b/src/frdc/load/dataset.py index e0a0d9c..4666283 100644 --- a/src/frdc/load/dataset.py +++ b/src/frdc/load/dataset.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd from PIL import Image +from sklearn.preprocessing import StandardScaler from torch.utils.data import Dataset, ConcatDataset from frdc.conf import ( @@ -71,8 +72,9 @@ def __init__( site: str, date: str, version: str | None, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool | StandardScaler = True, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, @@ -95,6 +97,9 @@ def __init__( date: The date of the dataset, e.g. "20201218". version: The version of the dataset, e.g. "183deg". transform: The transform to apply to each segment. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. target_transform: The transform to apply to each label. use_legacy_bounds: Whether to use the legacy bounds.csv file. This will automatically be set to True if LABEL_STUDIO_CLIENT @@ -102,6 +107,7 @@ def __init__( to. polycrop: Whether to further crop the segments via its polygon bounds. The cropped area will be padded with np.nan. + polycrop_value: The value to pad the cropped area with. """ self.site = site self.date = date @@ -125,17 +131,40 @@ def __init__( self.transform = transform self.target_transform = target_transform + if transform_scale is True: + self.x_scaler = StandardScaler() + self.x_scaler.fit( + np.concatenate( + [ + # Segments: [H x W x C] -> [H*W, C] + # Reshaping is necessary for StandardScaler + segm.reshape(-1, segm.shape[-1]) + for segm in self.ar_segments + ] + ) + ) + self.transform = lambda x: transform( + self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape( + x.shape + ) + ) + elif isinstance(transform_scale, StandardScaler): + self.x_scaler = transform_scale + self.transform = lambda x: transform( + self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape( + x.shape + ) + ) + else: + self.x_scaler = None + def __len__(self): return len(self.ar_segments) def __getitem__(self, idx): return ( - self.transform(self.ar_segments[idx]) - if self.transform - else self.ar_segments[idx], - self.target_transform(self.targets[idx]) - if self.target_transform - else self.targets[idx], + self.transform(self.ar_segments[idx]), + self.target_transform(self.targets[idx]), ) @property diff --git a/src/frdc/load/preset.py b/src/frdc/load/preset.py index da0aedc..6fa54df 100644 --- a/src/frdc/load/preset.py +++ b/src/frdc/load/preset.py @@ -6,6 +6,7 @@ import numpy as np import torch +from sklearn.preprocessing import StandardScaler from torchvision.transforms.v2 import ( Compose, ToImage, @@ -47,15 +48,28 @@ class FRDCDatasetPartial: def __call__( self, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool | StandardScaler = True, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, ): - """Alias for labelled().""" + """Alias for labelled(). + + Args: + transform: The transform to apply to the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. + target_transform: The transform to apply to the labels. + use_legacy_bounds: Whether to use the legacy bounds. + polycrop: Whether to use polycrop. + polycrop_value: The value to use for polycrop. + """ return self.labelled( transform, + transform_scale, target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, @@ -64,19 +78,32 @@ def __call__( def labelled( self, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool | StandardScaler = True, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, ): - """Returns the Labelled Dataset.""" + """Returns the Labelled Dataset. + + Args: + transform: The transform to apply to the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. + target_transform: The transform to apply to the labels. + use_legacy_bounds: Whether to use the legacy bounds. + polycrop: Whether to use polycrop. + polycrop_value: The value to use for polycrop. + """ return FRDCDataset( self.site, self.date, self.version, - transform, - target_transform, + transform=transform, + transform_scale=transform_scale, + target_transform=target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, polycrop_value=polycrop_value, @@ -84,8 +111,9 @@ def labelled( def unlabelled( self, - transform: Callable[[list[np.ndarray]], Any] = None, - target_transform: Callable[[list[str]], list[str]] = None, + transform: Callable[[np.ndarray], Any] = lambda x: x, + transform_scale: bool | StandardScaler = True, + target_transform: Callable[[str], str] = lambda x: x, use_legacy_bounds: bool = False, polycrop: bool = False, polycrop_value: Any = np.nan, @@ -96,13 +124,24 @@ def unlabelled( This simply masks away the labels during __getitem__. The same behaviour can be achieved by setting __class__ to FRDCUnlabelledDataset, but this is a more convenient way to do so. + + Args: + transform: The transform to apply to the data. + transform_scale: Whether to scale the data. If True, it will fit + a StandardScaler to the data. If a StandardScaler is passed, + it will use that instead. If False, it will not scale the data. + target_transform: The transform to apply to the labels. + use_legacy_bounds: Whether to use the legacy bounds. + polycrop: Whether to use polycrop. + polycrop_value: The value to use for polycrop. """ return FRDCUnlabelledDataset( self.site, self.date, self.version, - transform, - target_transform, + transform=transform, + transform_scale=transform_scale, + target_transform=target_transform, use_legacy_bounds=use_legacy_bounds, polycrop=polycrop, polycrop_value=polycrop_value, @@ -167,5 +206,4 @@ class FRDCDatasetPreset: Resize((resize, resize)), ] ), - target_transform=None, ) diff --git a/src/frdc/models/efficientnetb1.py b/src/frdc/models/efficientnetb1.py index 5d5b4ef..ae87d8d 100644 --- a/src/frdc/models/efficientnetb1.py +++ b/src/frdc/models/efficientnetb1.py @@ -1,8 +1,7 @@ from copy import deepcopy -from typing import Dict, Any +from typing import Sequence import torch -from sklearn.preprocessing import OrdinalEncoder, StandardScaler from torch import nn from torchvision.models import ( EfficientNet, @@ -10,7 +9,6 @@ EfficientNet_B1_Weights, ) -from frdc.models.utils import save_unfrozen, load_checkpoint_lenient from frdc.train.fixmatch_module import FixMatchModule from frdc.train.mixmatch_module import MixMatchModule from frdc.utils.ema import EMA @@ -81,10 +79,8 @@ def __init__( self, *, in_channels: int, - n_classes: int, + out_targets: Sequence[str], lr: float, - x_scaler: StandardScaler, - y_encoder: OrdinalEncoder, ema_lr: float = 0.001, weight_decay: float = 1e-5, frozen: bool = True, @@ -93,10 +89,8 @@ def __init__( Args: in_channels: The number of input channels. - n_classes: The number of classes. + out_targets: The output targets. lr: The learning rate. - x_scaler: The X input StandardScaler. - y_encoder: The Y input OrdinalEncoder. ema_lr: The learning rate for the EMA model. weight_decay: The weight decay. frozen: Whether to freeze the base model. @@ -108,16 +102,14 @@ def __init__( self.weight_decay = weight_decay super().__init__( - n_classes=n_classes, - x_scaler=x_scaler, - y_encoder=y_encoder, + out_targets=out_targets, sharpen_temp=0.5, mix_beta_alpha=0.75, ) self.eff = efficientnet_b1_backbone(in_channels, frozen) self.fc = nn.Sequential( - nn.Linear(self.EFF_OUT_DIMS, n_classes), + nn.Linear(self.EFF_OUT_DIMS, self.n_classes), nn.Softmax(dim=1), ) @@ -155,10 +147,8 @@ def __init__( self, *, in_channels: int, - n_classes: int, + out_targets: Sequence[str], lr: float, - x_scaler: StandardScaler, - y_encoder: OrdinalEncoder, weight_decay: float = 1e-5, frozen: bool = True, ): @@ -166,10 +156,8 @@ def __init__( Args: in_channels: The number of input channels. - n_classes: The number of classes. + out_targets: The output targets. lr: The learning rate. - x_scaler: The X input StandardScaler. - y_encoder: The Y input OrdinalEncoder. weight_decay: The weight decay. frozen: Whether to freeze the base model. @@ -179,16 +167,12 @@ def __init__( self.lr = lr self.weight_decay = weight_decay - super().__init__( - n_classes=n_classes, - x_scaler=x_scaler, - y_encoder=y_encoder, - ) + super().__init__(out_targets=out_targets) self.eff = efficientnet_b1_backbone(in_channels, frozen) self.fc = nn.Sequential( - nn.Linear(self.EFF_OUT_DIMS, n_classes), + nn.Linear(self.EFF_OUT_DIMS, self.n_classes), nn.Softmax(dim=1), ) diff --git a/src/frdc/train/fixmatch_module.py b/src/frdc/train/fixmatch_module.py index 0d8cd80..f29316c 100644 --- a/src/frdc/train/fixmatch_module.py +++ b/src/frdc/train/fixmatch_module.py @@ -1,29 +1,23 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Dict +from typing import Sequence import torch import torch.nn.functional as F import wandb -from lightning import LightningModule -from sklearn.preprocessing import StandardScaler, OrdinalEncoder +from lightning.pytorch.utilities.types import STEP_OUTPUT from torchmetrics.functional import accuracy -from frdc.models.utils import save_unfrozen, load_checkpoint_lenient -from frdc.train.utils import ( - wandb_hist, - preprocess, -) +from frdc.train.frdc_module import FRDCModule +from frdc.train.utils import wandb_hist -class FixMatchModule(LightningModule): +class FixMatchModule(FRDCModule): def __init__( self, *, - x_scaler: StandardScaler, - y_encoder: OrdinalEncoder, - n_classes: int = 10, + out_targets: Sequence[str], unl_conf_threshold: float = 0.95, ): """PyTorch Lightning Module for MixMatch @@ -40,18 +34,13 @@ def __init__( how to implement a new dataset. Args: - n_classes: The number of classes in the dataset. - x_scaler: The StandardScaler to use for the data. - y_encoder: The OrdinalEncoder to use for the labels. + out_targets: The output targets for the model. unl_conf_threshold: The confidence threshold for unlabelled data to be considered correctly labelled. """ - super().__init__() + super().__init__(out_targets=out_targets) - self.x_scaler = x_scaler - self.y_encoder = y_encoder - self.n_classes = n_classes self.unl_conf_threshold = unl_conf_threshold self.save_hyperparameters() @@ -63,7 +52,11 @@ def __init__( def forward(self, x): ... - def training_step(self, batch, batch_idx): + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): """A single training step for a batch Notes: @@ -94,7 +87,6 @@ def training_step(self, batch, batch_idx): Loss: ℓ_lbl + ℓ_unl """ - def training_step(self, batch, batch_idx): (x_lbl, y_lbl), x_unls = batch opt = self.optimizers() @@ -175,7 +167,11 @@ def training_step(self, batch, batch_idx): } ) - def validation_step(self, batch, batch_idx): + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): # The batch outputs x_unls due to our on_before_batch_transfer (x, y), _x_unls = batch wandb.log({"val/y_lbl": wandb_hist(y, self.n_classes)}) @@ -197,7 +193,11 @@ def validation_step(self, batch, batch_idx): self.log("val/acc", acc, prog_bar=True) return loss - def test_step(self, batch, batch_idx): + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ) -> STEP_OUTPUT: # The batch outputs x_unls due to our on_before_batch_transfer (x, y), _x_unls = batch y_pred = self(x) @@ -210,7 +210,10 @@ def test_step(self, batch, batch_idx): self.log("test/acc", acc, prog_bar=True) return loss - def predict_step(self, batch, *args, **kwargs) -> Any: + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ): (x, y), _x_unls = batch y_pred = self(x) y_true_str = self.y_encoder.inverse_transform( @@ -220,37 +223,3 @@ def predict_step(self, batch, *args, **kwargs) -> Any: y_pred.argmax(dim=1).cpu().numpy().reshape(-1, 1) ) return y_true_str, y_pred_str - - @torch.no_grad() - def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: - """This method is called before any data transfer to the device. - - We leverage this to do some preprocessing on the data. - Namely, we use the StandardScaler and OrdinalEncoder to transform the - data. - - Notes: - PyTorch Lightning may complain about this being on the Module - instead of the DataModule. However, this is intentional as we - want to export the model alongside the transformations. - """ - - if self.training: - (x_lbl, y_lbl), x_unl = batch - else: - x_lbl, y_lbl = batch - x_unl = None - - return preprocess( - x_lbl=x_lbl, - y_lbl=y_lbl, - x_scaler=self.x_scaler, - y_encoder=self.y_encoder, - x_unl=x_unl, - ) - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - save_unfrozen(self, checkpoint) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - load_checkpoint_lenient(self, checkpoint) diff --git a/src/frdc/train/frdc_module.py b/src/frdc/train/frdc_module.py new file mode 100644 index 0000000..22582e9 --- /dev/null +++ b/src/frdc/train/frdc_module.py @@ -0,0 +1,143 @@ +from typing import Any, Dict, Sequence + +import numpy as np +import torch +from lightning import LightningModule +from sklearn.preprocessing import OrdinalEncoder + +from frdc.models.utils import save_unfrozen, load_checkpoint_lenient +from frdc.utils.utils import fn_recursive + + +class FRDCModule(LightningModule): + def __init__( + self, + *, + out_targets: Sequence[str], + nan_mask_missing_y_labels: bool = True, + ): + """Base Lightning Module for MixMatch + + Notes: + This is the base class for MixMatch and FixMatch. + This implements the Y-Encoder logic so that all modules can + encode and decode the tree string labels. + + Generally the hierarchy is: + Module + -> Module + -> FRDCModule + + E.g. + EfficientNetB1MixMatchModule + -> MixMatchModule + -> FRDCModule + + WideResNetFixMatchModule + -> FixMatchModule + -> FRDCModule + + Args: + out_targets: The output targets for the model. + nan_mask_missing_y_labels: Whether to mask away x values that + have missing y labels. This happens when the y label is not + present in the OrdinalEncoder's categories, which happens + during non-training steps. E.g. A new unseen tree is inferred. + """ + + super().__init__() + + self.y_encoder: OrdinalEncoder = OrdinalEncoder( + handle_unknown="use_encoded_value", + unknown_value=np.nan, + ) + self.y_encoder.fit(np.array(out_targets).reshape(-1, 1)) + self.nan_mask_missing_y_labels = nan_mask_missing_y_labels + self.save_hyperparameters() + + @property + def n_classes(self): + return len(self.y_encoder.categories_[0]) + + @torch.no_grad() + def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + """This method is called before any data transfer to the device. + + Notes: + This method wraps OrdinalEncoder to convert labels from str to int + before transferring the data to the device. + + Note that this step must happen before the transfer as tensors + don't support str types. + + PyTorch Lightning may complain about this being on the Module + instead of the DataModule. However, this is intentional as we + want to export the model alongside the transformations. + """ + + if self.training: + (x_lbl, y_lbl), x_unl = batch + else: + x_lbl, y_lbl = batch + x_unl = [] + + y_trans = torch.from_numpy( + self.y_encoder.transform(np.array(y_lbl).reshape(-1, 1))[..., 0] + ) + + # Remove nan values from the batch + # Ordinal Encoders can return a np.nan if the value is not in the + # categories. We will remove that from the batch. + nan = ( + ~torch.isnan(y_trans) # Keeps all non-nan values + if self.nan_mask_missing_y_labels + else torch.ones_like(y_trans).bool() # Keeps all values + ) + + x_lbl_trans = torch.nan_to_num(x_lbl[nan]) + + # This function applies nan_to_num to all tensors in the list, + # regardless of how deeply nested they are. + x_unl_trans = fn_recursive( + x_unl, + fn=lambda x: torch.nan_to_num(x[nan]), + type_atom=torch.Tensor, + type_list=list, + ) + y_trans = y_trans[nan].long() + + return (x_lbl_trans, y_trans), x_unl_trans + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + save_unfrozen(self, checkpoint) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + load_checkpoint_lenient(self, checkpoint) + + # The following methods are to enforce the batch schema typing. + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ) -> Any: + ... diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index 9617751..3b85785 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -1,32 +1,28 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Dict +from typing import Any, Dict, Sequence import torch import torch.nn.functional as F import wandb -from lightning import LightningModule -from sklearn.preprocessing import StandardScaler, OrdinalEncoder from torch.nn.functional import one_hot from torchmetrics.functional import accuracy -from frdc.models.utils import save_unfrozen, load_checkpoint_lenient +from frdc.models.utils import save_unfrozen +from frdc.train.frdc_module import FRDCModule from frdc.train.utils import ( mix_up, sharpen, wandb_hist, - preprocess, ) -class MixMatchModule(LightningModule): +class MixMatchModule(FRDCModule): def __init__( self, *, - x_scaler: StandardScaler, - y_encoder: OrdinalEncoder, - n_classes: int = 10, + out_targets: Sequence[str], sharpen_temp: float = 0.5, mix_beta_alpha: float = 0.75, ): @@ -44,17 +40,14 @@ def __init__( how to implement a new dataset. Args: - n_classes: The number of classes in the dataset. + out_targets: The output targets for the model. sharpen_temp: The temperature to use for sharpening. mix_beta_alpha: The alpha to use for the beta distribution when mixing. """ - super().__init__() + super().__init__(out_targets=out_targets) - self.x_scaler = x_scaler - self.y_encoder = y_encoder - self.n_classes = n_classes self.sharpen_temp = sharpen_temp self.mix_beta_alpha = mix_beta_alpha self.save_hyperparameters() @@ -114,7 +107,11 @@ def progress(self): self.global_step / self.trainer.num_training_batches ) / self.trainer.max_epochs - def training_step(self, batch, batch_idx): + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x_lbl, y_lbl), x_unls = batch self.log("train/x_lbl_mean", x_lbl.mean()) @@ -191,7 +188,11 @@ def training_step(self, batch, batch_idx): def on_after_backward(self) -> None: self.update_ema() - def validation_step(self, batch, batch_idx): + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x, y), _x_unls = batch wandb.log({"val/y_lbl": wandb_hist(y, self.n_classes)}) y_pred = self.ema_model(x) @@ -211,7 +212,11 @@ def validation_step(self, batch, batch_idx): self.log("val/acc", acc, prog_bar=True) return loss - def test_step(self, batch, batch_idx): + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x, y), _x_unls = batch y_pred = self.ema_model(x) loss = F.cross_entropy(y_pred, y.long()) @@ -223,7 +228,10 @@ def test_step(self, batch, batch_idx): self.log("test/acc", acc, prog_bar=True) return loss - def predict_step(self, batch, *args, **kwargs) -> Any: + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ) -> Any: (x, y), _x_unls = batch y_pred = self.ema_model(x) y_true_str = self.y_encoder.inverse_transform( @@ -234,40 +242,10 @@ def predict_step(self, batch, *args, **kwargs) -> Any: ) return y_true_str, y_pred_str - @torch.no_grad() - def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: - """This method is called before any data transfer to the device. - - We leverage this to do some preprocessing on the data. - Namely, we use the StandardScaler and OrdinalEncoder to transform the - data. - - Notes: - PyTorch Lightning may complain about this being on the Module - instead of the DataModule. However, this is intentional as we - want to export the model alongside the transformations. - """ - - if self.training: - (x_lbl, y_lbl), x_unl = batch - else: - x_lbl, y_lbl = batch - x_unl = None - - return preprocess( - x_lbl=x_lbl, - y_lbl=y_lbl, - x_scaler=self.x_scaler, - y_encoder=self.y_encoder, - x_unl=x_unl, - ) - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """This override the original method to save the EMAs as well.""" save_unfrozen( self, checkpoint, include_also=lambda k: k.startswith("_ema_model.fc."), ) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - load_checkpoint_lenient(self, checkpoint) diff --git a/src/frdc/train/utils.py b/src/frdc/train/utils.py index 6e85fe7..b42ce91 100644 --- a/src/frdc/train/utils.py +++ b/src/frdc/train/utils.py @@ -52,109 +52,6 @@ def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: return y_sharp -def preprocess( - x_lbl: torch.Tensor, - y_lbl: torch.Tensor, - x_scaler: StandardScaler, - y_encoder: OrdinalEncoder, - x_unl: list[torch.Tensor] = None, - nan_mask: bool = True, -) -> tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]]: - """Preprocesses the data - - Notes: - The reason why x and y's preprocessing is coupled is due to the NaN - elimination step. The NaN elimination step is due to unseen labels by y - - fn_recursive is to recursively apply some function to a nested list. - This happens due to unlabelled being a list of tensors. - - Args: - x_lbl: The data to preprocess. - y_lbl: The labels to preprocess. - x_scaler: The StandardScaler to use. - y_encoder: The OrdinalEncoder to use. - x_unl: The unlabelled data to preprocess. - nan_mask: Whether to remove nan values from the batch. - - Returns: - The preprocessed data and labels. - """ - - x_unl = [] if x_unl is None else x_unl - - x_lbl_trans = x_standard_scale(x_scaler, x_lbl) - y_trans = y_encode(y_encoder, y_lbl) - x_unl_trans = fn_recursive( - x_unl, - fn=lambda x: x_standard_scale(x_scaler, x), - type_atom=torch.Tensor, - type_list=list, - ) - - # Remove nan values from the batch - # Ordinal Encoders can return a np.nan if the value is not in the - # categories. We will remove that from the batch. - nan = ( - ~torch.isnan(y_trans) if nan_mask else torch.ones_like(y_trans).bool() - ) - x_lbl_trans = x_lbl_trans[nan] - x_lbl_trans = torch.nan_to_num(x_lbl_trans) - x_unl_trans = fn_recursive( - x_unl_trans, - fn=lambda x: torch.nan_to_num(x[nan]), - type_atom=torch.Tensor, - type_list=list, - ) - y_trans = y_trans[nan] - - return (x_lbl_trans, y_trans.long()), x_unl_trans - - -def x_standard_scale( - x_scaler: StandardScaler, x: torch.Tensor -) -> torch.Tensor: - """Standard scales the data - - Notes: - This is a wrapper around the StandardScaler to handle PyTorch tensors. - - Args: - x_scaler: The StandardScaler to use. - x: The data to standard scale, of shape (B, C, H, W). - """ - # Standard Scaler only accepts (n_samples, n_features), - # so we need to do some fancy reshaping. - # Note that moving dimensions then reshaping is different from just - # reshaping! - - # Move Channel to the last dimension then transform - # B x C x H x W -> B x H x W x C - b, c, h, w = x.shape - x_ss = x_scaler.transform(x.permute(0, 2, 3, 1).reshape(-1, c)) - - # Move Channel back to the second dimension - # B x H x W x C -> B x C x H x W - return torch.nan_to_num( - torch.from_numpy(x_ss.reshape(b, h, w, c)).permute(0, 3, 1, 2).float() - ) - - -def y_encode(y_encoder: OrdinalEncoder, y: torch.Tensor) -> torch.Tensor: - """Encodes the labels - - Notes: - This is a wrapper around the OrdinalEncoder to handle PyTorch tensors. - - Args: - y_encoder: The OrdinalEncoder to use. - y: The labels to encode. - """ - return torch.from_numpy( - y_encoder.transform(np.array(y).reshape(-1, 1))[..., 0] - ) - - def wandb_hist(x: torch.Tensor, num_bins: int) -> wandb.Histogram: """Records a W&B Histogram""" return wandb.Histogram( diff --git a/tests/integration_tests/test_pipeline.py b/tests/integration_tests/test_pipeline.py index 97676dc..3de3e6b 100644 --- a/tests/integration_tests/test_pipeline.py +++ b/tests/integration_tests/test_pipeline.py @@ -2,10 +2,8 @@ from pathlib import Path import lightning as pl -import numpy as np import pytest import torch -from sklearn.preprocessing import StandardScaler, OrdinalEncoder from frdc.models.efficientnetb1 import ( EfficientNetB1MixMatchModule, @@ -33,23 +31,10 @@ def test_manual_segmentation_pipeline(model_fn, ds): val_ds=ds, batch_size=BATCH_SIZE, ) - - oe = OrdinalEncoder( - handle_unknown="use_encoded_value", - unknown_value=np.nan, - ) - oe.fit(np.array(ds.targets).reshape(-1, 1)) - n_classes = len(oe.categories_[0]) - - ss = StandardScaler() - ss.fit(ds.ar.reshape(-1, ds.ar.shape[-1])) - m = model_fn( in_channels=ds.ar.shape[-1], - n_classes=n_classes, lr=1e-3, - x_scaler=ss, - y_encoder=oe, + out_targets=ds.targets, frozen=True, ) @@ -93,4 +78,4 @@ def test_manual_segmentation_pipeline(model_fn, ds): # E.g. achieved via hash comparison. # This is because BatchNorm usually keeps running statistics # and reloading the model will reset them. - # We don't necessarily need to + # We don't necessarily need to check for this. diff --git a/tests/model_tests/chestnut_dec_may/train_fixmatch.py b/tests/model_tests/chestnut_dec_may/train_fixmatch.py index b1df238..c83bf43 100644 --- a/tests/model_tests/chestnut_dec_may/train_fixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_fixmatch.py @@ -26,8 +26,6 @@ val_preprocess, FRDCDatasetStaticEval, n_weak_strong_aug, - get_y_encoder, - get_x_scaler, weak_aug, ) @@ -58,9 +56,12 @@ def main( im_size = 255 train_lab_ds = ds.chestnut_20201218(transform=weak_aug(im_size)) train_unl_ds = ds.chestnut_20201218.unlabelled( - transform=n_weak_strong_aug(im_size, unlabelled_factor) + transform=n_weak_strong_aug(im_size, unlabelled_factor), + ) + val_ds = ds.chestnut_20210510_43m( + transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, ) - val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size)) # Prepare the datamodule and trainer dm = FRDCDataModule( @@ -95,15 +96,10 @@ def main( ), ) - oe = get_y_encoder(train_lab_ds.targets) - ss = get_x_scaler(train_lab_ds.ar_segments) - m = EfficientNetB1FixMatchModule( in_channels=train_lab_ds.ar.shape[-1], - n_classes=len(oe.categories_[0]), + out_targets=train_lab_ds.targets, lr=lr, - x_scaler=ss, - y_encoder=oe, frozen=True, ) @@ -121,10 +117,11 @@ def main( "20210510", "90deg43m85pct255deg", transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, ), model=m, ) - fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0]) + fig, ax = plot_confusion_matrix(y_true, y_pred, m.y_encoder.categories_[0]) acc = np.sum(y_true == y_pred) / len(y_true) ax.set_title(f"Accuracy: {acc:.2%}") diff --git a/tests/model_tests/chestnut_dec_may/train_mixmatch.py b/tests/model_tests/chestnut_dec_may/train_mixmatch.py index 1bf839f..0aab51f 100644 --- a/tests/model_tests/chestnut_dec_may/train_mixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_mixmatch.py @@ -25,8 +25,6 @@ FRDCDatasetStaticEval, n_strong_aug, strong_aug, - get_y_encoder, - get_x_scaler, ) @@ -46,11 +44,16 @@ def main( ): # Prepare the dataset im_size = 299 - train_lab_ds = ds.chestnut_20201218(transform=strong_aug(im_size)) + train_lab_ds = ds.chestnut_20201218( + transform=strong_aug(im_size), + ) train_unl_ds = ds.chestnut_20201218.unlabelled( transform=n_strong_aug(im_size, 2) ) - val_ds = ds.chestnut_20210510_43m(transform=val_preprocess(im_size)) + val_ds = ds.chestnut_20210510_43m( + transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, + ) # Prepare the datamodule and trainer dm = FRDCDataModule( @@ -85,15 +88,10 @@ def main( ), ) - oe = get_y_encoder(train_lab_ds.targets) - ss = get_x_scaler(train_lab_ds.ar_segments) - m = EfficientNetB1MixMatchModule( in_channels=train_lab_ds.ar.shape[-1], - n_classes=len(oe.categories_[0]), + out_targets=train_lab_ds.targets, lr=lr, - x_scaler=ss, - y_encoder=oe, frozen=True, ) @@ -111,10 +109,11 @@ def main( "20210510", "90deg43m85pct255deg", transform=val_preprocess(im_size), + transform_scale=train_lab_ds.x_scaler, ), model=m, ) - fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0]) + fig, ax = plot_confusion_matrix(y_true, y_pred, m.y_encoder.categories_[0]) acc = np.sum(y_true == y_pred) / len(y_true) ax.set_title(f"Accuracy: {acc:.2%}") diff --git a/tests/model_tests/utils.py b/tests/model_tests/utils.py index ac5b54a..e4816b3 100644 --- a/tests/model_tests/utils.py +++ b/tests/model_tests/utils.py @@ -132,11 +132,3 @@ def get_y_encoder(targets): ) oe.fit(np.array(targets).reshape(-1, 1)) return oe - - -def get_x_scaler(segments): - ss = StandardScaler() - ss.fit( - np.concatenate([segm.reshape(-1, segm.shape[-1]) for segm in segments]) - ) - return ss