Skip to content

Commit 3ea57c5

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
tolist() support for FunctionalTensor
Summary: Support tolist() for FunctionalTensor for KJT in torch.export bypass-github-pytorch-ci-checks Reviewed By: ezyang Differential Revision: D53731064 fbshipit-source-id: a226bef5de4cdbf6aa0dcc1a6dee28c0874ac484
1 parent b752fde commit 3ea57c5

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

torchrec/distributed/tests/test_pt2.py

+15
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,18 @@ def test_maybe_compute_kjt_to_jt_dict(self) -> None:
240240
# TODO: turn on AOT Inductor test once the support is ready
241241
test_aot_inductor=False,
242242
)
243+
244+
def test_tensor_tolist(self) -> None:
245+
class M(torch.nn.Module):
246+
def forward(self, kjt: KeyedJaggedTensor):
247+
return kjt.values().tolist()
248+
249+
kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
250+
self._test_kjt_input_module(
251+
M(),
252+
kjt.keys(),
253+
(kjt._values, kjt._lengths),
254+
test_dynamo=False,
255+
test_aot_inductor=False,
256+
test_pt2_ir_export=True,
257+
)

0 commit comments

Comments
 (0)