-
Notifications
You must be signed in to change notification settings - Fork 0
/
Model.py
123 lines (99 loc) · 4.65 KB
/
Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from models.SimpleModel import LogisticRegression, NaiveBayes, RandomForest
from models.SimpleModel import SVM, RandomPredictor, DecisionTree
from models.MLP import MLP
from models.LanguageResNet import LanguageResNet
from models.ImageResNet import ImageResNet
from models.DistilBert import DistilBertClassifier
from models.ViT import ViTClassifier
import numpy as np
class Model:
def __init__(self, model_name, SAVE_DIR=None, config=None, from_saved=False, **kwargs):
self.name = model_name
self.SAVE_DIR = SAVE_DIR
self.config = config
model_found = False
self.calib_frac = config['calib_frac']
dataset_obj = kwargs['dataset_obj'] if 'dataset_obj' in kwargs else None
saved_epoch = kwargs['saved_epoch'] if 'saved_epoch' in kwargs else 0
save_scheme = kwargs['save_scheme'] if 'save_scheme' in kwargs else 'best-val-acc'
# verify save_scheme
assert save_scheme in ['best-val-acc', 'all-epochs']
if self.name not in ['DistilBert', 'ViT', 'LanguageResNet']:
err_msg = f'all-epoch saving not supported for {self.name} model.'
assert (save_scheme == 'best-val-acc'), err_msg
# simple models
if self.name == 'LogisticRegression':
model_found = True
self.model = LogisticRegression(config)
elif self.name == 'NaiveBayes':
model_found = True
self.model = NaiveBayes(config)
elif self.name == 'RandomForest':
model_found = True
self.model = RandomForest(config)
elif self.name == 'SVM':
model_found = True
self.model = SVM(config)
elif self.name == 'RandomPredictor':
model_found = True
self.model = RandomPredictor(config)
elif self.name == 'DecisionTree':
model_found = True
self.model = DecisionTree(config)
else:
# need to save model for evaluation
if self.SAVE_DIR is None:
raise ValueError('Must provide a save directory for MLP model.')
# need to save before training ends
if self.config['val_save_epoch'] > self.config['epochs'] - 1:
raise ValueError(('Note: val_save_epoch must be <= (# epochs - 1); ' +
'model only saved when (# epochs elapsed) > val_save_epoch.'))
# MLP
if self.name == 'MLP':
model_found = True
self.model = MLP(self.SAVE_DIR, self.config, from_saved=from_saved)
# Resnet
elif self.name == 'LanguageResNet':
model_found = True
self.model = LanguageResNet(self.SAVE_DIR, config, from_saved=from_saved,
dataset_obj=dataset_obj, save_scheme=save_scheme)
elif self.name == 'ImageResNet':
model_found = True
self.model = ImageResNet(self.SAVE_DIR, config, from_saved=from_saved)
# Transformer
elif self.name == 'DistilBert':
model_found = True
self.model = DistilBertClassifier(self.SAVE_DIR, config, from_saved=from_saved,
saved_epoch=saved_epoch, save_scheme=save_scheme)
elif self.name == 'ViT':
model_found = True
self.model = ViTClassifier(self.SAVE_DIR, config, from_saved=from_saved,
saved_epoch=saved_epoch, save_scheme=save_scheme)
if not model_found:
raise ValueError('Unknown model name')
def train(self, X_train, y_train, groups_train, X_val, y_val, groups_val):
self.model.train(X_train, y_train, groups_train, X_val, y_val, groups_val)
def predict_proba(self, X, with_logits=False):
'''
Returns positive class probabilities.
If with_logits=True, returns both probabilities of the
positive class, and logits for both classes.
'''
# if no training set, return .5
if self.calib_frac == 1.0:
p = np.ones((X.shape[0], 2)) * .5
if with_logits: return p[:, 1], p
else: return p[:, 1]
# nontrivial training set
p = self.model.predict_proba(X, with_logits)
if with_logits: return (p[0][:,1], p[1])
else: return p[:,1]
def predict(self, X):
# if no training set, select random class
if self.calib_frac == 1.0:
return np.random.choice([0, 1], size=X.shape[0])
# nontrivial training set
return self.model.predict(X)
def load(self):
if self.name == 'MLP':
self.model.load()