Skip to content

Commit

Permalink
Refactor process group initialization in ParallelContext
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Feb 14, 2024
1 parent b457d01 commit 5d3eb67
Showing 1 changed file with 34 additions and 61 deletions.
95 changes: 34 additions & 61 deletions src/nanotron/parallel/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,79 +64,52 @@ def _init_parallel_groups(self):
"""Initialize 3D parallelism's all process groups."""
# NOTE: ensure all processes have joined the global group
# before creating other groups
dist.barrier(group=self.world_pg)

rank = int(os.environ["RANK"])
dist.barrier()
world_size = int(os.environ["WORLD_SIZE"])

ranks = np.arange(0, world_size).reshape(
(self.pipeline_parallel_size, self.data_parallel_size, self.tensor_parallel_size)
)
world_ranks_to_pg = {}
self.world_ranks_to_pg = {}

tp_pg: dist.ProcessGroup
ranks_with_tp_last = ranks.reshape(
(self.pipeline_parallel_size * self.data_parallel_size, self.tensor_parallel_size)
# Relevent process groups containing the current rank
self.tp_pg = self.create_new_group(
ranks.reshape((self.pipeline_parallel_size * self.data_parallel_size, self.tensor_parallel_size))
)
for tp_ranks in ranks_with_tp_last:
sorted_ranks = tuple(sorted(tp_ranks))
if sorted_ranks not in world_ranks_to_pg:
new_group = dist.new_group(ranks=tp_ranks)
world_ranks_to_pg[sorted_ranks] = new_group
else:
new_group = world_ranks_to_pg[sorted_ranks]
if rank in tp_ranks:
tp_pg = new_group

dp_pg: dist.ProcessGroup
ranks_with_dp_last = ranks.transpose((0, 2, 1)).reshape(
(self.pipeline_parallel_size * self.tensor_parallel_size, self.data_parallel_size)
self.dp_pg = self.create_new_group(
ranks.transpose((0, 2, 1)).reshape(
(self.pipeline_parallel_size * self.tensor_parallel_size, self.data_parallel_size)
)
)
self.pp_pg = self.create_new_group(
ranks.transpose((2, 1, 0)).reshape(
(self.tensor_parallel_size * self.data_parallel_size, self.pipeline_parallel_size)
)
)
for dp_ranks in ranks_with_dp_last:
sorted_ranks = tuple(sorted(dp_ranks))
if sorted_ranks not in world_ranks_to_pg:
new_group = dist.new_group(ranks=dp_ranks)
world_ranks_to_pg[sorted_ranks] = new_group
else:
new_group = world_ranks_to_pg[sorted_ranks]
if rank in dp_ranks:
dp_pg = new_group

pp_pg: dist.ProcessGroup
ranks_with_pp_last = ranks.transpose((2, 1, 0)).reshape(
(self.tensor_parallel_size * self.data_parallel_size, self.pipeline_parallel_size)
# model parallel group = combination of tp and pp for a given dp rank
self.mp_pg = self.create_new_group(
[ranks[:, dp_rank, :].reshape(-1) for dp_rank in range(self.data_parallel_size)]
)
for pp_ranks in ranks_with_pp_last:
sorted_ranks = tuple(sorted(pp_ranks))
if sorted_ranks not in world_ranks_to_pg:
new_group = dist.new_group(ranks=pp_ranks)
world_ranks_to_pg[sorted_ranks] = new_group
else:
new_group = world_ranks_to_pg[sorted_ranks]
if rank in pp_ranks:
pp_pg = new_group

# We build model parallel group (combination of both tensor parallel and pipeline parallel)
for dp_rank in range(self.data_parallel_size):
pp_and_tp_ranks = ranks[:, dp_rank, :].reshape(-1)
sorted_ranks = tuple(sorted(pp_and_tp_ranks))
if sorted_ranks not in world_ranks_to_pg:
new_group = dist.new_group(ranks=pp_and_tp_ranks)
world_ranks_to_pg[sorted_ranks] = new_group
else:
new_group = world_ranks_to_pg[sorted_ranks]
if rank in pp_and_tp_ranks:
mp_pg = new_group

self.tp_pg = tp_pg
self.dp_pg = dp_pg
self.pp_pg = pp_pg
self.mp_pg = mp_pg # model parallel group = combination of tp and pp for a given dp rank
self.world_rank_matrix: np.ndarray = ranks
dist.barrier()

self.world_rank_matrix = ranks
self.world_ranks_to_pg = world_ranks_to_pg
def create_new_group(self, all_groups_ranks: np.ndarray) -> dist.ProcessGroup:
rank = int(os.environ["RANK"])
new_group_containing_rank = None
for group_ranks in all_groups_ranks:
sorted_ranks = tuple(sorted(group_ranks))

# add new group to `world_ranks_to_pg`
if sorted_ranks not in self.world_ranks_to_pg:
new_group = dist.new_group(ranks=group_ranks)
self.world_ranks_to_pg[sorted_ranks] = new_group
else:
new_group = self.world_ranks_to_pg[sorted_ranks]

dist.barrier()
if rank in sorted_ranks:
new_group_containing_rank = new_group
return new_group_containing_rank

def set_device(self):
local_rank = int(os.getenv("LOCAL_RANK", "0"))
Expand Down

0 comments on commit 5d3eb67

Please sign in to comment.