Skip to content

Commit

Permalink
ddp process group for backwawrd
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhhjjj committed Jul 2, 2024
1 parent 019d263 commit 07ebf2c
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,7 @@ def grad_optimizer_builder(named_param_groups):
named_params_or_groups=named_param_groups,
# TODO @thomasw21: We need a better API for gradient accumulation/zero etc ...
optimizer_builder=optimizer_builder,
# dp_pg=parallel_context.dp_pg,
dp_pg=parallel_context.dp_sp_pg,
dp_pg=parallel_context.dp_sp_pg, # sequence parallel and data parallel process group
)

# SANITY CHECK: assert that optimizer's named_params point to model's params (check only the first one)
Expand All @@ -407,8 +406,9 @@ def grad_optimizer_builder(named_param_groups):

assert isinstance(grad_accumulator, FP32GradientAccumulator)
grad_accumulator.assign_param_offsets(
# dp_rank=dist.get_rank(parallel_context.dp_pg),
dp_rank=dist.get_rank(parallel_context.dp_sp_pg),
dp_rank=dist.get_rank(
parallel_context.dp_sp_pg
), # sequence parallel and data parallel process group will synchronize the gradient together
param_name_to_offsets=param_name_to_dp_rank_offsets,
)

Expand All @@ -417,8 +417,7 @@ def grad_optimizer_builder(named_param_groups):
assert isinstance(grad_accumulator, FP32GradientAccumulator)
model.register_comm_hook(
state=FP32GradBucketManager(
# dp_pg=parallel_context.dp_pg,
dp_pg=parallel_context.dp_sp_pg,
dp_pg=parallel_context.dp_sp_pg, # sequence parallel and data parallel process group will synchronize the gradient together
accumulator=grad_accumulator,
param_id_to_name={
id(param): param.get_tied_info().get_full_name_from_module_id_to_prefix(
Expand Down Expand Up @@ -705,7 +704,6 @@ def get_consumed_train_samples_of_a_data_stage_from_ckp(
stage: DatasetStageArgs, metadata: TrainingMetadata
) -> Optional[int]:
start_training_step = stage.start_training_step
# TODO: if a new dataset is added, we know that it has consumed 0 tokens? Cannot add new dataset for now.
return next(
(s.consumed_train_samples for s in metadata.data_stages if s.start_training_step == start_training_step),
None,
Expand Down

0 comments on commit 07ebf2c

Please sign in to comment.