-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprompting.py
67 lines (54 loc) · 2.16 KB
/
prompting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from OpenAIInterface import OpenAIInterface as oAI
from omegaconf import OmegaConf
import time
import json
import re
def generate_completion(prompts, conf):
print("Calling OpenAI API...")
prompt_tic = time.time()
if conf.model in oAI.CHAT_GPT_MODEL_NAME:
prompt_responses = oAI.getCompletionForAllPrompts(conf, prompts, batch_size=20, use_parallel=True)
else:
prompt_responses = oAI.getCompletionForAllPrompts(conf, prompts, batch_size=10, use_parallel=False)
prompt_toc = time.time()
print("Called OpenAI API in", prompt_toc - prompt_tic, "seconds.")
oAI.save_cache()
predictions = []
for prompt_response in prompt_responses:
pred = prompt_response["text"] if "davinci" in conf.model else prompt_response["message"]["content"]
predictions.append(pred)
return predictions
def main():
data_address = "./prompts/zero_shot_10.json"
# data_address = "./prompts/zero_shot_3.json"
# data_address = "./prompts/zero_shot_10_complement.json"
# data_address = "./prompts/zero_shot_3_complement.json"
with open (data_address) as f:
prompts = json.load(f)
print(len(prompts))
print(prompts[0])
# prompts = prompts[:50]
oaicfg = {
'model': "gpt-4-turbo-preview", # gpt-4-turbo-preview
'temperature' : 0.0,
'max_tokens' : 1000, # 1000 350
'top_p' : 1.0,
'frequency_penalty' : 0.0,
'presence_penalty' : 0.0,
'stop' : [],
'logprobs': None,
'echo' : False
}
conf = OmegaConf.create(oaicfg)
predictions = generate_completion(prompts, conf)
# for predction in predictions:
# print(predction)
# print("_____________________")
result_address = "./predictions/zero_shot_10_predictions.json"
# result_address = "./predictions/zero_shot_3_predictions.json"
# result_address = "./predictions/zero_shot_10_complement_predictions.json"
# result_address = "./predictions/zero_shot_3_complement_predictions.json"
with open(result_address, 'w') as f:
json.dump(predictions, f, indent=4)
if __name__ == "__main__":
main()