Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Initial FSDP2 support #3394

Draft
wants to merge 33 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2cfe2db
Feat: initial conversion tool draft
S1ro1 Feb 11, 2025
d055eab
Feat: add value mapping to conversion tool
S1ro1 Feb 11, 2025
23ea666
Refactor: move from os to pathlib
S1ro1 Feb 11, 2025
7deaaaa
Feat: add first tests
S1ro1 Feb 11, 2025
d506b95
Feat: more tests
S1ro1 Feb 11, 2025
3eab226
Feat: minor fixes + dataclass conversions
S1ro1 Feb 11, 2025
269c217
Feat: more remapping
S1ro1 Feb 12, 2025
ba5372c
Fix: namespace has no attribute version + style
S1ro1 Feb 12, 2025
b16131b
Fix: offload params behavior
S1ro1 Feb 12, 2025
b0cc66b
Feat: add option to only rename keys in the config file to
S1ro1 Feb 12, 2025
8bc5cb3
Fix: wrong attr name
S1ro1 Feb 12, 2025
c31cc55
Merge branch 'main' into dev/fsdp2
S1ro1 Feb 12, 2025
bc70ec2
Fix: partially resolve comments
S1ro1 Feb 13, 2025
00dafc4
Feat: work on config command + minor fixes to reflect changes
S1ro1 Feb 13, 2025
7a92dac
Refactor: style + quality
S1ro1 Feb 13, 2025
b724f9c
Feat: fsdp2 initial work
S1ro1 Feb 13, 2025
7cd2587
Feat: some cleanups and first running fsdp2
S1ro1 Feb 13, 2025
d920e94
Fix: version checks + mixed precision policy
S1ro1 Feb 17, 2025
3cc7c20
Refactor: style + quality
S1ro1 Feb 17, 2025
432b4ff
Remove obsolete todos
S1ro1 Feb 17, 2025
bb83985
Feat: grad norm clipping
S1ro1 Feb 17, 2025
7759d79
Fix: tests + rename attrs
S1ro1 Feb 17, 2025
f8bfa04
Refactor: style + quality
S1ro1 Feb 17, 2025
47a12b9
Fix: None object is not iterable
S1ro1 Feb 17, 2025
030fa5d
Fix: default cpu_offload for fsdp2
S1ro1 Feb 18, 2025
c5526aa
Fix: cpu offload now behaves correctly
S1ro1 Feb 26, 2025
4984c8a
Feat: apply_activation_checkpointing
S1ro1 Feb 26, 2025
914e55f
Fix: append to models
S1ro1 Mar 3, 2025
1304284
Feat: start on concept guide
S1ro1 Mar 6, 2025
f9061f6
Merge branch 'main' into dev/fsdp2
S1ro1 Mar 6, 2025
c596c12
wip: concept guide
S1ro1 Mar 6, 2025
d55bab7
Fix: toctree
S1ro1 Mar 6, 2025
8494f7b
cleanup of the concept guide
S1ro1 Mar 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/accelerate/commands/accelerate_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from accelerate.commands.launch import launch_command_parser
from accelerate.commands.merge import merge_command_parser
from accelerate.commands.test import test_command_parser
from accelerate.commands.to_fsdp2 import to_fsdp2_command_parser
from accelerate.commands.tpu import tpu_command_parser
from accelerate.commands.utils import CustomArgumentParser

Expand All @@ -36,6 +37,7 @@ def main():
merge_command_parser(subparsers=subparsers)
tpu_command_parser(subparsers=subparsers)
test_command_parser(subparsers=subparsers)
to_fsdp2_command_parser(subparsers=subparsers)

# Let's go
args = parser.parse_args()
Expand Down
180 changes: 180 additions & 0 deletions src/accelerate/commands/to_fsdp2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
#!/usr/bin/env python

# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import logging
from pathlib import Path

import yaml

from accelerate.commands.utils import CustomArgumentParser
from accelerate.utils import is_rich_available


class ConversionStatus(enum.Enum):
NOT_YET_IMPLEMENTED = 0
REMOVED = -1


ARGUMENT_KEY_MAPPING = {
"fsdp_version": "fsdp_version",
# https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
# https://huggingface.co/docs/accelerate/en/usage_guides/fsdp
"fsdp_auto_wrap_policy": "fsdp_auto_wrap_policy",
"fsdp_backward_prefetch": ConversionStatus.REMOVED,
"fsdp_forward_prefetch": ConversionStatus.NOT_YET_IMPLEMENTED,
"fsdp_cpu_ram_efficient_loading": "fsdp_cpu_ram_efficient_loading",
"fsdp_offload_params": "fsdp_offload_params", # TODO: This becomes obsolete in FSDP2
"fsdp_sharding_strategy": "fsdp_reshard_after_forward",
"fsdp_state_dict_type": "fsdp_state_dict_type",
"fsdp_sync_module_states": ConversionStatus.REMOVED,
"fsdp_transformer_layer_cls_to_wrap": "fsdp_transformer_layer_cls_to_wrap",
"fsdp_min_num_params": "fsdp_min_num_params",
"fsdp_use_orig_params": ConversionStatus.REMOVED,
"fsdp_activation_checkpointing": "fsdp_activation_checkpointing", # TODO: not in the docs?
}

ARGUMENT_VALUE_MAPPING = {
"fsdp_sharding_strategy": {
"FULL_SHARD": True,
"SHARD_GRAD_OP": False,
"HYBRID_SHARD": True,
"HYBRID_SHARD_ZERO2": False,
},
# TODO: do we need to handle mp/offload policy
}

if is_rich_available():
from rich.logging import RichHandler

FORMAT = "%(message)s"
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()])


logger = logging.getLogger(__name__)


def _validate_to_fsdp2_args(args):
if not Path(args.config_file).exists():
raise FileNotFoundError(f"Config file {args.config_file} not found")

if not args.overwrite and args.output_file is None:
raise ValueError("If --overwrite is not set, --output_file must be provided")

if not args.overwrite and Path(args.output_file).exists():
raise FileExistsError(f"Output file {args.output_file} already exists and --overwrite is not set")


def convert_config_to_fsdp2(config: dict, only_rename: bool = False) -> dict:
fsdp_config = config.get("fsdp_config", {})

if not fsdp_config:
logger.info("No FSDP config found in the config file, skipping conversion...")
return config

new_fsdp_config = {}

if fsdp_config.get("fsdp_version", 1) == 2:
logger.warning("Config already specfies FSDP2, skipping conversion...")
logger.warning(
"If the config doesn't use new argument names, change `fsdp_version` to `1` and rerun the command."
)
return config

for key, value in fsdp_config.items():
conversion_status = ARGUMENT_KEY_MAPPING.get(key, None)
# short circuit if only renaming
if only_rename:
if isinstance(conversion_status, ConversionStatus) or conversion_status is None:
conversion_status = key
new_fsdp_config[conversion_status] = value
continue

if conversion_status == ConversionStatus.REMOVED:
logger.warning(f"Argument {key} has been removed in FSDP2, skipping this key...")
continue

if conversion_status == ConversionStatus.NOT_YET_IMPLEMENTED:
logger.warning(f"Argument {key} is not yet implemented in FSDP2, skipping this key...")
continue

if conversion_status is None:
logger.warning(f"Argument {key} is not being converted, skipping this key...")
new_fsdp_config[key] = value
else:
if key in ARGUMENT_VALUE_MAPPING:
value = ARGUMENT_VALUE_MAPPING[key].get(value, value)
new_fsdp_config[ARGUMENT_KEY_MAPPING[key]] = value

new_fsdp_config["fsdp_version"] = 1 if only_rename else 2
config["fsdp_config"] = new_fsdp_config
return config


def to_fsdp2_command_parser(subparsers=None):
description = "Convert an Accelerate config from FSDP1 to FSDP2"

if subparsers is not None:
parser = subparsers.add_parser("to-fsdp2", description=description)
else:
parser = CustomArgumentParser(description=description)

parser.add_argument("--config_file", type=str, help="The config file to convert to FSDP2", required=True)
parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite the config file if it exists",
default=False,
)
parser.add_argument(
"--output_file",
type=str,
help="The path to the output file to write the converted config to. If not provided, the input file will be overwritten (if --overwrite is set)",
default=None,
)
parser.add_argument(
"--only-rename",
action="store_true",
help="If set to True, only rename keys in the config file and do not convert the config, this is required because of breaking changes introduced in FSDP2",
default=False,
)

if subparsers is not None:
parser.set_defaults(func=to_fsdp2_command)

return parser


def load_config(config_file: str) -> dict:
with open(config_file) as f:
config = yaml.safe_load(f)
if not config:
raise ValueError("Config file is empty")

return config


def to_fsdp2_command(args):
_validate_to_fsdp2_args(args)
config = load_config(args.config_file)

if args.overwrite and args.output_file is None:
args.output_file = args.config_file

new_config = convert_config_to_fsdp2(config, args.only_rename)

with open(args.output_file, "w") as f:
yaml.dump(new_config, f)
2 changes: 1 addition & 1 deletion src/accelerate/test_utils/scripts/test_merge_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def setup():
if AcceleratorState._shared_state != {}:
AcceleratorState()._reset_state()
plugin = FullyShardedDataParallelPlugin(
sharding_strategy=ShardingStrategy.FULL_SHARD, state_dict_type=StateDictType.SHARDED_STATE_DICT
reshard_after_forward=ShardingStrategy.FULL_SHARD, state_dict_type=StateDictType.SHARDED_STATE_DICT
)
model = TinyModel()
with patch_environment(fsdp_auto_wrap_policy="SIZE_BASED_WRAP"):
Expand Down
Loading
Loading