From 395702024f0cba384098da3f19ca07598c496579 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 7 Oct 2023 12:15:35 -0500 Subject: [PATCH 1/9] Add support for using custom Environments and Strategies --- configs/environment/default.yaml | 2 ++ configs/environment/lightning.yaml | 1 + configs/environment/slurm.yaml | 3 ++ configs/eval.yaml | 2 ++ configs/strategy/ddp.yaml | 4 +++ configs/strategy/deepspeed.yaml | 5 +++ configs/strategy/default.yaml | 2 ++ configs/strategy/fsdp.yaml | 12 ++++++++ configs/strategy/optimized_ddp.yaml | 4 +++ configs/train.yaml | 2 ++ src/__init__.py | 33 ++++++++++++++++++++ src/eval.py | 45 ++++++++++++++++++++++++++- src/train.py | 48 +++++++++++++++++++++++++++-- 13 files changed, 160 insertions(+), 3 deletions(-) create mode 100644 configs/environment/default.yaml create mode 100644 configs/environment/lightning.yaml create mode 100644 configs/environment/slurm.yaml create mode 100644 configs/strategy/ddp.yaml create mode 100644 configs/strategy/deepspeed.yaml create mode 100644 configs/strategy/default.yaml create mode 100644 configs/strategy/fsdp.yaml create mode 100644 configs/strategy/optimized_ddp.yaml diff --git a/configs/environment/default.yaml b/configs/environment/default.yaml new file mode 100644 index 000000000..758b22bc2 --- /dev/null +++ b/configs/environment/default.yaml @@ -0,0 +1,2 @@ +defaults: + - _self_ diff --git a/configs/environment/lightning.yaml b/configs/environment/lightning.yaml new file mode 100644 index 000000000..a5e70f457 --- /dev/null +++ b/configs/environment/lightning.yaml @@ -0,0 +1 @@ +_target_: lightning.fabric.plugins.environments.LightningEnvironment diff --git a/configs/environment/slurm.yaml b/configs/environment/slurm.yaml new file mode 100644 index 000000000..55a5522ab --- /dev/null +++ b/configs/environment/slurm.yaml @@ -0,0 +1,3 @@ +_target_: lightning.fabric.plugins.environments.SLURMEnvironment +auto_requeue: true +requeue_signal: null diff --git a/configs/eval.yaml b/configs/eval.yaml index be312992b..c5d89c48f 100644 --- a/configs/eval.yaml +++ b/configs/eval.yaml @@ -5,10 +5,12 @@ defaults: - data: mnist # choose datamodule with `test_dataloader()` for evaluation - model: mnist - logger: null + - strategy: default - trainer: default - paths: default - extras: default - hydra: default + - environment: default task_name: "eval" diff --git a/configs/strategy/ddp.yaml b/configs/strategy/ddp.yaml new file mode 100644 index 000000000..14933f3a8 --- /dev/null +++ b/configs/strategy/ddp.yaml @@ -0,0 +1,4 @@ +_target_: lightning.pytorch.strategies.DDPStrategy +static_graph: false +gradient_as_bucket_view: false +find_unused_parameters: true diff --git a/configs/strategy/deepspeed.yaml b/configs/strategy/deepspeed.yaml new file mode 100644 index 000000000..3c05b4c25 --- /dev/null +++ b/configs/strategy/deepspeed.yaml @@ -0,0 +1,5 @@ +_target_: lightning.pytorch.strategies.DeepSpeedStrategy +stage: 2 +offload_optimizer: false +allgather_bucket_size: 200_000_000 +reduce_bucket_size: 200_000_000 diff --git a/configs/strategy/default.yaml b/configs/strategy/default.yaml new file mode 100644 index 000000000..758b22bc2 --- /dev/null +++ b/configs/strategy/default.yaml @@ -0,0 +1,2 @@ +defaults: + - _self_ diff --git a/configs/strategy/fsdp.yaml b/configs/strategy/fsdp.yaml new file mode 100644 index 000000000..12e29e86f --- /dev/null +++ b/configs/strategy/fsdp.yaml @@ -0,0 +1,12 @@ +_target_: lightning.pytorch.strategies.FSDPStrategy +sharding_strategy: ${resolve_variable:torch.distributed.fsdp.ShardingStrategy.FULL_SHARD} +cpu_offload: null +activation_checkpointing: null +mixed_precision: + _target_: torch.distributed.fsdp.MixedPrecision + param_dtype: null + reduce_dtype: null + buffer_dtype: null + keep_low_precision_grads: false + cast_forward_inputs: false + cast_root_forward_inputs: true diff --git a/configs/strategy/optimized_ddp.yaml b/configs/strategy/optimized_ddp.yaml new file mode 100644 index 000000000..b8b4b3122 --- /dev/null +++ b/configs/strategy/optimized_ddp.yaml @@ -0,0 +1,4 @@ +_target_: lightning.pytorch.strategies.DDPStrategy +static_graph: true +gradient_as_bucket_view: true +find_unused_parameters: false diff --git a/configs/train.yaml b/configs/train.yaml index ef7bdab6e..5f084625f 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -8,10 +8,12 @@ defaults: - model: mnist - callbacks: default - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) + - strategy: default - trainer: default - paths: default - extras: default - hydra: default + - environment: default # experiment configs allow for version control of specific hyperparameters # e.g. best hyperparameters for given model and datamodule diff --git a/src/__init__.py b/src/__init__.py index e69de29bb..184fee583 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -0,0 +1,33 @@ +import importlib + +from omegaconf import OmegaConf +from typing import Any + + +def resolve_omegaconf_variable(variable_path: str) -> Any: + """Resolve an OmegaConf variable path to its value.""" + # split the string into parts using the dot separator + parts = variable_path.rsplit(".", 1) + + # get the module name from the first part of the path + module_name = parts[0] + + # dynamically import the module using the module name + try: + module = importlib.import_module(module_name) + # use the imported module to get the requested attribute value + attribute = getattr(module, parts[1]) + except Exception: + module = importlib.import_module(".".join(module_name.split(".")[:-1])) + inner_module = ".".join(module_name.split(".")[-1:]) + # use the imported module to get the requested attribute value + attribute = getattr(getattr(module, inner_module), parts[1]) + + return attribute + + +def register_custom_omegaconf_resolvers(): + """Register custom OmegaConf resolvers.""" + OmegaConf.register_new_resolver( + "resolve_variable", lambda variable_path: resolve_omegaconf_variable(variable_path) + ) diff --git a/src/eval.py b/src/eval.py index b70faae8b..64790347a 100644 --- a/src/eval.py +++ b/src/eval.py @@ -3,7 +3,9 @@ import hydra import rootutils from lightning import LightningDataModule, LightningModule, Trainer +from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment from lightning.pytorch.loggers import Logger +from lightning.pytorch.strategies.strategy import Strategy from omegaconf import DictConfig rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) @@ -24,6 +26,7 @@ # more info: https://github.com/ashleve/rootutils # ------------------------------------------------------------------------------------ # +from src import register_custom_omegaconf_resolvers, resolve_omegaconf_variable from src.utils import ( RankedLogger, extras, @@ -56,8 +59,47 @@ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: log.info("Instantiating loggers...") logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + plugins = None + if "_target_" in cfg.environment: + log.info(f"Instantiating environment <{cfg.environment._target_}>") + plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment) + + strategy = getattr(cfg.trainer, "strategy", None) + if "_target_" in cfg.strategy: + log.info(f"Instantiating strategy <{cfg.strategy._target_}>") + strategy: Strategy = hydra.utils.instantiate(cfg.strategy) + if "mixed_precision" in strategy.__dict__: + strategy.mixed_precision.param_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype) + if cfg.strategy.mixed_precision.param_dtype is not None + else None + ) + strategy.mixed_precision.reduce_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype) + if cfg.strategy.mixed_precision.reduce_dtype is not None + else None + ) + strategy.mixed_precision.buffer_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype) + if cfg.strategy.mixed_precision.buffer_dtype is not None + else None + ) + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") - trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) + trainer: Trainer = ( + hydra.utils.instantiate( + cfg.trainer, + logger=logger, + plugins=plugins, + strategy=strategy, + ) + if strategy is not None + else hydra.utils.instantiate( + cfg.trainer, + logger=logger, + plugins=plugins, + ) + ) object_dict = { "cfg": cfg, @@ -96,4 +138,5 @@ def main(cfg: DictConfig) -> None: if __name__ == "__main__": + register_custom_omegaconf_resolvers() main() diff --git a/src/train.py b/src/train.py index 4adbcf442..f8b6bed8a 100644 --- a/src/train.py +++ b/src/train.py @@ -3,9 +3,10 @@ import hydra import lightning as L import rootutils -import torch from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment from lightning.pytorch.loggers import Logger +from lightning.pytorch.strategies.strategy import Strategy from omegaconf import DictConfig rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) @@ -26,6 +27,7 @@ # more info: https://github.com/ashleve/rootutils # ------------------------------------------------------------------------------------ # +from src import register_custom_omegaconf_resolvers, resolve_omegaconf_variable from src.utils import ( RankedLogger, extras, @@ -66,8 +68,49 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: log.info("Instantiating loggers...") logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + plugins = None + if "_target_" in cfg.environment: + log.info(f"Instantiating environment <{cfg.environment._target_}>") + plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment) + + strategy = getattr(cfg.trainer, "strategy", None) + if "_target_" in cfg.strategy: + log.info(f"Instantiating strategy <{cfg.strategy._target_}>") + strategy: Strategy = hydra.utils.instantiate(cfg.strategy) + if "mixed_precision" in strategy.__dict__: + strategy.mixed_precision.param_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype) + if cfg.strategy.mixed_precision.param_dtype is not None + else None + ) + strategy.mixed_precision.reduce_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype) + if cfg.strategy.mixed_precision.reduce_dtype is not None + else None + ) + strategy.mixed_precision.buffer_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype) + if cfg.strategy.mixed_precision.buffer_dtype is not None + else None + ) + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") - trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) + trainer: Trainer = ( + hydra.utils.instantiate( + cfg.trainer, + callbacks=callbacks, + logger=logger, + plugins=plugins, + strategy=strategy, + ) + if strategy is not None + else hydra.utils.instantiate( + cfg.trainer, + callbacks=callbacks, + logger=logger, + plugins=plugins, + ) + ) object_dict = { "cfg": cfg, @@ -129,4 +172,5 @@ def main(cfg: DictConfig) -> Optional[float]: if __name__ == "__main__": + register_custom_omegaconf_resolvers() main() From 47707c69bc279a57a507a65a3b145b4a338a8af6 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 7 Oct 2023 12:28:39 -0500 Subject: [PATCH 2/9] Address pre-commit concerns --- src/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/__init__.py b/src/__init__.py index 184fee583..c9b048bb5 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,7 +1,7 @@ import importlib +from typing import Any from omegaconf import OmegaConf -from typing import Any def resolve_omegaconf_variable(variable_path: str) -> Any: From fc265af1924aab6784613f485ab8832cb1d610d6 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sun, 8 Oct 2023 16:13:12 -0500 Subject: [PATCH 3/9] Handle edge case for DeepSpeed optimization --- src/models/mnist_module.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/models/mnist_module.py b/src/models/mnist_module.py index 5d303ac2f..6f613865c 100644 --- a/src/models/mnist_module.py +++ b/src/models/mnist_module.py @@ -198,7 +198,11 @@ def configure_optimizers(self) -> Dict[str, Any]: :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. """ - optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + try: + optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + except TypeError: + # NOTE: strategies such as DeepSpeed require `params` to instead be specified as `model_params` + optimizer = self.hparams.optimizer(model_params=self.trainer.model.parameters()) if self.hparams.scheduler is not None: scheduler = self.hparams.scheduler(optimizer=optimizer) return { From 6709b363ef73c4b0c67bc3e8f416871006ff6ed3 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sun, 8 Oct 2023 16:34:07 -0500 Subject: [PATCH 4/9] Add warning message for passing ckpt_path that points to a non-existent file --- src/train.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/train.py b/src/train.py index f8b6bed8a..c61a976fa 100644 --- a/src/train.py +++ b/src/train.py @@ -3,6 +3,7 @@ import hydra import lightning as L import rootutils +import os from lightning import Callback, LightningDataModule, LightningModule, Trainer from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment from lightning.pytorch.loggers import Logger @@ -127,7 +128,14 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: if cfg.get("train"): log.info("Starting training!") - trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + ckpt_path = None + if cfg.get("ckpt_path") and os.path.exists(cfg.get("ckpt_path")): + ckpt_path = cfg.get("ckpt_path") + elif cfg.get("ckpt_path"): + log.warning( + "`ckpt_path` was given, but the path does not exist. Training with new model weights." + ) + trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) train_metrics = trainer.callback_metrics From 70a8bb6cc0a4a70403098776aad9470b3aa83503 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Tue, 17 Oct 2023 16:48:58 -0500 Subject: [PATCH 5/9] Add missing if-check for edge case --- src/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/train.py b/src/train.py index c61a976fa..d21afdcd5 100644 --- a/src/train.py +++ b/src/train.py @@ -78,7 +78,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: if "_target_" in cfg.strategy: log.info(f"Instantiating strategy <{cfg.strategy._target_}>") strategy: Strategy = hydra.utils.instantiate(cfg.strategy) - if "mixed_precision" in strategy.__dict__: + if "mixed_precision" in strategy.__dict__ and strategy.mixed_precision is not None: strategy.mixed_precision.param_dtype = ( resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype) if cfg.strategy.mixed_precision.param_dtype is not None From 0a6d1bb43e8f5449b2dc0ffed972c060b3480fb8 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Tue, 17 Oct 2023 16:49:36 -0500 Subject: [PATCH 6/9] Add if-check to account for edge case --- src/eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/eval.py b/src/eval.py index 64790347a..9e15a3e16 100644 --- a/src/eval.py +++ b/src/eval.py @@ -68,7 +68,7 @@ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: if "_target_" in cfg.strategy: log.info(f"Instantiating strategy <{cfg.strategy._target_}>") strategy: Strategy = hydra.utils.instantiate(cfg.strategy) - if "mixed_precision" in strategy.__dict__: + if "mixed_precision" in strategy.__dict__ and strategy.mixed_precision is not None: strategy.mixed_precision.param_dtype = ( resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype) if cfg.strategy.mixed_precision.param_dtype is not None From bffa428a8a260d0b25b5a80f2cc9f09de430af90 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Tue, 17 Oct 2023 16:56:17 -0500 Subject: [PATCH 7/9] Add extra guards --- src/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/train.py b/src/train.py index d21afdcd5..318b3d022 100644 --- a/src/train.py +++ b/src/train.py @@ -78,20 +78,20 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: if "_target_" in cfg.strategy: log.info(f"Instantiating strategy <{cfg.strategy._target_}>") strategy: Strategy = hydra.utils.instantiate(cfg.strategy) - if "mixed_precision" in strategy.__dict__ and strategy.mixed_precision is not None: + if "mixed_precision" in strategy.__dict__ and getattr(strategy, "mixed_precision", None) is not None: strategy.mixed_precision.param_dtype = ( resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype) - if cfg.strategy.mixed_precision.param_dtype is not None + if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None else None ) strategy.mixed_precision.reduce_dtype = ( resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype) - if cfg.strategy.mixed_precision.reduce_dtype is not None + if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None else None ) strategy.mixed_precision.buffer_dtype = ( resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype) - if cfg.strategy.mixed_precision.buffer_dtype is not None + if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None else None ) From 473cfbf94a7937d626f4541a42d6d9cfc1c90330 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Tue, 17 Oct 2023 16:56:39 -0500 Subject: [PATCH 8/9] Add extra guards --- src/eval.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/eval.py b/src/eval.py index 9e15a3e16..81686ac09 100644 --- a/src/eval.py +++ b/src/eval.py @@ -68,20 +68,20 @@ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: if "_target_" in cfg.strategy: log.info(f"Instantiating strategy <{cfg.strategy._target_}>") strategy: Strategy = hydra.utils.instantiate(cfg.strategy) - if "mixed_precision" in strategy.__dict__ and strategy.mixed_precision is not None: + if "mixed_precision" in strategy.__dict__ and getattr(strategy, "mixed_precision", None) is not None: strategy.mixed_precision.param_dtype = ( resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype) - if cfg.strategy.mixed_precision.param_dtype is not None + if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None else None ) strategy.mixed_precision.reduce_dtype = ( resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype) - if cfg.strategy.mixed_precision.reduce_dtype is not None + if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None else None ) strategy.mixed_precision.buffer_dtype = ( resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype) - if cfg.strategy.mixed_precision.buffer_dtype is not None + if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None else None ) From 27fd56224976852f21c164b500ec9d4de6024508 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Fri, 1 Dec 2023 14:49:36 -0600 Subject: [PATCH 9/9] Update docstring --- src/models/mnist_module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/models/mnist_module.py b/src/models/mnist_module.py index 6f613865c..294d5341c 100644 --- a/src/models/mnist_module.py +++ b/src/models/mnist_module.py @@ -51,6 +51,7 @@ def __init__( :param net: The model to train. :param optimizer: The optimizer to use for training. :param scheduler: The learning rate scheduler to use for training. + :param compile: Whether to compile the model before training. """ super().__init__()