From de905844d1e1e069183afac43c3d77cb75c4005f Mon Sep 17 00:00:00 2001 From: Robert Washbourne Date: Tue, 12 Nov 2024 04:48:17 +0000 Subject: [PATCH] doot --- mcts/process_results.py | 174 +++++++++++++++------- mcts/simple_sample.py | 296 +++++++++++++++++++++++++++++++++++++ mcts/train_policy_orpo.py | 125 ++++++++++++++++ mcts/tree_search.py | 20 ++- modal_prm_reward.py | 7 +- modal_train_policy_orpo.py | 86 +++++++++++ modal_vllm.py | 13 +- 7 files changed, 652 insertions(+), 69 deletions(-) create mode 100644 mcts/simple_sample.py create mode 100644 mcts/train_policy_orpo.py create mode 100644 modal_train_policy_orpo.py diff --git a/mcts/process_results.py b/mcts/process_results.py index ad998a7..65d4f92 100644 --- a/mcts/process_results.py +++ b/mcts/process_results.py @@ -129,22 +129,23 @@ def analyze_question(self, question: dict) -> QuestionAnalysis: incorrect_paths=incorrect_paths, answer_distribution=answers ) - - # New version of get_paired_examples: + def get_paired_examples( self, analyses: List[QuestionAnalysis], max_pairs: int = 10000, - top_n_correct: int = 50 # New parameter + top_n_correct: int = 10, + top_n_incorrect: int = 10 ) -> List[Dict[str, Any]]: - """Get paired examples considering multiple correct paths per question""" + """Get paired examples with diverse incorrect paths for each correct path""" paired_examples = [] + seen_pairs = set() # Track unique pairs to avoid duplicates for analysis in analyses: if not analysis.correct_paths or not analysis.incorrect_paths: continue - # Sort correct paths by quality (shorter length + higher PRM score) + # Sort correct paths by quality (higher PRM score, shorter length) sorted_correct = sorted( analysis.correct_paths, key=lambda p: (-p.prm_score, p.path_length) @@ -161,39 +162,88 @@ def get_paired_examples( p for p in top_correct_paths if p.path_length <= shortest_correct_len * 1.4 ] - - # For each correct path, find the most deceptive incorrect path + + # For each correct path, select diverse incorrect paths for correct_path in filtered_correct: - # Find most deceptive incorrect path relative to this correct path - best_incorrect = max( - analysis.incorrect_paths, - key=lambda p: ( - p.prm_score, - -abs(p.path_length - correct_path.path_length) + # Calculate deceptiveness scores for all incorrect paths + scored_incorrect = [ + ( + incorrect_path, + ( + incorrect_path.prm_score, # Higher score = more deceptive + -abs(incorrect_path.path_length - correct_path.path_length), # Similar length preferred + hash(str(incorrect_path.steps)) # Use step content as tiebreaker + ) + ) + for incorrect_path in analysis.incorrect_paths + ] + + # Sort by deceptiveness score + scored_incorrect.sort(key=lambda x: x[1], reverse=True) + + # Try to select diverse incorrect paths + selected_incorrect = [] + seen_lengths = set() + seen_answers = set() + + # First pass - try to get diverse lengths and answers + for incorrect_path, _ in scored_incorrect: + pair_key = ( + hash(str(correct_path.steps)), + hash(str(incorrect_path.steps)) ) - ) + + # Skip if we've seen this exact pair + if pair_key in seen_pairs: + continue + + # Try to get diverse lengths and answers + if ( + len(selected_incorrect) < top_n_incorrect and + incorrect_path.path_length not in seen_lengths and + incorrect_path.answer not in seen_answers + ): + selected_incorrect.append(incorrect_path) + seen_pairs.add(pair_key) + seen_lengths.add(incorrect_path.path_length) + seen_answers.add(incorrect_path.answer) - paired_examples.append({ - 'question': analysis.question_text, - 'correct_answer': analysis.correct_answer, - 'metrics': { - 'sc_score': analysis.sc_score, - 'sc_correct_percent': analysis.sc_correct_percent, - 'total_paths': analysis.total_paths - }, - 'positive': { - 'steps': correct_path.steps, - 'answer': correct_path.answer, - 'prm_score': correct_path.prm_score, - 'path_length': correct_path.path_length - }, - 'negative': { - 'steps': best_incorrect.steps, - 'answer': best_incorrect.answer, - 'prm_score': best_incorrect.prm_score, - 'path_length': best_incorrect.path_length - } - }) + # Second pass - fill remaining slots if needed + if len(selected_incorrect) < top_n_incorrect: + for incorrect_path, _ in scored_incorrect: + pair_key = ( + hash(str(correct_path.steps)), + hash(str(incorrect_path.steps)) + ) + if pair_key not in seen_pairs: + selected_incorrect.append(incorrect_path) + seen_pairs.add(pair_key) + if len(selected_incorrect) >= top_n_incorrect: + break + + # Create pairs with selected incorrect paths + for incorrect_path in selected_incorrect: + paired_examples.append({ + 'question': analysis.question_text, + 'correct_answer': analysis.correct_answer, + 'metrics': { + 'sc_score': analysis.sc_score, + 'sc_correct_percent': analysis.sc_correct_percent, + 'total_paths': analysis.total_paths + }, + 'positive': { + 'steps': correct_path.steps, + 'answer': correct_path.answer, + 'prm_score': correct_path.prm_score, + 'path_length': correct_path.path_length + }, + 'negative': { + 'steps': incorrect_path.steps, + 'answer': incorrect_path.answer, + 'prm_score': incorrect_path.prm_score, + 'path_length': incorrect_path.path_length + } + }) # Sort by quality criteria including SC correct % paired_examples.sort( @@ -211,6 +261,7 @@ def get_paired_examples( def generate_prm_training_data(self, analyses: List[QuestionAnalysis]) -> List[Dict[str, Any]]: """Generate training data for Process Reward Model (PRM) from MCTS paths.""" prm_examples = [] + seen_examples = set() # Track unique (question, steps) combinations original_correct_lengths = [] original_incorrect_lengths = [] @@ -229,6 +280,19 @@ def generate_prm_training_data(self, analyses: List[QuestionAnalysis]) -> List[D for k, step in enumerate(path.steps, 1): partial_steps = path.steps[:k] + + # Create unique key based on question and step sequence + example_key = ( + hash(analysis.question_text), + hash(str(partial_steps)) + ) + + # Skip if we've seen this exact example + if example_key in seen_examples: + continue + + seen_examples.add(example_key) + m_k = K - k r_s_k = 0 w_s_k = (1 - v_prev) / (m_k + 1) * (1 - 2 * r_s_k) @@ -259,6 +323,19 @@ def generate_prm_training_data(self, analyses: List[QuestionAnalysis]) -> List[D for k, step in enumerate(path.steps, 1): partial_steps = path.steps[:k] + + # Create unique key based on question and step sequence + example_key = ( + hash(analysis.question_text), + hash(str(partial_steps)) + ) + + # Skip if we've seen this exact example + if example_key in seen_examples: + continue + + seen_examples.add(example_key) + penalize = k == K m_k = K - k if not penalize else K - k + 1 r_s_k = 0 if not penalize else 1 @@ -279,35 +356,20 @@ def generate_prm_training_data(self, analyses: List[QuestionAnalysis]) -> List[D }) v_prev = v_k - # Record length statistics - if original_correct_lengths: - print("\nOriginal Path Length Statistics:") - print(f"Correct paths mean length: {np.mean(original_correct_lengths):.1f} (±{np.std(original_correct_lengths):.1f})") - if original_incorrect_lengths: - print(f"Incorrect paths mean length: {np.mean(original_incorrect_lengths):.1f} (±{np.std(original_incorrect_lengths):.1f})") - - # Print complete path statistics - complete_correct = [ex for ex in prm_examples if ex["metadata"]["is_correct"] and ex["metadata"]["is_complete"]] - complete_incorrect = [ex for ex in prm_examples if not ex["metadata"]["is_correct"] and ex["metadata"]["is_complete"]] - - print("\nComplete Path Statistics:") - print(f"Complete correct paths: {len(complete_correct)}") - print(f"Complete incorrect paths: {len(complete_incorrect)}") - - if complete_correct: - print(f"Complete correct mean length: {np.mean([ex['metadata']['path_length'] for ex in complete_correct]):.1f}") - if complete_incorrect: - print(f"Complete incorrect mean length: {np.mean([ex['metadata']['path_length'] for ex in complete_incorrect]):.1f}") + # Print statistics about duplicates avoided + print(f"\nTotal examples generated: {len(prm_examples)}") + print(f"Unique (question, steps) combinations: {len(seen_examples)}") + print(f"Duplicates avoided: {len(seen_examples) - len(prm_examples)}") return prm_examples def main(): - # analyzer = MathReasoningAnalyzer('mcts_results.jsonl') + analyzer = MathReasoningAnalyzer('mcts_results.jsonl') # analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st0.bak') # analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st1.bak') # analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st2-v1.bak') # analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st2-v2.bak') - analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st3.bak') + # analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st3.bak') # Analyze all questions analyses = [] diff --git a/mcts/simple_sample.py b/mcts/simple_sample.py new file mode 100644 index 0000000..578abbc --- /dev/null +++ b/mcts/simple_sample.py @@ -0,0 +1,296 @@ +import asyncio +import aiohttp +from openai import AsyncOpenAI +import random +from datasets import load_dataset +from tqdm.asyncio import tqdm +from typing import List, Tuple, Dict +import json +from asyncio import Semaphore +from collections import Counter +from functools import wraps +from collections import OrderedDict + +# Configuration +POLICY_URL = 'https://rawsh--vllm-qwen-simpo-serve.modal.run/v1/' +PRM_URL = 'https://rawsh--mirrorqwen-prm-embedder-score-output.modal.run' +API_KEY = '9FF74944EED19865193F979942FB1' +BATCH_SIZE = 100 # Reduced batch size since we're doing multiple requests per question +MAX_RETRIES = 5 +TIMEOUT = 20 +MAX_CONCURRENT = 100 +SAMPLES_PER_QUESTION = 10 # Default to single sample mode, override with CLI arg + +# Cache decorator for PRM scores +def async_lru_cache(maxsize=2000): + cache = OrderedDict() + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + key = str(args) + str(kwargs) + if key not in cache: + if len(cache) >= maxsize: + cache.popitem(last=False) + cache[key] = await func(*args, **kwargs) + return cache[key] + return wrapper + return decorator + +class BatchProgress: + def __init__(self, total_questions: int, samples_per_question: int): + self.total = total_questions + self.samples = samples_per_question + self.correct_any = 0 + self.correct_best = 0 + self.correct_sc = 0 + self.processed = 0 + self.pbar = tqdm(total=total_questions, desc=self.get_description()) + + def get_description(self) -> str: + if self.processed == 0: + return "Starting..." + + any_acc = (self.correct_any / self.processed) * 100 + if self.samples > 1: + best_acc = (self.correct_best / self.processed) * 100 + sc_acc = (self.correct_sc / self.processed) * 100 + return f"Processed: {self.processed}/{self.total} | Any: {any_acc:.1f}% | Best: {best_acc:.1f}% | SC: {sc_acc:.1f}%" + else: + return f"Processed: {self.processed}/{self.total} | Accuracy: {any_acc:.1f}%" + + def update(self, any_correct: bool, best_correct: bool = None, sc_correct: bool = None): + self.processed += 1 + if any_correct: + self.correct_any += 1 + if best_correct: + self.correct_best += 1 + if sc_correct: + self.correct_sc += 1 + self.pbar.update(1) + self.pbar.set_description(self.get_description()) + + def close(self): + self.pbar.close() + if self.processed > 0: + any_acc = (self.correct_any / self.processed) * 100 + print(f"\nFinal Results:") + print(f"Total Questions: {self.processed}") + print(f"Single Sample Accuracy: {any_acc:.2f}%") + + if self.samples > 1: + best_acc = (self.correct_best / self.processed) * 100 + sc_acc = (self.correct_sc / self.processed) * 100 + print(f"Best-of-{self.samples} Accuracy: {best_acc:.2f}%") + print(f"Self-Consistency Accuracy: {sc_acc:.2f}%") + +async def retry_with_exponential_backoff(func, *args, **kwargs): + for attempt in range(MAX_RETRIES): + try: + return await asyncio.wait_for(func(*args, **kwargs), timeout=TIMEOUT) + except Exception as e: + if attempt == MAX_RETRIES - 1: + raise + delay = min(1.5 ** attempt + random.random(), 10) + await asyncio.sleep(delay) + +@async_lru_cache(maxsize=1000) +async def get_prm_score(completion: str, session: aiohttp.ClientSession) -> float: + """Get the PRM score for a completion.""" + async with session.post(PRM_URL, json={"prompt": completion}) as response: + result = await response.json() + return float(result['score']) + +async def generate_completion( + question: str, + client: AsyncOpenAI, + semaphore: Semaphore +) -> str: + """Generate a single completion.""" + async with semaphore: + response = await client.completions.create( + model="mirrorqwen2.5-0.5b-SimPO-3", + prompt=question, + max_tokens=1500, + temperature=0.8 + ) + return response.choices[0].text.strip() + +async def evaluate_question( + question: str, + answer: str, + client: AsyncOpenAI, + session: aiohttp.ClientSession, + semaphore: Semaphore, + samples_per_question: int +) -> Dict: + """Evaluate a question with single or multiple samples.""" + try: + # Generate completions + completions = [] + for _ in range(samples_per_question): + completion = await retry_with_exponential_backoff( + generate_completion, question, client, semaphore + ) + completions.append(completion) + + # For single sample mode, return simpler result + if samples_per_question == 1: + is_correct = fr"\boxed{{{answer}}}" in completions[0] + return { + "question": question, + "expected_answer": answer, + "completion": completions[0], + "correct": is_correct + } + + # For multi-sample mode, evaluate with PRM + scores = [] + for completion in completions: + score = await retry_with_exponential_backoff( + get_prm_score, completion, session + ) + scores.append(score) + + # Evaluate correctness and extract answers + is_correct = [] + extracted_answers = [] + for completion in completions: + correct = fr"\boxed{{{answer}}}" in completion + is_correct.append(correct) + + # Extract answer for self-consistency + if r"\boxed{" in completion: + extracted = completion.split(r"\boxed{")[1].split("}")[0] + extracted_answers.append(extracted) + + # Find best completion by PRM score + best_idx = max(range(len(scores)), key=lambda i: scores[i]) + + # Calculate self-consistency + answer_counts = Counter(extracted_answers) + most_common_answer = answer_counts.most_common(1)[0][0] if answer_counts else None + is_sc_correct = most_common_answer == answer if most_common_answer else False + + return { + "question": question, + "expected_answer": answer, + "completions": [ + { + "text": compl, + "score": score, + "correct": corr + } + for compl, score, corr in zip(completions, scores, is_correct) + ], + "best_completion": { + "text": completions[best_idx], + "score": scores[best_idx], + "correct": is_correct[best_idx] + }, + "statistics": { + "any_correct": any(is_correct), + "best_correct": is_correct[best_idx], + "self_consistency_correct": is_sc_correct, + "unique_answers": len(answer_counts), + "most_common_answer": most_common_answer, + "most_common_count": answer_counts.most_common(1)[0][1] if answer_counts else 0 + } + } + + except Exception as e: + return { + "question": question, + "expected_answer": answer, + "error": str(e) + } + +async def process_batch( + batch: List[Tuple[str, str]], + client: AsyncOpenAI, + session: aiohttp.ClientSession, + progress: BatchProgress, + semaphore: Semaphore, + samples_per_question: int +) -> List[dict]: + """Process a batch of questions concurrently.""" + tasks = [] + for question, answer in batch: + tasks.append( + evaluate_question( + question, answer, client, session, semaphore, samples_per_question + ) + ) + + results = await asyncio.gather(*tasks) + + # Update progress based on mode + for result in results: + if "error" not in result: + if samples_per_question == 1: + progress.update(result["correct"]) + else: + progress.update( + result["statistics"]["any_correct"], + result["statistics"]["best_correct"], + result["statistics"]["self_consistency_correct"] + ) + + return results + +async def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--samples", type=int, default=10, + help="Number of samples per question (default: 1)") + parser.add_argument("--num-questions", type=int, default=200, + help="Number of questions to evaluate (default: 200)") + args = parser.parse_args() + + # Set random seed for reproducibility + random.seed(42) + + # Load and preprocess dataset + gsm8k = load_dataset("openai/gsm8k", "main", split="test").shuffle(seed=42) + questions = [(ex["question"], ex["answer"].split("\n#### ")[-1].strip()) + for ex in gsm8k] + # questions = random.sample(questions, args.num_questions) + + # Initialize API client and semaphore + client = AsyncOpenAI(base_url=POLICY_URL, api_key=API_KEY) + semaphore = Semaphore(MAX_CONCURRENT) + + # Initialize progress tracker + progress = BatchProgress(len(questions), args.samples) + + # Process in batches + all_results = [] + + # Create session only if needed (multi-sample mode) + if args.samples > 1: + async with aiohttp.ClientSession() as session: + for i in range(0, len(questions), BATCH_SIZE): + batch = questions[i:i + BATCH_SIZE] + results = await process_batch( + batch, client, session, progress, semaphore, args.samples + ) + all_results.extend(results) + else: + # Use None for session in single-sample mode + for i in range(0, len(questions), BATCH_SIZE): + batch = questions[i:i + BATCH_SIZE] + results = await process_batch( + batch, client, None, progress, semaphore, args.samples + ) + all_results.extend(results) + + # Save results + suffix = f"{args.samples}samples" if args.samples > 1 else "single" + filename = f"sampling_results_{suffix}.jsonl" + with open(filename, "w") as f: + for result in all_results: + f.write(json.dumps(result) + "\n") + + progress.close() + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/mcts/train_policy_orpo.py b/mcts/train_policy_orpo.py new file mode 100644 index 0000000..d700cac --- /dev/null +++ b/mcts/train_policy_orpo.py @@ -0,0 +1,125 @@ +from dataclasses import dataclass, field +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser +from trl import ModelConfig, ORPOConfig, ORPOTrainer, get_peft_config +import wandb + +@dataclass +class ScriptArguments: + model_name: str = field(default="Qwen/Qwen2-0.5B-Instruct") + dataset_name: str = field(default="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ST-0") + dataset_train_split: str = field(default="train") + dataset_test_split: str = field(default="train") # Using train as test since original doesn't have test split + output_model_name: str = field(default=None) + hub_token: str = field(default=None) + use_peft: bool = field(default=False) + +@dataclass +class ModelArguments(ModelConfig): + model_name_or_path: str = field(default="Qwen/Qwen2-0.5B-Instruct") + trust_remote_code: bool = field(default=True) + +def train_orpo( + model_name=None, + dataset_name=None, + output_model_name=None, + hub_token=None + ): + # Initialize wandb + wandb.init(project="orpo-training") + + # Initialize base arguments + script_args = ScriptArguments() + if model_name: + script_args.model_name = model_name + if dataset_name: + script_args.dataset_name = dataset_name + if output_model_name: + script_args.output_model_name = output_model_name + if hub_token: + script_args.hub_token = hub_token + + # Set up model arguments + model_args = ModelArguments( + model_name_or_path=script_args.model_name + ) + + # Set up training configuration + training_args = ORPOConfig( + output_dir="orpo-math-model", + num_train_epochs=1, + per_device_train_batch_size=8, + gradient_accumulation_steps=8, + # learning_rate=5e-7, + learning_rate=8e-6, + lr_scheduler_type="linear", + beta=0.1, + # learning_rate=5e-6, + max_length=2048, + max_prompt_length=1024, + gradient_checkpointing=True, + push_to_hub=True, + hub_model_id=script_args.output_model_name, + hub_strategy="end", + report_to=["wandb"], + bf16=True, + tf32=True, + optim="paged_adamw_32bit", + max_grad_norm=1.0, + warmup_ratio=0.1, + # lr_scheduler_type="cosine", + do_eval=True, + evaluation_strategy="steps", + eval_steps=20, + remove_unused_columns=False, + logging_steps=10, + logging_first_step=True + ) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code + ) + tokenizer.pad_token = tokenizer.eos_token + + # Load model + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, + torch_dtype=torch.float16, + device_map="auto" + ) + model.config.use_cache = False + + # Load and process dataset + dataset = load_dataset(script_args.dataset_name, token=script_args.hub_token) + train_dataset = dataset["train"].map( + lambda examples: { + "prompt": examples["question"], + "chosen": ["\n\n".join(ex["steps"]) for ex in examples["positive"]], + "rejected": ["\n\n".join(ex["steps"]) for ex in examples["negative"]] + }, + batched=True, + remove_columns=dataset["train"].column_names + ) + + # Initialize trainer + trainer = ORPOTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=train_dataset, + processing_class=tokenizer, + peft_config=get_peft_config(model_args) if script_args.use_peft else None, + ) + + # Train the model + trainer.train() + trainer.save_model() + + wandb.finish() + +if __name__ == "__main__": + train_orpo() \ No newline at end of file diff --git a/mcts/tree_search.py b/mcts/tree_search.py index 776aea6..679fca1 100644 --- a/mcts/tree_search.py +++ b/mcts/tree_search.py @@ -14,9 +14,15 @@ # URLs and configuration # POLICY_URL = 'https://rawsh--vllm-qwen-ft-serve.modal.run/v1/' -POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-SimPO-3' -POLICY_URL = 'https://rawsh--vllm-qwen-simpo-serve.modal.run/v1/' -PRM_URL = 'https://rawsh--mirrorqwen-prm-embedder-score-output.modal.run' +# POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-SimPO-3' +# POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-SimPO-0' +# POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-SFT' +POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-ORPO-1' +# POLICY_URL = 'https://rawsh--vllm-qwen-simpo-serve.modal.run/v1/' +# POLICY_URL = 'https://rawsh--vllm-qwen-base-serve.modal.run/v1/' +POLICY_URL = 'https://rawsh--vllm-qwen-orpo-serve.modal.run/v1/' +# PRM_URL = 'https://rawsh--mirrorqwen-prm-embedder-score-output.modal.run' +PRM_URL = 'https://rawsh--mirrorqwen-prm-st-embedder-score-output.modal.run' API_KEY = '9FF74944EED19865193F979942FB1' CONCURRENT_MCTS_SEMAPHORE = Semaphore(50) @@ -24,7 +30,7 @@ PRM_SEMAPHORE = Semaphore(1000) MAX_RETRIES = 20 # Increased from 10s -TIMEOUT = 20 # Decreased from 30 to fail faster and retry +TIMEOUT = 20 # Decreased from 30 to fail faster and retry # Cache decorator and retry function def async_lru_cache(maxsize=2000): @@ -359,9 +365,9 @@ async def main(): # Set random seed for reproducibility # random.seed(0) # eval set - all models # random.seed(42) # st 0 - # random.seed(4242) # st 1 + random.seed(4242) # st 1 # random.seed(424242) # st 2 - random.seed(42424242) # st 3 + # random.seed(42424242) # st 3 def process(example): example["answer"] = example["answer"].split("\n#### ")[-1].strip() @@ -371,10 +377,10 @@ def process(example): gsm8k = load_dataset("openai/gsm8k", "main", split="train").shuffle(seed=42) gsm8k = gsm8k.map(process, num_proc=24) initial_states = [(example["question"], example["answer"]) for example in gsm8k] - # initial_states = random.sample(initial_states, 200) # SAMPLE 200 QUESTIONS - SELF TRAINING initial_states = random.sample(initial_states, 200) + # initial_states = random.sample(initial_states, 1000) num_iterations = 100 print("cold starting policy vllm + prm api") diff --git a/modal_prm_reward.py b/modal_prm_reward.py index 75c84b3..fc15a47 100644 --- a/modal_prm_reward.py +++ b/modal_prm_reward.py @@ -9,7 +9,8 @@ "batched", ]) ) -app = modal.App("mirrorqwen-prm", image=image) +# app = modal.App("mirrorqwen-prm", image=image) +app = modal.App("mirrorqwen-prm-st", image=image) with image.imports(): from typing import List, Dict, Tuple @@ -45,11 +46,11 @@ def _process_batch(prompts: List[str]) -> List[Dict]: class Embedder: model_id = "rawsh/mirrorqwen2.5-0.5b-prm" # revision = "894341fbd81d0c1abdd98b4e0630de932aa63c6f" # base - # revision = "42e07d1b708282ac2aae338050d8116f8c69398d" # st0 + revision = "42e07d1b708282ac2aae338050d8116f8c69398d" # st0 # revision = "65f4a7601dffacc40e0ef7fa4733d346c926bd18" # st1 v1 # revision = "80da7ccc4f107e0cb6bf937d61be4702badfb96b" # st1 v2 # revision = "4d618515c90069993f4b32e4201783efdeebbc22" # st2 - revesion = "b052380b619e5c62ce9f407522362f5caf7b8346" # st3 + # revision = "b052380b619e5c62ce9f407522362f5caf7b8346" # st3 device = "cuda" print(model_id) diff --git a/modal_train_policy_orpo.py b/modal_train_policy_orpo.py new file mode 100644 index 0000000..ef310bc --- /dev/null +++ b/modal_train_policy_orpo.py @@ -0,0 +1,86 @@ +import modal +import sys +import traceback + +# Define CUDA specifications +cuda_version = "12.4.0" +flavor = "devel" +operating_sys = "ubuntu22.04" +tag = f"{cuda_version}-{flavor}-{operating_sys}" + +# Create Modal image with all necessary dependencies +image = ( + modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") + .apt_install("git") + .pip_install("torch") + .pip_install("transformers") + .pip_install("accelerate") + .pip_install("datasets") + .pip_install("wandb") + .pip_install("trl>=0.7.6") + .pip_install("huggingface_hub") + .pip_install("bitsandbytes") +) + +with image.imports(): + from mcts.train_policy_orpo import train_orpo # Import from our new simplified script + +# Create Modal app +app = modal.App("train-policy-orpo", image=image) + +@app.function( + cpu=4.0, + gpu=modal.gpu.H100(count=1), + timeout=24 * 60 * 60, + # memory=32768, + secrets=[ + modal.Secret.from_name("hf-token"), + modal.Secret.from_name("wandb-token") + ], +) +def train_policy_orpo(): + import os + from huggingface_hub import HfFolder + import wandb + + try: + # Set up HuggingFace token + hf_token = os.environ["HF_TOKEN"] + HfFolder.save_token(hf_token) + + # Set up Weights & Biases + wandb.login(key=os.environ["WANDB_API_KEY"]) + + # Run training with specified parameters + train_orpo( + # model_name="rawsh/mirrorqwen2.5-0.5b-SFT", + model_name="rawsh/mirrorqwen2.5-0.5b-ORPO-1", + # model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-0", + # model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-1", + # model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-2", + # dataset_name="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ST-3", + # dataset_name="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ST-0", + dataset_name="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ORPO-1", + output_model_name="rawsh/mirrorqwen2.5-0.5b-ORPO-2", + hub_token=hf_token + ) + except Exception as e: + print(f"Error during training: {str(e)}", file=sys.stderr) + print("Traceback:", file=sys.stderr) + traceback.print_exc(file=sys.stderr) + # Make sure to finish wandb run even on error + try: + wandb.finish() + except: + pass + raise e + +@app.local_entrypoint() +def main(): + print("Starting full model ORPO training on Modal...") + try: + train_policy_orpo.remote() + print("Training job submitted to Modal. Check W&B dashboard for training progress.") + except Exception as e: + print(f"Error in training job: {str(e)}") + sys.exit(1) \ No newline at end of file diff --git a/modal_vllm.py b/modal_vllm.py index d037e05..f066157 100644 --- a/modal_vllm.py +++ b/modal_vllm.py @@ -17,6 +17,11 @@ def download_model_to_image(model_dir, model_name, model_revision): move_cache() MODEL_DIR = "/qwen" +MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-ORPO-1" +MODEL_REVISION = "a3e4731f8fb3384b07ba112a37cbcc2d4f531623" +# MODEL_DIR = "/qwen" +# MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SFT" +# MODEL_REVISION = "1f75c1204888cc912ad0b186c5b7620235246ffa" # # st0 # MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SimPO-0" # MODEL_REVISION = "c699a3f7e82a805d6a4b158b033c5d7919230dd1" @@ -28,8 +33,8 @@ def download_model_to_image(model_dir, model_name, model_revision): # MODEL_REVISION = "9e6d25903688b5678bdbe333c537a58488212024" # MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SimPO-2" # MODEL_REVISION = "a41b6dd0307cf080a83cf20efc25bbf025b47852" -MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SimPO-3" -MODEL_REVISION = "4bf9608e31850cf1020de695d99f0c1fb9e0575f" +# MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SimPO-3" +# MODEL_REVISION = "4bf9608e31850cf1020de695d99f0c1fb9e0575f" vllm_image = ( modal.Image.debian_slim(python_version="3.10") @@ -55,7 +60,9 @@ def download_model_to_image(model_dir, model_name, model_revision): .env({"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"}) ) -app = modal.App("vllm-qwen-simpo") +# app = modal.App("vllm-qwen-simpo") +# app = modal.App("vllm-qwen-base") +app = modal.App("vllm-qwen-orpo") N_GPU = 1 MINUTES = 60