diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index fb24b9b..9617751 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -266,7 +266,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: save_unfrozen( self, checkpoint, - include_also=lambda k: k.startswith("_ema_model."), + include_also=lambda k: k.startswith("_ema_model.fc."), ) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: