Skip to content

Commit b4366b1

Browse files
ZhengkaiZfacebook-github-bot
authored andcommitted
Back out "Fix dtype mismatch between weight and per_sample_weights" (#1791)
Summary: Pull Request resolved: #1791 in inference tbe_input_combine op does not support fp16, we will need to revert to avoid any prod impact. Original commit changeset: 969cb64c4af3 Original Phabricator Diff: D54526190 Reviewed By: peking2, houseroad Differential Revision: D54884092 fbshipit-source-id: 728d89ae1ccb6fd219510c378d88c1a9ac20c880
1 parent d23a447 commit b4366b1

File tree

1 file changed

+1
-12
lines changed

1 file changed

+1
-12
lines changed

torchrec/modules/embedding_modules.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ def __init__(
162162
self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
163163
self._embedding_bag_configs = tables
164164
self._lengths_per_embedding: List[int] = []
165-
self._dtypes: List[int] = []
166165

167166
table_names = set()
168167
for embedding_config in tables:
@@ -184,7 +183,6 @@ def __init__(
184183
)
185184
if device is None:
186185
device = self.embedding_bags[embedding_config.name].weight.device
187-
self._dtypes.append(embedding_config.data_type.value)
188186

189187
if not embedding_config.feature_names:
190188
embedding_config.feature_names = [embedding_config.name]
@@ -221,19 +219,10 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
221219
for i, embedding_bag in enumerate(self.embedding_bags.values()):
222220
for feature_name in self._feature_names[i]:
223221
f = feature_dict[feature_name]
224-
per_sample_weights: Optional[torch.Tensor] = None
225-
if self._is_weighted:
226-
per_sample_weights = (
227-
f.weights().half()
228-
if self._dtypes[i] == DataType.FP16.value
229-
else f.weights()
230-
)
231222
res = embedding_bag(
232223
input=f.values(),
233224
offsets=f.offsets(),
234-
per_sample_weights=(
235-
per_sample_weights if self._is_weighted else None
236-
),
225+
per_sample_weights=f.weights() if self._is_weighted else None,
237226
).float()
238227
pooled_embeddings.append(res)
239228
return KeyedTensor(

0 commit comments

Comments
 (0)