Skip to content

Commit

Permalink
Merge pull request #69 from FR-DC/frml-149
Browse files Browse the repository at this point in the history
FRML-149 Make Partial saving function more generic
  • Loading branch information
Eve-ning authored Jun 6, 2024
2 parents 71534c6 + ce8f7bb commit a8247ac
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 54 deletions.
30 changes: 1 addition & 29 deletions src/frdc/models/efficientnetb1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
EfficientNet_B1_Weights,
)

from frdc.models.utils import on_save_checkpoint, on_load_checkpoint
from frdc.models.utils import save_unfrozen, load_checkpoint_lenient
from frdc.train.fixmatch_module import FixMatchModule
from frdc.train.mixmatch_module import MixMatchModule
from frdc.utils.ema import EMA
Expand Down Expand Up @@ -139,29 +139,13 @@ def update_ema(self):
self.ema_updater.update(self.ema_lr)

def forward(self, x: torch.Tensor):
"""Forward pass."""
return self.fc(self.eff(x))

def configure_optimizers(self):
return torch.optim.Adam(
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
)

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# TODO: MixMatch's saving is a bit complicated due to the dependency
# on the EMA model. This only saves the FC for both the
# main model and the EMA model.
# This may be the reason certain things break when loading
if checkpoint["hyper_parameters"]["frozen"]:
on_save_checkpoint(
self,
checkpoint,
saved_module_prefixes=("_ema_model.fc.", "fc."),
)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
on_load_checkpoint(self, checkpoint)


class EfficientNetB1FixMatchModule(FixMatchModule):
MIN_SIZE = 255
Expand Down Expand Up @@ -209,21 +193,9 @@ def __init__(
)

def forward(self, x: torch.Tensor):
"""Forward pass."""
return self.fc(self.eff(x))

def configure_optimizers(self):
return torch.optim.Adam(
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
)

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if checkpoint["hyper_parameters"]["frozen"]:
on_save_checkpoint(
self,
checkpoint,
saved_module_prefixes=("fc.",),
)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
on_load_checkpoint(self, checkpoint)
68 changes: 45 additions & 23 deletions src/frdc/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from typing import Dict, Any, Sequence
import logging
import warnings
from typing import Dict, Any, Sequence, Callable


def on_save_checkpoint(
def save_unfrozen(
self,
checkpoint: Dict[str, Any],
saved_module_prefixes: Sequence[str] = ("fc.",),
saved_module_suffixes: Sequence[str] = (
"running_mean",
"running_var",
"num_batches_tracked",
),
include_also: Callable[[str], bool] = lambda k: False,
) -> None:
"""Saving only the classifier if frozen.
Expand All @@ -20,29 +17,54 @@ def on_save_checkpoint(
This usually reduces the model size by 99.9%, so it's worth it.
By default, this will save the classifier and the BatchNorm running
statistics.
By default, this will save any parameter that requires grad
and the BatchNorm running statistics.
Args:
self: Not used, but kept for consistency with on_load_checkpoint.
checkpoint: The checkpoint to save.
saved_module_prefixes: The prefixes of the modules to save.
saved_module_suffixes: The suffixes of the modules to save.
include_also: A function that returns whether to include a parameter,
on top of any parameter that requires grad and BatchNorm running
statistics.
"""

if checkpoint["hyper_parameters"]["frozen"]:
# Keep only the classifier
checkpoint["state_dict"] = {
k: v
for k, v in checkpoint["state_dict"].items()
if (
k.startswith(saved_module_prefixes)
or k.endswith(saved_module_suffixes)
)
}
# Keep only the classifier
new_state_dict = {}

for k, v in checkpoint["state_dict"].items():
# We keep 2 things,
# 1. The BatchNorm running statistics
# 2. Anything that requires grad

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# BatchNorm running statistics should be kept
# for closer reconstruction of the model
is_bn_var = k.endswith(
("running_mean", "running_var", "num_batches_tracked")
)
try:
# We need to retrieve it from the original model
# as the state dict already freezes the model
is_required_grad = self.get_parameter(k).requires_grad
except AttributeError:
if not is_bn_var:
warnings.warn(
f"Unknown non-parameter key in state_dict. {k}."
f"This is an edge case where it's not a parameter nor "
f"BatchNorm running statistics. This will still be saved."
)
is_required_grad = True

# These are additional parameters to keep
is_include = include_also(k)

if is_required_grad or is_bn_var or is_include:
logging.debug(f"Keeping {k}")
new_state_dict[k] = v

checkpoint["state_dict"] = new_state_dict


def load_checkpoint_lenient(self, checkpoint: Dict[str, Any]) -> None:
"""Loading only the classifier if frozen
Notes:
Expand Down
10 changes: 9 additions & 1 deletion src/frdc/train/fixmatch_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Any
from typing import Any, Dict

import torch
import torch.nn.functional as F
Expand All @@ -10,6 +10,7 @@
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
from torchmetrics.functional import accuracy

from frdc.models.utils import save_unfrozen, load_checkpoint_lenient
from frdc.train.utils import (
wandb_hist,
preprocess,
Expand Down Expand Up @@ -92,6 +93,7 @@ def training_step(self, batch, batch_idx):
Loss: ℓ_lbl + ℓ_unl
"""

def training_step(self, batch, batch_idx):
(x_lbl, y_lbl), x_unls = batch
opt = self.optimizers()
Expand Down Expand Up @@ -246,3 +248,9 @@ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
y_encoder=self.y_encoder,
x_unl=x_unl,
)

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
save_unfrozen(self, checkpoint)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
load_checkpoint_lenient(self, checkpoint)
13 changes: 12 additions & 1 deletion src/frdc/train/mixmatch_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Any
from typing import Any, Dict

import torch
import torch.nn.functional as F
Expand All @@ -11,6 +11,7 @@
from torch.nn.functional import one_hot
from torchmetrics.functional import accuracy

from frdc.models.utils import save_unfrozen, load_checkpoint_lenient
from frdc.train.utils import (
mix_up,
sharpen,
Expand Down Expand Up @@ -260,3 +261,13 @@ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
y_encoder=self.y_encoder,
x_unl=x_unl,
)

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
save_unfrozen(
self,
checkpoint,
include_also=lambda k: k.startswith("_ema_model.fc."),
)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
load_checkpoint_lenient(self, checkpoint)

0 comments on commit a8247ac

Please sign in to comment.