Skip to content

Commit 8988ead

Browse files
Ivan Kobzarevfacebook-github-bot
Ivan Kobzarev
authored andcommitted
default _use_segment_sum_csr for dynamo (#1798)
Summary: Pull Request resolved: #1798 Constanting heuristic for dynamo for now. Dynamo can not pass it without concrete values of the batch compile time. Reviewed By: joshuadeng Differential Revision: D54966335 fbshipit-source-id: 1359b1bbc6e50061c65a312df66fc5b2bf01a8b1
1 parent 1970969 commit 8988ead

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torchrec/sparse/jagged_tensor.py

+4
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,10 @@ def _use_segment_sum_csr(stride_per_key: List[int]) -> bool:
727727
per segment that match performance between the kernel and PyTorch solution, to
728728
determine the threshold of when to use `segment_sum_csr`.
729729
"""
730+
if is_torchdynamo_compiling():
731+
# dynamo symbolic shapes can not pass this condition without concrete stride values
732+
return False
733+
730734
elements_per_segment = sum(stride_per_key) / len(stride_per_key)
731735
segment_threshold = int(
732736
1.39771

0 commit comments

Comments
 (0)