Skip to content

Commit 6b8ecdd

Browse files
author
pytorchbot
committed
2025-01-15 nightly release (7c7468c)
1 parent 42fa9d0 commit 6b8ecdd

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

torchrec/sparse/tensor_dict.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from typing import List, Optional
9+
10+
import torch
11+
from tensordict import TensorDict
12+
13+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
14+
15+
16+
def maybe_td_to_kjt(
17+
features: KeyedJaggedTensor, keys: Optional[List[str]] = None
18+
) -> KeyedJaggedTensor:
19+
if torch.jit.is_scripting():
20+
assert isinstance(features, KeyedJaggedTensor)
21+
return features
22+
if isinstance(features, TensorDict):
23+
if keys is None:
24+
keys = list(features.keys())
25+
values = torch.cat([features[key]._values for key in keys], dim=0)
26+
lengths = torch.cat(
27+
[
28+
(
29+
(features[key]._lengths)
30+
if features[key]._lengths is not None
31+
else torch.diff(features[key]._offsets)
32+
)
33+
for key in keys
34+
],
35+
dim=0,
36+
)
37+
return KeyedJaggedTensor(
38+
keys=keys,
39+
values=values,
40+
lengths=lengths,
41+
)
42+
else:
43+
return features
+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
11+
import unittest
12+
13+
import torch
14+
from tensordict import TensorDict
15+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
16+
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
17+
from torchrec.sparse.tests.utils import repeat_test
18+
19+
20+
class TestTensorDIct(unittest.TestCase):
21+
@repeat_test(device_str=["cpu", "cuda", "meta"])
22+
def test_kjt_input(self, device_str: str) -> None:
23+
device = torch.device(device_str)
24+
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
25+
kjt = KeyedJaggedTensor.from_offsets_sync(
26+
keys=["f1", "f2", "f3"],
27+
values=values,
28+
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7], device=device),
29+
)
30+
features = maybe_td_to_kjt(kjt)
31+
self.assertEqual(features, kjt)
32+
33+
@repeat_test(device_str=["cpu", "cuda", "meta"])
34+
def test_td_kjt(self, device_str: str) -> None:
35+
device = torch.device(device_str)
36+
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
37+
lengths = torch.tensor([2, 0, 1, 1, 1, 2], device=device)
38+
data = {
39+
"f2": torch.nested.nested_tensor_from_jagged(
40+
torch.tensor([2, 3], device=device),
41+
lengths=torch.tensor([1, 1], device=device),
42+
),
43+
"f1": torch.nested.nested_tensor_from_jagged(
44+
torch.arange(2, device=device),
45+
offsets=torch.tensor([0, 2, 2], device=device),
46+
),
47+
"f3": torch.nested.nested_tensor_from_jagged(
48+
torch.tensor([2, 3, 4], device=device),
49+
lengths=torch.tensor([1, 2], device=device),
50+
),
51+
}
52+
td = TensorDict(
53+
data, # type: ignore[arg-type]
54+
device=device,
55+
batch_size=[2],
56+
)
57+
58+
features = maybe_td_to_kjt(td, ["f1", "f2", "f3"]) # pyre-ignore[6]
59+
torch.testing.assert_close(features.values(), values)
60+
torch.testing.assert_close(features.lengths(), lengths)

0 commit comments

Comments
 (0)