-
Notifications
You must be signed in to change notification settings - Fork 349
/
run_qg.py
236 lines (197 loc) · 7.72 KB
/
run_qg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import dataclasses
import json
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import numpy as np
import torch
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
T5Tokenizer,
BartTokenizer,
HfArgumentParser,
DataCollator,
TrainingArguments,
set_seed,
)
from trainer import Trainer
from data_collator import T2TDataCollator
from utils import freeze_embeds, assert_not_all_frozen
MODEL_TYPE_TO_TOKENIZER = {
"t5": T5Tokenizer,
"bart": BartTokenizer,
}
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
model_type: str = field(metadata={"help": "One of 't5', 'bart'"})
tokenizer_name_or_path: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
label_smoothing: Optional[float] = field(
default=0,
metadata={"help": "label smoothing rate, set to > 0 if you want to enable lable smoothing"}
)
freeze_embeds: bool = field(
default=False,
metadata={"help": "Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
train_file_path: str = field(
metadata={"help": "Path for cached train dataset"},
)
valid_file_path: str = field(
metadata={"help": "Path for cached valid dataset"},
)
data_dir: Optional[str] = field(
default=None,
metadata={"help": "Path for data files"},
)
task: Optional[str] = field(
default=None,
metadata={"help": "Which task 'qa', 'qg', 'e2e_qg', 'ans_ext', 'multi'. 'multi' means 'qa', 'qg', 'ans_ext' tasks"},
)
qg_format: Optional[str] = field(
default='prepend_qg_format',
metadata={"help": "How to format inputs for que generation, 'highlight_qg_format' or 'prepend_qg_format'"},
)
max_source_length: Optional[int] = field(
default=512,
metadata={"help": "Max input length for the source text"},
)
max_target_length: Optional[int] = field(
default=32,
metadata={"help": "Max input length for the target text"},
)
def main(args_file=None):
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if (len(sys.argv) == 2 and sys.argv[1].endswith(".json")) or args_file is not None:
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
args_file_path = os.path.abspath(sys.argv[1]) if args_file is None else args_file
model_args, data_args, training_args = parser.parse_json_file(json_file=args_file_path)
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
assert model_args.model_type in list(MODEL_TYPE_TO_TOKENIZER.keys()), "model type should be 't5' or 'bart'"
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
# Set seed
set_seed(training_args.seed)
# Set project name
os.environ["WANDB_PROJECT"] = "question-generation"
# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
tokenizer_cls = MODEL_TYPE_TO_TOKENIZER[model_args.model_type]
tokenizer = tokenizer_cls.from_pretrained(
model_args.tokenizer_name_or_path if model_args.tokenizer_name_or_path else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
model.resize_token_embeddings(len(tokenizer))
if model_args.freeze_embeds:
logger.info("freezing embeddings of the model")
freeze_embeds(model)
assert_not_all_frozen(model)
# Get datasets
logger.info('loading dataset')
train_dataset = torch.load(data_args.train_file_path) if training_args.do_train else None
valid_dataset = torch.load(data_args.valid_file_path) if training_args.do_eval else None
logger.info('finished loading dataset')
# Initialize data_collator
data_collator = T2TDataCollator(
tokenizer=tokenizer,
model_type=model_args.model_type,
mode="training",
using_tpu=training_args.tpu_num_cores is not None
)
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
data_collator=data_collator,
prediction_loss_only=True,
label_smoothing=model_args.label_smoothing
)
# disable wandb console logs
logging.getLogger('wandb.run_manager').setLevel(logging.WARNING)
# Training
if training_args.do_train:
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
)
trainer.save_model()
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
if trainer.is_world_master():
tokenizer.save_pretrained(training_args.output_dir)
# Evaluation
results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]:
logger.info("*** Evaluate ***")
eval_output = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(eval_output.keys()):
logger.info(" %s = %s", key, str(eval_output[key]))
writer.write("%s = %s\n" % (key, str(eval_output[key])))
results.update(eval_output)
return results
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
def run_qg(args_dict):
with open("args.json", 'w') as f:
json.dump(args_dict, f)
main(args_file="args.json")
if __name__ == "__main__":
main()