Skip to content

Commit

Permalink
fix: ensure CLI args take precedence over config file. (#3409)
Browse files Browse the repository at this point in the history
* fix: ensure CLI args take precedence over config file.

* add test case

* remove inappropriate comment

---------

Co-authored-by: 차영록 <[email protected]>
  • Loading branch information
cyr0930 and 차영록 authored Feb 28, 2025
1 parent 90f8198 commit c7b3625
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 27 deletions.
37 changes: 11 additions & 26 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,35 +1043,20 @@ def _validate_launch_command(args):
# Update args with the defaults
for name, attr in defaults.__dict__.items():
if isinstance(attr, dict):
for k in defaults.deepspeed_config:
setattr(args, k, defaults.deepspeed_config[k])
for k in defaults.fsdp_config:
arg_to_set = k
if "fsdp" not in arg_to_set:
arg_to_set = "fsdp_" + arg_to_set
setattr(args, arg_to_set, defaults.fsdp_config[k])
for k in defaults.tp_config:
setattr(args, k, defaults.tp_config[k])
for k in defaults.megatron_lm_config:
setattr(args, k, defaults.megatron_lm_config[k])
for k in defaults.dynamo_config:
setattr(args, k, defaults.dynamo_config[k])
for k in defaults.ipex_config:
setattr(args, k, defaults.ipex_config[k])
for k in defaults.mpirun_config:
setattr(args, k, defaults.mpirun_config[k])
for k in defaults.fp8_config:
arg_to_set = k
if "fp8" not in arg_to_set:
arg_to_set = "fp8_" + arg_to_set
setattr(args, arg_to_set, defaults.fp8_config[k])
continue

# Those args are handled separately
if (
# Copy defaults.somedict.somearg to args.somearg and
# defaults.fsdp_config.x to args.fsdp_x
for key, value in attr.items():
if name == "fsdp_config" and not key.startswith("fsdp"):
key = "fsdp_" + key
elif name == "fp8_config" and not key.startswith("fp8"):
key = "fp8_" + key
if hasattr(args, "nondefault") and key not in args.nondefault:
setattr(args, key, value)
elif (
name not in ["compute_environment", "mixed_precision", "distributed_type"]
and getattr(args, name, None) is None
):
# Those args are handled separately
setattr(args, name, attr)
if not args.debug:
args.debug = defaults.debug
Expand Down
5 changes: 4 additions & 1 deletion src/accelerate/commands/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def __init__(self, *args, **kwargs):

def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, values)
if not hasattr(namespace, "nondefault"):
namespace.nondefault = set()
namespace.nondefault.add(self.dest)


class _StoreConstAction(_StoreAction):
Expand All @@ -51,7 +54,7 @@ def __init__(self, option_strings, dest, const, default=None, required=False, he
)

def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, self.const)
super().__call__(parser, namespace, self.const, option_string)


class _StoreTrueAction(_StoreConstAction):
Expand Down
22 changes: 22 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,28 @@ def test_mpi_multicpu_config_cmd(self):
self.assertEqual(python_script_cmd[1], str(self.test_file_path))
self.assertEqual(python_script_cmd[2], test_file_arg)

def test_validate_launch_command(self):
"""Test that the validation function combines args and defaults."""
parser = launch_command_parser()
args = parser.parse_args(
[
"--num-processes",
"2",
"--deepspeed_config_file",
"path/to/be/accepted",
"--config-file",
str(self.test_config_path / "validate_launch_cmd.yaml"),
"test.py",
]
)
self.assertFalse(args.debug)
self.assertTrue(args.fsdp_sync_module_states)
_validate_launch_command(args)
self.assertTrue(args.debug)
self.assertEqual(2, args.num_processes)
self.assertFalse(args.fsdp_sync_module_states)
self.assertEqual("path/to/be/accepted", args.deepspeed_config_file)


class LaunchArgTester(unittest.TestCase):
"""
Expand Down
8 changes: 8 additions & 0 deletions tests/test_configs/validate_launch_cmd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
compute_environment: LOCAL_MACHINE
debug: true
num_processes: 1
distributed_type: 'NO'
fsdp_config:
fsdp_sync_module_states: false
deepspeed_config:
deepspeed_config_file: path/to/be/ignored

0 comments on commit c7b3625

Please sign in to comment.