-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGatewayTokenClassifier.py
219 lines (181 loc) · 9.61 KB
/
GatewayTokenClassifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#!/usr/bin/env python3
# add parent dir to sys path for import of modules
import os
import sys
# find recursively the project root dir
parent_dir = str(os.getcwdb())
while not os.path.exists(os.path.join(parent_dir, "README.md")):
parent_dir = os.path.abspath(os.path.join(parent_dir, os.pardir))
sys.path.insert(0, parent_dir)
import argparse
import logging
from typing import List
import transformers
import numpy as np
from Ensemble import Ensemble
from token_approaches.token_data_preparation import create_token_cls_dataset_full, create_token_cls_dataset_cv
from training import cross_validation, full_training
from token_approaches.metrics import *
from utils import config, generate_args_logdir, set_seeds
logger = logging.getLogger('Gateway Token Classifier')
parser = argparse.ArgumentParser()
# Standard params
parser.add_argument("--batch_size", default=8, type=int, help="Batch size.")
parser.add_argument("--epochs", default=1, type=int, help="Number of epochs.")
parser.add_argument("--seed_general", default=42, type=int, help="Random seed.")
parser.add_argument("--seeds_ensemble", default="0-1", type=str, help="Random seed range to use for ensembles")
# routine params
parser.add_argument("--routine", default="cv", type=str, help="Cross validation 'cv' or "
"full training without validation 'ft'.")
parser.add_argument("--folds", default=2, type=int, help="Number of folds in cross validation routine.")
parser.add_argument("--store_weights", default=False, type=bool, help="Flag if best weights should be stored.")
# Architecture / data params
parser.add_argument("--ensemble", default=True, type=bool, help="Use ensemble learning with config.json seeds.")
parser.add_argument("--labels", default=ALL, type=str, help="Label set to use.")
parser.add_argument("--other_labels_weight", default=0.1, type=float, help="Sample weight for non gateway tokens.")
parser.add_argument("--sampling_strategy", default=NORMAL, type=str, help="How to sample samples.")
parser.add_argument("--use_synonyms", default=False, type=str, help="Include synonym samples.")
parser.add_argument("--activity_masking", default=NOT, type=str, help="How to include activity data.")
class GatewayTokenClassifier(tf.keras.Model):
def __init__(self, args: argparse.Namespace, bert_model=None, train_size: int = None,
weights_path: str = None) -> None:
"""
creates a GatewayTokenClassifier
:param args: args Namespace
:param bert_model: bert like transformer token classification model
:param train_size: train dataset size
:param weights_path: path of stored weights. If set, load from there
"""
logger.info("Create and initialize a GatewayTokenClassifier")
self.weights_path = weights_path
# A) ARCHITECTURE
inputs = {
"input_ids": tf.keras.layers.Input(shape=[None], dtype=tf.int32),
"attention_mask": tf.keras.layers.Input(shape=[None], dtype=tf.int32)
}
# head of the following model is random initialized by the seed.
# - in case of single model, seed is set at the beginning of the script
# - in case of model in ensemble, seed is set before this constructor call
if not bert_model:
bert_model = transformers.TFAutoModelForTokenClassification.from_pretrained(
config[KEYWORDS_FILTERED_APPROACH][BERT_MODEL_NAME],
num_labels=config[KEYWORDS_FILTERED_APPROACH][LABEL_NUMBER])
# includes one dense layer with linear activation function
predictions = bert_model(inputs).logits
super().__init__(inputs=inputs, outputs=predictions)
# B) COMPILE (only needed when training is intended)
if args and train_size:
optimizer, lr_schedule = transformers.create_optimizer(
init_lr=2e-5,
num_train_steps=(train_size // args.batch_size) * args.epochs,
weight_decay_rate=0.01,
num_warmup_steps=0,
)
self.compile(optimizer=optimizer,
# loss=custom_loss,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
# general accuracy of all labels (except 0 class for padding tokens)
weighted_metrics=[tf.metrics.SparseCategoricalAccuracy(name="overall_accuracy")],
# metrics for classes of interest
metrics=[xor_precision, xor_recall, xor_f1, and_recall, and_precision, and_f1])
# token_cls_model.summary()
# self.summary()
# if model path is passed, restore weights
if self.weights_path:
logger.info(f"Restored weights from {weights_path}")
self.load_weights(weights_path)
def predict(self, tokens: transformers.BatchEncoding) -> np.ndarray:
"""
create predictions for given data
:param tokens: tokens as BatchEncoding
:return: numpy array of predictions
"""
return super().predict({"input_ids": tokens["input_ids"], "attention_mask": tokens["attention_mask"]})
class GTCEnsemble(Ensemble):
"""
Ensemble (seeds) of token classifier model
"""
def __init__(self, seeds: List = None, ensemble_path: str = None, es_monitor: str = 'val_loss',
seed_limit: int = None, **kwargs) -> None:
"""
see super class for param description
override for fixing model class
"""
super().__init__(GatewayTokenClassifier, seeds, ensemble_path, es_monitor, seed_limit, **kwargs)
def predict(self, tokens: transformers.BatchEncoding) -> np.ndarray:
"""
create predictions for given data with each model and average results on token axis
:param tokens: tokens as BatchEncoding
:return: numpy array of averaged predictions
"""
predictions = [model.predict(tokens=tokens) for model in self.models]
predictions_averaged = np.mean(predictions, axis=0)
return predictions_averaged
def convert_predictions_into_labels(predictions: np.ndarray, word_ids: List[List[int]]) -> List[List[int]]:
"""
convert predictions for every token (logits) into a list of labels for each sample
:param predictions: logits as np.ndarray
:param word_ids: original word ids of tokens
:return: list of labels for each sample
"""
converted_results = [] # list (for each sample): a dict with word_id: predicted class(es))
for i, sample in enumerate(predictions):
important_token_pairs = [(i, word_id) for i, word_id in enumerate(word_ids[i]) if word_id is not None]
converted_sample = {}
# store prediction(s) for every original word
for token_index, word_id in important_token_pairs:
token_prediction = np.argmax(sample[token_index])
if word_id not in converted_sample:
converted_sample[word_id] = [token_prediction]
else:
converted_sample[word_id].append(token_prediction)
# merge predictions (possible/necessary if multiple BERT-tokens are mapped to one input word)
for word_id, token_predictions in converted_sample.items():
token_predictions = list(set(token_predictions))
# if different labels were predicted, take the highest => 3 (AND) > 2 (XOR) > 1(other)
if len(token_predictions) > 1:
token_predictions.sort(reverse=True)
token_predictions = token_predictions[:1]
converted_sample[word_id] = token_predictions[0]
# assure sort by index after extracting from (unordered?) dict
converted_sample = [(idx, label) for idx, label in converted_sample.items()]
converted_sample.sort(key=lambda idx_label_pair: idx_label_pair[0])
# reduce to ordered list of labels
converted_sample = [label for idx, label in converted_sample]
converted_results.append(converted_sample)
return converted_results
def train_routine(args: argparse.Namespace) -> None:
"""
run GatewayTokenClassifier training based on passed args
:param args: namespace args
:return:
"""
if args.labels == 'filtered':
args.num_labels = 4
elif args.labels == 'all':
args.num_labels = 9
else:
raise ValueError(f"args.labels must be 'filtered' or 'all' and not '{args.labels}'")
logger.info(f"Use {args.labels} labels ({args.num_labels})")
print(args)
# Load the model
logger.info(f"Load transformer model and tokenizer ({config[KEYWORDS_FILTERED_APPROACH][BERT_MODEL_NAME]})")
token_cls_model = transformers.TFAutoModelForTokenClassification.from_pretrained(
config[KEYWORDS_FILTERED_APPROACH][BERT_MODEL_NAME], num_labels=args.num_labels)
# cross validation
if args.routine == 'cv':
folded_datasets = create_token_cls_dataset_cv(args)
cross_validation(args, GatewayTokenClassifier, folded_datasets, token_cls_model)
# full training without validation
elif args.routine == 'ft':
train = create_token_cls_dataset_full(args)
full_training(args, GatewayTokenClassifier, train, token_cls_model)
else:
raise ValueError(f"Invalid training routine: {args.routine}")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
args = parser.parse_args([] if "__file__" not in globals() else None)
args.logdir = generate_args_logdir(args, script_name="GatewayTokenClassifier")
# this seed is used by default (only overwritten in case of ensemble)
set_seeds(args.seed_general, "args - used for dataset split/shuffling")
train_routine(args)