-
Notifications
You must be signed in to change notification settings - Fork 6
/
tune_transformer.py
138 lines (126 loc) · 5.99 KB
/
tune_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import os
import sys
import torch
from download import write_line_by_line
from transformers import (
AutoConfig, AutoModelWithLMHead, AutoTokenizer,
DataCollatorForLanguageModeling, set_seed,
TextDataset, Trainer, TrainingArguments)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
device = torch.device('cuda')
def finetune(tag):
"""fine-tune gpt2 on the given caption dataset"""
global tokenizer
config = AutoConfig.from_pretrained('gpt2')
model = AutoModelWithLMHead.from_pretrained('gpt2', config=config)
block_size = tokenizer.max_len
# https://github.com/huggingface/transformers/blob/448c467256332e4be8c122a159b482c1ef039b98/src/transformers/data/datasets/language_modeling.py
try:
train_dataset = TextDataset(
tokenizer=tokenizer, file_path=f'./text/training_text/{tag}.txt',
block_size=block_size, overwrite_cache=True)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
epochs = 8
training_args = TrainingArguments(
output_dir='logging/output',
overwrite_output_dir=True,
do_train=True,
num_train_epochs=epochs,
gradient_accumulation_steps=1,
learning_rate=1e-4,
per_gpu_train_batch_size=1,
logging_steps=50,
save_steps=0)
set_seed(training_args.seed)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
prediction_loss_only=True)
with open(f'./logging/training_stats/training_{tag}.log', 'w') as log:
sys.stdout = log
trainer.train()
sys.stdout = sys.__stdout__
if not os.path.exists(f'./trained_models/{tag}/'):
os.makedirs(f'./trained_models/{tag}/')
# save the model
model.save_pretrained(f'./trained_models/{tag}/')
print('Done!')
except AssertionError:
print(f'The training text with the tag = {tag} does not exist. No model was trained!')
# TODO: captions can always be cleaned/scored better
def generate_captions(tag, prompt, max_length, min_length, num_return_sequences):
"""generate captions from our fine-tuned model"""
def clean_token(text):
"""edge case where the endoftext token can be left in generated"""
token = '<|endoftext|>'
while len(token)>1:
text = text.replace(token, '')
token = token[:-1]
text = text.strip()
if text[-1] == '"' and text.count('"') % 2: text = text[:-1]
return text.strip()
try:
model = AutoModelWithLMHead.from_pretrained(f'./trained_models/{tag}/').to(device)
encoded_sentence = tokenizer.encode(prompt, add_special_tokens=False, return_tensors='pt').to(device)
# https://huggingface.co/transformers/_modules/transformers/modeling_utils.html#PreTrainedModel.generate
output_sequences = model.generate(
input_ids= encoded_sentence,
max_length=max_length,
min_length=min_length,
temperature=1.,
top_p=0.95,
do_sample=True,
num_return_sequences=num_return_sequences)
stop_token = '\n'
generated_sequences = []
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
text = text[: text.find(stop_token)]
# + 2 because it could be punctuation or emojis or both
if len(text) > (len(prompt) + 2):
generated_sequences.append(clean_token(text))
# remove duplicates
generated_sequences = list(set(generated_sequences))
# just so I can see things better
generated_text = '\nCAPTION: '.join(generated_sequences)
generated_text = 'CAPTION: ' + generated_text
write_line_by_line(f'./text/generated_text/{tag}_gen.txt', generated_text)
print(f'Writing captions: /Hugging-Captions/text/generated_text/{tag}_gen.txt')
print('Done!')
except EnvironmentError:
print(f'A model with the tag = {tag} does not exist. No captions were generated. Train a model first.')
def main():
parser = argparse.ArgumentParser(
description='Tune transformer model and generate captions'
)
parser.add_argument('--tag', type=str, help='Hashtag that we used to train the data', required=True)
# finetune stuff (could add more params later)
parser.add_argument('--train', action='store_true', default=False, help='Should we train the model (default: False)')
# generate_caption stuff
parser.add_argument('--generate', action='store_true', default=False, help='Should we generate captions')
parser.add_argument('--prompt', type=str, default='My day',
help='Give the model something to start with when generating text 1-5 words will due (default= My\ Day)')
parser.add_argument('--max-length', type=int, default=60, help='Max length of caption text (default=60)')
parser.add_argument('--min-length', type=int, default=20, help='Min length of caption text (default=20)')
parser.add_argument('--num-captions', type=int, default=40,
help='Number of captions to generate, some of these captions will be dropped because they are duplicates (default=40)')
args = parser.parse_args()
if (args.train and args.generate):
print('Training and generating captions ...')
finetune(args.tag)
generate_captions(args.tag, args.prompt, args.max_length,
args.min_length, args.num_captions)
elif (args.train):
print('Training ...')
finetune(args.tag)
elif (args.generate):
print('Generating captions ...')
generate_captions(args.tag, args.prompt, args.max_length,
args.min_length, args.num_captions)
else:
print('Please choose either --train or --generate or both')
if __name__ == '__main__':
main()