Skip to content

Commit b752fde

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
Introduce pt2 ir export unit test
Summary: Introduce unit testing torch.export Reviewed By: IvanKobzarev Differential Revision: D53426083 fbshipit-source-id: 79ab5e6de47786cd2be02df0fe8ffa05f97597ad
1 parent bc52746 commit b752fde

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

torchrec/distributed/tests/test_pt2.py

+27-4
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,20 @@ def _test_kjt_input_module(
103103
kjt_keys: List[str],
104104
# pyre-ignore
105105
inputs,
106+
test_dynamo: bool = True,
106107
test_aot_inductor: bool = True,
108+
test_pt2_ir_export: bool = False,
107109
) -> None:
108110
with dynamo_skipfiles_allow("torchrec"):
109111
EM: torch.nn.Module = KJTInputExportWrapper(kjt_input_module, kjt_keys)
110112
eager_output = EM(*inputs)
111-
x = torch._dynamo.export(EM, same_signature=True)(*inputs)
113+
if test_dynamo:
114+
x = torch._dynamo.export(EM, same_signature=True)(*inputs)
112115

113-
export_gm = x.graph_module
114-
export_gm_output = export_gm(*inputs)
116+
export_gm = x.graph_module
117+
export_gm_output = export_gm(*inputs)
115118

116-
assert_close(eager_output, export_gm_output)
119+
assert_close(eager_output, export_gm_output)
117120

118121
if test_aot_inductor:
119122
# pyre-ignore
@@ -127,6 +130,11 @@ def _test_kjt_input_module(
127130
aot_actual_output = aot_inductor_module(*inputs)
128131
assert_close(eager_output, aot_actual_output)
129132

133+
if test_pt2_ir_export:
134+
pt2_ir = torch.export.export(EM, inputs, {}, strict=False)
135+
pt2_ir_output = pt2_ir(*inputs)
136+
assert_close(eager_output, pt2_ir_output)
137+
130138
def test_kjt_split(self) -> None:
131139
class M(torch.nn.Module):
132140
def forward(self, kjt: KeyedJaggedTensor, segments: List[int]):
@@ -155,6 +163,21 @@ def forward(self, kjt: KeyedJaggedTensor, indices: List[int]):
155163
test_aot_inductor=False,
156164
)
157165

166+
def test_kjt_length_per_key(self) -> None:
167+
class M(torch.nn.Module):
168+
def forward(self, kjt: KeyedJaggedTensor):
169+
return kjt.length_per_key()
170+
171+
kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
172+
173+
self._test_kjt_input_module(
174+
M(),
175+
kjt.keys(),
176+
(kjt._values, kjt._lengths),
177+
test_aot_inductor=False,
178+
test_pt2_ir_export=True,
179+
)
180+
158181
# pyre-ignore
159182
@unittest.skipIf(
160183
torch.cuda.device_count() <= 1,

0 commit comments

Comments
 (0)