@@ -103,17 +103,20 @@ def _test_kjt_input_module(
103
103
kjt_keys : List [str ],
104
104
# pyre-ignore
105
105
inputs ,
106
+ test_dynamo : bool = True ,
106
107
test_aot_inductor : bool = True ,
108
+ test_pt2_ir_export : bool = False ,
107
109
) -> None :
108
110
with dynamo_skipfiles_allow ("torchrec" ):
109
111
EM : torch .nn .Module = KJTInputExportWrapper (kjt_input_module , kjt_keys )
110
112
eager_output = EM (* inputs )
111
- x = torch ._dynamo .export (EM , same_signature = True )(* inputs )
113
+ if test_dynamo :
114
+ x = torch ._dynamo .export (EM , same_signature = True )(* inputs )
112
115
113
- export_gm = x .graph_module
114
- export_gm_output = export_gm (* inputs )
116
+ export_gm = x .graph_module
117
+ export_gm_output = export_gm (* inputs )
115
118
116
- assert_close (eager_output , export_gm_output )
119
+ assert_close (eager_output , export_gm_output )
117
120
118
121
if test_aot_inductor :
119
122
# pyre-ignore
@@ -127,6 +130,11 @@ def _test_kjt_input_module(
127
130
aot_actual_output = aot_inductor_module (* inputs )
128
131
assert_close (eager_output , aot_actual_output )
129
132
133
+ if test_pt2_ir_export :
134
+ pt2_ir = torch .export .export (EM , inputs , {}, strict = False )
135
+ pt2_ir_output = pt2_ir (* inputs )
136
+ assert_close (eager_output , pt2_ir_output )
137
+
130
138
def test_kjt_split (self ) -> None :
131
139
class M (torch .nn .Module ):
132
140
def forward (self , kjt : KeyedJaggedTensor , segments : List [int ]):
@@ -155,6 +163,21 @@ def forward(self, kjt: KeyedJaggedTensor, indices: List[int]):
155
163
test_aot_inductor = False ,
156
164
)
157
165
166
+ def test_kjt_length_per_key (self ) -> None :
167
+ class M (torch .nn .Module ):
168
+ def forward (self , kjt : KeyedJaggedTensor ):
169
+ return kjt .length_per_key ()
170
+
171
+ kjt : KeyedJaggedTensor = make_kjt ([2 , 3 , 4 , 5 , 6 ], [1 , 2 , 1 , 1 ])
172
+
173
+ self ._test_kjt_input_module (
174
+ M (),
175
+ kjt .keys (),
176
+ (kjt ._values , kjt ._lengths ),
177
+ test_aot_inductor = False ,
178
+ test_pt2_ir_export = True ,
179
+ )
180
+
158
181
# pyre-ignore
159
182
@unittest .skipIf (
160
183
torch .cuda .device_count () <= 1 ,
0 commit comments