diff --git a/composer/loggers/mosaicml_logger.py b/composer/loggers/mosaicml_logger.py index a00f6f43c0..7d1f34633e 100644 --- a/composer/loggers/mosaicml_logger.py +++ b/composer/loggers/mosaicml_logger.py @@ -14,17 +14,21 @@ import time import warnings from concurrent.futures import wait +from dataclasses import dataclass from functools import reduce -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import mcli import torch +import torch.utils.data -from composer.core.time import TimeUnit +from composer.core.event import Event +from composer.core.time import Time, TimeUnit from composer.loggers import Logger from composer.loggers.logger_destination import LoggerDestination from composer.loggers.wandb_logger import WandBLogger from composer.utils import dist +from composer.utils.file_helpers import parse_uri if TYPE_CHECKING: from composer.core import State @@ -46,6 +50,18 @@ class MosaicMLLogger(LoggerDestination): Logs metrics to the MosaicML platform. Logging only happens on rank 0 every ``log_interval`` seconds to avoid performance issues. + Additionally, The following metrics are logged upon ``INIT``: + - ``composer/autoresume``: Whether or not the run can be stopped / resumed during training. + - ``composer/precision``: The precision to use for training. + - ``composer/eval_loaders``: A list containing the labels of each evaluation dataloader. + - ``composer/optimizers``: A list of dictionaries containing information about each optimizer. + - ``composer/algorithms``: A list containing the names of the algorithms used for training. + - ``composer/loggers``: A list containing the loggers used in the ``Trainer``. + - ``composer/cloud_provided_load_path``: The cloud provider for the load path. + - ``composer/cloud_provided_save_folder``: The cloud provider for the save folder. + - ``composer/save_interval``: The save interval for the run. + - ``composer/fsdp_config``: The FSDP config used for training. + When running on the MosaicML platform, the logger is automatically enabled by Trainer. To disable, the environment variable 'MOSAICML_PLATFORM' can be set to False. @@ -62,6 +78,7 @@ class MosaicMLLogger(LoggerDestination): (default: ``None``) ignore_exceptions: Flag to disable logging exceptions. Defaults to False. + analytics_data (Dict[str, Any], optional): A dictionary containing variables used to log analytics. Defaults to ``None``. """ def __init__( @@ -69,10 +86,12 @@ def __init__( log_interval: int = 60, ignore_keys: Optional[List[str]] = None, ignore_exceptions: bool = False, + analytics_data: Optional[MosaicAnalyticsData] = None, ) -> None: self.log_interval = log_interval self.ignore_keys = ignore_keys self.ignore_exceptions = ignore_exceptions + self.analytics_data = analytics_data self._enabled = dist.get_global_rank() == 0 if self._enabled: self.time_last_logged = 0 @@ -96,10 +115,57 @@ def log_hyperparameters(self, hyperparameters: Dict[str, Any]): def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: self._log_metadata(metrics) + def log_analytics(self, state: State, loggers: Tuple[LoggerDestination, ...]) -> None: + if self.analytics_data is None: + return + + metrics: Dict[str, Any] = { + 'composer/autoresume': self.analytics_data.autoresume, + 'composer/precision': state.precision, + } + metrics['composer/eval_loaders'] = [evaluator.label for evaluator in state.evaluators] + metrics['composer/optimizers'] = [{ + optimizer.__class__.__name__: optimizer.defaults, + } for optimizer in state.optimizers] + metrics['composer/algorithms'] = [algorithm.__class__.__name__ for algorithm in state.algorithms] + metrics['composer/loggers'] = [logger.__class__.__name__ for logger in loggers] + + # Take the service provider out of the URI and log it to metadata. If no service provider + # is found (i.e. backend = ''), then we assume 'local' for the cloud provider. + if self.analytics_data.load_path is not None: + backend, _, _ = parse_uri(self.analytics_data.load_path) + metrics['composer/cloud_provided_load_path'] = backend if backend else 'local' + if self.analytics_data.save_folder is not None: + backend, _, _ = parse_uri(self.analytics_data.save_folder) + metrics['composer/cloud_provided_save_folder'] = backend if backend else 'local' + + # Save interval can be passed in w/ multiple types. If the type is a function, then + # we log 'callable' as the save_interval value for analytics. + if isinstance(self.analytics_data.save_interval, Union[str, int]): + save_interval_str = str(self.analytics_data.save_interval) + elif isinstance(self.analytics_data.save_interval, Time): + save_interval_str = f'{self.analytics_data.save_interval._value}{self.analytics_data.save_interval._unit}' + else: + save_interval_str = 'callable' + metrics['composer/save_interval'] = save_interval_str + + if state.fsdp_config: + # Keys need to be sorted so they can be parsed consistently in SQL queries + metrics['composer/fsdp_config'] = json.dumps(state.fsdp_config, sort_keys=True) + + self.log_metrics(metrics) + self._flush_metadata(force_flush=True) + def log_exception(self, exception: Exception): self._log_metadata({'exception': exception_to_json_serializable_dict(exception)}) self._flush_metadata(force_flush=True) + def init(self, state: State, logger: Logger) -> None: + try: + self.log_analytics(state, logger.destinations) + except: + warnings.warn('Failed to log analytics data to MosaicML. Continuing without logging analytics data.') + def after_load(self, state: State, logger: Logger) -> None: # Log model data downloaded and initialized for run events log.debug(f'Logging model initialized time to metadata') @@ -229,6 +295,14 @@ def _get_training_progress_metrics(self, state: State) -> Dict[str, Any]: return training_progress_metrics +@dataclass(frozen=True) +class MosaicAnalyticsData: + autoresume: bool + save_interval: Union[str, int, Time, Callable[[State, Event], bool]] + load_path: Union[str, None] + save_folder: Union[str, None] + + def format_data_to_json_serializable(data: Any): """Recursively formats data to be JSON serializable. diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 064c22a73b..41f34a51d8 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -83,7 +83,11 @@ RemoteUploaderDownloader, WandBLogger, ) -from composer.loggers.mosaicml_logger import MOSAICML_ACCESS_TOKEN_ENV_VAR, MOSAICML_PLATFORM_ENV_VAR +from composer.loggers.mosaicml_logger import ( + MOSAICML_ACCESS_TOKEN_ENV_VAR, + MOSAICML_PLATFORM_ENV_VAR, + MosaicAnalyticsData, +) from composer.models import ComposerModel from composer.optim import ComposerScheduler, DecoupledSGDW, compile_composer_scheduler from composer.profiler import Profiler @@ -1284,8 +1288,24 @@ def __init__( MOSAICML_ACCESS_TOKEN_ENV_VAR, ) is not None and not any(isinstance(x, MosaicMLLogger) for x in loggers): log.info('Detected run on MosaicML platform. Adding MosaicMLLogger to loggers.') - mosaicml_logger = MosaicMLLogger() + + analytics_data = MosaicAnalyticsData( + autoresume=autoresume, + save_interval=save_interval, + load_path=load_path, + save_folder=save_folder, + ) + mosaicml_logger = MosaicMLLogger(analytics_data=analytics_data,) loggers.append(mosaicml_logger) + elif any(isinstance(x, MosaicMLLogger) for x in loggers): + # If a MosaicMLLogger is already present (i.e. passed into the Trainer), update the analytics data + mosaicml_logger = next((logger for logger in loggers if isinstance(logger, MosaicMLLogger))) + mosaicml_logger.analytics_data = MosaicAnalyticsData( + autoresume=autoresume, + save_interval=save_interval, + load_path=load_path, + save_folder=save_folder, + ) # Remote Uploader Downloader # Keep the ``RemoteUploaderDownloader`` below client-provided loggers so the loggers init callbacks run before diff --git a/tests/loggers/test_mosaicml_logger.py b/tests/loggers/test_mosaicml_logger.py index 795c8da56b..5048398378 100644 --- a/tests/loggers/test_mosaicml_logger.py +++ b/tests/loggers/test_mosaicml_logger.py @@ -384,3 +384,33 @@ def test_epoch_zero_no_dataloader_progress_metrics(): assert training_progress['training_progress'] == '[epoch=1/3]' assert 'training_sub_progress' in training_progress assert training_progress['training_sub_progress'] == '[batch=1]' + + +def test_logged_metrics(monkeypatch): + mock_mapi = MockMAPI() + monkeypatch.setenv('MOSAICML_PLATFORM', 'True') + monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata) + run_name = 'test-run-name' + monkeypatch.setenv('RUN_NAME', run_name) + trainer = Trainer( + model=SimpleModel(), + train_dataloader=DataLoader(RandomClassificationDataset()), + train_subset_num_batches=1, + max_duration='4ba', + loggers=[MosaicMLLogger()], + ) + trainer.fit() + + # Check that analytics metrics were logged + metadata = mock_mapi.run_metadata[run_name] + analytics = {k: v for k, v in metadata.items() if k.startswith('mosaicml/composer/')} + assert len(analytics) > 0 + + key_name = lambda x: f'mosaicml/composer/{x}' + assert key_name('autoresume') in analytics and analytics[key_name('autoresume')] == False + assert key_name('precision') in analytics and analytics[key_name('precision')] == 'Precision.FP32' + assert key_name('eval_loaders') in analytics and analytics[key_name('eval_loaders')] == [] + assert key_name('algorithms') in analytics and analytics[key_name('algorithms')] == [] + assert key_name('loggers') in analytics and analytics[key_name('loggers') + ] == ['MosaicMLLogger', 'ProgressBarLogger'] + assert key_name('save_interval') in analytics and analytics[key_name('save_interval')] == '1ep'