Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rawsh authored Nov 6, 2024
1 parent 654dbbb commit 0fc3e40
Show file tree
Hide file tree
Showing 10 changed files with 1,063 additions and 340 deletions.
19 changes: 12 additions & 7 deletions mcts/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
from tqdm import tqdm

client = AsyncOpenAI(
api_key="9FF74944EED19865193F979942FB1",
base_url="https://rawsh--vllm-qwen-serve.modal.run/v1"
api_key="9FF74944EED19865193F979942FB1",adfghk
base_url="https://rawsh--vllm-smollm-serve.modal.run/v1"
)

def format_thoughts(thoughts: List[str]) -> str:
return "\n".join(f"## Step {i}:\n{thought}" for i, thought in enumerate(thoughts, 1))

template = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n\
<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n{assistant_partial}"
# template = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n\
# <|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n{assistant_partial}"

template = "<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>\n\
<|im_start|>human\n{user}<|im_end|>\n<|im_start|>assistant\n{assistant_partial}"

class ReasoningTrace:
def __init__(self, question: str, previous_thoughts: List[str], next_step: int):
Expand All @@ -36,10 +39,12 @@ async def generate_thought_batched(batch: List[ReasoningTrace]) -> List[Processe
prompts.append(prompt)

params = {
"model": "Qwen/Qwen2.5-0.5B-Instruct",
# "model": "Qwen/Qwen2.5-0.5B-Instruct",
"model": "HuggingFaceTB/SmolLM2-135M-Instruct",
"prompt": prompts,
"max_tokens": 200,
"temperature": 0.7,
# "temperature": 0.7,
"temperature": 0.0,
"stop": ["\n## Step"],
"timeout": 600
}
Expand All @@ -58,7 +63,7 @@ async def generate_thought_batched(batch: List[ReasoningTrace]) -> List[Processe
return None

async def format_thought_chain(question: str, chain: List[str]) -> List[ReasoningTrace]:
return [ReasoningTrace(question, chain[:i], i+1) for i in range(1, len(chain))]
return [ReasoningTrace(question, chain[:i], i+1) for i in range(0, len(chain))]

async def process_batch(batch: List[ReasoningTrace], semaphore: asyncio.Semaphore) -> List[ProcessedReasoningTrace]:
async with semaphore:
Expand Down
Loading

0 comments on commit 0fc3e40

Please sign in to comment.