Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
rawsh authored Nov 27, 2024
1 parent 3a2f4d0 commit 28a4e75
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 135 deletions.
96 changes: 61 additions & 35 deletions mcts/simple_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from collections import Counter
from functools import wraps
from collections import OrderedDict
import math # Added for math.exp in evaluate_step

# Configuration
POLICY_URL = 'https://rawsh--vllm-qwen-simpo-serve.modal.run/v1/'
PRM_URL = 'https://rawsh--mirrorqwen-prm-embedder-score-output.modal.run'
POLICY_URL = 'https://rawsh--vllm-qwen-metamath-serve.modal.run/v1/'
PRM_URL = 'https://rawsh--vllm-qwen-prm-serve.modal.run/v1/'
API_KEY = '9FF74944EED19865193F979942FB1'
BATCH_SIZE = 100 # Reduced batch size since we're doing multiple requests per question
MAX_RETRIES = 5
Expand All @@ -22,7 +23,7 @@
SAMPLES_PER_QUESTION = 1 # Default to single sample mode, override with CLI arg

# Cache decorator for PRM scores
def async_lru_cache(maxsize=2000):
def async_lru_cache(maxsize=10000):
cache = OrderedDict()
def decorator(func):
@wraps(func)
Expand Down Expand Up @@ -62,10 +63,12 @@ def update(self, any_correct: bool, best_correct: bool = None, sc_correct: bool
self.processed += 1
if any_correct:
self.correct_any += 1
if best_correct:
self.correct_best += 1
if sc_correct:
self.correct_sc += 1
if best_correct is not None:
if best_correct:
self.correct_best += 1
if sc_correct is not None:
if sc_correct:
self.correct_sc += 1
self.pbar.update(1)
self.pbar.set_description(self.get_description())

Expand Down Expand Up @@ -93,7 +96,7 @@ async def retry_with_exponential_backoff(func, *args, **kwargs):
delay = min(1.5 ** attempt + random.random(), 10)
await asyncio.sleep(delay)

@async_lru_cache(maxsize=1000)
@async_lru_cache(maxsize=10000)
async def get_prm_score(completion: str, session: aiohttp.ClientSession) -> float:
"""Get the PRM score for a completion."""
async with session.post(PRM_URL, json={"prompt": completion}) as response:
Expand All @@ -105,16 +108,30 @@ async def generate_completion(
client: AsyncOpenAI,
semaphore: Semaphore
) -> str:
"""Generate a single completion."""
"""Generate a single completion using chat-based API."""
async with semaphore:
response = await client.completions.create(
# model="mirrorqwen2.5-0.5b-SimPO-3",
model="MetaMath-Qwen2.5-0.5b",
prompt=question,
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": question}
]
response = await client.chat.completions.create(
timeout=TIMEOUT,
model="MetaMath-Qwen2.5-0.5b", # Ensure this is the correct model name
messages=messages,
max_tokens=1500,
temperature=0.8
# temperature=1.2,
temperature=0.0,
stop=["<|endoftext|>", "<|im_end|>"],
extra_body={
"repetition_penalty": 1.05,
"top_p": 0.8,
"top_k": 20,
"frequency_penalty": 0.05,
"presence_penalty": 0.05,
}
)
return response.choices[0].text.strip()
# print(response)
return response.choices[0].message.content.strip()

async def evaluate_question(
question: str,
Expand All @@ -136,7 +153,11 @@ async def evaluate_question(

# For single sample mode, return simpler result
if samples_per_question == 1:
is_correct = fr"\boxed{{{answer}}}" in completions[0]
# is_correct = fr"\boxed{{{answer}}}" in completions[0]
completion = completions[0].split("\n\n")[-1]
is_correct = answer in completion
print(completions[0].split("\n\n")[-1])
print(f"ANSWER: {completion} {answer} ({is_correct})")
return {
"question": question,
"expected_answer": answer,
Expand All @@ -156,13 +177,15 @@ async def evaluate_question(
is_correct = []
extracted_answers = []
for completion in completions:
correct = fr"\boxed{{{answer}}}" in completion
is_correct.append(correct)
# correct = fr"\boxed{{{answer}}}" in completion
extracted = completion.split("\n\n")[-1].split("The answer is: ")[-1].strip()
is_correct.append(answer in extracted)

# Extract answer for self-consistency
if r"\boxed{" in completion:
extracted = completion.split(r"\boxed{")[1].split("}")[0]
extracted_answers.append(extracted)
# if r"\boxed{" in completion:
# extracted = completion.split(r"\boxed{")[1].split("}")[0]
print(extracted)
extracted_answers.append(extracted)

# Find best completion by PRM score
best_idx = max(range(len(scores)), key=lambda i: scores[i])
Expand Down Expand Up @@ -238,23 +261,25 @@ async def process_batch(

return results

# greedy: 57%

async def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--samples", type=int, default=10,
help="Number of samples per question (default: 1)")
parser.add_argument("--num-questions", type=int, default=200,
help="Number of questions to evaluate (default: 200)")
parser.add_argument("--samples", type=int, default=1,
help="Number of samples per question (default: 1)")
parser.add_argument("--num-questions", type=int, default=100,
help="Number of questions to evaluate (default: 200)")
args = parser.parse_args()

# Set random seed for reproducibility
random.seed(42)
random.seed(0)

# Load and preprocess dataset
gsm8k = load_dataset("openai/gsm8k", "main", split="test").shuffle(seed=42)
questions = [(ex["question"], ex["answer"].split("\n#### ")[-1].strip())
for ex in gsm8k]
# questions = random.sample(questions, args.num_questions)
questions = random.sample(questions, args.num_questions)

# Initialize API client and semaphore
client = AsyncOpenAI(base_url=POLICY_URL, api_key=API_KEY)
Expand All @@ -276,13 +301,14 @@ async def main():
)
all_results.extend(results)
else:
# Use None for session in single-sample mode
for i in range(0, len(questions), BATCH_SIZE):
batch = questions[i:i + BATCH_SIZE]
results = await process_batch(
batch, client, None, progress, semaphore, args.samples
)
all_results.extend(results)
# Use a dummy session since PRM is not needed in single-sample mode
async with aiohttp.ClientSession() as session:
for i in range(0, len(questions), BATCH_SIZE):
batch = questions[i:i + BATCH_SIZE]
results = await process_batch(
batch, client, session, progress, semaphore, args.samples
)
all_results.extend(results)

# Save results
suffix = f"{args.samples}samples" if args.samples > 1 else "single"
Expand All @@ -294,4 +320,4 @@ async def main():
progress.close()

if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())
Loading

0 comments on commit 28a4e75

Please sign in to comment.