diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 918a10f20b..3cb5bf39f1 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -643,6 +643,7 @@ def _save_checkpoint( self, optimizer: torch.optim.Optimizer = None, epoch: int = None, + train_metrics_dict: Optional[Dict[str, float]] = None, validation_results_dict: Optional[Dict[str, float]] = None, context: PhaseContext = None, ) -> None: @@ -657,10 +658,28 @@ def _save_checkpoint( # COMPUTE THE CURRENT metric # IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST'S INDICES - metric = validation_results_dict[self.metric_to_watch] + curr_tracked_metric = float(validation_results_dict[self.metric_to_watch]) + + # create metrics dict to save + valid_metrics_titles = get_metrics_titles(self.valid_metrics) + + all_metrics = { + "tracked_metric_name": self.metric_to_watch, + "valid": {metric_name: float(validation_results_dict[metric_name]) for metric_name in valid_metrics_titles}, + } + + if train_metrics_dict is not None: + train_metrics_titles = get_metrics_titles(self.train_metrics) + all_metrics["train"] = {metric_name: float(train_metrics_dict[metric_name]) for metric_name in train_metrics_titles} # BUILD THE state_dict - state = {"net": unwrap_model(self.net).state_dict(), "acc": metric, "epoch": epoch, "packages": get_installed_packages()} + state = { + "net": unwrap_model(self.net).state_dict(), + "acc": curr_tracked_metric, + "epoch": epoch, + "metrics": all_metrics, + "packages": get_installed_packages(), + } if optimizer is not None: state["optimizer_state_dict"] = optimizer.state_dict() @@ -686,17 +705,16 @@ def _save_checkpoint( self.sg_logger.add_checkpoint(tag=f"ckpt_epoch_{epoch}.pth", state_dict=state, global_step=epoch) # OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST - if (metric > self.best_metric and self.greater_metric_to_watch_is_better) or (metric < self.best_metric and not self.greater_metric_to_watch_is_better): + if (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or ( + curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better + ): # STORE THE CURRENT metric AS BEST - self.best_metric = metric + self.best_metric = curr_tracked_metric self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch) # RUN PHASE CALLBACKS self.phase_callback_handler.on_validation_end_best_epoch(context) - - if isinstance(metric, torch.Tensor): - metric = metric.item() - logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(metric)) + logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(curr_tracked_metric)) if self.training_params.average_best_models: net_for_averaging = unwrap_model(self.ema_model.ema if self.ema else self.net) @@ -1187,6 +1205,7 @@ def forward(self, inputs, targets): random_seed(is_ddp=device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, device=device_config.device, seed=self.training_params.seed) silent_mode = self.training_params.silent_mode or self.ddp_silent_mode + # METRICS self._set_train_metrics(train_metrics_list=self.training_params.train_metrics_list) self._set_valid_metrics(valid_metrics_list=self.training_params.valid_metrics_list) @@ -1938,7 +1957,13 @@ def _write_to_disk_operations( # SAVE THE CHECKPOINT if self.training_params.save_model: - self._save_checkpoint(self.optimizer, epoch + 1, validation_results_dict, context) + self._save_checkpoint( + optimizer=self.optimizer, + epoch=epoch + 1, + train_metrics_dict=train_metrics_dict, + validation_results_dict=validation_results_dict, + context=context, + ) def _get_epoch_start_logging_values(self) -> dict: """Get all the values that should be logged at the start of each epoch. diff --git a/tests/end_to_end_tests/trainer_test.py b/tests/end_to_end_tests/trainer_test.py index 7a0f5916b7..0dbcbbad17 100644 --- a/tests/end_to_end_tests/trainer_test.py +++ b/tests/end_to_end_tests/trainer_test.py @@ -18,7 +18,16 @@ class TestTrainer(unittest.TestCase): def setUp(cls): super_gradients.init_trainer() # NAMES FOR THE EXPERIMENTS TO LATER DELETE - cls.experiment_names = ["test_train", "test_save_load", "test_load_w", "test_load_w2", "test_load_w3", "test_checkpoint_content", "analyze"] + cls.experiment_names = [ + "test_train", + "test_save_load", + "test_load_w", + "test_load_w2", + "test_load_w3", + "test_checkpoint_content", + "analyze", + "test_yaml_metrics_present", + ] cls.training_params = { "max_epochs": 1, "silent_mode": True, @@ -79,7 +88,7 @@ def test_checkpoint_content(self): ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename] for ckpt_path in ckpt_paths: ckpt = torch.load(ckpt_path) - self.assertListEqual(sorted(["net", "acc", "epoch", "optimizer_state_dict", "scaler_state_dict", "packages"]), sorted(list(ckpt.keys()))) + self.assertListEqual(sorted(["net", "acc", "epoch", "optimizer_state_dict", "scaler_state_dict", "metrics", "packages"]), sorted(list(ckpt.keys()))) trainer._save_checkpoint() weights_only = torch.load(os.path.join(trainer.checkpoints_dir_path, "ckpt_latest_weights_only.pth")) self.assertListEqual(["net"], list(weights_only.keys()))