Skip to content

Commit

Permalink
Refactor common modules to FRDCModule
Browse files Browse the repository at this point in the history
  • Loading branch information
Eve-ning committed Jun 10, 2024
1 parent 8497c3e commit ab9de7c
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 201 deletions.
28 changes: 9 additions & 19 deletions src/frdc/models/efficientnetb1.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from copy import deepcopy
from typing import Dict, Any
from typing import Sequence

import torch
from sklearn.preprocessing import OrdinalEncoder, StandardScaler
from torch import nn
from torchvision.models import (
EfficientNet,
efficientnet_b1,
EfficientNet_B1_Weights,
)

from frdc.models.utils import save_unfrozen, load_checkpoint_lenient
from frdc.train.fixmatch_module import FixMatchModule
from frdc.train.mixmatch_module import MixMatchModule
from frdc.utils.ema import EMA
Expand Down Expand Up @@ -81,9 +79,8 @@ def __init__(
self,
*,
in_channels: int,
n_classes: int,
out_targets: Sequence[str],
lr: float,
y_encoder: OrdinalEncoder,
ema_lr: float = 0.001,
weight_decay: float = 1e-5,
frozen: bool = True,
Expand All @@ -92,9 +89,8 @@ def __init__(
Args:
in_channels: The number of input channels.
n_classes: The number of classes.
out_targets: The output targets.
lr: The learning rate.
y_encoder: The Y input OrdinalEncoder.
ema_lr: The learning rate for the EMA model.
weight_decay: The weight decay.
frozen: Whether to freeze the base model.
Expand All @@ -106,15 +102,14 @@ def __init__(
self.weight_decay = weight_decay

super().__init__(
n_classes=n_classes,
y_encoder=y_encoder,
out_targets=out_targets,
sharpen_temp=0.5,
mix_beta_alpha=0.75,
)

self.eff = efficientnet_b1_backbone(in_channels, frozen)
self.fc = nn.Sequential(
nn.Linear(self.EFF_OUT_DIMS, n_classes),
nn.Linear(self.EFF_OUT_DIMS, self.n_classes),
nn.Softmax(dim=1),
)

Expand Down Expand Up @@ -152,19 +147,17 @@ def __init__(
self,
*,
in_channels: int,
n_classes: int,
out_targets: Sequence[str],
lr: float,
y_encoder: OrdinalEncoder,
weight_decay: float = 1e-5,
frozen: bool = True,
):
"""Initialize the EfficientNet model.
Args:
in_channels: The number of input channels.
n_classes: The number of classes.
out_targets: The output targets.
lr: The learning rate.
y_encoder: The Y input OrdinalEncoder.
weight_decay: The weight decay.
frozen: Whether to freeze the base model.
Expand All @@ -174,15 +167,12 @@ def __init__(
self.lr = lr
self.weight_decay = weight_decay

super().__init__(
n_classes=n_classes,
y_encoder=y_encoder,
)
super().__init__(out_targets=out_targets)

self.eff = efficientnet_b1_backbone(in_channels, frozen)

self.fc = nn.Sequential(
nn.Linear(self.EFF_OUT_DIMS, n_classes),
nn.Linear(self.EFF_OUT_DIMS, self.n_classes),
nn.Softmax(dim=1),
)

Expand Down
81 changes: 27 additions & 54 deletions src/frdc/train/fixmatch_module.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,23 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Any, Dict
from typing import Sequence

import torch
import torch.nn.functional as F
import wandb
from lightning import LightningModule
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torchmetrics.functional import accuracy

from frdc.models.utils import save_unfrozen, load_checkpoint_lenient
from frdc.train.utils import (
wandb_hist,
preprocess,
)
from frdc.train.frdc_module import FRDCModule
from frdc.train.utils import wandb_hist


class FixMatchModule(LightningModule):
class FixMatchModule(FRDCModule):
def __init__(
self,
*,
y_encoder: OrdinalEncoder,
n_classes: int = 10,
out_targets: Sequence[str],
unl_conf_threshold: float = 0.95,
):
"""PyTorch Lightning Module for MixMatch
Expand All @@ -39,16 +34,13 @@ def __init__(
how to implement a new dataset.
Args:
n_classes: The number of classes in the dataset.
y_encoder: The OrdinalEncoder to use for the labels.
out_targets: The output targets for the model.
unl_conf_threshold: The confidence threshold for unlabelled data
to be considered correctly labelled.
"""

super().__init__()
super().__init__(out_targets=out_targets)

self.y_encoder = y_encoder
self.n_classes = n_classes
self.unl_conf_threshold = unl_conf_threshold
self.save_hyperparameters()

Expand All @@ -60,7 +52,11 @@ def __init__(
def forward(self, x):
...

def training_step(self, batch, batch_idx):
def training_step(
self,
batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]],
batch_idx: int,
):
"""A single training step for a batch
Notes:
Expand Down Expand Up @@ -91,7 +87,6 @@ def training_step(self, batch, batch_idx):
Loss: ℓ_lbl + ℓ_unl
"""

def training_step(self, batch, batch_idx):
(x_lbl, y_lbl), x_unls = batch
opt = self.optimizers()

Expand Down Expand Up @@ -172,7 +167,11 @@ def training_step(self, batch, batch_idx):
}
)

def validation_step(self, batch, batch_idx):
def validation_step(
self,
batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]],
batch_idx: int,
):
# The batch outputs x_unls due to our on_before_batch_transfer
(x, y), _x_unls = batch
wandb.log({"val/y_lbl": wandb_hist(y, self.n_classes)})
Expand All @@ -194,7 +193,11 @@ def validation_step(self, batch, batch_idx):
self.log("val/acc", acc, prog_bar=True)
return loss

def test_step(self, batch, batch_idx):
def test_step(
self,
batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]],
batch_idx: int,
) -> STEP_OUTPUT:
# The batch outputs x_unls due to our on_before_batch_transfer
(x, y), _x_unls = batch
y_pred = self(x)
Expand All @@ -207,7 +210,10 @@ def test_step(self, batch, batch_idx):
self.log("test/acc", acc, prog_bar=True)
return loss

def predict_step(self, batch, *args, **kwargs) -> Any:
def predict_step(
self,
batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]],
):
(x, y), _x_unls = batch
y_pred = self(x)
y_true_str = self.y_encoder.inverse_transform(
Expand All @@ -217,36 +223,3 @@ def predict_step(self, batch, *args, **kwargs) -> Any:
y_pred.argmax(dim=1).cpu().numpy().reshape(-1, 1)
)
return y_true_str, y_pred_str

@torch.no_grad()
def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
"""This method is called before any data transfer to the device.
We leverage this to do some preprocessing on the data.
Namely, we use the StandardScaler and OrdinalEncoder to transform the
data.
Notes:
PyTorch Lightning may complain about this being on the Module
instead of the DataModule. However, this is intentional as we
want to export the model alongside the transformations.
"""

if self.training:
(x_lbl, y_lbl), x_unl = batch
else:
x_lbl, y_lbl = batch
x_unl = None

return preprocess(
x_lbl=x_lbl,
y_lbl=y_lbl,
y_encoder=self.y_encoder,
x_unl=x_unl,
)

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
save_unfrozen(self, checkpoint)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
load_checkpoint_lenient(self, checkpoint)
143 changes: 143 additions & 0 deletions src/frdc/train/frdc_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from typing import Any, Dict, Sequence

import numpy as np
import torch
from lightning import LightningModule
from sklearn.preprocessing import OrdinalEncoder

from frdc.models.utils import save_unfrozen, load_checkpoint_lenient
from frdc.utils.utils import fn_recursive


class FRDCModule(LightningModule):
def __init__(
self,
*,
out_targets: Sequence[str],
nan_mask_missing_y_labels: bool = True,
):
"""Base Lightning Module for MixMatch
Notes:
This is the base class for MixMatch and FixMatch.
This implements the Y-Encoder logic so that all modules can
encode and decode the tree string labels.
Generally the hierarchy is:
<Model><Architecture>Module
-> <Architecture>Module
-> FRDCModule
E.g.
EfficientNetB1MixMatchModule
-> MixMatchModule
-> FRDCModule
WideResNetFixMatchModule
-> FixMatchModule
-> FRDCModule
Args:
out_targets: The output targets for the model.
nan_mask_missing_y_labels: Whether to mask away x values that
have missing y labels. This happens when the y label is not
present in the OrdinalEncoder's categories, which happens
during non-training steps. E.g. A new unseen tree is inferred.
"""

super().__init__()

self.y_encoder: OrdinalEncoder = OrdinalEncoder(
handle_unknown="use_encoded_value",
unknown_value=np.nan,
)
self.y_encoder.fit(np.array(out_targets).reshape(-1, 1))
self.nan_mask_missing_y_labels = nan_mask_missing_y_labels
self.save_hyperparameters()

@property
def n_classes(self):
return len(self.y_encoder.categories_[0])

@torch.no_grad()
def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
"""This method is called before any data transfer to the device.
Notes:
This method wraps OrdinalEncoder to convert labels from str to int
before transferring the data to the device.
Note that this step must happen before the transfer as tensors
don't support str types.
PyTorch Lightning may complain about this being on the Module
instead of the DataModule. However, this is intentional as we
want to export the model alongside the transformations.
"""

if self.training:
(x_lbl, y_lbl), x_unl = batch
else:
x_lbl, y_lbl = batch
x_unl = []

y_trans = torch.from_numpy(
self.y_encoder.transform(np.array(y_lbl).reshape(-1, 1))[..., 0]
)

# Remove nan values from the batch
# Ordinal Encoders can return a np.nan if the value is not in the
# categories. We will remove that from the batch.
nan = (
~torch.isnan(y_trans) # Keeps all non-nan values
if self.nan_mask_missing_y_labels
else torch.ones_like(y_trans).bool() # Keeps all values
)

x_lbl_trans = torch.nan_to_num(x_lbl[nan])

# This function applies nan_to_num to all tensors in the list,
# regardless of how deeply nested they are.
x_unl_trans = fn_recursive(
x_unl,
fn=lambda x: torch.nan_to_num(x[nan]),
type_atom=torch.Tensor,
type_list=list,
)
y_trans = y_trans[nan].long()

return (x_lbl_trans, y_trans), x_unl_trans

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
save_unfrozen(self, checkpoint)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
load_checkpoint_lenient(self, checkpoint)

# The following methods are to enforce the batch schema typing.
def training_step(
self,
batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]],
batch_idx: int,
):
...

def validation_step(
self,
batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]],
batch_idx: int,
):
...

def test_step(
self,
batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]],
batch_idx: int,
):
...

def predict_step(
self,
batch: tuple[tuple[torch.Tensor, torch.Tensor], list[torch.Tensor]],
) -> Any:
...
Loading

0 comments on commit ab9de7c

Please sign in to comment.