diff --git a/mcts/generate.py b/mcts/generate.py index ed12c10..6e82bf7 100644 --- a/mcts/generate.py +++ b/mcts/generate.py @@ -7,15 +7,18 @@ from tqdm import tqdm client = AsyncOpenAI( - api_key="9FF74944EED19865193F979942FB1", - base_url="https://rawsh--vllm-qwen-serve.modal.run/v1" + api_key="9FF74944EED19865193F979942FB1",adfghk + base_url="https://rawsh--vllm-smollm-serve.modal.run/v1" ) def format_thoughts(thoughts: List[str]) -> str: return "\n".join(f"## Step {i}:\n{thought}" for i, thought in enumerate(thoughts, 1)) -template = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n\ -<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n{assistant_partial}" +# template = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n\ +# <|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n{assistant_partial}" + +template = "<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>\n\ +<|im_start|>human\n{user}<|im_end|>\n<|im_start|>assistant\n{assistant_partial}" class ReasoningTrace: def __init__(self, question: str, previous_thoughts: List[str], next_step: int): @@ -36,10 +39,12 @@ async def generate_thought_batched(batch: List[ReasoningTrace]) -> List[Processe prompts.append(prompt) params = { - "model": "Qwen/Qwen2.5-0.5B-Instruct", + # "model": "Qwen/Qwen2.5-0.5B-Instruct", + "model": "HuggingFaceTB/SmolLM2-135M-Instruct", "prompt": prompts, "max_tokens": 200, - "temperature": 0.7, + # "temperature": 0.7, + "temperature": 0.0, "stop": ["\n## Step"], "timeout": 600 } @@ -58,7 +63,7 @@ async def generate_thought_batched(batch: List[ReasoningTrace]) -> List[Processe return None async def format_thought_chain(question: str, chain: List[str]) -> List[ReasoningTrace]: - return [ReasoningTrace(question, chain[:i], i+1) for i in range(1, len(chain))] + return [ReasoningTrace(question, chain[:i], i+1) for i in range(0, len(chain))] async def process_batch(batch: List[ReasoningTrace], semaphore: asyncio.Semaphore) -> List[ProcessedReasoningTrace]: async with semaphore: diff --git a/mcts/process_results.py b/mcts/process_results.py new file mode 100644 index 0000000..802faff --- /dev/null +++ b/mcts/process_results.py @@ -0,0 +1,492 @@ +import json +from collections import defaultdict, Counter +from dataclasses import dataclass +from typing import Dict, List, Set, Tuple, Optional, Any +from pathlib import Path +import numpy as np + +# Changes to PathMetrics class definition: +@dataclass +class PathMetrics: + """Metrics for a single reasoning path""" + answer: str + is_correct: bool + path_length: int + prm_score: float + raw_steps: List[str] # Store original steps + steps: List[str] # Store preprocessed steps + +@dataclass +class QuestionAnalysis: + """Analysis results for a single question""" + question_text: str + correct_answer: str + binary_success: bool # Any correct path? + sc_score: float # Overall self-consistency score + sc_correct_percent: float # % of self-consistent answers that are correct + total_paths: int + correct_paths: List[PathMetrics] + incorrect_paths: List[PathMetrics] + answer_distribution: Counter + +class MathReasoningAnalyzer: + def __init__(self, path_data: str): + self.data_path = Path(path_data) + self.questions = self._load_data() + + def _load_data(self) -> List[dict]: + """Load and parse JSONL data""" + questions = [] + with open(self.data_path) as f: + for line in f: + if line.strip(): + questions.append(json.loads(line)) + return questions + + def _extract_steps(self, final_state: str) -> Tuple[List[str], List[str]]: + """Extract reasoning steps from final state and return both raw and processed steps""" + raw_steps = [step.strip() for step in final_state.split('\n\n')] + raw_steps = [step for step in raw_steps if step] + + if len(raw_steps) > 2: + # Remove first step (question repeat) + processed_steps = raw_steps[1:] + # Concatenate last formatting step to the previous step + if len(processed_steps) > 1: + processed_steps[-2] = processed_steps[-2] + "\n\n" + processed_steps[-1] + processed_steps = processed_steps[:-1] + else: + processed_steps = [] + + return raw_steps, processed_steps + + def _extract_answer(self, final_state: str) -> str: + """Extract final answer from the final state""" + if '\\boxed{' in final_state: + return final_state.split('\\boxed{')[1].split('}')[0] + return '' + + # Changes to analyze_question method: + def analyze_question(self, question: dict) -> QuestionAnalysis: + """Analyze a single question's reasoning paths""" + paths = [] + answers = Counter() + + # Process each terminal path + for path in question['terminal_paths']: + raw_steps, processed_steps = self._extract_steps(path['final_state']) + answer = self._extract_answer(path['final_state']) + + path_metrics = PathMetrics( + answer=answer, + is_correct=path['correct'], + path_length=len(processed_steps), # Use processed steps length + prm_score=path['score'], + raw_steps=raw_steps, + steps=processed_steps + ) + paths.append(path_metrics) + answers[answer] += 1 + + # Split paths + correct_paths = [p for p in paths if p.is_correct] + incorrect_paths = [p for p in paths if not p.is_correct] + + # Calculate self-consistency and SC correct % + total_paths = len(paths) + if total_paths > 0: + # Overall self-consistency + most_common = answers.most_common() + most_common_count = most_common[0][1] if most_common else 0 + sc_score = most_common_count / total_paths + + # Calculate % of self-consistent answers that are correct + sc_answers = [ans for ans, count in most_common + if count > total_paths * 0.2] # Consider answers that appear >20% of time + sc_correct = sum(1 for ans in sc_answers + if any(p.answer == ans and p.is_correct + for p in correct_paths)) + sc_correct_percent = sc_correct / len(sc_answers) if sc_answers else 0 + else: + sc_score = 0 + sc_correct_percent = 0 + + return QuestionAnalysis( + question_text=question['question'], + correct_answer=question['correct_answer'], + binary_success=bool(correct_paths), + sc_score=sc_score, + sc_correct_percent=sc_correct_percent, + total_paths=total_paths, + correct_paths=correct_paths, + incorrect_paths=incorrect_paths, + answer_distribution=answers + ) + + def get_paired_examples( + self, + analyses: List[QuestionAnalysis], + max_pairs: int = 10000 + ) -> List[Dict[str, Any]]: + """Get paired positive/negative examples for each question""" + paired_examples = [] + + for analysis in analyses: + if not analysis.correct_paths or not analysis.incorrect_paths: + continue + + # Find best correct path + shortest_correct = min(analysis.correct_paths, key=lambda p: p.path_length) + best_correct = max( + [p for p in analysis.correct_paths + if p.path_length <= shortest_correct.path_length * 1.2], + key=lambda p: p.prm_score + ) + + # Find most deceptive incorrect path + best_incorrect = max( + analysis.incorrect_paths, + key=lambda p: ( + p.prm_score, + -abs(p.path_length - best_correct.path_length) + ) + ) + + paired_examples.append({ + 'question': analysis.question_text, + 'correct_answer': analysis.correct_answer, + 'metrics': { + 'sc_score': analysis.sc_score, + 'sc_correct_percent': analysis.sc_correct_percent, + 'total_paths': analysis.total_paths, + 'answer_distribution': dict(analysis.answer_distribution) + }, + 'positive': { + 'steps': best_correct.steps, + 'answer': best_correct.answer, + 'prm_score': best_correct.prm_score, + 'path_length': best_correct.path_length + }, + 'negative': { + 'steps': best_incorrect.steps, + 'answer': best_incorrect.answer, + 'prm_score': best_incorrect.prm_score, + 'path_length': best_incorrect.path_length + } + }) + + # Sort by quality criteria including SC correct % + paired_examples.sort( + key=lambda x: ( + x['metrics']['sc_correct_percent'], # Higher correct % in SC answers + x['metrics']['sc_score'], # Higher overall SC + x['positive']['prm_score'], # Higher positive score + x['negative']['prm_score'], # Higher negative score (more deceptive) + ), + reverse=True + ) + + return paired_examples[:max_pairs] + + def generate_prm_training_data(self, analyses: List[QuestionAnalysis]) -> List[Dict[str, Any]]: + """Generate training data for Process Reward Model (PRM) from MCTS paths.""" + prm_examples = [] + original_correct_lengths = [] + original_incorrect_lengths = [] + + for analysis in analyses: + if analysis.sc_score < 0.6: + continue + + # Process correct paths + for path in analysis.correct_paths: + if not path.steps: # Skip if no steps after preprocessing + continue + + original_correct_lengths.append(len(path.steps)) + K = len(path.steps) + v_prev = 0 + + for k, step in enumerate(path.steps, 1): + partial_steps = path.steps[:k] + m_k = K - k + r_s_k = 0 + w_s_k = (1 - v_prev) / (m_k + 1) * (1 - 2 * r_s_k) + v_k = max(v_prev + w_s_k, 0) + + prm_examples.append({ + "question": analysis.question_text, + "steps": partial_steps, + "final_step_reward": float(v_k), + "metadata": { + "is_complete": k == K, + "is_correct": True, + "path_length": K, + "step_number": k, + "raw_path_length": len(path.raw_steps) + } + }) + v_prev = v_k + + # Process incorrect paths + for path in analysis.incorrect_paths: + if not path.steps: # Skip if no steps after preprocessing + continue + + original_incorrect_lengths.append(len(path.steps)) + K = len(path.steps) + v_prev = 0 + + for k, step in enumerate(path.steps, 1): + partial_steps = path.steps[:k] + penalize = k == K + m_k = K - k if not penalize else K - k + 1 + r_s_k = 0 if not penalize else 1 + w_s_k = (1 - v_prev) / (m_k + 1) * (1 - 2 * r_s_k) + v_k = max(v_prev + w_s_k, 0) + + prm_examples.append({ + "question": analysis.question_text, + "steps": partial_steps, + "final_step_reward": float(v_k), + "metadata": { + "is_complete": k == K, + "is_correct": False, + "path_length": K, + "step_number": k, + "raw_path_length": len(path.raw_steps) + } + }) + v_prev = v_k + + # Record length statistics + if original_correct_lengths: + print("\nOriginal Path Length Statistics:") + print(f"Correct paths mean length: {np.mean(original_correct_lengths):.1f} (±{np.std(original_correct_lengths):.1f})") + if original_incorrect_lengths: + print(f"Incorrect paths mean length: {np.mean(original_incorrect_lengths):.1f} (±{np.std(original_incorrect_lengths):.1f})") + + # Print complete path statistics + complete_correct = [ex for ex in prm_examples if ex["metadata"]["is_correct"] and ex["metadata"]["is_complete"]] + complete_incorrect = [ex for ex in prm_examples if not ex["metadata"]["is_correct"] and ex["metadata"]["is_complete"]] + + print("\nComplete Path Statistics:") + print(f"Complete correct paths: {len(complete_correct)}") + print(f"Complete incorrect paths: {len(complete_incorrect)}") + + if complete_correct: + print(f"Complete correct mean length: {np.mean([ex['metadata']['path_length'] for ex in complete_correct]):.1f}") + if complete_incorrect: + print(f"Complete incorrect mean length: {np.mean([ex['metadata']['path_length'] for ex in complete_incorrect]):.1f}") + + return prm_examples + +def main(): + analyzer = MathReasoningAnalyzer('mcts_results.jsonl') + + # Analyze all questions + analyses = [] + for question in analyzer.questions: + analysis = analyzer.analyze_question(question) + analyses.append(analysis) + + # Calculate overall statistics + total = len(analyses) + binary_success = sum(1 for a in analyses if a.binary_success) + avg_sc = np.mean([a.sc_score for a in analyses]) + avg_sc_correct = np.mean([a.sc_correct_percent for a in analyses]) + + # Terminal path statistics + total_paths = [a.total_paths for a in analyses] + correct_paths = [len(a.correct_paths) for a in analyses] + incorrect_paths = [len(a.incorrect_paths) for a in analyses] + + # Path length statistics + all_correct_lengths = [p.path_length for a in analyses for p in a.correct_paths] + all_incorrect_lengths = [p.path_length for a in analyses for p in a.incorrect_paths] + + # PRM score statistics + all_correct_scores = [p.prm_score for a in analyses for p in a.correct_paths] + all_incorrect_scores = [p.prm_score for a in analyses for p in a.incorrect_paths] + + # Best path analysis + best_paths_correct = 0 + total_questions = len(analyses) + + for question in analyzer.questions: + # Get highest scoring path + best_path = max(question['terminal_paths'], key=lambda x: x['score']) + if best_path['correct']: + best_paths_correct += 1 + + best_path_accuracy = (best_paths_correct / total_questions) * 100 + + print("\nBest Path Analysis:") + print(f"Questions where highest scoring path was correct: {best_paths_correct} ({best_path_accuracy:.1f}%)") + + print("\nOverall Statistics:") + print(f"Total questions analyzed: {total}") + print(f"Questions with at least one correct path: {binary_success} ({binary_success/total*100:.1f}%)") + print(f"Average self-consistency score: {avg_sc:.3f}") + print(f"Average % of self-consistent answers that are correct: {avg_sc_correct*100:.1f}%") + + print("\nTerminal Path Statistics:") + print(f"Average total paths per question: {np.mean(total_paths):.1f} (±{np.std(total_paths):.1f})") + print(f"Average correct paths per question: {np.mean(correct_paths):.1f} (±{np.std(correct_paths):.1f})") + print(f"Average incorrect paths per question: {np.mean(incorrect_paths):.1f} (±{np.std(incorrect_paths):.1f})") + + print("\nPath Length Statistics:") + if all_correct_lengths: + print(f"Average correct path length: {np.mean(all_correct_lengths):.1f} (±{np.std(all_correct_lengths):.1f})") + if all_incorrect_lengths: + print(f"Average incorrect path length: {np.mean(all_incorrect_lengths):.1f} (±{np.std(all_incorrect_lengths):.1f})") + + print("\nPRM Score Statistics:") + if all_correct_scores: + print(f"Average correct path PRM score: {np.mean(all_correct_scores):.3f} (±{np.std(all_correct_scores):.3f})") + if all_incorrect_scores: + print(f"Average incorrect path PRM score: {np.mean(all_incorrect_scores):.3f} (±{np.std(all_incorrect_scores):.3f})") + + # Distribution of number of paths + path_counts = Counter(total_paths) + print("\nPath Count Distribution:") + for count in sorted(path_counts.keys()): + questions = path_counts[count] + print(f"{count} paths: {questions} questions ({questions/total*100:.1f}%)") + + print("\nSelf-Consistency Breakdown:") + sc_thresholds = [0.2, 0.4, 0.6, 0.8] + for threshold in sc_thresholds: + questions_above = sum(1 for a in analyses if a.sc_score >= threshold) + correct_above = sum(1 for a in analyses + if a.sc_score >= threshold and a.sc_correct_percent > 0) + print(f"Questions with SC >= {threshold:.1f}: {questions_above} " + f"({questions_above/total*100:.1f}%) - " + f"Correct: {correct_above} ({correct_above/questions_above*100:.1f}% of SC)") + + + should_generate=False + if should_generate: + # Generate both preference pairs and PRM training data + paired_examples = analyzer.get_paired_examples(analyses) + prm_training_data = analyzer.generate_prm_training_data(analyses) + + print(f"\nSelected {len(paired_examples)} paired examples") + + # Statistics on selected pairs + if paired_examples: + print("\nSelected Pairs Statistics:") + pair_pos_prm = [ex['positive']['prm_score'] for ex in paired_examples] + pair_neg_prm = [ex['negative']['prm_score'] for ex in paired_examples] + pair_pos_len = [ex['positive']['path_length'] for ex in paired_examples] + pair_neg_len = [ex['negative']['path_length'] for ex in paired_examples] + pair_sc = [ex['metrics']['sc_score'] for ex in paired_examples] + pair_sc_correct = [ex['metrics']['sc_correct_percent'] for ex in paired_examples] + + print("\nPaired Examples Metrics:") + print(f"Average positive path length: {np.mean(pair_pos_len):.1f} (±{np.std(pair_pos_len):.1f})") + print(f"Average negative path length: {np.mean(pair_neg_len):.1f} (±{np.std(pair_neg_len):.1f})") + print(f"Average positive PRM score: {np.mean(pair_pos_prm):.3f} (±{np.std(pair_pos_prm):.3f})") + print(f"Average negative PRM score: {np.mean(pair_neg_prm):.3f} (±{np.std(pair_neg_prm):.3f})") + print(f"Average self-consistency: {np.mean(pair_sc):.3f} (±{np.std(pair_sc):.3f})") + print(f"Average % correct in SC: {np.mean(pair_sc_correct)*100:.1f}% (±{np.std(pair_sc_correct)*100:.1f}%)") + + # In main(), replace the PRM Training Data Statistics section with: + + # Print PRM Training Data Statistics + print("\nPRM Training Data Statistics:") + correct_examples = [ex for ex in prm_training_data if ex["metadata"]["is_correct"]] + incorrect_examples = [ex for ex in prm_training_data if not ex["metadata"]["is_correct"]] + + print(f"Total training examples: {len(prm_training_data)}") + print(f"Correct examples: {len(correct_examples)}") + print(f"Incorrect examples: {len(incorrect_examples)}") + + print("\nCorrect Examples Statistics:") + if correct_examples: + complete_correct = [ex for ex in correct_examples if ex["metadata"]["is_complete"]] + print(f"Complete paths: {len(complete_correct)}") + print(f"Average steps: {np.mean([len(ex['steps']) for ex in correct_examples]):.1f}") + print(f"Average reward: {np.mean([ex['final_step_reward'] for ex in correct_examples]):.3f}") + else: + print("No correct examples found") + + print("\nIncorrect Examples Statistics:") + if incorrect_examples: + complete_incorrect = [ex for ex in incorrect_examples if ex["metadata"]["is_complete"]] + print(f"Complete paths: {len(complete_incorrect)}") + print(f"Average steps: {np.mean([len(ex['steps']) for ex in incorrect_examples]):.1f}") + print(f"Average reward: {np.mean([ex['final_step_reward'] for ex in incorrect_examples]):.3f}") + else: + print("No incorrect examples found") + + # Add path length distribution + print("\nPath Length Distribution:") + correct_lengths = [len(ex['steps']) for ex in correct_examples] + incorrect_lengths = [len(ex['steps']) for ex in incorrect_examples] + + if correct_lengths: + correct_dist = Counter(correct_lengths) + print("\nCorrect path lengths:") + for length in sorted(correct_dist.keys()): + count = correct_dist[length] + percent = (count / len(correct_lengths)) * 100 + print(f"{length} steps: {count} examples ({percent:.1f}%)") + + if incorrect_lengths: + incorrect_dist = Counter(incorrect_lengths) + print("\nIncorrect path lengths:") + for length in sorted(incorrect_dist.keys()): + count = incorrect_dist[length] + percent = (count / len(incorrect_lengths)) * 100 + print(f"{length} steps: {count} examples ({percent:.1f}%)") + + # Add reward distribution + print("\nReward Distribution:") + if correct_examples: + correct_rewards = [ex['final_step_reward'] for ex in correct_examples] + print(f"\nCorrect rewards: min={min(correct_rewards):.3f}, " + f"mean={np.mean(correct_rewards):.3f}, " + f"max={max(correct_rewards):.3f}") + + if incorrect_examples: + incorrect_rewards = [ex['final_step_reward'] for ex in incorrect_examples] + print(f"Incorrect rewards: min={min(incorrect_rewards):.3f}, " + f"mean={np.mean(incorrect_rewards):.3f}, " + f"max={max(incorrect_rewards):.3f}") + + # Save both datasets + with open('paired_examples.json', 'w') as f: + json.dump({ + 'summary_stats': { + 'total_questions': total, + 'questions_with_correct': binary_success, + 'avg_self_consistency': float(avg_sc), + 'avg_sc_correct_percent': float(avg_sc_correct), + 'path_stats': { + 'avg_total_paths': float(np.mean(total_paths)), + 'avg_correct_paths': float(np.mean(correct_paths)), + 'avg_incorrect_paths': float(np.mean(incorrect_paths)), + 'avg_correct_length': float(np.mean(all_correct_lengths)) if all_correct_lengths else None, + 'avg_incorrect_length': float(np.mean(all_incorrect_lengths)) if all_incorrect_lengths else None, + 'avg_correct_prm': float(np.mean(all_correct_scores)) if all_correct_scores else None, + 'avg_incorrect_prm': float(np.mean(all_incorrect_scores)) if all_incorrect_scores else None + }, + 'selected_pairs': len(paired_examples) + }, + 'paired_examples': paired_examples + }, f, indent=2) + + # Save PRM training data in JSONL format + with open('prm_training.jsonl', 'w') as f: + for example in prm_training_data: + json.dump(example, f) + f.write('\n') + + print("\nOutput files written:") + print("- paired_examples.json") + print("- prm_training.jsonl") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/mcts/train_policy_sft.py b/mcts/train_policy_sft.py index b1e23d2..7b725c0 100644 --- a/mcts/train_policy_sft.py +++ b/mcts/train_policy_sft.py @@ -46,7 +46,8 @@ def train_sft(): load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False. model, tokenizer = FastLanguageModel.from_pretrained( - model_name = "unsloth/gemma-2-2b", + # model_name = "unsloth/gemma-2-2b", + model_name = "Qwen/Qwen2.5-0.5B", max_seq_length = max_seq_length, dtype = dtype, load_in_4bit = load_in_4bit, @@ -129,4 +130,5 @@ def formatting_prompts_func(examples): trainer_stats = trainer.train() - model.push_to_hub_merged("rawsh/mirrorgemma-2-2b-SFT", tokenizer, save_method = "merged_16bit") \ No newline at end of file + # model.push_to_hub_merged("rawsh/mirrorgemma-2-2b-SFT", tokenizer, save_method = "merged_16bit") + model.push_to_hub_merged("rawsh/mirrorqwen2.5-0.5b-SFT", tokenizer, save_method = "merged_16bit") \ No newline at end of file diff --git a/mcts/train_reward.py b/mcts/train_reward.py index 4f19d4c..0062cfe 100644 --- a/mcts/train_reward.py +++ b/mcts/train_reward.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Union - import numpy as np import torch import torch.nn as nn @@ -12,83 +11,45 @@ PreTrainedTokenizerBase, Trainer, TrainingArguments, + EarlyStoppingCallback, ) from transformers.utils import PaddingStrategy - import random from collections import Counter -# Define and parse arguments. @dataclass class ScriptArguments: - local_rank: Optional[int] = field( - default=-1, metadata={"help": "Used for multi-gpu"} - ) - deepspeed: Optional[str] = field( - default=None, - metadata={ - "help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU." - }, - ) - per_device_train_batch_size: Optional[int] = field(default=4) + local_rank: Optional[int] = field(default=-1, metadata={"help": "Used for multi-gpu"}) + deepspeed: Optional[str] = field(default=None) + per_device_train_batch_size: Optional[int] = field(default=8) per_device_eval_batch_size: Optional[int] = field(default=4) gradient_accumulation_steps: Optional[int] = field(default=32) - # learning_rate: Optional[float] = field(default=2e-6) - # embedding_learning_rate: Optional[float] = field(default=1e-6) - learning_rate: Optional[float] = field(default=1e-5) - weight_decay: Optional[float] = field(default=0.001) - model_name: Optional[str] = field( - default="google/gemma-2-2b", - metadata={ - "help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc." - }, - ) - bf16: Optional[bool] = field( - default=True, - metadata={ - "help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU." - }, - ) - num_train_epochs: Optional[int] = field( - default=1, - # default=3, - metadata={"help": "The number of training epochs for the reward model."}, - ) - train_set_path: Optional[str] = field( - default="rawsh/magpie-ultra-v0.1-PRM-data-base", - metadata={"help": "The dir of the subset of the training data to use"}, - ) - eval_set_path: Optional[str] = field( - default="rawsh/magpie-ultra-v0.1-PRM-data-base", - metadata={"help": "The dir of the subset of the eval data to use"}, - ) - output_path: Optional[str] = field( - default="./mirrorgemma-2-2b-prm-base", - metadata={"help": "The dir for output model"}, - ) - gradient_checkpointing: Optional[bool] = field( - default=True, - metadata={"help": "Enables gradient checkpointing."}, - ) - optim: Optional[str] = field( - default="adamw_torch_fused", - metadata={"help": "The optimizer to use."}, - ) - lr_scheduler_type: Optional[str] = field( - default="cosine", - metadata={"help": "The lr scheduler"}, - ) + learning_rate: Optional[float] = field(default=8e-6) + weight_decay: Optional[float] = field(default=0.0001) + model_name: Optional[str] = field(default="Qwen/Qwen2.5-0.5B") + bf16: Optional[bool] = field(default=True) + num_train_epochs: Optional[int] = field(default=2) + train_set_path: Optional[str] = field(default="rawsh/magpie-ultra-v0.1-PRM-data-base") + eval_set_path: Optional[str] = field(default="rawsh/magpie-ultra-v0.1-PRM-data-base") + output_path: Optional[str] = field(default="./mirrorqwen2.5-0.5b-prm-base") + output_model_name: Optional[str] = field(default="rawsh/mirrorqwen2.5-0.5b-PRM") + gradient_checkpointing: Optional[bool] = field(default=True) + optim: Optional[str] = field(default="adamw_torch_fused") + lr_scheduler_type: Optional[str] = field(default="cosine") max_length: Optional[int] = field(default=8192) - save_every_steps: Optional[int] = field( - default=999999, - metadata={"help": "Save the model every x steps"}, - ) - eval_every_steps: Optional[int] = field( - default=999999, - metadata={"help": "Eval the model every x steps"}, - ) + save_every_steps: Optional[int] = field(default=999999) + eval_every_steps: Optional[int] = field(default=999999) + early_stopping_patience: Optional[int] = field(default=3) + early_stopping_threshold: Optional[float] = field(default=0.001) + disable_binning: Optional[bool] = field(default=False) + # Add new parameters for improved checkpointing + warmup_steps: Optional[int] = field(default=100) + save_total_limit: Optional[int] = field(default=3) + min_loss_threshold: Optional[float] = field(default=0.1) -def build_dataset(tokenizer, train_path, eval_path): + + +def build_dataset(tokenizer, train_path, eval_path, disable_binning: bool): def tokenize(sample): question = sample['question'] steps = sample['steps'] @@ -112,85 +73,91 @@ def tokenize(sample): ds_train = load_dataset(train_path, split="train").shuffle(seed=42) ds_train = ds_train.map(tokenize, num_proc=24) - # Step 2: Assign bin number to each sample in training data - def assign_bin(example): - final_step_reward = example['final_step_reward'] - # Calculate bin number (bins: 0.0-0.1 => bin 0, ..., 0.9-1.0 => bin 9) - bin_number = int(final_step_reward * 10) - if bin_number == 10: - bin_number = 9 # Handle the edge case where final_step_reward == 1.0 - example['bin'] = bin_number - return example - - ds_train = ds_train.map(assign_bin, num_proc=24) - - # Step 3: Get counts of samples in each bin for training data - bin_counts_train = Counter(ds_train['bin']) - print("Training bin counts before undersampling:", bin_counts_train) - - # Determine the minimum count across all bins in training data - min_count_train = min(bin_counts_train.values()) - print("Training minimum count per bin:", min_count_train) - - # Step 4: Create a mapping from bin to indices for training data - bin_to_indices_train = {i: [] for i in range(10)} # Bins 0 to 9 - for idx, bin_number in enumerate(ds_train['bin']): - bin_to_indices_train[bin_number].append(idx) - - # Randomly sample min_count_train indices per bin for training data - random.seed(42) - selected_indices_train = [] - for bin_number, indices in bin_to_indices_train.items(): - if len(indices) >= min_count_train: - sampled_indices = random.sample(indices, min_count_train) - else: - sampled_indices = indices # Keep all samples if less than min_count_train - selected_indices_train.extend(sampled_indices) - - # Shuffle the selected indices to mix the data - random.shuffle(selected_indices_train) - - # Step 5: Create the balanced training dataset - train_dataset = ds_train.select(selected_indices_train) - print("Total training samples after undersampling:", len(train_dataset)) + if not disable_binning: + # Step 2: Assign bin number to each sample in training data + def assign_bin(example): + final_step_reward = example['final_step_reward'] + # Calculate bin number (bins: 0.0-0.1 => bin 0, ..., 0.9-1.0 => bin 9) + bin_number = int(final_step_reward * 10) + if bin_number == 10: + bin_number = 9 # Handle the edge case where final_step_reward == 1.0 + example['bin'] = bin_number + return example + + ds_train = ds_train.map(assign_bin, num_proc=24) + + # Step 3: Get counts of samples in each bin for training data + bin_counts_train = Counter(ds_train['bin']) + print("Training bin counts before undersampling:", bin_counts_train) + + # Determine the minimum count across all bins in training data + min_count_train = min(bin_counts_train.values()) + print("Training minimum count per bin:", min_count_train) + + # Step 4: Create a mapping from bin to indices for training data + bin_to_indices_train = {i: [] for i in range(10)} # Bins 0 to 9 + for idx, bin_number in enumerate(ds_train['bin']): + bin_to_indices_train[bin_number].append(idx) + + # Randomly sample min_count_train indices per bin for training data + random.seed(42) + selected_indices_train = [] + for bin_number, indices in bin_to_indices_train.items(): + if len(indices) >= min_count_train: + sampled_indices = random.sample(indices, min_count_train) + else: + sampled_indices = indices # Keep all samples if less than min_count_train + selected_indices_train.extend(sampled_indices) + + # Shuffle the selected indices to mix the data + random.shuffle(selected_indices_train) + + # Step 5: Create the balanced training dataset + train_dataset = ds_train.select(selected_indices_train) + print("Total training samples after undersampling:", len(train_dataset)) + else: + train_dataset = ds_train # Now, build the evaluation dataset # Load and shuffle the evaluation dataset ds_eval = load_dataset(eval_path, split="train").shuffle(seed=42) ds_eval = ds_eval.map(tokenize, num_proc=24) - # Assign bins to evaluation data - ds_eval = ds_eval.map(assign_bin, num_proc=24) - - # Get counts of samples in each bin for evaluation data - bin_counts_eval = Counter(ds_eval['bin']) - print("Evaluation bin counts before undersampling:", bin_counts_eval) - - # Determine the minimum count per bin for evaluation data - # Set it to be 10% of min_count_train, at least 1 - eval_min_count_per_bin = max(1, int(min_count_train * 0.1)) - print("Evaluation minimum count per bin:", eval_min_count_per_bin) - - # Create a mapping from bin to indices for evaluation data - bin_to_indices_eval = {i: [] for i in range(10)} # Bins 0 to 9 - for idx, bin_number in enumerate(ds_eval['bin']): - bin_to_indices_eval[bin_number].append(idx) - - # Randomly sample eval_min_count_per_bin indices per bin for evaluation data - selected_indices_eval = [] - for bin_number, indices in bin_to_indices_eval.items(): - if len(indices) >= eval_min_count_per_bin: - sampled_indices = random.sample(indices, eval_min_count_per_bin) - else: - sampled_indices = indices # Keep all samples if less than eval_min_count_per_bin - selected_indices_eval.extend(sampled_indices) - - # Shuffle the selected indices to mix the data - random.shuffle(selected_indices_eval) - - # Create the balanced evaluation dataset - eval_dataset = ds_eval.select(selected_indices_eval) - print("Total evaluation samples after undersampling:", len(eval_dataset)) + if not disable_binning: + # Assign bins to evaluation data + ds_eval = ds_eval.map(assign_bin, num_proc=24) + + # Get counts of samples in each bin for evaluation data + bin_counts_eval = Counter(ds_eval['bin']) + print("Evaluation bin counts before undersampling:", bin_counts_eval) + + # Determine the minimum count per bin for evaluation data + # Set it to be 10% of min_count_train, at least 1 + eval_min_count_per_bin = max(1, int(min_count_train * 0.1)) + print("Evaluation minimum count per bin:", eval_min_count_per_bin) + + # Create a mapping from bin to indices for evaluation data + bin_to_indices_eval = {i: [] for i in range(10)} # Bins 0 to 9 + for idx, bin_number in enumerate(ds_eval['bin']): + bin_to_indices_eval[bin_number].append(idx) + + # Randomly sample eval_min_count_per_bin indices per bin for evaluation data + selected_indices_eval = [] + for bin_number, indices in bin_to_indices_eval.items(): + if len(indices) >= eval_min_count_per_bin: + sampled_indices = random.sample(indices, eval_min_count_per_bin) + else: + sampled_indices = indices # Keep all samples if less than eval_min_count_per_bin + selected_indices_eval.extend(sampled_indices) + + # Shuffle the selected indices to mix the data + random.shuffle(selected_indices_eval) + + # Create the balanced evaluation dataset + eval_dataset = ds_eval.select(selected_indices_eval) + print("Total evaluation samples after undersampling:", len(eval_dataset)) + else: + eval_dataset = ds_eval return train_dataset, eval_dataset @@ -226,12 +193,21 @@ def compute_metrics(eval_pred): predictions = eval_pred.predictions.squeeze() labels = eval_pred.label_ids mse = np.mean((predictions - labels) ** 2) - return {"mse": mse} + return { + "mse": mse, + "mse_moving_avg": mse # Just use MSE directly since we're serverless + } class RewardTrainer(Trainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.warmup_steps = self.args.warmup_steps + self.current_step = 0 + def compute_loss(self, model, inputs, return_outputs=False): rewards = model( - input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"] + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"] )[0].squeeze() loss = nn.functional.mse_loss(rewards, inputs["rewards"]) @@ -239,40 +215,50 @@ def compute_loss(self, model, inputs, return_outputs=False): return loss, {"rewards": rewards} return loss -def train_reward_model(): - # Hardcode args (or you can parse arguments) - script_args = ScriptArguments() - # Load the model and tokenizer - tokenizer_name = script_args.model_name - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True) +def train_reward_model( + model_name=None, + dataset_path=None, + output_model_name=None, + disable_binning=False +): + script_args = ScriptArguments( + disable_binning=disable_binning, + warmup_steps=100, # Customize warmup period + save_total_limit=3, # Keep only last 3 checkpoints + min_loss_threshold=0.1 # Minimum loss threshold for saving + ) + + if model_name: + script_args.model_name = model_name + if output_model_name: + script_args.output_model_name = output_model_name + if dataset_path: + script_args.train_set_path = dataset_path + script_args.eval_set_path = dataset_path - # Adjusted according to the base model + # Initialize tokenizer + tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_auth_token=True) tokenizer.truncation_side = "left" tokenizer.model_max_length = script_args.max_length + tokenizer.pad_token = tokenizer.eos_token + + # Build datasets + train_dataset, eval_dataset = build_dataset( + tokenizer, + script_args.train_set_path, + script_args.eval_set_path, + script_args.disable_binning + ) - # Get the datasets - train_path = script_args.train_set_path - eval_path = script_args.eval_set_path - output_name = script_args.output_path - - train_dataset, eval_dataset = build_dataset(tokenizer, train_path, eval_path) - print("Training set size:", len(train_dataset)) - print("Evaluation set size:", len(eval_dataset)) - - # Define the training arguments + # Enhanced training arguments training_args = TrainingArguments( - output_dir=output_name, + output_dir=script_args.output_path, learning_rate=script_args.learning_rate, - # embedding_learning_rate=script_args.embedding_learning_rate, per_device_train_batch_size=script_args.per_device_train_batch_size, per_device_eval_batch_size=script_args.per_device_eval_batch_size, num_train_epochs=script_args.num_train_epochs, weight_decay=script_args.weight_decay, - evaluation_strategy="steps", - eval_steps=script_args.eval_every_steps, - save_strategy="steps", - save_steps=script_args.save_every_steps, gradient_accumulation_steps=script_args.gradient_accumulation_steps, gradient_checkpointing=script_args.gradient_checkpointing, deepspeed=script_args.deepspeed, @@ -281,24 +267,36 @@ def train_reward_model(): label_names=[], bf16=script_args.bf16, logging_strategy="steps", - logging_steps=10, + logging_steps=1, optim=script_args.optim, lr_scheduler_type=script_args.lr_scheduler_type, - warmup_ratio=0.03, + warmup_ratio=0.05, report_to='wandb', torch_compile=True, + # Enhanced checkpointing settings + load_best_model_at_end=True, + metric_for_best_model="mse_moving_avg", # Use moving average instead of raw MSE + greater_is_better=False, + save_strategy="steps", + save_steps=max(100, script_args.eval_every_steps), # Minimum 100 steps + evaluation_strategy="steps", + eval_steps=max(100, script_args.eval_every_steps), + save_total_limit=script_args.save_total_limit, + # Gradient clipping + max_grad_norm=1.0, ) + # Initialize model model = AutoModelForSequenceClassification.from_pretrained( script_args.model_name, num_labels=1, torch_dtype=torch.bfloat16, use_flash_attention_2=True, ) - + model.config.pad_token_id = model.config.eos_token_id model.config.use_cache = not script_args.gradient_checkpointing - # Initialize the trainer + # Initialize trainer with improved callbacks trainer = RewardTrainer( model=model, args=training_args, @@ -306,20 +304,29 @@ def train_reward_model(): eval_dataset=eval_dataset, compute_metrics=compute_metrics, data_collator=RewardDataCollatorWithPadding( - tokenizer=tokenizer, max_length=script_args.max_length + tokenizer=tokenizer, + max_length=script_args.max_length ), + callbacks=[ + EarlyStoppingCallback( + early_stopping_patience=script_args.early_stopping_patience, + early_stopping_threshold=script_args.early_stopping_threshold + ) + ], ) - # Start training - trainer.train() - - print("Saving last checkpoint of the model") - trainer.save_model(output_name + "/last_checkpoint") - tokenizer.save_pretrained(output_name + "/last_checkpoint") - # Push the model to Hugging Face Hub - # Ensure you have the necessary permissions and authentication - trainer.push_to_hub("rawsh/mirrorgemma-2-2b-PRM-base") + # Train and save + trainer.train() + + print("Saving final checkpoint") + trainer.save_model(script_args.output_path + "/final_checkpoint") + tokenizer.save_pretrained(script_args.output_path + "/final_checkpoint") + + # Push to Hub if specified + if script_args.output_model_name: + tokenizer.push_to_hub(script_args.output_model_name) + trainer.push_to_hub(script_args.output_model_name) if __name__ == "__main__": train_reward_model() diff --git a/mcts/tree_search.py b/mcts/tree_search.py index 7b130be..86a617c 100644 --- a/mcts/tree_search.py +++ b/mcts/tree_search.py @@ -4,32 +4,28 @@ import json from openai import AsyncOpenAI import time +import random # Added for jitter in retry logic from functools import wraps from collections import OrderedDict from asyncio import Semaphore, TimeoutError +from datasets import load_dataset +from tqdm import tqdm +from tqdm.asyncio import tqdm as atqdm -POLICY_URL = 'https://rawsh--vllm-gemma-serve.modal.run/v1/' -PRM_URL = 'https://rawsh--mirrorgemma-prm-embedder-score-output.modal.run' +# URLs and configuration +POLICY_URL = 'https://rawsh--vllm-qwen-ft-serve.modal.run/v1/' +PRM_URL = 'https://rawsh--mirrorqwen-prm-embedder-score-output.modal.run' API_KEY = '9FF74944EED19865193F979942FB1' -POLICY_SEMAPHORE = Semaphore(200) -PRM_SEMAPHORE = Semaphore(1) +CONCURRENT_MCTS_SEMAPHORE = Semaphore(50) +POLICY_SEMAPHORE = Semaphore(1000) +PRM_SEMAPHORE = Semaphore(1000) -MAX_RETRIES = 5 -TIMEOUT = 30 # seconds +MAX_RETRIES = 25 # Increased from 10 +TIMEOUT = 20 # Decreased from 30 to fail faster and retry -action_num = 0 - -class Node: - def __init__(self, state, parent=None): - self.state = state - self.parent = parent - self.children = {} - self.visits = 0 - self.total_value = 0 - self.prm_value = None - -def async_lru_cache(maxsize=128): +# Cache decorator and retry function +def async_lru_cache(maxsize=2000): cache = OrderedDict() def decorator(func): @wraps(func) @@ -50,30 +46,85 @@ async def retry_with_timeout(func, *args, **kwargs): except TimeoutError: if attempt == MAX_RETRIES - 1: raise - print(f"Attempt {attempt + 1} timed out. Retrying...") + # Exponential backoff with jitter + delay = min(1.5 ** attempt + random.random(), 10) + await asyncio.sleep(delay) except Exception as e: if attempt == MAX_RETRIES - 1: raise - print(f"Attempt {attempt + 1} failed with error: {str(e)}. Retrying...") - await asyncio.sleep(1) # Wait a bit before retrying - -async def mcts(root_state, correct_answer, num_iterations, session): - root = Node(root_state) - client = AsyncOpenAI(base_url=POLICY_URL, api_key=API_KEY) - terminal_nodes = set() + # Exponential backoff with jitter for other errors + delay = min(1.5 ** attempt + random.random(), 10) + print(f"Attempt {attempt + 1} failed with error: {str(e)}. Retrying in {delay:.1f}s...") + await asyncio.sleep(delay) - for i in range(num_iterations): - print(f"Starting iteration {i + 1}/{num_iterations}") - leaf = select(root) - is_term, is_corr = await retry_with_timeout(is_terminal, leaf.state, correct_answer, client, session) - if is_term: - terminal_nodes.add(leaf.state) - if not is_term: - child = await retry_with_timeout(expand, leaf, client, session) - value = await retry_with_timeout(simulate, child, correct_answer, client, session, terminal_nodes) - backpropagate(child, value) +class Node: + def __init__(self, state, parent=None): + self.state = state + self.parent = parent + self.children = {} + self.visits = 0 + self.total_value = 0 + self.prm_value = None - return root, terminal_nodes +# Progress tracking class with added metrics +class MCTSProgress: + def __init__(self, total_questions, iterations_per_question): + self.total_questions = total_questions + self.total_iterations = total_questions * iterations_per_question + self.completed_questions = 0 + self.completed_iterations = 0 + self.correct_sc = 0 # Self-consistency correct count + self.correct_any = 0 # Any-correct count + self.correct_best = 0 # Best PRM path correct count + self.total_actions = 0 # Global action counter + self.total_terminal_questions = 0 # Questions with at least one terminal node + + # Single progress bar with dynamic description + self.pbar = tqdm(total=self.total_iterations, + desc=self.get_progress_description()) + + def get_progress_description(self): + completed_pct = (self.completed_iterations / self.total_iterations) * 100 + sc_pct = (self.correct_sc / max(1, self.total_terminal_questions)) * 100 + any_pct = (self.correct_any / max(1, self.total_terminal_questions)) * 100 + best_pct = (self.correct_best / max(1, self.total_terminal_questions)) * 100 + return (f"#Q: {self.completed_questions}/{self.total_questions} | " + f"SC: {sc_pct:.1f}% | " + f"ANY: {any_pct:.1f}% | " + f"BEST: {best_pct:.1f}% | " + f"Actions: {self.total_actions}") + + def increment_iteration(self): + self.completed_iterations += 1 + self.pbar.update(1) + self.pbar.set_description(self.get_progress_description()) + + def complete_question(self, is_sc_correct, is_any_correct, is_best_correct, has_terminal_nodes): + self.completed_questions += 1 + if has_terminal_nodes: + self.total_terminal_questions += 1 + if is_sc_correct: + self.correct_sc += 1 + if is_any_correct: + self.correct_any += 1 + if is_best_correct: + self.correct_best += 1 + self.pbar.set_description(self.get_progress_description()) + + def close(self): + # Print final statistics + if self.total_terminal_questions > 0: + sc_pct = (self.correct_sc / self.total_terminal_questions) * 100 + any_pct = (self.correct_any / self.total_terminal_questions) * 100 + best_pct = (self.correct_best / self.total_terminal_questions) * 100 + print(f"\nFinal Results:") + print(f"Total Questions Processed: {self.completed_questions}") + print(f"Questions with Terminal Nodes: {self.total_terminal_questions}") + print(f"Self-Consistency Accuracy: {sc_pct:.2f}% ({self.correct_sc}/{self.total_terminal_questions})") + print(f"Any-Correct Accuracy: {any_pct:.2f}% ({self.correct_any}/{self.total_terminal_questions})") + print(f"Best-Path Accuracy: {best_pct:.2f}% ({self.correct_best}/{self.total_terminal_questions})") + print(f"Total Actions Taken: {self.total_actions}") + self.pbar.close() def select(node): while node.children: @@ -82,14 +133,22 @@ def select(node): node = best_uct_child(node) return node -async def expand(node, client, session): +def best_uct_child(node): + C = 1.41 + return max( + node.children.values(), + key=lambda child: (child.total_value / child.visits) + C * math.sqrt(math.log(node.visits) / child.visits) + ) + +async def expand(node, client, session, progress_tracker): action = await retry_with_timeout(get_next_action, node.state, client) new_state = apply_action(node.state, action) child = Node(new_state, parent=node) node.children[action] = child + progress_tracker.total_actions += 1 return child -async def simulate(node, correct_answer, client, session, terminal_nodes): +async def simulate(node, correct_answer, client, session, terminal_nodes, progress_tracker): state = node.state depth = 0 max_depth = 10 @@ -100,6 +159,7 @@ async def simulate(node, correct_answer, client, session, terminal_nodes): break action = await retry_with_timeout(get_next_action, state, client) state = apply_action(state, action) + progress_tracker.total_actions += 1 depth += 1 return await retry_with_timeout(evaluate_state, state, session) @@ -109,21 +169,11 @@ def backpropagate(node, value): node.total_value += value node = node.parent -def best_uct_child(node): - C = 1.41 - return max( - node.children.values(), - key=lambda child: (child.total_value / child.visits) + C * math.sqrt(math.log(node.visits) / child.visits) - ) - async def get_next_action(state, client): - global action_num - action_num += 1 - print(f"action {action_num}", end="\r") prompt = format_state_for_policy(state) async with POLICY_SEMAPHORE: response = await client.completions.create( - model="rawsh/mirrorgemma-2-2b-SFT", + model="rawsh/mirrorqwen2.5-0.5b-SFT", prompt=prompt, max_tokens=250, stop=["\n\n"], @@ -137,7 +187,6 @@ def is_correct(state, correct_answer): async def is_terminal(state, correct_answer, client, session): if is_correct(state, correct_answer): - print("CORRECT", state) return True, True if state.count("\n\n") < 2: @@ -145,7 +194,7 @@ async def is_terminal(state, correct_answer, client, session): async with POLICY_SEMAPHORE: response = await client.completions.create( - model="rawsh/mirrorgemma-2-2b-SFT", + model="rawsh/mirrorqwen2.5-0.5b-SFT", prompt=state, max_tokens=1, stop=["\n\n"], @@ -156,8 +205,6 @@ async def is_terminal(state, correct_answer, client, session): first_token_top_logprobs = response.choices[0].logprobs.top_logprobs[0] if "" in first_token_top_logprobs: scaled = math.exp(first_token_top_logprobs[""]) - res = response.choices[0].text.strip() - yes_bigger_than_no = True if "\n\n" in first_token_top_logprobs: scaled_no = math.exp(first_token_top_logprobs["\n\n"]) @@ -165,7 +212,6 @@ async def is_terminal(state, correct_answer, client, session): threshold = 0.95 terminal = (scaled >= threshold) and yes_bigger_than_no - print(first_token_top_logprobs[""], scaled, res, terminal) return terminal, False else: return False, False @@ -176,7 +222,7 @@ async def evaluate_state(state, session): async with PRM_SEMAPHORE: async with session.post(PRM_URL, json={"prompt": prompt}) as response: result = await response.json() - return float(result[0]['score']) + return float(result['score']) def apply_action(state, action): return f"{state}\n\n{action}" @@ -210,61 +256,131 @@ async def find_best_leaf_by_prm(node, session): async def evaluate_and_store_prm(node, session): node.prm_value = await retry_with_timeout(evaluate_state, node.state, session) -async def run_mcts(initial_state, correct_answer, num_iterations, session): - start_time = time.time() - root, terminal_nodes = await mcts(initial_state, correct_answer, num_iterations, session) - end_time = time.time() - best_leaf = await find_best_leaf_by_prm(root, session) - - terminal_paths = [] - for node in terminal_nodes: - score = await retry_with_timeout(evaluate_state, node, session) - terminal_paths.append({ - "final_state": node, - "score": score, - "correct": is_correct(node, correct_answer) - }) +async def mcts(root_state, correct_answer, num_iterations, session, progress_tracker): + root = Node(root_state) + client = AsyncOpenAI(base_url=POLICY_URL, api_key=API_KEY) + terminal_nodes = set() - result = { - "question": initial_state, - "correct_answer": correct_answer, - "statistics": { - "num_iterations": num_iterations, - "execution_time": end_time - start_time, - "total_terminal_nodes": len(terminal_nodes), - }, - "best_path": { - "final_state": best_leaf.state, - "score": best_leaf.prm_value, - "correct": is_correct(best_leaf.state, correct_answer) - }, - "terminal_paths": terminal_paths - } + for i in range(num_iterations): + leaf = select(root) + is_term, is_corr = await retry_with_timeout(is_terminal, leaf.state, correct_answer, client, session) + + if is_term: + terminal_nodes.add(leaf.state) + else: + child = await retry_with_timeout(expand, leaf, client, session, progress_tracker) + value = await retry_with_timeout(simulate, child, correct_answer, client, session, terminal_nodes, progress_tracker) + backpropagate(child, value) + + progress_tracker.increment_iteration() - return result + return root, terminal_nodes + +async def run_mcts(initial_state, correct_answer, num_iterations, session, progress_tracker): + async with CONCURRENT_MCTS_SEMAPHORE: + start_time = time.time() + root, terminal_nodes = await mcts(initial_state, correct_answer, num_iterations, session, progress_tracker) + end_time = time.time() + + best_leaf = await find_best_leaf_by_prm(root, session) + + terminal_paths = [] + terminal_correct_count = 0 + total_terminal_nodes = len(terminal_nodes) + max_prm_score = float('-inf') + best_prm_path_correct = False + + for node in terminal_nodes: + score = await retry_with_timeout(evaluate_state, node, session) + is_node_correct = is_correct(node, correct_answer) + if is_node_correct: + terminal_correct_count += 1 + if score > max_prm_score: + max_prm_score = score + best_prm_path_correct = is_node_correct + terminal_paths.append({ + "final_state": node, + "score": score, + "correct": is_node_correct + }) + + is_best_correct = is_correct(best_leaf.state, correct_answer) + + # Calculate metrics using only terminal nodes + # Self-consistency based on majority voting of terminal nodes (>50% correct) + has_terminal_nodes = total_terminal_nodes > 0 + is_sc_correct = (terminal_correct_count > total_terminal_nodes / 2) if has_terminal_nodes else False + is_any_correct = (terminal_correct_count > 0) # Any-correct using terminal nodes + + result = { + "question": initial_state, + "correct_answer": correct_answer, + "statistics": { + "num_iterations": num_iterations, + "execution_time": end_time - start_time, + "total_terminal_nodes": total_terminal_nodes, + "correct_terminal_nodes": terminal_correct_count, + "self_consistency_correct": is_sc_correct, + "any_correct": is_any_correct, + "has_terminal_nodes": has_terminal_nodes, + "best_prm_path_correct": best_prm_path_correct + }, + "best_path": { + "final_state": best_leaf.state, + "score": best_leaf.prm_value, + "correct": is_best_correct + }, + "terminal_paths": terminal_paths + } + + progress_tracker.complete_question(is_sc_correct, is_any_correct, best_prm_path_correct, has_terminal_nodes) + return result async def main(): - initial_states = [ - ("Janet hires six employees. Four of them are warehouse workers who make $15/hour, and the other two are managers who make $20/hour. Janet has to pay 10% of her workers' salaries in FICA taxes. If everyone works 25 days a month and 8 hours a day, how much does Janet owe total for their wages and taxes for one month?", "22000"), - ("Peggy is moving and is looking to get rid of her record collection. Sammy says that he will buy all of them for 4 dollars each. Bryan is only interested in half of the records but will offer 6 dollars each for the half that he is interested in and 1 dollar each for the remaining half that he is not interested in with the hopes that he can resell them in bulk later. If Peggy has 200 records, what is the difference in profit between Sammy versus Bryan's deal?", "100"), - ("Angelo and Melanie want to plan how many hours over the next week they should study together for their test next week. They have 2 chapters of their textbook to study and 4 worksheets to memorize. They figure out that they should dedicate 3 hours to each chapter of their textbook and 1.5 hours for each worksheet. If they plan to study no more than 4 hours each day, how many days should they plan to study total over the next week if they take a 10-minute break every hour, include 3 10-minute snack breaks each day, and 30 minutes for lunch each day?", "4"), - ("Carol is an aviation engineer deciding how much fuel to put in a jet. The empty plane needs 20 gallons of fuel per mile. Each person on the plane increases this amount by 3 gallons per mile, and each bag increases it by 2 gallons per mile. If there are 30 passengers and 5 flight crew, and each person brought two bags, how many gallons of fuel does the plane need for a 400-mile trip?", "106000"), - ("Susan is making jewelry with a repeating pattern that has 3 green beads, 5 purple beads, and twice as many red beads as green beads. If the pattern repeats three times per bracelet and 5 times per necklace, how many beads does she need to make 1 bracelets and 10 necklaces?", "742"), - ("A group of hawks is called a kettle. It is breeding season for hawks. A group of ornithologists are tracking 6 kettles of hawks. Each kettle has an average of 15 pregnancies that yield 4 babies per batch. How many babies are expected this season if approximately 25% are lost?", "270"), - ("Brendan makes $6/hour as a waiter. He's scheduled for 2 8-hour shifts and 1 12-hour shift this week. He also makes an average of $12 in tips each hour. Brendan is supposed to pay 20% of his income in taxes, but he only reports 1/3rd of his tips to the IRS. How much money does Brendan pay in taxes each week?", "56"), - ("Karen's students are about to take a standardized test. Karen gets a $500 bonus if their average score is above 75, plus an extra $10 bonus for every additional point the average score increases above 75. So far, Karen has graded 8 tests, and the average is 70. Given that each student can have a maximum score of 150, what combined score do the last two tests need to have for Karen to earn a $600 bonus?", "290") - ] + # Set random seed for reproducibility + random.seed(42) + + def process(example): + example["answer"] = example["answer"].split("\n#### ")[-1].strip() + return example + + gsm8k = load_dataset("openai/gsm8k", "main", split="train").shuffle(seed=42) + gsm8k = gsm8k.map(process, num_proc=24) + initial_states = [(example["question"], example["answer"]) for example in gsm8k] + + # Sample 100 questions + sample = False + if sample: + initial_states = random.sample(initial_states, 10) + num_iterations = 10 + # Initialize progress tracker + progress_tracker = MCTSProgress(len(initial_states), num_iterations) + async with aiohttp.ClientSession() as session: - tasks = [run_mcts(state, answer, num_iterations, session) for state, answer in initial_states] + tasks = [] + for state, answer in initial_states: + tasks.append(run_mcts(state, answer, num_iterations, session, progress_tracker)) + results = await asyncio.gather(*tasks) + progress_tracker.close() + + # Calculate and print final statistics + total_questions = len(results) + sc_correct = sum(1 for r in results if r["statistics"]["self_consistency_correct"]) + any_correct = sum(1 for r in results if r["statistics"]["any_correct"]) + + print(f"\nFinal Statistics:") + print(f"Total Questions: {total_questions}") + print(f"Self-Consistency Accuracy: {(sc_correct/total_questions)*100:.2f}%") + print(f"Any-Correct Accuracy: {(any_correct/total_questions)*100:.2f}%") + + # Write results with open("mcts_results.jsonl", "w") as f: for result in results: f.write(json.dumps(result) + "\n") - - print(f"\nAll MCTS processes completed and results written to mcts_results.jsonl") if __name__ == "__main__": asyncio.run(main()) \ No newline at end of file diff --git a/modal_prm_reward.py b/modal_prm_reward.py index f9269f7..7de4e30 100644 --- a/modal_prm_reward.py +++ b/modal_prm_reward.py @@ -2,53 +2,55 @@ image = ( modal.Image.debian_slim() - .pip_install("torch") - .pip_install("transformers") - .pip_install("accelerate") + .pip_install([ + "torch", + "transformers", + "accelerate", + "batched", + ]) ) -app = modal.App("mirrorgemma-prm", image=image) - +app = modal.App("mirrorqwen-prm", image=image) with image.imports(): from typing import List, Dict, Tuple import asyncio import torch from time import perf_counter as pc - import copy - # from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import pipeline import os - # from lib import extract_tensors, test - # print(test()) + + class BatchProcessor: + def __init__(self): + import batched + self.batched = batched + + def create_batch_processor(self, pipeline_func): + @self.batched.dynamically(batch_size=256, timeout_ms=100.0, small_batch_threshold=4) + def _process_batch(prompts: List[str]) -> List[Dict]: + return pipeline_func(prompts) + return _process_batch @app.cls( gpu=modal.gpu.A10G(), + # gpu=modal.gpu.H100(), container_idle_timeout=120, - # allow_concurrent_inputs=100, - # volumes={"/data": modal.Volume.from_name("my-test-volume")} + allow_concurrent_inputs=1000, secrets=[ modal.Secret.from_name("hf-token"), - # modal.Secret.from_name("wandb-token") ], ) class Embedder: - # model_id = "RLHFlow/ArmoRM-Llama3-8B-v0.1" - model_id = "rawsh/mirrorgemma-2-2b-prm-base" - revision = "a6bc6d57b3d7c873ba30e88ffd3e304e4758c295" + model_id = "rawsh/mirrorqwen2.5-0.5b-prm" + # revision = "a1cd3547343bab37ff61fd248ef46b779d5a8dfa" # base + revision = "3ad692bde328cddbfd45666cb6f7307430cac181" device = "cuda" print(model_id) @modal.build() def build(self): - # cache print("build") dtype = torch.bfloat16 with torch.device("cuda"): - # from transformers import ( - # AutoTokenizer - # ) - # tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", use_auth_token=True) - # tokenizer.push_to_hub("rawsh/mirrorgemma-2-2b-PRM-base") print("[build] loading model") start = pc() classifier = pipeline("sentiment-analysis", model=self.model_id, revision=self.revision, @@ -56,40 +58,65 @@ def build(self): elapsed = pc() - start print(f"[build] loading model took {elapsed} seconds") - # @modal.enter(snap=False) @modal.enter() def setup(self): - # Start the model to a GPU before doing any work. print("setup") - # os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - - # faster model loading dtype = torch.bfloat16 with torch.device("cuda"): print("[setup] loading model") start = pc() self.pipeline = pipeline("sentiment-analysis", model=self.model_id, revision=self.revision, - trust_remote_code=True, torch_dtype=dtype, device="cuda") + trust_remote_code=True, torch_dtype=dtype, device="cuda", batch_size=256) elapsed = pc() - start print(f"[setup] loading model took {elapsed} seconds") + + # Initialize batch processor + batch_processor = BatchProcessor() + self._process_batch = batch_processor.create_batch_processor(self.pipeline) @modal.web_endpoint(method="POST", docs=True) - def score_output(self, inp: dict): + async def score_output(self, inp: dict): prompt = inp["prompt"] - print("score_output") - return self.pipeline(prompt) + # Handle both single inputs and lists of inputs + if isinstance(prompt, str): + prompts = [prompt] + else: + prompts = prompt + try: + # Use the batched processing method + results = await self._process_batch.acall(prompts) + + # Return single result if input was single, otherwise return list + if isinstance(inp["prompt"], str): + return results[0] + return results + except Exception as e: + return {"error": str(e)} -# @app.local_entrypoint() -# async def main(): -# # score the messages -# prompt = 'What are some synonyms for the word "beautiful"?' -# response1 = 'Nicely, Beautifully, Handsome, Stunning, Wonderful, Gorgeous, Pretty, Stunning, Elegant' -# response2 = 'bad' -# messages1 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response1}] -# messages2 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response2}] -# m1 = Embedder().score_output(messages1) -# m2 = Embedder().score_output(messages2) -# res = await asyncio.gather(*[m1,m2]) -# print(response1, res[0]) -# print(response2, res[1]) \ No newline at end of file +@app.local_entrypoint() +async def main(): + embedder = Embedder() + + # Test with multiple prompts + prompt = 'What are some synonyms for the word "beautiful"?' + response1 = 'Nicely, Beautifully, Handsome, Stunning, Wonderful, Gorgeous, Pretty, Stunning, Elegant' + response2 = 'bad' + + # Create batch of requests + inputs = [ + {"prompt": response1}, + {"prompt": response2} + ] + + # Process in parallel + results = await asyncio.gather(*[ + embedder.score_output(inp) for inp in inputs + ]) + + # Print results + for response, result in zip([response1, response2], results): + print(f"Response: {response}\nResult: {result}\n") + + # Print batching statistics + print("Batching stats:", embedder._process_batch.stats) \ No newline at end of file diff --git a/modal_train_policy_sft.py b/modal_train_policy_sft.py index c79f8e8..a223876 100644 --- a/modal_train_policy_sft.py +++ b/modal_train_policy_sft.py @@ -31,8 +31,8 @@ @app.function( cpu=2.0, - # gpu=modal.gpu.A10G(), - gpu=modal.gpu.H100(), + gpu=modal.gpu.A10G(), + # gpu=modal.gpu.H100(), # gpu=modal.gpu.A100(size="40GB"), timeout=20 * HOURS, secrets=[ diff --git a/modal_train_prm.py b/modal_train_prm_init.py similarity index 85% rename from modal_train_prm.py rename to modal_train_prm_init.py index 41c2b29..1755091 100644 --- a/modal_train_prm.py +++ b/modal_train_prm_init.py @@ -19,6 +19,8 @@ .pip_install("datasets") .pip_install("wandb") .pip_install("bitsandbytes") + .pip_install("matplotlib") + .pip_install("seaborn") ) app = modal.App("train_prm", image=image) @@ -28,17 +30,20 @@ MINUTES = 60 # seconds HOURS = 60 * MINUTES +vol = modal.Volume.from_name("prm-tmp", create_if_missing=True) + @app.function( cpu=2.0, - # gpu=modal.gpu.A10G(), - gpu=modal.gpu.H100(), + gpu=modal.gpu.A10G(), + # gpu=modal.gpu.H100(), # gpu=modal.gpu.A100(count=4, size="40GB"), # gpu=modal.gpu.A100(size="40GB"), timeout=20 * HOURS, secrets=[ modal.Secret.from_name("hf-token"), modal.Secret.from_name("wandb-token") - ] + ], + volumes={"/out": vol}, ) def train_reward_model_upload_to_hf(): train_reward_model() diff --git a/modal_train_prm_st.py b/modal_train_prm_st.py new file mode 100644 index 0000000..a68a412 --- /dev/null +++ b/modal_train_prm_st.py @@ -0,0 +1,59 @@ +import modal + +cuda_version = "12.4.0" # should be no greater than host CUDA version +flavor = "devel" # includes full CUDA toolkit +operating_sys = "ubuntu22.04" +tag = f"{cuda_version}-{flavor}-{operating_sys}" + +image = ( + # modal.Image.debian_slim() + modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") + .apt_install("git") + .pip_install("torch") + .pip_install("packaging") + .pip_install("wheel") + .run_commands("pip install flash-attn --no-build-isolation") + .pip_install("transformers") + .pip_install("accelerate") + .pip_install("numpy") + .pip_install("datasets") + .pip_install("wandb") + .pip_install("bitsandbytes") + .pip_install("matplotlib") + .pip_install("seaborn") +) +app = modal.App("train_prm", image=image) + +with image.imports(): + from mcts.train_reward import train_reward_model + +MINUTES = 60 # seconds +HOURS = 60 * MINUTES + +vol = modal.Volume.from_name("prm-tmp", create_if_missing=True) + +@app.function( + cpu=2.0, + # gpu=modal.gpu.A10G(), + gpu=modal.gpu.H100(), + # gpu=modal.gpu.A100(count=4, size="40GB"), + # gpu=modal.gpu.A100(size="40GB"), + timeout=20 * HOURS, + secrets=[ + modal.Secret.from_name("hf-token"), + modal.Secret.from_name("wandb-token") + ], + volumes={"/out": vol}, +) +def train_reward_model_upload_to_hf(): + train_reward_model( + model_name="rawsh/mirrorqwen2.5-0.5b-prm", + dataset_path="rawsh/magpie-ultra-v0.1-PRM-data-ST-0", + output_model_name="rawsh/mirrorqwen2.5-0.5b-PRM-ST-0", + disable_binning=True + ) + +@app.local_entrypoint() +def main(): + # run the function remotely on Modal + train_reward_model_upload_to_hf.remote() \ No newline at end of file diff --git a/modal_vllm.py b/modal_vllm.py index 70615e9..190a067 100644 --- a/modal_vllm.py +++ b/modal_vllm.py @@ -15,13 +15,21 @@ def download_model_to_image(model_dir, model_name, model_revision): ) move_cache() +MODEL_DIR = "/qwen" +MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SFT" +MODEL_REVISION = "1f75c1204888cc912ad0b186c5b7620235246ffa" + +# MODEL_DIR = "/smollm" +# MODEL_NAME = "HuggingFaceTB/SmolLM2-135M-Instruct" +# MODEL_REVISION = "7e27bd9f95328f0f3b08261d1252705110c806f8" + # MODEL_DIR = "/qwen" # MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" # MODEL_REVISION = "a8b602d9dafd3a75d382e62757d83d89fca3be54" -MODEL_DIR = "/gemma" -MODEL_NAME = "rawsh/mirrorgemma-2-2b-SFT" -MODEL_REVISION = "0ec8c2eaead95160a9f908cd59f254bdace496bd" +# MODEL_DIR = "/gemma" +# MODEL_NAME = "rawsh/mirrorgemma-2-2b-SFT" +# MODEL_REVISION = "0ec8c2eaead95160a9f908cd59f254bdace496bd" vllm_image = ( modal.Image.debian_slim(python_version="3.10") @@ -47,7 +55,9 @@ def download_model_to_image(model_dir, model_name, model_revision): .env({"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"}) ) -app = modal.App("vllm-gemma") + +# app = modal.App("vllm-smollm") +app = modal.App("vllm-qwen-ft") N_GPU = 1 # tip: for best results, first upgrade to more powerful GPUs, and only then increase GPU count @@ -59,7 +69,7 @@ def download_model_to_image(model_dir, model_name, model_revision): @app.function( image=vllm_image, # gpu=modal.gpu.H100(count=N_GPU), - # gpu=modal.gpu.A100(count=N_GPU, size="40GB"), + # gpu=modal.gpu.A100(count=N_GPU, size="80GB"), gpu=modal.gpu.A10G(count=N_GPU), container_idle_timeout=2 * MINUTES, timeout=20 * MINUTES,