Skip to content

Commit

Permalink
tweak
Browse files Browse the repository at this point in the history
  • Loading branch information
rawsh authored Nov 6, 2024
1 parent 0fc3e40 commit d59673b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 32 deletions.
2 changes: 1 addition & 1 deletion mcts/train_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ScriptArguments:
per_device_train_batch_size: Optional[int] = field(default=8)
per_device_eval_batch_size: Optional[int] = field(default=4)
gradient_accumulation_steps: Optional[int] = field(default=32)
learning_rate: Optional[float] = field(default=8e-6)
learning_rate: Optional[float] = field(default=2e-5)
weight_decay: Optional[float] = field(default=0.0001)
model_name: Optional[str] = field(default="Qwen/Qwen2.5-0.5B")
bf16: Optional[bool] = field(default=True)
Expand Down
63 changes: 32 additions & 31 deletions mcts/tree_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
PRM_URL = 'https://rawsh--mirrorqwen-prm-embedder-score-output.modal.run'
API_KEY = '9FF74944EED19865193F979942FB1'

CONCURRENT_MCTS_SEMAPHORE = Semaphore(50)
CONCURRENT_MCTS_SEMAPHORE = Semaphore(20)
POLICY_SEMAPHORE = Semaphore(1000)
PRM_SEMAPHORE = Semaphore(1000)

MAX_RETRIES = 25 # Increased from 10
TIMEOUT = 20 # Decreased from 30 to fail faster and retry
MAX_RETRIES = 20 # Increased from 10
TIMEOUT = 15 # Decreased from 30 to fail faster and retry

# Cache decorator and retry function
def async_lru_cache(maxsize=2000):
Expand Down Expand Up @@ -71,24 +71,25 @@ class MCTSProgress:
def __init__(self, total_questions, iterations_per_question):
self.total_questions = total_questions
self.total_iterations = total_questions * iterations_per_question
self.completed_questions = 0
self.iterations_per_question = iterations_per_question
self.completed_iterations = 0
self.correct_sc = 0 # Self-consistency correct count
self.correct_any = 0 # Any-correct count
self.correct_best = 0 # Best PRM path correct count
self.total_actions = 0 # Global action counter
self.total_terminal_questions = 0 # Questions with at least one terminal node
self.questions_with_terminal = 0 # Questions with at least one terminal path
self.fully_completed_questions = 0 # Questions that completed all iterations

# Single progress bar with dynamic description
self.pbar = tqdm(total=self.total_iterations,
desc=self.get_progress_description())

def get_progress_description(self):
completed_pct = (self.completed_iterations / self.total_iterations) * 100
sc_pct = (self.correct_sc / max(1, self.total_terminal_questions)) * 100
any_pct = (self.correct_any / max(1, self.total_terminal_questions)) * 100
best_pct = (self.correct_best / max(1, self.total_terminal_questions)) * 100
return (f"#Q: {self.completed_questions}/{self.total_questions} | "
sc_pct = (self.correct_sc / max(1, self.fully_completed_questions)) * 100
any_pct = (self.correct_any / max(1, self.fully_completed_questions)) * 100
best_pct = (self.correct_best / max(1, self.fully_completed_questions)) * 100
q_pct = (self.questions_with_terminal / self.total_questions) * 100
return (f"#Q ({self.questions_with_terminal}/{self.total_questions}): {q_pct:.0f}% | "
f"SC: {sc_pct:.1f}% | "
f"ANY: {any_pct:.1f}% | "
f"BEST: {best_pct:.1f}% | "
Expand All @@ -97,12 +98,13 @@ def get_progress_description(self):
def increment_iteration(self):
self.completed_iterations += 1
self.pbar.update(1)
self.pbar.set_description(self.get_progress_description())
# No need to update description here

def complete_question(self, is_sc_correct, is_any_correct, is_best_correct, has_terminal_nodes):
self.completed_questions += 1
def complete_question(self, is_sc_correct, is_any_correct, is_best_correct, is_fully_completed, has_terminal_nodes):
if has_terminal_nodes:
self.total_terminal_questions += 1
self.questions_with_terminal += 1
if is_fully_completed:
self.fully_completed_questions += 1
if is_sc_correct:
self.correct_sc += 1
if is_any_correct:
Expand All @@ -113,16 +115,16 @@ def complete_question(self, is_sc_correct, is_any_correct, is_best_correct, has_

def close(self):
# Print final statistics
if self.total_terminal_questions > 0:
sc_pct = (self.correct_sc / self.total_terminal_questions) * 100
any_pct = (self.correct_any / self.total_terminal_questions) * 100
best_pct = (self.correct_best / self.total_terminal_questions) * 100
if self.fully_completed_questions > 0:
sc_pct = (self.correct_sc / self.fully_completed_questions) * 100
any_pct = (self.correct_any / self.fully_completed_questions) * 100
best_pct = (self.correct_best / self.fully_completed_questions) * 100
print(f"\nFinal Results:")
print(f"Total Questions Processed: {self.completed_questions}")
print(f"Questions with Terminal Nodes: {self.total_terminal_questions}")
print(f"Self-Consistency Accuracy: {sc_pct:.2f}% ({self.correct_sc}/{self.total_terminal_questions})")
print(f"Any-Correct Accuracy: {any_pct:.2f}% ({self.correct_any}/{self.total_terminal_questions})")
print(f"Best-Path Accuracy: {best_pct:.2f}% ({self.correct_best}/{self.total_terminal_questions})")
print(f"Questions with Terminal Paths: {self.questions_with_terminal}")
print(f"Fully Completed Questions: {self.fully_completed_questions}")
print(f"Self-Consistency Accuracy: {sc_pct:.2f}% ({self.correct_sc}/{self.fully_completed_questions})")
print(f"Any-Correct Accuracy: {any_pct:.2f}% ({self.correct_any}/{self.fully_completed_questions})")
print(f"Best-Path Accuracy: {best_pct:.2f}% ({self.correct_best}/{self.fully_completed_questions})")
print(f"Total Actions Taken: {self.total_actions}")
self.pbar.close()

Expand Down Expand Up @@ -312,6 +314,9 @@ async def run_mcts(initial_state, correct_answer, num_iterations, session, progr
is_sc_correct = (terminal_correct_count > total_terminal_nodes / 2) if has_terminal_nodes else False
is_any_correct = (terminal_correct_count > 0) # Any-correct using terminal nodes

# Check if this question completed all iterations
is_fully_completed = total_terminal_nodes > 0 and num_iterations == progress_tracker.iterations_per_question

result = {
"question": initial_state,
"correct_answer": correct_answer,
Expand All @@ -323,7 +328,8 @@ async def run_mcts(initial_state, correct_answer, num_iterations, session, progr
"self_consistency_correct": is_sc_correct,
"any_correct": is_any_correct,
"has_terminal_nodes": has_terminal_nodes,
"best_prm_path_correct": best_prm_path_correct
"best_prm_path_correct": best_prm_path_correct,
"fully_completed": is_fully_completed
},
"best_path": {
"final_state": best_leaf.state,
Expand All @@ -333,7 +339,7 @@ async def run_mcts(initial_state, correct_answer, num_iterations, session, progr
"terminal_paths": terminal_paths
}

progress_tracker.complete_question(is_sc_correct, is_any_correct, best_prm_path_correct, has_terminal_nodes)
progress_tracker.complete_question(is_sc_correct, is_any_correct, best_prm_path_correct, is_fully_completed, has_terminal_nodes)
return result

async def main():
Expand All @@ -344,15 +350,10 @@ def process(example):
example["answer"] = example["answer"].split("\n#### ")[-1].strip()
return example

gsm8k = load_dataset("openai/gsm8k", "main", split="train").shuffle(seed=42)
gsm8k = load_dataset("openai/gsm8k", "main", split="test").shuffle(seed=42)
gsm8k = gsm8k.map(process, num_proc=24)
initial_states = [(example["question"], example["answer"]) for example in gsm8k]

# Sample 100 questions
sample = False
if sample:
initial_states = random.sample(initial_states, 10)

num_iterations = 10

# Initialize progress tracker
Expand Down

0 comments on commit d59673b

Please sign in to comment.