diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index c90c2ed30d1..b5844a0de55 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -86,6 +86,8 @@ title: Gradient synchronization - local: concept_guides/fsdp_and_deepspeed title: FSDP vs DeepSpeed + - local: concept_guides/fsdp1_vs_fsdp2 + title: FSDP1 vs FSDP2 - local: concept_guides/low_precision_training title: Low precision training methods - local: concept_guides/training_tpu diff --git a/docs/source/concept_guides/fsdp1_vs_fsdp2.md b/docs/source/concept_guides/fsdp1_vs_fsdp2.md new file mode 100644 index 00000000000..28adf431853 --- /dev/null +++ b/docs/source/concept_guides/fsdp1_vs_fsdp2.md @@ -0,0 +1,67 @@ + + +# FSDP1 vs FSDP2 + +This guide explains the key differences between `FSDP1` and `FSDP2` and helps you migrate your existing code to use `FSDP2` with minimal changes. + +## What is FSDP2? + +`FSDP2` is a new and improved version of PyTorch's fully-sharded data parallel training API. Compared to `FSDP1`, it offers: +- Simpler internal implementation +- Flexible parameter freezing +- Support for `fp8` parameters +- Faster and simpler checkpointing +- Better memory efficiency + +## Key Differences + +Here are the main changes in configuration options when using `FSDP2` through the Accelerate CLI: + +Previous (`FSDP1`) | New (`FSDP2`) | What Changed +-- | -- | -- +`--fsdp_sharding_strategy` | `--fsdp_reshard_after_forward` | replaces `--fsdp_sharding_strategy`, changed to `true` (previously `FULL_SHARD`) or `false` (previously `SHARD_GRAD_OP`) +`--fsdp_backward_prefetch` | \*\***REMOVED**\*\* | `FSDP2` uses previous `BACKWARD_PRE` option by default, as only this allows communication and computation overlap +`--fsdp_state_dict_type` | \*\***REMOVED**\*\* | `FSDP2` always uses `SHARDED_STATE_DICT`, i.e. each rank only checkpoints the shard of the model on it, resulting in no extra communication +`--fsdp_forward_prefetch` | \*\***NOT YET IMPLEMENTED**\*\* | How to implement this is under active discussion, for now it is not supported in `FSDP2` +`--fsdp_cpu_ram_efficient_loading` | **TODO** | **TODO** +`--fsdp_sync_module_states` | **TODO** | **TODO** +`--fsdp_use_orig_params` | \*\***REMOVED**\*\* | `FSDP2` uses a `DTensor` class on the background, which means it *always* uses the original parameters by default +\*\***NEW**\*\* | `--fsdp_version` | `2` is the default option, which means `FSDP2` is enabled by default, `FSDP1` can be selected by setting this to `1` + +For all other options that remain unchanged, see the [`FSDP` documentation](../usage_guides/fsdp.md). + +## How to Switch to FSDP2 + +### If using Python code: +Simply set `fsdp_version=2` when creating your plugin: + +```python +from accelerate import FullyShardedDataParallelPlugin, Accelerator + +fsdp_plugin = FullyShardedDataParallelPlugin( + fsdp_version=2 + # other options... +) +accelerator = Accelerator(fsdp_plugin=fsdp_plugin) +``` + +### If using YAML config: +Use our conversion tool: +```bash +accelerate to-fsdp2 --config_file config.yaml --output_file new_config.yaml +``` + +This will automatically convert all FSDP1 settings to their FSDP2 equivalents. Use `--overwrite` to update the existing file instead of creating a new one. diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 1abcee49345..abe6640d52d 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -33,6 +33,7 @@ import torch.utils.hooks as hooks from huggingface_hub import split_torch_state_dict_into_shards +from accelerate.utils.dataclasses import get_module_class_from_name from accelerate.utils.imports import is_torchao_available from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state @@ -117,6 +118,7 @@ from .utils.constants import ( BETA_TP_AVAILABLE_PYTORCH_VERSION, BETA_TP_AVAILABLE_TRANSFORMERS_VERSION, + FSDP2_PYTORCH_VERSION, FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME, ) @@ -388,6 +390,10 @@ def __init__( raise TypeError("`fsdp_plugin` must be a FullyShardedDataParallelPlugin object.") os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided + if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: + if not is_torch_version(">=", FSDP2_PYTORCH_VERSION): + raise ValueError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}") + if torch_tp_plugin is None: torch_tp_plugin = ( TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None @@ -1539,6 +1545,108 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e and _tp_plan attribute to model class." ) model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"]) + elif self.distributed_type == DistributedType.FSDP and self.state.fsdp_plugin.fsdp_version == 2: + from torch.distributed.fsdp import FSDPModule, fully_shard + + is_type_fsdp = isinstance(model, FSDPModule) or ( + is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule) + ) + + if not is_type_fsdp: + fsdp2_plugin = self.state.fsdp_plugin + + from torch.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, + ) + + auto_wrap_policy_type = None # extract the original type to create custom fn later + if fsdp2_plugin.auto_wrap_policy is transformer_auto_wrap_policy: + auto_wrap_policy_type = "transformer" + elif fsdp2_plugin.auto_wrap_policy is size_based_auto_wrap_policy: + auto_wrap_policy_type = "size" + + fsdp2_plugin.set_auto_wrap_policy( + model + ) # we set the auto_wrap policy to a functools.partial, so we can use it in apply_activation_checkpointing + + kwargs = { + "reshard_after_forward": fsdp2_plugin.reshard_after_forward, + "offload_policy": fsdp2_plugin.cpu_offload, + "mp_policy": fsdp2_plugin.mixed_precision_policy, + } + + if fsdp2_plugin.activation_checkpointing: + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, + ) + + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=functools.partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, # TODO(siro1): This breaks + ), + auto_wrap_policy=fsdp2_plugin.auto_wrap_policy, + ) + if (auto_wrap_policy := auto_wrap_policy_type) is not None: + from torch.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, + ) + + # Simulate the behavior of the old auto_wrap_policy + # TODO(siro1): abstract this into a function together with `set_auto_wrap_policy` + if auto_wrap_policy == "transformer": + no_split_modules = model._no_split_modules + if no_split_modules is None: + no_split_modules = [] + transformer_cls_names_to_wrap = list(no_split_modules) + if fsdp2_plugin.transformer_cls_names_to_wrap is not None: + transformer_cls_names_to_wrap = fsdp2_plugin.transformer_cls_names_to_wrap + transformer_cls_to_wrap = set() + + for layer_class in transformer_cls_names_to_wrap: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise ValueError( + f"Could not find the transformer layer class {layer_class} in the model." + ) + transformer_cls_to_wrap.add(transformer_cls) + + def policy(module: torch.nn.Module) -> bool: + if fsdp2_plugin.transformer_cls_names_to_wrap is None: + return False + return isinstance(module, tuple(transformer_cls_to_wrap)) + + elif auto_wrap_policy == "size": + + def policy(module: torch.nn.Module) -> bool: + return module.numel() > fsdp2_plugin.min_num_params + + stack = [model] + ordered_modules = [] + while stack: + current_module = stack.pop() + for _, attr in current_module.named_children(): + if isinstance(attr, torch.nn.Module): + stack.append(attr) + ordered_modules.append(current_module) + + for module in ordered_modules[::-1][ + :-1 + ]: # Skip the top-most module, as that one is wrapped even without policy + if policy(module): + fully_shard(module, **kwargs) + + fully_shard(model, **kwargs) # Wrap the top-most module nonetheless + + if len(self._models) > 1 and (self._models[-2] is self._models[-1]): + del self._models[-2] + self._models[-1] = model + elif self.distributed_type == DistributedType.FSDP: # We need to fix the optimizer *before* sharding the model from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP @@ -1564,7 +1672,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e ) kwargs = { - "sharding_strategy": fsdp_plugin.sharding_strategy, + "sharding_strategy": fsdp_plugin.reshard_after_forward, "cpu_offload": fsdp_plugin.cpu_offload, "auto_wrap_policy": fsdp_plugin.auto_wrap_policy, "mixed_precision": fsdp_plugin.mixed_precision_policy, @@ -2457,7 +2565,12 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2): parameters = [p for p in parameters] for model in self._models: if parameters == [p for p in model.parameters()]: - return model.clip_grad_norm_(max_norm, norm_type) + if self.fsdp_version == 1: + return model.clip_grad_norm_(max_norm, norm_type) + else: + return torch.nn.utils.clip_grad_norm_( + parameters, max_norm, norm_type=norm_type + ) # viz: https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md elif self.distributed_type == DistributedType.DEEPSPEED: # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed # We cannot return the gradient norm because DeepSpeed does it. diff --git a/src/accelerate/commands/accelerate_cli.py b/src/accelerate/commands/accelerate_cli.py index d9790e5805c..b878c8debd8 100644 --- a/src/accelerate/commands/accelerate_cli.py +++ b/src/accelerate/commands/accelerate_cli.py @@ -20,6 +20,7 @@ from accelerate.commands.launch import launch_command_parser from accelerate.commands.merge import merge_command_parser from accelerate.commands.test import test_command_parser +from accelerate.commands.to_fsdp2 import to_fsdp2_command_parser from accelerate.commands.tpu import tpu_command_parser from accelerate.commands.utils import CustomArgumentParser @@ -36,6 +37,7 @@ def main(): merge_command_parser(subparsers=subparsers) tpu_command_parser(subparsers=subparsers) test_command_parser(subparsers=subparsers) + to_fsdp2_command_parser(subparsers=subparsers) # Let's go args = parser.parse_args() diff --git a/src/accelerate/commands/config/cluster.py b/src/accelerate/commands/config/cluster.py index bb8aa9000dd..a8afaa19702 100644 --- a/src/accelerate/commands/config/cluster.py +++ b/src/accelerate/commands/config/cluster.py @@ -399,18 +399,36 @@ def get_cluster_input(): if use_fsdp: distributed_type = DistributedType.FSDP if distributed_type == DistributedType.FSDP: - sharding_strategy_query = "What should be your sharding strategy?" - fsdp_config["fsdp_sharding_strategy"] = _ask_options( - sharding_strategy_query, - FSDP_SHARDING_STRATEGY, - lambda x: FSDP_SHARDING_STRATEGY[int(x)], + fsdp_config["fsdp_version"] = _ask_options( + "What should be your FSDP version? [1]: ", + [1, 2], + lambda x: int(x) + 1, + default=1, ) + fsdp_version = fsdp_config["fsdp_version"] # extract to a variable to simplify usage later + + if fsdp_version == 1: + sharding_strategy_query = "What should be your sharding strategy?" + fsdp_config["fsdp_reshard_after_forward"] = _ask_options( + sharding_strategy_query, + FSDP_SHARDING_STRATEGY, + lambda x: FSDP_SHARDING_STRATEGY[int(x)], + ) + else: + fsdp_config["fsdp_reshard_after_forward"] = _ask_field( + "Do you want to enable resharding after forward? [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + fsdp_config["fsdp_offload_params"] = _ask_field( "Do you want to offload parameters and gradients to CPU? [yes/NO]: ", _convert_yes_no_to_bool, default=False, error_message="Please enter yes or no.", ) + fsdp_wrap_query = "What should be your auto wrap policy?" fsdp_config["fsdp_auto_wrap_policy"] = _ask_options( fsdp_wrap_query, @@ -436,12 +454,14 @@ def get_cluster_input(): int, default=100000000, ) - fsdp_backward_prefetch_query = "What should be your FSDP's backward prefetch policy?" - fsdp_config["fsdp_backward_prefetch"] = _ask_options( - fsdp_backward_prefetch_query, - FSDP_BACKWARD_PREFETCH, - lambda x: FSDP_BACKWARD_PREFETCH[int(x)], - ) + # Removed in FSDP2, ask for user input for FSDP1 + if fsdp_version == 1: + fsdp_backward_prefetch_query = "What should be your FSDP's backward prefetch policy?" + fsdp_config["fsdp_backward_prefetch"] = _ask_options( + fsdp_backward_prefetch_query, + FSDP_BACKWARD_PREFETCH, + lambda x: FSDP_BACKWARD_PREFETCH[int(x)], + ) fsdp_state_dict_type_query = "What should be your FSDP's state dict type?" fsdp_config["fsdp_state_dict_type"] = _ask_options( fsdp_state_dict_type_query, @@ -449,33 +469,39 @@ def get_cluster_input(): lambda x: FSDP_STATE_DICT_TYPE[int(x)], default=2, ) - fsdp_config["fsdp_forward_prefetch"] = _ask_field( - "Do you want to enable FSDP's forward prefetch policy? [yes/NO]: ", - _convert_yes_no_to_bool, - default=False, - error_message="Please enter yes or no.", - ) - fsdp_config["fsdp_use_orig_params"] = _ask_field( - "Do you want to enable FSDP's `use_orig_params` feature? [YES/no]: ", - _convert_yes_no_to_bool, - default=True, - error_message="Please enter yes or no.", - ) + # Not implemented in FSDP2, ask for user input for FSDP1 + if fsdp_version == 1: + fsdp_config["fsdp_forward_prefetch"] = _ask_field( + "Do you want to enable FSDP's forward prefetch policy? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + # Obsolete in FSDP2, ask for user input for FSDP1 + if fsdp_version == 1: + fsdp_config["fsdp_use_orig_params"] = _ask_field( + "Do you want to enable FSDP's `use_orig_params` feature? [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) fsdp_config["fsdp_cpu_ram_efficient_loading"] = _ask_field( "Do you want to enable CPU RAM efficient model loading? Only applicable for 🤗 Transformers models. [YES/no]: ", _convert_yes_no_to_bool, default=True, error_message="Please enter yes or no.", ) - if fsdp_config["fsdp_cpu_ram_efficient_loading"]: - fsdp_config["fsdp_sync_module_states"] = True - else: - fsdp_config["fsdp_sync_module_states"] = _ask_field( - "Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [YES/no]: ", - _convert_yes_no_to_bool, - default=True, - error_message="Please enter yes or no.", - ) + # Obsolete in FSDP2, ask for user input for FSDP1 + if fsdp_version == 1: + if fsdp_config["fsdp_cpu_ram_efficient_loading"]: + fsdp_config["fsdp_sync_module_states"] = True + else: + fsdp_config["fsdp_sync_module_states"] = _ask_field( + "Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) fsdp_config["fsdp_activation_checkpointing"] = _ask_field( "Do you want to enable FSDP activation checkpointing? [yes/NO]: ", _convert_yes_no_to_bool, diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index c460ea3ffb5..7ff9012eeea 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -530,10 +530,10 @@ def launch_command_parser(subparsers=None): help="FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `use_fsdp` flag is passed).", ) fsdp_args.add_argument( - "--fsdp_sharding_strategy", + "--fsdp_reshard_after_forward", # TODO(s1ro1): Maybe too harsh to rename to FSDP2 naming and not support FSDP1 naming type=str, default="FULL_SHARD", - help="FSDP's Sharding Strategy. (useful only when `use_fsdp` flag is passed).", + help="FSDP's Reshard After Forward Strategy. (useful only when `use_fsdp` flag is passed). Supports either boolean (FSDP2) or `FULL_SHARD | SHARD_GRAD_OP | NO_RESHARD` (FSDP1).", ) fsdp_args.add_argument( "--fsdp_auto_wrap_policy", diff --git a/src/accelerate/commands/to_fsdp2.py b/src/accelerate/commands/to_fsdp2.py new file mode 100644 index 00000000000..f339814a5ef --- /dev/null +++ b/src/accelerate/commands/to_fsdp2.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import logging +from pathlib import Path + +import yaml + +from accelerate.commands.utils import CustomArgumentParser + + +class ConversionStatus(enum.Enum): + NOT_YET_IMPLEMENTED = 0 + REMOVED = -1 + + +ARGUMENT_KEY_MAPPING = { + # New keys in FSDP2 + "fsdp_version": "fsdp_version", + "fsdp_reshard_after_forward": "fsdp_reshard_after_forward", + # https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md + # https://huggingface.co/docs/accelerate/en/usage_guides/fsdp + "fsdp_auto_wrap_policy": "fsdp_auto_wrap_policy", + "fsdp_backward_prefetch": ConversionStatus.REMOVED, + "fsdp_forward_prefetch": ConversionStatus.NOT_YET_IMPLEMENTED, + "fsdp_cpu_ram_efficient_loading": "fsdp_cpu_ram_efficient_loading", + "fsdp_offload_params": "fsdp_offload_params", # TODO: This becomes obsolete in FSDP2 + "fsdp_sharding_strategy": "fsdp_reshard_after_forward", + "fsdp_state_dict_type": "fsdp_state_dict_type", + "fsdp_sync_module_states": ConversionStatus.REMOVED, + "fsdp_transformer_layer_cls_to_wrap": "fsdp_transformer_layer_cls_to_wrap", + "fsdp_min_num_params": "fsdp_min_num_params", + "fsdp_use_orig_params": ConversionStatus.REMOVED, + "fsdp_activation_checkpointing": "fsdp_activation_checkpointing", # TODO: not in the docs? +} + +ARGUMENT_VALUE_MAPPING = { + "fsdp_sharding_strategy": { + "FULL_SHARD": True, + "SHARD_GRAD_OP": False, + "HYBRID_SHARD": True, + "HYBRID_SHARD_ZERO2": False, + "NO_SHARD": False, + }, + "fsdp_reshard_after_forward": { # Needed to convert newly created configs using FSDP1 to FSDP2 + "FULL_SHARD": True, + "SHARD_GRAD_OP": False, + "HYBRID_SHARD": True, + "HYBRID_SHARD_ZERO2": False, + "NO_SHARD": False, + }, +} + +logger = logging.getLogger(__name__) + + +def _validate_to_fsdp2_args(args): + if not Path(args.config_file).exists(): + raise FileNotFoundError(f"Config file {args.config_file} not found") + + if not args.overwrite and args.output_file is None: + raise ValueError("If --overwrite is not set, --output_file must be provided") + + if not args.overwrite and Path(args.output_file).exists(): + raise FileExistsError(f"Output file {args.output_file} already exists and --overwrite is not set") + + +def convert_config_to_fsdp2(config: dict, only_rename: bool = False) -> dict: + fsdp_config = config.get("fsdp_config", {}) + + if not fsdp_config: + logger.info("No FSDP config found in the config file, skipping conversion...") + return config + + new_fsdp_config = {} + + if fsdp_config.get("fsdp_version", 1) == 2: + logger.warning("Config already specfies FSDP2, skipping conversion...") + logger.warning( + "If the config doesn't use new argument names, change `fsdp_version` to `1` and rerun the command." + ) + return config + + for key, value in fsdp_config.items(): + conversion_status = ARGUMENT_KEY_MAPPING.get(key, None) + # short circuit if only renaming + if only_rename: + if isinstance(conversion_status, ConversionStatus) or conversion_status is None: + conversion_status = key + new_fsdp_config[conversion_status] = value + continue + + if conversion_status == ConversionStatus.REMOVED: + logger.warning(f"Argument {key} has been removed in FSDP2, skipping this key...") + continue + + if conversion_status == ConversionStatus.NOT_YET_IMPLEMENTED: + logger.warning(f"Argument {key} is not yet implemented in FSDP2, skipping this key...") + continue + + if conversion_status is None: + logger.warning(f"Argument {key} is not being converted, skipping this key...") + new_fsdp_config[key] = value + else: + if key in ARGUMENT_VALUE_MAPPING: + value = ARGUMENT_VALUE_MAPPING[key].get(value, value) + new_fsdp_config[ARGUMENT_KEY_MAPPING[key]] = value + + new_fsdp_config["fsdp_version"] = 1 if only_rename else 2 + config["fsdp_config"] = new_fsdp_config + return config + + +def to_fsdp2_command_parser(subparsers=None): + description = "Convert an Accelerate config from FSDP1 to FSDP2" + + if subparsers is not None: + parser = subparsers.add_parser("to-fsdp2", description=description) + else: + parser = CustomArgumentParser(description=description) + + parser.add_argument("--config_file", type=str, help="The config file to convert to FSDP2", required=True) + parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite the config file if it exists", + default=False, + ) + parser.add_argument( + "--output_file", + type=str, + help="The path to the output file to write the converted config to. If not provided, the input file will be overwritten (if --overwrite is set)", + default=None, + ) + parser.add_argument( + "--only-rename", + action="store_true", + help="If set to True, only rename keys in the config file and do not convert the config, this is required because of breaking changes introduced in FSDP2", + default=False, + ) + + if subparsers is not None: + parser.set_defaults(func=to_fsdp2_command) + + return parser + + +def load_config(config_file: str) -> dict: + with open(config_file) as f: + config = yaml.safe_load(f) + if not config: + raise ValueError("Config file is empty") + + return config + + +def to_fsdp2_command(args): + _validate_to_fsdp2_args(args) + config = load_config(args.config_file) + + if args.overwrite and args.output_file is None: + args.output_file = args.config_file + + new_config = convert_config_to_fsdp2(config, args.only_rename) + + with open(args.output_file, "w") as f: + yaml.dump(new_config, f) diff --git a/src/accelerate/test_utils/scripts/test_merge_weights.py b/src/accelerate/test_utils/scripts/test_merge_weights.py index a1390864047..911acd947de 100644 --- a/src/accelerate/test_utils/scripts/test_merge_weights.py +++ b/src/accelerate/test_utils/scripts/test_merge_weights.py @@ -50,7 +50,7 @@ def setup(): if AcceleratorState._shared_state != {}: AcceleratorState()._reset_state() plugin = FullyShardedDataParallelPlugin( - sharding_strategy=ShardingStrategy.FULL_SHARD, state_dict_type=StateDictType.SHARDED_STATE_DICT + reshard_after_forward=ShardingStrategy.FULL_SHARD, state_dict_type=StateDictType.SHARDED_STATE_DICT ) model = TinyModel() with patch_environment(fsdp_auto_wrap_policy="SIZE_BASED_WRAP"): diff --git a/src/accelerate/utils/constants.py b/src/accelerate/utils/constants.py index a9d840c896d..9413722b898 100644 --- a/src/accelerate/utils/constants.py +++ b/src/accelerate/utils/constants.py @@ -40,6 +40,7 @@ FSDP_PYTORCH_VERSION = ( "2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image. ) +FSDP2_PYTORCH_VERSION = "2.6.0" # TODO(s1ro): check if this is 100% correct FSDP_MODEL_NAME = "pytorch_model_fsdp" DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich"] TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"] diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 0c3d7c438ad..4838280c503 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -1439,22 +1439,27 @@ class FullyShardedDataParallelPlugin: This plugin is used to enable fully sharded data parallelism. Args: - sharding_strategy (`Union[str, torch.distributed.fsdp.ShardingStrategy]`, defaults to `'FULL_SHARD'`): - Sharding strategy to use. Should be either a `str` or an instance of + fsdp_version (`int`, defaults to `1`): + The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to + FSDP2 format. + sharding_strategy (`Union[str, torch.distributed.fsdp.ShardingStrategy, bool]`, defaults to `'FULL_SHARD'`): + Sharding strategy to use. Should be a bool if `fsdp_version` is set to 2 else a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`): Backward prefetch strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`. - mixed_precision_policy (`Optional[Union[dict, torch.distributed.fsdp.MixedPrecision]]`, defaults to `None`): + mixed_precision_policy (`Optional[Union[dict, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`): A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it - should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`. + should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of + `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. auto_wrap_policy (`Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to `NO_WRAP`): A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See `torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like. - cpu_offload (`Union[bool, torch.distributed.fsdp.CPUOffload]`, defaults to `False`): + cpu_offload (`Union[bool, torch.distributed.fsdp.CPUOffload, torch.distributed.fsdp.CPUOffloadPolicy]`, defaults to `False`): Whether to offload parameters to CPU. Should be either a `bool` or an instance of - `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload`. + `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or + `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. ignored_modules (`Optional[Iterable[torch.nn.Module]]`, defaults to `None`): A list of modules to ignore when wrapping with FSDP. state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`): @@ -1495,23 +1500,33 @@ class FullyShardedDataParallelPlugin: is `size_based_wrap`. """ - sharding_strategy: Union[str, "torch.distributed.fsdp.ShardingStrategy"] = field( + fsdp_version: int = field( default=None, metadata={ - "help": "Sharding strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'" + "help": "The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to FSDP2 format." }, ) - backward_prefetch: Union[str, "torch.distributed.fsdp.BackwardPrefetch"] = field( + + reshard_after_forward: Union[str, "torch.distributed.fsdp.ShardingStrategy", bool] = field( + default=None, + metadata={ + "help": "Sharding strategy to use. Should be a bool if `fsdp_version` is set to 2 else a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'" + }, + ) + backward_prefetch: Optional[Union[str, "torch.distributed.fsdp.BackwardPrefetch"]] = field( default=None, metadata={ - "help": "Backward prefetch strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`. Defaults to 'NO_PREFETCH'" + "help": "Backward prefetch strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`. Defaults to 'NO_PREFETCH'. This becomes obsolete in FSDP2." }, ) - mixed_precision_policy: Optional[Union[dict, "torch.distributed.fsdp.MixedPrecision"]] = field( + mixed_precision_policy: Optional[ + Union[dict, "torch.distributed.fsdp.MixedPrecision", "torch.distributed.fsdp.MixedPrecisionPolicy"] + ] = field( default=None, metadata={ "help": "A config to enable mixed precision training with FullyShardedDataParallel. " "If passing in a `dict`, it should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`." + "Can also be an instance of `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2." }, ) auto_wrap_policy: Optional[Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]] = ( @@ -1523,10 +1538,10 @@ class FullyShardedDataParallelPlugin: }, ) ) - cpu_offload: Union[bool, "torch.distributed.fsdp.CPUOffload"] = field( + cpu_offload: Union[bool, "torch.distributed.fsdp.CPUOffload", "torch.distributed.fsdp.CPUOffloadPolicy"] = field( default=None, metadata={ - "help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload`. Defaults to `False`" + "help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`" }, ) ignored_modules: Optional[Iterable[torch.nn.Module]] = field( @@ -1565,9 +1580,11 @@ class FullyShardedDataParallelPlugin: "Enabling this can help lower the number of CUDA malloc retries." }, ) - use_orig_params: bool = field( + use_orig_params: Optional[bool] = field( default=None, - metadata={"help": "Whether to use the original parameters for the optimizer. Defaults to `False`"}, + metadata={ + "help": "Whether to use the original parameters for the optimizer. Defaults to `False`. This becomes obsolete in FSDP2." + }, ) param_init_fn: Optional[Callable[[torch.nn.Module], None]] = field( default=None, @@ -1577,12 +1594,12 @@ class FullyShardedDataParallelPlugin: "Only applicable when `sync_module_states` is `True`. By default is a `lambda` which calls `to_empty` on the module." }, ) - sync_module_states: bool = field( + sync_module_states: Optional[bool] = field( default=None, metadata={ "help": "Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 " "to ensure they are the same across all ranks after initialization. Defaults to `False` unless " - "`cpu_ram_efficient_loading` is `True`, then will be forcibly enabled." + "`cpu_ram_efficient_loading` is `True`, then will be forcibly enabled. This becomes obsolete in FSDP2." }, ) forward_prefetch: bool = field( @@ -1623,27 +1640,40 @@ class FullyShardedDataParallelPlugin: def __post_init__(self): from torch.distributed.fsdp import ( BackwardPrefetch, - CPUOffload, ShardingStrategy, ) env_prefix = "FSDP_" # Strategy: By default we should always assume that values are passed in, else we check the environment variables - if self.sharding_strategy is None: - self.sharding_strategy = os.environ.get(env_prefix + "SHARDING_STRATEGY", "FULL_SHARD") - if isinstance(self.sharding_strategy, str): - # We need to remap based on custom enum values for user readability - if self.sharding_strategy.upper() in FSDP_SHARDING_STRATEGY: - self.sharding_strategy = FSDP_SHARDING_STRATEGY.index(self.sharding_strategy.upper()) + 1 - if isinstance(self.sharding_strategy, int) or self.sharding_strategy.isdigit(): - self.sharding_strategy = ShardingStrategy(int(self.sharding_strategy)) + if self.fsdp_version is None: + self.fsdp_version = int(os.environ.get(env_prefix + "VERSION", "1")) + + if self.reshard_after_forward is None: + self.reshard_after_forward = os.environ.get( + env_prefix + "RESHARD_AFTER_FORWARD", + True if self.fsdp_version == 2 else "FULL_SHARD", + ) + if isinstance(self.reshard_after_forward, str): + if self.fsdp_version == 2: + self.reshard_after_forward = str_to_bool(self.reshard_after_forward.lower(), to_bool=True) else: - self.sharding_strategy = ShardingStrategy[self.sharding_strategy.upper()] + # We need to remap based on custom enum values for user readability + if self.reshard_after_forward.upper() in FSDP_SHARDING_STRATEGY: + self.reshard_after_forward = FSDP_SHARDING_STRATEGY.index(self.reshard_after_forward.upper()) + 1 + if isinstance(self.reshard_after_forward, int) or self.reshard_after_forward.isdigit(): + self.reshard_after_forward = ShardingStrategy(int(self.reshard_after_forward)) + else: + self.reshard_after_forward = ShardingStrategy[self.reshard_after_forward.upper()] + if self.fsdp_version != 2 and isinstance(self.reshard_after_forward, bool): + raise ValueError( + f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`" + ) if self.cpu_offload is None: self.cpu_offload = str_to_bool(os.environ.get(env_prefix + "OFFLOAD_PARAMS", "False")) == 1 - if isinstance(self.cpu_offload, bool): - self.cpu_offload = CPUOffload(offload_params=self.cpu_offload) + + self.set_cpu_offload() # abstracted away to hide imports due to version checks + self.validate_cpu_offload() if self.backward_prefetch is None: self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None) @@ -1656,6 +1686,9 @@ def __post_init__(self): self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch)) else: self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()] + if self.fsdp_version == 2 and self.backward_prefetch is not None: + warnings.warn("Backward prefetch is not supported in FSDP2. Setting backward prefetch to None.") + self.backward_prefetch = None self.set_state_dict_type() @@ -1685,14 +1718,22 @@ def __post_init__(self): elif self.auto_wrap_policy.upper() == "NO_WRAP": self.auto_wrap_policy = None - if self.use_orig_params is None: + if self.use_orig_params is None and self.fsdp_version == 1: self.use_orig_params = str_to_bool(os.environ.get(env_prefix + "USE_ORIG_PARAMS", "False")) == 1 + if self.fsdp_version == 2 and self.use_orig_params is not None: + warnings.warn("use_orig_params is obsolete in FSDP2, as FSDP2 always uses the original parameters.") + self.use_orig_params = None - if self.sync_module_states is None: + if self.sync_module_states is None and self.fsdp_version == 1: self.sync_module_states = str_to_bool(os.environ.get(env_prefix + "SYNC_MODULE_STATES", "False")) == 1 + if self.fsdp_version == 2 and self.sync_module_states is not None: + warnings.warn("sync_module_states is obsolete in FSDP2, as it is not needed anymore.") + self.sync_module_states = None - if self.forward_prefetch is None: + if self.forward_prefetch is None and self.fsdp_version == 1: self.forward_prefetch = str_to_bool(os.environ.get(env_prefix + "FORWARD_PREFETCH", "False")) == 1 + if self.fsdp_version == 2 and self.forward_prefetch is not None: + raise ValueError("forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version=1`") if self.activation_checkpointing is None: self.activation_checkpointing = ( @@ -1704,15 +1745,25 @@ def __post_init__(self): str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1 ) - if self.cpu_ram_efficient_loading and not self.sync_module_states: + if self.cpu_ram_efficient_loading and self.sync_module_states is False: warnings.warn( "sync_module_states cannot be False since efficient cpu ram loading enabled. " "Setting sync_module_states to True." ) self.sync_module_states = True + # Invariant: sync_module_states is None in FSDP2 only + if self.cpu_ram_efficient_loading and self.sync_module_states is None: + # TODO: how does this interact with FSDP2 + warnings.warn( + "cpu_ram_efficient_loading is enabled, but sync_module_states is not set. " + "This is not properly tested yet in FSDP2" + ) + if isinstance(self.mixed_precision_policy, dict): self.set_mixed_precision(self.mixed_precision_policy) + if self.mixed_precision_policy is not None: + self.validate_mixed_precision_policy() if self.sync_module_states: if is_npu_available(): @@ -1822,12 +1873,22 @@ def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=F buffer_type = torch.float32 if buffer_autocast else dtype - from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision + if self.fsdp_version == 1: + from torch.distributed.fsdp import MixedPrecision + elif self.fsdp_version == 2: + from torch.distributed.fsdp import MixedPrecisionPolicy if override or self.mixed_precision_policy is None: - self.mixed_precision_policy = MixedPrecision( - param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=buffer_type - ) + if self.fsdp_version == 1: + self.mixed_precision_policy = MixedPrecision( + param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=buffer_type + ) + elif self.fsdp_version == 2: # at this point we're sure we are at a version that supports it + self.mixed_precision_policy = MixedPrecisionPolicy( + param_dtype=dtype, + reduce_dtype=dtype, + output_dtype=dtype, # TODO(s1ro1): `cast_forward_inputs`? + ) elif isinstance(self.mixed_precision_policy, dict): # Check for incompatible types missing_keys = [ @@ -1842,7 +1903,61 @@ def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=F f"Must be a `dict` with keys `param_dtype`, `reduce_dtype`, and `buffer_dtype`. " f"Values must be one of {list(mixed_precision_mapping.values())}" ) - self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy) + if self.fsdp_version == 1: + self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy) + elif self.fsdp_version == 2: + self.mixed_precision_policy = MixedPrecisionPolicy(**self.mixed_precision_policy) + + def validate_mixed_precision_policy(self): + """ + Validates the mixed precision policy, abstracted away to not bring in the imports if not needed. + """ + if self.fsdp_version == 2: + from torch.distributed.fsdp import ( + MixedPrecisionPolicy, # at this point we're sure we are at a version that supports it + ) + + if not isinstance(self.mixed_precision_policy, MixedPrecisionPolicy): + raise ValueError( + "mixed_precision_policy must be an instance of `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2" + ) + if self.fsdp_version == 1: + from torch.distributed.fsdp import MixedPrecision + + if not isinstance(self.mixed_precision_policy, MixedPrecision): + raise ValueError( + "mixed_precision_policy must be an instance of `torch.distributed.fsdp.MixedPrecision` if `fsdp_version` is set to 1" + ) + + def set_cpu_offload(self): + if self.fsdp_version == 2: + from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy + else: + from torch.distributed.fsdp import CPUOffload + + if isinstance(self.cpu_offload, bool): + if self.fsdp_version == 2: + if not self.cpu_offload: + self.cpu_offload = OffloadPolicy() + else: + self.cpu_offload = CPUOffloadPolicy() + else: + self.cpu_offload = CPUOffload(offload_params=self.cpu_offload) + + def validate_cpu_offload(self): + if self.fsdp_version == 2: + from torch.distributed.fsdp import OffloadPolicy + else: + from torch.distributed.fsdp import CPUOffload + + if self.fsdp_version == 2 and not isinstance(self.cpu_offload, OffloadPolicy): + raise ValueError( + f"`cpu_offload` must be an instance of `torch.distributed.fsdp.OffloadPolicy` in FSDP2, got {self.cpu_offload}" + ) + if self.fsdp_version == 1 and not isinstance(self.cpu_offload, CPUOffload): + raise ValueError( + f"`cpu_offload` must be an instance of `torch.distributed.fsdp.CPUOffload` in FSDP1, got {self.cpu_offload}" + ) @dataclass diff --git a/src/accelerate/utils/environment.py b/src/accelerate/utils/environment.py index 1f64a144333..bcafb2d9f01 100644 --- a/src/accelerate/utils/environment.py +++ b/src/accelerate/utils/environment.py @@ -22,7 +22,7 @@ from dataclasses import dataclass, field from functools import lru_cache, wraps from shutil import which -from typing import List, Optional +from typing import List, Optional, Union import torch from packaging.version import parse @@ -56,7 +56,7 @@ def convert_dict_to_env_variables(current_env: dict): return valid_env_items -def str_to_bool(value) -> int: +def str_to_bool(value, to_bool: bool = False) -> Union[int, bool]: """ Converts a string representation of truth to `True` (1) or `False` (0). diff --git a/src/accelerate/utils/launch.py b/src/accelerate/utils/launch.py index a7497d73781..a34da0871a9 100644 --- a/src/accelerate/utils/launch.py +++ b/src/accelerate/utils/launch.py @@ -278,8 +278,11 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]: if args.fsdp_cpu_ram_efficient_loading and not args.fsdp_sync_module_states: raise ValueError("When using `--fsdp_cpu_ram_efficient_loading` set `--fsdp_sync_module_states` to `True`") - current_env["FSDP_SHARDING_STRATEGY"] = str(args.fsdp_sharding_strategy) - current_env["FSDP_OFFLOAD_PARAMS"] = str(args.fsdp_offload_params).lower() + current_env["FSDP_VERSION"] = str(args.fsdp_version) if hasattr(args, "fsdp_version") else "1" + current_env["FSDP_RESHARD_AFTER_FORWARD"] = str( + args.fsdp_reshard_after_forward + ).lower() # TODO(s1ro1): this breaks with old configs, maybe too harsh? + current_env["FSDP_OFFLOAD_PARAMS"] = str(args.fsdp_offload_params) current_env["FSDP_MIN_NUM_PARAMS"] = str(args.fsdp_min_num_params) if args.fsdp_auto_wrap_policy is not None: current_env["FSDP_AUTO_WRAP_POLICY"] = str(args.fsdp_auto_wrap_policy) diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 083d32bfc26..3daebf0cd46 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -76,22 +76,22 @@ def test_sharding_strategy(self): # check that giving enums works fine for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): env = self.fsdp_env.copy() - env["FSDP_SHARDING_STRATEGY"] = f"{i + 1}" + env["FSDP_RESHARD_AFTER_FORWARD"] = f"{i + 1}" with mockenv_context(**env): fsdp_plugin = FullyShardedDataParallelPlugin() - assert fsdp_plugin.sharding_strategy == ShardingStrategy(i + 1) - fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=ShardingStrategy(i + 1)) - assert fsdp_plugin.sharding_strategy == ShardingStrategy(i + 1) + assert fsdp_plugin.reshard_after_forward == ShardingStrategy(i + 1) + fsdp_plugin = FullyShardedDataParallelPlugin(reshard_after_forward=ShardingStrategy(i + 1)) + assert fsdp_plugin.reshard_after_forward == ShardingStrategy(i + 1) # check that giving names works fine for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): env = self.fsdp_env.copy() - env["FSDP_SHARDING_STRATEGY"] = strategy + env["FSDP_RESHARD_AFTER_FORWARD"] = strategy with mockenv_context(**env): fsdp_plugin = FullyShardedDataParallelPlugin() - assert fsdp_plugin.sharding_strategy == ShardingStrategy(i + 1) - fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=strategy) - assert fsdp_plugin.sharding_strategy == ShardingStrategy(i + 1) + assert fsdp_plugin.reshard_after_forward == ShardingStrategy(i + 1) + fsdp_plugin = FullyShardedDataParallelPlugin(reshard_after_forward=strategy) + assert fsdp_plugin.reshard_after_forward == ShardingStrategy(i + 1) def test_backward_prefetch(self): from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch @@ -318,7 +318,7 @@ def test_performance(self): cmd_config = cmd.copy() for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): if strategy.lower() in config: - cmd_config.append(f"--fsdp_sharding_strategy={strategy}") + cmd_config.append(f"--fsdp_reshard_after_forward={strategy}") break if "fp32" in config: @@ -362,7 +362,7 @@ def test_checkpointing(self): for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): cmd_config = cmd.copy() - cmd_config.append(f"--fsdp_sharding_strategy={strategy}") + cmd_config.append(f"--fsdp_reshard_after_forward={strategy}") if strategy != "FULL_SHARD": continue state_dict_config_index = len(cmd_config) @@ -410,7 +410,7 @@ def test_peak_memory_usage(self): cmd_config.extend(["--use_fsdp"]) for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): if strategy.lower() in spec: - cmd_config.append(f"--fsdp_sharding_strategy={strategy}") + cmd_config.append(f"--fsdp_reshard_after_forward={strategy}") break if "cpu_offload" in spec: diff --git a/tests/test_cli.py b/tests/test_cli.py index 28945a23515..abd65ad4c0c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -23,6 +23,7 @@ from accelerate.commands.config.config_args import BaseConfig, ClusterConfig, SageMakerConfig, load_config_from_file from accelerate.commands.estimate import estimate_command, estimate_command_parser, gather_data from accelerate.commands.launch import _validate_launch_command, launch_command, launch_command_parser +from accelerate.commands.to_fsdp2 import to_fsdp2_command, to_fsdp2_command_parser from accelerate.commands.tpu import tpu_command_launcher, tpu_command_parser from accelerate.test_utils.testing import ( capture_call_output, @@ -535,3 +536,98 @@ def test_timm_model(self): assert ( total_size == output[0][2] ), f"Calculation for total size in `fp32` is incorrect, expected {total_size} but received {output[0][2]}" + + +class ToFSDP2Tester(unittest.TestCase): + """ + Test case for verifying the `accelerate to-fsdp2` CLI outputs. + """ + + parser = to_fsdp2_command_parser() + test_config_path = Path("tests/test_configs") + + @classmethod + def setUpClass(cls): + if (cls.test_config_path / "latest_fsdp.yaml").exists(): + cls.original_config = load_config_from_file(str(cls.test_config_path / "latest_fsdp.yaml")) + + @classmethod + def tearDownClass(cls): + if cls.original_config is not None: + cls.original_config.to_yaml_file(str(cls.test_config_path / "latest_fsdp.yaml")) + + def tearDown(self): + if (self.test_config_path / "output.yaml").exists(): + (self.test_config_path / "output.yaml").unlink() + + def test_nonexistent_config_file(self): + with self.assertRaises(FileNotFoundError, msg="Config file `nonexistent.yaml` not found"): + args = self.parser.parse_args(["--config_file", "nonexistent.yaml"]) + to_fsdp2_command(args) + + def test_no_output_without_overwrite(self): + with self.assertRaises(ValueError, msg="If --overwrite is not set, --output_file must be provided"): + args = self.parser.parse_args(["--config_file", str(self.test_config_path / "latest_fsdp.yaml")]) + to_fsdp2_command(args) + + @patch("pathlib.Path.exists") + def test_overwrite_when_output_file_exists(self, mock_exists): + mock_exists.side_effect = ( + lambda: str(mock_exists._mock_self) == "output.yaml" or mock_exists._mock_self.exists() + ) + + with self.assertRaises( + FileExistsError, msg="Output file `output.yaml` already exists and --overwrite is not set" + ): + args = self.parser.parse_args( + ["--config_file", str(self.test_config_path / "latest_fsdp.yaml"), "--output_file", "output.yaml"] + ) + to_fsdp2_command(args) + + def test_fsdp2_config(self): + args = self.parser.parse_args( + [ + "--config_file", + str(self.test_config_path / "latest_fsdp.yaml"), + "--output_file", + str(self.test_config_path / "output.yaml"), + ] + ) + to_fsdp2_command(args) + + config = load_config_from_file(str(self.test_config_path / "output.yaml")) + assert isinstance(config, ClusterConfig) + assert config.fsdp_config["fsdp_version"] == 2 + + def test_config_already_fsdp2(self): + args = self.parser.parse_args( + [ + "--config_file", + str(self.test_config_path / "latest_fsdp.yaml"), + "--output_file", + str(self.test_config_path / "output.yaml"), + ] + ) + + mock_config = {"fsdp_config": {"fsdp_version": 2}} + + with patch("accelerate.commands.to_fsdp2.load_config", return_value=mock_config): + with self.assertLogs(level="WARNING") as cm: + to_fsdp2_command(args) + + assert "Config already specfies FSDP2, skipping conversion..." in cm.output[0] + + # Has to be the last test because it overwrites the config file + def test_fsdp2_overwrite(self): + args = self.parser.parse_args( + [ + "--config_file", + str(self.test_config_path / "latest_fsdp.yaml"), + "--overwrite", + ] + ) + to_fsdp2_command(args) + + config = load_config_from_file(str(self.test_config_path / "latest_fsdp.yaml")) + assert isinstance(config, ClusterConfig) + assert config.fsdp_config["fsdp_version"] == 2 diff --git a/tests/test_configs/latest_fsdp.yaml b/tests/test_configs/latest_fsdp.yaml new file mode 100644 index 00000000000..ccee7d00c58 --- /dev/null +++ b/tests/test_configs/latest_fsdp.yaml @@ -0,0 +1,28 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_transformer_layer_cls_to_wrap: BertLayer + fsdp_use_orig_params: true +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false