From 124f6ed43d958f8332efb447c4e35add61228b2c Mon Sep 17 00:00:00 2001 From: Robert Washbourne Date: Tue, 19 Nov 2024 06:02:44 +0000 Subject: [PATCH] m --- best_of_n.py | 368 ++++++++++++++++++++++++++++++++++++++++++++ eval_reward.py | 208 +++++++++++++++++++++++++ modal_prm_armorm.py | 5 +- 3 files changed, 579 insertions(+), 2 deletions(-) create mode 100644 best_of_n.py create mode 100644 eval_reward.py diff --git a/best_of_n.py b/best_of_n.py new file mode 100644 index 0000000..02a989e --- /dev/null +++ b/best_of_n.py @@ -0,0 +1,368 @@ +import aiohttp +import asyncio +import json +import os +import re +from tqdm.asyncio import tqdm_asyncio +from tqdm import tqdm +from typing import List, Dict, Any, Tuple +import gc +from datetime import datetime +from datasets import load_dataset + +# 1 sample +# Overall Statistics: +# Total correct answers: 58/100 +# Accuracy: 0.58 +# Average reward score: 0.0929 + +# best of 2 +# Overall Statistics: +# Total correct answers: 63/100 +# Accuracy: 0.63 +# Average reward score: 0.1148 + +# best of 4 +# Overall Statistics: +# Total correct answers: 64/100 +# Accuracy: 0.64 +# Average reward score: 0.1257 + +# best of 8 +# Overall Statistics: +# Total correct answers: 63/100 +# Accuracy: 0.63 +# Average reward score: 0.1307 + +# best of 16 +# Overall Statistics: +# Total correct answers: 70/100 +# Accuracy: 0.70 +# Average reward score: 0.1345 + +# best of 32 +# Overall Statistics: +# Total correct answers: 67/100 +# Accuracy: 0.67 +# Average reward score: 0.1380 + +# best of 64 +# Overall Statistics: +# Total correct answers: 67/100 +# Accuracy: 0.67 +# Average reward score: 0.1461 + + + +# Configuration +FIREWORKS_API_KEY = os.getenv('FIREWORKS_API_KEY') +if not FIREWORKS_API_KEY: + raise ValueError("FIREWORKS_API_KEY environment variable must be set") + +FIREWORKS_API_ENDPOINT = "https://api.fireworks.ai/inference/v1/chat/completions" +REWARD_MODEL_ENDPOINT = "https://rawsh--reward-api-model-score.modal.run" + +# Separate rate limits for different APIs +LLM_MAX_CONCURRENT = 100 +REWARD_MAX_CONCURRENT = 32 +BATCH_SIZE = 10 + +BEST_OF_N = 2 +MODEL_NAME = "accounts/fireworks/models/llama-v3p1-8b-instruct" + +# Timeout configurations +REWARD_MODEL_TIMEOUT = 20 +LLM_TIMEOUT = 10 +REWARD_MODEL_MAX_RETRIES = 3 + +class APIError(Exception): + """Custom exception for API-related errors""" + pass + +async def with_retry(func, max_retries=3, base_delay=1): + """Generic retry wrapper with exponential backoff""" + for i in range(max_retries): + try: + return await func() + except Exception as e: + if i == max_retries - 1: + raise + delay = base_delay * (2 ** i) + print(f"Attempt {i+1} failed: {str(e)}. Retrying in {delay}s...") + await asyncio.sleep(delay) + +def extract_answer(completion: str) -> Tuple[str, str]: + """Extract the final answer from the completion.""" + match = re.search(r"Answer:\s*([A-Z])", completion, re.IGNORECASE) + if match: + return completion.strip(), match.group(1).upper() + # Fallback: look for the last letter A-Z in the completion + letters = re.findall(r'[A-Z]', completion, re.IGNORECASE) + return completion.strip(), letters[-1].upper() if letters else "" + +async def get_reward_score( + reward_sem: asyncio.Semaphore, + session: aiohttp.ClientSession, + messages: List[Dict[str, str]] +) -> float: + """Get reward model score for a completion.""" + async def _get_score(): + async with reward_sem: + try: + async with session.post( + REWARD_MODEL_ENDPOINT, + json={"messages": messages}, + headers={"Content-Type": "application/json"}, + timeout=aiohttp.ClientTimeout(total=REWARD_MODEL_TIMEOUT) + ) as response: + if response.status != 200: + text = await response.text() + print(f"Error {response.status}: {text}") + raise APIError(f"Reward API returned status {response.status}") + result = await response.json() + return float(result.get('score', 0)) + except asyncio.TimeoutError: + print("Reward model request timed out") + raise + except Exception as e: + print(f"Exception in get_reward_score: {str(e)}") + raise + + try: + return await with_retry(_get_score, max_retries=REWARD_MODEL_MAX_RETRIES) + except Exception: + print("All reward score attempts failed") + return 0.0 + +async def verify_answer( + llm_sem: asyncio.Semaphore, + session: aiohttp.ClientSession, + student_answer: str, + correct_idx: int +) -> float: + """Verify if the student's answer is correct.""" + # Convert index to letter (0 -> A, 1 -> B, etc.) + correct_letter = chr(65 + correct_idx) # 65 is ASCII for 'A' + return 1.0 if student_answer.upper() == correct_letter else 0.0 + +async def get_completions( + llm_sem: asyncio.Semaphore, + reward_sem: asyncio.Semaphore, + session: aiohttp.ClientSession, + question: str, + choices: List[str], + n: int +) -> List[Tuple[str, str, float]]: + """Generate n completions and get their reward scores.""" + # Format question with options + formatted_question = f"{question}\n\nOptions:\n" + for idx, choice in enumerate(choices): + formatted_question += f"{chr(65 + idx)}) {choice}\n" + + print(f"\n[Generating {n} completions for question]") + print(f"Q: {formatted_question}") + + USER_PROMPT = ("You are an expert at truthful reasoning and you always pick the most accurate answer. " + "Think step by step and output your reasoning followed by your final answer using the following format:\n" + "Answer: X where X is one of the available letter options.\n\n") + + async def _get_completions(): + async with llm_sem: + messages = [ + {"role": "user", "content": ( + USER_PROMPT + + f"{formatted_question}" + )} + ] + + async with session.post( + FIREWORKS_API_ENDPOINT, + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {FIREWORKS_API_KEY}" + }, + json={ + "model": MODEL_NAME, + "messages": messages, + "n": n, + "temperature": 0.7, + "max_tokens": 4096, + "top_p": 1, + "top_k": 40, + "presence_penalty": 0, + "frequency_penalty": 0 + }, + timeout=aiohttp.ClientTimeout(total=LLM_TIMEOUT) + ) as response: + if response.status != 200: + text = await response.text() + raise APIError(f"OpenAI API returned status {response.status}: {text}") + + return await response.json() + + try: + result = await with_retry(_get_completions) + completions = [] + + # Get reward scores for each completion + for choice in result["choices"]: + full_completion, extracted_answer = extract_answer( + choice["message"]["content"].strip() + ) + + # Get reward score + reward_score = await get_reward_score( + reward_sem, + session, + [ + {"role": "user", "content": USER_PROMPT + formatted_question}, + {"role": "assistant", "content": full_completion} + ] + ) + + completions.append((full_completion, extracted_answer, reward_score)) + + # Log results + print("\n[Completion Results]") + for i, (_, answer, score) in enumerate(completions, 1): + print(f" {i}. {answer:<40} [reward: {score:.4f}]") + + return completions + + except Exception as e: + print(f"Exception in get_completions: {str(e)}") + return [("", "", 0.0)] * n + +async def evaluate_question( + llm_sem: asyncio.Semaphore, + reward_sem: asyncio.Semaphore, + session: aiohttp.ClientSession, + example +) -> Dict[str, Any]: + """Evaluate a single question with best-of-n completions.""" + question = example['question'] + mc1_targets = example['mc1_targets'] + choices = mc1_targets['choices'] + correct_idx = mc1_targets['labels'].index(1) # Find index where label is 1 + + # Get n completions with reasoning, extracted answers, and reward scores + completion_data = await get_completions(llm_sem, reward_sem, session, question, choices, BEST_OF_N) + completions, extracted_answers, reward_scores = zip(*completion_data) + + # Use reward scores to pick the best completion + best_idx = reward_scores.index(max(reward_scores)) + best_completion = completions[best_idx] + best_extracted = extracted_answers[best_idx] + + # Verify correctness of the best answer + correctness_score = await verify_answer( + llm_sem, session, best_extracted, correct_idx + ) + + return { + 'question': question, + 'choices': choices, + 'correct_answer': chr(65 + correct_idx), # Convert index to letter + 'completions': completions, + 'extracted_answers': extracted_answers, + 'reward_scores': reward_scores, + 'best_reward_score': reward_scores[best_idx], + 'best_completion': best_completion, + 'best_extracted_answer': best_extracted, + 'is_correct': bool(correctness_score) + } + +async def process_batch( + llm_sem: asyncio.Semaphore, + reward_sem: asyncio.Semaphore, + session: aiohttp.ClientSession, + batch_data: List[Dict] +) -> List[Dict[str, Any]]: + """Process a batch of questions.""" + batch_requests = [ + evaluate_question(llm_sem, reward_sem, session, example) + for example in batch_data + ] + return await tqdm_asyncio.gather(*batch_requests) + +async def evaluate_all(session: aiohttp.ClientSession, dataset) -> List[Dict[str, Any]]: + """Evaluate all questions in the dataset using batching.""" + llm_sem = asyncio.Semaphore(LLM_MAX_CONCURRENT) + reward_sem = asyncio.Semaphore(REWARD_MAX_CONCURRENT) + + # Convert dataset to list of dictionaries for easier processing + dataset_dicts = [ + { + 'question': item['question'], + 'mc1_targets': item['mc1_targets'] + } + for item in dataset + ] + + import random + random.seed(42) + random.shuffle(dataset_dicts) + dataset_dicts = dataset_dicts[:100] + + results = [] + print(f"\nEvaluating {len(dataset_dicts)} questions with {BEST_OF_N} completions each...") + + # Process in batches + for i in range(0, len(dataset_dicts), BATCH_SIZE): + batch_data = dataset_dicts[i:i + BATCH_SIZE] + print(f"\nProcessing batch {i//BATCH_SIZE + 1}/{(len(dataset_dicts) + BATCH_SIZE - 1)//BATCH_SIZE}") + + batch_results = await process_batch(llm_sem, reward_sem, session, batch_data) + results.extend(batch_results) + + # Periodic cleanup + gc.collect() + await asyncio.sleep(1) # Small delay between batches + + return results + +async def main(): + try: + # Load TruthfulQA dataset + dataset = load_dataset("truthful_qa", "multiple_choice") + validation_set = dataset["validation"] + print(f"Loaded {len(validation_set)} questions from TruthfulQA validation set") + + # Configure session with connection pooling + connector = aiohttp.TCPConnector( + limit=max(LLM_MAX_CONCURRENT, REWARD_MAX_CONCURRENT), + force_close=True + ) + timeout = aiohttp.ClientTimeout(total=60) + + # Create timestamp for output file + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + results = await evaluate_all(session, validation_set) + + if results: + print("\nOverall Statistics:") + correct_count = sum(1 for r in results if r['is_correct']) + total_count = len(results) + + print(f"Total correct answers: {correct_count}/{total_count}") + print(f"Accuracy: {correct_count/total_count:.2f}") + print(f"Average reward score: {sum(r['best_reward_score'] for r in results)/total_count:.4f}") + + # Save results + output_file = f'truthfulqa_mc_results.json' + with open(output_file, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nDetailed results saved to {output_file}") + + except Exception as e: + print(f"Error in main: {str(e)}") + raise + finally: + if 'connector' in locals() and hasattr(connector, 'close'): + await connector.close() + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/eval_reward.py b/eval_reward.py new file mode 100644 index 0000000..04ae326 --- /dev/null +++ b/eval_reward.py @@ -0,0 +1,208 @@ +import aiohttp +import asyncio +import json +from tqdm.asyncio import tqdm_asyncio +from tqdm import tqdm +from datasets import load_dataset +import random +from datetime import datetime + +MODAL_ENDPOINT = "https://rawsh--reward-api-model-score.modal.run" +MAX_CONCURRENT = 32 +BATCH_SIZE = 10 + +async def get_score(sem, session, messages, question_id, option_idx, answer, is_correct): + """Get reward model score for a completion.""" + async with sem: + try: + async with session.post( + MODAL_ENDPOINT, + json={"messages": messages}, + headers={"Content-Type": "application/json"}, + timeout=aiohttp.ClientTimeout(total=20) + ) as response: + if response.status != 200: + print(f"Error {response.status}: {await response.text()}") + score = 0 + else: + result = await response.json() + score = result.get('score', 0) + + return { + 'question_id': question_id, + 'option_idx': option_idx, + 'answer': answer, + 'score': float(score), + 'is_correct': is_correct + } + except Exception as e: + print(f"Exception in get_score: {str(e)}") + return { + 'question_id': question_id, + 'option_idx': option_idx, + 'answer': answer, + 'score': 0, + 'is_correct': is_correct + } + +async def evaluate_question(sem, session, question_data, question_idx): + """Evaluate a single question with all its options.""" + question = question_data['question'] + choices = question_data['mc1_targets']['choices'] + correct_idx = question_data['mc1_targets']['labels'].index(1) + + print(f"\nEvaluating Question {question_idx}:") + print(f"Question: {question}") + print(f"Correct Answer: {chr(65 + correct_idx)}) {choices[correct_idx]}") + + all_requests = [] + + # Evaluate each option + for idx, choice in enumerate(choices): + formatted_question = f"{question}\n\nOptions:\n" + for i, c in enumerate(choices): + formatted_question += f"{chr(65 + i)}) {c}\n" + + messages = [ + {"role": "user", "content": ("You are an expert at truthful reasoning and you always pick the most accurate answer. " + "Think step by step and output your reasoning followed by your final answer.\n\n" + + formatted_question)}, + {"role": "assistant", "content": f"After careful consideration, I choose option {chr(65 + idx)}. {choice}"} + ] + + all_requests.append(get_score( + sem, + session, + messages, + question_idx, + idx, + choice, + idx == correct_idx + )) + + return await asyncio.gather(*all_requests) + +async def process_batch(sem, session, batch_data, start_idx): + """Process a batch of questions.""" + batch_requests = [ + evaluate_question(sem, session, example, idx) + for idx, example in enumerate(batch_data, start_idx) + ] + return await tqdm_asyncio.gather(*batch_requests) + +async def evaluate_all(session, dataset): + """Evaluate all questions in the dataset using batching.""" + sem = asyncio.Semaphore(MAX_CONCURRENT) + + # Convert dataset to list and take same subset as original code + dataset_list = list(dataset) + random.seed(42) # Same seed as original code + random.shuffle(dataset_list) + dataset_list = dataset_list[:100] # Same subset size as original code + + results = [] + print(f"\nEvaluating {len(dataset_list)} questions...") + + # Process in batches + for i in range(0, len(dataset_list), BATCH_SIZE): + batch_data = dataset_list[i:i + BATCH_SIZE] + print(f"\nProcessing batch {i//BATCH_SIZE + 1}/{(len(dataset_list) + BATCH_SIZE - 1)//BATCH_SIZE}") + + batch_results = await process_batch(sem, session, batch_data, i) + results.extend(batch_results) + + await asyncio.sleep(1) # Small delay between batches + + return results, dataset_list + +async def main(): + try: + # Load TruthfulQA dataset + dataset = load_dataset("truthful_qa", "multiple_choice") + validation_set = dataset["validation"] + print(f"Loaded {len(validation_set)} questions from TruthfulQA validation set") + + # Configure session + connector = aiohttp.TCPConnector(limit=MAX_CONCURRENT, force_close=True) + timeout = aiohttp.ClientTimeout(total=60) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + all_results, dataset_list = await evaluate_all(session, validation_set) + + if all_results: + # Process results by question + results_by_question = {} + for question_results in all_results: + for result in question_results: + qid = result['question_id'] + if qid not in results_by_question: + results_by_question[qid] = [] + results_by_question[qid].append(result) + + # Calculate statistics + total_questions = len(results_by_question) + rank_1_count = 0 + total_correct_rank = 0 + total_score_diff = 0 + total_correct_score = 0 + total_best_score = 0 + + print("\nDetailed Results:") + for qid, scores in results_by_question.items(): + # Sort by score + scores.sort(key=lambda x: x['score'], reverse=True) + + # Find correct answer details + correct_scores = [s for s in scores if s['is_correct']] + if correct_scores: + correct_score = correct_scores[0] + correct_rank = scores.index(correct_score) + 1 + + if correct_rank == 1: + rank_1_count += 1 + + total_correct_rank += correct_rank + total_score_diff += scores[0]['score'] - correct_score['score'] + total_correct_score += correct_score['score'] + total_best_score += scores[0]['score'] + + print(f"\nQuestion {qid}:") + print(f"Correct answer rank: {correct_rank} out of {len(scores)}") + print(f"Correct score: {correct_score['score']:.4f}") + print(f"Best score: {scores[0]['score']:.4f}") + print(f"Score difference: {scores[0]['score'] - correct_score['score']:.4f}") + + print("\nSummary Statistics:") + print(f"Average rank of correct answer: {total_correct_rank/total_questions:.2f}") + print(f"Times correct answer ranked first: {rank_1_count}/{total_questions}") + print(f"Average score difference from best: {total_score_diff/total_questions:.4f}") + print(f"Average correct answer score: {total_correct_score/total_questions:.4f}") + print(f"Average best score: {total_best_score/total_questions:.4f}") + + # Save results + output_file = f'truthfulqa_reward_results_{timestamp}.json' + with open(output_file, 'w') as f: + json.dump({ + 'results_by_question': results_by_question, + 'summary': { + 'total_questions': total_questions, + 'rank_1_count': rank_1_count, + 'avg_correct_rank': total_correct_rank/total_questions, + 'avg_score_diff': total_score_diff/total_questions, + 'avg_correct_score': total_correct_score/total_questions, + 'avg_best_score': total_best_score/total_questions + } + }, f, indent=2) + print(f"\nDetailed results saved to {output_file}") + + except Exception as e: + print(f"Error in main: {str(e)}") + raise + finally: + if 'connector' in locals() and hasattr(connector, 'close'): + await connector.close() + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/modal_prm_armorm.py b/modal_prm_armorm.py index d4dd5a4..31b68be 100644 --- a/modal_prm_armorm.py +++ b/modal_prm_armorm.py @@ -24,7 +24,7 @@ class RewardModelHelper: def __init__(self, model): self.model = model - @inference.dynamically(batch_size=32, timeout_ms=100.0) + @inference.dynamically(batch_size=64, timeout_ms=20.0) def score_batch(self, features: dict[str, torch.Tensor]) -> torch.Tensor: with torch.no_grad(): # Move input to same device as model @@ -34,7 +34,7 @@ def score_batch(self, features: dict[str, torch.Tensor]) -> torch.Tensor: @app.cls( gpu=modal.gpu.A10G(), allow_concurrent_inputs=1000, - container_idle_timeout=120, + container_idle_timeout=300, ) class Model: def load_model(self): @@ -69,6 +69,7 @@ async def score(self, messages_dict: Dict[str, List[Dict[str, str]]]): return_tensors="pt", padding=True, truncation=True, + tokenize=True ) score = await self.score_batch.acall({"input_ids": inputs}) return {"score": score[0].item()} \ No newline at end of file