From ab9de7ca2e5f8601b32936ff868494a15a91b50e Mon Sep 17 00:00:00 2001 From: Evening Date: Mon, 10 Jun 2024 12:40:21 +0800 Subject: [PATCH] Refactor common modules to FRDCModule --- src/frdc/models/efficientnetb1.py | 28 ++-- src/frdc/train/fixmatch_module.py | 81 ++++------ src/frdc/train/frdc_module.py | 143 ++++++++++++++++++ src/frdc/train/mixmatch_module.py | 73 ++++----- src/frdc/train/utils.py | 52 ------- tests/integration_tests/test_pipeline.py | 18 +-- .../chestnut_dec_may/train_fixmatch.py | 7 +- .../chestnut_dec_may/train_mixmatch.py | 11 +- 8 files changed, 212 insertions(+), 201 deletions(-) create mode 100644 src/frdc/train/frdc_module.py diff --git a/src/frdc/models/efficientnetb1.py b/src/frdc/models/efficientnetb1.py index d793249..ae87d8d 100644 --- a/src/frdc/models/efficientnetb1.py +++ b/src/frdc/models/efficientnetb1.py @@ -1,8 +1,7 @@ from copy import deepcopy -from typing import Dict, Any +from typing import Sequence import torch -from sklearn.preprocessing import OrdinalEncoder, StandardScaler from torch import nn from torchvision.models import ( EfficientNet, @@ -10,7 +9,6 @@ EfficientNet_B1_Weights, ) -from frdc.models.utils import save_unfrozen, load_checkpoint_lenient from frdc.train.fixmatch_module import FixMatchModule from frdc.train.mixmatch_module import MixMatchModule from frdc.utils.ema import EMA @@ -81,9 +79,8 @@ def __init__( self, *, in_channels: int, - n_classes: int, + out_targets: Sequence[str], lr: float, - y_encoder: OrdinalEncoder, ema_lr: float = 0.001, weight_decay: float = 1e-5, frozen: bool = True, @@ -92,9 +89,8 @@ def __init__( Args: in_channels: The number of input channels. - n_classes: The number of classes. + out_targets: The output targets. lr: The learning rate. - y_encoder: The Y input OrdinalEncoder. ema_lr: The learning rate for the EMA model. weight_decay: The weight decay. frozen: Whether to freeze the base model. @@ -106,15 +102,14 @@ def __init__( self.weight_decay = weight_decay super().__init__( - n_classes=n_classes, - y_encoder=y_encoder, + out_targets=out_targets, sharpen_temp=0.5, mix_beta_alpha=0.75, ) self.eff = efficientnet_b1_backbone(in_channels, frozen) self.fc = nn.Sequential( - nn.Linear(self.EFF_OUT_DIMS, n_classes), + nn.Linear(self.EFF_OUT_DIMS, self.n_classes), nn.Softmax(dim=1), ) @@ -152,9 +147,8 @@ def __init__( self, *, in_channels: int, - n_classes: int, + out_targets: Sequence[str], lr: float, - y_encoder: OrdinalEncoder, weight_decay: float = 1e-5, frozen: bool = True, ): @@ -162,9 +156,8 @@ def __init__( Args: in_channels: The number of input channels. - n_classes: The number of classes. + out_targets: The output targets. lr: The learning rate. - y_encoder: The Y input OrdinalEncoder. weight_decay: The weight decay. frozen: Whether to freeze the base model. @@ -174,15 +167,12 @@ def __init__( self.lr = lr self.weight_decay = weight_decay - super().__init__( - n_classes=n_classes, - y_encoder=y_encoder, - ) + super().__init__(out_targets=out_targets) self.eff = efficientnet_b1_backbone(in_channels, frozen) self.fc = nn.Sequential( - nn.Linear(self.EFF_OUT_DIMS, n_classes), + nn.Linear(self.EFF_OUT_DIMS, self.n_classes), nn.Softmax(dim=1), ) diff --git a/src/frdc/train/fixmatch_module.py b/src/frdc/train/fixmatch_module.py index 93e8c9a..f29316c 100644 --- a/src/frdc/train/fixmatch_module.py +++ b/src/frdc/train/fixmatch_module.py @@ -1,28 +1,23 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Dict +from typing import Sequence import torch import torch.nn.functional as F import wandb -from lightning import LightningModule -from sklearn.preprocessing import StandardScaler, OrdinalEncoder +from lightning.pytorch.utilities.types import STEP_OUTPUT from torchmetrics.functional import accuracy -from frdc.models.utils import save_unfrozen, load_checkpoint_lenient -from frdc.train.utils import ( - wandb_hist, - preprocess, -) +from frdc.train.frdc_module import FRDCModule +from frdc.train.utils import wandb_hist -class FixMatchModule(LightningModule): +class FixMatchModule(FRDCModule): def __init__( self, *, - y_encoder: OrdinalEncoder, - n_classes: int = 10, + out_targets: Sequence[str], unl_conf_threshold: float = 0.95, ): """PyTorch Lightning Module for MixMatch @@ -39,16 +34,13 @@ def __init__( how to implement a new dataset. Args: - n_classes: The number of classes in the dataset. - y_encoder: The OrdinalEncoder to use for the labels. + out_targets: The output targets for the model. unl_conf_threshold: The confidence threshold for unlabelled data to be considered correctly labelled. """ - super().__init__() + super().__init__(out_targets=out_targets) - self.y_encoder = y_encoder - self.n_classes = n_classes self.unl_conf_threshold = unl_conf_threshold self.save_hyperparameters() @@ -60,7 +52,11 @@ def __init__( def forward(self, x): ... - def training_step(self, batch, batch_idx): + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): """A single training step for a batch Notes: @@ -91,7 +87,6 @@ def training_step(self, batch, batch_idx): Loss: ℓ_lbl + ℓ_unl """ - def training_step(self, batch, batch_idx): (x_lbl, y_lbl), x_unls = batch opt = self.optimizers() @@ -172,7 +167,11 @@ def training_step(self, batch, batch_idx): } ) - def validation_step(self, batch, batch_idx): + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): # The batch outputs x_unls due to our on_before_batch_transfer (x, y), _x_unls = batch wandb.log({"val/y_lbl": wandb_hist(y, self.n_classes)}) @@ -194,7 +193,11 @@ def validation_step(self, batch, batch_idx): self.log("val/acc", acc, prog_bar=True) return loss - def test_step(self, batch, batch_idx): + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ) -> STEP_OUTPUT: # The batch outputs x_unls due to our on_before_batch_transfer (x, y), _x_unls = batch y_pred = self(x) @@ -207,7 +210,10 @@ def test_step(self, batch, batch_idx): self.log("test/acc", acc, prog_bar=True) return loss - def predict_step(self, batch, *args, **kwargs) -> Any: + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ): (x, y), _x_unls = batch y_pred = self(x) y_true_str = self.y_encoder.inverse_transform( @@ -217,36 +223,3 @@ def predict_step(self, batch, *args, **kwargs) -> Any: y_pred.argmax(dim=1).cpu().numpy().reshape(-1, 1) ) return y_true_str, y_pred_str - - @torch.no_grad() - def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: - """This method is called before any data transfer to the device. - - We leverage this to do some preprocessing on the data. - Namely, we use the StandardScaler and OrdinalEncoder to transform the - data. - - Notes: - PyTorch Lightning may complain about this being on the Module - instead of the DataModule. However, this is intentional as we - want to export the model alongside the transformations. - """ - - if self.training: - (x_lbl, y_lbl), x_unl = batch - else: - x_lbl, y_lbl = batch - x_unl = None - - return preprocess( - x_lbl=x_lbl, - y_lbl=y_lbl, - y_encoder=self.y_encoder, - x_unl=x_unl, - ) - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - save_unfrozen(self, checkpoint) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - load_checkpoint_lenient(self, checkpoint) diff --git a/src/frdc/train/frdc_module.py b/src/frdc/train/frdc_module.py new file mode 100644 index 0000000..22582e9 --- /dev/null +++ b/src/frdc/train/frdc_module.py @@ -0,0 +1,143 @@ +from typing import Any, Dict, Sequence + +import numpy as np +import torch +from lightning import LightningModule +from sklearn.preprocessing import OrdinalEncoder + +from frdc.models.utils import save_unfrozen, load_checkpoint_lenient +from frdc.utils.utils import fn_recursive + + +class FRDCModule(LightningModule): + def __init__( + self, + *, + out_targets: Sequence[str], + nan_mask_missing_y_labels: bool = True, + ): + """Base Lightning Module for MixMatch + + Notes: + This is the base class for MixMatch and FixMatch. + This implements the Y-Encoder logic so that all modules can + encode and decode the tree string labels. + + Generally the hierarchy is: + Module + -> Module + -> FRDCModule + + E.g. + EfficientNetB1MixMatchModule + -> MixMatchModule + -> FRDCModule + + WideResNetFixMatchModule + -> FixMatchModule + -> FRDCModule + + Args: + out_targets: The output targets for the model. + nan_mask_missing_y_labels: Whether to mask away x values that + have missing y labels. This happens when the y label is not + present in the OrdinalEncoder's categories, which happens + during non-training steps. E.g. A new unseen tree is inferred. + """ + + super().__init__() + + self.y_encoder: OrdinalEncoder = OrdinalEncoder( + handle_unknown="use_encoded_value", + unknown_value=np.nan, + ) + self.y_encoder.fit(np.array(out_targets).reshape(-1, 1)) + self.nan_mask_missing_y_labels = nan_mask_missing_y_labels + self.save_hyperparameters() + + @property + def n_classes(self): + return len(self.y_encoder.categories_[0]) + + @torch.no_grad() + def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + """This method is called before any data transfer to the device. + + Notes: + This method wraps OrdinalEncoder to convert labels from str to int + before transferring the data to the device. + + Note that this step must happen before the transfer as tensors + don't support str types. + + PyTorch Lightning may complain about this being on the Module + instead of the DataModule. However, this is intentional as we + want to export the model alongside the transformations. + """ + + if self.training: + (x_lbl, y_lbl), x_unl = batch + else: + x_lbl, y_lbl = batch + x_unl = [] + + y_trans = torch.from_numpy( + self.y_encoder.transform(np.array(y_lbl).reshape(-1, 1))[..., 0] + ) + + # Remove nan values from the batch + # Ordinal Encoders can return a np.nan if the value is not in the + # categories. We will remove that from the batch. + nan = ( + ~torch.isnan(y_trans) # Keeps all non-nan values + if self.nan_mask_missing_y_labels + else torch.ones_like(y_trans).bool() # Keeps all values + ) + + x_lbl_trans = torch.nan_to_num(x_lbl[nan]) + + # This function applies nan_to_num to all tensors in the list, + # regardless of how deeply nested they are. + x_unl_trans = fn_recursive( + x_unl, + fn=lambda x: torch.nan_to_num(x[nan]), + type_atom=torch.Tensor, + type_list=list, + ) + y_trans = y_trans[nan].long() + + return (x_lbl_trans, y_trans), x_unl_trans + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + save_unfrozen(self, checkpoint) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + load_checkpoint_lenient(self, checkpoint) + + # The following methods are to enforce the batch schema typing. + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): + ... + + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ) -> Any: + ... diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index cf6442a..3b85785 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -1,31 +1,28 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Dict +from typing import Any, Dict, Sequence import torch import torch.nn.functional as F import wandb -from lightning import LightningModule -from sklearn.preprocessing import StandardScaler, OrdinalEncoder from torch.nn.functional import one_hot from torchmetrics.functional import accuracy -from frdc.models.utils import save_unfrozen, load_checkpoint_lenient +from frdc.models.utils import save_unfrozen +from frdc.train.frdc_module import FRDCModule from frdc.train.utils import ( mix_up, sharpen, wandb_hist, - preprocess, ) -class MixMatchModule(LightningModule): +class MixMatchModule(FRDCModule): def __init__( self, *, - y_encoder: OrdinalEncoder, - n_classes: int = 10, + out_targets: Sequence[str], sharpen_temp: float = 0.5, mix_beta_alpha: float = 0.75, ): @@ -43,16 +40,14 @@ def __init__( how to implement a new dataset. Args: - n_classes: The number of classes in the dataset. + out_targets: The output targets for the model. sharpen_temp: The temperature to use for sharpening. mix_beta_alpha: The alpha to use for the beta distribution when mixing. """ - super().__init__() + super().__init__(out_targets=out_targets) - self.y_encoder = y_encoder - self.n_classes = n_classes self.sharpen_temp = sharpen_temp self.mix_beta_alpha = mix_beta_alpha self.save_hyperparameters() @@ -112,7 +107,11 @@ def progress(self): self.global_step / self.trainer.num_training_batches ) / self.trainer.max_epochs - def training_step(self, batch, batch_idx): + def training_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x_lbl, y_lbl), x_unls = batch self.log("train/x_lbl_mean", x_lbl.mean()) @@ -189,7 +188,11 @@ def training_step(self, batch, batch_idx): def on_after_backward(self) -> None: self.update_ema() - def validation_step(self, batch, batch_idx): + def validation_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x, y), _x_unls = batch wandb.log({"val/y_lbl": wandb_hist(y, self.n_classes)}) y_pred = self.ema_model(x) @@ -209,7 +212,11 @@ def validation_step(self, batch, batch_idx): self.log("val/acc", acc, prog_bar=True) return loss - def test_step(self, batch, batch_idx): + def test_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + batch_idx: int, + ): (x, y), _x_unls = batch y_pred = self.ema_model(x) loss = F.cross_entropy(y_pred, y.long()) @@ -221,7 +228,10 @@ def test_step(self, batch, batch_idx): self.log("test/acc", acc, prog_bar=True) return loss - def predict_step(self, batch, *args, **kwargs) -> Any: + def predict_step( + self, + batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]], + ) -> Any: (x, y), _x_unls = batch y_pred = self.ema_model(x) y_true_str = self.y_encoder.inverse_transform( @@ -232,39 +242,10 @@ def predict_step(self, batch, *args, **kwargs) -> Any: ) return y_true_str, y_pred_str - @torch.no_grad() - def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: - """This method is called before any data transfer to the device. - - We leverage this to do some preprocessing on the data. - Namely, we use the StandardScaler and OrdinalEncoder to transform the - data. - - Notes: - PyTorch Lightning may complain about this being on the Module - instead of the DataModule. However, this is intentional as we - want to export the model alongside the transformations. - """ - - if self.training: - (x_lbl, y_lbl), x_unl = batch - else: - x_lbl, y_lbl = batch - x_unl = None - - return preprocess( - x_lbl=x_lbl, - y_lbl=y_lbl, - y_encoder=self.y_encoder, - x_unl=x_unl, - ) - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """This override the original method to save the EMAs as well.""" save_unfrozen( self, checkpoint, include_also=lambda k: k.startswith("_ema_model.fc."), ) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - load_checkpoint_lenient(self, checkpoint) diff --git a/src/frdc/train/utils.py b/src/frdc/train/utils.py index bb1b2e4..b42ce91 100644 --- a/src/frdc/train/utils.py +++ b/src/frdc/train/utils.py @@ -52,58 +52,6 @@ def sharpen(y: torch.Tensor, temp: float) -> torch.Tensor: return y_sharp -def preprocess( - x_lbl: torch.Tensor, - y_lbl: torch.Tensor, - y_encoder: OrdinalEncoder, - x_unl: list[torch.Tensor] = None, - nan_mask: bool = True, -) -> tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]]: - """Preprocesses the data - - Notes: - The reason why x and y's preprocessing is coupled is due to the NaN - elimination step. The NaN elimination step is due to unseen labels by y - - fn_recursive is to recursively apply some function to a nested list. - This happens due to unlabelled being a list of tensors. - - Args: - x_lbl: The data to preprocess. - y_lbl: The labels to preprocess. - y_encoder: The OrdinalEncoder to use. - x_unl: The unlabelled data to preprocess. - nan_mask: Whether to remove nan values from the batch. - - Returns: - The preprocessed data and labels. - """ - - x_unl = [] if x_unl is None else x_unl - - y_trans = torch.from_numpy( - y_encoder.transform(np.array(y_lbl).reshape(-1, 1))[..., 0] - ) - - # Remove nan values from the batch - # Ordinal Encoders can return a np.nan if the value is not in the - # categories. We will remove that from the batch. - nan = ( - ~torch.isnan(y_trans) if nan_mask else torch.ones_like(y_trans).bool() - ) - x_lbl_trans = x_lbl[nan] - x_lbl_trans = torch.nan_to_num(x_lbl_trans) - x_unl_trans = fn_recursive( - x_unl, - fn=lambda x: torch.nan_to_num(x[nan]), - type_atom=torch.Tensor, - type_list=list, - ) - y_trans = y_trans[nan] - - return (x_lbl_trans, y_trans.long()), x_unl_trans - - def wandb_hist(x: torch.Tensor, num_bins: int) -> wandb.Histogram: """Records a W&B Histogram""" return wandb.Histogram( diff --git a/tests/integration_tests/test_pipeline.py b/tests/integration_tests/test_pipeline.py index 9db094c..3de3e6b 100644 --- a/tests/integration_tests/test_pipeline.py +++ b/tests/integration_tests/test_pipeline.py @@ -2,10 +2,8 @@ from pathlib import Path import lightning as pl -import numpy as np import pytest import torch -from sklearn.preprocessing import StandardScaler, OrdinalEncoder from frdc.models.efficientnetb1 import ( EfficientNetB1MixMatchModule, @@ -33,22 +31,10 @@ def test_manual_segmentation_pipeline(model_fn, ds): val_ds=ds, batch_size=BATCH_SIZE, ) - - oe = OrdinalEncoder( - handle_unknown="use_encoded_value", - unknown_value=np.nan, - ) - oe.fit(np.array(ds.targets).reshape(-1, 1)) - n_classes = len(oe.categories_[0]) - - ss = StandardScaler() - ss.fit(ds.ar.reshape(-1, ds.ar.shape[-1])) - m = model_fn( in_channels=ds.ar.shape[-1], - n_classes=n_classes, lr=1e-3, - y_encoder=oe, + out_targets=ds.targets, frozen=True, ) @@ -92,4 +78,4 @@ def test_manual_segmentation_pipeline(model_fn, ds): # E.g. achieved via hash comparison. # This is because BatchNorm usually keeps running statistics # and reloading the model will reset them. - # We don't necessarily need to + # We don't necessarily need to check for this. diff --git a/tests/model_tests/chestnut_dec_may/train_fixmatch.py b/tests/model_tests/chestnut_dec_may/train_fixmatch.py index 7975883..b5804de 100644 --- a/tests/model_tests/chestnut_dec_may/train_fixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_fixmatch.py @@ -94,13 +94,10 @@ def main( ), ) - oe = get_y_encoder(train_lab_ds.targets) - m = EfficientNetB1FixMatchModule( in_channels=train_lab_ds.ar.shape[-1], - n_classes=len(oe.categories_[0]), + out_targets=train_lab_ds.targets, lr=lr, - y_encoder=oe, frozen=True, ) @@ -121,7 +118,7 @@ def main( ), model=m, ) - fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0]) + fig, ax = plot_confusion_matrix(y_true, y_pred, m.y_encoder.categories_[0]) acc = np.sum(y_true == y_pred) / len(y_true) ax.set_title(f"Accuracy: {acc:.2%}") diff --git a/tests/model_tests/chestnut_dec_may/train_mixmatch.py b/tests/model_tests/chestnut_dec_may/train_mixmatch.py index 1bf839f..7b85260 100644 --- a/tests/model_tests/chestnut_dec_may/train_mixmatch.py +++ b/tests/model_tests/chestnut_dec_may/train_mixmatch.py @@ -25,8 +25,6 @@ FRDCDatasetStaticEval, n_strong_aug, strong_aug, - get_y_encoder, - get_x_scaler, ) @@ -85,15 +83,10 @@ def main( ), ) - oe = get_y_encoder(train_lab_ds.targets) - ss = get_x_scaler(train_lab_ds.ar_segments) - m = EfficientNetB1MixMatchModule( in_channels=train_lab_ds.ar.shape[-1], - n_classes=len(oe.categories_[0]), + out_targets=train_lab_ds.targets, lr=lr, - x_scaler=ss, - y_encoder=oe, frozen=True, ) @@ -114,7 +107,7 @@ def main( ), model=m, ) - fig, ax = plot_confusion_matrix(y_true, y_pred, oe.categories_[0]) + fig, ax = plot_confusion_matrix(y_true, y_pred, m.y_encoder.categories_[0]) acc = np.sum(y_true == y_pred) / len(y_true) ax.set_title(f"Accuracy: {acc:.2%}")