From 2807e0fcf5c6ca85daa6d2363fa9a57b63800371 Mon Sep 17 00:00:00 2001 From: Robert Washbourne Date: Sat, 16 Nov 2024 22:25:13 +0000 Subject: [PATCH] updoot --- mcts/process_results.py | 56 ++-- mcts/simple_sample.py | 5 +- mcts/train_policy_orpo.py | 7 +- mcts/train_policy_sft_metamath.py | 148 +++++++++ mcts/train_reward.py | 63 ++-- mcts/tree_search.py | 20 +- mcts/tree_search_chat.py | 488 +++++++++++++++++++++++++++++ modal_prm_reward.py | 6 +- modal_train_policy_orpo.py | 7 +- modal_train_policy_sft_metamath.py | 51 +++ modal_train_prm_rlhf_flow.py | 68 ++++ modal_train_prm_st.py | 8 +- modal_vllm.py | 41 +-- modal_vllm_chat.py | 180 +++++++++++ prm_rlhf_flow/qwen.yml | 80 +++++ 15 files changed, 1129 insertions(+), 99 deletions(-) create mode 100644 mcts/train_policy_sft_metamath.py create mode 100644 mcts/tree_search_chat.py create mode 100644 modal_train_policy_sft_metamath.py create mode 100644 modal_train_prm_rlhf_flow.py create mode 100644 modal_vllm_chat.py create mode 100644 prm_rlhf_flow/qwen.yml diff --git a/mcts/process_results.py b/mcts/process_results.py index 65d4f92..621c5ee 100644 --- a/mcts/process_results.py +++ b/mcts/process_results.py @@ -133,7 +133,7 @@ def analyze_question(self, question: dict) -> QuestionAnalysis: def get_paired_examples( self, analyses: List[QuestionAnalysis], - max_pairs: int = 10000, + max_pairs: int = 20000, top_n_correct: int = 10, top_n_incorrect: int = 10 ) -> List[Dict[str, Any]]: @@ -261,7 +261,6 @@ def get_paired_examples( def generate_prm_training_data(self, analyses: List[QuestionAnalysis]) -> List[Dict[str, Any]]: """Generate training data for Process Reward Model (PRM) from MCTS paths.""" prm_examples = [] - seen_examples = set() # Track unique (question, steps) combinations original_correct_lengths = [] original_incorrect_lengths = [] @@ -280,19 +279,6 @@ def generate_prm_training_data(self, analyses: List[QuestionAnalysis]) -> List[D for k, step in enumerate(path.steps, 1): partial_steps = path.steps[:k] - - # Create unique key based on question and step sequence - example_key = ( - hash(analysis.question_text), - hash(str(partial_steps)) - ) - - # Skip if we've seen this exact example - if example_key in seen_examples: - continue - - seen_examples.add(example_key) - m_k = K - k r_s_k = 0 w_s_k = (1 - v_prev) / (m_k + 1) * (1 - 2 * r_s_k) @@ -323,19 +309,6 @@ def generate_prm_training_data(self, analyses: List[QuestionAnalysis]) -> List[D for k, step in enumerate(path.steps, 1): partial_steps = path.steps[:k] - - # Create unique key based on question and step sequence - example_key = ( - hash(analysis.question_text), - hash(str(partial_steps)) - ) - - # Skip if we've seen this exact example - if example_key in seen_examples: - continue - - seen_examples.add(example_key) - penalize = k == K m_k = K - k if not penalize else K - k + 1 r_s_k = 0 if not penalize else 1 @@ -356,15 +329,34 @@ def generate_prm_training_data(self, analyses: List[QuestionAnalysis]) -> List[D }) v_prev = v_k - # Print statistics about duplicates avoided - print(f"\nTotal examples generated: {len(prm_examples)}") - print(f"Unique (question, steps) combinations: {len(seen_examples)}") - print(f"Duplicates avoided: {len(seen_examples) - len(prm_examples)}") + # Record length statistics + if original_correct_lengths: + print("\nOriginal Path Length Statistics:") + print(f"Correct paths mean length: {np.mean(original_correct_lengths):.1f} (±{np.std(original_correct_lengths):.1f})") + if original_incorrect_lengths: + print(f"Incorrect paths mean length: {np.mean(original_incorrect_lengths):.1f} (±{np.std(original_incorrect_lengths):.1f})") + + # Print complete path statistics + complete_correct = [ex for ex in prm_examples if ex["metadata"]["is_correct"] and ex["metadata"]["is_complete"]] + complete_incorrect = [ex for ex in prm_examples if not ex["metadata"]["is_correct"] and ex["metadata"]["is_complete"]] + + print("\nComplete Path Statistics:") + print(f"Complete correct paths: {len(complete_correct)}") + print(f"Complete incorrect paths: {len(complete_incorrect)}") + + if complete_correct: + print(f"Complete correct mean length: {np.mean([ex['metadata']['path_length'] for ex in complete_correct]):.1f}") + if complete_incorrect: + print(f"Complete incorrect mean length: {np.mean([ex['metadata']['path_length'] for ex in complete_incorrect]):.1f}") return prm_examples def main(): analyzer = MathReasoningAnalyzer('mcts_results.jsonl') + # analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st1_orpo.bak') + # analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st2_orpo.bak') + # analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st3_orpo.bak') + # analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st0.bak') # analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st1.bak') # analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st2-v1.bak') diff --git a/mcts/simple_sample.py b/mcts/simple_sample.py index 578abbc..8663dc0 100644 --- a/mcts/simple_sample.py +++ b/mcts/simple_sample.py @@ -19,7 +19,7 @@ MAX_RETRIES = 5 TIMEOUT = 20 MAX_CONCURRENT = 100 -SAMPLES_PER_QUESTION = 10 # Default to single sample mode, override with CLI arg +SAMPLES_PER_QUESTION = 1 # Default to single sample mode, override with CLI arg # Cache decorator for PRM scores def async_lru_cache(maxsize=2000): @@ -108,7 +108,8 @@ async def generate_completion( """Generate a single completion.""" async with semaphore: response = await client.completions.create( - model="mirrorqwen2.5-0.5b-SimPO-3", + # model="mirrorqwen2.5-0.5b-SimPO-3", + model="MetaMath-Qwen2.5-0.5b", prompt=question, max_tokens=1500, temperature=0.8 diff --git a/mcts/train_policy_orpo.py b/mcts/train_policy_orpo.py index d700cac..2c9e996 100644 --- a/mcts/train_policy_orpo.py +++ b/mcts/train_policy_orpo.py @@ -52,10 +52,11 @@ def train_orpo( per_device_train_batch_size=8, gradient_accumulation_steps=8, # learning_rate=5e-7, - learning_rate=8e-6, + # learning_rate=8e-6, lr_scheduler_type="linear", beta=0.1, - # learning_rate=5e-6, + learning_rate=3e-6, + # max_steps max_length=2048, max_prompt_length=1024, gradient_checkpointing=True, @@ -71,7 +72,7 @@ def train_orpo( # lr_scheduler_type="cosine", do_eval=True, evaluation_strategy="steps", - eval_steps=20, + eval_steps=10, remove_unused_columns=False, logging_steps=10, logging_first_step=True diff --git a/mcts/train_policy_sft_metamath.py b/mcts/train_policy_sft_metamath.py new file mode 100644 index 0000000..a52aafa --- /dev/null +++ b/mcts/train_policy_sft_metamath.py @@ -0,0 +1,148 @@ +from unsloth import FastLanguageModel +import torch +from trl import SFTTrainer +from transformers import TrainingArguments +from unsloth import is_bfloat16_supported +from unsloth import UnslothTrainer, UnslothTrainingArguments +from datasets import load_dataset +from unsloth.chat_templates import get_chat_template + +# Constants +SEED = 42 +max_seq_length = 8192 +dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ +load_in_4bit = False + +first = True + +def format_answer(response): + global first + """Extract answer from #### pattern and format response.""" + # Split at #### and get everything before it + parts = response.split('####') + if len(parts) < 2: + return None + + + solution = "\n\n".join(parts[0].strip().split("\n")) + answer = parts[1].split('The answer is:')[0].strip() + + if (first): + print(solution) + print(answer) + first = False + + return f"{solution}\n\nThe final answer is: $\\boxed{{{answer}}}$" + +def train_sft(): + # Load model and tokenizer + model, tokenizer = FastLanguageModel.from_pretrained( + model_name = "Qwen/Qwen2.5-0.5B", + max_seq_length = max_seq_length, + dtype = dtype, + load_in_4bit = load_in_4bit, + ) + + # Set up chat template + tokenizer = get_chat_template( + tokenizer, + chat_template = "qwen-2.5", + ) + + # Configure PEFT + model = FastLanguageModel.get_peft_model( + model, + r = 128, + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + "embed_tokens", "lm_head"], + lora_alpha = 32, + lora_dropout = 0, + bias = "none", + use_gradient_checkpointing = "unsloth", + random_state = 3407, + use_rslora = True, + loftq_config = None, + ) + + # Load dataset + ds = load_dataset("meta-math/MetaMathQA") + train_ds = ds['train'] + + # Format prompts + def formatting_prompts_func(examples): + conversations = [] + for query, response in zip(examples['query'], examples['response']): + formatted_response = format_answer(response) + if formatted_response is None: + continue + + conversation = [ + {"role": "user", "content": query}, + {"role": "assistant", "content": formatted_response} + ] + conversations.append(conversation) + + texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) + for convo in conversations] + return {"text": texts} + + # <|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is the total cost of purchasing equipment for all sixteen players on the football team, considering that each player requires a $25 jersey, a $15.20 pair of shorts, and a pair of socks priced at $6.80?<|im_end|>\n<|im_start|>assistant\nEach player requires a $25 jersey, a $15.20 pair of shorts, and a pair of socks priced at $6.80.\n\nSo the total cost for each player is $25 + $15.20 + $6.80 = $47.\n\nSince there are sixteen players on the football team, the total cost for all of them is 16 * $47 = $752.\n\nThe final answer is: $\\boxed{752}$<|im_end|>\n' + + # Process dataset + formatted_dataset = train_ds.map( + formatting_prompts_func, + batched=True, + remove_columns=train_ds.column_names + ) + print(len(formatted_dataset)) + print(formatted_dataset[0]) + + # Configure trainer + trainer = UnslothTrainer( + model = model, + tokenizer = tokenizer, + train_dataset = formatted_dataset, + dataset_text_field = "text", + max_seq_length = max_seq_length, + dataset_num_proc = 8, + packing = True, + args = UnslothTrainingArguments( + per_device_train_batch_size = 8, + gradient_accumulation_steps = 8, + warmup_ratio = 0.1, + num_train_epochs = 3, + # learning_rate = 5e-6, + # embedding_learning_rate = 5e-7, + learning_rate = 8e-6, + embedding_learning_rate = 1e-6, + fp16 = not is_bfloat16_supported(), + bf16 = is_bfloat16_supported(), + logging_steps = 1, + optim = "adamw_torch_fused", + weight_decay = 0.01, + lr_scheduler_type = "cosine", + seed = 3407, + output_dir = "outputs", + ), + ) + + # Print GPU stats + gpu_stats = torch.cuda.get_device_properties(0) + start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) + max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) + print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") + print(f"{start_gpu_memory} GB of memory reserved.") + + # Train + trainer_stats = trainer.train() + + # Save model + model.push_to_hub_merged( + "rawsh/MetaMath-Qwen2.5-0.5b", + tokenizer, + save_method = "merged_16bit" + ) + +if __name__ == "__main__": + train_sft() \ No newline at end of file diff --git a/mcts/train_reward.py b/mcts/train_reward.py index e14041c..e97b13a 100644 --- a/mcts/train_reward.py +++ b/mcts/train_reward.py @@ -85,31 +85,40 @@ def tokenize(sample): # Step 2: Assign bin number to each sample in training data def assign_bin(example): final_step_reward = example['final_step_reward'] - # Calculate bin number (bins: 0.0-0.1 => bin 0, ..., 0.9-1.0 => bin 9) - bin_number = int(final_step_reward * 10) - if bin_number == 10: - bin_number = 9 # Handle the edge case where final_step_reward == 1.0 - example['bin'] = bin_number + if final_step_reward <= 0.1: + # Samples with rewards <= 0.1 get assigned to bin -1 (won't be balanced) + example['bin'] = -1 + else: + # Calculate bin number for rewards > 0.1 (bins: 0.1-0.2 => bin 0, ..., 0.9-1.0 => bin 8) + bin_number = int((final_step_reward - 0.1) * 10) + if bin_number == 9: + bin_number = 8 # Handle the edge case where final_step_reward == 1.0 + example['bin'] = bin_number return example ds_train = ds_train.map(assign_bin, num_proc=24) - # Step 3: Get counts of samples in each bin for training data - bin_counts_train = Counter(ds_train['bin']) - print("Training bin counts before undersampling:", bin_counts_train) + # Step 3: Separate low reward samples and get counts for other bins + low_reward_indices = [idx for idx, bin_num in enumerate(ds_train['bin']) if bin_num == -1] + bin_counts_train = Counter([b for b in ds_train['bin'] if b >= 0]) + print("Training bin counts before undersampling (excluding ≤0.1):", bin_counts_train) - # Determine the minimum count across all bins in training data - min_count_train = min(bin_counts_train.values()) - print("Training minimum count per bin:", min_count_train) + # Determine the minimum count across all bins in training data (excluding bin -1) + min_count_train = min(bin_counts_train.values()) if bin_counts_train else 0 + print("Training minimum count per bin (excluding ≤0.1):", min_count_train) # Step 4: Create a mapping from bin to indices for training data - bin_to_indices_train = {i: [] for i in range(10)} # Bins 0 to 9 + bin_to_indices_train = {i: [] for i in range(9)} # Bins 0 to 8 (for rewards 0.1-1.0) for idx, bin_number in enumerate(ds_train['bin']): - bin_to_indices_train[bin_number].append(idx) + if bin_number >= 0: # Only process samples with rewards > 0.1 + bin_to_indices_train[bin_number].append(idx) # Randomly sample min_count_train indices per bin for training data random.seed(42) selected_indices_train = [] + # First add all low reward samples (≤0.1) + selected_indices_train.extend(low_reward_indices) + # Then sample from other bins for bin_number, indices in bin_to_indices_train.items(): if len(indices) >= min_count_train: sampled_indices = random.sample(indices, min_count_train) @@ -122,12 +131,13 @@ def assign_bin(example): # Step 5: Create the balanced training dataset train_dataset = ds_train.select(selected_indices_train) - print("Total training samples after undersampling:", len(train_dataset)) + print("Total training samples after processing:", len(train_dataset)) + print("- Samples with reward ≤0.1:", len(low_reward_indices)) + print("- Samples per bin >0.1:", min_count_train) else: train_dataset = ds_train - # Now, build the evaluation dataset - # Load and shuffle the evaluation dataset + # Now, build the evaluation dataset similarly ds_eval = load_dataset(eval_path, split="train").shuffle(seed=42) ds_eval = ds_eval.map(tokenize, num_proc=24) @@ -135,22 +145,27 @@ def assign_bin(example): # Assign bins to evaluation data ds_eval = ds_eval.map(assign_bin, num_proc=24) - # Get counts of samples in each bin for evaluation data - bin_counts_eval = Counter(ds_eval['bin']) - print("Evaluation bin counts before undersampling:", bin_counts_eval) + # Separate low reward samples and get counts for other bins + eval_low_reward_indices = [idx for idx, bin_num in enumerate(ds_eval['bin']) if bin_num == -1] + bin_counts_eval = Counter([b for b in ds_eval['bin'] if b >= 0]) + print("Evaluation bin counts before undersampling (excluding ≤0.1):", bin_counts_eval) # Determine the minimum count per bin for evaluation data # Set it to be 10% of min_count_train, at least 1 eval_min_count_per_bin = max(1, int(min_count_train * 0.1)) - print("Evaluation minimum count per bin:", eval_min_count_per_bin) + print("Evaluation minimum count per bin (excluding ≤0.1):", eval_min_count_per_bin) # Create a mapping from bin to indices for evaluation data - bin_to_indices_eval = {i: [] for i in range(10)} # Bins 0 to 9 + bin_to_indices_eval = {i: [] for i in range(9)} # Bins 0 to 8 for idx, bin_number in enumerate(ds_eval['bin']): - bin_to_indices_eval[bin_number].append(idx) + if bin_number >= 0: # Only process samples with rewards > 0.1 + bin_to_indices_eval[bin_number].append(idx) # Randomly sample eval_min_count_per_bin indices per bin for evaluation data selected_indices_eval = [] + # First add all low reward samples (≤0.1) + selected_indices_eval.extend(eval_low_reward_indices) + # Then sample from other bins for bin_number, indices in bin_to_indices_eval.items(): if len(indices) >= eval_min_count_per_bin: sampled_indices = random.sample(indices, eval_min_count_per_bin) @@ -163,7 +178,9 @@ def assign_bin(example): # Create the balanced evaluation dataset eval_dataset = ds_eval.select(selected_indices_eval) - print("Total evaluation samples after undersampling:", len(eval_dataset)) + print("Total evaluation samples after processing:", len(eval_dataset)) + print("- Evaluation samples with reward ≤0.1:", len(eval_low_reward_indices)) + print("- Evaluation samples per bin >0.1:", eval_min_count_per_bin) else: eval_dataset = ds_eval diff --git a/mcts/tree_search.py b/mcts/tree_search.py index 679fca1..da2368c 100644 --- a/mcts/tree_search.py +++ b/mcts/tree_search.py @@ -17,7 +17,10 @@ # POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-SimPO-3' # POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-SimPO-0' # POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-SFT' -POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-ORPO-1' +# POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-ORPO-1' +# POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-ORPO-2' +# POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-ORPO-3' +POLICY_MODEL_NAME = 'MetaMath-Qwen2.5-0.5b' # POLICY_URL = 'https://rawsh--vllm-qwen-simpo-serve.modal.run/v1/' # POLICY_URL = 'https://rawsh--vllm-qwen-base-serve.modal.run/v1/' POLICY_URL = 'https://rawsh--vllm-qwen-orpo-serve.modal.run/v1/' @@ -30,7 +33,7 @@ PRM_SEMAPHORE = Semaphore(1000) MAX_RETRIES = 20 # Increased from 10s -TIMEOUT = 20 # Decreased from 30 to fail faster and retry +TIMEOUT = 30 # Decreased from 30 to fail faster and retry # Cache decorator and retry function def async_lru_cache(maxsize=2000): @@ -363,9 +366,9 @@ async def run_mcts(initial_state, correct_answer, num_iterations, session, progr async def main(): # Set random seed for reproducibility - # random.seed(0) # eval set - all models + random.seed(0) # eval set - all models # random.seed(42) # st 0 - random.seed(4242) # st 1 + # random.seed(4242) # st 1 # random.seed(424242) # st 2 # random.seed(42424242) # st 3 @@ -373,15 +376,15 @@ def process(example): example["answer"] = example["answer"].split("\n#### ")[-1].strip() return example - # gsm8k = load_dataset("openai/gsm8k", "main", split="test").shuffle(seed=42) - gsm8k = load_dataset("openai/gsm8k", "main", split="train").shuffle(seed=42) + gsm8k = load_dataset("openai/gsm8k", "main", split="test").shuffle(seed=42) + # gsm8k = load_dataset("openai/gsm8k", "main", split="train").shuffle(seed=42) gsm8k = gsm8k.map(process, num_proc=24) initial_states = [(example["question"], example["answer"]) for example in gsm8k] # SAMPLE 200 QUESTIONS - SELF TRAINING - initial_states = random.sample(initial_states, 200) + initial_states = random.sample(initial_states, 10) # initial_states = random.sample(initial_states, 1000) - num_iterations = 100 + num_iterations = 250 print("cold starting policy vllm + prm api") @@ -394,6 +397,7 @@ def process(example): stop=["\n\n"], temperature=0.3, logprobs=20, + response_role="assistant" ) async with aiohttp.ClientSession() as session: diff --git a/mcts/tree_search_chat.py b/mcts/tree_search_chat.py new file mode 100644 index 0000000..1370c49 --- /dev/null +++ b/mcts/tree_search_chat.py @@ -0,0 +1,488 @@ +import asyncio +import math +import aiohttp +import json +from openai import AsyncOpenAI +import time +import random # Added for jitter in retry logic +from functools import wraps +from collections import OrderedDict +from asyncio import Semaphore, TimeoutError +from datasets import load_dataset +from tqdm import tqdm +from tqdm.asyncio import tqdm as atqdm + +# URLs and configuration +# POLICY_URL = 'https://rawsh--vllm-qwen-ft-serve.modal.run/v1/' +# POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-SimPO-3' +# POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-SimPO-0' +# POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-SFT' +# POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-ORPO-1' +# POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-ORPO-2' +# POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-ORPO-3' +POLICY_MODEL_NAME = 'MetaMath-Qwen2.5-0.5b' +# POLICY_URL = 'https://rawsh--vllm-qwen-simpo-serve.modal.run/v1/' +# POLICY_URL = 'https://rawsh--vllm-qwen-base-serve.modal.run/v1/' +# POLICY_URL = 'https://rawsh--vllm-qwen-orpo-serve.modal.run/v1/' +POLICY_URL = 'https://rawsh--vllm-qwen-metamath-serve.modal.run/v1/' +# PRM_URL = 'https://rawsh--mirrorqwen-prm-embedder-score-output.modal.run' +PRM_URL = 'https://rawsh--mirrorqwen-prm-st-embedder-score-output.modal.run' +API_KEY = '9FF74944EED19865193F979942FB1' + +CONCURRENT_MCTS_SEMAPHORE = Semaphore(20) +POLICY_SEMAPHORE = Semaphore(1000) +PRM_SEMAPHORE = Semaphore(1000) + +MAX_RETRIES = 20 # Increased from 10s +TIMEOUT = 10 # Decreased from 30 to fail faster and retry + +# Cache decorator and retry function +def async_lru_cache(maxsize=2000): + cache = OrderedDict() + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + key = str(args) + str(kwargs) + if key not in cache: + if len(cache) >= maxsize: + cache.popitem(last=False) + cache[key] = await func(*args, **kwargs) + return cache[key] + return wrapper + return decorator + +async def retry_with_timeout(func, *args, **kwargs): + for attempt in range(MAX_RETRIES): + try: + return await asyncio.wait_for(func(*args, **kwargs), timeout=TIMEOUT * max(1, attempt / 10 )) + except TimeoutError: + if attempt == MAX_RETRIES - 1: + raise + # Exponential backoff with jitter + delay = min(1.5 ** attempt + random.random(), 10) + await asyncio.sleep(delay) + except Exception as e: + if attempt == MAX_RETRIES - 1: + raise + # Exponential backoff with jitter for other errors + delay = min(1.5 ** attempt + random.random(), 10) + print(f"Attempt {attempt + 1} failed with error: {str(e)}. Retrying in {delay:.1f}s...") + await asyncio.sleep(delay) + +class Node: + def __init__(self, state, parent=None): + self.state = state + self.parent = parent + self.children = {} + self.visits = 0 + self.total_value = 0 + self.prm_value = None + +# Progress tracking class with added metrics +class MCTSProgress: + def __init__(self, total_questions, iterations_per_question): + self.total_questions = total_questions + self.total_iterations = total_questions * iterations_per_question + self.iterations_per_question = iterations_per_question + self.completed_iterations = 0 + self.correct_sc = 0 # Self-consistency correct count + self.correct_any = 0 # Any-correct count + self.correct_best = 0 # Best PRM path correct count + self.total_actions = 0 # Global action counter + self.questions_with_terminal = 0 # Questions with at least one terminal path + self.fully_completed_questions = 0 # Questions that completed all iterations + + # Single progress bar with dynamic description + self.pbar = tqdm(total=self.total_iterations, + desc=self.get_progress_description()) + + def get_progress_description(self): + sc_pct = (self.correct_sc / max(1, self.fully_completed_questions)) * 100 + any_pct = (self.correct_any / max(1, self.fully_completed_questions)) * 100 + best_pct = (self.correct_best / max(1, self.fully_completed_questions)) * 100 + q_pct = (self.questions_with_terminal / self.total_questions) * 100 + return (f"#Q ({self.questions_with_terminal}/{self.total_questions}): {q_pct:.0f}% | " + f"SC: {sc_pct:.1f}% | " + f"ANY: {any_pct:.1f}% | " + f"BEST: {best_pct:.1f}% | " + f"Actions: {self.total_actions}") + + def increment_iteration(self): + self.completed_iterations += 1 + self.pbar.update(1) + # No need to update description here + + def complete_question(self, is_sc_correct, is_any_correct, is_best_correct, is_fully_completed, has_terminal_nodes): + if has_terminal_nodes: + self.questions_with_terminal += 1 + if is_fully_completed: + self.fully_completed_questions += 1 + if is_sc_correct: + self.correct_sc += 1 + if is_any_correct: + self.correct_any += 1 + if is_best_correct: + self.correct_best += 1 + self.pbar.set_description(self.get_progress_description()) + + def close(self): + # Print final statistics + if self.fully_completed_questions > 0: + sc_pct = (self.correct_sc / self.fully_completed_questions) * 100 + any_pct = (self.correct_any / self.fully_completed_questions) * 100 + best_pct = (self.correct_best / self.fully_completed_questions) * 100 + print(f"\nFinal Results:") + print(f"Questions with Terminal Paths: {self.questions_with_terminal}") + print(f"Fully Completed Questions: {self.fully_completed_questions}") + print(f"Self-Consistency Accuracy: {sc_pct:.2f}% ({self.correct_sc}/{self.fully_completed_questions})") + print(f"Any-Correct Accuracy: {any_pct:.2f}% ({self.correct_any}/{self.fully_completed_questions})") + print(f"Best-Path Accuracy: {best_pct:.2f}% ({self.correct_best}/{self.fully_completed_questions})") + print(f"Total Actions Taken: {self.total_actions}") + self.pbar.close() + +def select(node): + while node.children: + if len(node.children) < len(get_possible_actions(node.state)): + return node + node = best_uct_child(node) + return node + +def best_uct_child(node): + C = 1.41 + return max( + node.children.values(), + key=lambda child: (child.total_value / child.visits) + C * math.sqrt(math.log(node.visits) / child.visits) + ) + +async def expand(node, client, session, progress_tracker): + action = await retry_with_timeout(get_next_action, node.state, client) + new_state = apply_action(node.state, action) + child = Node(new_state, parent=node) + node.children[action] = child + progress_tracker.total_actions += 1 + return child + +async def simulate(node, correct_answer, client, session, terminal_nodes, progress_tracker): + state = node.state + depth = 0 + max_depth = 10 + while depth < max_depth: + is_term, is_corr = await retry_with_timeout(is_terminal, state, correct_answer, client, session) + if is_term: + terminal_nodes.add(state) + break + action = await retry_with_timeout(get_next_action, state, client) + state = apply_action(state, action) + progress_tracker.total_actions += 1 + depth += 1 + return await retry_with_timeout(evaluate_state, state, session) + +def backpropagate(node, value): + while node: + node.visits += 1 + node.total_value += value + node = node.parent + +async def get_next_action(state, client): + prompt = format_state_for_policy(state) + async with POLICY_SEMAPHORE: + steps = prompt.split("\n\n") + question = steps[0] + answer = None + if len(steps) > 0: + answer = "\n\n".join(steps[1:]) + + messages = [ + {"role": "user", "content": question} + ] + if answer is not None: + messages.append({"role": "assistant", "content": answer}) + + response = await client.chat.completions.create( + model=POLICY_MODEL_NAME, + messages=messages, + max_tokens=250, + stop=["\n\n"], + temperature=0.8, + ) + # return response.choices[0].text.strip() + return response.choices[0].message.content.strip() + +def is_correct(state, correct_answer): + last_step = state.split("\n\n")[-1] + return fr"\boxed{{{correct_answer}}}" in last_step + +async def is_terminal(state, correct_answer, client, session): + if is_correct(state, correct_answer): + return True, True + + if state.count("\n\n") < 2: + return False, False + + async with POLICY_SEMAPHORE: + steps = state.split("\n\n") + question = steps[0] + answer = None + if len(steps) > 0: + answer = "\n\n".join(steps[1:]) + + messages = [ + {"role": "user", "content": question} + ] + if answer is not None: + messages.append({"role": "assistant", "content": answer}) + + response = await client.chat.completions.create( + model=POLICY_MODEL_NAME, + messages=messages, + max_tokens=1, + stop=["\n\n"], + temperature=0.8, + logprobs=True, + top_logprobs=20 + ) + # response = await client.completions.create( + # model=POLICY_MODEL_NAME, + # prompt=state, + # max_tokens=1, + # stop=["\n\n"], + # temperature=0.3, + # logprobs=20, + # ) + + first_token_top_logprobs = response.choices[0].logprobs.content[0].top_logprobs + first_token_top_logprobs_map = dict() + for token_logprob in first_token_top_logprobs: + first_token_top_logprobs_map[token_logprob.token] = token_logprob.logprob + + if "" in first_token_top_logprobs_map: + scaled = math.exp(first_token_top_logprobs_map[""]) + yes_bigger_than_no = True + if "\n\n" in first_token_top_logprobs_map: + scaled_no = math.exp(first_token_top_logprobs_map["\n\n"]) + yes_bigger_than_no = (scaled > scaled_no) + + threshold = 0.95 + terminal = (scaled >= threshold) and yes_bigger_than_no + return terminal, False + else: + return False, False + +@async_lru_cache(maxsize=1000) +async def evaluate_state(state, session): + prompt = format_state_for_prm(state) + async with PRM_SEMAPHORE: + async with session.post(PRM_URL, json={"prompt": prompt}) as response: + result = await response.json() + return float(result['score']) + +def apply_action(state, action): + return f"{state}\n\n{action}" + +def get_possible_actions(state): + return range(3) + +def format_state_for_policy(state): + return f"{state}\n\n" + +def format_state_for_prm(state): + return state + +def collect_leaf_nodes(node, leaf_nodes): + if not node.children: + leaf_nodes.append(node) + else: + for child in node.children.values(): + collect_leaf_nodes(child, leaf_nodes) + +async def find_best_leaf_by_prm(node, session): + leaf_nodes = [] + collect_leaf_nodes(node, leaf_nodes) + tasks = [] + for leaf in leaf_nodes: + if leaf.prm_value is None: + tasks.append(evaluate_and_store_prm(leaf, session)) + await asyncio.gather(*tasks) + return max(leaf_nodes, key=lambda leaf: leaf.prm_value if leaf.prm_value is not None else float('-inf')) + +async def evaluate_and_store_prm(node, session): + node.prm_value = await retry_with_timeout(evaluate_state, node.state, session) + +async def mcts(root_state, correct_answer, num_iterations, session, progress_tracker): + root = Node(root_state) + client = AsyncOpenAI(base_url=POLICY_URL, api_key=API_KEY) + terminal_nodes = set() + + for i in range(num_iterations): + leaf = select(root) + is_term, is_corr = await retry_with_timeout(is_terminal, leaf.state, correct_answer, client, session) + + if is_term: + terminal_nodes.add(leaf.state) + else: + child = await retry_with_timeout(expand, leaf, client, session, progress_tracker) + value = await retry_with_timeout(simulate, child, correct_answer, client, session, terminal_nodes, progress_tracker) + backpropagate(child, value) + + progress_tracker.increment_iteration() + + return root, terminal_nodes + + +async def run_mcts(initial_state, correct_answer, num_iterations, session, progress_tracker): + async with CONCURRENT_MCTS_SEMAPHORE: + start_time = time.time() + root, terminal_nodes = await mcts(initial_state, correct_answer, num_iterations, session, progress_tracker) + end_time = time.time() + + best_leaf = await find_best_leaf_by_prm(root, session) + + terminal_paths = [] + answers = {} # Track answer frequencies + max_prm_score = float('-inf') + best_prm_path_correct = False + terminal_correct_count = 0 # Add this counter + + for node in terminal_nodes: + score = await retry_with_timeout(evaluate_state, node, session) + is_node_correct = is_correct(node, correct_answer) + if is_node_correct: + terminal_correct_count += 1 # Increment counter + + # Extract answer from the node + last_step = node.split("\n\n")[-1] + if r"\boxed{" in last_step: + answer = last_step.split(r"\boxed{")[1].split("}")[0] + answers[answer] = answers.get(answer, 0) + 1 + + if score > max_prm_score: + max_prm_score = score + best_prm_path_correct = is_node_correct + + terminal_paths.append({ + "final_state": node, + "score": score, + "correct": is_node_correct + }) + + is_best_correct = is_correct(best_leaf.state, correct_answer) + + # Calculate SC using most common answer + has_terminal_nodes = len(terminal_nodes) > 0 + is_sc_correct = False + if has_terminal_nodes and answers: + most_common_answer = max(answers.items(), key=lambda x: x[1])[0] + is_sc_correct = any(p["correct"] and most_common_answer == p["final_state"].split(r"\boxed{")[1].split("}")[0] + for p in terminal_paths) + + is_any_correct = any(p["correct"] for p in terminal_paths) + is_fully_completed = len(terminal_nodes) > 0 and num_iterations == progress_tracker.iterations_per_question + + result = { + "question": initial_state, + "correct_answer": correct_answer, + "statistics": { + "num_iterations": num_iterations, + "execution_time": end_time - start_time, + "total_terminal_nodes": len(terminal_nodes), # Use len() directly + "correct_terminal_nodes": terminal_correct_count, + "self_consistency_correct": is_sc_correct, + "any_correct": is_any_correct, + "has_terminal_nodes": has_terminal_nodes, + "best_prm_path_correct": best_prm_path_correct, + "fully_completed": is_fully_completed + }, + "best_path": { + "final_state": best_leaf.state, + "score": best_leaf.prm_value, + "correct": is_best_correct + }, + "terminal_paths": terminal_paths + } + + progress_tracker.complete_question(is_sc_correct, is_any_correct, best_prm_path_correct, is_fully_completed, has_terminal_nodes) + return result + +async def main(): + # Set random seed for reproducibility + random.seed(0) # eval set - all models + # random.seed(42) # st 0 + # random.seed(4242) # st 1 + # random.seed(424242) # st 2 + # random.seed(42424242) # st 3 + + def process(example): + example["answer"] = example["answer"].split("\n#### ")[-1].strip() + return example + + gsm8k = load_dataset("openai/gsm8k", "main", split="test").shuffle(seed=42) + # gsm8k = load_dataset("openai/gsm8k", "main", split="train").shuffle(seed=42) + gsm8k = gsm8k.map(process, num_proc=24) + initial_states = [(example["question"], example["answer"]) for example in gsm8k] + + # SAMPLE 200 QUESTIONS - SELF TRAINING + initial_states = random.sample(initial_states, 100) + # initial_states = random.sample(initial_states, 1000) + num_iterations = 10 + + print("cold starting policy vllm + prm api") + + # warm up the chat API + client = AsyncOpenAI(base_url=POLICY_URL, api_key=API_KEY) + completion_promise = client.chat.completions.create( + model=POLICY_MODEL_NAME, + messages=[ + {"role": "user", "content": "Which is larger 9.11 or 9.9? Respond with just the answer."} + ], + # max_tokens=3, + # stop=["\n\n"], + stop=["<|endoftext|>"], + temperature=0.8, + # Note: logprobs is not available in chat format + ) + # res = await completion_promise + # print(res) + # return + + async with aiohttp.ClientSession() as session: + # warm up PRM api + async with session.post(PRM_URL, json={"prompt": "TEST"}) as response: + prm_promise = response.json() + prm_score = await prm_promise + assert('score' in prm_score) + print("warmed up PRM api") + + completion = await completion_promise + assert(len(completion.choices) == 1) + print(completion.choices[0]) + print("warmed up vllm") + + + # Initialize progress tracker + progress_tracker = MCTSProgress(len(initial_states), num_iterations) + + tasks = [] + for state, answer in initial_states: + tasks.append(run_mcts(state, answer, num_iterations, session, progress_tracker)) + + results = await asyncio.gather(*tasks) + + progress_tracker.close() + + # Calculate and print final statistics + total_questions = len(results) + sc_correct = sum(1 for r in results if r["statistics"]["self_consistency_correct"]) + any_correct = sum(1 for r in results if r["statistics"]["any_correct"]) + + print(f"\nFinal Statistics:") + print(f"Total Questions: {total_questions}") + print(f"Self-Consistency Accuracy: {(sc_correct/total_questions)*100:.2f}%") + print(f"Any-Correct Accuracy: {(any_correct/total_questions)*100:.2f}%") + + # Write results + with open("mcts_results.jsonl", "w") as f: + for result in results: + f.write(json.dumps(result) + "\n") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/modal_prm_reward.py b/modal_prm_reward.py index fc15a47..8f241a2 100644 --- a/modal_prm_reward.py +++ b/modal_prm_reward.py @@ -46,11 +46,15 @@ def _process_batch(prompts: List[str]) -> List[Dict]: class Embedder: model_id = "rawsh/mirrorqwen2.5-0.5b-prm" # revision = "894341fbd81d0c1abdd98b4e0630de932aa63c6f" # base - revision = "42e07d1b708282ac2aae338050d8116f8c69398d" # st0 + # revision = "42e07d1b708282ac2aae338050d8116f8c69398d" # st0 # revision = "65f4a7601dffacc40e0ef7fa4733d346c926bd18" # st1 v1 # revision = "80da7ccc4f107e0cb6bf937d61be4702badfb96b" # st1 v2 # revision = "4d618515c90069993f4b32e4201783efdeebbc22" # st2 # revision = "b052380b619e5c62ce9f407522362f5caf7b8346" # st3 + # note: orpo 1 st for prm used strong/weak to generate samples. + # inference pair to gen data for orpo 2 was orpo 1 policy + st0 + # revision = "e49e4ca7c847194be48c42c52ad8f871da204300" # orpo2 + revision = "ecae5a74ef094d6e839dcb2a32500c36e6786ad1" # orpo3 device = "cuda" print(model_id) diff --git a/modal_train_policy_orpo.py b/modal_train_policy_orpo.py index ef310bc..c4cf23d 100644 --- a/modal_train_policy_orpo.py +++ b/modal_train_policy_orpo.py @@ -54,14 +54,15 @@ def train_policy_orpo(): # Run training with specified parameters train_orpo( # model_name="rawsh/mirrorqwen2.5-0.5b-SFT", - model_name="rawsh/mirrorqwen2.5-0.5b-ORPO-1", + # model_name="rawsh/mirrorqwen2.5-0.5b-ORPO-1", + model_name="rawsh/mirrorqwen2.5-0.5b-ORPO-2", # model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-0", # model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-1", # model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-2", # dataset_name="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ST-3", # dataset_name="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ST-0", - dataset_name="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ORPO-1", - output_model_name="rawsh/mirrorqwen2.5-0.5b-ORPO-2", + dataset_name="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ORPO-2", + output_model_name="rawsh/mirrorqwen2.5-0.5b-ORPO-3", hub_token=hf_token ) except Exception as e: diff --git a/modal_train_policy_sft_metamath.py b/modal_train_policy_sft_metamath.py new file mode 100644 index 0000000..bb05f87 --- /dev/null +++ b/modal_train_policy_sft_metamath.py @@ -0,0 +1,51 @@ +import modal + +cuda_version = "12.4.0" # should be no greater than host CUDA version +flavor = "devel" # includes full CUDA toolkit +operating_sys = "ubuntu22.04" +tag = f"{cuda_version}-{flavor}-{operating_sys}" + +image = ( + # modal.Image.debian_slim() + modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") + .apt_install("git") + .pip_install("torch") + .pip_install("packaging") + .pip_install("wheel") + .run_commands("pip install flash-attn --no-build-isolation") + .pip_install("transformers") + .pip_install("accelerate") + .pip_install("numpy") + .pip_install("datasets") + .pip_install("wandb") + .pip_install("bitsandbytes") + .pip_install("unsloth @ git+https://github.com/unslothai/unsloth.git") + .pip_install("unsloth_zoo") + .pip_install("xformers") +) +app = modal.App("train_policy_sft", image=image) + +with image.imports(): + from mcts.train_policy_sft_metamath import train_sft + +MINUTES = 60 # seconds +HOURS = 60 * MINUTES + +@app.function( + cpu=2.0, + # gpu=modal.gpu.A10G(), + gpu=modal.gpu.H100(), + # gpu=modal.gpu.A100(size="40GB"), + timeout=20 * HOURS, + secrets=[ + modal.Secret.from_name("hf-token"), + modal.Secret.from_name("wandb-token") + ] +) +def train_policy_model_sft_upload_to_hf(): + train_sft() + +@app.local_entrypoint() +def main(): + # run the function remotely on Modal + train_policy_model_sft_upload_to_hf.remote() \ No newline at end of file diff --git a/modal_train_prm_rlhf_flow.py b/modal_train_prm_rlhf_flow.py new file mode 100644 index 0000000..babbb7b --- /dev/null +++ b/modal_train_prm_rlhf_flow.py @@ -0,0 +1,68 @@ +# train.py +import modal +import yaml +import os +from pathlib import Path + +# CUDA setup +AXOLOTL_REGISTRY_SHA = "9578c47333bdcc9ad7318e54506b9adaf283161092ae780353d506f7a656590a" +image = ( + modal.Image.from_registry(f"winglian/axolotl@sha256:{AXOLOTL_REGISTRY_SHA}") + .pip_install( + "huggingface_hub==0.23.2", + "hf-transfer==0.1.5", + "wandb==0.16.3", + "fastapi==0.110.0", + "pydantic==2.6.3", + ) + .env( + dict( + HUGGINGFACE_HUB_CACHE="/pretrained", + HF_HUB_ENABLE_HF_TRANSFER="1", + AXOLOTL_NCCL_TIMEOUT="60", + ) + ) + .entrypoint([]) +) + +app = modal.App("train-hf", image=image) + +# Constants +MINUTES = 60 +HOURS = 60 * MINUTES + +# Create volume for persistent storage +training_vol = modal.Volume.from_name("training-data", create_if_missing=True) + +@app.function( + cpu=8, + gpu=modal.gpu.H100(), + timeout=20 * HOURS, + volumes={"/training": training_vol}, + secrets=[ + modal.Secret.from_name("hf-token"), + modal.Secret.from_name("wandb-token") + ], +) +def run_training(config): + import subprocess + + # Write the config to the container + config_path = Path("/training/config.yml") + with open(config_path, 'w') as f: + yaml.dump(config, f) + + # Run training - Axolotl will handle HF upload if push_to_hub is True + subprocess.run(["python", "-m", "axolotl.cli.train", config_path]) + +@app.local_entrypoint() +def main(): + # Read the local config file + with open("prm_rlhf_flow/qwen.yml", 'r') as f: + config = yaml.safe_load(f) + + # Run the training + run_training.remote(config) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/modal_train_prm_st.py b/modal_train_prm_st.py index b21e708..fa32572 100644 --- a/modal_train_prm_st.py +++ b/modal_train_prm_st.py @@ -52,10 +52,12 @@ def train_reward_model_upload_to_hf(): # model_revision="aed1bcf7d3d984272e329c3843f9c5fd0dfe5ca5", # base # model_revision="42e07d1b708282ac2aae338050d8116f8c69398d", # st0 # model_revision="80da7ccc4f107e0cb6bf937d61be4702badfb96b", # st1 - model_revision="4d618515c90069993f4b32e4201783efdeebbc22", # st2 - dataset_path="rawsh/mirrorqwen2.5-0.5B-gsm8k-PRM-data-ST-2", + # model_revision="4d618515c90069993f4b32e4201783efdeebbc22", # st2 + # fucked up orpo2 prm - it used st0 as base model as well. + model_revision="e49e4ca7c847194be48c42c52ad8f871da204300", # orpo2 + dataset_path="rawsh/mirrorqwen2.5-0.5B-gsm8k-PRM-data-ORPO-2", output_model_name="rawsh/mirrorqwen2.5-0.5b-prm", - disable_binning=True + disable_binning=False ) @app.local_entrypoint() diff --git a/modal_vllm.py b/modal_vllm.py index f066157..a500df9 100644 --- a/modal_vllm.py +++ b/modal_vllm.py @@ -17,24 +17,10 @@ def download_model_to_image(model_dir, model_name, model_revision): move_cache() MODEL_DIR = "/qwen" -MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-ORPO-1" -MODEL_REVISION = "a3e4731f8fb3384b07ba112a37cbcc2d4f531623" -# MODEL_DIR = "/qwen" -# MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SFT" -# MODEL_REVISION = "1f75c1204888cc912ad0b186c5b7620235246ffa" -# # st0 -# MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SimPO-0" -# MODEL_REVISION = "c699a3f7e82a805d6a4b158b033c5d7919230dd1" -# st1 -# MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SimPO-1" -# MODEL_REVISION = "4ba061377ace8d0fb15802aaf943b4184420ea7d" -# st1 v2 -# MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SimPO-1" -# MODEL_REVISION = "9e6d25903688b5678bdbe333c537a58488212024" -# MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SimPO-2" -# MODEL_REVISION = "a41b6dd0307cf080a83cf20efc25bbf025b47852" -# MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SimPO-3" -# MODEL_REVISION = "4bf9608e31850cf1020de695d99f0c1fb9e0575f" +MODEL_NAME = "rawsh/MetaMath-Qwen2.5-0.5b" +MODEL_REVISION = "286ca8b160074c923b89c318652ab4b979627550" +# MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-ORPO-3" +# MODEL_REVISION = "4b3e3eb18fe84477ee949058484ec951a5b8beb6" vllm_image = ( modal.Image.debian_slim(python_version="3.10") @@ -60,9 +46,7 @@ def download_model_to_image(model_dir, model_name, model_revision): .env({"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"}) ) -# app = modal.App("vllm-qwen-simpo") -# app = modal.App("vllm-qwen-base") -app = modal.App("vllm-qwen-orpo") +app = modal.App("vllm-qwen-metamath") N_GPU = 1 MINUTES = 60 @@ -91,11 +75,8 @@ async def lifespan(app): @app.function( image=vllm_image, gpu=modal.gpu.A10G(count=N_GPU), - # gpu=modal.gpu.T4(), - # gpu=modal.gpu.A100(), container_idle_timeout=2 * MINUTES, timeout=20 * MINUTES, - # allow_concurrent_inputs=1000, allow_concurrent_inputs=1000, secrets=[modal.Secret.from_name("vllm-token")] ) @@ -173,6 +154,7 @@ async def setup_engine(): BaseModelPath(name=MODEL_NAME.split("/")[1], model_path=MODEL_NAME) ] + # Set up completion endpoint api_server.completion = lambda s: OpenAIServingCompletion( engine, model_config=model_config, @@ -182,4 +164,15 @@ async def setup_engine(): request_logger=request_logger, ) + # Set up chat endpoint + api_server.chat = lambda s: OpenAIServingChat( + engine, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=[], + prompt_adapters=[], + request_logger=request_logger, + response_role="assistant" + ) + return web_app \ No newline at end of file diff --git a/modal_vllm_chat.py b/modal_vllm_chat.py new file mode 100644 index 0000000..ddb4e91 --- /dev/null +++ b/modal_vllm_chat.py @@ -0,0 +1,180 @@ +import modal +import asyncio +from contextlib import asynccontextmanager + +def download_model_to_image(model_dir, model_name, model_revision): + import os + from huggingface_hub import snapshot_download + from transformers.utils import move_cache + + os.makedirs(model_dir, exist_ok=True) + snapshot_download( + model_name, + revision=model_revision, + local_dir=model_dir, + ignore_patterns=["*.pt", "*.bin"], # Using safetensors + ) + move_cache() + +MODEL_DIR = "/qwen" +MODEL_NAME = "rawsh/MetaMath-Qwen2.5-0.5b" +MODEL_REVISION = "a1a6e9afd500586ce620efa67e278a8dd3ac575e" + +vllm_image = ( + modal.Image.debian_slim(python_version="3.10") + .pip_install( + "vllm==0.6.2", + "torch==2.4.0", + "transformers>=4.45", + "ray==2.36.0", + "hf-transfer==0.1.8", + "huggingface_hub==0.25.0", + ) + .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) + .run_function( + download_model_to_image, + timeout=60 * 20, + secrets=[modal.Secret.from_name("hf-token")], + kwargs={ + "model_dir": MODEL_DIR, + "model_name": MODEL_NAME, + "model_revision": MODEL_REVISION, + }, + ) + .env({"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"}) +) + +app = modal.App("vllm-qwen-metamath") + +N_GPU = 1 +MINUTES = 60 +HOURS = 60 * MINUTES + +async def get_model_config(engine): + try: + return await engine.get_model_config() + except Exception as e: + print(f"Error getting model config: {e}") + raise + +@asynccontextmanager +async def lifespan(app): + try: + await asyncio.sleep(0) + yield + finally: + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + +@app.function( + image=vllm_image, + gpu=modal.gpu.A10G(count=N_GPU), + container_idle_timeout=2 * MINUTES, + timeout=20 * MINUTES, + allow_concurrent_inputs=1000, + secrets=[modal.Secret.from_name("vllm-token")] +) +@modal.asgi_app() +def serve(): + import os + import fastapi + import vllm.entrypoints.openai.api_server as api_server + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm.entrypoints.logger import RequestLogger + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat + from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion + from vllm.entrypoints.openai.serving_engine import BaseModelPath + from vllm.usage.usage_lib import UsageContext + from transformers import AutoTokenizer + + web_app = fastapi.FastAPI( + title=f"OpenAI-compatible {MODEL_NAME} server", + description="Run an OpenAI-compatible LLM server with vLLM on modal.com", + version="0.0.1", + docs_url="/docs", + lifespan=lifespan + ) + + http_bearer = fastapi.security.HTTPBearer( + scheme_name="Bearer Token", + description="See code for authentication details.", + ) + web_app.add_middleware( + fastapi.middleware.cors.CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + TOKEN = os.environ["API_TOKEN"] + async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): + if api_key.credentials != TOKEN: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + ) + return {"username": "authenticated_user"} + + router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)]) + router.include_router(api_server.router) + web_app.include_router(router) + + engine_args = AsyncEngineArgs( + model=MODEL_DIR, + tensor_parallel_size=N_GPU, + gpu_memory_utilization=0.90, + max_model_len=8096, + enforce_eager=False, + enable_prefix_caching=True + ) + + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER + ) + + async def setup_engine(): + model_config = await get_model_config(engine) + return model_config + + model_config = asyncio.run(setup_engine()) + request_logger = RequestLogger(max_log_len=2048) + + base_model_paths = [ + BaseModelPath(name=MODEL_NAME.split("/")[1], model_path=MODEL_NAME) + ] + + # Qwen chat template with exact formatting + TEMPLATE = """<|im_start|>system +You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|> +{% for message in messages %}<|im_start|>{{ message['role'] }} +{{ message['content'] }}<|im_end|> +{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant +{% endif %}""" + + # Set up completion endpoint + api_server.completion = lambda s: OpenAIServingCompletion( + engine, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=[], + prompt_adapters=[], + request_logger=request_logger, + ) + + # Set up chat endpoint with tokenizer's chat template + api_server.chat = lambda s: OpenAIServingChat( + engine, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=[], + prompt_adapters=[], + request_logger=request_logger, + response_role="assistant", + chat_template=TEMPLATE + ) + + return web_app \ No newline at end of file diff --git a/prm_rlhf_flow/qwen.yml b/prm_rlhf_flow/qwen.yml new file mode 100644 index 0000000..fd68abd --- /dev/null +++ b/prm_rlhf_flow/qwen.yml @@ -0,0 +1,80 @@ +# config.yml +base_model: Qwen/Qwen2.5-0.5B +# base_model: rawsh/MetaMath-Qwen2.5-0.5b +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer + +# HuggingFace settings +push_to_hub: true # Enable direct upload to HF +hub_model_id: "rawsh/MetaMath-Qwen2.5-0.5b-PRM" # Target repo name +hub_strategy: "every_save" # or "end", "checkpoint", "all_checkpoints" + +# Model loading settings +load_in_8bit: false +load_in_4bit: false +strict: false + +# # Dataset configuration +# chat_template: llama3 +# datasets: +# - path: RLHFlow/Mistral-PRM-Data +# type: chat_template +# split: "train" +# train_on_split: "train" +# field_messages: conversations +# message_field_role: role +# message_field_content: content + + +datasets: + - path: RLHFlow/Mistral-PRM-Data + conversation: llama3 + type: sharegpt + split: "train" + train_on_split: "train" + +# Training settings +warmup_ratio: 0.05 +val_set_size: 0.0 +output_dir: /training/prm +train_on_inputs: false + +# Weights & Biases settings +wandb_project: "preference-models" +wandb_name: "qwen2.5-0.5b-bs32_lr2e-6_prm" +# wandb_watch: false +# wandb_log_model: false + +# Model saving settings +save_safetensors: true +dataset_prepared_path: /training/data/prepared + +# Training hyperparameters +sequence_len: 8192 +sample_packing: true +pad_to_sequence_len: true +trust_remote_code: true +gradient_checkpointing: true +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: paged_adamw_32bit +lr_scheduler: cosine +learning_rate: 2.0e-6 +weight_decay: 0.0 +max_grad_norm: 1.0 + +# Hardware settings +bf16: auto +fp16: false +tf32: true +flash_attention: true + +# Logging and checkpointing +logging_steps: 2 +save_strategy: "epoch" +save_total_limit: 4 + +# Special tokens +special_tokens: + pad_token: <|endoftext|> \ No newline at end of file