Skip to content

Commit cb6b69a

Browse files
gnahzgfacebook-github-bot
authored andcommitted
Support Hybrid Sharding for DI (#1786)
Summary: Pull Request resolved: #1786 Miminum change to support hybrid sharding for DI Context: https://docs.google.com/document/d/1Y0H1TntfZkW5Cgw0_B_gydC9qPIQtpirchVtzSADid8/edit#heading=h.z2j5qijdvagp TLDR: DI need to sharding table from the same EC in a way that some table goes to CPU, some tables go to GPU. Currently we only support all host/devices as a whole env. Below changes enable sharding by device group. Most implementation copy from D54570308 with adjustment in sharding according to device instead of sharding type TODO: generialize support for hybrid sharding by (1) Explicitly supporting creating sharding plan with different world_size for different device group (2) Clean up code (3) Support generate sharding plan by device group, merge sharding plan Reviewed By: IvanKobzarev Differential Revision: D54805360 fbshipit-source-id: d7b8c7e0232d6de457fc991f55c4c6bd3d53812c
1 parent 0ec0b2e commit cb6b69a

File tree

4 files changed

+152
-6
lines changed

4 files changed

+152
-6
lines changed

torchrec/distributed/quant_embedding.py

+35-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from collections import defaultdict, deque
1212
from dataclasses import dataclass
13-
from typing import Any, cast, Dict, List, Optional, Tuple, Type
13+
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
1414

1515
import torch
1616
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
@@ -80,6 +80,24 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
8080
ctx.record_stream(stream)
8181

8282

83+
def get_device_from_parameter_sharding(ps: ParameterSharding) -> str:
84+
# pyre-ignore
85+
return ps.sharding_spec.shards[0].placement.device().type
86+
87+
88+
def get_device_from_sharding_type(
89+
emb_shard_infos: List[EmbeddingShardingInfo],
90+
) -> str:
91+
res = list(
92+
{
93+
get_device_from_parameter_sharding(ps.param_sharding)
94+
for ps in emb_shard_infos
95+
}
96+
)
97+
assert len(res) == 1, "All shards should be on the same type of device"
98+
return res[0]
99+
100+
83101
def create_infer_embedding_sharding(
84102
sharding_type: str,
85103
sharding_infos: List[EmbeddingShardingInfo],
@@ -336,19 +354,25 @@ def __init__(
336354
self,
337355
module: QuantEmbeddingCollection,
338356
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
339-
env: ShardingEnv,
357+
# TODO: Consolidate to use Dict[str, ShardingEnv]
358+
env: Union[
359+
ShardingEnv, Dict[str, ShardingEnv]
360+
], # Support hybrid sharding for DI
340361
fused_params: Optional[Dict[str, Any]] = None,
341362
device: Optional[torch.device] = None,
342363
) -> None:
343364
super().__init__()
344365

345366
self._embedding_configs: List[EmbeddingConfig] = module.embedding_configs()
346367

368+
self._is_hybrid_sharding: bool = isinstance(env, Dict)
369+
347370
self._sharding_type_to_sharding_infos: Dict[
348371
str, List[EmbeddingShardingInfo]
349372
] = create_sharding_infos_by_sharding(
350373
module, table_name_to_parameter_sharding, fused_params
351374
)
375+
352376
self._sharding_type_to_sharding: Dict[
353377
str,
354378
EmbeddingSharding[
@@ -359,7 +383,14 @@ def __init__(
359383
],
360384
] = {
361385
sharding_type: create_infer_embedding_sharding(
362-
sharding_type, embedding_confings, env
386+
sharding_type,
387+
embedding_confings,
388+
(
389+
env
390+
if not self._is_hybrid_sharding
391+
# pyre-ignore
392+
else env[get_device_from_sharding_type(embedding_confings)]
393+
),
363394
)
364395
for sharding_type, embedding_confings in self._sharding_type_to_sharding_infos.items()
365396
}
@@ -732,7 +763,7 @@ def shard(
732763
self,
733764
module: QuantEmbeddingCollection,
734765
params: Dict[str, ParameterSharding],
735-
env: ShardingEnv,
766+
env: Union[ShardingEnv, Dict[str, ShardingEnv]],
736767
device: Optional[torch.device] = None,
737768
) -> ShardedQuantEmbeddingCollection:
738769
fused_params = self.fused_params if self.fused_params else {}

torchrec/distributed/shard.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,10 @@ def init_weights(m):
189189

190190
def _shard_modules( # noqa: C901
191191
module: nn.Module,
192-
env: Optional[ShardingEnv] = None,
192+
# TODO: Consolidate to using Dict[str, ShardingEnv]
193+
env: Optional[
194+
Union[ShardingEnv, Dict[str, ShardingEnv]]
195+
] = None, # Support hybrid sharding
193196
device: Optional[torch.device] = None,
194197
plan: Optional[ShardingPlan] = None,
195198
sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None,
@@ -220,6 +223,9 @@ def _shard_modules( # noqa: C901
220223
}
221224

222225
if plan is None:
226+
assert isinstance(
227+
env, ShardingEnv
228+
), "Currently hybrid sharding only support use manual sharding plan"
223229
planner = EmbeddingShardingPlanner(
224230
topology=Topology(
225231
local_world_size=get_local_size(env.world_size),

torchrec/distributed/test_utils/infer_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def shard(
382382
self,
383383
module: QuantEmbeddingCollection,
384384
params: Dict[str, ParameterSharding],
385-
env: ShardingEnv,
385+
env: Union[Dict[str, ShardingEnv], ShardingEnv],
386386
device: Optional[torch.device] = None,
387387
) -> ShardedQuantEmbeddingCollection:
388388
fused_params = self.fused_params if self.fused_params else {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
#!/usr/bin/env python3
11+
12+
import unittest
13+
14+
import torch
15+
from torchrec import EmbeddingCollection, EmbeddingConfig
16+
from torchrec.distributed.quant_embedding import QuantEmbeddingCollectionSharder
17+
from torchrec.distributed.shard import _shard_modules
18+
from torchrec.distributed.sharding_plan import (
19+
construct_module_sharding_plan,
20+
row_wise,
21+
table_wise,
22+
)
23+
from torchrec.distributed.test_utils.infer_utils import KJTInputWrapper, quantize
24+
from torchrec.distributed.types import ShardingEnv, ShardingPlan
25+
26+
27+
class InferHeteroShardingsTest(unittest.TestCase):
28+
# pyre-ignore
29+
@unittest.skipIf(
30+
torch.cuda.device_count() <= 3,
31+
"Not enough GPUs available",
32+
)
33+
def test_sharder_different_world_sizes(self) -> None:
34+
num_embeddings = 10
35+
emb_dim = 16
36+
world_size = 2
37+
local_size = 1
38+
tables = [
39+
EmbeddingConfig(
40+
num_embeddings=num_embeddings,
41+
embedding_dim=emb_dim,
42+
name=f"table_{i}",
43+
feature_names=[f"feature_{i}"],
44+
)
45+
for i in range(3)
46+
]
47+
model = KJTInputWrapper(
48+
module_kjt_input=torch.nn.Sequential(
49+
EmbeddingCollection(
50+
tables=tables,
51+
device=torch.device("cpu"),
52+
)
53+
)
54+
)
55+
non_sharded_model = quantize(
56+
model,
57+
inplace=False,
58+
quant_state_dict_split_scale_bias=True,
59+
weight_dtype=torch.qint8,
60+
)
61+
sharder = QuantEmbeddingCollectionSharder()
62+
module_plan = construct_module_sharding_plan(
63+
non_sharded_model._module_kjt_input[0],
64+
per_param_sharding={
65+
"table_0": row_wise(([20, 10, 100], "cpu")),
66+
"table_1": table_wise(rank=0, device="cuda"),
67+
"table_2": table_wise(rank=1, device="cuda"),
68+
},
69+
# pyre-ignore
70+
sharder=sharder,
71+
local_size=local_size,
72+
world_size=world_size,
73+
)
74+
plan = ShardingPlan(plan={"_module_kjt_input.0": module_plan})
75+
env_dict = {
76+
"cpu": ShardingEnv.from_local(
77+
3,
78+
0,
79+
),
80+
"cuda": ShardingEnv.from_local(
81+
2,
82+
0,
83+
),
84+
}
85+
sharded_model = _shard_modules(
86+
module=non_sharded_model,
87+
# pyre-ignore
88+
sharders=[sharder],
89+
device=torch.device("cpu"),
90+
plan=plan,
91+
env=env_dict,
92+
)
93+
self.assertTrue(hasattr(sharded_model._module_kjt_input[0], "_lookups"))
94+
self.assertTrue(len(sharded_model._module_kjt_input[0]._lookups) == 2)
95+
for i, env in enumerate(env_dict.values()):
96+
self.assertTrue(
97+
hasattr(
98+
sharded_model._module_kjt_input[0]._lookups[i],
99+
"_embedding_lookups_per_rank",
100+
)
101+
)
102+
self.assertTrue(
103+
len(
104+
sharded_model._module_kjt_input[0]
105+
._lookups[i]
106+
._embedding_lookups_per_rank
107+
)
108+
== env.world_size
109+
)

0 commit comments

Comments
 (0)