Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rawsh authored Sep 22, 2024
1 parent 58c16e8 commit 0c62779
Show file tree
Hide file tree
Showing 8 changed files with 793 additions and 14 deletions.
108 changes: 108 additions & 0 deletions mcts/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import asyncio
from openai import AsyncOpenAI
import json
from typing import List, Tuple
from datasets import load_dataset
from util import split_and_clean_steps, quality_filter, SEED
from tqdm import tqdm

client = AsyncOpenAI(
api_key="9FF74944EED19865193F979942FB1",
base_url="https://rawsh--vllm-qwen-serve.modal.run/v1"
)

def format_thoughts(thoughts: List[str]) -> str:
return "\n".join(f"## Step {i}:\n{thought}" for i, thought in enumerate(thoughts, 1))

template = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n\
<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n{assistant_partial}"

class ReasoningTrace:
def __init__(self, question: str, previous_thoughts: List[str], next_step: int):
self.question = question
self.previous_thoughts = previous_thoughts
self.next_step = next_step

class ProcessedReasoningTrace:
def __init__(self, question: str, thoughts: List[str]):
self.question = question
self.thoughts = thoughts

async def generate_thought_batched(batch: List[ReasoningTrace]) -> List[ProcessedReasoningTrace]:
prompts = []
for trace in batch:
formatted_thoughts = format_thoughts(trace.previous_thoughts)
prompt = template.format(user=trace.question, assistant_partial=f"{formatted_thoughts}\n## Step {trace.next_step}:\n")
prompts.append(prompt)

params = {
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"prompt": prompts,
"max_tokens": 200,
"temperature": 0.7,
"stop": ["\n## Step"],
"timeout": 600
}

try:
response = await client.completions.create(**params)
processed = [
ProcessedReasoningTrace(
question=batch[i].question,
thoughts=batch[i].previous_thoughts + [response.choices[i].text.strip()]
) for i in range(len(batch))
]
return processed
except Exception as e:
print(f"An error occurred: {str(e)}")
return None

async def format_thought_chain(question: str, chain: List[str]) -> List[ReasoningTrace]:
return [ReasoningTrace(question, chain[:i], i+1) for i in range(1, len(chain))]

async def process_batch(batch: List[ReasoningTrace], semaphore: asyncio.Semaphore) -> List[ProcessedReasoningTrace]:
async with semaphore:
return await generate_thought_batched(batch)

async def process_all_thought_chains_batched(thought_chains: List[Tuple[str, List[str]]]) -> List[ProcessedReasoningTrace]:
batch_size = 200
all_traces = []

for question, chain in thought_chains:
all_traces.extend(await format_thought_chain(question, chain))

results = []
semaphore = asyncio.Semaphore(10) # Limit to 10 concurrent batches
tasks = []

for i in range(0, len(all_traces), batch_size):
batch = all_traces[i:i + batch_size]
task = asyncio.create_task(process_batch(batch, semaphore))
tasks.append(task)

for task in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Processing batches"):
processed_batch = await task
if processed_batch:
results.extend(processed_batch)

return results

async def main():
ds = load_dataset("argilla/magpie-ultra-v0.1")
filtered_ds = ds.filter(quality_filter)
split_ds = filtered_ds['train'].train_test_split(test_size=0.1, seed=SEED)
train_ds = split_ds['train']
correct_traces = [(row["instruction"], split_and_clean_steps(row["response"])) for row in train_ds]

# correct_traces = correct_traces[:1000]
generated_thoughts = await process_all_thought_chains_batched(correct_traces)

with open("out.jsonl", "w") as f:
for chain in generated_thoughts:
json.dump(chain.__dict__, f)
f.write("\n")

print(f"Results written to out.jsonl")

if __name__ == "__main__":
asyncio.run(main())
119 changes: 119 additions & 0 deletions mcts/reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from datasets import load_dataset
import numpy as np
from util import split_and_clean_steps, quality_filter, SEED
import json

def initialize_prm(traces, last_step_correct=True):
"""
Initialize the Process Reward Model (PRM) using sets of reasoning traces.
Args:
traces (list of list of str): Reasoning traces
correct (bool): Whether the traces are correct (True) or incorrect (False)
Returns:
dict: Initialized PRM with quality values and weighted rewards
"""
# prm = {}
prm_data = []

for i, trace_tuple in enumerate(traces):
question, trace = trace_tuple
K = len(trace) # Total number of reasoning steps

# Initialize trace
prm_example = {"steps": [], "quality_values": [], "weighted_rewards": []}
v_prev = 0
for k, step in enumerate(trace, 1):
penalize = (not last_step_correct) and k == len(trace)
m_k = K - k if (not penalize) else K - k + 1 # One more step needed to correct mistake if incorrect
r_s_k = 0 if (not penalize) else 1 # 0 for correct steps, 1 for incorrect steps
w_s_k = (1 - v_prev) / (m_k + 1) * (1 - 2 * r_s_k)
v_k = max(v_prev + w_s_k, 0)

prm_example["question"] = question
prm_example["steps"].append(step)
prm_example["quality_values"].append(v_k)
prm_example["weighted_rewards"].append(w_s_k)
v_prev = v_k

prm_data.append(prm_example)

return prm_data


# Load and filter the dataset, then apply the 90:10 split
ds = load_dataset("argilla/magpie-ultra-v0.1")
# Filter the dataset
filtered_ds = ds.filter(quality_filter)
# Apply the 90:10 split on the filtered training data
split_ds = filtered_ds['train'].train_test_split(test_size=0.1, seed=SEED)
train_ds = split_ds['train']
print(len(train_ds))
# "Correct" traces generated by 405B
correct_traces = [(row["instruction"], split_and_clean_steps(row["response"])) for row in train_ds]

# Example usage:
# correct_traces = [
# ["Step 1: Correct", "Step 2: Correct", "Step 3: Correct"],
# ["Step 1: Correct", "Step 2: Correct"]
# ]

with open('out.jsonl') as f:
last_step_incorrect_data = [json.loads(line) for line in f]
last_step_incorrect_traces = [(ex["question"], ex["thoughts"]) for ex in last_step_incorrect_data]

# incorrect_traces = [['Identify all the possible outcomes of tossing four coins simultaneously. When tossing four coins simultaneously, each coin has 2 possible outcomes (heads or tails). Therefore, for four coins, the total number of possible outcomes is $2^4 = 16$.', 'List all the outcomes that result in more heads than tails. There are 4 outcomes that meet this criterion: HTHT, HHTT, THTH, TTHH. This gives us a total of 4 favorable outcomes.'], ['Identify all the possible outcomes of tossing four coins simultaneously. When tossing four coins simultaneously, each coin has 2 possible outcomes (heads or tails). Therefore, for four coins, the total number of possible outcomes is $2^4 = 16$.', 'Determine the favorable outcomes. We want more heads than tails, which means we need 3 heads and 1 tail, or 4 heads.', 'Count the number of outcomes with 3 heads and 1 tail. For 3 heads, there is only 1 way to arrange them (HHH). For 1 tail, there are 2 ways to arrange them (TTH and THT). So, there are a total of 1 + 2 = 3 favorable outcomes.'], ['Recognize that this is an arithmetic sequence with a common difference of 1.', 'To find the sum of the first 100 positive integers, we can use the formula for the sum of an arithmetic series, which is given by S = n/2 * (a1 + an), where n is the number of terms, a1 is the first term, and an is the last term.']]

# initialized_prm = initialize_prm(correct_traces)
# print(initialized_prm)
# print(initialized_prm["trace_1000"])

correct_prm_data = initialize_prm(correct_traces, last_step_correct=True)
print(len(correct_prm_data))
total_length = 0
correct_prm_data_step_values = []
for ex in correct_prm_data:
total_length += len(ex["steps"])
for i in range(len(ex["steps"])):
question = ex["question"]
partial_steps = ex["steps"][:i+1]
partial_reward = ex["quality_values"][i]
correct_prm_data_step_values.append({
"question": question,
"steps": partial_steps,
"final_step_reward": partial_reward
})

print("corr total # step values", total_length)

last_step_incorrect_prm_data = initialize_prm(last_step_incorrect_traces, last_step_correct=False)
print(len(last_step_incorrect_prm_data))

last_step_incorrect_prm_data_step_values = []
for ex in last_step_incorrect_prm_data:
i = len(ex["steps"]) - 1
question = ex["question"]
partial_steps = ex["steps"][:i+1]
partial_reward = ex["quality_values"][i]
last_step_incorrect_prm_data_step_values.append({
"question": question,
"steps": partial_steps,
"final_step_reward": partial_reward
})

print("last step incorr total # step values", len(last_step_incorrect_prm_data_step_values))

# print(initialized_prm)
# print(last_step_incorrect_prm_data[1000])

with open("reward.jsonl", "w") as f:
for prm_examples in correct_prm_data_step_values:
json.dump(prm_examples, f)
f.write("\n")

for prm_examples in last_step_incorrect_prm_data_step_values:
json.dump(prm_examples, f)
f.write("\n")

print(f"Results written to reward.jsonl")
51 changes: 51 additions & 0 deletions mcts/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import re

def split_and_clean_steps(text):
# Use regex to split the text into steps
steps = re.split(r'(?=##\s*Step\s+\d+:)', text)

# Remove any leading/trailing whitespace, empty steps, and the "## Step n:" prefix
cleaned_steps = []
for step in steps:
# Strip whitespace and check if step is not empty
step = step.strip()
if step:
# Remove the "## Step n:" prefix
step = re.sub(r'^##\s*Step\s+\d+:\s*', '', step)
cleaned_steps.append(step)

return cleaned_steps

# Example usage
text1 = """## Step 1: First step
Content of first step.
## Step 2: Second step
Content of second step.
## Step 10: Tenth step
Content of tenth step.
## Step 11: Eleventh step
Content of eleventh step.
sdfsdfsdfsdf
sdfsdfsd
step ## Step 12: Test"""

text2 = """## Step 1: Short step
Brief content.
## Step 99: Large step number
Content of step 99.
## Step 100: Three-digit step
Content of step 100."""

# Test with both examples
for i, text in enumerate([text1, text2], 1):
# print(f"Test case {i}:")
result = split_and_clean_steps(text)
for j, step in enumerate(result, 1):
print(f"Step {j}:")
print(step)
print()
print("---\n")
Loading

0 comments on commit 0c62779

Please sign in to comment.