diff --git a/flagscale/auto_tuner/prune/history.py b/flagscale/auto_tuner/prune/history.py index d053d7578..5359d7a5a 100644 --- a/flagscale/auto_tuner/prune/history.py +++ b/flagscale/auto_tuner/prune/history.py @@ -97,13 +97,25 @@ def prune_by_recompute(config, strategy, history=[]): and item["performance"]): if recompute_num_layers > item["recompute_num_layers"]: logger.info( - f"The strategy {strategy} has been pruned by recompute_num_layers performance." + f"The strategy {strategy} has been pruned by block recompute_num_layers performance." ) strategy["performance"] = item["performance"] strategy["max_mem"] = item["max_mem"] strategy["pruned"] = True return True + if (use_recompute and item["use_recompute"] + and recompute_method == "uniform" + and recompute_method == item["recompute_method"] + and item["performance"]): + if recompute_num_layers > item["recompute_num_layers"]: + logger.info( + f"The strategy {strategy} has been pruned by uniform recompute_num_layers performance." + ) + strategy["performance"] = item["performance"] + strategy["max_mem"] = item["max_mem"] + strategy["pruned"] = True + return True # memory prune if not use_recompute and item["use_recompute"] and item[ "max_mem"] == "OOM": @@ -170,4 +182,4 @@ def prune_by_sequence_parallel(config, strategy, history=[]): strategy["performance"] = None strategy["pruned"] = True return True - return False \ No newline at end of file + return False diff --git a/flagscale/auto_tuner/search/searcher.py b/flagscale/auto_tuner/search/searcher.py index 8c00ff827..4046bd3ee 100644 --- a/flagscale/auto_tuner/search/searcher.py +++ b/flagscale/auto_tuner/search/searcher.py @@ -404,6 +404,11 @@ def _product_recompute_dims(self, micro_batch_size_vpp_part, space, ) if recompute_num_layers > layers_per_stage: continue + if recompute_method == "uniform": + if not divisible( + config.train.model.num_layers, + recompute_num_layers): + continue product_dim["recompute_num_layers"] = ( recompute_num_layers) self._append(result, unique_result, product_dim) diff --git a/flagscale/auto_tuner/utils.py b/flagscale/auto_tuner/utils.py index 737a17cff..21cdda8f3 100644 --- a/flagscale/auto_tuner/utils.py +++ b/flagscale/auto_tuner/utils.py @@ -9,54 +9,18 @@ def beside(keys, strategy, history): from .search.searcher import __BUILT_IN_STRATEGY_DIMS__ retrieval = [] - if strategy == { - 'data_parallel_size': 1, - 'tensor_model_parallel_size': 1, - 'pipeline_model_parallel_size': 8, - 'expert_model_parallel_size': 1, - 'context_parallel_size': 1, - 'use_distributed_optimizer': None, - 'sequence_parallel': None, - 'acc_step': 4, - 'micro_batch_size': 4, - 'num_layers_per_virtual_pipeline_stage': None, - 'use_recompute': True, - 'recompute_method': 'uniform', - 'recompute_granularity': 'full', - 'recompute_num_layers': 1 - }: - - for task in history: - is_same = True - print(f"task {task}") - for dim in task: - print(f"dim {dim}") - if dim not in __BUILT_IN_STRATEGY_DIMS__: - print(f"dim {dim} not in ") - continue - if dim in keys: - print(f"dim {dim} in ") - continue - if strategy[dim] != task[dim]: - print(f"dim {dim} !=") - is_same = False - break - print(f"is_same: {is_same}") - if is_same: - retrieval.append(task) - else: - for task in history: - is_same = True - for dim in task: - if dim not in __BUILT_IN_STRATEGY_DIMS__: - continue - if dim in keys: - continue - if strategy[dim] != task[dim]: - is_same = False - break - if is_same: - retrieval.append(task) + for task in history: + is_same = True + for dim in task: + if dim not in __BUILT_IN_STRATEGY_DIMS__: + continue + if dim in keys: + continue + if strategy[dim] != task[dim]: + is_same = False + break + if is_same: + retrieval.append(task) return retrieval @@ -65,13 +29,20 @@ def sort_by_memory(strategy): -strategy["tensor_model_parallel_size"], -strategy["pipeline_model_parallel_size"], -strategy["use_recompute"], - strategy["micro_batch_size"] + strategy["micro_batch_size"], ) def sort_by_performance(strategy): magic_number = 4 - return (-strategy["use_recompute"], - (strategy["tensor_model_parallel_size"] % magic_number), - (strategy["micro_batch_size"] % magic_number), - strategy["pipeline_model_parallel_size"]) + return ( + -strategy["use_recompute"], + (strategy["tensor_model_parallel_size"] % magic_number), + (strategy["micro_batch_size"] % magic_number), + strategy["pipeline_model_parallel_size"], + ( + strategy["recompute_num_layers"] + if strategy["recompute_num_layers"] + else float("inf") + ), + )