Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added metrics logging to checkpoint and separate yaml file #1562

Merged
merged 24 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4b11551
Added metrics logging to checkpoint and separate yaml file
philmarchenko Oct 23, 2023
58c96ca
Merge branch 'master' into feature/saving_metrics_to_yaml
philmarchenko Oct 23, 2023
6b3999c
Fixed variable name in docstr to add_yaml_summary
philmarchenko Oct 23, 2023
fb8ae7d
Merged changes in feature/saving_metrics_to_yaml and making commit si…
philmarchenko Oct 23, 2023
9446489
Added method to abstract sglogger class and fixed some places comment…
philmarchenko Oct 24, 2023
80dc4b9
Added method to abstract sglogger class and fixed some places comment…
philmarchenko Oct 24, 2023
0c2e066
Fixed metric saving and added/fixed some tests
philmarchenko Oct 24, 2023
53b5294
Merge branch 'master' into feature/saving_metrics_to_yaml
philmarchenko Oct 24, 2023
cff188e
Merge branch 'master' into feature/saving_metrics_to_yaml
philmarchenko Oct 24, 2023
ca054c7
Merge branch 'master' into feature/saving_metrics_to_yaml
philmarchenko Oct 25, 2023
aea5860
Merge branch 'master' into feature/saving_metrics_to_yaml
philmarchenko Oct 25, 2023
c60612d
Changed casting function from __maybe_get_item_from_tensor to simple …
philmarchenko Oct 25, 2023
133b9f7
Merge branch 'feature/saving_metrics_to_yaml' of github.com:hakuryuu9…
philmarchenko Oct 25, 2023
2bf40e2
Merge branch 'master' into feature/saving_metrics_to_yaml
philmarchenko Oct 26, 2023
534fbfa
Left only metrics saved to checkpoint
philmarchenko Oct 26, 2023
34ab3fd
Removed test for yaml files
philmarchenko Oct 26, 2023
256dbae
Fixed some place in code according to comments in PR
philmarchenko Oct 26, 2023
cb84783
Changed float metric value back to int in test of schedulers :(
philmarchenko Oct 26, 2023
a96c2ab
Merge branch 'master' into feature/saving_metrics_to_yaml
philmarchenko Oct 27, 2023
e7b8d83
Merge branch 'master' into feature/saving_metrics_to_yaml
BloodAxe Oct 29, 2023
2ce24a7
Merge branch 'master' into feature/saving_metrics_to_yaml
philmarchenko Oct 30, 2023
412b62d
Fixed linters (trailing spaces)
philmarchenko Oct 30, 2023
6168a75
Merge branch 'master' into feature/saving_metrics_to_yaml
philmarchenko Oct 30, 2023
0a02eac
Merge branch 'master' into feature/saving_metrics_to_yaml
Louis-Dupont Oct 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/super_gradients/common/sg_loggers/base_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import os
import signal
import time
from typing import Union, Any
from typing import Union, Any, Optional

import matplotlib.pyplot as plt
import numpy as np
import psutil
import torch
from PIL import Image
import shutil
import yaml

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig
Expand Down Expand Up @@ -329,6 +330,22 @@ def _save_checkpoint(self, path: str, state_dict: dict) -> None:
if self.save_checkpoints_remote:
self.model_checkpoints_data_interface.save_remote_checkpoints_file(self.experiment_name, self._local_dir, name)

@multi_process_safe
def add_yaml_summary(self, tag: str, summary_dict: dict, global_step: Optional[int] = None) -> None:
philmarchenko marked this conversation as resolved.
Show resolved Hide resolved
"""Saves any dict to <experiment_folder>/<tag>.yaml
Initially added for saving metrics to yaml to store it in something easily parsable (easier than .pth checkpoints),
but who knows what it will be suited for later.

:param tag: Identifier of the summary.
:param summary_dict: Checkpoint summary_dict.
:param global_step: Epoch number.
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
"""

name = tag + (f"_{global_step}" if global_step is not None else "") + ".yml"
with open(os.path.join(self._local_dir, name), "w") as outfile:
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
yaml.dump(summary_dict, outfile, default_flow_style=False)
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
outfile.close()
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved

def add(self, tag: str, obj: Any, global_step: int = None):
pass

Expand Down
37 changes: 29 additions & 8 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ def _save_checkpoint(
self,
optimizer: torch.optim.Optimizer = None,
epoch: int = None,
train_metrics_dict: Optional[Dict[str, float]] = None,
validation_results_dict: Optional[Dict[str, float]] = None,
context: PhaseContext = None,
) -> None:
Expand All @@ -654,10 +655,18 @@ def _save_checkpoint(

# COMPUTE THE CURRENT metric
# IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST'S INDICES
metric = validation_results_dict[self.metric_to_watch]
curr_tracked_metric = validation_results_dict[self.metric_to_watch]
philmarchenko marked this conversation as resolved.
Show resolved Hide resolved

# create metrics dict to save
all_metrics = {
"tracked_metric_name": self.metric_to_watch,
"metrics": {"valid": {metric_name: validation_results_dict[metric_name] for metric_name in self.valid_metrics}},
}
if train_metrics_dict is not None:
all_metrics["metrics"]["train"] = {metric_name: train_metrics_dict[metric_name].item() for metric_name in self.train_metrics}

# BUILD THE state_dict
state = {"net": unwrap_model(self.net).state_dict(), "acc": metric, "epoch": epoch}
state = {"net": unwrap_model(self.net).state_dict(), "acc": curr_tracked_metric, "epoch": epoch, "all_metrics": all_metrics}
philmarchenko marked this conversation as resolved.
Show resolved Hide resolved

if optimizer is not None:
state["optimizer_state_dict"] = optimizer.state_dict()
Expand All @@ -677,23 +686,28 @@ def _save_checkpoint(

# SAVES CURRENT MODEL AS ckpt_latest
self.sg_logger.add_checkpoint(tag="ckpt_latest.pth", state_dict=state, global_step=epoch)
self.sg_logger.add_yaml_summary(tag="metrics_latest", summary_dict=all_metrics)

# SAVE MODEL AT SPECIFIC EPOCHS DETERMINED BY save_ckpt_epoch_list
if epoch in self.training_params.save_ckpt_epoch_list:
self.sg_logger.add_checkpoint(tag=f"ckpt_epoch_{epoch}.pth", state_dict=state, global_step=epoch)
self.sg_logger.add_yaml_summary(tag="metrics_epoch", summary_dict=all_metrics, global_step=epoch)

# OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST
if (metric > self.best_metric and self.greater_metric_to_watch_is_better) or (metric < self.best_metric and not self.greater_metric_to_watch_is_better):
if (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better
):
# STORE THE CURRENT metric AS BEST
self.best_metric = metric
self.best_metric = curr_tracked_metric
self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
self.sg_logger.add_yaml_summary(tag="metrics_best", summary_dict=all_metrics)

# RUN PHASE CALLBACKS
self.phase_callback_handler.on_validation_end_best_epoch(context)

if isinstance(metric, torch.Tensor):
metric = metric.item()
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(metric))
if isinstance(curr_tracked_metric, torch.Tensor):
curr_tracked_metric = curr_tracked_metric.item()
logger.info("Best checkpoint overriden: validation " + self.metric_to_watch + ": " + str(curr_tracked_metric))

if self.training_params.average_best_models:
net_for_averaging = unwrap_model(self.ema_model.ema if self.ema else self.net)
Expand Down Expand Up @@ -1184,6 +1198,7 @@ def forward(self, inputs, targets):
random_seed(is_ddp=device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, device=device_config.device, seed=self.training_params.seed)

silent_mode = self.training_params.silent_mode or self.ddp_silent_mode

# METRICS
self._set_train_metrics(train_metrics_list=self.training_params.train_metrics_list)
self._set_valid_metrics(valid_metrics_list=self.training_params.valid_metrics_list)
Expand Down Expand Up @@ -1925,7 +1940,13 @@ def _write_to_disk_operations(

# SAVE THE CHECKPOINT
if self.training_params.save_model:
self._save_checkpoint(self.optimizer, epoch + 1, validation_results_dict, context)
self._save_checkpoint(
optimizer=self.optimizer,
epoch=epoch + 1,
train_metrics_dict=train_metrics_dict,
validation_results_dict=validation_results_dict,
context=context,
)

def _get_epoch_start_logging_values(self) -> dict:
"""Get all the values that should be logged at the start of each epoch.
Expand Down