Skip to content

Commit 6a415a7

Browse files
joshuadengfacebook-github-bot
authored andcommitted
avoid permuting inverse indices when not necessary (#1807)
Summary: Pull Request resolved: #1807 if inverse indices and embedding names are in the same order + no duplicates we can avoid the index select to permute the inverse indices tensor Reviewed By: zainhuda Differential Revision: D55039535 fbshipit-source-id: c00a193e5a8f34914cd67c627cb88d79e02cfdc9
1 parent ce7b919 commit 6a415a7

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

torchrec/distributed/embeddingbag.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def __init__(
311311
self,
312312
awaitables: List[Awaitable[torch.Tensor]],
313313
inverse_indices: Tuple[List[str], torch.Tensor],
314-
inverse_indices_permute_indices: torch.Tensor,
314+
inverse_indices_permute_indices: Optional[torch.Tensor],
315315
batch_size_per_feature_pre_a2a: List[int],
316316
uncombined_embedding_dims: List[int],
317317
embedding_names: List[str],
@@ -331,9 +331,11 @@ def __init__(
331331
def _wait_impl(self) -> KeyedTensor:
332332
embeddings = [w.wait() for w in self._awaitables]
333333
batch_size = self._inverse_indices[1].numel() // len(self._inverse_indices[0])
334-
indices = torch.index_select(
335-
self._inverse_indices[1], 0, self._inverse_indices_permute_indices
336-
)
334+
permute_indices = self._inverse_indices_permute_indices
335+
if permute_indices is not None:
336+
indices = torch.index_select(self._inverse_indices[1], 0, permute_indices)
337+
else:
338+
indices = self._inverse_indices[1]
337339
reindex_output = torch.ops.fbgemm.batch_index_select_dim0(
338340
inputs=embeddings[0] if len(embeddings) == 1 else torch.cat(embeddings),
339341
indices=indices.view(-1),
@@ -768,25 +770,25 @@ def _create_inverse_indices_permute_indices(
768770
index_per_name[name.split("@")[0]]
769771
for name in self._uncombined_embedding_names
770772
]
771-
self._inverse_indices_permute_indices = _pin_and_move(
772-
torch.tensor(permute_indices),
773-
inverse_indices[1].device,
774-
)
773+
if len(permute_indices) != len(index_per_name) or permute_indices != sorted(
774+
permute_indices
775+
):
776+
self._inverse_indices_permute_indices = _pin_and_move(
777+
torch.tensor(permute_indices),
778+
inverse_indices[1].device,
779+
)
775780

776781
# pyre-ignore [14]
777782
def input_dist(
778783
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
779784
) -> Awaitable[Awaitable[KJTList]]:
785+
ctx.variable_batch_per_feature = features.variable_stride_per_key()
786+
ctx.inverse_indices = features.inverse_indices_or_none()
780787
if self._has_uninitialized_input_dist:
781788
self._create_input_dist(features.keys())
782789
self._has_uninitialized_input_dist = False
783-
ctx.variable_batch_per_feature = features.variable_stride_per_key()
784-
ctx.inverse_indices = features.inverse_indices_or_none()
785-
if (
786-
ctx.variable_batch_per_feature
787-
and self._inverse_indices_permute_indices is None
788-
):
789-
self._create_inverse_indices_permute_indices(ctx.inverse_indices)
790+
if ctx.variable_batch_per_feature:
791+
self._create_inverse_indices_permute_indices(ctx.inverse_indices)
790792
with torch.no_grad():
791793
if self._has_features_permute:
792794
features = features.permute(

0 commit comments

Comments
 (0)