Skip to content

Commit

Permalink
[MCTS] Add self-refined MCTS (hpcaitech#6098)
Browse files Browse the repository at this point in the history
* add reasoner

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update code

* delete llama

* update prompts

* update readme

* update readme

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
TongLi3701 and pre-commit-ci[bot] authored Oct 24, 2024
1 parent 4294ae8 commit 89a9a60
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 3 deletions.
21 changes: 18 additions & 3 deletions applications/ColossalChat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [O1 Journey](#o1-journey)
- [Inference with Self-refined MCTS](#inference-with-self-refined-mcts)
- [FAQ](#faq)
- [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq)
- [The Plan](#the-plan)
- [Real-time progress](#real-time-progress)
- [Invitation to open-source contribution](#invitation-to-open-source-contribution)
- [Quick Preview](#quick-preview)
- [Authors](#authors)
Expand Down Expand Up @@ -272,7 +272,7 @@ Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pd
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.

### Inference Quantization and Serving - After Training
## Inference Quantization and Serving - After Training

We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.

Expand All @@ -281,6 +281,21 @@ We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inferen
Online inference server scripts can help you deploy your own services.
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).

## O1 Journey
### Inference with Self-refined MCTS
We provide the implementation of MCT Self-Refine (MCTSr) algorithm, an innovative integration of Large Language Models with Monte Carlo Tree Search.
To run inference with MCTS, simply use the following script.
```python
from coati.reasoner.guided_search.mcts import MCTS
from coati.reasoner.guided_search.prompt_store.qwen import Qwen32B_prompt_CFG

problem = "How Many R in 'Strawberry'"

search_tree = MCTS(problem=problem, max_simulations=8, cfg=Qwen32B_prompt_CFG)
answer = search_tree.simulate()
print(answer)
```

## Coati7B examples

### Generation
Expand Down
26 changes: 26 additions & 0 deletions applications/ColossalChat/coati/reasoner/guided_search/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import openai
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam

API_KEY = "Dummy API Key"


def get_client(base_url: str | None = None) -> openai.Client:
return openai.Client(api_key=API_KEY, base_url=base_url)


def chat_completion(
messages: list[ChatCompletionMessageParam],
model: str,
base_url: str | None = None,
temperature: float = 0.8,
**kwargs,
) -> ChatCompletion:
client = get_client(base_url)
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
**kwargs,
)
return response
250 changes: 250 additions & 0 deletions applications/ColossalChat/coati/reasoner/guided_search/mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
"""
Implementation of MCTS + Self-refine algorithm.
Reference:
1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte
Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report"
2. https://github.com/BrendanGraham14/mcts-llm/
3. https://github.com/trotsky1997/MathBlackBox/
4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py
"""

from __future__ import annotations

import math
from collections import deque

import numpy as np
import tqdm
from coati.reasoner.guided_search.llm import chat_completion
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
from pydantic import BaseModel


class MCTSNode(BaseModel):
"""
Node for MCTS.
"""

answer: str
parent: MCTSNode = None
children: list[MCTSNode] = []
num_visits: int = 0
Q: int = 0
rewards: list[int] = []

def expand_node(self, node) -> None:
self.children.append(node)

def add_reward(self, reward: int) -> None:
self.rewards.append(reward)
self.Q = (np.min(self.rewards) + np.mean(self.rewards)) / 2


class MCTS(BaseModel):
"""
Simulation of MCTS process.
"""

problem: str
max_simulations: int
cfg: PromptCFG
C: float = 1.4
max_children: int = 2
epsilon: float = 1e-5
root: MCTSNode = None

def initialization(self):
"""
Root Initiation.
"""
# Dummy answer as root.
base_answer = self.sample_base_answer()
self.root = MCTSNode(answer=base_answer)
self.self_evaluate(self.root)

def is_fully_expanded(self, node: MCTSNode):
return len(node.children) >= self.max_children or any(child.Q > node.Q for child in node.children)

def select_node(self) -> MCTSNode:
"""
Select next node to explore.
"""
candidates: list[MCTSNode] = []
to_explore = deque([self.root])

while to_explore:
current_node = to_explore.popleft()
if not self.is_fully_expanded(current_node):
candidates.append(current_node)
to_explore.extend(current_node.children)

if not candidates:
return self.root

return max(candidates, key=self.compute_uct)

def self_evaluate(self, node: MCTSNode):
"""
Sample reward of the answer.
"""
reward = self.sample_reward(node)
node.add_reward(reward)

def back_propagation(self, node: MCTSNode):
"""
Back propagate the value of the refined answer.
"""
parent = node.parent
while parent:
best_child_Q = max(child.Q for child in parent.children)
parent.Q = (parent.Q + best_child_Q) / 2
parent.num_visits += 1
parent = parent.parent

def compute_uct(self, node: MCTSNode):
"""
Compute UCT.
"""
if node.parent is None:
return -100
return node.Q + self.C * math.sqrt(math.log(node.parent.num_visits + 1) / (node.num_visits + self.epsilon))

def simulate(self):
self.initialization()
for _ in tqdm.tqdm(range(self.max_simulations)):
node = self.select_node()
child = self.self_refine(node)
node.expand_node(child)
self.self_evaluate(child)
self.back_propagation(child)

return self.get_best_answer()

def get_best_answer(self):
to_visit = deque([self.root])
best_node = self.root

while to_visit:
current_node = to_visit.popleft()
if current_node.Q > best_node.Q:
best_node = current_node
to_visit.extend(current_node.children)

return best_node.answer

def self_refine(self, node: MCTSNode):
"""
Refine node.
"""
critique_response = chat_completion(
messages=[
{
"role": "system",
"content": self.cfg.critic_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<current_answer>\n{node.answer}\n</current_answer>",
]
),
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
critique = critique_response.choices[0].message.content
assert critique is not None
refined_answer_response = chat_completion(
messages=[
{
"role": "system",
"content": self.cfg.refine_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<current_answer>\n{node.answer}\n</current_answer>",
f"<critique>\n{critique}\n</critique>",
]
),
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
refined_answer = refined_answer_response.choices[0].message.content
assert refined_answer is not None

return MCTSNode(answer=refined_answer, parent=node)

def sample_base_answer(self):
response = chat_completion(
messages=[
{
"role": "system",
"content": "The user will provide a problem. Solve the problem. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. \nThe answer is [answer] \n#### [answer].",
},
{
"role": "user",
"content": f"<problem>\n {self.problem} \n</problem> \nLet's think step by step",
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
assert response.choices[0].message.content is not None
return response.choices[0].message.content

def sample_reward(self, node: MCTSNode):
"""
Calculate reward.
"""
messages = [
{
"role": "system",
"content": self.cfg.evaluate_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<answer>\n{node.answer}\n</answer>",
]
),
},
]
for attempt in range(3):
try:
response = chat_completion(
messages=messages,
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
assert response.choices[0].message.content is not None
return int(response.choices[0].message.content)
except ValueError:
messages.extend(
[
{
"role": "assistant",
"content": response.choices[0].message.content,
},
{
"role": "user",
"content": "Failed to parse reward as an integer.",
},
]
)
if attempt == 2:
raise
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pydantic import BaseModel


class PromptCFG(BaseModel):
model: str
base_url: str
max_tokens: int = 4096
critic_system_prompt: str
refine_system_prompt: str
evaluate_system_prompt: str
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Prompts for Qwen Series.
"""

from coati.reasoner.guided_search.prompt_store.base import PromptCFG

Qwen32B_prompt_CFG = PromptCFG(
base_url="http://0.0.0.0:8008/v1",
model="Qwen2.5-32B-Instruct",
critic_system_prompt="Provide a detailed and constructive critique to improve the answer. "
"Highlight specific areas that need refinement or correction.",
refine_system_prompt="""# Instruction
Refine the answer based on the critique. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer].
""",
evaluate_system_prompt=(
"Analyze this answer strictly and critic, provide a reward score between -100 and 100 for the answer quality, using very strict standards. "
"Do not give a full score above 95. Make sure the reward score is an integer. "
"Return *ONLY* the score."
),
)

0 comments on commit 89a9a60

Please sign in to comment.