From 0c11e6ae9cd947d0a0de77becf8544e55e55b8fc Mon Sep 17 00:00:00 2001 From: Evening Date: Wed, 5 Jun 2024 17:22:55 +0800 Subject: [PATCH 1/2] Make Partial saving function more generic --- src/frdc/models/efficientnetb1.py | 30 +------------- src/frdc/models/utils.py | 68 ++++++++++++++++++++----------- src/frdc/train/fixmatch_module.py | 10 ++++- src/frdc/train/mixmatch_module.py | 13 +++++- 4 files changed, 67 insertions(+), 54 deletions(-) diff --git a/src/frdc/models/efficientnetb1.py b/src/frdc/models/efficientnetb1.py index 28276aa..5d5b4ef 100644 --- a/src/frdc/models/efficientnetb1.py +++ b/src/frdc/models/efficientnetb1.py @@ -10,7 +10,7 @@ EfficientNet_B1_Weights, ) -from frdc.models.utils import on_save_checkpoint, on_load_checkpoint +from frdc.models.utils import save_unfrozen, load_checkpoint_lenient from frdc.train.fixmatch_module import FixMatchModule from frdc.train.mixmatch_module import MixMatchModule from frdc.utils.ema import EMA @@ -139,7 +139,6 @@ def update_ema(self): self.ema_updater.update(self.ema_lr) def forward(self, x: torch.Tensor): - """Forward pass.""" return self.fc(self.eff(x)) def configure_optimizers(self): @@ -147,21 +146,6 @@ def configure_optimizers(self): self.parameters(), lr=self.lr, weight_decay=self.weight_decay ) - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - # TODO: MixMatch's saving is a bit complicated due to the dependency - # on the EMA model. This only saves the FC for both the - # main model and the EMA model. - # This may be the reason certain things break when loading - if checkpoint["hyper_parameters"]["frozen"]: - on_save_checkpoint( - self, - checkpoint, - saved_module_prefixes=("_ema_model.fc.", "fc."), - ) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - on_load_checkpoint(self, checkpoint) - class EfficientNetB1FixMatchModule(FixMatchModule): MIN_SIZE = 255 @@ -209,21 +193,9 @@ def __init__( ) def forward(self, x: torch.Tensor): - """Forward pass.""" return self.fc(self.eff(x)) def configure_optimizers(self): return torch.optim.Adam( self.parameters(), lr=self.lr, weight_decay=self.weight_decay ) - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - if checkpoint["hyper_parameters"]["frozen"]: - on_save_checkpoint( - self, - checkpoint, - saved_module_prefixes=("fc.",), - ) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - on_load_checkpoint(self, checkpoint) diff --git a/src/frdc/models/utils.py b/src/frdc/models/utils.py index 2b75e1c..a3b94b3 100644 --- a/src/frdc/models/utils.py +++ b/src/frdc/models/utils.py @@ -1,15 +1,12 @@ -from typing import Dict, Any, Sequence +import logging +import warnings +from typing import Dict, Any, Sequence, Callable -def on_save_checkpoint( +def save_unfrozen( self, checkpoint: Dict[str, Any], - saved_module_prefixes: Sequence[str] = ("fc.",), - saved_module_suffixes: Sequence[str] = ( - "running_mean", - "running_var", - "num_batches_tracked", - ), + include_also: Callable[[str], bool] = lambda k: False, ) -> None: """Saving only the classifier if frozen. @@ -20,29 +17,54 @@ def on_save_checkpoint( This usually reduces the model size by 99.9%, so it's worth it. - By default, this will save the classifier and the BatchNorm running - statistics. + By default, this will save any parameter that requires grad + and the BatchNorm running statistics. Args: self: Not used, but kept for consistency with on_load_checkpoint. checkpoint: The checkpoint to save. - saved_module_prefixes: The prefixes of the modules to save. - saved_module_suffixes: The suffixes of the modules to save. + include_also: A function that returns whether to include a parameter, + on top of any parameter that requires grad and BatchNorm running + statistics. """ - if checkpoint["hyper_parameters"]["frozen"]: - # Keep only the classifier - checkpoint["state_dict"] = { - k: v - for k, v in checkpoint["state_dict"].items() - if ( - k.startswith(saved_module_prefixes) - or k.endswith(saved_module_suffixes) - ) - } + # Keep only the classifier + new_state_dict = {} + for k, v in checkpoint["state_dict"].items(): + # We keep 2 things, + # 1. The BatchNorm running statistics + # 2. Anything that requires grad -def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + # BatchNorm running statistics should be kept + # for closer reconstruction of the model + is_bn_var = k.endswith( + ("running_mean", "running_var", "num_batches_tracked") + ) + try: + # We need to retrieve it from the original model + # as the state dict already freezes the model + is_required_grad = self.get_parameter(k).requires_grad + except AttributeError: + if not is_bn_var: + warnings.warn( + f"Unknown non-parameter key in state_dict. {k}." + f"This is an edge case where it's not a parameter nor " + f"BatchNorm running statistics. This will still be saved." + ) + is_required_grad = True + + # These are additional parameters to keep + is_include = include_also(k) + + if is_required_grad or is_bn_var or is_include: + logging.debug(f"Keeping {k}") + new_state_dict[k] = v + + checkpoint["state_dict"] = new_state_dict + + +def load_checkpoint_lenient(self, checkpoint: Dict[str, Any]) -> None: """Loading only the classifier if frozen Notes: diff --git a/src/frdc/train/fixmatch_module.py b/src/frdc/train/fixmatch_module.py index 67e788d..0d8cd80 100644 --- a/src/frdc/train/fixmatch_module.py +++ b/src/frdc/train/fixmatch_module.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any +from typing import Any, Dict import torch import torch.nn.functional as F @@ -10,6 +10,7 @@ from sklearn.preprocessing import StandardScaler, OrdinalEncoder from torchmetrics.functional import accuracy +from frdc.models.utils import save_unfrozen, load_checkpoint_lenient from frdc.train.utils import ( wandb_hist, preprocess, @@ -92,6 +93,7 @@ def training_step(self, batch, batch_idx): ℓ Loss: ℓ_lbl + ℓ_unl """ + def training_step(self, batch, batch_idx): (x_lbl, y_lbl), x_unls = batch opt = self.optimizers() @@ -246,3 +248,9 @@ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: y_encoder=self.y_encoder, x_unl=x_unl, ) + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + save_unfrozen(self, checkpoint) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + load_checkpoint_lenient(self, checkpoint) diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index 300d25e..fb24b9b 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any +from typing import Any, Dict import torch import torch.nn.functional as F @@ -11,6 +11,7 @@ from torch.nn.functional import one_hot from torchmetrics.functional import accuracy +from frdc.models.utils import save_unfrozen, load_checkpoint_lenient from frdc.train.utils import ( mix_up, sharpen, @@ -260,3 +261,13 @@ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: y_encoder=self.y_encoder, x_unl=x_unl, ) + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + save_unfrozen( + self, + checkpoint, + include_also=lambda k: k.startswith("_ema_model."), + ) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + load_checkpoint_lenient(self, checkpoint) From ce8f7bb1d687cd6762e071cea21c5e281ac1ce46 Mon Sep 17 00:00:00 2001 From: Evening Date: Wed, 5 Jun 2024 17:26:14 +0800 Subject: [PATCH 2/2] Reduce scope of MixMatch custom criteria saving --- src/frdc/train/mixmatch_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index fb24b9b..9617751 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -266,7 +266,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: save_unfrozen( self, checkpoint, - include_also=lambda k: k.startswith("_ema_model."), + include_also=lambda k: k.startswith("_ema_model.fc."), ) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: