-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathopen_learning.py
279 lines (215 loc) · 9.73 KB
/
open_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
""" Module for Open Learning """
from abc import ABC, abstractmethod
import numpy as np
from sklearn.metrics import matthews_corrcoef, f1_score
import torch
from torch.nn import Module
from torch.nn.functional import binary_cross_entropy
class OpenLearning(Module, ABC):
""" Abstract base class for open world learning """
@abstractmethod
def loss(self, logits, labels):
""" Return loss score to train model """
raise NotImplementedError("Abstract method called")
def fit(self, logits, labels):
""" Hook to learn additional parameters on whole training set """
return self
@abstractmethod
def predict(self, logits, subset=None):
""" Return most likely classes per instance """
raise NotImplementedError("Abstract method called")
@abstractmethod
def reject(self, logits, subset=None):
""" Return example-wise mask to emit 1 if reject and 0 otherwise """
raise NotImplementedError("Abstract method called")
def forward(self, logits, labels=None, subset=None):
reject_mask = self.reject(logits, subset=subset)
predictions = self.predict(logits, subset=subset)
loss = self.loss(logits, labels) if labels is not None else None
return reject_mask, predictions, loss
class DeepOpenClassification(OpenLearning):
"""
Deep Open ClassificatioN: Sigmoidal activation + Threshold based rejection
Inputs should *not* be activated in any way.
This module will apply sigmoid activations.
"""
def __init__(self, threshold: float = 0.5,
reduce_risk: bool = False, alpha: float = 3.0,
num_classes=None, use_class_weights=False):
"""
Arguments
---------
threshold: Threshold for class rejection
alpha: Factor of standard deviation to reduce open space risk
**kwargs: will be passed to BCEWithLogitsLoss
"""
super().__init__()
self.reduce_risk = bool(reduce_risk)
self.alpha = float(alpha)
self.threshold = float(threshold)
self.num_classes = num_classes
self.use_class_weights = use_class_weights
# Minimum threshold if reduce_risk is True,
# allows to call fit() multiple times
self.min_threshold = threshold
def loss(self, logits, labels):
if self.use_class_weights:
with torch.no_grad():
values, counts = torch.unique(labels,
return_counts=True)
# print("Num classes", self.num_classes)
# print("Values", values)
# print("Values.shape", values.shape)
# print("Counts", counts)
# print("Counts", counts.shape)
total = counts.sum()
# Neg examples / positive examples *per class*
class_weights = (total - counts) / counts
# print("Pos Weights", counts.shape)
# print("Pos Weights.shape", counts.shape)
# Default zero, but doesnt matter, as never seen.
pos_weight = torch.zeros(self.num_classes,
device=class_weights.device)
pos_weight[values] = class_weights
else:
pos_weight = None
criterion = torch.nn.BCEWithLogitsLoss(reduction='mean',
pos_weight=pos_weight)
targets = torch.nn.functional.one_hot(labels,
num_classes=logits.size(1))
return criterion(logits, targets.float())
def fit(self, logits, labels):
""" Gaussian fitting of the thresholds per class.
To be called on the full training set after actual training,
but before evaluation!
"""
if not self.reduce_risk:
print("[DOC/warning] fit() called but reduce_risk is False. Pass.")
return self
y = logits.detach().sigmoid() # [num_examples, num_classes]
# TODO: extend to online variant by computing *rolling* std. dev.?
# posterior "probabilities" p(y=l_i | x_j, y_j = li)
uniq_labels = labels.unique()
if self.num_classes is None:
# Infer #classes
num_classes = len(uniq_labels)
else:
num_classes = self.num_classes
std_per_class = torch.zeros(num_classes, device=logits.device)
for i in uniq_labels:
# Filter for y_j == li
y_i = y[labels == i, i]
# for each existing point,
# create a mirror point (not a probability),
# mirrored on the mean of 1
y_i_mirror = 1 + (1 - y_i) # [num_examples, num_classes]
# estimate the standard deviation per class
# using both existing and the created points
y_i_all = torch.cat([y_i, y_i_mirror], dim=0)
# TODO: unbiased SD? orig work did not specify...
std_i = y_i_all.std(dim=0, unbiased=True) # scalar
std_per_class[i] = std_i
print("SD per class:\n", std_per_class)
# Set the probability threshold t_i = max(0.5, 1 - alpha * SD_i)
# Orig paper uses base threshold 0.5,
# but we use a specified minimum threshold
thresholds_per_class = (1 - self.alpha * std_per_class).clamp(self.min_threshold)
self.threshold = thresholds_per_class # [num_classes]
print("Updated thresholds:\n", self.threshold)
return self
def reject(self, logits, subset=None):
with torch.no_grad():
if subset is not None:
logits = logits[:, subset]
# Reduce view on thresholds if subset is given,
# AND if self.threshold is not just a float
if subset is not None and not isinstance(self.threshold, float):
threshold = self.threshold[subset]
else:
threshold = self.threshold
y_proba = logits.sigmoid()
# Dim1 is reduced by 'all' anyways, no mapping back needed
reject_mask = (y_proba < threshold).all(dim=1)
return reject_mask
def predict(self, logits, subset=None):
with torch.no_grad():
if subset is not None:
print(f"Reducing view to {len(subset)} known classes")
logits = logits[:, subset]
print("Logits\n", logits)
y_proba = logits.sigmoid()
print("Logits after sigmoid\n", y_proba)
# Basic argmax
__max_vals, max_indices = torch.max(y_proba, dim=1)
return max_indices
class OpenMax(OpenLearning):
pass
##########################
# Module-level functions #
##########################
def add_args(parser):
parser.add_argument('--open_learning', default=None,
help="Method for self detection of unseen classes",
choices=["doc"])
parser.add_argument('--doc_threshold', default=0.5, type=float,
help="Threshold for DOC")
parser.add_argument('--doc_reduce_risk',
default=False, action='store_true',
help="Reduce Open Space Risk by Gaussian-fitting")
parser.add_argument('--doc_alpha', default=3.0,
help="Alpha for DOC")
parser.add_argument('--doc_class_weights', default=False,
action='store_true',
help="Use class weights against class imbalance")
def build(args, num_classes=None):
if args.open_learning == "doc":
return DeepOpenClassification(threshold=args.doc_threshold,
reduce_risk=args.doc_reduce_risk,
alpha=args.doc_alpha,
num_classes=num_classes,
use_class_weights=args.doc_class_weights)
elif args.open_learning == "openmax":
raise NotImplementedError("OpenMax not yet implemented")
else:
raise NotImplementedError(f"Unknown key: {args.open_learning}")
def bool2pmone(x):
""" Converts boolean mask to {-1,1}^N int array """
x = np.asarray(x, dtype=int)
return x * 2 - 1
def evaluate(labels, unseen_classes,
predictions, reject_mask):
# Shift stuff to CPU
labels = labels.cpu()
predictions = predictions.cpu()
reject_mask = reject_mask.cpu()
# Copy because we will later insert -100 for unseen
labels = labels.clone().numpy()
predictions = predictions.clone().numpy()
unseen = list(unseen_classes)
reject_mask = np.asarray(reject_mask)
print("Labels", labels)
print("Unseen", unseen)
true_reject = np.isin(labels, unseen)
# print(reject_mask.shape)
# print(true_reject.shape)
print("True reject", true_reject)
print("Reject mask", reject_mask)
print("False in true_reject:", False in true_reject)
print("True in true_reject:", True in true_reject)
print("False in reject_mask:", False in reject_mask)
print("True in reject_mask:", True in reject_mask)
tp = (reject_mask & true_reject).sum()
tn = (~reject_mask & ~true_reject).sum()
fp = (reject_mask & ~true_reject).sum()
fn = (~reject_mask & true_reject).sum()
# MCC
mcc = matthews_corrcoef(bool2pmone(true_reject),
bool2pmone(reject_mask))
# Open F1 Macro
labels[true_reject] = -100
print("True lables with -100 for unseen:", labels, labels.shape)
predictions[reject_mask] = -100
print("Predictions including rejected:", predictions, predictions.shape)
f1_macro = f1_score(labels, predictions, average='macro')
return {'open_mcc': mcc, 'open_f1_macro': f1_macro,
'open_tp': tp, 'open_tn': tn, 'open_fp': fp, 'open_fn': fn}