Skip to content

Commit

Permalink
updoot
Browse files Browse the repository at this point in the history
  • Loading branch information
rawsh authored Nov 16, 2024
1 parent de90584 commit 2807e0f
Show file tree
Hide file tree
Showing 15 changed files with 1,129 additions and 99 deletions.
56 changes: 24 additions & 32 deletions mcts/process_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def analyze_question(self, question: dict) -> QuestionAnalysis:
def get_paired_examples(
self,
analyses: List[QuestionAnalysis],
max_pairs: int = 10000,
max_pairs: int = 20000,
top_n_correct: int = 10,
top_n_incorrect: int = 10
) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -261,7 +261,6 @@ def get_paired_examples(
def generate_prm_training_data(self, analyses: List[QuestionAnalysis]) -> List[Dict[str, Any]]:
"""Generate training data for Process Reward Model (PRM) from MCTS paths."""
prm_examples = []
seen_examples = set() # Track unique (question, steps) combinations
original_correct_lengths = []
original_incorrect_lengths = []

Expand All @@ -280,19 +279,6 @@ def generate_prm_training_data(self, analyses: List[QuestionAnalysis]) -> List[D

for k, step in enumerate(path.steps, 1):
partial_steps = path.steps[:k]

# Create unique key based on question and step sequence
example_key = (
hash(analysis.question_text),
hash(str(partial_steps))
)

# Skip if we've seen this exact example
if example_key in seen_examples:
continue

seen_examples.add(example_key)

m_k = K - k
r_s_k = 0
w_s_k = (1 - v_prev) / (m_k + 1) * (1 - 2 * r_s_k)
Expand Down Expand Up @@ -323,19 +309,6 @@ def generate_prm_training_data(self, analyses: List[QuestionAnalysis]) -> List[D

for k, step in enumerate(path.steps, 1):
partial_steps = path.steps[:k]

# Create unique key based on question and step sequence
example_key = (
hash(analysis.question_text),
hash(str(partial_steps))
)

# Skip if we've seen this exact example
if example_key in seen_examples:
continue

seen_examples.add(example_key)

penalize = k == K
m_k = K - k if not penalize else K - k + 1
r_s_k = 0 if not penalize else 1
Expand All @@ -356,15 +329,34 @@ def generate_prm_training_data(self, analyses: List[QuestionAnalysis]) -> List[D
})
v_prev = v_k

# Print statistics about duplicates avoided
print(f"\nTotal examples generated: {len(prm_examples)}")
print(f"Unique (question, steps) combinations: {len(seen_examples)}")
print(f"Duplicates avoided: {len(seen_examples) - len(prm_examples)}")
# Record length statistics
if original_correct_lengths:
print("\nOriginal Path Length Statistics:")
print(f"Correct paths mean length: {np.mean(original_correct_lengths):.1f}{np.std(original_correct_lengths):.1f})")
if original_incorrect_lengths:
print(f"Incorrect paths mean length: {np.mean(original_incorrect_lengths):.1f}{np.std(original_incorrect_lengths):.1f})")

# Print complete path statistics
complete_correct = [ex for ex in prm_examples if ex["metadata"]["is_correct"] and ex["metadata"]["is_complete"]]
complete_incorrect = [ex for ex in prm_examples if not ex["metadata"]["is_correct"] and ex["metadata"]["is_complete"]]

print("\nComplete Path Statistics:")
print(f"Complete correct paths: {len(complete_correct)}")
print(f"Complete incorrect paths: {len(complete_incorrect)}")

if complete_correct:
print(f"Complete correct mean length: {np.mean([ex['metadata']['path_length'] for ex in complete_correct]):.1f}")
if complete_incorrect:
print(f"Complete incorrect mean length: {np.mean([ex['metadata']['path_length'] for ex in complete_incorrect]):.1f}")

return prm_examples

def main():
analyzer = MathReasoningAnalyzer('mcts_results.jsonl')
# analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st1_orpo.bak')
# analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st2_orpo.bak')
# analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st3_orpo.bak')

# analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st0.bak')
# analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st1.bak')
# analyzer = MathReasoningAnalyzer('mcts_results.jsonl.st2-v1.bak')
Expand Down
5 changes: 3 additions & 2 deletions mcts/simple_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
MAX_RETRIES = 5
TIMEOUT = 20
MAX_CONCURRENT = 100
SAMPLES_PER_QUESTION = 10 # Default to single sample mode, override with CLI arg
SAMPLES_PER_QUESTION = 1 # Default to single sample mode, override with CLI arg

# Cache decorator for PRM scores
def async_lru_cache(maxsize=2000):
Expand Down Expand Up @@ -108,7 +108,8 @@ async def generate_completion(
"""Generate a single completion."""
async with semaphore:
response = await client.completions.create(
model="mirrorqwen2.5-0.5b-SimPO-3",
# model="mirrorqwen2.5-0.5b-SimPO-3",
model="MetaMath-Qwen2.5-0.5b",
prompt=question,
max_tokens=1500,
temperature=0.8
Expand Down
7 changes: 4 additions & 3 deletions mcts/train_policy_orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ def train_orpo(
per_device_train_batch_size=8,
gradient_accumulation_steps=8,
# learning_rate=5e-7,
learning_rate=8e-6,
# learning_rate=8e-6,
lr_scheduler_type="linear",
beta=0.1,
# learning_rate=5e-6,
learning_rate=3e-6,
# max_steps
max_length=2048,
max_prompt_length=1024,
gradient_checkpointing=True,
Expand All @@ -71,7 +72,7 @@ def train_orpo(
# lr_scheduler_type="cosine",
do_eval=True,
evaluation_strategy="steps",
eval_steps=20,
eval_steps=10,
remove_unused_columns=False,
logging_steps=10,
logging_first_step=True
Expand Down
148 changes: 148 additions & 0 deletions mcts/train_policy_sft_metamath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from unsloth import FastLanguageModel
import torch
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
from unsloth import UnslothTrainer, UnslothTrainingArguments
from datasets import load_dataset
from unsloth.chat_templates import get_chat_template

# Constants
SEED = 42
max_seq_length = 8192
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = False

first = True

def format_answer(response):
global first
"""Extract answer from #### pattern and format response."""
# Split at #### and get everything before it
parts = response.split('####')
if len(parts) < 2:
return None


solution = "\n\n".join(parts[0].strip().split("\n"))
answer = parts[1].split('The answer is:')[0].strip()

if (first):
print(solution)
print(answer)
first = False

return f"{solution}\n\nThe final answer is: $\\boxed{{{answer}}}$"

def train_sft():
# Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "Qwen/Qwen2.5-0.5B",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)

# Set up chat template
tokenizer = get_chat_template(
tokenizer,
chat_template = "qwen-2.5",
)

# Configure PEFT
model = FastLanguageModel.get_peft_model(
model,
r = 128,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
"embed_tokens", "lm_head"],
lora_alpha = 32,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = True,
loftq_config = None,
)

# Load dataset
ds = load_dataset("meta-math/MetaMathQA")
train_ds = ds['train']

# Format prompts
def formatting_prompts_func(examples):
conversations = []
for query, response in zip(examples['query'], examples['response']):
formatted_response = format_answer(response)
if formatted_response is None:
continue

conversation = [
{"role": "user", "content": query},
{"role": "assistant", "content": formatted_response}
]
conversations.append(conversation)

texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
for convo in conversations]
return {"text": texts}

# <|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is the total cost of purchasing equipment for all sixteen players on the football team, considering that each player requires a $25 jersey, a $15.20 pair of shorts, and a pair of socks priced at $6.80?<|im_end|>\n<|im_start|>assistant\nEach player requires a $25 jersey, a $15.20 pair of shorts, and a pair of socks priced at $6.80.\n\nSo the total cost for each player is $25 + $15.20 + $6.80 = $47.\n\nSince there are sixteen players on the football team, the total cost for all of them is 16 * $47 = $752.\n\nThe final answer is: $\\boxed{752}$<|im_end|>\n'

# Process dataset
formatted_dataset = train_ds.map(
formatting_prompts_func,
batched=True,
remove_columns=train_ds.column_names
)
print(len(formatted_dataset))
print(formatted_dataset[0])

# Configure trainer
trainer = UnslothTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = formatted_dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
dataset_num_proc = 8,
packing = True,
args = UnslothTrainingArguments(
per_device_train_batch_size = 8,
gradient_accumulation_steps = 8,
warmup_ratio = 0.1,
num_train_epochs = 3,
# learning_rate = 5e-6,
# embedding_learning_rate = 5e-7,
learning_rate = 8e-6,
embedding_learning_rate = 1e-6,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_torch_fused",
weight_decay = 0.01,
lr_scheduler_type = "cosine",
seed = 3407,
output_dir = "outputs",
),
)

# Print GPU stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

# Train
trainer_stats = trainer.train()

# Save model
model.push_to_hub_merged(
"rawsh/MetaMath-Qwen2.5-0.5b",
tokenizer,
save_method = "merged_16bit"
)

if __name__ == "__main__":
train_sft()
Loading

0 comments on commit 2807e0f

Please sign in to comment.