diff --git a/configs/environment/default.yaml b/configs/environment/default.yaml new file mode 100644 index 000000000..758b22bc2 --- /dev/null +++ b/configs/environment/default.yaml @@ -0,0 +1,2 @@ +defaults: + - _self_ diff --git a/configs/environment/lightning.yaml b/configs/environment/lightning.yaml new file mode 100644 index 000000000..a5e70f457 --- /dev/null +++ b/configs/environment/lightning.yaml @@ -0,0 +1 @@ +_target_: lightning.fabric.plugins.environments.LightningEnvironment diff --git a/configs/environment/slurm.yaml b/configs/environment/slurm.yaml new file mode 100644 index 000000000..55a5522ab --- /dev/null +++ b/configs/environment/slurm.yaml @@ -0,0 +1,3 @@ +_target_: lightning.fabric.plugins.environments.SLURMEnvironment +auto_requeue: true +requeue_signal: null diff --git a/configs/eval.yaml b/configs/eval.yaml index be312992b..c5d89c48f 100644 --- a/configs/eval.yaml +++ b/configs/eval.yaml @@ -5,10 +5,12 @@ defaults: - data: mnist # choose datamodule with `test_dataloader()` for evaluation - model: mnist - logger: null + - strategy: default - trainer: default - paths: default - extras: default - hydra: default + - environment: default task_name: "eval" diff --git a/configs/strategy/ddp.yaml b/configs/strategy/ddp.yaml new file mode 100644 index 000000000..14933f3a8 --- /dev/null +++ b/configs/strategy/ddp.yaml @@ -0,0 +1,4 @@ +_target_: lightning.pytorch.strategies.DDPStrategy +static_graph: false +gradient_as_bucket_view: false +find_unused_parameters: true diff --git a/configs/strategy/deepspeed.yaml b/configs/strategy/deepspeed.yaml new file mode 100644 index 000000000..3c05b4c25 --- /dev/null +++ b/configs/strategy/deepspeed.yaml @@ -0,0 +1,5 @@ +_target_: lightning.pytorch.strategies.DeepSpeedStrategy +stage: 2 +offload_optimizer: false +allgather_bucket_size: 200_000_000 +reduce_bucket_size: 200_000_000 diff --git a/configs/strategy/default.yaml b/configs/strategy/default.yaml new file mode 100644 index 000000000..758b22bc2 --- /dev/null +++ b/configs/strategy/default.yaml @@ -0,0 +1,2 @@ +defaults: + - _self_ diff --git a/configs/strategy/fsdp.yaml b/configs/strategy/fsdp.yaml new file mode 100644 index 000000000..12e29e86f --- /dev/null +++ b/configs/strategy/fsdp.yaml @@ -0,0 +1,12 @@ +_target_: lightning.pytorch.strategies.FSDPStrategy +sharding_strategy: ${resolve_variable:torch.distributed.fsdp.ShardingStrategy.FULL_SHARD} +cpu_offload: null +activation_checkpointing: null +mixed_precision: + _target_: torch.distributed.fsdp.MixedPrecision + param_dtype: null + reduce_dtype: null + buffer_dtype: null + keep_low_precision_grads: false + cast_forward_inputs: false + cast_root_forward_inputs: true diff --git a/configs/strategy/optimized_ddp.yaml b/configs/strategy/optimized_ddp.yaml new file mode 100644 index 000000000..b8b4b3122 --- /dev/null +++ b/configs/strategy/optimized_ddp.yaml @@ -0,0 +1,4 @@ +_target_: lightning.pytorch.strategies.DDPStrategy +static_graph: true +gradient_as_bucket_view: true +find_unused_parameters: false diff --git a/configs/train.yaml b/configs/train.yaml index ef7bdab6e..5f084625f 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -8,10 +8,12 @@ defaults: - model: mnist - callbacks: default - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) + - strategy: default - trainer: default - paths: default - extras: default - hydra: default + - environment: default # experiment configs allow for version control of specific hyperparameters # e.g. best hyperparameters for given model and datamodule diff --git a/src/__init__.py b/src/__init__.py index e69de29bb..c9b048bb5 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -0,0 +1,33 @@ +import importlib +from typing import Any + +from omegaconf import OmegaConf + + +def resolve_omegaconf_variable(variable_path: str) -> Any: + """Resolve an OmegaConf variable path to its value.""" + # split the string into parts using the dot separator + parts = variable_path.rsplit(".", 1) + + # get the module name from the first part of the path + module_name = parts[0] + + # dynamically import the module using the module name + try: + module = importlib.import_module(module_name) + # use the imported module to get the requested attribute value + attribute = getattr(module, parts[1]) + except Exception: + module = importlib.import_module(".".join(module_name.split(".")[:-1])) + inner_module = ".".join(module_name.split(".")[-1:]) + # use the imported module to get the requested attribute value + attribute = getattr(getattr(module, inner_module), parts[1]) + + return attribute + + +def register_custom_omegaconf_resolvers(): + """Register custom OmegaConf resolvers.""" + OmegaConf.register_new_resolver( + "resolve_variable", lambda variable_path: resolve_omegaconf_variable(variable_path) + ) diff --git a/src/eval.py b/src/eval.py index b70faae8b..81686ac09 100644 --- a/src/eval.py +++ b/src/eval.py @@ -3,7 +3,9 @@ import hydra import rootutils from lightning import LightningDataModule, LightningModule, Trainer +from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment from lightning.pytorch.loggers import Logger +from lightning.pytorch.strategies.strategy import Strategy from omegaconf import DictConfig rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) @@ -24,6 +26,7 @@ # more info: https://github.com/ashleve/rootutils # ------------------------------------------------------------------------------------ # +from src import register_custom_omegaconf_resolvers, resolve_omegaconf_variable from src.utils import ( RankedLogger, extras, @@ -56,8 +59,47 @@ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: log.info("Instantiating loggers...") logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + plugins = None + if "_target_" in cfg.environment: + log.info(f"Instantiating environment <{cfg.environment._target_}>") + plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment) + + strategy = getattr(cfg.trainer, "strategy", None) + if "_target_" in cfg.strategy: + log.info(f"Instantiating strategy <{cfg.strategy._target_}>") + strategy: Strategy = hydra.utils.instantiate(cfg.strategy) + if "mixed_precision" in strategy.__dict__ and getattr(strategy, "mixed_precision", None) is not None: + strategy.mixed_precision.param_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype) + if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None + else None + ) + strategy.mixed_precision.reduce_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype) + if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None + else None + ) + strategy.mixed_precision.buffer_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype) + if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None + else None + ) + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") - trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) + trainer: Trainer = ( + hydra.utils.instantiate( + cfg.trainer, + logger=logger, + plugins=plugins, + strategy=strategy, + ) + if strategy is not None + else hydra.utils.instantiate( + cfg.trainer, + logger=logger, + plugins=plugins, + ) + ) object_dict = { "cfg": cfg, @@ -96,4 +138,5 @@ def main(cfg: DictConfig) -> None: if __name__ == "__main__": + register_custom_omegaconf_resolvers() main() diff --git a/src/models/mnist_module.py b/src/models/mnist_module.py index 5d303ac2f..294d5341c 100644 --- a/src/models/mnist_module.py +++ b/src/models/mnist_module.py @@ -51,6 +51,7 @@ def __init__( :param net: The model to train. :param optimizer: The optimizer to use for training. :param scheduler: The learning rate scheduler to use for training. + :param compile: Whether to compile the model before training. """ super().__init__() @@ -198,7 +199,11 @@ def configure_optimizers(self) -> Dict[str, Any]: :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. """ - optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + try: + optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + except TypeError: + # NOTE: strategies such as DeepSpeed require `params` to instead be specified as `model_params` + optimizer = self.hparams.optimizer(model_params=self.trainer.model.parameters()) if self.hparams.scheduler is not None: scheduler = self.hparams.scheduler(optimizer=optimizer) return { diff --git a/src/train.py b/src/train.py index 4adbcf442..318b3d022 100644 --- a/src/train.py +++ b/src/train.py @@ -3,9 +3,11 @@ import hydra import lightning as L import rootutils -import torch +import os from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment from lightning.pytorch.loggers import Logger +from lightning.pytorch.strategies.strategy import Strategy from omegaconf import DictConfig rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) @@ -26,6 +28,7 @@ # more info: https://github.com/ashleve/rootutils # ------------------------------------------------------------------------------------ # +from src import register_custom_omegaconf_resolvers, resolve_omegaconf_variable from src.utils import ( RankedLogger, extras, @@ -66,8 +69,49 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: log.info("Instantiating loggers...") logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + plugins = None + if "_target_" in cfg.environment: + log.info(f"Instantiating environment <{cfg.environment._target_}>") + plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment) + + strategy = getattr(cfg.trainer, "strategy", None) + if "_target_" in cfg.strategy: + log.info(f"Instantiating strategy <{cfg.strategy._target_}>") + strategy: Strategy = hydra.utils.instantiate(cfg.strategy) + if "mixed_precision" in strategy.__dict__ and getattr(strategy, "mixed_precision", None) is not None: + strategy.mixed_precision.param_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype) + if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None + else None + ) + strategy.mixed_precision.reduce_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype) + if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None + else None + ) + strategy.mixed_precision.buffer_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype) + if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None + else None + ) + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") - trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) + trainer: Trainer = ( + hydra.utils.instantiate( + cfg.trainer, + callbacks=callbacks, + logger=logger, + plugins=plugins, + strategy=strategy, + ) + if strategy is not None + else hydra.utils.instantiate( + cfg.trainer, + callbacks=callbacks, + logger=logger, + plugins=plugins, + ) + ) object_dict = { "cfg": cfg, @@ -84,7 +128,14 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: if cfg.get("train"): log.info("Starting training!") - trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + ckpt_path = None + if cfg.get("ckpt_path") and os.path.exists(cfg.get("ckpt_path")): + ckpt_path = cfg.get("ckpt_path") + elif cfg.get("ckpt_path"): + log.warning( + "`ckpt_path` was given, but the path does not exist. Training with new model weights." + ) + trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) train_metrics = trainer.callback_metrics @@ -129,4 +180,5 @@ def main(cfg: DictConfig) -> Optional[float]: if __name__ == "__main__": + register_custom_omegaconf_resolvers() main()