From 14409c2b8a5ae22ebac51881103c9099e8f59560 Mon Sep 17 00:00:00 2001 From: Robert Washbourne Date: Mon, 11 Nov 2024 01:16:34 +0000 Subject: [PATCH] fix --- mcts/process_results.py | 149 +++++++++++++-------------- mcts/train_policy_simpo.py | 128 +++++++++++++++++++++++ mcts/train_reward.py | 196 ++++++++++++++++++++++++++++-------- mcts/tree_search.py | 91 ++++++++++++----- modal_prm_reward.py | 10 +- modal_train_policy_simpo.py | 81 +++++++++++++++ modal_train_prm_init.py | 12 ++- modal_train_prm_st.py | 7 +- modal_vllm.py | 129 ++++++++++-------------- 9 files changed, 577 insertions(+), 226 deletions(-) create mode 100644 mcts/train_policy_simpo.py create mode 100644 modal_train_policy_simpo.py diff --git a/mcts/process_results.py b/mcts/process_results.py index 802faff..206f509 100644 --- a/mcts/process_results.py +++ b/mcts/process_results.py @@ -24,6 +24,7 @@ class QuestionAnalysis: binary_success: bool # Any correct path? sc_score: float # Overall self-consistency score sc_correct_percent: float # % of self-consistent answers that are correct + sc_most_common_correct: bool # Whether most common answer is correct total_paths: int correct_paths: List[PathMetrics] incorrect_paths: List[PathMetrics] @@ -66,7 +67,6 @@ def _extract_answer(self, final_state: str) -> str: return final_state.split('\\boxed{')[1].split('}')[0] return '' - # Changes to analyze_question method: def analyze_question(self, question: dict) -> QuestionAnalysis: """Analyze a single question's reasoning paths""" paths = [] @@ -80,7 +80,7 @@ def analyze_question(self, question: dict) -> QuestionAnalysis: path_metrics = PathMetrics( answer=answer, is_correct=path['correct'], - path_length=len(processed_steps), # Use processed steps length + path_length=len(processed_steps), prm_score=path['score'], raw_steps=raw_steps, steps=processed_steps @@ -98,18 +98,24 @@ def analyze_question(self, question: dict) -> QuestionAnalysis: # Overall self-consistency most_common = answers.most_common() most_common_count = most_common[0][1] if most_common else 0 + most_common_answer = most_common[0][0] if most_common else '' sc_score = most_common_count / total_paths # Calculate % of self-consistent answers that are correct sc_answers = [ans for ans, count in most_common - if count > total_paths * 0.2] # Consider answers that appear >20% of time + if count > total_paths * 0.2] # Consider answers that appear >20% of time sc_correct = sum(1 for ans in sc_answers - if any(p.answer == ans and p.is_correct - for p in correct_paths)) + if any(p.answer == ans and p.is_correct + for p in correct_paths)) sc_correct_percent = sc_correct / len(sc_answers) if sc_answers else 0 + + # Check if most common answer is correct + sc_most_common_correct = any(p.answer == most_common_answer and p.is_correct + for p in correct_paths) else: sc_score = 0 sc_correct_percent = 0 + sc_most_common_correct = False return QuestionAnalysis( question_text=question['question'], @@ -117,64 +123,78 @@ def analyze_question(self, question: dict) -> QuestionAnalysis: binary_success=bool(correct_paths), sc_score=sc_score, sc_correct_percent=sc_correct_percent, + sc_most_common_correct=sc_most_common_correct, total_paths=total_paths, correct_paths=correct_paths, incorrect_paths=incorrect_paths, answer_distribution=answers ) + # New version of get_paired_examples: def get_paired_examples( self, analyses: List[QuestionAnalysis], - max_pairs: int = 10000 + max_pairs: int = 10000, + top_n_correct: int = 50 # New parameter ) -> List[Dict[str, Any]]: - """Get paired positive/negative examples for each question""" + """Get paired examples considering multiple correct paths per question""" paired_examples = [] for analysis in analyses: if not analysis.correct_paths or not analysis.incorrect_paths: continue - # Find best correct path - shortest_correct = min(analysis.correct_paths, key=lambda p: p.path_length) - best_correct = max( - [p for p in analysis.correct_paths - if p.path_length <= shortest_correct.path_length * 1.2], - key=lambda p: p.prm_score + # Sort correct paths by quality (shorter length + higher PRM score) + sorted_correct = sorted( + analysis.correct_paths, + key=lambda p: (-p.prm_score, p.path_length) ) - # Find most deceptive incorrect path - best_incorrect = max( - analysis.incorrect_paths, - key=lambda p: ( - p.prm_score, - -abs(p.path_length - best_correct.path_length) - ) - ) + # Take top N correct paths + top_correct_paths = sorted_correct[:top_n_correct] - paired_examples.append({ - 'question': analysis.question_text, - 'correct_answer': analysis.correct_answer, - 'metrics': { - 'sc_score': analysis.sc_score, - 'sc_correct_percent': analysis.sc_correct_percent, - 'total_paths': analysis.total_paths, - 'answer_distribution': dict(analysis.answer_distribution) - }, - 'positive': { - 'steps': best_correct.steps, - 'answer': best_correct.answer, - 'prm_score': best_correct.prm_score, - 'path_length': best_correct.path_length - }, - 'negative': { - 'steps': best_incorrect.steps, - 'answer': best_incorrect.answer, - 'prm_score': best_incorrect.prm_score, - 'path_length': best_incorrect.path_length - } - }) - + # Get shortest correct path length for reference + shortest_correct_len = min(p.path_length for p in analysis.correct_paths) + + # Filter correct paths that aren't too much longer than shortest + filtered_correct = [ + p for p in top_correct_paths + if p.path_length <= shortest_correct_len * 1.4 + ] + + # For each correct path, find the most deceptive incorrect path + for correct_path in filtered_correct: + # Find most deceptive incorrect path relative to this correct path + best_incorrect = max( + analysis.incorrect_paths, + key=lambda p: ( + p.prm_score, + -abs(p.path_length - correct_path.path_length) + ) + ) + + paired_examples.append({ + 'question': analysis.question_text, + 'correct_answer': analysis.correct_answer, + 'metrics': { + 'sc_score': analysis.sc_score, + 'sc_correct_percent': analysis.sc_correct_percent, + 'total_paths': analysis.total_paths + }, + 'positive': { + 'steps': correct_path.steps, + 'answer': correct_path.answer, + 'prm_score': correct_path.prm_score, + 'path_length': correct_path.path_length + }, + 'negative': { + 'steps': best_incorrect.steps, + 'answer': best_incorrect.answer, + 'prm_score': best_incorrect.prm_score, + 'path_length': best_incorrect.path_length + } + }) + # Sort by quality criteria including SC correct % paired_examples.sort( key=lambda x: ( @@ -282,7 +302,8 @@ def generate_prm_training_data(self, analyses: List[QuestionAnalysis]) -> List[D return prm_examples def main(): - analyzer = MathReasoningAnalyzer('mcts_results.jsonl') + # analyzer = MathReasoningAnalyzer('mcts_results.jsonl') + analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st0.bak') # Analyze all questions analyses = [] @@ -295,6 +316,7 @@ def main(): binary_success = sum(1 for a in analyses if a.binary_success) avg_sc = np.mean([a.sc_score for a in analyses]) avg_sc_correct = np.mean([a.sc_correct_percent for a in analyses]) + sc_accuracy = sum(1 for a in analyses if a.sc_most_common_correct) / total * 100 # Terminal path statistics total_paths = [a.total_paths for a in analyses] @@ -327,6 +349,7 @@ def main(): print("\nOverall Statistics:") print(f"Total questions analyzed: {total}") print(f"Questions with at least one correct path: {binary_success} ({binary_success/total*100:.1f}%)") + print(f"Accuracy using most common answer (SC): {sc_accuracy:.1f}%") print(f"Average self-consistency score: {avg_sc:.3f}") print(f"Average % of self-consistent answers that are correct: {avg_sc_correct*100:.1f}%") @@ -359,13 +382,12 @@ def main(): for threshold in sc_thresholds: questions_above = sum(1 for a in analyses if a.sc_score >= threshold) correct_above = sum(1 for a in analyses - if a.sc_score >= threshold and a.sc_correct_percent > 0) + if a.sc_score >= threshold and a.sc_most_common_correct) print(f"Questions with SC >= {threshold:.1f}: {questions_above} " f"({questions_above/total*100:.1f}%) - " f"Correct: {correct_above} ({correct_above/questions_above*100:.1f}% of SC)") - - should_generate=False + should_generate=True if should_generate: # Generate both preference pairs and PRM training data paired_examples = analyzer.get_paired_examples(analyses) @@ -391,8 +413,6 @@ def main(): print(f"Average self-consistency: {np.mean(pair_sc):.3f} (±{np.std(pair_sc):.3f})") print(f"Average % correct in SC: {np.mean(pair_sc_correct)*100:.1f}% (±{np.std(pair_sc_correct)*100:.1f}%)") - # In main(), replace the PRM Training Data Statistics section with: - # Print PRM Training Data Statistics print("\nPRM Training Data Statistics:") correct_examples = [ex for ex in prm_training_data if ex["metadata"]["is_correct"]] @@ -455,27 +475,11 @@ def main(): f"mean={np.mean(incorrect_rewards):.3f}, " f"max={max(incorrect_rewards):.3f}") - # Save both datasets - with open('paired_examples.json', 'w') as f: - json.dump({ - 'summary_stats': { - 'total_questions': total, - 'questions_with_correct': binary_success, - 'avg_self_consistency': float(avg_sc), - 'avg_sc_correct_percent': float(avg_sc_correct), - 'path_stats': { - 'avg_total_paths': float(np.mean(total_paths)), - 'avg_correct_paths': float(np.mean(correct_paths)), - 'avg_incorrect_paths': float(np.mean(incorrect_paths)), - 'avg_correct_length': float(np.mean(all_correct_lengths)) if all_correct_lengths else None, - 'avg_incorrect_length': float(np.mean(all_incorrect_lengths)) if all_incorrect_lengths else None, - 'avg_correct_prm': float(np.mean(all_correct_scores)) if all_correct_scores else None, - 'avg_incorrect_prm': float(np.mean(all_incorrect_scores)) if all_incorrect_scores else None - }, - 'selected_pairs': len(paired_examples) - }, - 'paired_examples': paired_examples - }, f, indent=2) + # Save paired examples in JSONL format + with open('paired_examples.jsonl', 'w') as f: + for example in paired_examples: + json.dump(example, f) + f.write('\n') # Save PRM training data in JSONL format with open('prm_training.jsonl', 'w') as f: @@ -484,9 +488,8 @@ def main(): f.write('\n') print("\nOutput files written:") - print("- paired_examples.json") + print("- paired_examples.jsonl") print("- prm_training.jsonl") - if __name__ == "__main__": main() \ No newline at end of file diff --git a/mcts/train_policy_simpo.py b/mcts/train_policy_simpo.py new file mode 100644 index 0000000..6d5fedf --- /dev/null +++ b/mcts/train_policy_simpo.py @@ -0,0 +1,128 @@ +from dataclasses import dataclass, field +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import CPOConfig, CPOTrainer +import wandb + +@dataclass +class ScriptArguments: + model_name: str = field(default="Qwen/Qwen2-0.5B-Instruct") + dataset_name: str = field(default="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ST-0") + output_dir: str = field(default="simpo-math-model") + warmup_ratio: float = field(default=0.1) # 10% warmup + lr_scheduler_type: str = field(default="cosine") # Cosine decay + max_grad_norm: float = field(default=1.0) + output_model_name: str = field(default=None) + hub_token: str = field(default=None) + push_to_hub: bool = field(default=True) + # learning_rate: float = field(default=3e-7) + learning_rate: float = field(default=5e-7) + batch_size: int = field(default=8) + num_train_epochs: int = field(default=7) + # max_steps: int = field(default=-1) + # max_steps: int = field(default=10) + gradient_accumulation_steps: int = field(default=8) + beta: float = field(default=2.0) + simpo_gamma: float = field(default=0.5) + +# class CustomCPOTrainer(CPOTrainer): +# def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): +# loss, outputs = super().compute_loss(model, inputs, return_outputs=True) +# wandb.log({"loss": loss.item()}, step=self.state.step) +# if return_outputs: +# return loss, outputs +# return loss + +def train_simpo( + model_name=None, + dataset_name=None, + output_model_name=None, + hub_token=None + ): + args = ScriptArguments() + if model_name: + args.model_name = model_name + if dataset_name: + args.dataset_name = dataset_name + if output_model_name: + args.output_model_name = output_model_name + if hub_token: + args.hub_token = hub_token + + wandb.init(project="simpo-training") + + tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + args.model_name, + trust_remote_code=True, + torch_dtype=torch.float16, + device_map="auto" + ) + model.config.use_cache = False + + dataset = load_dataset(args.dataset_name, token=args.hub_token) + train_dataset = dataset["train"].map( + lambda examples: { + "prompt": examples["question"], + "chosen": ["\n\n".join(ex["steps"]) for ex in examples["positive"]], + "rejected": ["\n\n".join(ex["steps"]) for ex in examples["negative"]] + }, + batched=True, + remove_columns=dataset["train"].column_names + ) + + training_args = CPOConfig( + output_dir=args.output_dir, + num_train_epochs=args.num_train_epochs, + per_device_train_batch_size=args.batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + learning_rate=args.learning_rate, + # max_steps=args.max_steps, + remove_unused_columns=False, + loss_type="simpo", + cpo_alpha=0.5, + beta=args.beta, + simpo_gamma=args.simpo_gamma, + max_length=2048, + max_prompt_length=1024, + gradient_checkpointing=True, + push_to_hub=args.push_to_hub, + hub_model_id=args.output_model_name, + hub_token=args.hub_token, + hub_strategy="end", + report_to=["wandb"], + # Mixed precision settings + bf16=True, # Use bfloat16 instead of fp16 + tf32=True, + optim="paged_adamw_32bit", # Use 32-bit optimizer + max_grad_norm=args.max_grad_norm, + warmup_ratio=args.warmup_ratio, + lr_scheduler_type=args.lr_scheduler_type, + do_eval=True, + evaluation_strategy="steps", + eval_steps=20, + ) + + trainer = CPOTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=train_dataset, + processing_class=tokenizer + ) + + trainer.train() + trainer.save_model() + + # if args.push_to_hub and args.output_model_name: + # print("saving model") + # trainer.push_to_hub(repo_id=args.output_model_name, commit_message="Final SimPO model") + # tokenizer.push_to_hub(repo_id=args.output_model_name) + + wandb.finish() + +if __name__ == "__main__": + train_simpo() \ No newline at end of file diff --git a/mcts/train_reward.py b/mcts/train_reward.py index f939d99..eacd759 100644 --- a/mcts/train_reward.py +++ b/mcts/train_reward.py @@ -14,6 +14,8 @@ EarlyStoppingCallback, ) from transformers.utils import PaddingStrategy +from huggingface_hub import HfFolder +import os import random from collections import Counter @@ -24,15 +26,16 @@ class ScriptArguments: per_device_train_batch_size: Optional[int] = field(default=8) per_device_eval_batch_size: Optional[int] = field(default=4) gradient_accumulation_steps: Optional[int] = field(default=32) - learning_rate: Optional[float] = field(default=2e-5) + learning_rate: Optional[float] = field(default=1e-5) weight_decay: Optional[float] = field(default=0.0001) model_name: Optional[str] = field(default="Qwen/Qwen2.5-0.5B") + model_revision: Optional[str] = field(default=None) bf16: Optional[bool] = field(default=True) num_train_epochs: Optional[int] = field(default=2) train_set_path: Optional[str] = field(default="rawsh/magpie-ultra-v0.1-PRM-data-base") eval_set_path: Optional[str] = field(default="rawsh/magpie-ultra-v0.1-PRM-data-base") output_path: Optional[str] = field(default="./mirrorqwen2.5-0.5b-prm-base") - output_model_name: Optional[str] = field(default="rawsh/mirrorqwen2.5-0.5b-PRM") + output_model_name: Optional[str] = field(default="rawsh/mirrorqwen2.5-0.5b-prm") gradient_checkpointing: Optional[bool] = field(default=True) optim: Optional[str] = field(default="adamw_torch_fused") lr_scheduler_type: Optional[str] = field(default="cosine") @@ -42,12 +45,17 @@ class ScriptArguments: early_stopping_patience: Optional[int] = field(default=3) early_stopping_threshold: Optional[float] = field(default=0.001) disable_binning: Optional[bool] = field(default=False) - # Add new parameters for improved checkpointing warmup_steps: Optional[int] = field(default=100) save_total_limit: Optional[int] = field(default=3) min_loss_threshold: Optional[float] = field(default=0.1) - - + hub_token: Optional[str] = field( + default=None, + metadata={"help": "HuggingFace Hub token. If not provided, will try to use the cached token."} + ) + push_to_hub: Optional[bool] = field( + default=True, + metadata={"help": "Whether to push the model to the HuggingFace Hub"} + ) def build_dataset(tokenizer, train_path, eval_path, disable_binning: bool): def tokenize(sample): @@ -163,7 +171,7 @@ def assign_bin(example): @dataclass class RewardDataCollatorWithPadding: - tokenizer: AutoTokenizer + tokenizer: PreTrainedTokenizerBase padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None pad_to_multiple_of: Optional[int] = None @@ -190,6 +198,13 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: return batch def compute_metrics(eval_pred): + """ + Compute metrics for the model evaluation. + Args: + eval_pred: tuple of predictions and labels + Returns: + dict: Dictionary containing metrics + """ predictions = eval_pred.predictions.squeeze() labels = eval_pred.label_ids mse = np.mean((predictions - labels) ** 2) @@ -215,43 +230,101 @@ def compute_loss(self, model, inputs, return_outputs=False): return loss, {"rewards": rewards} return loss + def _push_to_hub(self, commit_message: Optional[str] = "End of training", **kwargs) -> str: + """Override to force push to hub by using the low-level API""" + if not self.args.push_to_hub: + return + + if not self.args.hub_token and not HfFolder.get_token(): + raise ValueError( + "No token provided and no token found in cache. " + "Please provide a token via hub_token parameter or log in using `huggingface-cli login`" + ) + + # Set token if provided + token = self.args.hub_token or HfFolder.get_token() + if token: + os.environ["HF_TOKEN"] = token + HfFolder.save_token(token) + + try: + from huggingface_hub import HfApi, create_repo + api = HfApi() + + # Create or ensure repo exists + repo_id = self.args.hub_model_id + try: + create_repo(repo_id, token=token, exist_ok=True, private=True) + print(f"Repository {repo_id} is ready") + except Exception as e: + print(f"Note: {str(e)}") + + # Save model and tokenizer locally first + local_path = os.path.join(self.args.output_dir, "for_hub_push") + os.makedirs(local_path, exist_ok=True) + self.save_model(local_path) + + # Use upload_folder with low-level API + api.upload_folder( + folder_path=local_path, + repo_id=repo_id, + token=token, + repo_type="model", + commit_message=commit_message, + revision="main", + create_pr=False, + ) + print(f"Successfully uploaded folder to {repo_id}") + + return repo_id + + except Exception as e: + print(f"Error in push_to_hub: {str(e)}") + raise e def train_reward_model( model_name=None, + model_revision=None, dataset_path=None, output_model_name=None, - disable_binning=False + disable_binning=False, + hub_token=None, + push_to_hub=True ): script_args = ScriptArguments( disable_binning=disable_binning, - warmup_steps=100, # Customize warmup period - save_total_limit=3, # Keep only last 3 checkpoints - min_loss_threshold=0.1 # Minimum loss threshold for saving + warmup_steps=100, + save_total_limit=3, + min_loss_threshold=0.1, + hub_token=hub_token, + push_to_hub=push_to_hub ) if model_name: script_args.model_name = model_name + if model_revision: + script_args.model_revision = model_revision if output_model_name: script_args.output_model_name = output_model_name if dataset_path: script_args.train_set_path = dataset_path script_args.eval_set_path = dataset_path - # Initialize tokenizer - tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_auth_token=True) - tokenizer.truncation_side = "left" - tokenizer.model_max_length = script_args.max_length - tokenizer.pad_token = tokenizer.eos_token + # Initialize tokenizer with better error handling + try: + tokenizer = AutoTokenizer.from_pretrained( + script_args.model_name, + use_auth_token=True if script_args.hub_token else None, + revision=script_args.model_revision + ) + tokenizer.truncation_side = "left" + tokenizer.model_max_length = script_args.max_length + tokenizer.pad_token = tokenizer.eos_token - # Build datasets - train_dataset, eval_dataset = build_dataset( - tokenizer, - script_args.train_set_path, - script_args.eval_set_path, - script_args.disable_binning - ) + except Exception as e: + raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") - # Enhanced training arguments + # Enhanced training arguments with Hub settings training_args = TrainingArguments( output_dir=script_args.output_path, learning_rate=script_args.learning_rate, @@ -273,30 +346,45 @@ def train_reward_model( warmup_ratio=0.05, report_to='wandb', torch_compile=True, - # Enhanced checkpointing settings load_best_model_at_end=True, - metric_for_best_model="mse_moving_avg", # Use moving average instead of raw MSE + metric_for_best_model="mse_moving_avg", greater_is_better=False, save_strategy="steps", - save_steps=max(100, script_args.eval_every_steps), # Minimum 100 steps + save_steps=max(100, script_args.eval_every_steps), evaluation_strategy="steps", eval_steps=max(100, script_args.eval_every_steps), save_total_limit=script_args.save_total_limit, - # Gradient clipping max_grad_norm=1.0, + # Hub-specific settings + push_to_hub=script_args.push_to_hub, + hub_model_id=script_args.output_model_name if script_args.push_to_hub else None, + hub_token=script_args.hub_token, ) - # Initialize model - model = AutoModelForSequenceClassification.from_pretrained( - script_args.model_name, - num_labels=1, - torch_dtype=torch.bfloat16, - use_flash_attention_2=True, + # Initialize model with better error handling + try: + model = AutoModelForSequenceClassification.from_pretrained( + script_args.model_name, + num_labels=1, + torch_dtype=torch.bfloat16, + use_flash_attention_2=True, + use_auth_token=True if script_args.hub_token else None, + revision=script_args.model_revision + ) + model.config.pad_token_id = model.config.eos_token_id + model.config.use_cache = not script_args.gradient_checkpointing + except Exception as e: + raise RuntimeError(f"Failed to initialize model: {str(e)}") + + # Build datasets + train_dataset, eval_dataset = build_dataset( + tokenizer, + script_args.train_set_path, + script_args.eval_set_path, + script_args.disable_binning ) - model.config.pad_token_id = model.config.eos_token_id - model.config.use_cache = not script_args.gradient_checkpointing - # Initialize trainer with improved callbacks + # Initialize trainer trainer = RewardTrainer( model=model, args=training_args, @@ -315,18 +403,38 @@ def train_reward_model( ], ) - # Train and save trainer.train() print("Saving final checkpoint") - trainer.save_model(script_args.output_path + "/final_checkpoint") - tokenizer.save_pretrained(script_args.output_path + "/final_checkpoint") + trainer.save_model(script_args.output_path) + tokenizer.save_pretrained(script_args.output_path) - # Push to Hub if specified - if script_args.output_model_name: - tokenizer.push_to_hub(script_args.output_model_name) - trainer.push_to_hub(script_args.output_model_name) + # Push to Hub with force push handling + if script_args.push_to_hub and script_args.output_model_name: + try: + print(f"Pushing model to hub: {script_args.output_model_name}") + # First try to push the tokenizer separately with force + try: + tokenizer.push_to_hub( + script_args.output_model_name, + use_auth_token=script_args.hub_token, + force=True + ) + print("Successfully pushed tokenizer to hub!") + except Exception as e: + print(f"Warning: Failed to push tokenizer: {str(e)}") + + # Then push the model with force + trainer.push_to_hub( + commit_message="Final trained model", + blocking=True, # Wait for push to complete + force=True # Force the push + ) + print("Successfully pushed model to hub!") + except Exception as e: + print(f"Error pushing to hub: {str(e)}") + print("Saving model locally anyway...") if __name__ == "__main__": - train_reward_model() + train_reward_model() \ No newline at end of file diff --git a/mcts/tree_search.py b/mcts/tree_search.py index 02288c1..bddba7f 100644 --- a/mcts/tree_search.py +++ b/mcts/tree_search.py @@ -13,16 +13,18 @@ from tqdm.asyncio import tqdm as atqdm # URLs and configuration -POLICY_URL = 'https://rawsh--vllm-qwen-ft-serve.modal.run/v1/' +# POLICY_URL = 'https://rawsh--vllm-qwen-ft-serve.modal.run/v1/' +POLICY_MODEL_NAME = 'mirrorqwen2.5-0.5b-SimPO-1' +POLICY_URL = 'https://rawsh--vllm-qwen-simpo-serve.modal.run/v1/' PRM_URL = 'https://rawsh--mirrorqwen-prm-embedder-score-output.modal.run' API_KEY = '9FF74944EED19865193F979942FB1' -CONCURRENT_MCTS_SEMAPHORE = Semaphore(20) +CONCURRENT_MCTS_SEMAPHORE = Semaphore(50) POLICY_SEMAPHORE = Semaphore(1000) PRM_SEMAPHORE = Semaphore(1000) -MAX_RETRIES = 20 # Increased from 10 -TIMEOUT = 15 # Decreased from 30 to fail faster and retry +MAX_RETRIES = 20 # Increased from 10s +TIMEOUT = 20 # Decreased from 30 to fail faster and retry # Cache decorator and retry function def async_lru_cache(maxsize=2000): @@ -42,7 +44,7 @@ async def wrapper(*args, **kwargs): async def retry_with_timeout(func, *args, **kwargs): for attempt in range(MAX_RETRIES): try: - return await asyncio.wait_for(func(*args, **kwargs), timeout=TIMEOUT) + return await asyncio.wait_for(func(*args, **kwargs), timeout=TIMEOUT * max(1, attempt / 10 )) except TimeoutError: if attempt == MAX_RETRIES - 1: raise @@ -175,7 +177,7 @@ async def get_next_action(state, client): prompt = format_state_for_policy(state) async with POLICY_SEMAPHORE: response = await client.completions.create( - model="rawsh/mirrorqwen2.5-0.5b-SFT", + model=POLICY_MODEL_NAME, prompt=prompt, max_tokens=250, stop=["\n\n"], @@ -196,7 +198,7 @@ async def is_terminal(state, correct_answer, client, session): async with POLICY_SEMAPHORE: response = await client.completions.create( - model="rawsh/mirrorqwen2.5-0.5b-SFT", + model=POLICY_MODEL_NAME, prompt=state, max_tokens=1, stop=["\n\n"], @@ -278,6 +280,7 @@ async def mcts(root_state, correct_answer, num_iterations, session, progress_tra return root, terminal_nodes + async def run_mcts(initial_state, correct_answer, num_iterations, session, progress_tracker): async with CONCURRENT_MCTS_SEMAPHORE: start_time = time.time() @@ -287,19 +290,27 @@ async def run_mcts(initial_state, correct_answer, num_iterations, session, progr best_leaf = await find_best_leaf_by_prm(root, session) terminal_paths = [] - terminal_correct_count = 0 - total_terminal_nodes = len(terminal_nodes) + answers = {} # Track answer frequencies max_prm_score = float('-inf') best_prm_path_correct = False + terminal_correct_count = 0 # Add this counter for node in terminal_nodes: score = await retry_with_timeout(evaluate_state, node, session) is_node_correct = is_correct(node, correct_answer) if is_node_correct: - terminal_correct_count += 1 + terminal_correct_count += 1 # Increment counter + + # Extract answer from the node + last_step = node.split("\n\n")[-1] + if r"\boxed{" in last_step: + answer = last_step.split(r"\boxed{")[1].split("}")[0] + answers[answer] = answers.get(answer, 0) + 1 + if score > max_prm_score: max_prm_score = score best_prm_path_correct = is_node_correct + terminal_paths.append({ "final_state": node, "score": score, @@ -308,14 +319,16 @@ async def run_mcts(initial_state, correct_answer, num_iterations, session, progr is_best_correct = is_correct(best_leaf.state, correct_answer) - # Calculate metrics using only terminal nodes - # Self-consistency based on majority voting of terminal nodes (>50% correct) - has_terminal_nodes = total_terminal_nodes > 0 - is_sc_correct = (terminal_correct_count > total_terminal_nodes / 2) if has_terminal_nodes else False - is_any_correct = (terminal_correct_count > 0) # Any-correct using terminal nodes + # Calculate SC using most common answer + has_terminal_nodes = len(terminal_nodes) > 0 + is_sc_correct = False + if has_terminal_nodes and answers: + most_common_answer = max(answers.items(), key=lambda x: x[1])[0] + is_sc_correct = any(p["correct"] and most_common_answer == p["final_state"].split(r"\boxed{")[1].split("}")[0] + for p in terminal_paths) - # Check if this question completed all iterations - is_fully_completed = total_terminal_nodes > 0 and num_iterations == progress_tracker.iterations_per_question + is_any_correct = any(p["correct"] for p in terminal_paths) + is_fully_completed = len(terminal_nodes) > 0 and num_iterations == progress_tracker.iterations_per_question result = { "question": initial_state, @@ -323,7 +336,7 @@ async def run_mcts(initial_state, correct_answer, num_iterations, session, progr "statistics": { "num_iterations": num_iterations, "execution_time": end_time - start_time, - "total_terminal_nodes": total_terminal_nodes, + "total_terminal_nodes": len(terminal_nodes), # Use len() directly "correct_terminal_nodes": terminal_correct_count, "self_consistency_correct": is_sc_correct, "any_correct": is_any_correct, @@ -344,29 +357,57 @@ async def run_mcts(initial_state, correct_answer, num_iterations, session, progr async def main(): # Set random seed for reproducibility - random.seed(42) + # random.seed(42) # st 0 + # random.seed(4242) # st 1 + random.seed(424242) # st 2 def process(example): example["answer"] = example["answer"].split("\n#### ")[-1].strip() return example - gsm8k = load_dataset("openai/gsm8k", "main", split="test").shuffle(seed=42) + # gsm8k = load_dataset("openai/gsm8k", "main", split="test").shuffle(seed=42) + gsm8k = load_dataset("openai/gsm8k", "main", split="train").shuffle(seed=42) gsm8k = gsm8k.map(process, num_proc=24) initial_states = [(example["question"], example["answer"]) for example in gsm8k] - - num_iterations = 10 + # initial_states = random.sample(initial_states, 200) + + # SAMPLE 200 QUESTIONS - SELF TRAINING + initial_states = random.sample(initial_states, 200) + num_iterations = 100 - # Initialize progress tracker - progress_tracker = MCTSProgress(len(initial_states), num_iterations) + print("cold starting policy vllm + prm api") + + # warm up the chat API + client = AsyncOpenAI(base_url=POLICY_URL, api_key=API_KEY) + completion = await client.completions.create( + model=POLICY_MODEL_NAME, + prompt="TEST", + max_tokens=1, + stop=["\n\n"], + temperature=0.3, + logprobs=20, + ) + assert(len(completion.choices) == 1) + print("warmed up vllm") async with aiohttp.ClientSession() as session: + # warm up PRM api + async with session.post(PRM_URL, json={"prompt": "TEST"}) as response: + result = await response.json() + + assert('score' in result) + print("warmed up PRM api") + + # Initialize progress tracker + progress_tracker = MCTSProgress(len(initial_states), num_iterations) + tasks = [] for state, answer in initial_states: tasks.append(run_mcts(state, answer, num_iterations, session, progress_tracker)) results = await asyncio.gather(*tasks) - progress_tracker.close() + progress_tracker.close() # Calculate and print final statistics total_questions = len(results) diff --git a/modal_prm_reward.py b/modal_prm_reward.py index 7de4e30..5647529 100644 --- a/modal_prm_reward.py +++ b/modal_prm_reward.py @@ -25,15 +25,18 @@ def __init__(self): self.batched = batched def create_batch_processor(self, pipeline_func): - @self.batched.dynamically(batch_size=256, timeout_ms=100.0, small_batch_threshold=4) + @self.batched.dynamically(batch_size=256, timeout_ms=200.0, small_batch_threshold=4) def _process_batch(prompts: List[str]) -> List[Dict]: return pipeline_func(prompts) return _process_batch @app.cls( + # gpu=modal.gpu.T4(), gpu=modal.gpu.A10G(), # gpu=modal.gpu.H100(), + # gpu=modal.gpu.A100(), container_idle_timeout=120, + # allow_concurrent_inputs=1000, allow_concurrent_inputs=1000, secrets=[ modal.Secret.from_name("hf-token"), @@ -41,8 +44,9 @@ def _process_batch(prompts: List[str]) -> List[Dict]: ) class Embedder: model_id = "rawsh/mirrorqwen2.5-0.5b-prm" - # revision = "a1cd3547343bab37ff61fd248ef46b779d5a8dfa" # base - revision = "3ad692bde328cddbfd45666cb6f7307430cac181" + # revision = "894341fbd81d0c1abdd98b4e0630de932aa63c6f" # base + # revision = "42e07d1b708282ac2aae338050d8116f8c69398d" # st0 + revision = "65f4a7601dffacc40e0ef7fa4733d346c926bd18" # st1 device = "cuda" print(model_id) diff --git a/modal_train_policy_simpo.py b/modal_train_policy_simpo.py new file mode 100644 index 0000000..0d7cbe6 --- /dev/null +++ b/modal_train_policy_simpo.py @@ -0,0 +1,81 @@ +import modal +import sys +import traceback + +# Define CUDA specifications +cuda_version = "12.4.0" +flavor = "devel" +operating_sys = "ubuntu22.04" +tag = f"{cuda_version}-{flavor}-{operating_sys}" + +# Create Modal image with all necessary dependencies +image = ( + modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") + .apt_install("git") + .pip_install("torch") + .pip_install("transformers") + .pip_install("accelerate") + .pip_install("datasets") + .pip_install("wandb") + .pip_install("trl>=0.7.6") + .pip_install("huggingface_hub") + .pip_install("bitsandbytes") +) + +with image.imports(): + from mcts.train_policy_simpo import train_simpo # Import from our new simplified script + +# Create Modal app +app = modal.App("train-policy-simpo", image=image) + +@app.function( + cpu=4.0, + gpu=modal.gpu.H100(count=1), + timeout=24 * 60 * 60, + memory=32768, + secrets=[ + modal.Secret.from_name("hf-token"), + modal.Secret.from_name("wandb-token") + ], +) +def train_policy_simpo(): + import os + from huggingface_hub import HfFolder + import wandb + + try: + # Set up HuggingFace token + hf_token = os.environ["HF_TOKEN"] + HfFolder.save_token(hf_token) + + # Set up Weights & Biases + wandb.login(key=os.environ["WANDB_API_KEY"]) + + # Run training with specified parameters + train_simpo( + # model_name="rawsh/mirrorqwen2.5-0.5b-SFT", + model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-0", + dataset_name="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ST-1", + output_model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-1", + hub_token=hf_token + ) + except Exception as e: + print(f"Error during training: {str(e)}", file=sys.stderr) + print("Traceback:", file=sys.stderr) + traceback.print_exc(file=sys.stderr) + # Make sure to finish wandb run even on error + try: + wandb.finish() + except: + pass + raise e + +@app.local_entrypoint() +def main(): + print("Starting full model SimPO training on Modal...") + try: + train_policy_simpo.remote() + print("Training job submitted to Modal. Check W&B dashboard for training progress.") + except Exception as e: + print(f"Error in training job: {str(e)}") + sys.exit(1) \ No newline at end of file diff --git a/modal_train_prm_init.py b/modal_train_prm_init.py index 1755091..745e867 100644 --- a/modal_train_prm_init.py +++ b/modal_train_prm_init.py @@ -34,8 +34,8 @@ @app.function( cpu=2.0, - gpu=modal.gpu.A10G(), - # gpu=modal.gpu.H100(), + # gpu=modal.gpu.A10G(), + gpu=modal.gpu.H100(), # gpu=modal.gpu.A100(count=4, size="40GB"), # gpu=modal.gpu.A100(size="40GB"), timeout=20 * HOURS, @@ -46,7 +46,13 @@ volumes={"/out": vol}, ) def train_reward_model_upload_to_hf(): - train_reward_model() + train_reward_model( + # add revision + model_name="Qwen/Qwen2.5-0.5B", + dataset_path="rawsh/magpie-ultra-v0.1-PRM-data-base", + output_model_name="rawsh/mirrorqwen2.5-0.5b-prm", + disable_binning=False + ) @app.local_entrypoint() def main(): diff --git a/modal_train_prm_st.py b/modal_train_prm_st.py index a68a412..da877c3 100644 --- a/modal_train_prm_st.py +++ b/modal_train_prm_st.py @@ -47,9 +47,12 @@ ) def train_reward_model_upload_to_hf(): train_reward_model( + # add revision model_name="rawsh/mirrorqwen2.5-0.5b-prm", - dataset_path="rawsh/magpie-ultra-v0.1-PRM-data-ST-0", - output_model_name="rawsh/mirrorqwen2.5-0.5b-PRM-ST-0", + # model_revision="aed1bcf7d3d984272e329c3843f9c5fd0dfe5ca5", # base + model_revision="42e07d1b708282ac2aae338050d8116f8c69398d", # st0 + dataset_path="rawsh/mirrorqwen2.5-0.5B-gsm8k-PRM-data-ST-1", + output_model_name="rawsh/mirrorqwen2.5-0.5b-prm", disable_binning=True ) diff --git a/modal_vllm.py b/modal_vllm.py index 190a067..7aab0c3 100644 --- a/modal_vllm.py +++ b/modal_vllm.py @@ -1,4 +1,6 @@ import modal +import asyncio +from contextlib import asynccontextmanager def download_model_to_image(model_dir, model_name, model_revision): import os @@ -6,7 +8,6 @@ def download_model_to_image(model_dir, model_name, model_revision): from transformers.utils import move_cache os.makedirs(model_dir, exist_ok=True) - snapshot_download( model_name, revision=model_revision, @@ -16,27 +17,19 @@ def download_model_to_image(model_dir, model_name, model_revision): move_cache() MODEL_DIR = "/qwen" -MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SFT" -MODEL_REVISION = "1f75c1204888cc912ad0b186c5b7620235246ffa" - -# MODEL_DIR = "/smollm" -# MODEL_NAME = "HuggingFaceTB/SmolLM2-135M-Instruct" -# MODEL_REVISION = "7e27bd9f95328f0f3b08261d1252705110c806f8" - -# MODEL_DIR = "/qwen" -# MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" -# MODEL_REVISION = "a8b602d9dafd3a75d382e62757d83d89fca3be54" - -# MODEL_DIR = "/gemma" -# MODEL_NAME = "rawsh/mirrorgemma-2-2b-SFT" -# MODEL_REVISION = "0ec8c2eaead95160a9f908cd59f254bdace496bd" +# # st0 +# MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SimPO-0" +# MODEL_REVISION = "c699a3f7e82a805d6a4b158b033c5d7919230dd1" +# st1 +MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-SimPO-1" +MODEL_REVISION = "4ba061377ace8d0fb15802aaf943b4184420ea7d" vllm_image = ( modal.Image.debian_slim(python_version="3.10") .pip_install( - "vllm==0.6.1.post2", + "vllm==0.6.2", "torch==2.4.0", - "transformers==4.44.2", + "transformers>=4.45", "ray==2.36.0", "hf-transfer==0.1.8", "huggingface_hub==0.25.0", @@ -55,30 +48,42 @@ def download_model_to_image(model_dir, model_name, model_revision): .env({"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"}) ) +app = modal.App("vllm-qwen-simpo") -# app = modal.App("vllm-smollm") -app = modal.App("vllm-qwen-ft") - -N_GPU = 1 # tip: for best results, first upgrade to more powerful GPUs, and only then increase GPU count - -MINUTES = 60 # seconds +N_GPU = 1 +MINUTES = 60 HOURS = 60 * MINUTES -# key: 9FF74944EED19865193F979942FB1 +async def get_model_config(engine): + try: + return await engine.get_model_config() + except Exception as e: + print(f"Error getting model config: {e}") + raise + +@asynccontextmanager +async def lifespan(app): + # Startup + try: + await asyncio.sleep(0) # Give chance for event loop to start + yield + finally: + # Shutdown: Cancel all pending tasks + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) @app.function( image=vllm_image, - # gpu=modal.gpu.H100(count=N_GPU), - # gpu=modal.gpu.A100(count=N_GPU, size="80GB"), gpu=modal.gpu.A10G(count=N_GPU), + # gpu=modal.gpu.T4(), + # gpu=modal.gpu.A100(), container_idle_timeout=2 * MINUTES, timeout=20 * MINUTES, + # allow_concurrent_inputs=1000, allow_concurrent_inputs=1000, - secrets=[ - modal.Secret.from_name("vllm-token"), - # modal.Secret.from_name("hf-token"), - ] - # volumes={MODELS_DIR: volume}, + secrets=[modal.Secret.from_name("vllm-token")] ) @modal.asgi_app() def serve(): @@ -89,40 +94,18 @@ def serve(): from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.serving_chat import OpenAIServingChat - from vllm.entrypoints.openai.serving_completion import ( - OpenAIServingCompletion, - ) + from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion + from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.usage.usage_lib import UsageContext - def get_model_config(engine): - import asyncio - - try: # adapted from vLLM source -- https://github.com/vllm-project/vllm/blob/507ef787d85dec24490069ffceacbd6b161f4f72/vllm/entrypoints/openai/api_server.py#L235C1-L247C1 - event_loop = asyncio.get_running_loop() - except RuntimeError: - event_loop = None - - if event_loop is not None and event_loop.is_running(): - # If the current is instanced by Ray Serve, - # there is already a running event loop - model_config = event_loop.run_until_complete(engine.get_model_config()) - else: - # When using single vLLM without engine_use_ray - model_config = asyncio.run(engine.get_model_config()) - - return model_config - - # volume.reload() # ensure we have the latest version of the weights - - # create a fastAPI app that uses vLLM's OpenAI-compatible router web_app = fastapi.FastAPI( title=f"OpenAI-compatible {MODEL_NAME} server", description="Run an OpenAI-compatible LLM server with vLLM on modal.com", version="0.0.1", docs_url="/docs", + lifespan=lifespan ) - # security: CORS middleware for external requests http_bearer = fastapi.security.HTTPBearer( scheme_name="Bearer Token", description="See code for authentication details.", @@ -135,7 +118,6 @@ def get_model_config(engine): allow_headers=["*"], ) - # security: inject dependency on authed routes TOKEN = os.environ["API_TOKEN"] async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): if api_key.credentials != TOKEN: @@ -146,20 +128,18 @@ async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): return {"username": "authenticated_user"} router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)]) - + # wrap vllm's router in auth router router.include_router(api_server.router) # add authed vllm to our fastAPI app web_app.include_router(router) engine_args = AsyncEngineArgs( - # model=MODELS_DIR + "/" + MODEL_NAME, model=MODEL_DIR, tensor_parallel_size=N_GPU, gpu_memory_utilization=0.90, max_model_len=8096, - # enforce_eager=True, - enforce_eager=False, # capture the graph for faster inference, but slower cold starts (30s > 20s) + enforce_eager=False, enable_prefix_caching=True ) @@ -167,28 +147,25 @@ async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): engine_args, usage_context=UsageContext.OPENAI_API_SERVER ) - model_config = get_model_config(engine) + async def setup_engine(): + model_config = await get_model_config(engine) + return model_config + # Use asyncio.run to properly handle the async setup + model_config = asyncio.run(setup_engine()) request_logger = RequestLogger(max_log_len=2048) - api_server.openai_serving_chat = OpenAIServingChat( - engine, - model_config=model_config, - served_model_names=[MODEL_NAME], - chat_template=None, - response_role="assistant", - lora_modules=[], - prompt_adapters=[], - request_logger=request_logger, - ) - api_server.openai_serving_completion = OpenAIServingCompletion( + base_model_paths = [ + BaseModelPath(name=MODEL_NAME.split("/")[1], model_path=MODEL_NAME) + ] + + api_server.completion = lambda s: OpenAIServingCompletion( engine, model_config=model_config, - served_model_names=[MODEL_NAME], + base_model_paths=base_model_paths, lora_modules=[], prompt_adapters=[], request_logger=request_logger, ) - return web_app - + return web_app \ No newline at end of file