Skip to content

Commit dfc82b5

Browse files
author
pytorchbot
committed
2025-03-08 nightly release (411876a)
1 parent dd8636c commit dfc82b5

File tree

2 files changed

+429
-16
lines changed

2 files changed

+429
-16
lines changed

torchrec/distributed/itep_embeddingbag.py

+125-12
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,15 @@
77

88
# pyre-strict
99

10+
from collections import defaultdict
1011
from dataclasses import dataclass
11-
from typing import Dict, List, Optional, Type, Union
12+
from enum import Enum
13+
from typing import Dict, List, Optional, OrderedDict, Tuple, Type, Union
1214

1315
import torch
16+
from torch import nn
17+
from torch.nn.modules.module import _IncompatibleKeys
18+
from torch.nn.parallel import DistributedDataParallel
1419

1520
from torchrec.distributed.embedding_types import (
1621
BaseEmbeddingSharder,
@@ -30,8 +35,9 @@
3035
ShardingEnv,
3136
ShardingType,
3237
)
38+
from torchrec.distributed.utils import filter_state_dict
3339
from torchrec.modules.itep_embedding_modules import ITEPEmbeddingBagCollection
34-
from torchrec.modules.itep_modules import GenericITEPModule
40+
from torchrec.modules.itep_modules import GenericITEPModule, RowwiseShardedITEPModule
3541
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
3642

3743

@@ -40,6 +46,19 @@ class ITEPEmbeddingBagCollectionContext(EmbeddingBagCollectionContext):
4046
is_reindexed: bool = False
4147

4248

49+
class ShardingTypeGroup(Enum):
50+
CW_GROUP = "column_wise_group"
51+
RW_GROUP = "row_wise_group"
52+
53+
54+
SHARDING_TYPE_TO_GROUP: Dict[str, ShardingTypeGroup] = {
55+
ShardingType.ROW_WISE.value: ShardingTypeGroup.RW_GROUP,
56+
ShardingType.TABLE_ROW_WISE.value: ShardingTypeGroup.RW_GROUP,
57+
ShardingType.COLUMN_WISE.value: ShardingTypeGroup.CW_GROUP,
58+
ShardingType.TABLE_WISE.value: ShardingTypeGroup.CW_GROUP,
59+
}
60+
61+
4362
class ShardedITEPEmbeddingBagCollection(
4463
ShardedEmbeddingModule[
4564
KJTList,
@@ -75,10 +94,34 @@ def __init__(
7594
)
7695
)
7796

97+
self.table_name_to_sharding_type: Dict[str, str] = {}
98+
for table_name in table_name_to_parameter_sharding.keys():
99+
self.table_name_to_sharding_type[table_name] = (
100+
table_name_to_parameter_sharding[table_name].sharding_type
101+
)
102+
103+
# Group lookups, table_name_to_unpruned_hash_sizes by sharding type and pass to separate itep modules
104+
(grouped_lookups, grouped_table_unpruned_size_map) = (
105+
self._group_lookups_and_table_unpruned_size_map(
106+
module._itep_module.table_name_to_unpruned_hash_sizes,
107+
)
108+
)
109+
78110
# Instantiate ITEP Module in sharded case, re-using metadata from non-sharded case
79111
self._itep_module: GenericITEPModule = GenericITEPModule(
80-
table_name_to_unpruned_hash_sizes=module._itep_module.table_name_to_unpruned_hash_sizes,
81-
lookups=self._embedding_bag_collection._lookups,
112+
table_name_to_unpruned_hash_sizes=grouped_table_unpruned_size_map[
113+
ShardingTypeGroup.CW_GROUP
114+
],
115+
lookups=grouped_lookups[ShardingTypeGroup.CW_GROUP],
116+
pruning_interval=module._itep_module.pruning_interval,
117+
enable_pruning=module._itep_module.enable_pruning,
118+
)
119+
self._rowwise_itep_module: RowwiseShardedITEPModule = RowwiseShardedITEPModule(
120+
table_name_to_sharding_type=self.table_name_to_sharding_type,
121+
table_name_to_unpruned_hash_sizes=grouped_table_unpruned_size_map[
122+
ShardingTypeGroup.RW_GROUP
123+
],
124+
lookups=grouped_lookups[ShardingTypeGroup.RW_GROUP],
82125
pruning_interval=module._itep_module.pruning_interval,
83126
enable_pruning=module._itep_module.enable_pruning,
84127
)
@@ -106,8 +149,16 @@ def input_dist(
106149
return self._embedding_bag_collection.input_dist(ctx, features)
107150

108151
def _reindex(self, dist_input: KJTList) -> KJTList:
109-
for i in range(len(dist_input)):
110-
remapped_kjt = self._itep_module(dist_input[i], self._iter.item())
152+
for i, (sharding, features) in enumerate(
153+
zip(
154+
self._embedding_bag_collection._sharding_types,
155+
dist_input,
156+
)
157+
):
158+
if SHARDING_TYPE_TO_GROUP[sharding] == ShardingTypeGroup.CW_GROUP:
159+
remapped_kjt = self._itep_module(features, self._iter.item())
160+
else:
161+
remapped_kjt = self._rowwise_itep_module(features, self._iter.item())
111162
dist_input[i] = remapped_kjt
112163
return dist_input
113164

@@ -136,8 +187,16 @@ def compute_and_output_dist(
136187
self, ctx: ITEPEmbeddingBagCollectionContext, input: KJTList
137188
) -> LazyAwaitable[KeyedTensor]:
138189
# Insert forward() function of GenericITEPModule into compute_and_output_dist()
139-
for i in range(len(input)):
140-
remapped_kjt = self._itep_module(input[i], self._iter.item())
190+
for i, (sharding, features) in enumerate(
191+
zip(
192+
self._embedding_bag_collection._sharding_types,
193+
input,
194+
)
195+
):
196+
if SHARDING_TYPE_TO_GROUP[sharding] == ShardingTypeGroup.CW_GROUP:
197+
remapped_kjt = self._itep_module(features, self._iter.item())
198+
else:
199+
remapped_kjt = self._rowwise_itep_module(features, self._iter.item())
141200
input[i] = remapped_kjt
142201
self._iter += 1
143202
ebc_awaitable = self._embedding_bag_collection.compute_and_output_dist(
@@ -148,6 +207,63 @@ def compute_and_output_dist(
148207
def create_context(self) -> ITEPEmbeddingBagCollectionContext:
149208
return ITEPEmbeddingBagCollectionContext()
150209

210+
# pyre-fixme[14]: `load_state_dict` overrides method defined in `Module`
211+
# inconsistently.
212+
def load_state_dict(
213+
self,
214+
state_dict: "OrderedDict[str, torch.Tensor]",
215+
strict: bool = True,
216+
) -> _IncompatibleKeys:
217+
missing_keys = []
218+
unexpected_keys = []
219+
self._iter = state_dict["_iter"]
220+
for name, child_module in self._modules.items():
221+
if child_module is not None:
222+
missing, unexpected = child_module.load_state_dict(
223+
filter_state_dict(state_dict, name),
224+
strict,
225+
)
226+
missing_keys.extend(missing)
227+
unexpected_keys.extend(unexpected)
228+
return _IncompatibleKeys(
229+
missing_keys=missing_keys, unexpected_keys=unexpected_keys
230+
)
231+
232+
def _group_lookups_and_table_unpruned_size_map(
233+
self, table_name_to_unpruned_hash_sizes: Dict[str, int]
234+
) -> Tuple[
235+
Dict[ShardingTypeGroup, List[nn.Module]],
236+
Dict[ShardingTypeGroup, Dict[str, int]],
237+
]:
238+
"""
239+
Group ebc lookups and table_name_to_unpruned_hash_sizes by sharding types.
240+
CW and TW are grouped into CW_GROUP, RW and TWRW are grouped into RW_GROUP.
241+
242+
Return a tuple of (grouped_lookups, grouped _table_unpruned_size_map)
243+
"""
244+
grouped_lookups: Dict[ShardingTypeGroup, List[nn.Module]] = defaultdict(list)
245+
grouped_table_unpruned_size_map: Dict[ShardingTypeGroup, Dict[str, int]] = (
246+
defaultdict(dict)
247+
)
248+
for sharding_type, lookup in zip(
249+
self._embedding_bag_collection._sharding_types,
250+
self._embedding_bag_collection._lookups,
251+
):
252+
sharding_group = SHARDING_TYPE_TO_GROUP[sharding_type]
253+
# group lookups
254+
grouped_lookups[sharding_group].append(lookup)
255+
# group table_name_to_unpruned_hash_sizes
256+
while isinstance(lookup, DistributedDataParallel):
257+
lookup = lookup.module
258+
for emb_config in lookup.grouped_configs:
259+
for table in emb_config.embedding_tables:
260+
if table.name in table_name_to_unpruned_hash_sizes.keys():
261+
grouped_table_unpruned_size_map[sharding_group][table.name] = (
262+
table_name_to_unpruned_hash_sizes[table.name]
263+
)
264+
265+
return grouped_lookups, grouped_table_unpruned_size_map
266+
151267

152268
class ITEPEmbeddingBagCollectionSharder(
153269
BaseEmbeddingSharder[ITEPEmbeddingBagCollection]
@@ -196,8 +312,5 @@ def module_type(self) -> Type[ITEPEmbeddingBagCollection]:
196312
return ITEPEmbeddingBagCollection
197313

198314
def sharding_types(self, compute_device_type: str) -> List[str]:
199-
types = [
200-
ShardingType.COLUMN_WISE.value,
201-
ShardingType.TABLE_WISE.value,
202-
]
315+
types = list(SHARDING_TYPE_TO_GROUP.keys())
203316
return types

0 commit comments

Comments
 (0)