Skip to content

Commit cf48e9e

Browse files
author
pytorchbot
committed
2025-02-13 nightly release (fd45bdc)
1 parent 8cb817f commit cf48e9e

12 files changed

+441
-37
lines changed

torchrec/distributed/comm.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,14 @@ def intra_and_cross_node_pg_2D(
226226

227227
if _INTRA_PG_2D is None:
228228
for group_rank in range(step):
229-
sharding_pg_peers = [
230-
step * r + group_rank for r in range(sharding_group_size)
231-
]
229+
if env.use_inter_host_allreduce:
230+
# for inter host all reduce, we change the sharding group calculation to be continuous
231+
ranks = group_rank * sharding_group_size
232+
sharding_pg_peers = list(range(ranks, ranks + sharding_group_size))
233+
else:
234+
sharding_pg_peers = [
235+
step * r + group_rank for r in range(sharding_group_size)
236+
]
232237
for group in range(len(sharding_pg_peers) // devices_per_node):
233238
intra_pg_peers = sharding_pg_peers[
234239
group * devices_per_node : (group + 1) * devices_per_node

torchrec/distributed/embedding.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
3333
from torch.distributed._tensor import DTensor
3434
from torch.nn.parallel import DistributedDataParallel
35+
from torchrec.distributed.comm import get_local_size
3536
from torchrec.distributed.embedding_sharding import (
3637
EmbeddingSharding,
3738
EmbeddingShardingInfo,
@@ -69,13 +70,16 @@
6970
QuantizedCommCodecs,
7071
ShardedTensor,
7172
ShardingEnv,
73+
ShardingEnv2D,
7274
ShardMetadata,
7375
)
7476
from torchrec.distributed.utils import (
7577
add_params_from_parameter_sharding,
7678
convert_to_fbgemm_types,
79+
create_global_tensor_shape_stride_from_metadata,
7780
maybe_annotate_embedding_event,
7881
merge_fused_params,
82+
none_throws,
7983
optimizer_type_to_emb_opt_type,
8084
)
8185
from torchrec.modules.embedding_configs import (
@@ -534,12 +538,9 @@ def __init__(
534538
if table_name in self._table_names
535539
},
536540
)
537-
# output parameters as DTensor in state dict
538-
self._output_dtensor: bool = (
539-
fused_params.get("output_dtensor", False) if fused_params else False
540-
)
541-
542541
self._env = env
542+
# output parameters as DTensor in state dict
543+
self._output_dtensor: bool = env.output_dtensor
543544
# TODO get rid of get_ec_index_dedup global flag
544545
self._use_index_dedup: bool = use_index_dedup or get_ec_index_dedup()
545546
sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
@@ -842,6 +843,14 @@ def _initialize_torch_state(self) -> None: # noqa
842843
)
843844
)
844845
else:
846+
shape, stride = create_global_tensor_shape_stride_from_metadata(
847+
none_throws(self.module_sharding_plan[table_name]),
848+
(
849+
self._env.node_group_size
850+
if isinstance(self._env, ShardingEnv2D)
851+
else get_local_size(self._env.world_size)
852+
),
853+
)
845854
# empty shard case
846855
self._model_parallel_name_to_dtensor[table_name] = (
847856
DTensor.from_local(
@@ -851,6 +860,8 @@ def _initialize_torch_state(self) -> None: # noqa
851860
),
852861
device_mesh=self._env.device_mesh,
853862
run_check=False,
863+
shape=shape,
864+
stride=stride,
854865
)
855866
)
856867
else:
@@ -861,7 +872,11 @@ def _initialize_torch_state(self) -> None: # noqa
861872
ShardedTensor._init_from_local_shards(
862873
local_shards,
863874
self._name_to_table_size[table_name],
864-
process_group=self._env.process_group,
875+
process_group=(
876+
self._env.sharding_pg
877+
if isinstance(self._env, ShardingEnv2D)
878+
else self._env.process_group
879+
),
865880
)
866881
)
867882

torchrec/distributed/model_parallel.py

+50-12
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,7 @@ def __init__(
690690
init_data_parallel: bool = True,
691691
init_parameters: bool = True,
692692
data_parallel_wrapper: Optional[DataParallelWrapper] = None,
693+
use_inter_host_allreduce: bool = False,
693694
) -> None:
694695
assert device.type == "cuda", "DMPCollection only supports CUDA"
695696
self._device = device
@@ -705,13 +706,16 @@ def __init__(
705706
global_rank=self._global_rank,
706707
world_size=world_size,
707708
local_size=sharding_group_size,
709+
use_inter_host_allreduce=use_inter_host_allreduce,
708710
)
709711
)
710712

711713
self._remap_sharding_plan(
712714
plan=plan,
713715
rank=self._global_rank,
714-
num_nodes=world_size // sharding_group_size,
716+
step=world_size // sharding_group_size,
717+
sharding_group_size=sharding_group_size,
718+
use_inter_host_allreduce=use_inter_host_allreduce,
715719
)
716720
super().__init__(
717721
module,
@@ -720,6 +724,7 @@ def __init__(
720724
sharding_pg=self._sharding_pg,
721725
device_mesh=self._device_mesh,
722726
node_group_size=node_group_size,
727+
use_inter_host_allreduce=use_inter_host_allreduce,
723728
),
724729
device,
725730
plan,
@@ -768,7 +773,11 @@ def sync(self, include_optimizer_state: bool = True) -> None:
768773
handle.wait()
769774

770775
def _create_process_groups(
771-
self, global_rank: int, world_size: int, local_size: int
776+
self,
777+
global_rank: int,
778+
world_size: int,
779+
local_size: int,
780+
use_inter_host_allreduce: bool = False,
772781
) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]:
773782
"""
774783
Creates process groups for sharding and replication, the process groups
@@ -784,17 +793,29 @@ def _create_process_groups(
784793
replication process group, and allreduce process group.
785794
"""
786795
peer_matrix = []
787-
num_nodes = world_size // local_size
796+
mesh, sharding_pg, replica_pg = None, None, None
788797

789-
for group_rank in range(world_size // local_size):
790-
peers = [num_nodes * r + group_rank for r in range(local_size)]
791-
peer_matrix.append(peers)
798+
logger.warning(f"[2D] Use inter host all reduce: {use_inter_host_allreduce}")
799+
800+
if use_inter_host_allreduce:
801+
# We shard on continuous set of ranks and nodes. Thereby forcing our all reduce to be inter host.
802+
# Under this scheme sharding types such as TWRW and GRID will now take
803+
# advantage of intra node comms as a result of the continuous set of ranks.
804+
peer_matrix = [
805+
list(range(i, i + local_size)) for i in range(0, world_size, local_size)
806+
]
807+
else:
808+
step = world_size // local_size
809+
for group_rank in range(world_size // local_size):
810+
peers = [step * r + group_rank for r in range(local_size)]
811+
peer_matrix.append(peers)
792812

793813
mesh = DeviceMesh(
794814
device_type=self._device.type,
795815
mesh=peer_matrix,
796816
mesh_dim_names=("replicate", "shard"),
797817
)
818+
798819
logger.warning(f"[Connection] 2D Device Mesh created: {mesh}")
799820
sharding_pg = mesh.get_group(mesh_dim="shard")
800821
logger.warning(
@@ -808,7 +829,12 @@ def _create_process_groups(
808829
return mesh, sharding_pg, replica_pg
809830

810831
def _remap_sharding_plan(
811-
self, plan: ShardingPlan, rank: int, num_nodes: int
832+
self,
833+
plan: ShardingPlan,
834+
rank: int,
835+
step: int,
836+
sharding_group_size: int,
837+
use_inter_host_allreduce: bool = False,
812838
) -> None:
813839
"""
814840
Remaps the sharding plan to the local replica process group ranks
@@ -822,20 +848,32 @@ def _remap_sharding_plan(
822848
global_rank (int): The global rank of the current process.
823849
num_nodes (int): The number of nodes.
824850
"""
825-
826-
group_start = rank % num_nodes
851+
group_start = rank % step
827852
for key in plan.plan:
828853
# pyre-ignore[16]
829854
for _, param_sharding in plan.plan[key].items():
830855
new_ranks = []
831-
for shard_rank in param_sharding.ranks:
832-
new_ranks.append(shard_rank * num_nodes + group_start)
856+
if use_inter_host_allreduce:
857+
group = rank // sharding_group_size
858+
new_ranks = [
859+
shard_rank + (group * sharding_group_size)
860+
for shard_rank in param_sharding.ranks
861+
]
862+
else:
863+
for shard_rank in param_sharding.ranks:
864+
new_ranks.append(shard_rank * step + group_start)
833865
param_sharding.ranks = new_ranks
866+
834867
if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec):
835868
shards = param_sharding.sharding_spec.shards
836869
if shards is not None:
837870
for shard in shards:
838-
shard_rank = shard.placement._rank * num_nodes + group_start
871+
if use_inter_host_allreduce:
872+
shard_rank = shard.placement._rank + (
873+
(rank // sharding_group_size) * sharding_group_size
874+
)
875+
else:
876+
shard_rank = shard.placement._rank * step + group_start
839877
shard.placement = _remote_device(
840878
f"rank:{shard_rank}/cuda:{shard_rank % get_local_size()}"
841879
)

torchrec/distributed/sharding/cw_sharding.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
QuantizedCommCodecs,
4646
ShardedTensorMetadata,
4747
ShardingEnv,
48+
ShardingType,
4849
ShardMetadata,
4950
)
5051
from torchrec.distributed.utils import none_throws
@@ -191,7 +192,7 @@ def _shard(
191192
for i, rank in enumerate(info.param_sharding.ranks):
192193
# Remap rank by number of replica groups if 2D parallelism is enabled
193194
rank = (
194-
rank // self._env.num_sharding_groups() # pyre-ignore[16]
195+
self._env.remap_rank(rank, ShardingType.COLUMN_WISE) # pyre-ignore[16]
195196
if self._is_2D_parallel
196197
else rank
197198
)

torchrec/distributed/sharding/grid_sharding.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def _shard(
250250
# pyre-fixme [6]
251251
for i, rank in enumerate(info.param_sharding.ranks):
252252
rank = (
253-
rank // self._env.num_sharding_groups() # pyre-ignore[16]
253+
self._env.remap_rank(rank, ShardingType.GRID_SHARD) # pyre-ignore[16]
254254
if self._is_2D_parallel
255255
else rank
256256
)

torchrec/distributed/sharding/tw_sharding.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
ShardedTensorMetadata,
4949
ShardingEnv,
5050
ShardingEnv2D,
51+
ShardingType,
5152
ShardMetadata,
5253
)
5354
from torchrec.distributed.utils import none_throws
@@ -128,7 +129,7 @@ def _shard(
128129
)
129130

130131
dtensor_metadata = None
131-
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
132+
if self._env.output_dtensor:
132133
dtensor_metadata = DTensorMetadata(
133134
mesh=(
134135
self._env.device_mesh["replicate"] # pyre-ignore[16]
@@ -142,12 +143,12 @@ def _shard(
142143
),
143144
stride=info.param.stride(),
144145
)
145-
# to not pass onto TBE
146-
info.fused_params.pop("output_dtensor", None) # pyre-ignore[16]
147146

148147
rank = (
149148
# pyre-ignore [16]
150-
info.param_sharding.ranks[0] // self._env.num_sharding_groups()
149+
self._env.remap_rank(
150+
info.param_sharding.ranks[0], ShardingType.TABLE_WISE # pyre-ignore[16]
151+
)
151152
if self._is_2D_parallel
152153
else info.param_sharding.ranks[0]
153154
)

torchrec/distributed/sharding/twrw_sharding.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torch
1515
import torch.distributed as dist
16-
from torch.distributed._tensor import Shard
16+
from torch.distributed._tensor import Replicate, Shard
1717
from torch.distributed.distributed_c10d import get_process_group_ranks
1818
from torchrec.distributed.comm import (
1919
get_local_size,
@@ -165,10 +165,11 @@ def _shard(
165165

166166
dtensor_metadata = None
167167
if self._env.output_dtensor:
168-
placements = (Shard(0),)
169168
dtensor_metadata = DTensorMetadata(
170169
mesh=self._env.device_mesh,
171-
placements=placements,
170+
placements=(
171+
(Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),)
172+
),
172173
size=(
173174
info.embedding_config.num_embeddings,
174175
info.embedding_config.embedding_dim,

torchrec/distributed/test_utils/test_model_parallel.py

+2
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def _test_sharding(
149149
global_constant_batch: bool = False,
150150
pooling: PoolingType = PoolingType.SUM,
151151
data_type: DataType = DataType.FP32,
152+
use_inter_host_allreduce: bool = False,
152153
) -> None:
153154
self._build_tables_and_groups(data_type=data_type)
154155
self._run_multi_process_test(
@@ -170,6 +171,7 @@ def _test_sharding(
170171
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
171172
variable_batch_per_feature=variable_batch_per_feature,
172173
global_constant_batch=global_constant_batch,
174+
use_inter_host_allreduce=use_inter_host_allreduce,
173175
)
174176

175177

torchrec/distributed/test_utils/test_sharding.py

+2
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def sharding_single_rank_test(
315315
global_constant_batch: bool = False,
316316
world_size_2D: Optional[int] = None,
317317
node_group_size: Optional[int] = None,
318+
use_inter_host_allreduce: bool = False,
318319
input_type: str = "kjt", # "kjt" or "td"
319320
) -> None:
320321
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
@@ -432,6 +433,7 @@ def sharding_single_rank_test(
432433
plan=plan,
433434
sharders=sharders,
434435
device=ctx.device,
436+
use_inter_host_allreduce=use_inter_host_allreduce,
435437
)
436438
else:
437439
local_model = DistributedModelParallel(

0 commit comments

Comments
 (0)