@@ -690,6 +690,7 @@ def __init__(
690
690
init_data_parallel : bool = True ,
691
691
init_parameters : bool = True ,
692
692
data_parallel_wrapper : Optional [DataParallelWrapper ] = None ,
693
+ use_inter_host_allreduce : bool = False ,
693
694
) -> None :
694
695
assert device .type == "cuda" , "DMPCollection only supports CUDA"
695
696
self ._device = device
@@ -705,13 +706,16 @@ def __init__(
705
706
global_rank = self ._global_rank ,
706
707
world_size = world_size ,
707
708
local_size = sharding_group_size ,
709
+ use_inter_host_allreduce = use_inter_host_allreduce ,
708
710
)
709
711
)
710
712
711
713
self ._remap_sharding_plan (
712
714
plan = plan ,
713
715
rank = self ._global_rank ,
714
- num_nodes = world_size // sharding_group_size ,
716
+ step = world_size // sharding_group_size ,
717
+ sharding_group_size = sharding_group_size ,
718
+ use_inter_host_allreduce = use_inter_host_allreduce ,
715
719
)
716
720
super ().__init__ (
717
721
module ,
@@ -720,6 +724,7 @@ def __init__(
720
724
sharding_pg = self ._sharding_pg ,
721
725
device_mesh = self ._device_mesh ,
722
726
node_group_size = node_group_size ,
727
+ use_inter_host_allreduce = use_inter_host_allreduce ,
723
728
),
724
729
device ,
725
730
plan ,
@@ -768,7 +773,11 @@ def sync(self, include_optimizer_state: bool = True) -> None:
768
773
handle .wait ()
769
774
770
775
def _create_process_groups (
771
- self , global_rank : int , world_size : int , local_size : int
776
+ self ,
777
+ global_rank : int ,
778
+ world_size : int ,
779
+ local_size : int ,
780
+ use_inter_host_allreduce : bool = False ,
772
781
) -> Tuple [DeviceMesh , dist .ProcessGroup , dist .ProcessGroup ]:
773
782
"""
774
783
Creates process groups for sharding and replication, the process groups
@@ -784,17 +793,29 @@ def _create_process_groups(
784
793
replication process group, and allreduce process group.
785
794
"""
786
795
peer_matrix = []
787
- num_nodes = world_size // local_size
796
+ mesh , sharding_pg , replica_pg = None , None , None
788
797
789
- for group_rank in range (world_size // local_size ):
790
- peers = [num_nodes * r + group_rank for r in range (local_size )]
791
- peer_matrix .append (peers )
798
+ logger .warning (f"[2D] Use inter host all reduce: { use_inter_host_allreduce } " )
799
+
800
+ if use_inter_host_allreduce :
801
+ # We shard on continuous set of ranks and nodes. Thereby forcing our all reduce to be inter host.
802
+ # Under this scheme sharding types such as TWRW and GRID will now take
803
+ # advantage of intra node comms as a result of the continuous set of ranks.
804
+ peer_matrix = [
805
+ list (range (i , i + local_size )) for i in range (0 , world_size , local_size )
806
+ ]
807
+ else :
808
+ step = world_size // local_size
809
+ for group_rank in range (world_size // local_size ):
810
+ peers = [step * r + group_rank for r in range (local_size )]
811
+ peer_matrix .append (peers )
792
812
793
813
mesh = DeviceMesh (
794
814
device_type = self ._device .type ,
795
815
mesh = peer_matrix ,
796
816
mesh_dim_names = ("replicate" , "shard" ),
797
817
)
818
+
798
819
logger .warning (f"[Connection] 2D Device Mesh created: { mesh } " )
799
820
sharding_pg = mesh .get_group (mesh_dim = "shard" )
800
821
logger .warning (
@@ -808,7 +829,12 @@ def _create_process_groups(
808
829
return mesh , sharding_pg , replica_pg
809
830
810
831
def _remap_sharding_plan (
811
- self , plan : ShardingPlan , rank : int , num_nodes : int
832
+ self ,
833
+ plan : ShardingPlan ,
834
+ rank : int ,
835
+ step : int ,
836
+ sharding_group_size : int ,
837
+ use_inter_host_allreduce : bool = False ,
812
838
) -> None :
813
839
"""
814
840
Remaps the sharding plan to the local replica process group ranks
@@ -822,20 +848,32 @@ def _remap_sharding_plan(
822
848
global_rank (int): The global rank of the current process.
823
849
num_nodes (int): The number of nodes.
824
850
"""
825
-
826
- group_start = rank % num_nodes
851
+ group_start = rank % step
827
852
for key in plan .plan :
828
853
# pyre-ignore[16]
829
854
for _ , param_sharding in plan .plan [key ].items ():
830
855
new_ranks = []
831
- for shard_rank in param_sharding .ranks :
832
- new_ranks .append (shard_rank * num_nodes + group_start )
856
+ if use_inter_host_allreduce :
857
+ group = rank // sharding_group_size
858
+ new_ranks = [
859
+ shard_rank + (group * sharding_group_size )
860
+ for shard_rank in param_sharding .ranks
861
+ ]
862
+ else :
863
+ for shard_rank in param_sharding .ranks :
864
+ new_ranks .append (shard_rank * step + group_start )
833
865
param_sharding .ranks = new_ranks
866
+
834
867
if isinstance (param_sharding .sharding_spec , EnumerableShardingSpec ):
835
868
shards = param_sharding .sharding_spec .shards
836
869
if shards is not None :
837
870
for shard in shards :
838
- shard_rank = shard .placement ._rank * num_nodes + group_start
871
+ if use_inter_host_allreduce :
872
+ shard_rank = shard .placement ._rank + (
873
+ (rank // sharding_group_size ) * sharding_group_size
874
+ )
875
+ else :
876
+ shard_rank = shard .placement ._rank * step + group_start
839
877
shard .placement = _remote_device (
840
878
f"rank:{ shard_rank } /cuda:{ shard_rank % get_local_size ()} "
841
879
)
0 commit comments