Skip to content

Commit 1970969

Browse files
Min Yufacebook-github-bot
Min Yu
authored andcommitted
Fix the type mismatch issue during TGIF publish (#1796)
Summary: Pull Request resolved: #1796 Annotate _embedding_names_per_rank_per_sharding to make sure it's type of List[List[List[str]]]. The annotation has to be in a separate function with input parameter so that it won't be dropped during symbolic trace. Reviewed By: s4ayub Differential Revision: D54442150 fbshipit-source-id: 8940a31c405eb6a5cb947c057ca4e77a08b6f245
1 parent da1c013 commit 1970969

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

torchrec/distributed/quant_embedding.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ def _construct_jagged_tensors_tw(
134134
for i in range(len(embeddings)):
135135
embeddings_i: torch.Tensor = embeddings[i]
136136
features_i: KeyedJaggedTensor = features[i]
137+
if features_i.lengths().numel() == 0:
138+
# No table on the rank, skip.
139+
continue
137140

138141
lengths = features_i.lengths().view(-1, features_i.stride())
139142
values = features_i.values()
@@ -287,6 +290,34 @@ def _construct_jagged_tensors(
287290
return _construct_jagged_tensors_tw(embeddings, features, need_indices)
288291

289292

293+
# Wrap the annotation in a separate function with input parameter so that it won't be dropped during symbolic trace.
294+
# Please note the input parameter is necessary, though is not used, otherwise this function will be optimized.
295+
@torch.fx.has_side_effect
296+
@torch.fx.wrap
297+
def annotate_embedding_names(
298+
embedding_names: List[str],
299+
dummy: List[List[torch.Tensor]],
300+
) -> List[str]:
301+
return torch.jit.annotate(List[str], embedding_names)
302+
303+
304+
def format_embedding_names_per_rank_per_sharding(
305+
embedding_names_per_rank_per_sharding: List[List[List[str]]],
306+
dummy: List[List[torch.Tensor]],
307+
) -> List[List[List[str]]]:
308+
annotated_embedding_names_per_rank_per_sharding: List[List[List[str]]] = []
309+
for embedding_names_per_rank in embedding_names_per_rank_per_sharding:
310+
annotated_embedding_names_per_rank: List[List[str]] = []
311+
for embedding_names in embedding_names_per_rank:
312+
annotated_embedding_names_per_rank.append(
313+
annotate_embedding_names(embedding_names, dummy)
314+
)
315+
annotated_embedding_names_per_rank_per_sharding.append(
316+
annotated_embedding_names_per_rank
317+
)
318+
return annotated_embedding_names_per_rank_per_sharding
319+
320+
290321
@torch.fx.wrap
291322
def output_jt_dict(
292323
sharding_types: List[str],
@@ -709,11 +740,14 @@ def output_dist(
709740
# pyre-ignore
710741
sharding_ctx.features_before_input_dist
711742
)
743+
712744
return output_jt_dict(
713745
sharding_types=list(self._sharding_type_to_sharding.keys()),
714746
emb_per_sharding=emb_per_sharding,
715747
features_per_sharding=features_per_sharding,
716-
embedding_names_per_rank_per_sharding=self._embedding_names_per_rank_per_sharding,
748+
embedding_names_per_rank_per_sharding=format_embedding_names_per_rank_per_sharding(
749+
self._embedding_names_per_rank_per_sharding, output
750+
),
717751
need_indices=self._need_indices,
718752
features_before_input_dist_per_sharding=features_before_input_dist_per_sharding,
719753
unbucketize_tensor_idxs_per_sharding=unbucketize_tensor_idxs_per_sharding,

0 commit comments

Comments
 (0)