Skip to content

Commit 49fbc5f

Browse files
Ivan Kobzarevfacebook-github-bot
Ivan Kobzarev
authored andcommitted
Fix _length_per_key_from_stride_per_key empty cat
Summary: torch.cat fails on empty list, guarding this case. Reviewed By: zainhuda Differential Revision: D54305327 fbshipit-source-id: 82877e4f307631eed816a60b35e8b1ca52104b32
1 parent f1c716a commit 49fbc5f

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

torchrec/sparse/jagged_tensor.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -728,9 +728,13 @@ def _length_per_key_from_stride_per_key(
728728
1, stride_per_key_offsets, lengths
729729
).tolist()
730730
else:
731-
return torch.cat(
732-
[torch.sum(chunk).view(1) for chunk in torch.split(lengths, stride_per_key)]
733-
).tolist()
731+
tensor_list: List[torch.Tensor] = [
732+
torch.sum(chunk).view(1) for chunk in torch.split(lengths, stride_per_key)
733+
]
734+
if len(tensor_list) == 0:
735+
return []
736+
737+
return torch.cat(tensor_list).tolist()
734738

735739

736740
def _maybe_compute_length_per_key(

0 commit comments

Comments
 (0)