-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
793 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import asyncio | ||
from openai import AsyncOpenAI | ||
import json | ||
from typing import List, Tuple | ||
from datasets import load_dataset | ||
from util import split_and_clean_steps, quality_filter, SEED | ||
from tqdm import tqdm | ||
|
||
client = AsyncOpenAI( | ||
api_key="9FF74944EED19865193F979942FB1", | ||
base_url="https://rawsh--vllm-qwen-serve.modal.run/v1" | ||
) | ||
|
||
def format_thoughts(thoughts: List[str]) -> str: | ||
return "\n".join(f"## Step {i}:\n{thought}" for i, thought in enumerate(thoughts, 1)) | ||
|
||
template = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n\ | ||
<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n{assistant_partial}" | ||
|
||
class ReasoningTrace: | ||
def __init__(self, question: str, previous_thoughts: List[str], next_step: int): | ||
self.question = question | ||
self.previous_thoughts = previous_thoughts | ||
self.next_step = next_step | ||
|
||
class ProcessedReasoningTrace: | ||
def __init__(self, question: str, thoughts: List[str]): | ||
self.question = question | ||
self.thoughts = thoughts | ||
|
||
async def generate_thought_batched(batch: List[ReasoningTrace]) -> List[ProcessedReasoningTrace]: | ||
prompts = [] | ||
for trace in batch: | ||
formatted_thoughts = format_thoughts(trace.previous_thoughts) | ||
prompt = template.format(user=trace.question, assistant_partial=f"{formatted_thoughts}\n## Step {trace.next_step}:\n") | ||
prompts.append(prompt) | ||
|
||
params = { | ||
"model": "Qwen/Qwen2.5-0.5B-Instruct", | ||
"prompt": prompts, | ||
"max_tokens": 200, | ||
"temperature": 0.7, | ||
"stop": ["\n## Step"], | ||
"timeout": 600 | ||
} | ||
|
||
try: | ||
response = await client.completions.create(**params) | ||
processed = [ | ||
ProcessedReasoningTrace( | ||
question=batch[i].question, | ||
thoughts=batch[i].previous_thoughts + [response.choices[i].text.strip()] | ||
) for i in range(len(batch)) | ||
] | ||
return processed | ||
except Exception as e: | ||
print(f"An error occurred: {str(e)}") | ||
return None | ||
|
||
async def format_thought_chain(question: str, chain: List[str]) -> List[ReasoningTrace]: | ||
return [ReasoningTrace(question, chain[:i], i+1) for i in range(1, len(chain))] | ||
|
||
async def process_batch(batch: List[ReasoningTrace], semaphore: asyncio.Semaphore) -> List[ProcessedReasoningTrace]: | ||
async with semaphore: | ||
return await generate_thought_batched(batch) | ||
|
||
async def process_all_thought_chains_batched(thought_chains: List[Tuple[str, List[str]]]) -> List[ProcessedReasoningTrace]: | ||
batch_size = 200 | ||
all_traces = [] | ||
|
||
for question, chain in thought_chains: | ||
all_traces.extend(await format_thought_chain(question, chain)) | ||
|
||
results = [] | ||
semaphore = asyncio.Semaphore(10) # Limit to 10 concurrent batches | ||
tasks = [] | ||
|
||
for i in range(0, len(all_traces), batch_size): | ||
batch = all_traces[i:i + batch_size] | ||
task = asyncio.create_task(process_batch(batch, semaphore)) | ||
tasks.append(task) | ||
|
||
for task in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Processing batches"): | ||
processed_batch = await task | ||
if processed_batch: | ||
results.extend(processed_batch) | ||
|
||
return results | ||
|
||
async def main(): | ||
ds = load_dataset("argilla/magpie-ultra-v0.1") | ||
filtered_ds = ds.filter(quality_filter) | ||
split_ds = filtered_ds['train'].train_test_split(test_size=0.1, seed=SEED) | ||
train_ds = split_ds['train'] | ||
correct_traces = [(row["instruction"], split_and_clean_steps(row["response"])) for row in train_ds] | ||
|
||
# correct_traces = correct_traces[:1000] | ||
generated_thoughts = await process_all_thought_chains_batched(correct_traces) | ||
|
||
with open("out.jsonl", "w") as f: | ||
for chain in generated_thoughts: | ||
json.dump(chain.__dict__, f) | ||
f.write("\n") | ||
|
||
print(f"Results written to out.jsonl") | ||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from datasets import load_dataset | ||
import numpy as np | ||
from util import split_and_clean_steps, quality_filter, SEED | ||
import json | ||
|
||
def initialize_prm(traces, last_step_correct=True): | ||
""" | ||
Initialize the Process Reward Model (PRM) using sets of reasoning traces. | ||
Args: | ||
traces (list of list of str): Reasoning traces | ||
correct (bool): Whether the traces are correct (True) or incorrect (False) | ||
Returns: | ||
dict: Initialized PRM with quality values and weighted rewards | ||
""" | ||
# prm = {} | ||
prm_data = [] | ||
|
||
for i, trace_tuple in enumerate(traces): | ||
question, trace = trace_tuple | ||
K = len(trace) # Total number of reasoning steps | ||
|
||
# Initialize trace | ||
prm_example = {"steps": [], "quality_values": [], "weighted_rewards": []} | ||
v_prev = 0 | ||
for k, step in enumerate(trace, 1): | ||
penalize = (not last_step_correct) and k == len(trace) | ||
m_k = K - k if (not penalize) else K - k + 1 # One more step needed to correct mistake if incorrect | ||
r_s_k = 0 if (not penalize) else 1 # 0 for correct steps, 1 for incorrect steps | ||
w_s_k = (1 - v_prev) / (m_k + 1) * (1 - 2 * r_s_k) | ||
v_k = max(v_prev + w_s_k, 0) | ||
|
||
prm_example["question"] = question | ||
prm_example["steps"].append(step) | ||
prm_example["quality_values"].append(v_k) | ||
prm_example["weighted_rewards"].append(w_s_k) | ||
v_prev = v_k | ||
|
||
prm_data.append(prm_example) | ||
|
||
return prm_data | ||
|
||
|
||
# Load and filter the dataset, then apply the 90:10 split | ||
ds = load_dataset("argilla/magpie-ultra-v0.1") | ||
# Filter the dataset | ||
filtered_ds = ds.filter(quality_filter) | ||
# Apply the 90:10 split on the filtered training data | ||
split_ds = filtered_ds['train'].train_test_split(test_size=0.1, seed=SEED) | ||
train_ds = split_ds['train'] | ||
print(len(train_ds)) | ||
# "Correct" traces generated by 405B | ||
correct_traces = [(row["instruction"], split_and_clean_steps(row["response"])) for row in train_ds] | ||
|
||
# Example usage: | ||
# correct_traces = [ | ||
# ["Step 1: Correct", "Step 2: Correct", "Step 3: Correct"], | ||
# ["Step 1: Correct", "Step 2: Correct"] | ||
# ] | ||
|
||
with open('out.jsonl') as f: | ||
last_step_incorrect_data = [json.loads(line) for line in f] | ||
last_step_incorrect_traces = [(ex["question"], ex["thoughts"]) for ex in last_step_incorrect_data] | ||
|
||
# incorrect_traces = [['Identify all the possible outcomes of tossing four coins simultaneously. When tossing four coins simultaneously, each coin has 2 possible outcomes (heads or tails). Therefore, for four coins, the total number of possible outcomes is $2^4 = 16$.', 'List all the outcomes that result in more heads than tails. There are 4 outcomes that meet this criterion: HTHT, HHTT, THTH, TTHH. This gives us a total of 4 favorable outcomes.'], ['Identify all the possible outcomes of tossing four coins simultaneously. When tossing four coins simultaneously, each coin has 2 possible outcomes (heads or tails). Therefore, for four coins, the total number of possible outcomes is $2^4 = 16$.', 'Determine the favorable outcomes. We want more heads than tails, which means we need 3 heads and 1 tail, or 4 heads.', 'Count the number of outcomes with 3 heads and 1 tail. For 3 heads, there is only 1 way to arrange them (HHH). For 1 tail, there are 2 ways to arrange them (TTH and THT). So, there are a total of 1 + 2 = 3 favorable outcomes.'], ['Recognize that this is an arithmetic sequence with a common difference of 1.', 'To find the sum of the first 100 positive integers, we can use the formula for the sum of an arithmetic series, which is given by S = n/2 * (a1 + an), where n is the number of terms, a1 is the first term, and an is the last term.']] | ||
|
||
# initialized_prm = initialize_prm(correct_traces) | ||
# print(initialized_prm) | ||
# print(initialized_prm["trace_1000"]) | ||
|
||
correct_prm_data = initialize_prm(correct_traces, last_step_correct=True) | ||
print(len(correct_prm_data)) | ||
total_length = 0 | ||
correct_prm_data_step_values = [] | ||
for ex in correct_prm_data: | ||
total_length += len(ex["steps"]) | ||
for i in range(len(ex["steps"])): | ||
question = ex["question"] | ||
partial_steps = ex["steps"][:i+1] | ||
partial_reward = ex["quality_values"][i] | ||
correct_prm_data_step_values.append({ | ||
"question": question, | ||
"steps": partial_steps, | ||
"final_step_reward": partial_reward | ||
}) | ||
|
||
print("corr total # step values", total_length) | ||
|
||
last_step_incorrect_prm_data = initialize_prm(last_step_incorrect_traces, last_step_correct=False) | ||
print(len(last_step_incorrect_prm_data)) | ||
|
||
last_step_incorrect_prm_data_step_values = [] | ||
for ex in last_step_incorrect_prm_data: | ||
i = len(ex["steps"]) - 1 | ||
question = ex["question"] | ||
partial_steps = ex["steps"][:i+1] | ||
partial_reward = ex["quality_values"][i] | ||
last_step_incorrect_prm_data_step_values.append({ | ||
"question": question, | ||
"steps": partial_steps, | ||
"final_step_reward": partial_reward | ||
}) | ||
|
||
print("last step incorr total # step values", len(last_step_incorrect_prm_data_step_values)) | ||
|
||
# print(initialized_prm) | ||
# print(last_step_incorrect_prm_data[1000]) | ||
|
||
with open("reward.jsonl", "w") as f: | ||
for prm_examples in correct_prm_data_step_values: | ||
json.dump(prm_examples, f) | ||
f.write("\n") | ||
|
||
for prm_examples in last_step_incorrect_prm_data_step_values: | ||
json.dump(prm_examples, f) | ||
f.write("\n") | ||
|
||
print(f"Results written to reward.jsonl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import re | ||
|
||
def split_and_clean_steps(text): | ||
# Use regex to split the text into steps | ||
steps = re.split(r'(?=##\s*Step\s+\d+:)', text) | ||
|
||
# Remove any leading/trailing whitespace, empty steps, and the "## Step n:" prefix | ||
cleaned_steps = [] | ||
for step in steps: | ||
# Strip whitespace and check if step is not empty | ||
step = step.strip() | ||
if step: | ||
# Remove the "## Step n:" prefix | ||
step = re.sub(r'^##\s*Step\s+\d+:\s*', '', step) | ||
cleaned_steps.append(step) | ||
|
||
return cleaned_steps | ||
|
||
# Example usage | ||
text1 = """## Step 1: First step | ||
Content of first step. | ||
## Step 2: Second step | ||
Content of second step. | ||
## Step 10: Tenth step | ||
Content of tenth step. | ||
## Step 11: Eleventh step | ||
Content of eleventh step. | ||
sdfsdfsdfsdf | ||
sdfsdfsd | ||
step ## Step 12: Test""" | ||
|
||
text2 = """## Step 1: Short step | ||
Brief content. | ||
## Step 99: Large step number | ||
Content of step 99. | ||
## Step 100: Three-digit step | ||
Content of step 100.""" | ||
|
||
# Test with both examples | ||
for i, text in enumerate([text1, text2], 1): | ||
# print(f"Test case {i}:") | ||
result = split_and_clean_steps(text) | ||
for j, step in enumerate(result, 1): | ||
print(f"Step {j}:") | ||
print(step) | ||
print() | ||
print("---\n") |
Oops, something went wrong.