-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTabCaps_Task.py
104 lines (87 loc) · 3.03 KB
/
TabCaps_Task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import numpy as np
from scipy.special import softmax
from lib.utils import PredictDataset
from abstract_model import TabCapsModel
from lib.multiclass_utils import infer_output_dim, check_output_dim
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy, mse_loss
class TabCapsClassifier(TabCapsModel):
def __post_init__(self):
super(TabCapsClassifier, self).__post_init__()
self._task = 'classification'
self._default_loss = cross_entropy
self._default_metric = 'logloss'
def weight_updater(self, weights):
"""
Updates weights dictionary according to target_mapper.
Parameters
----------
weights : bool or dict
Given weights for balancing training.
Returns
-------
bool or dict
Same bool if weights are bool, updated dict otherwise.
"""
if isinstance(weights, int):
return weights
elif isinstance(weights, dict):
return {self.target_mapper[key]: value for key, value in weights.items()}
else:
return weights
def prepare_target(self, y):
return np.vectorize(self.target_mapper.get)(y)
def compute_loss(self, y_pred, y_true):
return self.loss_fn(y_pred, y_true)
def update_fit_params(
self,
X_train,
y_train,
eval_set
):
output_dim, train_labels = infer_output_dim(y_train)
for X, y in eval_set:
check_output_dim(train_labels, y)
self.output_dim = output_dim
self._default_metric = ('auc' if self.output_dim == 2 else 'accuracy')
self.classes_ = train_labels
self.target_mapper = {
class_label: index for index, class_label in enumerate(self.classes_)
}
self.preds_mapper = {
str(index): class_label for index, class_label in enumerate(self.classes_)
}
def stack_batches(self, list_y_true, list_y_score):
y_true = np.hstack(list_y_true)
y_score = np.vstack(list_y_score)
y_score = softmax(y_score, axis=1)
return y_true, y_score
def predict_func(self, outputs):
return outputs
# return np.vectorize(self.preds_mapper.get)(outputs.astype(str))
def predict_proba(self, X):
"""
Make predictions for classification on a batch (valid)
Parameters
----------
X : a :tensor: `torch.Tensor`
Input data
Returns
-------
res : np.ndarray
"""
self.network.eval()
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
results = []
for batch_nb, data in enumerate(dataloader):
data = data.to(self.device).float()
output = self.network(data)
predictions = torch.nn.Softmax(dim=1)(output).cpu().detach().numpy()
results.append(predictions)
res = np.vstack(results)
return res