diff --git a/mcts/tree_search_mathrm.py b/mcts/tree_search_mathrm.py index 7d4e4f0..cdb3d31 100644 --- a/mcts/tree_search_mathrm.py +++ b/mcts/tree_search_mathrm.py @@ -28,7 +28,7 @@ PRM_SEMAPHORE = Semaphore(1000) # More aggressive retry settings -MAX_RETRIES = 5 +MAX_RETRIES = 10 TIMEOUT = 45 # Cache decorator and retry function @@ -519,8 +519,8 @@ def process(example): gsm8k = load_dataset("openai/gsm8k", "main", split="test").shuffle(seed=42) gsm8k = gsm8k.map(process, num_proc=24) initial_states = [(example["question"], example["answer"]) for example in gsm8k] - initial_states = random.sample(initial_states, 1) - num_iterations = 1 + initial_states = random.sample(initial_states, 100) + num_iterations = 50 print("cold starting policy vllm + prm api")