8
8
# pyre-strict
9
9
10
10
import unittest
11
- from typing import Dict , Iterable , Union
11
+ from typing import Any , Dict , Iterable , Union
12
12
13
13
import torch
14
14
from torch import no_grad
@@ -31,6 +31,8 @@ def _test_segemented_ne_helper(
31
31
weights : torch .Tensor ,
32
32
expected_ne : torch .Tensor ,
33
33
grouping_keys : torch .Tensor ,
34
+ grouping_key_tensor_name : str = "grouping_keys" ,
35
+ cast_keys_to_int : bool = False ,
34
36
) -> None :
35
37
num_task = labels .shape [0 ]
36
38
batch_size = labels .shape [0 ]
@@ -41,7 +43,7 @@ def _test_segemented_ne_helper(
41
43
"weights" : {},
42
44
}
43
45
if grouping_keys is not None :
44
- inputs ["required_inputs" ] = {"grouping_keys" : grouping_keys }
46
+ inputs ["required_inputs" ] = {grouping_key_tensor_name : grouping_keys }
45
47
for i in range (num_task ):
46
48
task_info = RecTaskInfo (
47
49
name = f"Task:{ i } " ,
@@ -64,6 +66,10 @@ def _test_segemented_ne_helper(
64
66
tasks = task_list ,
65
67
# pyre-ignore
66
68
num_groups = max (2 , torch .unique (grouping_keys )[- 1 ].item () + 1 ),
69
+ # pyre-ignore
70
+ grouping_keys = grouping_key_tensor_name ,
71
+ # pyre-ignore
72
+ cast_keys_to_int = cast_keys_to_int ,
67
73
)
68
74
ne .update (** inputs )
69
75
actual_ne = ne .compute ()
@@ -95,7 +101,7 @@ def test_grouped_ne(self) -> None:
95
101
raise
96
102
97
103
98
- def generate_model_outputs_cases () -> Iterable [Dict [str , torch . _tensor . Tensor ]]:
104
+ def generate_model_outputs_cases () -> Iterable [Dict [str , Any ]]:
99
105
return [
100
106
# base condition
101
107
{
@@ -149,4 +155,23 @@ def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]:
149
155
), # for this case, both tasks have same groupings
150
156
"expected_ne" : torch .tensor ([[3.1615 , 1.6004 ], [1.0034 , 0.4859 ]]),
151
157
},
158
+ # Custom grouping key tensor name
159
+ {
160
+ "labels" : torch .tensor ([[1 , 0 , 0 , 1 , 1 ]]),
161
+ "predictions" : torch .tensor ([[0.2 , 0.6 , 0.8 , 0.4 , 0.9 ]]),
162
+ "weights" : torch .tensor ([[0.13 , 0.2 , 0.5 , 0.8 , 0.75 ]]),
163
+ "grouping_keys" : torch .tensor ([0 , 1 , 0 , 1 , 1 ]),
164
+ "expected_ne" : torch .tensor ([[3.1615 , 1.6004 ]]),
165
+ "grouping_key_tensor_name" : "custom_key" ,
166
+ },
167
+ # Cast grouping keys to int32
168
+ {
169
+ "labels" : torch .tensor ([[1 , 0 , 0 , 1 , 1 ]]),
170
+ "predictions" : torch .tensor ([[0.2 , 0.6 , 0.8 , 0.4 , 0.9 ]]),
171
+ "weights" : torch .tensor ([[0.13 , 0.2 , 0.5 , 0.8 , 0.75 ]]),
172
+ "grouping_keys" : torch .tensor ([0.0 , 1.0 , 0.0 , 1.0 , 1.0 ]),
173
+ "expected_ne" : torch .tensor ([[3.1615 , 1.6004 ]]),
174
+ "grouping_key_tensor_name" : "custom_key" ,
175
+ "cast_keys_to_int" : True ,
176
+ },
152
177
]
0 commit comments