diff --git a/eval/eval_gpt_review.py b/eval/eval_gpt_review.py index 444b7474..1268d8e1 100644 --- a/eval/eval_gpt_review.py +++ b/eval/eval_gpt_review.py @@ -5,15 +5,11 @@ import time import openai -from tqdm import tqdm import ray +from tqdm import tqdm import shortuuid import logging -import numpy as np -import os -import openai -openai.api_key = os.getenv("OPENAI_API_KEY") logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -21,22 +17,38 @@ MAX_API_RETRY = 1000 REQ_TIME_GAP = 2 +openai.api_key = os.getenv("OPENAI_API_KEY") + +class PromptGenerator: + def __init__(self, reviewer_jsons, prompt_jsons): + self.reviewer_jsons = reviewer_jsons + self.prompt_jsons = prompt_jsons + + def gen_prompt(self, cat, ques, ans1, ans2): + reviewer_idx = next((idx for idx, reviewer in enumerate(self.reviewer_jsons) if reviewer["category"] == cat), 0) + prompt_id = self.reviewer_jsons[reviewer_idx]["prompt_id"] + prompt_json = self.prompt_jsons[prompt_id - 1] + assert prompt_json["prompt_id"] == prompt_id + + sys_prompt = prompt_json["system_prompt"] + prompt_template = prompt_json["prompt_template"] + defaults = prompt_json["defaults"] + prompt = prompt_template.format(question=ques, answer_1=ans1, answer_2=ans2, **defaults) + + return sys_prompt, prompt, reviewer_idx + 1 + @ray.remote(num_cpus=4) def get_eval(sys_prompt, user_prompt: str, max_tokens: int, model: str): - logging.basicConfig(level=logging.INFO) for i in range(MAX_API_RETRY): try: response = openai.ChatCompletion.create( model=model, messages=[ {"role": "system", "content": sys_prompt}, - { - "role": "user", - "content": user_prompt, - }, + {"role": "user", "content": user_prompt}, ], - temperature=0.2, # TODO: figure out which temperature is best for evaluation + temperature=0.2, max_tokens=max_tokens, ) content = response["choices"][0]["message"]["content"] @@ -44,20 +56,20 @@ def get_eval(sys_prompt, user_prompt: str, max_tokens: int, model: str): return content except Exception as e: logger.error(e) - time.sleep(min(5*(i+1), 100)) + time.sleep(min(5 * (i + 1), 100)) logger.error(f"Failed after {MAX_API_RETRY} retries.") return "error" + def parse_three_class_score(review): try: score = int(review.strip().split("\n")[-1].strip()) return score except Exception as e: - logger.error( - f"{e}\nContent: {review}\n" "You must manually fix the score pair." - ) + logger.error(f"{e}\nContent: {review}\nYou must manually fix the score pair.") return -1 - + + def parse_score(review): try: score_pair = review.split("\n")[0] @@ -68,43 +80,19 @@ def parse_score(review): else: raise Exception("Invalid score pair.") except Exception as e: - logger.error( - f"{e}\nContent: {review}\n" "You must manually fix the score pair." - ) + logger.error(f"{e}\nContent: {review}\nYou must manually fix the score pair.") return [-1, -1] -def gen_prompt(reviewer_jsons, prompt_jsons, cat, ques, ans1, ans2): - # Default to general category (index=0) - reviewer_idx = 0 - for idx, reviewer in enumerate(reviewer_jsons): - if reviewer["category"] == cat: - reviewer_idx = idx - break - prompt_id = reviewer_jsons[reviewer_idx]["prompt_id"] - prompt_json = prompt_jsons[prompt_id - 1] - assert prompt_json["prompt_id"] == prompt_id - - sys_prompt = prompt_json["system_prompt"] - prompt_template = prompt_json["prompt_template"] - defaults = prompt_json["defaults"] - prompt = prompt_template.format( - question=ques, answer_1=ans1, answer_2=ans2, **defaults - ) - - return sys_prompt, prompt, reviewer_idx + 1 - - def get_json_list(file_path): file_path = os.path.expanduser(file_path) with open(file_path, "r") as f: - json_list = [] - for line in f: - json_list.append(json.loads(line)) + json_list = [json.loads(line) for line in f] return json_list + if __name__ == "__main__": - parser = argparse.ArgumentParser(description="ChatGPT-based QA evaluation.") + parser = argparse.ArgumentParser(description="ChatGPT-based QA evaluation.", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("-q", "--question-file") parser.add_argument("-a", "--answer-file-list", nargs="+", default=[]) parser.add_argument("-p", "--prompt-file") @@ -112,12 +100,7 @@ def get_json_list(file_path): parser.add_argument("-o", "--output-review-file") parser.add_argument("-m", "--model", default='gpt-4') parser.add_argument("-id", "--id-key", default='question_id') - parser.add_argument( - "--max-tokens", - type=int, - default=1024, - help="maximum number of tokens produced in the output", - ) + parser.add_argument("--max-tokens", type=int, default=1024, help="maximum number of tokens produced in the output") args = parser.parse_args() if not os.path.isdir(args.output_review_file): @@ -148,69 +131,66 @@ def get_json_list(file_path): key=lambda x: x[args.id_key] ) - # check if # of questions, answers are the same assert len(question_jsons) == len(answer1_jsons) == len(answer2_jsons) + prompt_generator = PromptGenerator(reviewer_jsons, prompt_jsons) handles = [] review_jsons = [] - total_len = len(question_jsons) - question_idx_list = list(range(total_len)) - for i in tqdm(question_idx_list): + for i, question in enumerate(tqdm(question_jsons)): assert ( answer1_jsons[i][args.id_key] - == question_jsons[i][args.id_key] + == question[args.id_key] == answer2_jsons[i][args.id_key] ) - ques = question_jsons[i]["text"] - cat = question_jsons[i]["category"] + ques = question["text"] + cat = question["category"] if 'generation_truncated' in answer1_jsons[i]: - ans1 = answer1_jsons[i]["generation_truncated"].split( - 'A chat between a curious human and an artificial intelligence')[0] + ans1 = answer1_jsons[i]["generation_truncated"].split('A chat between a curious human and an artificial intelligence')[0] elif 'generation' in answer1_jsons[i]: - ans1 = answer1_jsons[i]["generation"].split( - 'A chat between a curious human and an artificial intelligence')[0] + ans1 = answer1_jsons[i]["generation"].split('A chat between a curious human and an artificial intelligence')[0] else: ans1 = answer1_jsons[i]["text"] - # ans1 = answer1_jsons[i]["text"] if 'generation_truncated' in answer2_jsons[i]: - ans2 = answer2_jsons[i]["generation_truncated"].split( - 'A chat between a curious human and an artificial intelligence')[0] + ans2 = answer2_jsons[i]["generation_truncated"].split('A chat between a curious human and an artificial intelligence')[0] elif 'generation' in answer2_jsons[i]: - ans2 = answer2_jsons[i]["generation"].split( - 'A chat between a curious human and an artificial intelligence')[0] + ans2 = answer2_jsons[i]["generation"].split('A chat between a curious human and an artificial intelligence')[0] else: ans2 = answer2_jsons[i]["text"] - sys_prompt, prompt, reviewer_id = gen_prompt( - reviewer_jsons, prompt_jsons, cat, ques, ans1, ans2 - ) + sys_prompt, prompt, reviewer_id = prompt_generator.gen_prompt(cat, ques, ans1, ans2) review_id = shortuuid.uuid() review_jsons.append( { "review_id": review_id, - args.id_key: question_jsons[i][args.id_key], - "answer1_id": answer1_jsons[i]["answer_id"] if 'answer_id' in answer1_jsons[i] else shortuuid.uuid(ans1), - "answer2_id": answer2_jsons[i]["answer_id"] if 'answer_id' in answer2_jsons[i] else shortuuid.uuid(ans2), + args.id_key: question[args.id_key], + "answer1_id": answer1_jsons[i].get("answer_id", shortuuid.uuid(ans1)), + "answer2_id": answer2_jsons[i].get("answer_id", shortuuid.uuid(ans2)), "reviewer_id": reviewer_id, - "metadata": {}, + "score": None, + "review": None, } ) - # To avoid the rate limit set by OpenAI - handles.append(get_eval.remote(sys_prompt, prompt, args.max_tokens, args.model)) - logger.info( - f"Waiting for {REQ_TIME_GAP} seconds before sending the next request." + handles.append( + get_eval.remote(sys_prompt, prompt, args.max_tokens, args.model) ) time.sleep(REQ_TIME_GAP) - reviews = ray.get(handles) - with open(dest, "w") as output_review_file: - for idx, review in enumerate(reviews): - if 'threeclass' in args.prompt_file: - scores = parse_three_class_score(review) - else: - scores = parse_score(review) - review_jsons[idx]["text"] = review - review_jsons[idx]["score"] = scores - output_review_file.write(json.dumps(review_jsons[idx]) + "\n") - output_review_file.flush() + review_texts = ray.get(handles) + + assert len(review_jsons) == len(review_texts) + + for i, review_text in enumerate(review_texts): + review_jsons[i]["review"] = review_text.strip() + if 'threeclass' in args.prompt_file: + review_jsons[i]["score"] = parse_three_class_score(review_text) + else: + review_jsons[i]["score"] = parse_score(review_text) + + with open(dest, "w") as f: + for review_json in review_jsons: + f.write(json.dumps(review_json) + "\n") + + logger.info(f"Review written to {dest}") + + ray.shutdown()