diff --git a/mcts/train_policy_sft_metamath.py b/mcts/train_policy_sft_metamath.py index a52aafa..88ce605 100644 --- a/mcts/train_policy_sft_metamath.py +++ b/mcts/train_policy_sft_metamath.py @@ -1,75 +1,114 @@ from unsloth import FastLanguageModel import torch -from trl import SFTTrainer -from transformers import TrainingArguments +import wandb +from datasets import load_dataset from unsloth import is_bfloat16_supported from unsloth import UnslothTrainer, UnslothTrainingArguments -from datasets import load_dataset from unsloth.chat_templates import get_chat_template # Constants -SEED = 42 -max_seq_length = 8192 +max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ -load_in_4bit = False +load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False. -first = True +first_type1 = True +first_type2 = True def format_answer(response): - global first """Extract answer from #### pattern and format response.""" + global first_type1 + global first_type2 + # Split at #### and get everything before it parts = response.split('####') if len(parts) < 2: - return None - - + # combine the last two steps + steps = parts[0].strip().split("\n") + if len(steps) > 1: + steps[-2] = steps[-2] + f"\n{steps[-1]}" + steps = steps[:-1] + sol = "\n\n".join(steps) + + if (first_type1): + print(response) + first_type1 = False + + return sol + else: + return None + solution = "\n\n".join(parts[0].strip().split("\n")) - answer = parts[1].split('The answer is:')[0].strip() + answer = parts[1].split('The answer is:') + answer = answer[0].strip() + sol = f"{solution}\nThe answer is: {answer}" - if (first): - print(solution) - print(answer) - first = False - - return f"{solution}\n\nThe final answer is: $\\boxed{{{answer}}}$" + if (first_type2): + print(response) + first_type2 = False + + return sol def train_sft(): - # Load model and tokenizer + # Load base and instruct models model, tokenizer = FastLanguageModel.from_pretrained( - model_name = "Qwen/Qwen2.5-0.5B", + model_name = "unsloth/Qwen2.5-0.5B", max_seq_length = max_seq_length, dtype = dtype, load_in_4bit = load_in_4bit, ) - # Set up chat template - tokenizer = get_chat_template( - tokenizer, - chat_template = "qwen-2.5", + model_instruct, tokenizer_instruct = FastLanguageModel.from_pretrained( + model_name = "unsloth/Qwen2.5-0.5B-Instruct", + max_seq_length = max_seq_length, + dtype = dtype, + load_in_4bit = load_in_4bit, ) - # Configure PEFT + # Transfer chat token embeddings from instruct to base model + base_embeddings = model.get_input_embeddings() + instruct_embeddings = model_instruct.get_input_embeddings() + chat_tokens = ["<|im_start|>", "<|im_end|>", "system", "assistant", "user"] + with torch.no_grad(): + for token in chat_tokens: + try: + instruct_id = tokenizer_instruct.convert_tokens_to_ids(token) + base_id = tokenizer.convert_tokens_to_ids(token) + if instruct_id != tokenizer_instruct.unk_token_id and base_id != tokenizer.unk_token_id: + base_embeddings.weight[base_id] = instruct_embeddings.weight[instruct_id].clone() + print(f"Transferred embedding for token: {token}") + else: + print(f"Warning: Token {token} not found in one of the vocabularies") + except Exception as e: + print(f"Error transferring token {token}: {str(e)}") + + # Add LoRA adapters model = FastLanguageModel.get_peft_model( model, - r = 128, + r = 128, # Choose any number > 0! Suggested 8, 16, 32, 64, 128 target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", - "embed_tokens", "lm_head"], + "embed_tokens", "lm_head",], # Add for continual pretraining lora_alpha = 32, - lora_dropout = 0, - bias = "none", - use_gradient_checkpointing = "unsloth", + lora_dropout = 0, # Supports any, but = 0 is optimized + bias = "none", # Supports any, but = "none" is optimized + use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context random_state = 3407, - use_rslora = True, - loftq_config = None, + use_rslora = True, # We support rank stabilized LoRA + loftq_config = None, # And LoftQ + ) + + # Set up tokenizer with chat template + tokenizer = get_chat_template( + tokenizer, + chat_template = "qwen-2.5", ) + tokenizer.eos_token = "<|im_end|>" + print(tokenizer.eos_token) + print(tokenizer.pad_token) - # Load dataset - ds = load_dataset("meta-math/MetaMathQA") - train_ds = ds['train'] + # Load and process dataset + dataset = load_dataset("meta-math/MetaMathQA", split="train") - # Format prompts def formatting_prompts_func(examples): conversations = [] for query, response in zip(examples['query'], examples['response']): @@ -82,52 +121,58 @@ def formatting_prompts_func(examples): {"role": "assistant", "content": formatted_response} ] conversations.append(conversation) - - texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) + + texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in conversations] return {"text": texts} - - # <|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is the total cost of purchasing equipment for all sixteen players on the football team, considering that each player requires a $25 jersey, a $15.20 pair of shorts, and a pair of socks priced at $6.80?<|im_end|>\n<|im_start|>assistant\nEach player requires a $25 jersey, a $15.20 pair of shorts, and a pair of socks priced at $6.80.\n\nSo the total cost for each player is $25 + $15.20 + $6.80 = $47.\n\nSince there are sixteen players on the football team, the total cost for all of them is 16 * $47 = $752.\n\nThe final answer is: $\\boxed{752}$<|im_end|>\n' - # Process dataset - formatted_dataset = train_ds.map( - formatting_prompts_func, - batched=True, - remove_columns=train_ds.column_names - ) - print(len(formatted_dataset)) - print(formatted_dataset[0]) + dataset = dataset.map(formatting_prompts_func, batched=True, remove_columns=dataset.column_names) + + # Debug tokenizer output - show examples + print("Example of tokenized output:") + print(dataset[5]["text"]) + print("\nAnother example:") + print(dataset[100]["text"]) # Configure trainer trainer = UnslothTrainer( model = model, tokenizer = tokenizer, - train_dataset = formatted_dataset, + train_dataset = dataset, dataset_text_field = "text", max_seq_length = max_seq_length, dataset_num_proc = 8, - packing = True, + args = UnslothTrainingArguments( - per_device_train_batch_size = 8, + learning_rate = 5e-5, + embedding_learning_rate = 5e-6, + per_device_train_batch_size = 8, # With gradient_accumulation_steps=8 this gives effective batch size 64 gradient_accumulation_steps = 8, - warmup_ratio = 0.1, + lr_scheduler_type = "cosine", num_train_epochs = 3, - # learning_rate = 5e-6, - # embedding_learning_rate = 5e-7, - learning_rate = 8e-6, - embedding_learning_rate = 1e-6, + warmup_ratio = 0.1, + max_seq_length = 2048, fp16 = not is_bfloat16_supported(), bf16 = is_bfloat16_supported(), - logging_steps = 1, - optim = "adamw_torch_fused", + optim = "adamw_8bit", weight_decay = 0.01, - lr_scheduler_type = "cosine", + logging_steps = 1, seed = 3407, output_dir = "outputs", + report_to = "wandb", + run_name = "metamath", + hub_strategy = "every_save", + save_strategy = "steps", + save_steps = 100, + hub_model_id = "rawsh/MetaMath-Qwen2.5-0.5b" ), ) - # Print GPU stats + # Set up wandb + # wandb.login(key="YOUR_WANDB_KEY") # Replace with your key + # wandb.init(project='metamath') + + # Print initial GPU stats gpu_stats = torch.cuda.get_device_properties(0) start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) @@ -137,11 +182,24 @@ def formatting_prompts_func(examples): # Train trainer_stats = trainer.train() - # Save model + # Show final memory and time stats + used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) + used_memory_for_lora = round(used_memory - start_gpu_memory, 3) + used_percentage = round(used_memory/max_memory*100, 3) + lora_percentage = round(used_memory_for_lora/max_memory*100, 3) + + print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.") + print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.") + print(f"Peak reserved memory = {used_memory} GB.") + print(f"Peak reserved memory for training = {used_memory_for_lora} GB.") + print(f"Peak reserved memory % of max memory = {used_percentage} %.") + print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.") + + # Save model to HuggingFace Hub model.push_to_hub_merged( - "rawsh/MetaMath-Qwen2.5-0.5b", + "rawsh/MetaMath-Qwen2.5-0.5b", # Replace with your username tokenizer, - save_method = "merged_16bit" + save_method="merged_16bit", ) if __name__ == "__main__": diff --git a/mcts/tree_search_mathrm.py b/mcts/tree_search_mathrm.py new file mode 100644 index 0000000..9d5c9e0 --- /dev/null +++ b/mcts/tree_search_mathrm.py @@ -0,0 +1,715 @@ +import asyncio +import math +import aiohttp +import json +from openai import AsyncOpenAI +import time +import random # Added for jitter in retry logic +from functools import wraps +from collections import OrderedDict +from asyncio import Semaphore, TimeoutError +from datasets import load_dataset +from tqdm import tqdm +from tqdm.asyncio import tqdm as atqdm + +POLICY_MODEL_NAME = 'MetaMath-Qwen2.5-0.5b' +POLICY_URL = 'https://rawsh--vllm-qwen-metamath-serve.modal.run/v1/' +# POLICY_URL = 'https://rawsh--vllm-qwen-metamath-serve-dev.modal.run/v1/' +PRM_URL = 'https://rawsh--vllm-qwen-prm-serve.modal.run/v1/' +PRM_MODEL_NAME = 'MetaMath-Qwen2.5-0.5b-PRM' +API_KEY = '9FF74944EED19865193F979942FB1' + +# Global clients +POLICY_CLIENT = AsyncOpenAI(base_url=POLICY_URL, api_key=API_KEY) +PRM_CLIENT = AsyncOpenAI(base_url=PRM_URL, api_key=API_KEY) + +# More aggressive semaphore limits +CONCURRENT_MCTS_SEMAPHORE = Semaphore(50) +POLICY_SEMAPHORE = Semaphore(100) +PRM_SEMAPHORE = Semaphore(100) + +# More aggressive retry settings +MAX_RETRIES = 10 +TIMEOUT = 10 + + +# Cache decorator and retry function +def async_lru_cache(maxsize=2000): + cache = OrderedDict() + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + key = str(args) + str(kwargs) + if key not in cache: + if len(cache) >= maxsize: + cache.popitem(last=False) + cache[key] = await func(*args, **kwargs) + return cache[key] + return wrapper + return decorator + +async def retry_with_timeout(func, *args, **kwargs): + for attempt in range(MAX_RETRIES): + try: + return await asyncio.wait_for(func(*args, **kwargs), timeout=TIMEOUT) + except (TimeoutError, Exception) as e: + print(f"WARNING: timeout during attempt {attempt}") + if attempt == MAX_RETRIES - 1: + raise + # Faster backoff + delay = min(0.1 * (attempt + 1), 1.0) + await asyncio.sleep(delay) + + +class Node: + def __init__(self, state, parent=None): + self.state = state + self.parent = parent + self.children = {} + self.visits = 0 + self.total_value = 0 + self.prm_value = None + self.step_scores = None + + def __hash__(self): + return hash(self.state) + + def __eq__(self, other): + return isinstance(other, Node) and self.state == other.state + + async def evaluate(self): + """Evaluate this node, caching step scores for reuse by children.""" + if self.step_scores is None: + # Get parts of the solution + parts = self.state.split("\n\n") + if len(parts) < 2: + print("WARNING: len(parts) < 2") + self.step_scores = [] + self.prm_value = 1e-10 + return self.prm_value + + # # Evaluate only the new step if we can reuse parent scores + # if self.parent and self.parent.step_scores is not None: + # self.step_scores = self.parent.step_scores.copy() + # new_prefix = self.state + # new_score = await evaluate_step(new_prefix) + # self.step_scores.append(new_score) + # else: + # # Evaluate all steps for root or if parent scores not available + # self.step_scores = [] + # for i in range(2, len(parts) + 1): + # prefix = "\n\n".join(parts[:i]) + # score = await evaluate_step(prefix) + # self.step_scores.append(score) + # # Calculate average score + # self.prm_value = sum(self.step_scores) / len(self.step_scores) if self.step_scores else 1e-10 + # self.step_scores = self.parent.step_scores.copy() + new_prefix = self.state + new_score = await evaluate_step(new_prefix) + self.prm_value = new_score + + return self.prm_value + +# Progress tracking class with added metrics +class MCTSProgress: + def __init__(self, total_questions, iterations_per_question): + self.total_questions = total_questions + self.total_iterations = total_questions * iterations_per_question + self.iterations_per_question = iterations_per_question + self.completed_iterations = 0 + self.correct_sc = 0 # Self-consistency correct count + self.correct_any = 0 # Any-correct count + self.correct_best = 0 # Best PRM path correct count + self.total_actions = 0 # Global action counter + self.questions_with_terminal = 0 # Questions with at least one terminal path + self.fully_completed_questions = 0 # Questions that completed all iterations + + # Single progress bar with dynamic description + self.pbar = tqdm(total=self.total_iterations, + desc=self.get_progress_description()) + + def get_progress_description(self): + sc_pct = (self.correct_sc / max(1, self.fully_completed_questions)) * 100 + any_pct = (self.correct_any / max(1, self.fully_completed_questions)) * 100 + best_pct = (self.correct_best / max(1, self.fully_completed_questions)) * 100 + q_pct = (self.questions_with_terminal / self.total_questions) * 100 + return (f"#Q ({self.questions_with_terminal}/{self.total_questions}): {q_pct:.0f}% | " + f"SC: {sc_pct:.1f}% | " + f"ANY: {any_pct:.1f}% | " + f"BEST: {best_pct:.1f}% | " + f"Actions: {self.total_actions}") + + def increment_iteration(self): + self.completed_iterations += 1 + self.pbar.update(1) + # No need to update description here + + def complete_question(self, is_sc_correct, is_any_correct, is_best_correct, is_fully_completed, has_terminal_nodes): + if has_terminal_nodes: + self.questions_with_terminal += 1 + if is_fully_completed: + self.fully_completed_questions += 1 + if is_sc_correct: + self.correct_sc += 1 + if is_any_correct: + self.correct_any += 1 + if is_best_correct: + self.correct_best += 1 + self.pbar.set_description(self.get_progress_description()) + + def close(self): + # Print final statistics + if self.fully_completed_questions > 0: + sc_pct = (self.correct_sc / self.fully_completed_questions) * 100 + any_pct = (self.correct_any / self.fully_completed_questions) * 100 + best_pct = (self.correct_best / self.fully_completed_questions) * 100 + print(f"\nFinal Results:") + print(f"Questions with Terminal Paths: {self.questions_with_terminal}") + print(f"Fully Completed Questions: {self.fully_completed_questions}") + print(f"Self-Consistency Accuracy: {sc_pct:.2f}% ({self.correct_sc}/{self.fully_completed_questions})") + print(f"Any-Correct Accuracy: {any_pct:.2f}% ({self.correct_any}/{self.fully_completed_questions})") + print(f"Best-Path Accuracy: {best_pct:.2f}% ({self.correct_best}/{self.fully_completed_questions})") + print(f"Total Actions Taken: {self.total_actions}") + self.pbar.close() + +def select(node): + while node.children: + if len(node.children) < len(get_possible_actions(node.state)): + return node + node = best_uct_child(node) + return node + +def best_uct_child(node): + C = 1.41 + return max( + node.children.values(), + key=lambda child: (child.total_value / child.visits) + C * math.sqrt(math.log(node.visits) / child.visits) + ) + +async def expand(node, client, session, progress_tracker): + action = await retry_with_timeout(get_next_action, node.state, client) + new_state = apply_action(node.state, action) + child = Node(new_state, parent=node) + node.children[action] = child + progress_tracker.total_actions += 1 + return child + +async def simulate(node, correct_answer, client, session, terminal_nodes, progress_tracker): + current_node = node + depth = 0 + max_depth = 10 + + while depth < max_depth: + if current_node in terminal_nodes: + break + + action, is_term = await retry_with_timeout(get_next_action, current_node.state, client) + new_state = apply_action(current_node.state, action) + child_node = Node(new_state, parent=current_node) + progress_tracker.total_actions += 1 + + if is_term or is_correct(new_state, correct_answer): + terminal_nodes.add(child_node) + current_node = child_node + break + + current_node = child_node + depth += 1 + + return await retry_with_timeout(evaluate_state, current_node.state, session) + + +def backpropagate(node, value): + while node: + node.visits += 1 + node.total_value += value + node = node.parent + +async def get_next_action(state, client): + steps = state.split("\n\n") + question = steps[0] + answer = "\n\n".join(steps[1:]) if len(steps) > 1 else None + + messages = [ + {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, + {"role": "user", "content": question} + ] + partial_answer = "" + if answer: + messages.append({"role": "assistant", "content": answer + "\n\n"}) + partial_answer = f"{answer}\n\n" + else: + messages.append({"role": "assistant", "content": ""}) + + + # WITH prefill + # Final Statistics: + # Total Questions: 100 + # Self-Consistency Accuracy: 61.00% + # Any-Correct Accuracy: 90.00% + # Best: 51.0% + # ------ + # Final Statistics: + # Total Questions: 100 + # Self-Consistency Accuracy: 58.00% + # Any-Correct Accuracy: 82.00% + # Best: 64% + # ------ + response = await client.chat.completions.create( + model=POLICY_MODEL_NAME, + messages=messages, + max_tokens=150, + stop=["<|endoftext|>", "<|im_end|>", "\n\n"], + # temperature=0.7, + temperature=0.7, + # top_p=0.8, + extra_body={ + "repetition_penalty": 1.05, + "top_p": 0.8, + "top_k": 20, + "frequency_penalty": 0.1, + "presence_penalty": 0.1, + # "add_generation_prompt": True, + } + ) + content = response.choices[0].message.content.strip() + + # # Final Statistics: + # # Total Questions: 100 + # # Self-Consistency Accuracy: 58.00% + # # Any-Correct Accuracy: 89.00% + # # Best: 44% + # # ------ + # response = await client.completions.create( + # model=POLICY_MODEL_NAME, + # prompt=f"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n{question.strip()}<|im_end|>\n<|im_start|>assistant\n{partial_answer}", + # max_tokens=150, + # stop=["<|endoftext|>", "<|im_end|>", "\n\n", + # # "()()()", "rawrawraw", "raw()raw()raw()", + # # "rushing rushing rushing", "rawword", "**********", + # # "() () ()" + # ], + # temperature=0.7, + # extra_body={ + # "repetition_penalty": 1.05, + # # "repetition_penalty": 1.2, + # "top_p": 0.8, + # "top_k": 20, + # # "frequency_penalty": 1.05, + # # "presence_penalty": 1.05, + # # "frequency_penalty": 0.2, + # # "presence_penalty": 0.2, + # } + # ) + # content = response.choices[0].text.strip() + + + # Determine if the assistant has stopped generating due to the stop sequence + is_term = (response.choices[0].finish_reason == 'stop' and \ + response.choices[0].stop_reason != '\n\n') + + # print(content, is_term) + return content, is_term + + +def is_correct(state, correct_answer): + last_step = state.split("\n\n")[-1] + # Normalize the strings for comparison + return correct_answer.strip() in last_step.strip() + + + +# Create single global client +PRM_CLIENT = AsyncOpenAI(base_url=PRM_URL, api_key=API_KEY) + +# Cache for step scores +step_scores_cache = {} + +# q +# 1 +# 2 + +# @async_lru_cache(maxsize=10000) +# async def evaluate_step(step_prefix: str) -> float: +# """Evaluate a single solution step using PRM.""" +# steps = step_prefix.split("\n\n") +# question = steps[0] +# curr_step = steps[-1] + +# # Format messages for just this step evaluation +# if len(steps) == 2: +# messages = [ +# # {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, +# {"role": "user", "content": f"{question} Step 1: {curr_step}"} + +# ] +# else: +# messages = [ +# # {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, +# {"role": "user", "content": f"{question} Step 1: {steps[1]}"} +# ] +# for i, step in enumerate(steps[2:-1], start=2): +# messages.extend([ +# {"role": "assistant", "content": "+"}, +# {"role": "user", "content": f"Step {i}: {step}"} +# ]) +# curr_step_num = len(steps)-1 +# messages.extend([ +# {"role": "assistant", "content": "+"}, +# {"role": "user", "content": f"Step {curr_step_num}: {curr_step}"} +# ]) + +# # messages.append({"role": "assistant", "content": ""}) +# # print(messages) + +# async with PRM_SEMAPHORE: +# response = await PRM_CLIENT.chat.completions.create( +# model=PRM_MODEL_NAME, +# messages=messages, +# max_tokens=1, +# temperature=0.0, +# logprobs=True, +# top_logprobs=20, +# extra_body={ +# "repetition_penalty": 1.05, +# "top_p": 0.8, +# "top_k": 20, +# "frequency_penalty": 0.1, +# "presence_penalty": 0.1, +# "add_generation_prompt": True, +# } +# ) + +# logprobs = response.choices[0].logprobs.content[0].top_logprobs +# # Get raw probabilities, defaulting to very small number if token not found +# prob_plus = next((math.exp(lp.logprob) for lp in logprobs if lp.token == "+"), 1e-10) +# # prob_minus = next((math.exp(lp.logprob) for lp in logprobs if lp.token == "-"), 1e-10) + +# # Normalize between + and - +# # final_prob = prob_plus / (prob_plus + prob_minus) if (prob_plus + prob_minus) > 0 else 1e-10 +# final_prob = prob_plus +# return final_prob + +@async_lru_cache(maxsize=10000) +async def evaluate_step(step_prefix: str) -> float: + """Evaluate a single solution step using PRM.""" + steps = step_prefix.split("\n\n") + question = steps[0] + curr_step = steps[-1] + + # Format messages for just this step evaluation + if len(steps) == 2: + messages = [ + # {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, + {"role": "user", "content": f"{question} Step 1: {curr_step}"} + ] + else: + messages = [ + # {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, + {"role": "user", "content": f"{question} Step 1: {steps[1]}"} + ] + for i, step in enumerate(steps[2:-1], start=2): + messages.extend([ + {"role": "assistant", "content": "+"}, + {"role": "user", "content": f"Step {i}: {step}"} + ]) + curr_step_num = len(steps)-1 + messages.extend([ + {"role": "assistant", "content": "+"}, + {"role": "user", "content": f"Step {curr_step_num}: {curr_step}"} + ]) + + # messages.append({"role": "assistant", "content": ""}) + # print(messages) + + async with PRM_SEMAPHORE: + response = await PRM_CLIENT.chat.completions.create( + model=PRM_MODEL_NAME, + messages=messages, + max_tokens=1, + temperature=0.0, + logprobs=True, + top_logprobs=20, + extra_body={ + "repetition_penalty": 1.05, + "top_p": 0.8, + "top_k": 20, + "frequency_penalty": 0.1, + "presence_penalty": 0.1, + "add_generation_prompt": True, + } + ) + + logprobs = response.choices[0].logprobs.content[0].top_logprobs + # Get raw probabilities, defaulting to very small number if token not found + prob_plus = next((math.exp(lp.logprob) for lp in logprobs if lp.token == "+"), 1e-10) + # prob_minus = next((math.exp(lp.logprob) for lp in logprobs if lp.token == "-"), 1e-10) + + # Normalize between + and - + # final_prob = prob_plus / (prob_plus + prob_minus) if (prob_plus + prob_minus) > 0 else 1e-10 + final_prob = prob_plus + return final_prob + +async def evaluate_state(state, session): + """Simplified evaluate_state that creates a temporary node for evaluation.""" + node = Node(state) + score = await node.evaluate() + return score + +def apply_action(state, action): + return f"{state}\n\n{action}" + +def get_possible_actions(state): + return range(3) + +def format_state_for_policy(state): + return f"{state}\n\n" + +def format_state_for_prm(state): + return state + +def collect_leaf_nodes(node, leaf_nodes): + if not node.children: + leaf_nodes.append(node) + else: + for child in node.children.values(): + collect_leaf_nodes(child, leaf_nodes) + +async def find_best_leaf_by_prm(node, session): + """Modified to use cached node evaluations.""" + leaf_nodes = [] + collect_leaf_nodes(node, leaf_nodes) + + # Evaluate all leaves in parallel + await asyncio.gather(*( + leaf.evaluate() + for leaf in leaf_nodes + if leaf.prm_value is None + )) + + return max(leaf_nodes, key=lambda leaf: leaf.prm_value if leaf.prm_value is not None else float('-inf')) + +async def evaluate_and_store_prm(node, session): + node.prm_value = await retry_with_timeout(evaluate_state, node.state, session) + +async def mcts(root_state, correct_answer, num_iterations, session, progress_tracker): + root = Node(root_state) + client = AsyncOpenAI(base_url=POLICY_URL, api_key=API_KEY) + terminal_nodes = set() + + for i in range(num_iterations): + leaf = select(root) + + if leaf.state in terminal_nodes: + continue + + action, is_term = await retry_with_timeout(get_next_action, leaf.state, client) + new_state = apply_action(leaf.state, action) + child = Node(new_state, parent=leaf) + leaf.children[action] = child + progress_tracker.total_actions += 1 + + # Check if the last step contains the correct answer + if is_term or is_correct(new_state, correct_answer): + terminal_nodes.add(child) + value = await retry_with_timeout(evaluate_state, child.state, session) + backpropagate(child, value) + else: + value = await retry_with_timeout( + simulate, child, correct_answer, client, session, terminal_nodes, progress_tracker + ) + backpropagate(child, value) + + progress_tracker.increment_iteration() + + return root, terminal_nodes + + + +async def run_mcts(initial_state, correct_answer, num_iterations, session, progress_tracker): + async with CONCURRENT_MCTS_SEMAPHORE: + start_time = time.time() + root, terminal_nodes = await mcts(initial_state, correct_answer, num_iterations, session, progress_tracker) + end_time = time.time() + + best_leaf = await find_best_leaf_by_prm(root, session) + + terminal_paths = [] + answers = {} + max_prm_score = float('-inf') + best_prm_path_correct = False + terminal_correct_count = 0 + + for node in terminal_nodes: + await node.evaluate() + is_node_correct = is_correct(node.state, correct_answer) + if is_node_correct: + terminal_correct_count += 1 + + last_step = node.state.split("\n\n")[-1] + answer = last_step.strip() + answers[answer] = answers.get(answer, 0) + 1 + + if node.prm_value > max_prm_score: + max_prm_score = node.prm_value + best_prm_path_correct = is_node_correct + + terminal_paths.append({ + "final_state": node.state, + "score": node.prm_value, + "correct": is_node_correct + }) + + is_best_correct = is_correct(best_leaf.state, correct_answer) + + # Determine self-consistency correctness + is_sc_correct = False + if answers: + most_common_answer = max(answers.items(), key=lambda x: x[1])[0] + is_sc_correct = correct_answer.strip() in most_common_answer + + is_any_correct = terminal_correct_count > 0 + is_fully_completed = num_iterations == progress_tracker.iterations_per_question + + result = { + "question": initial_state, + "correct_answer": correct_answer, + "statistics": { + "num_iterations": num_iterations, + "execution_time": end_time - start_time, + "total_terminal_nodes": len(terminal_nodes), + "correct_terminal_nodes": terminal_correct_count, + "self_consistency_correct": is_sc_correct, + "any_correct": is_any_correct, + "has_terminal_nodes": len(terminal_nodes) > 0, + "best_prm_path_correct": best_prm_path_correct, + "fully_completed": is_fully_completed + }, + "best_path": { + "final_state": best_leaf.state, + "score": best_leaf.prm_value, + "correct": is_best_correct + }, + "terminal_paths": terminal_paths + } + + progress_tracker.complete_question( + is_sc_correct, is_any_correct, best_prm_path_correct, is_fully_completed, len(terminal_nodes) > 0 + ) + return result + + +# 100 questions 10 iter (+/-) +# Final Statistics: +# Total Questions: 100 +# Self-Consistency Accuracy: 67.00% +# Any-Correct Accuracy: 87.00% +# Best: 64% + +# 100 questions 20 iter (+/-) +# Final Statistics: +# Total Questions: 100 +# Self-Consistency Accuracy: 74.00% +# Any-Correct Accuracy: 95.00% +# Best: 72.0% + +# 100 questions 30 iter (+/-) +# Final Statistics: +# Total Questions: 100 +# Self-Consistency Accuracy: 72.00% +# Any-Correct Accuracy: 93.00% +# Best: 69% + +async def main(): + # Set random seed for reproducibility + random.seed(0) + + def process(example): + example["answer"] = example["answer"].split("\n#### ")[-1].strip() + return example + + gsm8k = load_dataset("openai/gsm8k", "main", split="test").shuffle(seed=42) + gsm8k = gsm8k.map(process, num_proc=24) + initial_states = [(example["question"], example["answer"]) for example in gsm8k] + initial_states = random.sample(initial_states, 100) + num_iterations = 20 + + print("cold starting policy vllm + prm api") + + # warm up the chat API + client = AsyncOpenAI(base_url=POLICY_URL, api_key=API_KEY) + + async with aiohttp.ClientSession() as session: + # First warm up vLLM API + completion_promise = client.chat.completions.create( + model=POLICY_MODEL_NAME, + messages=[ + {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, + # {"role": "user", "content": "Which is bigger, 9.11 or 9.9?"} + {"role": "user", "content": "What is 5+45+4=?"} + ], + # prompt="<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhich is larger 9.11 or 9.9? Respond with just the answer.<|im_end|>\n<|im_start|>assistant\n", # <|im_end|>\n", + stop=["<|im_end|>"], + # eos_token="<|im_end|>", + temperature=0.3, + max_tokens=200, + ) + + if False: + completion = await completion_promise + print(completion) + assert(len(completion.choices) == 1) + print("warmed up vllm") + return + + # Then warm up PRM api + prm_client = AsyncOpenAI(base_url=PRM_URL, api_key=API_KEY) + prm_response = await prm_client.chat.completions.create( + model=PRM_MODEL_NAME, + messages=[ + {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, + {"role": "user", "content": "1+1=2"}, + {"role": "assistant", "content": "+"}, + {"role": "user", "content": "Next, 2+2=4"} + ], + max_tokens=1, + temperature=0.0, + logprobs=True, + top_logprobs=20 + ) + assert(len(prm_response.choices) == 1) + print("warmed up PRM api") + + completion = await completion_promise + assert(len(completion.choices) == 1) + print("warmed up vllm") + + # Initialize progress tracker + progress_tracker = MCTSProgress(len(initial_states), num_iterations) + + tasks = [] + for state, answer in initial_states: + tasks.append(run_mcts(state, answer, num_iterations, session, progress_tracker)) + + results = await asyncio.gather(*tasks) + + progress_tracker.close() + + # Calculate and print final statistics + total_questions = len(results) + sc_correct = sum(1 for r in results if r["statistics"]["self_consistency_correct"]) + any_correct = sum(1 for r in results if r["statistics"]["any_correct"]) + + print(f"\nFinal Statistics:") + print(f"Total Questions: {total_questions}") + print(f"Self-Consistency Accuracy: {(sc_correct/total_questions)*100:.2f}%") + print(f"Any-Correct Accuracy: {(any_correct/total_questions)*100:.2f}%") + + # Write results + with open("mcts_results.jsonl", "w") as f: + for result in results: + f.write(json.dumps(result) + "\n") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/modal_vllm_chat.py b/modal_vllm_chat.py index ddb4e91..f8e8044 100644 --- a/modal_vllm_chat.py +++ b/modal_vllm_chat.py @@ -18,7 +18,7 @@ def download_model_to_image(model_dir, model_name, model_revision): MODEL_DIR = "/qwen" MODEL_NAME = "rawsh/MetaMath-Qwen2.5-0.5b" -MODEL_REVISION = "a1a6e9afd500586ce620efa67e278a8dd3ac575e" +MODEL_REVISION = "779b469ef1bb4ef8faac05e46b94c09d38112194" vllm_image = ( modal.Image.debian_slim(python_version="3.10") @@ -71,7 +71,7 @@ async def lifespan(app): @app.function( image=vllm_image, gpu=modal.gpu.A10G(count=N_GPU), - container_idle_timeout=2 * MINUTES, + container_idle_timeout=5 * MINUTES, timeout=20 * MINUTES, allow_concurrent_inputs=1000, secrets=[modal.Secret.from_name("vllm-token")] @@ -148,12 +148,21 @@ async def setup_engine(): ] # Qwen chat template with exact formatting - TEMPLATE = """<|im_start|>system -You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|> -{% for message in messages %}<|im_start|>{{ message['role'] }} -{{ message['content'] }}<|im_end|> -{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant -{% endif %}""" +# TEMPLATE = """{%- for message in messages %} +# {{- '<|im_start|>' + message.role + '\n' + message.content.strip() + '\n<|im_end|>\n' }} +# {%- endfor %} +# {%- if add_generation_prompt %} +# {{- '<|im_start|>assistant\n' }} +# {%- endif %}""" +#NICEE +# TEMPLATE = """{%- for message in messages %} +# {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} +# {%- endfor %} +# <|im_start|>assistant +# """ + TEMPLATE = """{%- for message in messages %}{%- set content = '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' %}{%- if loop.last and message.role == 'assistant' %}{%- set content = '<|im_start|>' + message.role + '\n' + message.content %}{%- endif %}{{- content }}{%- endfor %}""" +# TEMPLATE = """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} +# {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}""" # Set up completion endpoint api_server.completion = lambda s: OpenAIServingCompletion( @@ -174,7 +183,7 @@ async def setup_engine(): prompt_adapters=[], request_logger=request_logger, response_role="assistant", - chat_template=TEMPLATE + chat_template=TEMPLATE, ) return web_app \ No newline at end of file diff --git a/modal_vllm_prm.py b/modal_vllm_prm.py new file mode 100644 index 0000000..d9c027f --- /dev/null +++ b/modal_vllm_prm.py @@ -0,0 +1,182 @@ +import modal +import asyncio +from contextlib import asynccontextmanager + +def download_model_to_image(model_dir, model_name, model_revision): + import os + from huggingface_hub import snapshot_download + from transformers.utils import move_cache + + os.makedirs(model_dir, exist_ok=True) + snapshot_download( + model_name, + revision=model_revision, + local_dir=model_dir, + ignore_patterns=["*.pt", "*.bin"], # Using safetensors + ) + move_cache() + +MODEL_DIR = "/qwen" +MODEL_NAME = "rawsh/MetaMath-Qwen2.5-0.5b-PRM" +MODEL_REVISION = "d230f00aa86b0967a4ee474df3c1f616f7ee7c57" + +vllm_image = ( + modal.Image.debian_slim(python_version="3.10") + .pip_install( + "vllm==0.6.2", + "torch==2.4.0", + "transformers>=4.45", + "ray==2.36.0", + "hf-transfer==0.1.8", + "huggingface_hub==0.25.0", + ) + .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) + .run_function( + download_model_to_image, + timeout=60 * 20, + secrets=[modal.Secret.from_name("hf-token")], + kwargs={ + "model_dir": MODEL_DIR, + "model_name": MODEL_NAME, + "model_revision": MODEL_REVISION, + }, + ) + .env({"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"}) +) + +app = modal.App("vllm-qwen-prm") + +N_GPU = 1 +MINUTES = 60 +HOURS = 60 * MINUTES + +async def get_model_config(engine): + try: + return await engine.get_model_config() + except Exception as e: + print(f"Error getting model config: {e}") + raise + +@asynccontextmanager +async def lifespan(app): + try: + await asyncio.sleep(0) + yield + finally: + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + +@app.function( + image=vllm_image, + gpu=modal.gpu.A10G(count=N_GPU), + container_idle_timeout=5 * MINUTES, + timeout=20 * MINUTES, + allow_concurrent_inputs=1000, + secrets=[modal.Secret.from_name("vllm-token")] +) +@modal.asgi_app() +def serve(): + import os + import fastapi + import vllm.entrypoints.openai.api_server as api_server + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm.entrypoints.logger import RequestLogger + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat + from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion + from vllm.entrypoints.openai.serving_engine import BaseModelPath + from vllm.usage.usage_lib import UsageContext + from transformers import AutoTokenizer + + web_app = fastapi.FastAPI( + title=f"OpenAI-compatible {MODEL_NAME} server", + description="Run an OpenAI-compatible LLM server with vLLM on modal.com", + version="0.0.1", + docs_url="/docs", + lifespan=lifespan + ) + + http_bearer = fastapi.security.HTTPBearer( + scheme_name="Bearer Token", + description="See code for authentication details.", + ) + web_app.add_middleware( + fastapi.middleware.cors.CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + TOKEN = os.environ["API_TOKEN"] + async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): + if api_key.credentials != TOKEN: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + ) + return {"username": "authenticated_user"} + + router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)]) + router.include_router(api_server.router) + web_app.include_router(router) + + engine_args = AsyncEngineArgs( + model=MODEL_DIR, + tensor_parallel_size=N_GPU, + gpu_memory_utilization=0.90, + max_model_len=8096, + enforce_eager=False, + enable_prefix_caching=True + ) + + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER + ) + + async def setup_engine(): + model_config = await get_model_config(engine) + return model_config + + model_config = asyncio.run(setup_engine()) + request_logger = RequestLogger(max_log_len=2048) + + base_model_paths = [ + BaseModelPath(name=MODEL_NAME.split("/")[1], model_path=MODEL_NAME) + ] + + # Qwen chat template with exact formatting + TEMPLATE = """<|im_start|>system +You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|> +{% for message in messages %}<|im_start|>{{ message['role'] }} +{{ message['content'] }}<|im_end|> +{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant +{% endif %}""" + + # TEMPLATE = """{%- for message in messages %}{%- set content = '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' %}{%- if loop.last and message.role == 'assistant' %}{%- set content = '<|im_start|>' + message.role + '\n' + message.content %}{%- endif %}{{- content }}{%- endfor %}""" + + # Set up completion endpoint + api_server.completion = lambda s: OpenAIServingCompletion( + engine, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=[], + prompt_adapters=[], + request_logger=request_logger, + ) + + # Set up chat endpoint with tokenizer's chat template + api_server.chat = lambda s: OpenAIServingChat( + engine, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=[], + prompt_adapters=[], + request_logger=request_logger, + response_role="assistant", + chat_template=TEMPLATE + ) + + return web_app \ No newline at end of file diff --git a/test_vllm_prm.py b/test_vllm_prm.py new file mode 100644 index 0000000..b37ca59 --- /dev/null +++ b/test_vllm_prm.py @@ -0,0 +1,117 @@ +from openai import AsyncOpenAI +import asyncio +import math + +class RewardModelClient: + def __init__(self): + self.client = AsyncOpenAI( + base_url="https://rawsh--vllm-qwen-prm-serve.modal.run/v1/", + api_key="9FF74944EED19865193F979942FB1" + ) + self.model_name = "MetaMath-Qwen2.5-0.5b-PRM" + + async def get_token_probability(self, response) -> float: + """Extract probability of + token from response""" + logprobs = response.choices[0].logprobs.content[0].top_logprobs + + # Print tokens and their probabilities for debugging + token_probs = {lp.token: math.exp(lp.logprob) for lp in logprobs} + print("Available tokens and probs:", token_probs) + + # Get raw probabilities, defaulting to very small number if token not found + prob_plus = next((math.exp(lp.logprob) for lp in logprobs if lp.token == "+"), 1e-10) + prob_minus = next((math.exp(lp.logprob) for lp in logprobs if lp.token == "-"), 1e-10) + + # Normalize between + and - + return prob_plus / (prob_plus + prob_minus) if (prob_plus + prob_minus) > 0 else 0.5 + + async def evaluate_steps(self, question: str, steps: list[str]) -> list[float]: + """ + Evaluate each step in the solution getting probabilities of + vs - + Returns probability of + for each step + """ + probabilities = [] + + # First evaluate question + first step + messages = [ + {"role": "user", "content": f"{question}\n{steps[0]}"} + ] + + try: + response = await self.client.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=1, + temperature=0, + logprobs=True, + top_logprobs=20 + ) + prob = await self.get_token_probability(response) + probabilities.append(prob) + + except Exception as e: + print(f"Error evaluating first step: {str(e)}") + probabilities.append(0.5) + + # For remaining steps + for i in range(1, len(steps)): + try: + # Build conversation including previous steps + messages = [ + {"role": "user", "content": f"{question}\n{steps[0]}"} + ] + + for prev_step in steps[1:i]: + messages.extend([ + {"role": "assistant", "content": "+"}, + {"role": "user", "content": prev_step} + ]) + + messages.append({"role": "assistant", "content": "+"}) + messages.append({"role": "user", "content": steps[i]}) + + response = await self.client.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=1, + temperature=0, + logprobs=True, + top_logprobs=20 + ) + + prob = await self.get_token_probability(response) + probabilities.append(prob) + + except Exception as e: + print(f"Error evaluating step {i+1}: {str(e)}") + probabilities.append(0.5) + + return probabilities + +async def main(): + # Initialize client + reward_model = RewardModelClient() + + # Example problem + question = "Janet has 3 apples and buys 2 more. How many apples does she have?" + steps = [ + "Step 1: If Janet has 3 apples and buys 2 more, total apples = 3 + 2 = 5.", + "Step 2: Therefore, Janet has 5 apples. The answer is: 5", + ] + + try: + # Get evaluations + probabilities = await reward_model.evaluate_steps(question, steps) + + # Print results + print("\nResults:") + print("Question:", question) + print("\nStep Evaluations:") + for step, prob in zip(steps, probabilities): + print(f"P(+) = {prob:.3f}: {step}") + + except Exception as e: + print(f"Error occurred: {str(e)}") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file