7
7
8
8
# pyre-strict
9
9
10
+ from collections import defaultdict
10
11
from dataclasses import dataclass
11
- from typing import Dict , List , Optional , Type , Union
12
+ from enum import Enum
13
+ from typing import Dict , List , Optional , OrderedDict , Tuple , Type , Union
12
14
13
15
import torch
16
+ from torch import nn
17
+ from torch .nn .modules .module import _IncompatibleKeys
18
+ from torch .nn .parallel import DistributedDataParallel
14
19
15
20
from torchrec .distributed .embedding_types import (
16
21
BaseEmbeddingSharder ,
30
35
ShardingEnv ,
31
36
ShardingType ,
32
37
)
38
+ from torchrec .distributed .utils import filter_state_dict
33
39
from torchrec .modules .itep_embedding_modules import ITEPEmbeddingBagCollection
34
- from torchrec .modules .itep_modules import GenericITEPModule
40
+ from torchrec .modules .itep_modules import GenericITEPModule , RowwiseShardedITEPModule
35
41
from torchrec .sparse .jagged_tensor import KeyedJaggedTensor , KeyedTensor
36
42
37
43
@@ -40,6 +46,19 @@ class ITEPEmbeddingBagCollectionContext(EmbeddingBagCollectionContext):
40
46
is_reindexed : bool = False
41
47
42
48
49
+ class ShardingTypeGroup (Enum ):
50
+ CW_GROUP = "column_wise_group"
51
+ RW_GROUP = "row_wise_group"
52
+
53
+
54
+ SHARDING_TYPE_TO_GROUP : Dict [str , ShardingTypeGroup ] = {
55
+ ShardingType .ROW_WISE .value : ShardingTypeGroup .RW_GROUP ,
56
+ ShardingType .TABLE_ROW_WISE .value : ShardingTypeGroup .RW_GROUP ,
57
+ ShardingType .COLUMN_WISE .value : ShardingTypeGroup .CW_GROUP ,
58
+ ShardingType .TABLE_WISE .value : ShardingTypeGroup .CW_GROUP ,
59
+ }
60
+
61
+
43
62
class ShardedITEPEmbeddingBagCollection (
44
63
ShardedEmbeddingModule [
45
64
KJTList ,
@@ -75,10 +94,34 @@ def __init__(
75
94
)
76
95
)
77
96
97
+ self .table_name_to_sharding_type : Dict [str , str ] = {}
98
+ for table_name in table_name_to_parameter_sharding .keys ():
99
+ self .table_name_to_sharding_type [table_name ] = (
100
+ table_name_to_parameter_sharding [table_name ].sharding_type
101
+ )
102
+
103
+ # Group lookups, table_name_to_unpruned_hash_sizes by sharding type and pass to separate itep modules
104
+ (grouped_lookups , grouped_table_unpruned_size_map ) = (
105
+ self ._group_lookups_and_table_unpruned_size_map (
106
+ module ._itep_module .table_name_to_unpruned_hash_sizes ,
107
+ )
108
+ )
109
+
78
110
# Instantiate ITEP Module in sharded case, re-using metadata from non-sharded case
79
111
self ._itep_module : GenericITEPModule = GenericITEPModule (
80
- table_name_to_unpruned_hash_sizes = module ._itep_module .table_name_to_unpruned_hash_sizes ,
81
- lookups = self ._embedding_bag_collection ._lookups ,
112
+ table_name_to_unpruned_hash_sizes = grouped_table_unpruned_size_map [
113
+ ShardingTypeGroup .CW_GROUP
114
+ ],
115
+ lookups = grouped_lookups [ShardingTypeGroup .CW_GROUP ],
116
+ pruning_interval = module ._itep_module .pruning_interval ,
117
+ enable_pruning = module ._itep_module .enable_pruning ,
118
+ )
119
+ self ._rowwise_itep_module : RowwiseShardedITEPModule = RowwiseShardedITEPModule (
120
+ table_name_to_sharding_type = self .table_name_to_sharding_type ,
121
+ table_name_to_unpruned_hash_sizes = grouped_table_unpruned_size_map [
122
+ ShardingTypeGroup .RW_GROUP
123
+ ],
124
+ lookups = grouped_lookups [ShardingTypeGroup .RW_GROUP ],
82
125
pruning_interval = module ._itep_module .pruning_interval ,
83
126
enable_pruning = module ._itep_module .enable_pruning ,
84
127
)
@@ -106,8 +149,16 @@ def input_dist(
106
149
return self ._embedding_bag_collection .input_dist (ctx , features )
107
150
108
151
def _reindex (self , dist_input : KJTList ) -> KJTList :
109
- for i in range (len (dist_input )):
110
- remapped_kjt = self ._itep_module (dist_input [i ], self ._iter .item ())
152
+ for i , (sharding , features ) in enumerate (
153
+ zip (
154
+ self ._embedding_bag_collection ._sharding_types ,
155
+ dist_input ,
156
+ )
157
+ ):
158
+ if SHARDING_TYPE_TO_GROUP [sharding ] == ShardingTypeGroup .CW_GROUP :
159
+ remapped_kjt = self ._itep_module (features , self ._iter .item ())
160
+ else :
161
+ remapped_kjt = self ._rowwise_itep_module (features , self ._iter .item ())
111
162
dist_input [i ] = remapped_kjt
112
163
return dist_input
113
164
@@ -136,8 +187,16 @@ def compute_and_output_dist(
136
187
self , ctx : ITEPEmbeddingBagCollectionContext , input : KJTList
137
188
) -> LazyAwaitable [KeyedTensor ]:
138
189
# Insert forward() function of GenericITEPModule into compute_and_output_dist()
139
- for i in range (len (input )):
140
- remapped_kjt = self ._itep_module (input [i ], self ._iter .item ())
190
+ for i , (sharding , features ) in enumerate (
191
+ zip (
192
+ self ._embedding_bag_collection ._sharding_types ,
193
+ input ,
194
+ )
195
+ ):
196
+ if SHARDING_TYPE_TO_GROUP [sharding ] == ShardingTypeGroup .CW_GROUP :
197
+ remapped_kjt = self ._itep_module (features , self ._iter .item ())
198
+ else :
199
+ remapped_kjt = self ._rowwise_itep_module (features , self ._iter .item ())
141
200
input [i ] = remapped_kjt
142
201
self ._iter += 1
143
202
ebc_awaitable = self ._embedding_bag_collection .compute_and_output_dist (
@@ -148,6 +207,63 @@ def compute_and_output_dist(
148
207
def create_context (self ) -> ITEPEmbeddingBagCollectionContext :
149
208
return ITEPEmbeddingBagCollectionContext ()
150
209
210
+ # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module`
211
+ # inconsistently.
212
+ def load_state_dict (
213
+ self ,
214
+ state_dict : "OrderedDict[str, torch.Tensor]" ,
215
+ strict : bool = True ,
216
+ ) -> _IncompatibleKeys :
217
+ missing_keys = []
218
+ unexpected_keys = []
219
+ self ._iter = state_dict ["_iter" ]
220
+ for name , child_module in self ._modules .items ():
221
+ if child_module is not None :
222
+ missing , unexpected = child_module .load_state_dict (
223
+ filter_state_dict (state_dict , name ),
224
+ strict ,
225
+ )
226
+ missing_keys .extend (missing )
227
+ unexpected_keys .extend (unexpected )
228
+ return _IncompatibleKeys (
229
+ missing_keys = missing_keys , unexpected_keys = unexpected_keys
230
+ )
231
+
232
+ def _group_lookups_and_table_unpruned_size_map (
233
+ self , table_name_to_unpruned_hash_sizes : Dict [str , int ]
234
+ ) -> Tuple [
235
+ Dict [ShardingTypeGroup , List [nn .Module ]],
236
+ Dict [ShardingTypeGroup , Dict [str , int ]],
237
+ ]:
238
+ """
239
+ Group ebc lookups and table_name_to_unpruned_hash_sizes by sharding types.
240
+ CW and TW are grouped into CW_GROUP, RW and TWRW are grouped into RW_GROUP.
241
+
242
+ Return a tuple of (grouped_lookups, grouped _table_unpruned_size_map)
243
+ """
244
+ grouped_lookups : Dict [ShardingTypeGroup , List [nn .Module ]] = defaultdict (list )
245
+ grouped_table_unpruned_size_map : Dict [ShardingTypeGroup , Dict [str , int ]] = (
246
+ defaultdict (dict )
247
+ )
248
+ for sharding_type , lookup in zip (
249
+ self ._embedding_bag_collection ._sharding_types ,
250
+ self ._embedding_bag_collection ._lookups ,
251
+ ):
252
+ sharding_group = SHARDING_TYPE_TO_GROUP [sharding_type ]
253
+ # group lookups
254
+ grouped_lookups [sharding_group ].append (lookup )
255
+ # group table_name_to_unpruned_hash_sizes
256
+ while isinstance (lookup , DistributedDataParallel ):
257
+ lookup = lookup .module
258
+ for emb_config in lookup .grouped_configs :
259
+ for table in emb_config .embedding_tables :
260
+ if table .name in table_name_to_unpruned_hash_sizes .keys ():
261
+ grouped_table_unpruned_size_map [sharding_group ][table .name ] = (
262
+ table_name_to_unpruned_hash_sizes [table .name ]
263
+ )
264
+
265
+ return grouped_lookups , grouped_table_unpruned_size_map
266
+
151
267
152
268
class ITEPEmbeddingBagCollectionSharder (
153
269
BaseEmbeddingSharder [ITEPEmbeddingBagCollection ]
@@ -196,8 +312,5 @@ def module_type(self) -> Type[ITEPEmbeddingBagCollection]:
196
312
return ITEPEmbeddingBagCollection
197
313
198
314
def sharding_types (self , compute_device_type : str ) -> List [str ]:
199
- types = [
200
- ShardingType .COLUMN_WISE .value ,
201
- ShardingType .TABLE_WISE .value ,
202
- ]
315
+ types = list (SHARDING_TYPE_TO_GROUP .keys ())
203
316
return types
0 commit comments