Skip to content

Commit 9d03945

Browse files
YazhiGaofacebook-github-bot
authored andcommitted
fix embedding dim for int8 output (#1792)
Summary: Pull Request resolved: #1792 * in trec, we introduced flattening and reshape operations while tensors shapes will be honored by tbe allocation directly. Reviewed By: xing-liu Differential Revision: D54885770 fbshipit-source-id: dcf4bfbb28495c017f9a4ee8e6390a3e9e723811
1 parent b4366b1 commit 9d03945

File tree

2 files changed

+4
-8
lines changed

2 files changed

+4
-8
lines changed

torchrec/distributed/embedding_lookup.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -602,9 +602,8 @@ def forward(
602602
)
603603
)
604604
for i in range(len(self._emb_modules)):
605-
embeddings.append(
606-
self._emb_modules[i].forward(features_by_group[i]).view(-1)
607-
)
605+
# 2d embedding by nature
606+
embeddings.append(self._emb_modules[i].forward(features_by_group[i]))
608607

609608
return embeddings_cat_empty_rank_handle_inference(
610609
embeddings, device=self.device, dtype=self.output_dtype

torchrec/distributed/quant_embedding.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -673,11 +673,8 @@ def compute(
673673
) -> List[List[torch.Tensor]]:
674674
ret: List[List[torch.Tensor]] = []
675675

676-
for lookup, features, sharding_type in zip(
677-
self._lookups, dist_input, self._sharding_type_to_sharding.keys()
678-
):
679-
embedding_dim = self._embedding_dim_for_sharding_type(sharding_type)
680-
ret.append([o.view(-1, embedding_dim) for o in lookup.forward(features)])
676+
for lookup, features in zip(self._lookups, dist_input):
677+
ret.append(lookup.forward(features))
681678
return ret
682679

683680
# pyre-ignore

0 commit comments

Comments
 (0)