Skip to content

Commit

Permalink
update recompute num
Browse files Browse the repository at this point in the history
  • Loading branch information
caozhou committed May 30, 2024
1 parent 1255ade commit 47d89ec
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 55 deletions.
16 changes: 14 additions & 2 deletions flagscale/auto_tuner/prune/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,25 @@ def prune_by_recompute(config, strategy, history=[]):
and item["performance"]):
if recompute_num_layers > item["recompute_num_layers"]:
logger.info(
f"The strategy {strategy} has been pruned by recompute_num_layers performance."
f"The strategy {strategy} has been pruned by block recompute_num_layers performance."
)
strategy["performance"] = item["performance"]
strategy["max_mem"] = item["max_mem"]
strategy["pruned"] = True
return True

if (use_recompute and item["use_recompute"]
and recompute_method == "uniform"
and recompute_method == item["recompute_method"]
and item["performance"]):
if recompute_num_layers > item["recompute_num_layers"]:
logger.info(
f"The strategy {strategy} has been pruned by uniform recompute_num_layers performance."
)
strategy["performance"] = item["performance"]
strategy["max_mem"] = item["max_mem"]
strategy["pruned"] = True
return True
# memory prune
if not use_recompute and item["use_recompute"] and item[
"max_mem"] == "OOM":
Expand Down Expand Up @@ -170,4 +182,4 @@ def prune_by_sequence_parallel(config, strategy, history=[]):
strategy["performance"] = None
strategy["pruned"] = True
return True
return False
return False
5 changes: 5 additions & 0 deletions flagscale/auto_tuner/search/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,11 @@ def _product_recompute_dims(self, micro_batch_size_vpp_part, space,
)
if recompute_num_layers > layers_per_stage:
continue
if recompute_method == "uniform":
if not divisible(
config.train.model.num_layers,
recompute_num_layers):
continue
product_dim["recompute_num_layers"] = (
recompute_num_layers)
self._append(result, unique_result, product_dim)
Expand Down
73 changes: 20 additions & 53 deletions flagscale/auto_tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,54 +9,18 @@ def beside(keys, strategy, history):
from .search.searcher import __BUILT_IN_STRATEGY_DIMS__

retrieval = []
if strategy == {
'data_parallel_size': 1,
'tensor_model_parallel_size': 1,
'pipeline_model_parallel_size': 8,
'expert_model_parallel_size': 1,
'context_parallel_size': 1,
'use_distributed_optimizer': None,
'sequence_parallel': None,
'acc_step': 4,
'micro_batch_size': 4,
'num_layers_per_virtual_pipeline_stage': None,
'use_recompute': True,
'recompute_method': 'uniform',
'recompute_granularity': 'full',
'recompute_num_layers': 1
}:

for task in history:
is_same = True
print(f"task {task}")
for dim in task:
print(f"dim {dim}")
if dim not in __BUILT_IN_STRATEGY_DIMS__:
print(f"dim {dim} not in ")
continue
if dim in keys:
print(f"dim {dim} in ")
continue
if strategy[dim] != task[dim]:
print(f"dim {dim} !=")
is_same = False
break
print(f"is_same: {is_same}")
if is_same:
retrieval.append(task)
else:
for task in history:
is_same = True
for dim in task:
if dim not in __BUILT_IN_STRATEGY_DIMS__:
continue
if dim in keys:
continue
if strategy[dim] != task[dim]:
is_same = False
break
if is_same:
retrieval.append(task)
for task in history:
is_same = True
for dim in task:
if dim not in __BUILT_IN_STRATEGY_DIMS__:
continue
if dim in keys:
continue
if strategy[dim] != task[dim]:
is_same = False
break
if is_same:
retrieval.append(task)
return retrieval


Expand All @@ -65,13 +29,16 @@ def sort_by_memory(strategy):
-strategy["tensor_model_parallel_size"],
-strategy["pipeline_model_parallel_size"],
-strategy["use_recompute"],
strategy["micro_batch_size"]
strategy["micro_batch_size"],
)


def sort_by_performance(strategy):
magic_number = 4
return (-strategy["use_recompute"],
(strategy["tensor_model_parallel_size"] % magic_number),
(strategy["micro_batch_size"] % magic_number),
strategy["pipeline_model_parallel_size"])
return (
-strategy["use_recompute"],
(strategy["tensor_model_parallel_size"] % magic_number),
(strategy["micro_batch_size"] % magic_number),
strategy["pipeline_model_parallel_size"],
strategy["recompute_num_layers"] if strategy["recompute_num_layers"] else float('inf'),
)

0 comments on commit 47d89ec

Please sign in to comment.