From 226c1c68e1d7cbe3cd0786b8de4ce3a444b9b6be Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Wed, 14 Feb 2024 19:59:23 +0000 Subject: [PATCH] Move barrier calls in ParallelContext --- src/nanotron/parallel/context.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/nanotron/parallel/context.py b/src/nanotron/parallel/context.py index d87804d8..5a153036 100644 --- a/src/nanotron/parallel/context.py +++ b/src/nanotron/parallel/context.py @@ -62,8 +62,6 @@ def __init__( def _init_parallel_groups(self): """Initialize 3D parallelism's all process groups.""" - # NOTE: ensure all processes have joined the global group - # before creating other groups dist.barrier() world_size = int(os.environ["WORLD_SIZE"]) ranks = np.arange(0, world_size).reshape( @@ -92,9 +90,9 @@ def _init_parallel_groups(self): ) self.world_rank_matrix: np.ndarray = ranks - dist.barrier() def create_new_group(self, all_groups_ranks: np.ndarray) -> dist.ProcessGroup: + dist.barrier() rank = int(os.environ["RANK"]) new_group_containing_rank = None for group_ranks in all_groups_ranks: @@ -109,6 +107,7 @@ def create_new_group(self, all_groups_ranks: np.ndarray) -> dist.ProcessGroup: if rank in sorted_ranks: new_group_containing_rank = new_group + dist.barrier() return new_group_containing_rank def set_device(self):