Skip to content

Commit

Permalink
update qwen
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Dec 26, 2023
1 parent b5d0f5f commit 32caace
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 13 deletions.
33 changes: 24 additions & 9 deletions chatproto/conversation/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,25 @@ def create_llama(settings: ConversationSettings, system: Optional[str], messages
ret += section
return ret, indices

def create_chatml(settings: ConversationSettings, system: Optional[str], messages: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[int, int]]]:
im_start, im_end = "<|im_start|>", "<|im_end|>"
indices = []
system_prompt = create_system_prompt(settings, system)
if len(system_prompt) != 0:
system_prompt += settings.sep
indices.append((0, len(system_prompt)))

ret = system_prompt
for i, (role, message) in enumerate(messages):
if message:
section = im_start + role + "\n" + message + im_end + settings.sep
prefix = ret + im_start + role + "\n"
indices.append((len(prefix), len(prefix) + len(message)))
else:
section = im_start + role + "\n"
ret += section
return ret, indices

@dataclasses.dataclass
class ConversationHistory:
"""A class that keeps all conversation history."""
Expand Down Expand Up @@ -237,6 +256,9 @@ def get_prompt_and_indices(self) -> Tuple[str, List[Tuple[int, int]]]:
elif self.settings.sep_style == SeparatorStyle.CHATGLM:
ret, indices = create_chatglm(self.settings, self.system, self.messages)
return ret
elif self.settings.sep_style == SeparatorStyle.CHATML:
ret, indices = create_chatml(self.settings, self.system, self.messages)
return ret
else:
raise Exception("Indices not support yet.")

Expand Down Expand Up @@ -273,15 +295,8 @@ def get_prompt(self) -> str:
elif self.settings.sep_style == SeparatorStyle.CHATGLM:
ret, indices = create_chatglm(self.settings, self.system, self.messages)
return ret
elif self.settings.sep_style == SeparatorStyle.CHATLM:
im_start, im_end = "<|im_start|>", "<|im_end|>"
ret = system_prompt + self.settings.sep

for i, (role, message) in enumerate(self.messages):
if message:
ret += im_start + role + "\n" + message + im_end + self.settings.sep
else:
ret += im_start + role + "\n"
elif self.settings.sep_style == SeparatorStyle.CHATML:
ret, indices = create_chatml(self.settings, self.system, self.messages)
return ret
else:
raise ValueError(f"Invalid style: {self.settings.sep_style}")
Expand Down
5 changes: 4 additions & 1 deletion chatproto/conversation/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@
sep_style=SeparatorStyle.CHATGLM,
sep="\n\n",
stop_str="\n\n",
)
)

chatglm2 = chatglm.alias("chatglm2")
chatglm3 = chatglm.alias("chatglm3")
3 changes: 3 additions & 0 deletions chatproto/conversation/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@
sep2=" </s><s>",
stop_str=["[/INST]", "[INST]"]
)

llama1 = llama.alias("llama1")
llama2 = llama.alias("llama2")
9 changes: 7 additions & 2 deletions chatproto/conversation/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
qwen = ConversationSettings(
name="qwen",
roles=("user", "assistant"),
sep_style=SeparatorStyle.CHATLM,
sep_style=SeparatorStyle.CHATML,
system_template="<|im_start|>system\n{system_message}<|im_end|>",
sep="\n",
stop_str="<|im_end|>",
stop_token_ids=[
151643,
151644,
151645,
], # "<|endoftext|>", "<|im_start|>", "<|im_end|>"
stop_str="<|endoftext|>",
)
14 changes: 13 additions & 1 deletion chatproto/conversation/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SeparatorStyle(Enum):
PHOENIX = auto()
LLAMA = auto()
CHATGLM = auto()
CHATLM = auto()
CHATML = auto()


@dataclasses.dataclass
Expand Down Expand Up @@ -48,4 +48,16 @@ def copy(self):
stop_str=self.stop_str,
stop_token_ids=self.stop_token_ids,
)

def alias(self, new_name: str):
return ConversationSettings(
name=new_name,
roles=self.roles,
system_template=self.system_template,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
stop_str=self.stop_str,
stop_token_ids=self.stop_token_ids,
)

0 comments on commit 32caace

Please sign in to comment.