Skip to content

Commit 5856c4d

Browse files
ge0405facebook-github-bot
authored andcommitted
Fix dtype mismatch between weight and per_sample_weights (#1758)
Summary: Pull Request resolved: #1758 # Background T176105639 |case |embedding bag weight |per_sample_weight |nn.EmbeddingBag, device="cpu"|nn.EmbeddingBag, device="cuda" |nn.EmbeddingBag, device="meta" |fbgemm lookup | |A|fp32|fp32|good|good|good|good| |B|fp16|fp32|Error:Expected tensor for argument #1 'weight' to have the same type as tensor for argument #1 'per_sample_weights'; but type torch.HalfTensor does not equal torch.FloatTensor |Error:expected scalar type Half but found Float|failed [check](https://fburl.com/code/ng9pv1vp) that forces weight dtype == per_sample_weights dtype|good| |C|fp16|fp16|good|good|good|good now with D54370192. Previous error: P1046999270, RuntimeError: "expected scalar type Float but found Half from fbgemm call"| Notebook to see nn.EmbeddingBag forward errors: N5007274. Currently we are in case A. Users need to add `use_fp32_embedding` in training to force embedding bag dtype to be fp32. However, users actually hope for case B to use fp16 as the embedding bag weight to reduce memory usage. When deleting `use_fp32_embedding`, they would fail the [check that forces weight dtype == per_sample_weights dtype](https://www.internalfb.com/code/fbsource/[e750b9f69f7f758682000804409456103510078c]/fbcode/caffe2/torch/_meta_registrations.py?lines=3521-3524) in meta_registration. Therefore, this diff aims to achieve case C - make dtype the same between embedding module weight and per_sample_weights. With the backend fbgemm lookup to support Half for per_sample_weights (D54370192), this diff introduces `dtype` in all feature process classes and initializes per_sample_weights according to the passed dtype. # Reference diffs to resolve this issue Diff 1: D52591217 This passes embedding bag dtype to feature_processor to make per_sample_weights same dtype as emb bag weight. However, is_meta also needs to be passed because of case C because fbgemm did not support per_sample_weights = fp16 (see the above table) at that time. Therefore users were forced to only make per_sample_weights fp16 when it is on meta. The solution requires too many hacks. Diff 2: D53232739 Basically doing the same thing in diff 1 D52591217, except that the hack is added in TorchRec library. This adds an if in EBC and PEA for: when emb bag weight is fp16, it forces per_sample_weight fp16 too. Reviewed By: henrylhtsang Differential Revision: D54526190 fbshipit-source-id: 969cb64c4af345ea222e8a2c7e5be0d9af0d0ae3
1 parent 1cd088f commit 5856c4d

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

torchrec/modules/embedding_modules.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def __init__(
162162
self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
163163
self._embedding_bag_configs = tables
164164
self._lengths_per_embedding: List[int] = []
165+
self._dtypes: List[int] = []
165166

166167
table_names = set()
167168
for embedding_config in tables:
@@ -183,6 +184,7 @@ def __init__(
183184
)
184185
if device is None:
185186
device = self.embedding_bags[embedding_config.name].weight.device
187+
self._dtypes.append(embedding_config.data_type.value)
186188

187189
if not embedding_config.feature_names:
188190
embedding_config.feature_names = [embedding_config.name]
@@ -219,10 +221,19 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
219221
for i, embedding_bag in enumerate(self.embedding_bags.values()):
220222
for feature_name in self._feature_names[i]:
221223
f = feature_dict[feature_name]
224+
per_sample_weights: Optional[torch.Tensor] = None
225+
if self._is_weighted:
226+
per_sample_weights = (
227+
f.weights().half()
228+
if self._dtypes[i] == DataType.FP16.value
229+
else f.weights()
230+
)
222231
res = embedding_bag(
223232
input=f.values(),
224233
offsets=f.offsets(),
225-
per_sample_weights=f.weights() if self._is_weighted else None,
234+
per_sample_weights=(
235+
per_sample_weights if self._is_weighted else None
236+
),
226237
).float()
227238
pooled_embeddings.append(res)
228239
return KeyedTensor(

0 commit comments

Comments
 (0)