-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain_infogain_pvi.py
144 lines (120 loc) · 5.19 KB
/
train_infogain_pvi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import transformers
from datasets import load_dataset, load_metric
import json
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback
import numpy as np
from transformers.trainer_utils import get_last_checkpoint
import torch
import shutil
max_input_length = 512
max_target_length = 64
padding = "max_length"
model_name = "t5-large"
label_pad_token_id = -100
pad_token = '<pad>'
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config)
batch_size = 16
output_dir = 'PVI/infogain_models/norepNLnomarkov'
do_train = True
do_eval = True
do_predict = True
global no_input
no_input = False
overwrite_output_dir = True
def postprocess_test_data(examples):
if not no_input:
inputs = [prefix + text for text in examples['inputs']]
else: inputs = [prefix + text for text in examples['prev_inputs']]
model_inputs = tokenizer(inputs, max_length=max_input_length, padding=padding, truncation=True, return_tensors="pt")
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
targets = [pad_token + label for label in examples['labels']]
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True, return_tensors="pt")
model_inputs["decoder_input_ids"] = labels["input_ids"]
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
def preprocess_data(examples):
if not no_input:
inputs = [prefix + text for text in examples['inputs']]
else: inputs = [prefix + text for text in examples['prev_inputs']]
model_inputs = tokenizer(inputs, max_length=max_input_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(examples["labels"], max_length=max_target_length, padding=padding, truncation=True)
labels["input_ids"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def compute_metrics(eval_pred):
predictions, labels = eval_pred
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Compute ROUGE scores
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
# Extract ROUGE f1 scores
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
return {k: round(v, 2) for k, v in result.items()}
prefix = 'Generate final answer: '
dataset = load_dataset('json', data_files = {'train': 'PVI/IG_norepNLnomarkov_tree/train.json', 'dev': 'PVI/IG_norepNLnomarkov_tree/dev.json', 'test': 'PVI/IG_norepNLnomarkov_tree/test.json'}, field="data")
tokenized_dataset = dataset.map(preprocess_data, batched=True)
# predict_dataset = dataset['test'].map(postprocess_test_data, batched=True)
args = Seq2SeqTrainingArguments(
output_dir = output_dir,
evaluation_strategy="epoch",
logging_strategy="epoch",
save_strategy="epoch",
learning_rate=3e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=0.01,
save_total_limit=1,
num_train_epochs=10,
predict_with_generate=True,
fp16=True,
load_best_model_at_end=True,
metric_for_best_model="eval_rougeL",
overwrite_output_dir=overwrite_output_dir,
)
data_collator = DataCollatorForSeq2Seq(tokenizer, model = model, label_pad_token_id=label_pad_token_id)
metric = load_metric("rouge")
trainer = Seq2SeqTrainer(
model=model,
args=args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["dev"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
callbacks = [EarlyStoppingCallback(early_stopping_patience=3)],
)
if do_train:
checkpoint = None
last_checkpoint = None
if os.path.isdir(output_dir):
last_checkpoint = get_last_checkpoint(output_dir)
if last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluation
results = {}
if do_eval:
metrics = trainer.evaluate(max_length=max_target_length, num_beams=8, metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Prediction
if do_predict:
results = trainer.predict(tokenized_dataset['test'], dataset['test'])
metrics = results.metrics
trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics)
last_checkpoint = get_last_checkpoint(output_dir)
shutil.rmtree(last_checkpoint)