Skip to content

Commit

Permalink
fixed mem issue
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghh04 committed Feb 6, 2025
1 parent 139eed7 commit 02b31e9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 19 deletions.
28 changes: 11 additions & 17 deletions dlio_benchmark/checkpointing/base_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, ext):


self.layer_state = None
start_layer, end_layer = self.get_layer_index(self.args.my_rank, self.tp, self.pp, self.args.num_layers)
start_layer, end_layer = self.get_layer_index()

if self.layer_parameters_predefined:
# This is for old code, where the layer parameters are predefined
Expand All @@ -85,11 +85,8 @@ def __init__(self, ext):
self.layer_state = dict()
ss = 0.0
for layer_index in range(start_layer, end_layer + 1):
if self.args.zero_stage < 3:
_, size = self.get_layer_state(layer_index)
else:
self.layer_state[str(layer_index)], size = self.get_layer_state(layer_index)
logging.info(f"{utcnow()} {self.args.my_rank}- {layer_index}: {size/1024./1024./1024:.4f} GB ")
self.layer_state[str(layer_index)], size = self.get_layer_state(layer_index)
#logging.info(f"{utcnow()} {self.args.my_rank} [{start_layer}-{end_layer}]:::{layer_index}: {size/1024./1024./1024:.4f} GB ")
ss += size
if self.args.my_rank == 0:
logging.info(f"{utcnow()} Layer states defined! {ss/1024./1024./1024} GB per rank")
Expand Down Expand Up @@ -130,9 +127,6 @@ def __init__(self, ext):
if self.args.zero_stage < 3:
ss /= self.dp
self.checkpoint_size = ss + opt



if self.args.my_rank == 0:
logging.info(f"{utcnow()} Total state size: {ss} GB")
logging.info(f"{utcnow()} Total checkpoint size: {self.checkpoint_size} GB")
Expand Down Expand Up @@ -221,7 +215,7 @@ def get_optimization_groups(self):
else:
return []

def get_layer_index(self, rank, tensor_parallelism, pipeline_parallelism, total_layers):
def get_layer_index(self):
'''
if tensor_parallelism > 1:
total_layers = total_layers + tensor_parallelism
Expand All @@ -248,25 +242,25 @@ def get_layer_index(self, rank, tensor_parallelism, pipeline_parallelism, total_
The transformer layers are from 1 to l. We only distribute the transformer layers among the ranks.
We assume layer 0 is always on rank 0, and l+1 and l+2 are on the last rank.
'''
pipeline_rank = (rank // tensor_parallelism) % pipeline_parallelism
remainder = total_layers%pipeline_parallelism
nl = total_layers//pipeline_parallelism
pipeline_rank = self.pp_rank
nl = self.args.num_layers//self.pp
remainder = self.args.num_layers%self.pp
if pipeline_rank < remainder:
start_layer = pipeline_rank * (nl + 1) + 1
end_layer = start_layer + nl + 1
else:
start_layer = remainder * (nl + 1) + (pipeline_rank - remainder) * nl + 1
end_layer = start_layer + nl
if pipeline_rank == pipeline_parallelism - 1:
end_layer = total_layers + 2
if pipeline_rank == self.pp - 1:
end_layer = self.args.num_layers + 2
if pipeline_rank == 0:
start_layer = 0
return start_layer, end_layer

@abstractmethod
def checkpoint(self, epoch, step_number):
my_rank = DLIOMPI.get_instance().rank()
start_layer, end_layer = self.get_layer_index(my_rank,self.args.tensor_parallelism, self.args.pipeline_parallelism, self.args.num_layers)
start_layer, end_layer = self.get_layer_index()
# create a specifc folder for each step
checkpoint_id = f"global_epoch{epoch}_step{step_number}"
self.checkpoint_storage.create_node(checkpoint_id, exist_ok=True)
Expand All @@ -283,7 +277,7 @@ def checkpoint(self, epoch, step_number):
if self.dp_rank == 0 and self.args.num_layers > 0:
# in this case, model is saved layer by layer
for layer_index in range(start_layer, end_layer + 1):
self.save_state(suffix=f"{checkpoint_id}/layer_{layer_index}-model_{self.mp_rank}_model_states", state=self.get_layer_state(layer_index))
self.save_state(suffix=f"{checkpoint_id}/layer_{layer_index}-model_{self.mp_rank}_model_states", state=self.layer_state[str(layer_index)])
else:
# in this case, model is sharded across the data parallel ranks
assert(self.pp == 1)
Expand Down
4 changes: 2 additions & 2 deletions dlio_benchmark/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,9 +608,9 @@ def LoadConfig(args, config):
args.layer_parameters = config['model']['layer_parameters']

if 'parallelism' in config['model']:
if 'tensor_parallelism' in config['model']['parallelism']:
if 'tensor' in config['model']['parallelism']:
args.tensor_parallelism = config['model']['parallelism']['tensor']
if 'pipeline_parallelism' in config['model']['parallelism']:
if 'pipeline' in config['model']['parallelism']:
args.pipeline_parallelism = config['model']['parallelism']['pipeline']
if 'zero_stage' in config['model']['parallelism']:
args.zero_stage = config['model']['parallelism']['zero_stage']
Expand Down

0 comments on commit 02b31e9

Please sign in to comment.