-
Notifications
You must be signed in to change notification settings - Fork 150
/
Copy pathsearch_config.py
117 lines (88 loc) · 4.77 KB
/
search_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from reasoners import SearchConfig, LanguageModel as Model
from world_model import PromptAlignWorldModel, PromptAlignState, PromptAlignAction, PromptAlignExample
from prompt import optimize_prompt
import json
import logging
from log_format import prompt_log, output_log, info_log
from utils import parse_json_output
class PromptAlignSearchConfig(SearchConfig[PromptAlignState, PromptAlignAction, PromptAlignExample]):
def __init__(self,
optimize_model: Model,
n_actions: int = 10,
temperature: float = 0.7
):
super().__init__()
self.optimize_model = optimize_model
self.n_actions = n_actions
self.temperature = temperature
# logging
logging.info("PromptAlignSearchConfig initialized with n_actions=%d, temperature=%f", n_actions, temperature)
def get_actions(self, state: PromptAlignState) -> list[PromptAlignAction]:
# logging
logging.info("Generating actions for the current state")
# we need current system prompt, current query, current output and current eval_dict
current_system_prompt = state[-1].system_prompt
current_query = state[-1].query
current_output = state[-1].output
current_eval_dict = state[-1].eval_dict
if len(current_eval_dict) == 0:
logging.info(info_log.format(info="Error in output parsing, skipping optimizarion"))
return True, [current_system_prompt]
score = 0
for aspect in current_eval_dict:
score += int(current_eval_dict[aspect]["score"])
score /= len(current_eval_dict)
# first let's check whether all eval_dict scores are 5
if all([int(current_eval_dict[aspect]["score"]) == 5 for aspect in current_eval_dict]):
# skip the optimization if all scores are 5
logging.info(info_log.format(info="All scores are 5, skipping optimization"))
return True, [current_system_prompt]
elif score > 4.5:
# skip the optimization if avg scores is > 4.5
logging.info(info_log.format(info="Avg score is >4.5, skipping optimization."))
return True, [current_system_prompt]
# we also need all the previous system prompts
previous_system_prompts = [sub_result.system_prompt for sub_result in state]
# but we only need the last 5
previous_system_prompts = previous_system_prompts[-5:]
# construct the prompt
prompt = optimize_prompt.replace("[CURRENT_SYSTEM_PROMPT]", current_system_prompt)\
.replace("[QUERY]", current_query)\
.replace("[OUTPUT]", current_output)\
.replace("[OUTPUT_EVALUATION]", json.dumps(current_eval_dict, indent=4))\
.replace(
"[FORMER_SYSTEM_PROMPTS]",
"\n".join(f"---Version {i+1}---\n{p}" for i, p in enumerate(previous_system_prompts[:-1])) + "\n---Current Version---\n" + previous_system_prompts[-1]
)
# logging the prompt, use "debug" level for the prompt
logging.debug(prompt_log.format(prompt=prompt))
# generate the new system prompt
outputs = self.optimize_model.generate(
user_prompt = prompt,
temperature = self.temperature,
top_p = 0.95,
max_new_tokens = 2048,
num_return_sequences = self.n_actions
)
if isinstance(outputs, str):
outputs = [outputs]
new_prompts = []
# logging
for output in outputs:
# parse the output
output = parse_json_output(output)
# logging
logging.info(output_log.format(output=json.dumps(output, indent=4)))
# append the new prompt
new_prompts.append(output["new_system_prompt"].replace("\\n", "\n"))
return False, new_prompts
def fast_reward(self, state: PromptAlignState, action: PromptAlignAction, **kwargs) -> tuple[float, dict]:
return 0, {}
def reward(self, state: PromptAlignState, action: PromptAlignAction, **kwargs) -> float:
# get the eval_dict directly from kwargs
eval_dict = kwargs["eval_dict"]
if len(eval_dict) == 0:
return 0
# calculate the reward by averaging the scores
reward = sum([int(eval_dict[aspect]["score"]) for aspect in eval_dict]) / len(eval_dict)
return reward