Skip to content

Commit

Permalink
update recompute num
Browse files Browse the repository at this point in the history
  • Loading branch information
caozhou committed May 30, 2024
1 parent 1255ade commit b318cad
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
16 changes: 14 additions & 2 deletions flagscale/auto_tuner/prune/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,25 @@ def prune_by_recompute(config, strategy, history=[]):
and item["performance"]):
if recompute_num_layers > item["recompute_num_layers"]:
logger.info(
f"The strategy {strategy} has been pruned by recompute_num_layers performance."
f"The strategy {strategy} has been pruned by block recompute_num_layers performance."
)
strategy["performance"] = item["performance"]
strategy["max_mem"] = item["max_mem"]
strategy["pruned"] = True
return True

if (use_recompute and item["use_recompute"]
and recompute_method == "uniform"
and recompute_method == item["recompute_method"]
and item["performance"]):
if recompute_num_layers > item["recompute_num_layers"]:
logger.info(
f"The strategy {strategy} has been pruned by uniform recompute_num_layers performance."
)
strategy["performance"] = item["performance"]
strategy["max_mem"] = item["max_mem"]
strategy["pruned"] = True
return True
# memory prune
if not use_recompute and item["use_recompute"] and item[
"max_mem"] == "OOM":
Expand Down Expand Up @@ -170,4 +182,4 @@ def prune_by_sequence_parallel(config, strategy, history=[]):
strategy["performance"] = None
strategy["pruned"] = True
return True
return False
return False
5 changes: 5 additions & 0 deletions flagscale/auto_tuner/search/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,11 @@ def _product_recompute_dims(self, micro_batch_size_vpp_part, space,
)
if recompute_num_layers > layers_per_stage:
continue
if recompute_method == "uniform":
if not divisible(
config.train.model.num_layers,
recompute_num_layers):
continue
product_dim["recompute_num_layers"] = (
recompute_num_layers)
self._append(result, unique_result, product_dim)
Expand Down
3 changes: 2 additions & 1 deletion flagscale/auto_tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,5 @@ def sort_by_performance(strategy):
return (-strategy["use_recompute"],
(strategy["tensor_model_parallel_size"] % magic_number),
(strategy["micro_batch_size"] % magic_number),
strategy["pipeline_model_parallel_size"])
strategy["pipeline_model_parallel_size"],
strategy["recompute_num_layers"])

0 comments on commit b318cad

Please sign in to comment.