-
Notifications
You must be signed in to change notification settings - Fork 1
/
generator.py
135 lines (117 loc) · 4.36 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import requests
import math
import json
import os
import openai
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
from args import printd
class OpenAIGenerator:
def __init__(self, model, tokenizer, stop_token, legacy, **kwargs):
openai.api_key = os.environ.get('OPENAI_API_KEY')
self.model = model
self.tokenizer = tokenizer
self.legacy = legacy
self.stop_token = stop_token
self.__dict__.update(kwargs)
def generate(self, input_ids, system_ids):
args = {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"stop": self.stop_token,
"n": self.n,
"frequency_penalty": self.repetition_penalty,
"presence_penalty": self.presence_penalty,
}
input_str = self.tokenizer.decode(input_ids) # Decoding the input IDs using the tokenizer
if self.legacy:
args["prompt"] = input_str
args["max_tokens"] = self.max_new_tokens # Default is not infinity for legacy
else:
messages = []
if len(system_ids) > 0:
system_str = self.tokenizer.decode(system_ids)
input_str = input_str.replace(system_str, '') # The input_str includes system_str, which must be removed
messages.append({
"role": "system",
"content": system_str
})
messages.append({
"role": "user",
"content": input_str
})
args["messages"] = messages
printd("-----------OPENAI CALL------------") # For debugging
printd(json.dumps(messages, indent=4))
printd("----------------------------------")
print(' Waiting for OpenAI result ...', end='', flush=True)
if self.legacy:
completion = openai.Completion.create(**args)
else:
completion = openai.ChatCompletion.create(**args)
for choice in completion.choices:
if self.legacy:
yield choice.text
else:
yield choice.message.content
class HFGenerator:
def __init__(self, model, tokenizer, stop_token, seq2seq, **kwargs):
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.model = model
self.tokenizer = tokenizer
self.seq2seq = seq2seq
stop_token_id = tokenizer.encode(stop_token, add_special_tokens=False)
if len(stop_token_id) == 1:
self.stop_token_id = stop_token_id[0]
else:
print("WARNING: too many tokens for stop_token_id, using only eos_token as stop token")
self.stop_token_id = tokenizer.eos_token_id
self.__dict__.update(kwargs)
def generate(self, input_ids, _):
stopping_criteria = StopTokenCriteria(self.stop_token_id, self.tokenizer.eos_token_id, self.tokenizer)
stopping_criteria_list = StoppingCriteriaList([stopping_criteria])
inputs = {'input_ids': torch.as_tensor([input_ids]).to(self.device)}
if self.seq2seq == "codet5p": # Only for bigger CodeT5+ models (and not fine-tuned)
inputs['decoder_input_ids'] = inputs['input_ids'].clone()
for output in self.model.generate(
**inputs,
temperature=self.temperature,
top_p=self.top_p,
min_new_tokens=0,
max_new_tokens=self.max_new_tokens,
do_sample=self.do_sample,
num_return_sequences=self.n,
num_beams=self.num_beams,
num_beam_groups=self.num_beam_groups,
diversity_penalty=self.diversity_penalty,
repetition_penalty=self.repetition_penalty,
pad_token_id=self.tokenizer.eos_token_id,
stopping_criteria=stopping_criteria_list,
eos_token_id=self.stop_token_id,
):
output = output[-stopping_criteria.generated-1:] # Trim input from output
output = output[output != self.stop_token_id] # Remove stop token
yield self.tokenizer.decode(output, skip_special_tokens=True)
class StopTokenCriteria(StoppingCriteria):
def __init__(self, stop_token_id, eos_token_id, tokenizer):
self.generated = 0
self.stop_token_id = stop_token_id
self.eos_token_id = eos_token_id
self.reached_stop_token = []
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
self.generated += 1
stop = False
for k, tokens in enumerate(input_ids):
if k not in self.reached_stop_token and (tokens[-1] == self.stop_token_id or tokens[-1] == self.eos_token_id):
self.reached_stop_token.append(k)
stop = True
if stop:
print('S', end='', flush=True) # Reached stop/eos token
if len(self.reached_stop_token) == len(input_ids):
self.generated -= 1 # Do not include final stop token in generation count (like eos)
return True # Stop generate on either stop token or eos token
else:
print('.', end='', flush=True)
return False