diff --git a/examples/aquila/conf/config_auto_tuner.yaml b/examples/aquila/conf/config_auto_tuner.yaml new file mode 100644 index 000000000..a1a7cc253 --- /dev/null +++ b/examples/aquila/conf/config_auto_tuner.yaml @@ -0,0 +1,43 @@ +defaults: + - train: demo + - _self_ + +experiment: + exp_name: aquila2 + exp_dir: ./outputs + task: + type: train + backend: megatron + entrypoint: ./flagscale/train/train_aquila.py + runner: + backend: torchrun + nnodes: 1 + nproc_per_node: 8 + envs: + CUDA_VISIBLE_DEVICES: 0,1,2,3,4,5,6,7 + CUDA_DEVICE_MAX_CONNECTIONS: 1 + auto_tuner: + space: + data_parallel_size: "auto" + use_distributed_optimizer: [true, false] + tensor_model_parallel_size: [2, 4, 8] + sequence_parallel: [true] + pipeline_model_parallel_size: "auto" + num_layers_per_virtual_pipeline_stage: [1] + context_parallel_size: "auto" + expert_model_parallel_size: [1] + micro_batch_size: "auto" + use_recompute: [true] + recompute_method: "auto" + recompute_granularity: "auto" + recompute_num_layers: "auto" + control: + max_time_per_task: 300 + train_iters: 5 + max_time: 600 + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra \ No newline at end of file diff --git a/flagscale/auto_tuner/generate.py b/flagscale/auto_tuner/generate.py index 1a6b5439e..18c0cd6b2 100644 --- a/flagscale/auto_tuner/generate.py +++ b/flagscale/auto_tuner/generate.py @@ -7,8 +7,8 @@ class Generator: def __init__(self, config): self.config = config # TODO: Just a temporary solution, need to be configurated by user - if "args_mapping" in config.auto_tuner: - self.args_mapping = config.auto_tuner.args_mapping + if "args_mapping" in config.experiment.auto_tuner: + self.args_mapping = config.experiment.auto_tuner.args_mapping else: self.args_mapping = { "data_parallel_size": "data_parallel_size", @@ -50,8 +50,8 @@ def gen(self, strategy): config.experiment.runner.tee = 3 config.experiment.runner.redirects = 3 - # FLAGSCALE_AUTOTUNER should be true, it will not save ckpt when train ended and report memory every iteration - config.experiment.envs.FLAGSCALE_AUTOTUNER = True + # auto_tune should be true, it will not save ckpt when train ended and report memory every iteration + config.train.system.auto_tune = True # Del lr_warmup_samples and train_samples to run megatron. assert "optimizer" in config.train.model @@ -79,8 +79,8 @@ def gen(self, strategy): config.train.system.checkpoint.save_interval = 2000 # Set train_iters of each task - if "control" in config.auto_tuner: - config.train.model.train_iters = config.auto_tuner.control.get( + if "control" in config.experiment.auto_tuner: + config.train.model.train_iters = config.experiment.auto_tuner.control.get( "train_iters", 5) else: config.train.model.train_iters = 5 diff --git a/flagscale/auto_tuner/record/recorder.py b/flagscale/auto_tuner/record/recorder.py index f40e414a1..24d0885ff 100644 --- a/flagscale/auto_tuner/record/recorder.py +++ b/flagscale/auto_tuner/record/recorder.py @@ -15,15 +15,15 @@ def __init__(self, config): "history.csv", ) # Metric to grep in the last rank of last node log file - if "auto_tuner" in self.config and "performance" in self.config.auto_tuner: - self.metric = self.config.auto_tuner.performance.get( + if "auto_tuner" in self.config and "performance" in self.config.experiment.auto_tuner: + self.metric = self.config.experiment.auto_tuner.performance.get( "name", "elapsed time per iteration \(ms\):") else: self.metric = "elapsed time per iteration \(ms\):" # Sort order of performance, order just in [ascend, and descend], default ascend - if "auto_tuner" in self.config and "performance" in self.config.auto_tuner: - self.sorted_order = self.config.auto_tuner.performance.get( + if "auto_tuner" in self.config and "performance" in self.config.experiment.auto_tuner: + self.sorted_order = self.config.experiment.auto_tuner.performance.get( "order", "ascend") else: self.sorted_order = "ascend" @@ -66,8 +66,8 @@ def record(self, task, strategy): strategy["error"] = None # Pass back to platform if need - if ("airs_switch" in self.config.auto_tuner.platform - and self.config.auto_tuner.platform.airs_switch + if ("airs_switch" in self.config.experiment.auto_tuner.platform + and self.config.experiment.auto_tuner.platform.airs_switch and strategy["performance"]): self.pass_back_to_platform(strategy) diff --git a/flagscale/auto_tuner/search/searcher.py b/flagscale/auto_tuner/search/searcher.py index 72f4eb11f..0a4835643 100644 --- a/flagscale/auto_tuner/search/searcher.py +++ b/flagscale/auto_tuner/search/searcher.py @@ -83,33 +83,33 @@ def _sort(self, key, dim, priority=None): def build_space(self, config): """Set value of each dim and sort.""" space = {} - cards = config.auto_tuner.cards - cards_per_node = config.auto_tuner.nproc_per_node + cards = config.experiment.auto_tuner.cards + cards_per_node = config.experiment.auto_tuner.nproc_per_node num_layers = config.train.model.num_layers gbs = config.train.model.global_batch_size - if "space" not in config.auto_tuner: - config.auto_tuner.space = {} + if "space" not in config.experiment.auto_tuner: + config.experiment.auto_tuner.space = {} - if "algo" not in self.config.auto_tuner: - self.config.auto_tuner.algo = {"name": "grid", "priority": None} - priority = config.auto_tuner.algo.get("priority", None) - if config.auto_tuner.platform.get("airs_switch", False): + if "algo" not in self.config.experiment.auto_tuner: + self.config.experiment.auto_tuner.algo = {"name": "grid", "priority": None} + priority = config.experiment.auto_tuner.algo.get("priority", None) + if config.experiment.auto_tuner.platform.get("airs_switch", False): priority = "memory" # Set data parallel degree space["data_parallel_size"] = ( [i for i in range(1, cards + 1)] - if "data_parallel_size" not in config.auto_tuner.space - or config.auto_tuner.space.data_parallel_size == "auto" - else config.auto_tuner.space.data_parallel_size + if "data_parallel_size" not in config.experiment.auto_tuner.space + or config.experiment.auto_tuner.space.data_parallel_size == "auto" + else config.experiment.auto_tuner.space.data_parallel_size ) self._sort("data_parallel_size", space["data_parallel_size"], priority) # Set distributed optimizer space["use_distributed_optimizer"] = ( [True, False] - if "use_distributed_optimizer" not in config.auto_tuner.space - or config.auto_tuner.space.use_distributed_optimizer == "auto" - else config.auto_tuner.space.use_distributed_optimizer + if "use_distributed_optimizer" not in config.experiment.auto_tuner.space + or config.experiment.auto_tuner.space.use_distributed_optimizer == "auto" + else config.experiment.auto_tuner.space.use_distributed_optimizer ) self._sort( "use_distributed_optimizer", space["use_distributed_optimizer"], priority @@ -118,9 +118,9 @@ def build_space(self, config): # Set tensor parallel degree space["tensor_model_parallel_size"] = ( [i for i in range(1, cards_per_node + 1)] - if "tensor_model_parallel_size" not in config.auto_tuner.space - or config.auto_tuner.space.tensor_model_parallel_size == "auto" - else config.auto_tuner.space.tensor_model_parallel_size + if "tensor_model_parallel_size" not in config.experiment.auto_tuner.space + or config.experiment.auto_tuner.space.tensor_model_parallel_size == "auto" + else config.experiment.auto_tuner.space.tensor_model_parallel_size ) self._sort( "tensor_model_parallel_size", space["tensor_model_parallel_size"], priority @@ -129,18 +129,18 @@ def build_space(self, config): # Set sequence parallel space["sequence_parallel"] = ( [True, False] - if "sequence_parallel" not in config.auto_tuner.space - or config.auto_tuner.space.sequence_parallel == "auto" - else config.auto_tuner.space.sequence_parallel + if "sequence_parallel" not in config.experiment.auto_tuner.space + or config.experiment.auto_tuner.space.sequence_parallel == "auto" + else config.experiment.auto_tuner.space.sequence_parallel ) self._sort("sequence_parallel", space["sequence_parallel"], priority) # Set pipeline parallel degree space["pipeline_model_parallel_size"] = ( [i for i in range(1, cards + 1)] - if "pipeline_model_parallel_size" not in config.auto_tuner.space - or config.auto_tuner.space.pipeline_model_parallel_size == "auto" - else config.auto_tuner.space.pipeline_model_parallel_size + if "pipeline_model_parallel_size" not in config.experiment.auto_tuner.space + or config.experiment.auto_tuner.space.pipeline_model_parallel_size == "auto" + else config.experiment.auto_tuner.space.pipeline_model_parallel_size ) self._sort( "pipeline_model_parallel_size", @@ -151,9 +151,9 @@ def build_space(self, config): # Set virtual pipeline parallel degree space["num_layers_per_virtual_pipeline_stage"] = ( [i for i in range(1, num_layers + 1)] - if "num_layers_per_virtual_pipeline_stage" not in config.auto_tuner.space - or config.auto_tuner.space.num_layers_per_virtual_pipeline_stage == "auto" - else config.auto_tuner.space.num_layers_per_virtual_pipeline_stage + if "num_layers_per_virtual_pipeline_stage" not in config.experiment.auto_tuner.space + or config.experiment.auto_tuner.space.num_layers_per_virtual_pipeline_stage == "auto" + else config.experiment.auto_tuner.space.num_layers_per_virtual_pipeline_stage ) self._sort( "num_layers_per_virtual_pipeline_stage", @@ -164,53 +164,53 @@ def build_space(self, config): # Set use recompute space["use_recompute"] = ( [True, False] - if "use_recompute" not in config.auto_tuner.space - or config.auto_tuner.space.use_recompute == "auto" - else config.auto_tuner.space.use_recompute + if "use_recompute" not in config.experiment.auto_tuner.space + or config.experiment.auto_tuner.space.use_recompute == "auto" + else config.experiment.auto_tuner.space.use_recompute ) self._sort("use_recompute", space["use_recompute"], priority) # Set recompute method space["recompute_method"] = ( ["uniform", "block"] - if "recompute_method" not in config.auto_tuner.space - or config.auto_tuner.space.recompute_method == "auto" - else config.auto_tuner.space.recompute_method + if "recompute_method" not in config.experiment.auto_tuner.space + or config.experiment.auto_tuner.space.recompute_method == "auto" + else config.experiment.auto_tuner.space.recompute_method ) self._sort("recompute_method", space["recompute_method"], priority) # Set recompute granularity space["recompute_granularity"] = ( ["full", "selective"] - if "recompute_granularity" not in config.auto_tuner.space - or config.auto_tuner.space.recompute_granularity == "auto" - else config.auto_tuner.space.recompute_granularity + if "recompute_granularity" not in config.experiment.auto_tuner.space + or config.experiment.auto_tuner.space.recompute_granularity == "auto" + else config.experiment.auto_tuner.space.recompute_granularity ) self._sort("recompute_granularity", space["recompute_granularity"], priority) # Set recompute num layers space["recompute_num_layers"] = ( [i for i in range(1, num_layers + 1)] - if "recompute_num_layers" not in config.auto_tuner.space - or config.auto_tuner.space.recompute_num_layers == "auto" - else config.auto_tuner.space.recompute_num_layers + if "recompute_num_layers" not in config.experiment.auto_tuner.space + or config.experiment.auto_tuner.space.recompute_num_layers == "auto" + else config.experiment.auto_tuner.space.recompute_num_layers ) self._sort("recompute_num_layers", space["recompute_num_layers"], priority) # Set micro batch size space["micro_batch_size"] = ( [i for i in range(1, gbs + 1)] - if "micro_batch_size" not in config.auto_tuner.space - or config.auto_tuner.space.micro_batch_size == "auto" - else config.auto_tuner.space.micro_batch_size + if "micro_batch_size" not in config.experiment.auto_tuner.space + or config.experiment.auto_tuner.space.micro_batch_size == "auto" + else config.experiment.auto_tuner.space.micro_batch_size ) self._sort("micro_batch_size", space["micro_batch_size"], priority) # Set context parallel degree space["context_parallel_size"] = ( [i for i in range(1, cards + 1)] - if "context_parallel_size" not in config.auto_tuner.space - or config.auto_tuner.space.context_parallel_size == "auto" - else config.auto_tuner.space.context_parallel_size + if "context_parallel_size" not in config.experiment.auto_tuner.space + or config.experiment.auto_tuner.space.context_parallel_size == "auto" + else config.experiment.auto_tuner.space.context_parallel_size ) self._sort("context_parallel_size", space["context_parallel_size"], priority) @@ -218,9 +218,9 @@ def build_space(self, config): # NOTE: Expert parallel degree is not supported now space["expert_model_parallel_size"] = ( [1] - if "expert_model_parallel_size" not in config.auto_tuner.space - or config.auto_tuner.space.expert_model_parallel_size == "auto" - else config.auto_tuner.space.expert_model_parallel_size + if "expert_model_parallel_size" not in config.experiment.auto_tuner.space + or config.experiment.auto_tuner.space.expert_model_parallel_size == "auto" + else config.experiment.auto_tuner.space.expert_model_parallel_size ) self._sort( "expert_model_parallel_size", space["expert_model_parallel_size"], priority @@ -240,7 +240,7 @@ def build_strategies(self, space, config): return recompute_part def build_algo(self, strategies, config): - name = self.config.auto_tuner.algo.name + name = self.config.experiment.auto_tuner.algo.name if name == "grid": from .algorithm import GridAlgo @@ -251,7 +251,7 @@ def build_algo(self, strategies, config): def _product_parallel_dims(self, space, config): # Avoid space explosion after product product_parallelism_dims = [] - cards = config.auto_tuner.cards + cards = config.experiment.auto_tuner.cards for data_parallel_size in space["data_parallel_size"]: dims = {} if not divisible(cards, data_parallel_size): diff --git a/flagscale/auto_tuner/tuner.py b/flagscale/auto_tuner/tuner.py index 898f128bb..da031a41a 100644 --- a/flagscale/auto_tuner/tuner.py +++ b/flagscale/auto_tuner/tuner.py @@ -44,33 +44,35 @@ def __init__(self, config: DictConfig): self.config = copy.deepcopy(config) # Set config of auto tuner - if "auto_tuner" not in self.config: - self.config.auto_tuner = {} + if "auto_tuner" not in self.config.experiment: + self.config.experiment.auto_tuner = {} # Add nodes, nproc_per_node, cards to build search space or prune assert "experiment" in config, "experiment is not in yaml file." assert "runner" in config.experiment, "runner is not in yaml file." nnodes = config.experiment.runner.get("nnodes", 1) nproc_per_node = config.experiment.runner.get("nproc_per_node", 8) - self.config.auto_tuner.nproc_per_node = nproc_per_node - self.config.auto_tuner.nnodes = nnodes + self.config.experiment.auto_tuner.nproc_per_node = nproc_per_node + self.config.experiment.auto_tuner.nnodes = nnodes # Set tuner configs # The interval of task monitoring - if "control" not in self.config.auto_tuner: - self.config.auto_tuner.control = {} - self.interval = self.config.auto_tuner.control.get("interval", 10) + if "control" not in self.config.experiment.auto_tuner: + self.config.experiment.auto_tuner.control = {} + self.interval = self.config.experiment.auto_tuner.control.get( + "interval", 10) # Set platform envs - if "platform" not in self.config.auto_tuner: - self.config.auto_tuner.platform = {} + if "platform" not in self.config.experiment.auto_tuner: + self.config.experiment.auto_tuner.platform = {} # As long as AIRS_SWITCH has value it means running on the platform if os.environ.get("AIRS_SWITCH", None): - self.config.auto_tuner.platform.airs_switch = True + self.config.experiment.auto_tuner.platform.airs_switch = True if os.environ.get("AIRS_SIZE", None): - self.config.auto_tuner.nnodes = int(os.environ["AIRS_SIZE"]) + self.config.experiment.auto_tuner.nnodes = int( + os.environ["AIRS_SIZE"]) # Set original config self.orig_config.experiment.runner.nnodes = int( os.environ["AIRS_SIZE"]) @@ -79,7 +81,7 @@ def __init__(self, config: DictConfig): os.environ["AIRS_SIZE"]) if os.environ.get("AIRS_ACCELERATOR_COUNT", None): - self.config.auto_tuner.nproc_per_node = int( + self.config.experiment.auto_tuner.nproc_per_node = int( os.environ["AIRS_ACCELERATOR_COUNT"]) # Set original config self.orig_config.experiment.runner.nproc_per_node = int( @@ -89,7 +91,8 @@ def __init__(self, config: DictConfig): os.environ["AIRS_ACCELERATOR_COUNT"]) if os.environ.get("AIRS_FBMEM", None): - self.config.auto_tuner.memory = int(os.environ["AIRS_FBMEM"]) + self.config.experiment.auto_tuner.memory = int( + os.environ["AIRS_FBMEM"]) if os.environ.get("AIRS_HOSTFILE_PATH", None): # Set original config @@ -99,7 +102,7 @@ def __init__(self, config: DictConfig): self.config.experiment.runner.hostfile = os.environ[ "AIRS_HOSTFILE_PATH"] - self.config.auto_tuner.cards = self.config.auto_tuner.nnodes * self.config.auto_tuner.nproc_per_node + self.config.experiment.auto_tuner.cards = self.config.experiment.auto_tuner.nnodes * self.config.experiment.auto_tuner.nproc_per_node # Build core sub modules, such as Searcher, Pruner, Generator and Recorder self.searcher = Searcher(self.config) @@ -112,11 +115,12 @@ def __init__(self, config: DictConfig): # The max time per task, unit: second # NOTE: The task will be stopped if the time is reached or done. - self.max_time_per_task = self.config.auto_tuner.control.get( + self.max_time_per_task = self.config.experiment.auto_tuner.control.get( "max_time_per_task", 300) # The max time of auto tuner, if None, no limit. - self.max_time = self.config.auto_tuner.control.get("max_time", None) + self.max_time = self.config.experiment.auto_tuner.control.get( + "max_time", None) # The start time of each task, used to control each task when stop self.start_task_time = None @@ -155,7 +159,7 @@ def tune(self): self.record() if (self.cur_strategy["performance"] - and self.config.auto_tuner.platform.get( + and self.config.experiment.auto_tuner.platform.get( "airs_switch", False) and not self.has_checkout): self.checkout() @@ -172,7 +176,7 @@ def tune(self): f"AutoTuner Ended in {tuner_end_time - tuner_start_time} seconds.") # Run the best task - if self.config.auto_tuner.control.get("run_best", True): + if self.config.experiment.auto_tuner.control.get("run_best", True): best_strategy = self.get_best() if best_strategy: self.logger.info(f"Run best Strategy: {best_strategy}") @@ -180,6 +184,7 @@ def tune(self): raise ValueError(f"No strategy can run.") best_task = self.generator.gen_best_task(best_strategy, self.orig_config) + best_task.action = "run" runner = SSHRunner(best_task) runner.run() diff --git a/flagscale/train/train.py b/flagscale/train/train.py index d0ed965b5..9cc688479 100644 --- a/flagscale/train/train.py +++ b/flagscale/train/train.py @@ -295,7 +295,7 @@ def pretrain(train_valid_test_dataset_provider, print_datetime('after training is done') - if not os.environ.get("FLAGSCALE_AUTOTUNER", False): + if not args.auto_tune: if args.save and iteration != 0 and iteration % args.save_interval != 0: save_checkpoint( iteration, @@ -951,7 +951,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r total_loss_dict[skipped_iters_key] = 0 total_loss_dict[nan_iters_key] = 0 print_rank_last(log_string) - if not os.environ.get("FLAGSCALE_AUTOTUNER", False): + if not args.auto_tune: if report_memory_flag and learning_rate > 0.0: # Report memory after optimizer state has been initialized. if torch.distributed.get_rank() == 0: diff --git a/megatron/megatron/training/arguments.py b/megatron/megatron/training/arguments.py index d38869337..4435cd34d 100644 --- a/megatron/megatron/training/arguments.py +++ b/megatron/megatron/training/arguments.py @@ -47,6 +47,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): parser = _add_serving_args(parser) parser = _add_customized_device_args(parser) parser = _add_hetero_args(parser) + parser = _add_auto_tuner_args(parser) # Custom arguments. if extra_args_provider is not None: @@ -1935,3 +1936,12 @@ def _add_hetero_args(parser): 'The order should be consistent with --hetero-device-types.') return parser + + +def _add_auto_tuner_args(parser): + group = parser.add_argument_group(title="auto tuner") + + group.add_argument('--auto-tune', action='store_true', + help='use auto tuner') + + return parser diff --git a/megatron/megatron/training/training.py b/megatron/megatron/training/training.py index 3f5fb9d34..771ca15a7 100644 --- a/megatron/megatron/training/training.py +++ b/megatron/megatron/training/training.py @@ -301,7 +301,7 @@ def pretrain(train_valid_test_dataset_provider, extra_valid_dataset_provider) print_datetime('after training is done') - if not os.environ.get("FLAGSCALE_AUTOTUNER", False): + if not args.auto_tune: if args.save and iteration != 0 and iteration % args.save_interval != 0: save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, checkpointing_context) @@ -952,7 +952,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r total_loss_dict[skipped_iters_key] = 0 total_loss_dict[nan_iters_key] = 0 print_rank_last(log_string) - if not os.environ.get("FLAGSCALE_AUTOTUNER", False): + if not args.auto_tune: if report_memory_flag and learning_rate > 0.: # Report memory after optimizer state has been initialized. if torch.distributed.get_rank() == 0: diff --git a/megatron/megatron/training/utils.py b/megatron/megatron/training/utils.py index 2c481ee54..46bbbbcac 100644 --- a/megatron/megatron/training/utils.py +++ b/megatron/megatron/training/utils.py @@ -119,7 +119,8 @@ def report_memory(name): torch.cuda.memory_reserved() / mega_bytes) string += ' | max reserved: {}'.format( torch.cuda.max_memory_reserved() / mega_bytes) - if not os.environ.get("FLAGSCALE_AUTOTUNER", False): + args = get_args() + if not args.auto_tune: # Each rank prints the memory report. if mpu.get_data_parallel_rank() == 0: print("[Rank {}] {}".format(torch.distributed.get_rank(), string),