@@ -311,7 +311,7 @@ def __init__(
311
311
self ,
312
312
awaitables : List [Awaitable [torch .Tensor ]],
313
313
inverse_indices : Tuple [List [str ], torch .Tensor ],
314
- inverse_indices_permute_indices : torch .Tensor ,
314
+ inverse_indices_permute_indices : Optional [ torch .Tensor ] ,
315
315
batch_size_per_feature_pre_a2a : List [int ],
316
316
uncombined_embedding_dims : List [int ],
317
317
embedding_names : List [str ],
@@ -331,9 +331,11 @@ def __init__(
331
331
def _wait_impl (self ) -> KeyedTensor :
332
332
embeddings = [w .wait () for w in self ._awaitables ]
333
333
batch_size = self ._inverse_indices [1 ].numel () // len (self ._inverse_indices [0 ])
334
- indices = torch .index_select (
335
- self ._inverse_indices [1 ], 0 , self ._inverse_indices_permute_indices
336
- )
334
+ permute_indices = self ._inverse_indices_permute_indices
335
+ if permute_indices is not None :
336
+ indices = torch .index_select (self ._inverse_indices [1 ], 0 , permute_indices )
337
+ else :
338
+ indices = self ._inverse_indices [1 ]
337
339
reindex_output = torch .ops .fbgemm .batch_index_select_dim0 (
338
340
inputs = embeddings [0 ] if len (embeddings ) == 1 else torch .cat (embeddings ),
339
341
indices = indices .view (- 1 ),
@@ -768,25 +770,25 @@ def _create_inverse_indices_permute_indices(
768
770
index_per_name [name .split ("@" )[0 ]]
769
771
for name in self ._uncombined_embedding_names
770
772
]
771
- self ._inverse_indices_permute_indices = _pin_and_move (
772
- torch .tensor (permute_indices ),
773
- inverse_indices [1 ].device ,
774
- )
773
+ if len (permute_indices ) != len (index_per_name ) or permute_indices != sorted (
774
+ permute_indices
775
+ ):
776
+ self ._inverse_indices_permute_indices = _pin_and_move (
777
+ torch .tensor (permute_indices ),
778
+ inverse_indices [1 ].device ,
779
+ )
775
780
776
781
# pyre-ignore [14]
777
782
def input_dist (
778
783
self , ctx : EmbeddingBagCollectionContext , features : KeyedJaggedTensor
779
784
) -> Awaitable [Awaitable [KJTList ]]:
785
+ ctx .variable_batch_per_feature = features .variable_stride_per_key ()
786
+ ctx .inverse_indices = features .inverse_indices_or_none ()
780
787
if self ._has_uninitialized_input_dist :
781
788
self ._create_input_dist (features .keys ())
782
789
self ._has_uninitialized_input_dist = False
783
- ctx .variable_batch_per_feature = features .variable_stride_per_key ()
784
- ctx .inverse_indices = features .inverse_indices_or_none ()
785
- if (
786
- ctx .variable_batch_per_feature
787
- and self ._inverse_indices_permute_indices is None
788
- ):
789
- self ._create_inverse_indices_permute_indices (ctx .inverse_indices )
790
+ if ctx .variable_batch_per_feature :
791
+ self ._create_inverse_indices_permute_indices (ctx .inverse_indices )
790
792
with torch .no_grad ():
791
793
if self ._has_features_permute :
792
794
features = features .permute (
0 commit comments