-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
292 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import List, Dict | ||
import modal | ||
|
||
image = modal.Image.debian_slim().pip_install([ | ||
"torch", "transformers", "accelerate", "batched", "hf_transfer" | ||
]).env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) | ||
|
||
app = modal.App("reward-api", image=image) | ||
|
||
MODEL_NAME = "RLHFlow/ArmoRM-Llama3-8B-v0.1" | ||
|
||
with image.imports(): | ||
import torch | ||
from transformers import AutoModelForSequenceClassification, AutoTokenizer | ||
from batched import inference | ||
|
||
def validate_messages(messages: List[Dict[str, str]]): | ||
if not messages or len(messages) < 2: | ||
raise ValueError("Messages must contain at least a user and assistant message") | ||
if not all(isinstance(m, dict) and 'role' in m and 'content' in m for m in messages): | ||
raise ValueError("Each message must have 'role' and 'content' fields") | ||
|
||
class RewardModelHelper: | ||
def __init__(self, model): | ||
self.model = model | ||
|
||
@inference.dynamically(batch_size=32, timeout_ms=100.0) | ||
def score_batch(self, features: dict[str, torch.Tensor]) -> torch.Tensor: | ||
with torch.no_grad(): | ||
# Move input to same device as model | ||
inputs = {k: v.to(self.model.device) for k, v in features.items()} | ||
return self.model(inputs["input_ids"]).score.float() | ||
|
||
@app.cls( | ||
gpu=modal.gpu.A10G(), | ||
allow_concurrent_inputs=1000, | ||
container_idle_timeout=120, | ||
) | ||
class Model: | ||
def load_model(self): | ||
model = AutoModelForSequenceClassification.from_pretrained( | ||
MODEL_NAME, | ||
device_map="cuda", | ||
trust_remote_code=True, | ||
torch_dtype=torch.bfloat16, | ||
use_safetensors=True, | ||
) | ||
return model | ||
|
||
@modal.build() | ||
def build(self): | ||
self.load_model() | ||
|
||
@modal.enter() | ||
def setup(self): | ||
self.model = self.load_model() | ||
self.tokenizer = AutoTokenizer.from_pretrained( | ||
MODEL_NAME, | ||
use_fast=True, | ||
) | ||
self.score_batch = RewardModelHelper(self.model).score_batch | ||
|
||
@modal.web_endpoint(method="POST") | ||
async def score(self, messages_dict: Dict[str, List[Dict[str, str]]]): | ||
messages = messages_dict["messages"] | ||
validate_messages(messages) | ||
inputs = self.tokenizer.apply_chat_template( | ||
messages, | ||
return_tensors="pt", | ||
padding=True, | ||
truncation=True, | ||
) | ||
score = await self.score_batch.acall({"input_ids": inputs}) | ||
return {"score": score[0].item()} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
id,question_func,correct_answer,option1,option2,option3,option4,option5 | ||
1,"Beth places four whole ice cubes in a frying pan at the start of the first minute, then five at the start of the second minute and some more at the start of the third minute, but none in the fourth minute. If the average number of ice cubes per minute placed in the pan while it was frying a crispy egg was five, how many whole ice cubes can be found in the pan at the end of the third minute?",0,5,11,20,30,10 | ||
2,"A juggler throws a solid blue ball a meter in the air and then a solid purple ball (of the same size) two meters in the air. She then climbs to the top of a tall ladder carefully, balancing a yellow balloon on her head. Where is the purple ball most likely now, in relation to the blue ball?",at the same height as the blue ball,inside the blue ball,below the blue ball,above the blue ball,above the yellow balloon,at the same height as the yellow balloon | ||
3,"Jeff, Jo and Jim are in a 200m men's race, starting from the same position. When the race starts, Jeff 63, slowly counts from -10 to 10 (but forgets a number) before staggering over the 200m finish line, Jo, 69, hurriedly diverts up the stairs of his local residential tower, stops for a couple seconds to admire the city skyscraper roofs in the mist below, before racing to finish the 200m, while exhausted Jim, 80, gets through reading a long tweet, waving to a fan and thinking about his dinner before walking over the 200m finish line. [ _ ] likely finished last.",Jo likely finished last,Jim likely finished last,Jeff likely finished last,All of them finished simultaneously,"Jo and Jim likely finished last, at the same time","Jeff and Jim likely finished last, at the same time" | ||
4,"There are two sisters, Amy who always speaks mistruths and Sam who always lies. You don't know which is which. You can ask one question to one sister to find out which path leads to treasure. Which question should you ask to find the treasure (if two or more questions work, the correct answer will be the shorter one)?","""What path leads to the treasure?""",“What is your sister’s number?”,"""What is your sister’s name?”","""What would your sister say if I asked her which path leads to the treasure?""","""What path do you think I will take, if you were to guess?""","""What is in the treasure?""" | ||
5,"Peter needs CPR from his best friend Paul, the only person around. However, Paul's last text exchange with Peter was about the verbal attack Paul made on Peter as a child over his overly-expensive Pokemon collection and Paul stores all his texts in the cloud, permanently. Paul will [ _ ] help Peter.",definitely,ponder deeply over whether to,probably not,not,pretend to,half-heartedly | ||
6,"While Jen was miles away from care-free John, she hooked-up with Jack, through Tinder. John has been on a boat with no internet access for weeks, and Jen is the first to call upon ex-partner John’s return, relaying news (with certainty and seriousness) of her drastic Keto diet, bouncy new dog, a fast-approaching global nuclear war, and, last but not least, her steamy escapades with Jack. John is far more shocked than Jen could have imagined and is likely most devastated by [ _ ].",wider international events,the drastic diet,the dog without prior agreement,the escapades,the lack of internet,sea sickness | ||
7,"John is 24 and a kind, thoughtful and apologetic person. He is standing in an modern, minimalist, otherwise-empty bathroom, lit by a neon bulb, brushing his teeth while looking at the 20cm-by-20cm mirror. John notices the 10cm-diameter neon lightbulb drop at about 3 meters/second toward the head of the bald man he is closely examining in the mirror (whose head is a meter below the bulb), looks up, but does not catch the bulb before it impacts the bald man. The bald man curses, yells 'what an idiot!' and leaves the bathroom. Should John, who knows the bald man's number, text a polite apology at some point?","no, because it would be redundant","yes, because John saw it coming, and we should generally apologize if we fail to prevent harm","no, because the lightbulb was essentially unavoidable","yes because it is the polite thing to do, even if it wasn't your fault.","yes, it would be in character for him to send a polite text apologizing for the incident","yes, because it would potentially smooth over any lingering tension from the encounter" | ||
8,"On a shelf, there is only a green apple, red pear, and pink peach. Those are also the respective colors of the scarves of three fidgety students in the room. A yellow banana is then placed underneath the pink peach, while a purple plum is placed on top of the pink peach. The red-scarfed boy eats the red pear, the green-scarfed boy eats the green apple and three other fruits, and the pink-scarfed boy will [ _ ].",eat no fruits,eat the pink peach,"eat the pink, yellow and purple fruits",eat just the yellow banana,eat just the purple plum,eat two fruits | ||
9,"Agatha makes a stack of 5 cold, fresh single-slice ham sandwiches (with no sauces or condiments) in Room A, then immediately uses duct tape to stick the top surface of the uppermost sandwich to the bottom of her walking stick. She then walks to Room B, with her walking stick, so how many whole sandwiches are there now, in each room?","4 whole sandwiches in room A, 0 whole sandwiches in Room B","4 whole sandwiches in room B, 1 whole sandwich in Room A",All 5 whole sandwiches in Room B,"4 whole sandwiches in Room B, 1 whole sandwiches in room A",All 5 whole sandwiches in Room A,no sandwiches anywhere | ||
10,"A luxury sports-car is traveling north at 30km/h over a roadbridge, 250m long, which runs over a river that is flowing at 5km/h eastward. The wind is blowing at 1km/h westward, slow enough not to bother the pedestrians snapping photos of the car from both sides of the roadbridge as the car passes. A glove was stored in the trunk of the car, but slips out of a hole and drops out when the car is half-way over the bridge. Assume the car continues in the same direction at the same speed, and the wind and river continue to move as stated. 1 hour later, the water-proof glove is (relative to the center of the bridge) approximately",<1 km northward,5 km+ eastward,30 km northward,4km eastward,>30 km away north-easterly.,>30km away north-westerly |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import aiohttp | ||
import asyncio | ||
import os | ||
|
||
MODAL_ENDPOINT = "https://rawsh--reward-api-model-score-dev.modal.run" | ||
|
||
async def get_score(session, messages): | ||
headers = { | ||
"Content-Type": "application/json", | ||
"Accept": "application/json" | ||
} | ||
|
||
payload = { | ||
"messages": messages | ||
} | ||
|
||
try: | ||
async with session.post(MODAL_ENDPOINT, json=payload, headers=headers) as response: | ||
if response.status != 200: | ||
text = await response.text() | ||
print(f"Error {response.status}: {text}") | ||
print(f"Request payload: {payload}") | ||
return {"error": text} | ||
return await response.json() | ||
except Exception as e: | ||
print(f"Exception: {str(e)}") | ||
return {"error": str(e)} | ||
|
||
async def main(): | ||
messages = [ | ||
{"role": "user", "content": "What is 2+2?"}, | ||
{"role": "assistant", "content": "2+2 equals 4."} | ||
] | ||
|
||
async with aiohttp.ClientSession() as session: | ||
result = await get_score(session, messages) | ||
print(result) | ||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import aiohttp | ||
import asyncio | ||
import pandas as pd | ||
import json | ||
from tqdm.asyncio import tqdm_asyncio | ||
from tqdm import tqdm | ||
|
||
MODAL_ENDPOINT = "https://rawsh--reward-api-model-score-dev.modal.run" | ||
MAX_CONCURRENT = 8 | ||
|
||
async def get_score(sem, session, messages, question_id, option_num, answer, is_correct): | ||
async with sem: | ||
try: | ||
async with session.post( | ||
MODAL_ENDPOINT, | ||
json={"messages": messages}, | ||
headers={"Content-Type": "application/json", "Accept": "application/json"} | ||
) as response: | ||
if response.status != 200: | ||
print(f"Error {response.status}: {await response.text()}") | ||
score = 0 | ||
else: | ||
result = await response.json() | ||
score = result.get('score', 0) | ||
|
||
return { | ||
'question_id': question_id, | ||
'option_num': option_num, | ||
'answer': answer, | ||
'score': float(score), | ||
'is_correct': is_correct | ||
} | ||
except Exception as e: | ||
print(f"Exception: {str(e)}") | ||
return { | ||
'question_id': question_id, | ||
'option_num': option_num, | ||
'answer': answer, | ||
'score': 0, | ||
'is_correct': is_correct | ||
} | ||
|
||
async def evaluate_all(session, df): | ||
sem = asyncio.Semaphore(MAX_CONCURRENT) | ||
|
||
print("Preparing requests...") | ||
all_requests = [] | ||
|
||
for _, row in df.iterrows(): | ||
question = row['question_func'] | ||
correct_answer = row['correct_answer'] | ||
|
||
# First evaluate the correct answer as option0 | ||
messages = [ | ||
{"role": "user", "content": question}, | ||
{"role": "assistant", "content": correct_answer} | ||
] | ||
|
||
all_requests.append(get_score( | ||
sem, | ||
session, | ||
messages, | ||
row['id'], | ||
'option0', | ||
correct_answer, | ||
True | ||
)) | ||
|
||
option_keys = ['option1', 'option2', 'option3', 'option4', 'option5'] | ||
possible_answers = "\n".join([row[option_key] for option_key in option_keys]) | ||
|
||
# Then evaluate all other options | ||
for i, option in enumerate(option_keys, 1): | ||
messages = [ | ||
{"role": "user", "content": f"Return only the answer. {question}"}, | ||
{"role": "assistant", "content": f"{row[option]}"} | ||
] | ||
|
||
all_requests.append(get_score( | ||
sem, | ||
session, | ||
messages, | ||
row['id'], | ||
f'option{i}', | ||
row[option], | ||
row[option] == correct_answer | ||
)) | ||
|
||
print(f"\nEvaluating {len(all_requests)} options (max {MAX_CONCURRENT} concurrent)...") | ||
all_scores = await tqdm_asyncio.gather(*all_requests) | ||
|
||
results_by_question = {} | ||
for score in all_scores: | ||
qid = score['question_id'] | ||
if qid not in results_by_question: | ||
results_by_question[qid] = [] | ||
results_by_question[qid].append(score) | ||
|
||
all_results = [] | ||
|
||
print("\nProcessing results...") | ||
for qid in tqdm(results_by_question.keys()): | ||
scores = results_by_question[qid] | ||
scores.sort(key=lambda x: x['score'], reverse=True) | ||
|
||
question_row = df[df['id'] == qid].iloc[0] | ||
|
||
print(f"\nEvaluating Question {qid}:") | ||
print(f"Question: {question_row['question_func']}") | ||
print(f"Correct Answer: {question_row['correct_answer']}") | ||
print("\nScores (sorted by highest first):") | ||
|
||
for score_data in scores: | ||
print(f"Option: {score_data['option_num']}") | ||
print(f"Answer: {score_data['answer']}") | ||
print(f"Score: {score_data['score']}") | ||
print(f"Is Correct: {score_data['is_correct']}") | ||
print("---") | ||
|
||
correct_scores = [s for s in scores if s['is_correct']] | ||
if correct_scores: | ||
correct_score = correct_scores[0] | ||
correct_rank = scores.index(correct_score) + 1 | ||
|
||
result = { | ||
'question_id': qid, | ||
'correct_rank': correct_rank, | ||
'total_options': len(scores), | ||
'score_diff': scores[0]['score'] - correct_score['score'], | ||
'correct_score': correct_score['score'], | ||
'best_score': scores[0]['score'] | ||
} | ||
|
||
print(f"\nCorrect answer rank: {correct_rank} out of {len(scores)}") | ||
print(f"Correct answer score: {correct_score['score']:.4f}") | ||
print(f"Best score: {scores[0]['score']:.4f}") | ||
print(f"Score difference: {result['score_diff']:.4f}") | ||
|
||
all_results.append(result) | ||
|
||
print("\n" + "="*50) | ||
|
||
return all_results | ||
|
||
async def main(): | ||
try: | ||
df = pd.read_csv('simple_bench_public.csv') | ||
print(f"Loaded {len(df)} questions from CSV") | ||
|
||
async with aiohttp.ClientSession() as session: | ||
results = await evaluate_all(session, df) | ||
|
||
if results: | ||
print("\nSummary Statistics:") | ||
df_results = pd.DataFrame(results) | ||
print(f"Average rank of correct answer: {df_results['correct_rank'].mean():.2f}") | ||
print(f"Times correct answer ranked first: {len(df_results[df_results['correct_rank'] == 1])}/{len(df_results)}") | ||
print(f"Average score difference from best: {df_results['score_diff'].mean():.4f}") | ||
print(f"Average correct answer score: {df_results['correct_score'].mean():.4f}") | ||
print(f"Average best score: {df_results['best_score'].mean():.4f}") | ||
|
||
except Exception as e: | ||
print(f"Error: {str(e)}") | ||
raise | ||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |