Skip to content

Commit

Permalink
Merge pull request #69 from huggingface/nouamane/refacto-pgs
Browse files Browse the repository at this point in the history
Refactor `ParallelContext` and some process groups creation
  • Loading branch information
NouamaneTazi authored Feb 16, 2024
2 parents fa4685a + 226c1c6 commit e98e3ed
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 105 deletions.
12 changes: 2 additions & 10 deletions src/nanotron/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,16 +1445,8 @@ def tie_custom_params(self) -> None:
shared_weights = [
(
name,
# This adds all the tp_ranks in one go
tuple(
sorted(
self.parallel_context.world_rank_matrix[
dist.get_rank(self.parallel_context.pp_pg),
dist.get_rank(self.parallel_context.dp_pg),
:,
]
)
),
# sync across TP group
tuple(sorted(dist.get_process_group_ranks(self.parallel_context.tp_pg))),
)
]
tie_parameters(
Expand Down
94 changes: 35 additions & 59 deletions src/nanotron/parallel/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,77 +62,53 @@ def __init__(

def _init_parallel_groups(self):
"""Initialize 3D parallelism's all process groups."""
# NOTE: ensure all processes have joined the global group
# before creating other groups
dist.barrier(group=self.world_pg)

rank = int(os.environ["RANK"])
dist.barrier()
world_size = int(os.environ["WORLD_SIZE"])

ranks = np.arange(0, world_size).reshape(
(self.pipeline_parallel_size, self.data_parallel_size, self.tensor_parallel_size)
)
world_ranks_to_pg = {}
self.world_ranks_to_pg = {}

tp_pg: dist.ProcessGroup
ranks_with_tp_last = ranks.reshape(
(self.pipeline_parallel_size * self.data_parallel_size, self.tensor_parallel_size)
# Relevent process groups containing the current rank
self.tp_pg = self.create_new_group(
ranks.reshape((self.pipeline_parallel_size * self.data_parallel_size, self.tensor_parallel_size))
)
for tp_ranks in ranks_with_tp_last:
sorted_ranks = tuple(sorted(tp_ranks))
if sorted_ranks not in world_ranks_to_pg:
new_group = dist.new_group(ranks=tp_ranks)
world_ranks_to_pg[sorted_ranks] = new_group
else:
new_group = world_ranks_to_pg[sorted_ranks]
if rank in tp_ranks:
tp_pg = new_group

dp_pg: dist.ProcessGroup
ranks_with_dp_last = ranks.transpose((0, 2, 1)).reshape(
(self.pipeline_parallel_size * self.tensor_parallel_size, self.data_parallel_size)
self.dp_pg = self.create_new_group(
ranks.transpose((0, 2, 1)).reshape(
(self.pipeline_parallel_size * self.tensor_parallel_size, self.data_parallel_size)
)
)
self.pp_pg = self.create_new_group(
ranks.transpose((2, 1, 0)).reshape(
(self.tensor_parallel_size * self.data_parallel_size, self.pipeline_parallel_size)
)
)
for dp_ranks in ranks_with_dp_last:
sorted_ranks = tuple(sorted(dp_ranks))
if sorted_ranks not in world_ranks_to_pg:
new_group = dist.new_group(ranks=dp_ranks)
world_ranks_to_pg[sorted_ranks] = new_group
else:
new_group = world_ranks_to_pg[sorted_ranks]
if rank in dp_ranks:
dp_pg = new_group

pp_pg: dist.ProcessGroup
ranks_with_pp_last = ranks.transpose((2, 1, 0)).reshape(
(self.tensor_parallel_size * self.data_parallel_size, self.pipeline_parallel_size)
# model parallel group = combination of tp and pp for a given dp rank
self.mp_pg = self.create_new_group(
[ranks[:, dp_rank, :].reshape(-1) for dp_rank in range(self.data_parallel_size)]
)
for pp_ranks in ranks_with_pp_last:
sorted_ranks = tuple(sorted(pp_ranks))
if sorted_ranks not in world_ranks_to_pg:
new_group = dist.new_group(ranks=pp_ranks)
world_ranks_to_pg[sorted_ranks] = new_group

self.world_rank_matrix: np.ndarray = ranks

def create_new_group(self, all_groups_ranks: np.ndarray) -> dist.ProcessGroup:
dist.barrier()
rank = int(os.environ["RANK"])
new_group_containing_rank = None
for group_ranks in all_groups_ranks:
sorted_ranks = tuple(sorted(group_ranks))

# add new group to `world_ranks_to_pg`
if sorted_ranks not in self.world_ranks_to_pg:
new_group = dist.new_group(ranks=group_ranks)
self.world_ranks_to_pg[sorted_ranks] = new_group
else:
new_group = world_ranks_to_pg[sorted_ranks]
if rank in pp_ranks:
pp_pg = new_group

# TODO(xrsrke): this looks unnecessary, remove it if possible
# We build model parallel group (combination of both tensor parallel and pipeline parallel)
for dp_rank in range(self.data_parallel_size):
pp_and_tp_ranks = ranks[:, dp_rank, :].reshape(-1)
sorted_ranks = tuple(sorted(pp_and_tp_ranks))
if sorted_ranks not in world_ranks_to_pg:
new_group = dist.new_group(ranks=pp_and_tp_ranks)
world_ranks_to_pg[sorted_ranks] = new_group

self.tp_pg = tp_pg
self.dp_pg = dp_pg
self.pp_pg = pp_pg

self.world_rank_matrix = ranks
self.world_ranks_to_pg = world_ranks_to_pg
new_group = self.world_ranks_to_pg[sorted_ranks]

if rank in sorted_ranks:
new_group_containing_rank = new_group
dist.barrier()
return new_group_containing_rank

def set_device(self):
local_rank = int(os.getenv("LOCAL_RANK", "0"))
Expand Down
21 changes: 3 additions & 18 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,17 +350,8 @@ def training_step(
for name, param in self.unwrapped_model.get_named_params_with_correct_tied()
if param.requires_grad
]
# TODO @nouamane: we need to split `world_rank_matrix` along PP axis, to separate ref from active model
self.grad_norm_unclipped = clip_grad_norm(
mp_pg=self.parallel_context.world_ranks_to_pg[
tuple(
sorted(
self.parallel_context.world_rank_matrix[
:, dist.get_rank(self.parallel_context.dp_pg), :
].reshape(-1)
)
)
],
mp_pg=self.parallel_context.mp_pg,
named_parameters=named_parameters,
grad_accumulator=self.grad_accumulator,
max_norm=self.config.optimizer.clip_grad,
Expand Down Expand Up @@ -784,14 +775,8 @@ def mark_tied_parameters(
shared_weights = [
(
name,
# This adds all the tp_ranks in one go
tuple(
sorted(
parallel_context.world_rank_matrix[
dist.get_rank(parallel_context.pp_pg), dist.get_rank(parallel_context.dp_pg), :
]
)
),
# sync across TP group
tuple(sorted(dist.get_process_group_ranks(parallel_context.tp_pg))),
)
]

Expand Down
8 changes: 2 additions & 6 deletions tests/helpers/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,8 @@ def init_dummy_model(parallel_context: ParallelContext, dtype: torch.dtype = tor
shared_weights = [
(
name,
# This adds all the tp_ranks in one go
set(
parallel_context.world_rank_matrix[
dist.get_rank(parallel_context.pp_pg), dist.get_rank(parallel_context.dp_pg), :
]
),
# sync across TP group
tuple(sorted(dist.get_process_group_ranks(parallel_context.tp_pg))),
)
]
tie_parameters(
Expand Down
16 changes: 4 additions & 12 deletions tests/test_clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,7 @@ def _test_clip_grads_with_pp(parallel_context: ParallelContext, norm_type: float
old_bias_grad = non_linear.bias.grad.clone()
# Clip grads
total_norm = clip_grad_norm(
mp_pg=parallel_context.world_ranks_to_pg[
tuple(sorted(parallel_context.world_rank_matrix[:, dist.get_rank(parallel_context.dp_pg), :].reshape(-1)))
],
mp_pg=parallel_context.mp_pg,
named_parameters=model.named_parameters(),
grad_accumulator=None,
max_norm=1.0,
Expand Down Expand Up @@ -311,9 +309,7 @@ def _test_clip_grads_with_tp(
old_grad = column_linear.weight.grad.clone()
# Clip grads
total_norm = clip_grad_norm(
mp_pg=parallel_context.world_ranks_to_pg[
tuple(sorted(parallel_context.world_rank_matrix[:, dist.get_rank(parallel_context.dp_pg), :].reshape(-1)))
],
mp_pg=parallel_context.mp_pg,
named_parameters=column_linear.named_parameters(),
grad_accumulator=None,
max_norm=1.0,
Expand Down Expand Up @@ -420,9 +416,7 @@ def _test_clip_grads_tied_weights(parallel_context: ParallelContext, norm_type:
old_grad = weight.grad.clone()
# Clip grads
total_norm = clip_grad_norm(
mp_pg=parallel_context.world_ranks_to_pg[
tuple(sorted(parallel_context.world_rank_matrix[:, dist.get_rank(parallel_context.dp_pg), :].reshape(-1)))
],
mp_pg=parallel_context.mp_pg,
named_parameters=model.named_parameters(),
grad_accumulator=None,
max_norm=1.0,
Expand Down Expand Up @@ -560,9 +554,7 @@ def _test_clip_grads_fp32_accumulator(

# Clip grads
total_norm = clip_grad_norm(
mp_pg=parallel_context.world_ranks_to_pg[
tuple(sorted(parallel_context.world_rank_matrix[:, dist.get_rank(parallel_context.dp_pg), :].reshape(-1)))
],
mp_pg=parallel_context.mp_pg,
named_parameters=model.named_parameters(),
grad_accumulator=grad_accumulator,
max_norm=1.0,
Expand Down

0 comments on commit e98e3ed

Please sign in to comment.