diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml
index 1d383115..972060d0 100644
--- a/.github/workflows/python-package.yml
+++ b/.github/workflows/python-package.yml
@@ -36,7 +36,7 @@ jobs:
- name: Lint with black
uses: psf/black@stable
with:
- options: "--check --verbose --line-length 120"
+ options: "--check --verbose"
src: "chatglm_cpp examples tests setup.py"
- name: Test with pytest
run: |
diff --git a/.gitignore b/.gitignore
index 47826aca..31d3823e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,6 +10,7 @@ __pycache__/
*.egg-info/
dist/
*.so
+*.whl
.hypothesis/
# cpp
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8b7ac72e..373b98c7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -92,7 +92,19 @@ file(GLOB PY_SOURCES
add_custom_target(lint
COMMAND clang-format -i ${CPP_SOURCES}
COMMAND isort ${PY_SOURCES}
- COMMAND black ${PY_SOURCES} --line-length 120)
+ COMMAND black ${PY_SOURCES} --verbose)
+
+# mypy
+add_custom_target(mypy
+ mypy chatglm_cpp examples --exclude __init__.pyi
+ WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
+)
+
+# stub
+add_custom_target(stub
+ pybind11-stubgen chatglm_cpp -o .
+ WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
+)
if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall")
diff --git a/README.md b/README.md
index 3c05e02d..fc87692c 100644
--- a/README.md
+++ b/README.md
@@ -105,11 +105,70 @@ python3 chatglm_cpp/convert.py -i THUDM/chatglm2-6b -t q4_0 -o chatglm2-ggml.bin
ChatGLM3-6B
+ChatGLM3-6B further supports function call and code interpreter in addition to chat mode.
+
+Chat mode:
```sh
python3 chatglm_cpp/convert.py -i THUDM/chatglm3-6b -t q4_0 -o chatglm3-ggml.bin
./build/bin/main -m chatglm3-ggml.bin -p 你好 --top_p 0.8 --temp 0.8
# 你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。
```
+
+Setting system prompt:
+```sh
+./build/bin/main -m chatglm3-ggml.bin -p 你好 -s "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown."
+# 你好👋!我是 ChatGLM3,有什么问题可以帮您解答吗?
+```
+
+Function call:
+~~~
+$ ./build/bin/main -m chatglm3-ggml.bin --top_p 0.8 --temp 0.8 --sp examples/system/function_call.txt -i
+System > Answer the following questions as best as you can. You have access to the following tools: ...
+Prompt > 生成一个随机数
+ChatGLM3 > random_number_generator
+```python
+tool_call(seed=42, range=(0, 100))
+```
+Tool Call > Please manually call function `random_number_generator` with args `tool_call(seed=42, range=(0, 100))` and provide the results below.
+Observation > 23
+ChatGLM3 > 根据您的要求,我使用随机数生成器API生成了一个随机数。根据API返回结果,生成的随机数为23。
+~~~
+
+Code interpreter:
+~~~
+$ ./build/bin/main -m chatglm3-ggml.bin --top_p 0.8 --temp 0.8 --sp examples/system/code_interpreter.txt -i
+System > 你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。
+Prompt > 列出100以内的所有质数
+ChatGLM3 > 好的,我会为您列出100以内的所有质数。
+```python
+def is_prime(n):
+ """Check if a number is prime."""
+ if n <= 1:
+ return False
+ if n <= 3:
+ return True
+ if n % 2 == 0 or n % 3 == 0:
+ return False
+ i = 5
+ while i * i <= n:
+ if n % i == 0 or n % (i + 2) == 0:
+ return False
+ i += 6
+ return True
+
+primes_upto_100 = [i for i in range(2, 101) if is_prime(i)]
+primes_upto_100
+```
+
+Code Interpreter > Please manually run the code and provide the results below.
+Observation > [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]
+ChatGLM3 > 100以内的所有质数为:
+
+$$
+2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97
+$$
+~~~
+
@@ -251,7 +310,7 @@ pip install .
Pre-built wheels for CPU backend on Linux / MacOS / Windows are published on [release](https://github.com/li-plus/chatglm.cpp/releases). For CUDA / Metal backends, please compile from source code or source distribution.
-**Using pre-converted ggml models**
+**Using Pre-converted GGML Models**
Here is a simple demo that uses `chatglm_cpp.Pipeline` to load the GGML model and chat with it. First enter the examples folder (`cd examples`) and launch a Python interactive shell:
```python
@@ -264,7 +323,7 @@ Here is a simple demo that uses `chatglm_cpp.Pipeline` to load the GGML model an
To chat in stream, run the below Python example:
```sh
-python3 cli_chat.py -m ../chatglm-ggml.bin -i
+python3 cli_demo.py -m ../chatglm-ggml.bin -i
```
Launch a web demo to chat in your browser:
@@ -280,7 +339,7 @@ For other models:
ChatGLM2-6B
```sh
-python3 cli_chat.py -m ../chatglm2-ggml.bin -p 你好 --temp 0.8 --top_p 0.8 # CLI demo
+python3 cli_demo.py -m ../chatglm2-ggml.bin -p 你好 --temp 0.8 --top_p 0.8 # CLI demo
python3 web_demo.py -m ../chatglm2-ggml.bin --temp 0.8 --top_p 0.8 # web demo
```
@@ -288,10 +347,40 @@ python3 web_demo.py -m ../chatglm2-ggml.bin --temp 0.8 --top_p 0.8 # web demo
ChatGLM3-6B
+**CLI Demo**
+
+Chat mode:
+```sh
+python3 cli_demo.py -m ../chatglm3-ggml.bin -p 你好 --temp 0.8 --top_p 0.8
+```
+
+Function call:
```sh
-python3 cli_chat.py -m ../chatglm3-ggml.bin -p 你好 --temp 0.8 --top_p 0.8 # CLI demo
-python3 web_demo.py -m ../chatglm3-ggml.bin --temp 0.8 --top_p 0.8 # web demo
+python3 cli_demo.py -m ../chatglm3-ggml.bin --temp 0.8 --top_p 0.8 --sp system/function_call.txt -i
```
+
+Code interpreter:
+```sh
+python3 cli_demo.py -m ../chatglm3-ggml.bin --temp 0.8 --top_p 0.8 --sp system/code_interpreter.txt -i
+```
+
+**Web Demo**
+
+Install Python dependencies and the IPython kernel for code interpreter.
+```sh
+pip install streamlit jupyter_client ipython ipykernel
+ipython kernel install --name chatglm3-demo --user
+```
+
+Launch the web demo:
+```sh
+streamlit run chatglm3_demo.py
+```
+
+| Function Call | Code Interpreter |
+|-----------------------------|--------------------------------|
+| ![](docs/function_call.png) | ![](docs/code_interpreter.png) |
+
@@ -299,7 +388,7 @@ python3 web_demo.py -m ../chatglm3-ggml.bin --temp 0.8 --top_p 0.8 # web demo
```sh
# CLI demo
-python3 cli_chat.py -m ../codegeex2-ggml.bin --temp 0 --mode generate -p "\
+python3 cli_demo.py -m ../codegeex2-ggml.bin --temp 0 --mode generate -p "\
# language: Python
# write a bubble sort function
"
@@ -312,7 +401,7 @@ python3 web_demo.py -m ../codegeex2-ggml.bin --temp 0 --max_length 512 --mode ge
Baichuan-13B-Chat
```sh
-python3 cli_chat.py -m ../baichuan-13b-chat-ggml.bin -p 你好 --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.1 # CLI demo
+python3 cli_demo.py -m ../baichuan-13b-chat-ggml.bin -p 你好 --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.1 # CLI demo
python3 web_demo.py -m ../baichuan-13b-chat-ggml.bin --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.1 # web demo
```
@@ -321,7 +410,7 @@ python3 web_demo.py -m ../baichuan-13b-chat-ggml.bin --top_k 5 --top_p 0.85 --te
Baichuan2-7B-Chat
```sh
-python3 cli_chat.py -m ../baichuan2-7b-chat-ggml.bin -p 你好 --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.05 # CLI demo
+python3 cli_demo.py -m ../baichuan2-7b-chat-ggml.bin -p 你好 --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.05 # CLI demo
python3 web_demo.py -m ../baichuan2-7b-chat-ggml.bin --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.05 # web demo
```
@@ -330,7 +419,7 @@ python3 web_demo.py -m ../baichuan2-7b-chat-ggml.bin --top_k 5 --top_p 0.85 --te
Baichuan2-13B-Chat
```sh
-python3 cli_chat.py -m ../baichuan2-13b-chat-ggml.bin -p 你好 --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.05 # CLI demo
+python3 cli_demo.py -m ../baichuan2-13b-chat-ggml.bin -p 你好 --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.05 # CLI demo
python3 web_demo.py -m ../baichuan2-13b-chat-ggml.bin --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.05 # web demo
```
@@ -339,7 +428,7 @@ python3 web_demo.py -m ../baichuan2-13b-chat-ggml.bin --top_k 5 --top_p 0.85 --t
InternLM-Chat-7B
```sh
-python3 cli_chat.py -m ../internlm-chat-7b-ggml.bin -p 你好 --top_p 0.8 --temp 0.8 # CLI demo
+python3 cli_demo.py -m ../internlm-chat-7b-ggml.bin -p 你好 --top_p 0.8 --temp 0.8 # CLI demo
python3 web_demo.py -m ../internlm-chat-7b-ggml.bin --top_p 0.8 --temp 0.8 # web demo
```
@@ -348,12 +437,12 @@ python3 web_demo.py -m ../internlm-chat-7b-ggml.bin --top_p 0.8 --temp 0.8 # we
InternLM-Chat-20B
```sh
-python3 cli_chat.py -m ../internlm-chat-20b-ggml.bin -p 你好 --top_p 0.8 --temp 0.8 # CLI demo
+python3 cli_demo.py -m ../internlm-chat-20b-ggml.bin -p 你好 --top_p 0.8 --temp 0.8 # CLI demo
python3 web_demo.py -m ../internlm-chat-20b-ggml.bin --top_p 0.8 --temp 0.8 # web demo
```
-**Load and optimize Hugging Face LLMs in one line of code**
+**Converting Hugging Face LLMs at Runtime**
Sometimes it might be inconvenient to convert and save the intermediate GGML models beforehand. Here is an option to directly load from the original Hugging Face model, quantize it into GGML models in a minute, and start serving. All you need is to replace the GGML model path with the Hugging Face model name or path.
```python
@@ -369,7 +458,7 @@ Processing model states: 100%|████████████████
Likewise, replace the GGML model path with Hugging Face model in any example script, and it just works. For example:
```sh
-python3 cli_chat.py -m THUDM/chatglm-6b -p 你好 -i
+python3 cli_demo.py -m THUDM/chatglm-6b -p 你好 -i
```
## API Server
@@ -443,7 +532,7 @@ docker build . --network=host -t chatglm.cpp
# cpp demo
docker run -it --rm -v $PWD:/opt chatglm.cpp ./build/bin/main -m /opt/chatglm-ggml.bin -p "你好"
# python demo
-docker run -it --rm -v $PWD:/opt chatglm.cpp python3 examples/cli_chat.py -m /opt/chatglm-ggml.bin -p "你好"
+docker run -it --rm -v $PWD:/opt chatglm.cpp python3 examples/cli_demo.py -m /opt/chatglm-ggml.bin -p "你好"
# langchain api server
docker run -it --rm -v $PWD:/opt -p 8000:8000 -e MODEL=/opt/chatglm-ggml.bin chatglm.cpp \
uvicorn chatglm_cpp.langchain_api:app --host 0.0.0.0 --port 8000
diff --git a/chatglm.cpp b/chatglm.cpp
index 0fdcdd2b..3cac2a41 100644
--- a/chatglm.cpp
+++ b/chatglm.cpp
@@ -136,6 +136,23 @@ ggml_tensor *tensor_to_cpu(ggml_tensor *tensor) {
return tensor;
}
+const std::string ToolCallMessage::TYPE_FUNCTION = "function";
+const std::string ToolCallMessage::TYPE_CODE = "code";
+
+const std::string ChatMessage::ROLE_USER = "user";
+const std::string ChatMessage::ROLE_ASSISTANT = "assistant";
+const std::string ChatMessage::ROLE_SYSTEM = "system";
+const std::string ChatMessage::ROLE_OBSERVATION = "observation";
+
+void BaseTokenizer::check_chat_messages(const std::vector &messages) {
+ CHATGLM_CHECK(messages.size() % 2 == 1) << "invalid chat messages size " << messages.size();
+ for (size_t i = 0; i < messages.size(); i++) {
+ const std::string &target_role = (i % 2 == 0) ? ChatMessage::ROLE_USER : ChatMessage::ROLE_ASSISTANT;
+ CHATGLM_CHECK(messages[i].role == target_role)
+ << "expect messages[" << i << "].role to be " << target_role << ", but got " << messages[i].role;
+ }
+}
+
// Adapted from https://github.com/ggerganov/llama.cpp/blob/master/llama.cpp
void ggml_graph_compute_helper(std::vector &buf, ggml_cgraph *graph, int n_threads) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
@@ -192,6 +209,24 @@ void StreamerGroup::end() {
}
}
+// reference: https://stackoverflow.com/questions/216823/how-to-trim-a-stdstring
+
+// trim from start (in place)
+static inline void ltrim(std::string &s) {
+ s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { return !std::isspace(ch); }));
+}
+
+// trim from end (in place)
+static inline void rtrim(std::string &s) {
+ s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base(), s.end());
+}
+
+// trim from both ends (in place)
+static inline void trim(std::string &s) {
+ rtrim(s);
+ ltrim(s);
+}
+
void TextStreamer::put(const std::vector &output_ids) {
if (is_prompt_) {
// skip prompt
@@ -203,6 +238,9 @@ void TextStreamer::put(const std::vector &output_ids) {
token_cache_.insert(token_cache_.end(), output_ids.begin(), output_ids.end());
std::string text = tokenizer_->decode(token_cache_);
+ if (is_first_line_) {
+ ltrim(text);
+ }
if (text.empty()) {
return;
}
@@ -211,6 +249,7 @@ void TextStreamer::put(const std::vector &output_ids) {
if (text.back() == '\n') {
// flush the cache after newline
printable_text = text.substr(print_len_);
+ is_first_line_ = false;
token_cache_.clear();
print_len_ = 0;
} else if (std::find(puncts.begin(), puncts.end(), text.back()) != puncts.end()) {
@@ -227,8 +266,12 @@ void TextStreamer::put(const std::vector &output_ids) {
void TextStreamer::end() {
std::string text = tokenizer_->decode(token_cache_);
+ if (is_first_line_) {
+ ltrim(text);
+ }
os_ << text.substr(print_len_) << std::endl;
is_prompt_ = true;
+ is_first_line_ = true;
token_cache_.clear();
print_len_ = 0;
}
@@ -418,20 +461,20 @@ int get_default_num_threads() {
std::string to_string(ModelType model_type) {
switch (model_type) {
- case MODEL_TYPE_CHATGLM:
+ case ModelType::CHATGLM:
return "ChatGLM";
- case MODEL_TYPE_CHATGLM2:
+ case ModelType::CHATGLM2:
return "ChatGLM2";
- case MODEL_TYPE_CHATGLM3:
+ case ModelType::CHATGLM3:
return "ChatGLM3";
- case MODEL_TYPE_BAICHUAN7B:
+ case ModelType::BAICHUAN7B:
return "Baichuan7B";
- case MODEL_TYPE_BAICHUAN13B:
+ case ModelType::BAICHUAN13B:
return "Baichuan13B";
- case MODEL_TYPE_INTERNLM:
+ case ModelType::INTERNLM:
return "InternLM";
default:
- CHATGLM_THROW << "unknown model type " << model_type;
+ CHATGLM_THROW << "unknown model type " << (int)model_type;
}
}
@@ -632,7 +675,9 @@ std::vector BaseModelForCausalLM::generate(const std::vector &input_id
streamer->put({next_token_id});
}
- if (next_token_id == config.eos_token_id) {
+ if (next_token_id == config.eos_token_id ||
+ std::find(config.extra_eos_token_ids.begin(), config.extra_eos_token_ids.end(), next_token_id) !=
+ config.extra_eos_token_ids.end()) {
break;
}
}
@@ -669,23 +714,23 @@ std::vector ChatGLMTokenizer::encode(const std::string &text, int max_lengt
return ids;
}
-std::vector ChatGLMTokenizer::encode_history(const std::vector &history, int max_length) const {
- std::string prompt = build_prompt(history);
+std::vector ChatGLMTokenizer::encode_messages(const std::vector &messages, int max_length) const {
+ std::string prompt = build_prompt(messages);
std::vector input_ids = encode(prompt, max_length);
return input_ids;
}
-std::string ChatGLMTokenizer::build_prompt(const std::vector &history) {
- CHATGLM_CHECK(history.size() % 2 == 1) << "invalid history size " << history.size();
+std::string ChatGLMTokenizer::build_prompt(const std::vector &messages) {
+ check_chat_messages(messages);
std::ostringstream oss_prompt;
- if (history.size() == 1) {
- oss_prompt << history.front();
+ if (messages.size() == 1) {
+ oss_prompt << messages.front().content;
} else {
- for (size_t i = 0; i < history.size(); i += 2) {
- oss_prompt << "[Round " << i / 2 << "]\n问:" << history[i] << "\n答:";
- if (i < history.size() - 1) {
- oss_prompt << history[i + 1] << "\n";
+ for (size_t i = 0; i < messages.size(); i += 2) {
+ oss_prompt << "[Round " << i / 2 << "]\n问:" << messages[i].content << "\n答:";
+ if (i + 1 < messages.size()) {
+ oss_prompt << messages[i + 1].content << "\n";
}
}
}
@@ -909,20 +954,20 @@ std::string ChatGLM2Tokenizer::decode(const std::vector &ids) const {
return text;
}
-std::vector ChatGLM2Tokenizer::encode_history(const std::vector &history, int max_length) const {
- std::string prompt = build_prompt(history);
+std::vector ChatGLM2Tokenizer::encode_messages(const std::vector &messages, int max_length) const {
+ std::string prompt = build_prompt(messages);
std::vector input_ids = encode(prompt, max_length);
return input_ids;
}
-std::string ChatGLM2Tokenizer::build_prompt(const std::vector &history) {
- CHATGLM_CHECK(history.size() % 2 == 1) << "invalid history size " << history.size();
+std::string ChatGLM2Tokenizer::build_prompt(const std::vector &messages) {
+ check_chat_messages(messages);
std::ostringstream oss_prompt;
- for (size_t i = 0; i < history.size(); i += 2) {
- oss_prompt << "[Round " << i / 2 + 1 << "]\n\n问:" << history[i] << "\n\n答:";
- if (i < history.size() - 1) {
- oss_prompt << history[i + 1] << "\n\n";
+ for (size_t i = 0; i < messages.size(); i += 2) {
+ oss_prompt << "[Round " << i / 2 + 1 << "]\n\n问:" << messages[i].content << "\n\n答:";
+ if (i < messages.size() - 1) {
+ oss_prompt << messages[i + 1].content << "\n\n";
}
}
return oss_prompt.str();
@@ -1014,6 +1059,22 @@ ChatGLM3Tokenizer::ChatGLM3Tokenizer(std::string_view serialized_model_proto) {
user_token_id = special_id++;
assistant_token_id = special_id++;
observation_token_id = special_id++;
+
+ special_tokens = {
+ {"[MASK]", mask_token_id},
+ {"[gMASK]", gmask_token_id},
+ {"[sMASK]", smask_token_id},
+ {"sop", sop_token_id},
+ {"eop", eop_token_id},
+ {"<|system|>", system_token_id},
+ {"<|user|>", user_token_id},
+ {"<|assistant|>", assistant_token_id},
+ {"<|observation|>", observation_token_id},
+ };
+
+ for (const auto &item : special_tokens) {
+ index_special_tokens[item.second] = item.first;
+ }
}
std::vector ChatGLM3Tokenizer::encode(const std::string &text, int max_length) const {
@@ -1025,44 +1086,138 @@ std::vector ChatGLM3Tokenizer::encode(const std::string &text, int max_leng
}
std::string ChatGLM3Tokenizer::decode(const std::vector &ids) const {
- // filter out special tokens
- std::vector normal_ids(ids);
- normal_ids.erase(std::remove_if(normal_ids.begin(), normal_ids.end(), [this](int id) { return is_special_id(id); }),
- normal_ids.end());
+ std::string text = decode_with_special_tokens(ids);
+ text = remove_special_tokens(text);
+ return text;
+}
- std::string text;
- sp.Decode(normal_ids, &text);
- text = replace_punctuations(text);
+std::string ChatGLM3Tokenizer::decode_with_special_tokens(const std::vector &ids) const {
+ std::vector pieces;
+ for (int id : ids) {
+ auto pos = index_special_tokens.find(id);
+ if (pos != index_special_tokens.end()) {
+ // special tokens
+ pieces.emplace_back(pos->second);
+ } else {
+ // normal tokens
+ pieces.emplace_back(sp.IdToPiece(id));
+ }
+ }
+
+ std::string text = sp.DecodePieces(pieces);
return text;
}
-std::vector ChatGLM3Tokenizer::encode_history(const std::vector &history, int max_length) const {
- // TODO: need a new api for system / tools / metadata prompt
+std::string ChatGLM3Tokenizer::remove_special_tokens(const std::string &text) {
+ std::string output = text;
+ static const std::vector special_token_regex{
+ // std::regex(R"(<\|assistant\|> interpreter)"),
+ // std::regex(R"(<\|assistant\|> interpre)"),
+ std::regex(R"(<\|assistant\|>)"),
+ std::regex(R"(<\|user\|>)"),
+ std::regex(R"(<\|observation\|>)"),
+ };
+ for (const auto &re : special_token_regex) {
+ output = std::regex_replace(output, re, "");
+ }
+ return output;
+}
+
+std::vector ChatGLM3Tokenizer::encode_single_message(const std::string &role, const std::string &content) const {
+ std::vector input_ids;
+ input_ids.emplace_back(get_command("<|" + role + "|>"));
+ // TODO: support metadata
std::vector newline_ids;
sp.Encode("\n", &newline_ids);
+ input_ids.insert(input_ids.end(), newline_ids.begin(), newline_ids.end());
+ std::vector content_ids;
+ sp.Encode(content, &content_ids);
+ input_ids.insert(input_ids.end(), content_ids.begin(), content_ids.end());
+ return input_ids;
+}
+
+std::vector ChatGLM3Tokenizer::encode_messages(const std::vector &messages, int max_length) const {
std::vector input_ids{gmask_token_id, sop_token_id};
- for (size_t i = 0; i < history.size(); i++) {
- // TODO: support all roles
- input_ids.emplace_back((i % 2 == 0) ? user_token_id : assistant_token_id);
- // TODO: support metadata
- input_ids.insert(input_ids.end(), newline_ids.begin(), newline_ids.end());
- std::vector content_ids;
- sp.Encode(history[i], &content_ids);
- input_ids.insert(input_ids.end(), content_ids.begin(), content_ids.end());
+ for (const auto &msg : messages) {
+ auto msg_ids = encode_single_message(msg.role, msg.content);
+ input_ids.insert(input_ids.end(), msg_ids.begin(), msg_ids.end());
+
+ // encode code block into a separate message
+ if (!msg.tool_calls.empty() && msg.tool_calls.front().type == ToolCallMessage::TYPE_CODE) {
+ auto code_ids = encode_single_message(msg.role, msg.tool_calls.front().code.input);
+ input_ids.insert(input_ids.end(), code_ids.begin(), code_ids.end());
+ }
}
input_ids.emplace_back(assistant_token_id);
- // NOTE: push '\n' into input_ids to avoid model generating it, saving 2 tokens
- input_ids.insert(input_ids.end(), newline_ids.begin(), newline_ids.end());
truncate(input_ids, max_length);
return input_ids;
}
-bool ChatGLM3Tokenizer::is_special_id(int id) const {
- return id == mask_token_id || id == gmask_token_id || id == smask_token_id || id == sop_token_id ||
- id == eop_token_id || id == system_token_id || id == user_token_id || id == assistant_token_id ||
- id == observation_token_id;
+ChatMessage ChatGLM3Tokenizer::decode_message(const std::vector &ids) const {
+ ChatMessage message;
+ if (!ids.empty() && ids.back() == observation_token_id) {
+ // insert an <|assistant|> token before content to match possible interpreter delimiter
+ std::vector full_ids{assistant_token_id};
+ full_ids.insert(full_ids.end(), ids.begin(), ids.end());
+
+ std::string output = decode_with_special_tokens(full_ids);
+ const std::string ci_delim = "<|assistant|> interpreter";
+ size_t ci_pos = output.find(ci_delim);
+ if (ci_pos != std::string::npos) {
+ // code interpreter
+ std::string chat_output = output.substr(0, ci_pos);
+ chat_output = remove_special_tokens(chat_output);
+ trim(chat_output);
+ std::string code_output = output.substr(ci_pos + ci_delim.size());
+ code_output = remove_special_tokens(code_output);
+ trim(code_output);
+ message = ChatMessage(ChatMessage::ROLE_ASSISTANT, std::move(chat_output),
+ {ToolCallMessage(CodeMessage(std::move(code_output)))});
+ } else {
+ // tool call
+ output = remove_special_tokens(output);
+
+ // parse tool name
+ std::string tool_name = "PARSE_ERROR";
+ size_t pos = output.find('\n');
+ if (pos != std::string::npos) {
+ // split tool name and args by 1st linebreak
+ tool_name = output.substr(0, pos);
+ trim(tool_name);
+ output.erase(0, pos + 1);
+ }
+
+ // post process output
+ trim(output);
+
+ // extract args
+ std::string tool_args = "PARSE_ERROR";
+ static const std::regex args_regex(R"(```.*?\n(.*?)\n```)");
+ std::smatch sm;
+ if (std::regex_search(output, sm, args_regex)) {
+ CHATGLM_CHECK(sm.size() == 2) << "unexpected regex match results";
+ tool_args = sm[1];
+ }
+
+ message = ChatMessage(ChatMessage::ROLE_ASSISTANT, std::move(output),
+ {ToolCallMessage(FunctionMessage(std::move(tool_name), std::move(tool_args)))});
+ }
+ } else {
+ // conversation
+ message = BaseTokenizer::decode_message(ids);
+ trim(message.content); // strip leading linebreak in conversation mode
+ }
+ return message;
}
+int ChatGLM3Tokenizer::get_command(const std::string &token) const {
+ auto pos = special_tokens.find(token);
+ CHATGLM_CHECK(pos != special_tokens.end()) << token << " is not a special token";
+ return pos->second;
+}
+
+bool ChatGLM3Tokenizer::is_special_id(int id) const { return index_special_tokens.count(id) > 0; }
+
void ChatGLM3Tokenizer::truncate(std::vector &ids, int max_length) {
if ((int)ids.size() > max_length) {
// sliding window: drop the least recent history while keeping the two special prefix tokens
@@ -1095,18 +1250,14 @@ std::string BaichuanTokenizer::decode(const std::vector &ids) const {
return text;
}
-std::vector BaichuanTokenizer::encode_history(const std::vector &history, int max_length) const {
- CHATGLM_CHECK(history.size() % 2 == 1) << "invalid history size " << history.size();
+std::vector BaichuanTokenizer::encode_messages(const std::vector &messages, int max_length) const {
+ check_chat_messages(messages);
std::vector ids;
ids.reserve(max_length);
- for (size_t i = 0; i < history.size(); i++) {
- if (i % 2 == 0) {
- ids.push_back(USER_TOKEN_ID);
- } else {
- ids.push_back(ASSISTANT_TOKEN_ID);
- }
- std::vector content_ids = encode(history[i], max_length);
+ for (const auto &msg : messages) {
+ ids.push_back((msg.role == ChatMessage::ROLE_USER) ? USER_TOKEN_ID : ASSISTANT_TOKEN_ID);
+ std::vector content_ids = encode(msg.content, max_length);
ids.insert(ids.end(), content_ids.begin(), content_ids.end());
}
ids.push_back(ASSISTANT_TOKEN_ID);
@@ -1242,20 +1393,21 @@ std::string InternLMTokenizer::decode(const std::vector &ids) const {
return text;
}
-std::vector InternLMTokenizer::encode_history(const std::vector &history, int max_length) const {
- std::string prompt = build_prompt(history);
+std::vector InternLMTokenizer::encode_messages(const std::vector &messages, int max_length) const {
+ std::string prompt = build_prompt(messages);
std::vector input_ids = encode(prompt, max_length);
return input_ids;
}
-std::string InternLMTokenizer::build_prompt(const std::vector &history) {
- CHATGLM_CHECK(history.size() % 2 == 1) << "invalid history size " << history.size();
+std::string InternLMTokenizer::build_prompt(const std::vector &messages) {
+ check_chat_messages(messages);
std::ostringstream oss_prompt;
- for (size_t i = 0; i < history.size(); i += 2) {
- oss_prompt << "<|User|>:" << history[i] << "\n<|Bot|>:";
- if (i < history.size() - 1) {
- oss_prompt << history[i + 1] << "\n";
+ for (const auto &msg : messages) {
+ if (msg.role == ChatMessage::ROLE_USER) {
+ oss_prompt << "<|User|>:" << msg.content << "\n<|Bot|>:";
+ } else {
+ oss_prompt << msg.content << "\n";
}
}
return oss_prompt.str();
@@ -1324,7 +1476,7 @@ Pipeline::Pipeline(const std::string &path) {
ModelType model_type = (ModelType)loader.read_basic();
// load version
int version = loader.read_basic();
- if (model_type == MODEL_TYPE_CHATGLM) {
+ if (model_type == ModelType::CHATGLM) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;
// load config
@@ -1339,7 +1491,7 @@ Pipeline::Pipeline(const std::string &path) {
// load model
model = std::make_unique(config);
model->load(loader);
- } else if (model_type == MODEL_TYPE_CHATGLM2 || model_type == MODEL_TYPE_CHATGLM3) {
+ } else if (model_type == ModelType::CHATGLM2 || model_type == ModelType::CHATGLM3) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;
// load config
@@ -1350,17 +1502,19 @@ Pipeline::Pipeline(const std::string &path) {
std::string_view serialized_model_proto((char *)mapped_file->data + loader.tell(), proto_size);
loader.seek(proto_size, SEEK_CUR);
- if (model_type == MODEL_TYPE_CHATGLM2) {
+ if (model_type == ModelType::CHATGLM2) {
tokenizer = std::make_unique(serialized_model_proto);
model = std::make_unique(config);
} else {
- tokenizer = std::make_unique(serialized_model_proto);
+ auto chatglm3_tokenizer = std::make_unique(serialized_model_proto);
+ config.extra_eos_token_ids = {chatglm3_tokenizer->observation_token_id, chatglm3_tokenizer->user_token_id};
+ tokenizer = std::move(chatglm3_tokenizer);
model = std::make_unique(config);
}
// load model
model->load(loader);
- } else if (model_type == MODEL_TYPE_BAICHUAN7B) {
+ } else if (model_type == ModelType::BAICHUAN7B) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;
// load config
@@ -1376,7 +1530,7 @@ Pipeline::Pipeline(const std::string &path) {
// load model
model = std::make_unique(config);
model->load(loader);
- } else if (model_type == MODEL_TYPE_BAICHUAN13B) {
+ } else if (model_type == ModelType::BAICHUAN13B) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;
// load config
@@ -1392,7 +1546,7 @@ Pipeline::Pipeline(const std::string &path) {
// load model
model = std::make_unique(config);
model->load(loader);
- } else if (model_type == MODEL_TYPE_INTERNLM) {
+ } else if (model_type == ModelType::INTERNLM) {
CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version;
// load config
@@ -1413,7 +1567,7 @@ Pipeline::Pipeline(const std::string &path) {
}
model->load(loader);
} else {
- CHATGLM_THROW << "invalid model type " << model_type;
+ CHATGLM_THROW << "invalid model type " << (int)model_type;
}
}
@@ -1432,11 +1586,11 @@ std::string Pipeline::generate(const std::string &prompt, const GenerationConfig
return output;
}
-std::string Pipeline::chat(const std::vector &history, const GenerationConfig &gen_config,
+ChatMessage Pipeline::chat(const std::vector &messages, const GenerationConfig &gen_config,
BaseStreamer *streamer) const {
- std::vector input_ids = tokenizer->encode_history(history, gen_config.max_context_length);
+ std::vector input_ids = tokenizer->encode_messages(messages, gen_config.max_context_length);
std::vector new_output_ids = generate(input_ids, gen_config, streamer);
- std::string output = tokenizer->decode(new_output_ids);
+ ChatMessage output = tokenizer->decode_message(new_output_ids);
return output;
}
diff --git a/chatglm.h b/chatglm.h
index cf3562b2..2394e159 100644
--- a/chatglm.h
+++ b/chatglm.h
@@ -2,10 +2,10 @@
#include
#include
+#include
#include
#include
#include
-#include
#ifdef GGML_USE_METAL
#include
@@ -46,13 +46,13 @@ ggml_tensor *tensor_to_device(ggml_tensor *tensor);
ggml_tensor *tensor_to_cpu(ggml_tensor *tensor);
-enum ModelType {
- MODEL_TYPE_CHATGLM = 1,
- MODEL_TYPE_CHATGLM2 = 2,
- MODEL_TYPE_CHATGLM3 = 3,
- MODEL_TYPE_BAICHUAN7B = 1024,
- MODEL_TYPE_BAICHUAN13B = 1025,
- MODEL_TYPE_INTERNLM = 1280,
+enum class ModelType {
+ CHATGLM = 1,
+ CHATGLM2 = 2,
+ CHATGLM3 = 3,
+ BAICHUAN7B = 1024,
+ BAICHUAN13B = 1025,
+ INTERNLM = 1280,
};
std::string to_string(ModelType model_type);
@@ -87,21 +87,23 @@ class ModelConfig {
ModelConfig(ModelType model_type, ggml_type dtype, int vocab_size, int hidden_size, int num_attention_heads,
int num_kv_heads, int num_hidden_layers, int intermediate_size, float norm_eps, int max_length,
- int bos_token_id, int eos_token_id, int pad_token_id, int sep_token_id)
+ int bos_token_id, int eos_token_id, int pad_token_id, int sep_token_id,
+ std::vector extra_eos_token_ids)
: model_type(model_type), dtype(dtype), vocab_size(vocab_size), hidden_size(hidden_size),
num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), num_hidden_layers(num_hidden_layers),
intermediate_size(intermediate_size), norm_eps(norm_eps), max_length(max_length), bos_token_id(bos_token_id),
- eos_token_id(eos_token_id), pad_token_id(pad_token_id), sep_token_id(sep_token_id) {}
+ eos_token_id(eos_token_id), pad_token_id(pad_token_id), sep_token_id(sep_token_id),
+ extra_eos_token_ids(std::move(extra_eos_token_ids)) {}
ModelConfig(ModelType model_type, const ConfigRecordV1 &rec)
: ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads,
rec.num_attention_heads, rec.num_hidden_layers, rec.intermediate_size, 1e-5, rec.max_length,
- rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, rec.sep_token_id) {}
+ rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, rec.sep_token_id, {}) {}
ModelConfig(ModelType model_type, const ConfigRecordV2 &rec)
: ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_kv_heads,
rec.num_hidden_layers, rec.intermediate_size, 1e-5, rec.max_length, rec.bos_token_id,
- rec.eos_token_id, rec.pad_token_id, rec.sep_token_id) {}
+ rec.eos_token_id, rec.pad_token_id, rec.sep_token_id, {}) {}
std::string model_type_name() const { return to_string(model_type); }
@@ -120,14 +122,91 @@ class ModelConfig {
int eos_token_id;
int pad_token_id;
int sep_token_id;
+ std::vector extra_eos_token_ids;
+};
+
+struct FunctionMessage {
+ std::string name;
+ std::string arguments;
+
+ FunctionMessage() = default;
+ FunctionMessage(std::string name, std::string arguments) : name(std::move(name)), arguments(std::move(arguments)) {}
+
+ friend std::ostream &operator<<(std::ostream &os, const FunctionMessage &self) {
+ return os << "FunctionMessage(name=" << std::quoted(self.name) << ", arguments=" << std::quoted(self.arguments)
+ << ")";
+ }
+};
+
+struct CodeMessage {
+ std::string input;
+
+ CodeMessage() = default;
+ CodeMessage(std::string input) : input(std::move(input)) {}
+
+ friend std::ostream &operator<<(std::ostream &os, const CodeMessage &self) {
+ return os << "CodeMessage(input=" << std::quoted(self.input) << ")";
+ }
+};
+
+struct ToolCallMessage {
+ std::string type;
+ FunctionMessage function;
+ CodeMessage code;
+
+ static const std::string TYPE_FUNCTION;
+ static const std::string TYPE_CODE;
+
+ ToolCallMessage(FunctionMessage function) : type(TYPE_FUNCTION), function(std::move(function)) {}
+
+ ToolCallMessage(CodeMessage code) : type(TYPE_CODE), code(std::move(code)) {}
+
+ friend std::ostream &operator<<(std::ostream &os, const ToolCallMessage &self) {
+ return os << "ToolCallMessage(type=" << std::quoted(self.type) << ", function=" << self.function
+ << ", code=" << self.code << ")";
+ }
+};
+
+struct ChatMessage {
+ std::string role;
+ std::string content;
+ std::vector tool_calls;
+
+ static const std::string ROLE_USER;
+ static const std::string ROLE_ASSISTANT;
+ static const std::string ROLE_SYSTEM;
+ static const std::string ROLE_OBSERVATION;
+
+ ChatMessage() = default;
+ ChatMessage(std::string role, std::string content, std::vector tool_calls = {})
+ : role(std::move(role)), content(std::move(content)), tool_calls(std::move(tool_calls)) {}
+
+ friend std::ostream &operator<<(std::ostream &os, const ChatMessage &self) {
+ os << "ChatMessage(role=" << std::quoted(self.role) << ", content=" << std::quoted(self.content)
+ << ", tool_calls=[";
+ for (size_t i = 0; i < self.tool_calls.size(); i++) {
+ os << (i > 0 ? ", " : "") << self.tool_calls[i];
+ }
+ return os << "])";
+ }
};
class BaseTokenizer {
public:
virtual ~BaseTokenizer() = default;
+
virtual std::vector encode(const std::string &text, int max_length) const = 0;
+
virtual std::string decode(const std::vector &ids) const = 0;
- virtual std::vector encode_history(const std::vector &history, int max_length) const = 0;
+
+ virtual std::vector encode_messages(const std::vector &messages, int max_length) const = 0;
+
+ virtual ChatMessage decode_message(const std::vector &ids) const {
+ return {ChatMessage::ROLE_ASSISTANT, decode(ids)};
+ }
+
+ protected:
+ static void check_chat_messages(const std::vector &messages);
};
struct ggml_context_deleter_t {
@@ -237,20 +316,20 @@ class RMSNorm {
float eps;
};
-enum ActivationType {
- ACT_TYPE_GELU,
- ACT_TYPE_SILU,
+enum class ActivationType {
+ GELU,
+ SILU,
};
template
static inline ggml_tensor *apply_activation_inplace(ggml_context *ctx, ggml_tensor *hidden_states) {
- static_assert(ACT_TYPE == ACT_TYPE_GELU || ACT_TYPE == ACT_TYPE_SILU);
- if constexpr (ACT_TYPE == ACT_TYPE_GELU) {
+ static_assert(ACT_TYPE == ActivationType::GELU || ACT_TYPE == ActivationType::SILU);
+ if constexpr (ACT_TYPE == ActivationType::GELU) {
hidden_states = tensor_assign_buffers(ggml_gelu_inplace(ctx, hidden_states));
- } else if constexpr (ACT_TYPE == ACT_TYPE_SILU) {
+ } else if constexpr (ACT_TYPE == ActivationType::SILU) {
hidden_states = tensor_assign_buffers(ggml_silu_inplace(ctx, hidden_states));
} else {
- CHATGLM_THROW << "Unknown activation type " << ACT_TYPE;
+ CHATGLM_THROW << "Unknown activation type " << (int)ACT_TYPE;
}
return hidden_states;
}
@@ -650,7 +729,7 @@ class StreamerGroup : public BaseStreamer {
class TextStreamer : public BaseStreamer {
public:
TextStreamer(std::ostream &os, BaseTokenizer *tokenizer)
- : os_(os), tokenizer_(tokenizer), is_prompt_(true), print_len_(0) {}
+ : os_(os), tokenizer_(tokenizer), is_prompt_(true), is_first_line_(true), print_len_(0) {}
void put(const std::vector &output_ids) override;
void end() override;
@@ -658,6 +737,7 @@ class TextStreamer : public BaseStreamer {
std::ostream &os_;
BaseTokenizer *tokenizer_;
bool is_prompt_;
+ bool is_first_line_;
std::vector token_cache_;
int print_len_;
};
@@ -870,9 +950,9 @@ class ChatGLMTokenizer : public BaseTokenizer {
std::string decode(const std::vector &ids) const override;
- std::vector encode_history(const std::vector &history, int max_length) const override;
+ std::vector encode_messages(const std::vector &messages, int max_length) const override;
- static std::string build_prompt(const std::vector &history);
+ static std::string build_prompt(const std::vector &messages);
private:
static std::string preprocess(const std::string &text);
@@ -894,7 +974,7 @@ struct GLMContextMasker {
using GLMAttention = BasicAttention;
-using GLMMLP = BasicMLP;
+using GLMMLP = BasicMLP;
// NOTE: disable inplace norm since it causes nonsense on cuda when sequence length >= 144
class GLMBlock : public BasicBlock {
@@ -942,10 +1022,11 @@ class ChatGLM2Tokenizer : public BaseTokenizer {
std::string decode(const std::vector &ids) const override;
- std::vector encode_history(const std::vector &history, int max_length) const override;
+ std::vector encode_messages(const std::vector &messages, int max_length) const override;
- static std::string build_prompt(const std::vector &history);
+ static std::string build_prompt(const std::vector &messages);
+ private:
bool is_special_id(int id) const;
public:
@@ -959,7 +1040,7 @@ class ChatGLM2Tokenizer : public BaseTokenizer {
using GLM2Attention = BasicAttention, false, CausalContextMasker>;
-using GLM2MLP = BasicGLU;
+using GLM2MLP = BasicGLU;
using GLM2Block = BasicBlock;
@@ -991,11 +1072,21 @@ class ChatGLM3Tokenizer : public BaseTokenizer {
std::string decode(const std::vector &ids) const override;
- std::vector encode_history(const std::vector &history, int max_length) const override;
+ std::vector encode_messages(const std::vector &messages, int max_length) const override;
+
+ ChatMessage decode_message(const std::vector &ids) const override;
+
+ private:
+ std::vector encode_single_message(const std::string &role, const std::string &content) const;
+
+ std::string decode_with_special_tokens(const std::vector &ids) const;
+
+ static std::string remove_special_tokens(const std::string &text);
+
+ int get_command(const std::string &token) const;
bool is_special_id(int id) const;
- protected:
static void truncate(std::vector &ids, int max_length);
public:
@@ -1009,6 +1100,8 @@ class ChatGLM3Tokenizer : public BaseTokenizer {
int user_token_id;
int assistant_token_id;
int observation_token_id;
+ std::unordered_map special_tokens;
+ std::unordered_map index_special_tokens;
};
using ChatGLM3Model = ChatGLM2Model;
@@ -1025,11 +1118,11 @@ class BaichuanTokenizer : public BaseTokenizer {
std::string decode(const std::vector &ids) const override;
- std::vector encode_history(const std::vector &history, int max_length) const override;
+ std::vector encode_messages(const std::vector &messages, int max_length) const override;
+ private:
bool is_special_id(int id) const;
- protected:
static void truncate(std::vector &ids, int max_length);
public:
@@ -1047,7 +1140,7 @@ class BaichuanTokenizer : public BaseTokenizer {
using Baichuan7BAttention =
BasicAttention, false, CausalContextMasker>;
-using Baichuan7BMLP = BasicGLU;
+using Baichuan7BMLP = BasicGLU;
using Baichuan7BBlock = BasicBlock;
@@ -1073,7 +1166,7 @@ class Baichuan7BForCausalLM : public BasicModelForCausalLM {
using Baichuan13BAttention = BasicAttention;
-using Baichuan13BMLP = BasicGLU;
+using Baichuan13BMLP = BasicGLU;
using Baichuan13BBlock = BasicBlock;
@@ -1105,10 +1198,11 @@ class InternLMTokenizer : public BaseTokenizer {
std::string decode(const std::vector &ids) const override;
- std::vector encode_history(const std::vector &history, int max_length) const override;
+ std::vector encode_messages(const std::vector &messages, int max_length) const override;
- static std::string build_prompt(const std::vector &history);
+ static std::string build_prompt(const std::vector &messages);
+ private:
bool is_special_id(int id) const { return id == unk_token_id || id == bos_token_id || id == eos_token_id; }
public:
@@ -1121,7 +1215,7 @@ class InternLMTokenizer : public BaseTokenizer {
using InternLM7BAttention =
BasicAttention, false, CausalContextMasker>;
-using InternLM7BMLP = BasicGLU;
+using InternLM7BMLP = BasicGLU;
using InternLM7BBlock = BasicBlock;
@@ -1130,7 +1224,7 @@ using InternLM7BModel = BasicModel, false, CausalContextMasker>;
-using InternLM20BMLP = BasicGLU;
+using InternLM20BMLP = BasicGLU;
using InternLM20BBlock = BasicBlock;
@@ -1171,7 +1265,7 @@ class Pipeline {
std::string generate(const std::string &prompt, const GenerationConfig &gen_config,
BaseStreamer *streamer = nullptr) const;
- std::string chat(const std::vector &history, const GenerationConfig &gen_config,
+ ChatMessage chat(const std::vector &messages, const GenerationConfig &gen_config,
BaseStreamer *streamer = nullptr) const;
public:
diff --git a/chatglm_cpp/_C.pyi b/chatglm_cpp/_C.pyi
new file mode 100644
index 00000000..20a5df80
--- /dev/null
+++ b/chatglm_cpp/_C.pyi
@@ -0,0 +1,193 @@
+"""
+ChatGLM.cpp python binding
+"""
+from __future__ import annotations
+import typing
+__all__ = ['Baichuan13BForCausalLM', 'Baichuan7BForCausalLM', 'BaichuanTokenizer', 'BaseModelForCausalLM', 'BaseTokenizer', 'ChatGLM2ForCausalLM', 'ChatGLM2Tokenizer', 'ChatGLM3Tokenizer', 'ChatGLMForCausalLM', 'ChatGLMTokenizer', 'ChatMessage', 'CodeMessage', 'FunctionMessage', 'GenerationConfig', 'InternLM20BForCausalLM', 'InternLM7BForCausalLM', 'InternLMTokenizer', 'ModelConfig', 'ModelType', 'Pipeline', 'ToolCallMessage']
+class Baichuan13BForCausalLM(BaseModelForCausalLM):
+ pass
+class Baichuan7BForCausalLM(BaseModelForCausalLM):
+ pass
+class BaichuanTokenizer(BaseTokenizer):
+ pass
+class BaseModelForCausalLM:
+ def generate_next_token(self, input_ids: list[int], gen_config: GenerationConfig, n_past: int, n_ctx: int) -> int:
+ ...
+ @property
+ def config(self) -> ModelConfig:
+ ...
+class BaseTokenizer:
+ def decode(self, ids: list[int]) -> str:
+ ...
+ def decode_message(self, ids: list[int]) -> ChatMessage:
+ ...
+ def encode(self, text: str, max_length: int) -> list[int]:
+ ...
+ def encode_messages(self, messages: list[ChatMessage], max_length: int) -> list[int]:
+ ...
+class ChatGLM2ForCausalLM(BaseModelForCausalLM):
+ pass
+class ChatGLM2Tokenizer(BaseTokenizer):
+ pass
+class ChatGLM3Tokenizer(BaseTokenizer):
+ pass
+class ChatGLMForCausalLM(BaseModelForCausalLM):
+ pass
+class ChatGLMTokenizer(BaseTokenizer):
+ pass
+class ChatMessage:
+ ROLE_ASSISTANT: typing.ClassVar[str] = 'assistant'
+ ROLE_OBSERVATION: typing.ClassVar[str] = 'observation'
+ ROLE_SYSTEM: typing.ClassVar[str] = 'system'
+ ROLE_USER: typing.ClassVar[str] = 'user'
+ content: str
+ role: str
+ tool_calls: list[ToolCallMessage]
+ def __init__(self, role: str, content: str, tool_calls: list[ToolCallMessage] = []) -> None:
+ ...
+ def __repr__(self) -> str:
+ ...
+ def __str__(self) -> str:
+ ...
+class CodeMessage:
+ input: str
+ def __repr__(self) -> str:
+ ...
+ def __str__(self) -> str:
+ ...
+class FunctionMessage:
+ arguments: str
+ name: str
+ def __repr__(self) -> str:
+ ...
+ def __str__(self) -> str:
+ ...
+class GenerationConfig:
+ do_sample: bool
+ max_context_length: int
+ max_length: int
+ num_threads: int
+ repetition_penalty: float
+ temperature: float
+ top_k: int
+ top_p: float
+ def __init__(self, max_length: int = 2048, max_context_length: int = 512, do_sample: bool = True, top_k: int = 0, top_p: float = 0.7, temperature: float = 0.95, repetition_penalty: float = 1.0, num_threads: int = 0) -> None:
+ ...
+class InternLM20BForCausalLM(BaseModelForCausalLM):
+ pass
+class InternLM7BForCausalLM(BaseModelForCausalLM):
+ pass
+class InternLMTokenizer(BaseTokenizer):
+ pass
+class ModelConfig:
+ @property
+ def bos_token_id(self) -> int:
+ ...
+ @property
+ def eos_token_id(self) -> int:
+ ...
+ @property
+ def extra_eos_token_ids(self) -> list[int]:
+ ...
+ @property
+ def hidden_size(self) -> int:
+ ...
+ @property
+ def intermediate_size(self) -> int:
+ ...
+ @property
+ def max_length(self) -> int:
+ ...
+ @property
+ def model_type(self) -> ModelType:
+ ...
+ @property
+ def model_type_name(self) -> str:
+ ...
+ @property
+ def norm_eps(self) -> float:
+ ...
+ @property
+ def num_attention_heads(self) -> int:
+ ...
+ @property
+ def num_hidden_layers(self) -> int:
+ ...
+ @property
+ def num_kv_heads(self) -> int:
+ ...
+ @property
+ def pad_token_id(self) -> int:
+ ...
+ @property
+ def sep_token_id(self) -> int:
+ ...
+ @property
+ def vocab_size(self) -> int:
+ ...
+class ModelType:
+ """
+ Members:
+
+ CHATGLM
+
+ CHATGLM2
+
+ CHATGLM3
+
+ BAICHUAN7B
+
+ BAICHUAN13B
+
+ INTERNLM
+ """
+ BAICHUAN13B: typing.ClassVar[ModelType] # value =
+ BAICHUAN7B: typing.ClassVar[ModelType] # value =
+ CHATGLM: typing.ClassVar[ModelType] # value =
+ CHATGLM2: typing.ClassVar[ModelType] # value =
+ CHATGLM3: typing.ClassVar[ModelType] # value =
+ INTERNLM: typing.ClassVar[ModelType] # value =
+ __members__: typing.ClassVar[dict[str, ModelType]] # value = {'CHATGLM': , 'CHATGLM2': , 'CHATGLM3': , 'BAICHUAN7B': , 'BAICHUAN13B': , 'INTERNLM': }
+ def __eq__(self, other: typing.Any) -> bool:
+ ...
+ def __getstate__(self) -> int:
+ ...
+ def __hash__(self) -> int:
+ ...
+ def __index__(self) -> int:
+ ...
+ def __init__(self, value: int) -> None:
+ ...
+ def __int__(self) -> int:
+ ...
+ def __ne__(self, other: typing.Any) -> bool:
+ ...
+ def __repr__(self) -> str:
+ ...
+ def __setstate__(self, state: int) -> None:
+ ...
+ def __str__(self) -> str:
+ ...
+ @property
+ def name(self) -> str:
+ ...
+ @property
+ def value(self) -> int:
+ ...
+class Pipeline:
+ def __init__(self, path: str) -> None:
+ ...
+ @property
+ def model(self) -> BaseModelForCausalLM:
+ ...
+ @property
+ def tokenizer(self) -> BaseTokenizer:
+ ...
+class ToolCallMessage:
+ code: CodeMessage
+ function: FunctionMessage
+ type: str
+ def __repr__(self) -> str:
+ ...
+ def __str__(self) -> str:
+ ...
diff --git a/chatglm_cpp/__init__.py b/chatglm_cpp/__init__.py
index a1dc1836..d3938bd0 100644
--- a/chatglm_cpp/__init__.py
+++ b/chatglm_cpp/__init__.py
@@ -1,11 +1,29 @@
import tempfile
-import warnings
+from dataclasses import dataclass
from pathlib import Path
-from typing import Iterator, List, Optional, Union
+from typing import Any, Dict, Iterator, List, Optional, Union
import chatglm_cpp._C as _C
+from chatglm_cpp._C import ChatMessage
-__version__ = "0.2.10"
+__version__ = "0.3.0"
+
+
+@dataclass
+class DeltaMessage:
+ role: str
+ content: str
+ token_ids: List[int]
+
+
+def _ensure_chat_message(message: Union[ChatMessage, Dict[str, Any]]) -> ChatMessage:
+ if isinstance(message, ChatMessage):
+ chat_message = message
+ elif isinstance(message, dict):
+ chat_message = ChatMessage(**message)
+ else:
+ raise TypeError(f"expect message type to be ChatMessage or dict, but got {type(message)}")
+ return chat_message
class Pipeline(_C.Pipeline):
@@ -26,7 +44,7 @@ def __init__(self, model_path: str, *, dtype: Optional[str] = None) -> None:
def chat(
self,
- history: List[str],
+ messages: List[ChatMessage],
*,
max_length: int = 2048,
max_context_length: int = 512,
@@ -37,10 +55,10 @@ def chat(
repetition_penalty: float = 1.0,
num_threads: int = 0,
stream: bool = False,
- ) -> Union[Iterator[str], str]:
- input_ids = self.tokenizer.encode_history(history, max_context_length)
- return self._generate(
- input_ids=input_ids,
+ ) -> Union[Iterator[DeltaMessage], ChatMessage]:
+ messages = [_ensure_chat_message(msg) for msg in messages]
+ input_ids = self.tokenizer.encode_messages(messages, max_context_length)
+ gen_config = _C.GenerationConfig(
max_length=max_length,
max_context_length=max_context_length,
do_sample=do_sample,
@@ -49,8 +67,10 @@ def chat(
temperature=temperature,
repetition_penalty=repetition_penalty,
num_threads=num_threads,
- stream=stream,
)
+ if stream:
+ return self._stream_chat(input_ids=input_ids, gen_config=gen_config)
+ return self._sync_chat(input_ids=input_ids, gen_config=gen_config)
def generate(
self,
@@ -67,33 +87,6 @@ def generate(
stream: bool = False,
) -> Union[Iterator[str], str]:
input_ids = self.tokenizer.encode(prompt, max_context_length)
- return self._generate(
- input_ids=input_ids,
- max_length=max_length,
- max_context_length=max_context_length,
- do_sample=do_sample,
- top_k=top_k,
- top_p=top_p,
- temperature=temperature,
- repetition_penalty=repetition_penalty,
- num_threads=num_threads,
- stream=stream,
- )
-
- def _generate(
- self,
- input_ids: List[int],
- *,
- max_length: int = 2048,
- max_context_length: int = 512,
- do_sample: bool = True,
- top_k: int = 0,
- top_p: float = 0.7,
- temperature: float = 0.95,
- repetition_penalty: float = 1.0,
- num_threads: int = 0,
- stream: bool = False,
- ) -> Union[Iterator[str], str]:
gen_config = _C.GenerationConfig(
max_length=max_length,
max_context_length=max_context_length,
@@ -104,83 +97,68 @@ def _generate(
repetition_penalty=repetition_penalty,
num_threads=num_threads,
)
+ if stream:
+ return self._stream_generate(input_ids=input_ids, gen_config=gen_config)
+ return self._sync_generate(input_ids=input_ids, gen_config=gen_config)
- generate_fn = self._stream_generate if stream else self._sync_generate
- return generate_fn(input_ids=input_ids, gen_config=gen_config)
-
- def _stream_generate(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> Iterator[str]:
- input_ids = [x for x in input_ids] # make a copy
+ def _stream_generate_ids(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> Iterator[int]:
+ input_ids = input_ids.copy()
n_past = 0
n_ctx = len(input_ids)
- token_cache = []
- print_len = 0
while len(input_ids) < gen_config.max_length:
next_token_id = self.model.generate_next_token(input_ids, gen_config, n_past, n_ctx)
+ yield next_token_id
n_past = len(input_ids)
input_ids.append(next_token_id)
+ if next_token_id in [self.model.config.eos_token_id, *self.model.config.extra_eos_token_ids]:
+ break
+
+ def _stream_chat(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> Iterator[DeltaMessage]:
+ token_cache = []
+ print_len = 0
+ print_token_len = 0
+ for next_token_id in self._stream_generate_ids(input_ids=input_ids, gen_config=gen_config):
token_cache.append(next_token_id)
output = self.tokenizer.decode(token_cache)
if output.endswith("\n"):
- yield output[print_len:]
+ yield DeltaMessage(
+ role=ChatMessage.ROLE_ASSISTANT, content=output[print_len:], token_ids=token_cache[print_token_len:]
+ )
token_cache = []
print_len = 0
+ print_token_len = 0
elif output.endswith((",", "!", ":", ";", "?", "�")):
pass
else:
- yield output[print_len:]
+ yield DeltaMessage(
+ role=ChatMessage.ROLE_ASSISTANT, content=output[print_len:], token_ids=token_cache[print_token_len:]
+ )
print_len = len(output)
-
- if next_token_id == self.model.config.eos_token_id:
- break
+ print_token_len = len(token_cache)
output = self.tokenizer.decode(token_cache)
- yield output[print_len:]
+ yield DeltaMessage(
+ role=ChatMessage.ROLE_ASSISTANT, content=output[print_len:], token_ids=token_cache[print_token_len:]
+ )
- def _sync_generate(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> str:
- input_ids = [x for x in input_ids] # make a copy
- n_past = 0
- n_ctx = len(input_ids)
+ def _stream_generate(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> Iterator[str]:
+ for msg in self._stream_chat(input_ids=input_ids, gen_config=gen_config):
+ yield msg.content
- while len(input_ids) < gen_config.max_length:
- next_token_id = self.model.generate_next_token(input_ids, gen_config, n_past, n_ctx)
- n_past = len(input_ids)
- input_ids.append(next_token_id)
- if next_token_id == self.model.config.eos_token_id:
- break
+ def _sync_generate_ids(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> List[int]:
+ return list(self._stream_generate_ids(input_ids=input_ids, gen_config=gen_config))
- output = self.tokenizer.decode(input_ids[n_ctx:])
- return output
+ def _sync_generate(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> str:
+ output_ids = self._sync_generate_ids(input_ids=input_ids, gen_config=gen_config)
+ return self.tokenizer.decode(output_ids)
- def stream_chat(
- self,
- history: List[str],
- *,
- max_length: int = 2048,
- max_context_length: int = 512,
- do_sample: bool = True,
- top_k: int = 0,
- top_p: float = 0.7,
- temperature: float = 0.95,
- repetition_penalty: float = 1.0,
- num_threads: int = 0,
- ) -> Iterator[str]:
- warnings.warn(
- "stream_chat is deprecated in favor of chat(..., stream=True), and will be removed in the next major version of chatglm-cpp",
- DeprecationWarning,
- stacklevel=2,
- )
- return self.chat(
- history=history,
- max_length=max_length,
- max_context_length=max_context_length,
- do_sample=do_sample,
- top_k=top_k,
- top_p=top_p,
- temperature=temperature,
- repetition_penalty=repetition_penalty,
- num_threads=num_threads,
- stream=True,
- )
+ def _sync_chat(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> ChatMessage:
+ output_ids = self._sync_generate_ids(input_ids=input_ids, gen_config=gen_config)
+ return self.tokenizer.decode_message(output_ids)
+
+ def merge_streaming_messages(self, chunks: List[DeltaMessage]) -> ChatMessage:
+ output_ids = [x for chunk in chunks for x in chunk.token_ids]
+ return self.tokenizer.decode_message(output_ids)
diff --git a/chatglm_cpp/convert.py b/chatglm_cpp/convert.py
index 53d5f4c7..32275b98 100644
--- a/chatglm_cpp/convert.py
+++ b/chatglm_cpp/convert.py
@@ -25,7 +25,7 @@
if platform.system() == "Darwin":
# cpm_kernels doesn't support macOS but transformers will check missing packages, so mock it
- sys.modules["cpm_kernels"] = object()
+ sys.modules["cpm_kernels"] = object() # type: ignore
class GGMLType(Enum):
@@ -47,7 +47,7 @@ class ModelType(Enum):
INTERNLM = 1280
-def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
+def quantize_q8_0(tensor: torch.Tensor) -> torch.Tensor:
# equivalent to ggml_quantize_q8_0 in ggml.c
assert tensor.shape[1] % GGML_QK8_0 == 0
tensor = tensor.view(-1, GGML_QK8_0)
@@ -58,7 +58,7 @@ def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
return tensor
-def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
+def quantize_q4_0(tensor: torch.Tensor) -> torch.Tensor:
# equivalent to ggml_quantize_q4_0 in ggml.c
assert tensor.shape[1] % GGML_QK4_0 == 0
tensor = tensor.view(-1, GGML_QK4_0)
@@ -73,7 +73,7 @@ def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
return tensor
-def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
+def quantize_q4_1(tensor: torch.Tensor) -> torch.Tensor:
# equivalent to ggml_quantize_q4_1 in ggml.c
assert tensor.shape[1] % GGML_QK4_1 == 0
tensor = tensor.view(-1, GGML_QK4_1)
@@ -88,7 +88,7 @@ def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
return tensor
-def quantize_q5_0(tensor: torch.Tensor) -> torch.CharTensor:
+def quantize_q5_0(tensor: torch.Tensor) -> torch.Tensor:
# equivalent to ggml_quantize_q5_0 in ggml.c
assert tensor.shape[1] % GGML_QK5_0 == 0
tensor = tensor.view(-1, GGML_QK5_0)
@@ -106,7 +106,7 @@ def quantize_q5_0(tensor: torch.Tensor) -> torch.CharTensor:
return tensor
-def quantize_q5_1(tensor: torch.Tensor) -> torch.CharTensor:
+def quantize_q5_1(tensor: torch.Tensor) -> torch.Tensor:
# equivalent to ggml_quantize_q5_1 in ggml.c
assert tensor.shape[1] % GGML_QK5_1 == 0
tensor = tensor.view(-1, GGML_QK5_1)
diff --git a/chatglm_cpp/langchain_api.py b/chatglm_cpp/langchain_api.py
index 5cb19115..ceea5988 100644
--- a/chatglm_cpp/langchain_api.py
+++ b/chatglm_cpp/langchain_api.py
@@ -53,17 +53,27 @@ class ChatResponse(BaseModel):
@app.post("/")
async def chat(body: ChatRequest) -> ChatResponse:
- chat_history = [msg for pair in body.history for msg in pair] + [body.prompt]
- response = pipeline.chat(
- chat_history,
+ messages = []
+ for prompt, response in body.history:
+ messages += [
+ chatglm_cpp.ChatMessage(role="user", content=prompt),
+ chatglm_cpp.ChatMessage(role="assistant", content=response),
+ ]
+ messages.append(chatglm_cpp.ChatMessage(role="user", content=body.prompt))
+
+ output = pipeline.chat(
+ messages,
max_length=body.max_length,
do_sample=body.temperature > 0,
top_p=body.top_p,
temperature=body.temperature,
)
- history = body.history + [(body.prompt, response)]
+ history = body.history + [(body.prompt, output.content)]
answer = ChatResponse(
- response=response, history=history, status=status.HTTP_200_OK, time=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ response=output.content,
+ history=history,
+ status=status.HTTP_200_OK,
+ time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
)
- logging.info(f'prompt: "{body.prompt}", response: "{response}"')
+ logging.info(f'prompt: "{body.prompt}", response: "{output.content}"')
return answer
diff --git a/chatglm_cpp/openai_api.py b/chatglm_cpp/openai_api.py
index fbda7e95..20976899 100644
--- a/chatglm_cpp/openai_api.py
+++ b/chatglm_cpp/openai_api.py
@@ -102,14 +102,14 @@ class ChatCompletionResponse(BaseModel):
lock = asyncio.Lock()
-def stream_chat(history, body):
+def stream_chat(messages, body):
yield ChatCompletionResponse(
object="chat.completion.chunk",
choices=[ChatCompletionResponseStreamChoice(delta=DeltaMessage(role="assistant"))],
)
- for piece in pipeline.chat(
- history,
+ for chunk in pipeline.chat(
+ messages=messages,
max_length=body.max_tokens,
do_sample=body.temperature > 0,
top_p=body.top_p,
@@ -119,7 +119,7 @@ def stream_chat(history, body):
):
yield ChatCompletionResponse(
object="chat.completion.chunk",
- choices=[ChatCompletionResponseStreamChoice(delta=DeltaMessage(content=piece))],
+ choices=[ChatCompletionResponseStreamChoice(delta=DeltaMessage(content=chunk.content))],
)
yield ChatCompletionResponse(
@@ -144,31 +144,31 @@ async def stream_chat_event_publisher(history, body):
@app.post("/v1/chat/completions")
async def create_chat_completion(body: ChatCompletionRequest) -> ChatCompletionResponse:
- # ignore system messages
- history = [msg.content for msg in body.messages if msg.role != "system"]
- if len(history) % 2 != 1:
- raise HTTPException(status.HTTP_400_BAD_REQUEST, "invalid history size")
+ if not body.messages:
+ raise HTTPException(status.HTTP_400_BAD_REQUEST, "empty messages")
+
+ messages = [chatglm_cpp.ChatMessage(role=msg.role, content=msg.content) for msg in body.messages]
if body.stream:
- generator = stream_chat_event_publisher(history, body)
+ generator = stream_chat_event_publisher(messages, body)
return EventSourceResponse(generator)
max_context_length = 512
output = pipeline.chat(
- history=history,
+ messages=messages,
max_length=body.max_tokens,
max_context_length=max_context_length,
do_sample=body.temperature > 0,
top_p=body.top_p,
temperature=body.temperature,
)
- logging.info(f'prompt: "{history[-1]}", sync response: "{output}"')
- prompt_tokens = len(pipeline.tokenizer.encode_history(history, max_context_length))
- completion_tokens = len(pipeline.tokenizer.encode(output, body.max_tokens))
+ logging.info(f'prompt: "{messages[-1].content}", sync response: "{output.content}"')
+ prompt_tokens = len(pipeline.tokenizer.encode_messages(messages, max_context_length))
+ completion_tokens = len(pipeline.tokenizer.encode(output.content, body.max_tokens))
return ChatCompletionResponse(
object="chat.completion",
- choices=[ChatCompletionResponseChoice(message=ChatMessage(role="assistant", content=output))],
+ choices=[ChatCompletionResponseChoice(message=ChatMessage(role="assistant", content=output.content))],
usage=ChatCompletionUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
)
diff --git a/chatglm_pybind.cpp b/chatglm_pybind.cpp
index 2fcd7c7d..24749b8d 100644
--- a/chatglm_pybind.cpp
+++ b/chatglm_pybind.cpp
@@ -17,8 +17,8 @@ class PyBaseTokenizer : public BaseTokenizer {
std::string decode(const std::vector &ids) const override {
PYBIND11_OVERLOAD_PURE(std::string, BaseTokenizer, decode, ids);
}
- std::vector encode_history(const std::vector &history, int max_length) const override {
- PYBIND11_OVERLOAD_PURE(std::vector, BaseTokenizer, encode_history, history, max_length);
+ std::vector encode_messages(const std::vector &history, int max_length) const override {
+ PYBIND11_OVERLOAD_PURE(std::vector, BaseTokenizer, encode_messages, history, max_length);
}
};
@@ -32,12 +32,27 @@ class PyBaseModelForCausalLM : public BaseModelForCausalLM {
}
};
+template
+static inline std::string to_string(const T &obj) {
+ std::ostringstream oss;
+ oss << obj;
+ return oss.str();
+}
+
PYBIND11_MODULE(_C, m) {
m.doc() = "ChatGLM.cpp python binding";
+ py::enum_(m, "ModelType")
+ .value("CHATGLM", ModelType::CHATGLM)
+ .value("CHATGLM2", ModelType::CHATGLM2)
+ .value("CHATGLM3", ModelType::CHATGLM3)
+ .value("BAICHUAN7B", ModelType::BAICHUAN7B)
+ .value("BAICHUAN13B", ModelType::BAICHUAN13B)
+ .value("INTERNLM", ModelType::INTERNLM);
+
py::class_(m, "ModelConfig")
.def_readonly("model_type", &ModelConfig::model_type)
- .def_readonly("dtype", &ModelConfig::dtype)
+ // .def_readonly("dtype", &ModelConfig::dtype)
.def_readonly("vocab_size", &ModelConfig::vocab_size)
.def_readonly("hidden_size", &ModelConfig::hidden_size)
.def_readonly("num_attention_heads", &ModelConfig::num_attention_heads)
@@ -50,17 +65,9 @@ PYBIND11_MODULE(_C, m) {
.def_readonly("eos_token_id", &ModelConfig::eos_token_id)
.def_readonly("pad_token_id", &ModelConfig::pad_token_id)
.def_readonly("sep_token_id", &ModelConfig::sep_token_id)
+ .def_readonly("extra_eos_token_ids", &ModelConfig::extra_eos_token_ids)
.def_property_readonly("model_type_name", &ModelConfig::model_type_name);
- py::class_(m, "BaseTokenizer")
- .def("encode", &BaseTokenizer::encode)
- .def("decode", &BaseTokenizer::decode)
- .def("encode_history", &BaseTokenizer::encode_history);
-
- py::class_(m, "BaseModelForCausalLM")
- .def("generate_next_token", &BaseModelForCausalLM::generate_next_token)
- .def_readonly("config", &BaseModelForCausalLM::config);
-
py::class_(m, "GenerationConfig")
.def(py::init(), "max_length"_a = 2048,
"max_context_length"_a = 512, "do_sample"_a = true, "top_k"_a = 0, "top_p"_a = 0.7, "temperature"_a = 0.95,
@@ -74,6 +81,48 @@ PYBIND11_MODULE(_C, m) {
.def_readwrite("repetition_penalty", &GenerationConfig::repetition_penalty)
.def_readwrite("num_threads", &GenerationConfig::num_threads);
+ py::class_(m, "FunctionMessage")
+ .def("__repr__", &to_string)
+ .def("__str__", &to_string)
+ .def_readwrite("name", &FunctionMessage::name)
+ .def_readwrite("arguments", &FunctionMessage::arguments);
+
+ py::class_(m, "CodeMessage")
+ .def("__repr__", &to_string)
+ .def("__str__", &to_string)
+ .def_readwrite("input", &CodeMessage::input);
+
+ py::class_(m, "ToolCallMessage")
+ .def("__repr__", &to_string)
+ .def("__str__", &to_string)
+ .def_readwrite("type", &ToolCallMessage::type)
+ .def_readwrite("function", &ToolCallMessage::function)
+ .def_readwrite("code", &ToolCallMessage::code);
+
+ py::class_(m, "ChatMessage")
+ .def(py::init>(), "role"_a, "content"_a,
+ "tool_calls"_a = std::vector{})
+ .def("__repr__", &to_string)
+ .def("__str__", &to_string)
+ .def_readonly_static("ROLE_SYSTEM", &ChatMessage::ROLE_SYSTEM)
+ .def_readonly_static("ROLE_USER", &ChatMessage::ROLE_USER)
+ .def_readonly_static("ROLE_ASSISTANT", &ChatMessage::ROLE_ASSISTANT)
+ .def_readonly_static("ROLE_OBSERVATION", &ChatMessage::ROLE_OBSERVATION)
+ .def_readwrite("role", &ChatMessage::role)
+ .def_readwrite("content", &ChatMessage::content)
+ .def_readwrite("tool_calls", &ChatMessage::tool_calls);
+
+ py::class_(m, "BaseTokenizer")
+ .def("encode", &BaseTokenizer::encode, "text"_a, "max_length"_a)
+ .def("decode", &BaseTokenizer::decode, "ids"_a)
+ .def("encode_messages", &BaseTokenizer::encode_messages, "messages"_a, "max_length"_a)
+ .def("decode_message", &BaseTokenizer::decode_message, "ids"_a);
+
+ py::class_(m, "BaseModelForCausalLM")
+ .def("generate_next_token", &BaseModelForCausalLM::generate_next_token, "input_ids"_a, "gen_config"_a,
+ "n_past"_a, "n_ctx"_a)
+ .def_readonly("config", &BaseModelForCausalLM::config);
+
// ===== ChatGLM =====
py::class_(m, "ChatGLMTokenizer");
@@ -109,7 +158,7 @@ PYBIND11_MODULE(_C, m) {
// ===== Pipeline ====
py::class_(m, "Pipeline")
- .def(py::init())
+ .def(py::init(), "path"_a)
.def_property_readonly("model", [](const Pipeline &self) { return self.model.get(); })
.def_property_readonly("tokenizer", [](const Pipeline &self) { return self.tokenizer.get(); });
}
diff --git a/chatglm_test.cpp b/chatglm_test.cpp
index 6df30289..03b8cd5d 100644
--- a/chatglm_test.cpp
+++ b/chatglm_test.cpp
@@ -1022,12 +1022,15 @@ TEST(Pipeline, ChatGLM) {
// prompter
{
- EXPECT_EQ(ChatGLMTokenizer::build_prompt({"你好"}), "你好");
- EXPECT_EQ(ChatGLMTokenizer::build_prompt(
- {"你好", "你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。",
- "晚上睡不着应该怎么办"}),
- "[Round 0]\n问:你好\n答:你好👋!我是人工智能助手 "
- "ChatGLM-6B,很高兴见到你,欢迎问我任何问题。\n[Round 1]\n问:晚上睡不着应该怎么办\n答:");
+ EXPECT_EQ(ChatGLMTokenizer::build_prompt({{ChatMessage::ROLE_USER, "你好"}}), "你好");
+ EXPECT_EQ(
+ ChatGLMTokenizer::build_prompt({
+ {ChatMessage::ROLE_USER, "你好"},
+ {ChatMessage::ROLE_ASSISTANT, "你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"},
+ {ChatMessage::ROLE_USER, "晚上睡不着应该怎么办"},
+ }),
+ "[Round 0]\n问:你好\n答:你好👋!我是人工智能助手 "
+ "ChatGLM-6B,很高兴见到你,欢迎问我任何问题。\n[Round 1]\n问:晚上睡不着应该怎么办\n答:");
}
// memory test
@@ -1041,17 +1044,17 @@ TEST(Pipeline, ChatGLM) {
for (int i = 0; i < gen_config.max_context_length; i++) {
oss << "你好";
}
- std::vector history{oss.str()};
- pipeline.chat(history, gen_config);
+ std::vector messages{{ChatMessage::ROLE_USER, oss.str()}};
+ pipeline.chat(messages, gen_config);
}
// chat
{
GenerationConfig gen_config;
gen_config.do_sample = false;
- std::vector history{"你好"};
- std::string output = pipeline.chat(history, gen_config);
- EXPECT_EQ(output, "你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。");
+ std::vector messages{{ChatMessage::ROLE_USER, "你好"}};
+ ChatMessage output = pipeline.chat(messages, gen_config);
+ EXPECT_EQ(output.content, "你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。");
}
}
@@ -1083,12 +1086,15 @@ TEST(Pipeline, ChatGLM2) {
// prompter
{
- EXPECT_EQ(ChatGLM2Tokenizer::build_prompt({"你好"}), "[Round 1]\n\n问:你好\n\n答:");
- EXPECT_EQ(ChatGLM2Tokenizer::build_prompt(
- {"你好", "你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。",
- "晚上睡不着应该怎么办"}),
- "[Round 1]\n\n问:你好\n\n答:你好👋!我是人工智能助手 "
- "ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。\n\n[Round 2]\n\n问:晚上睡不着应该怎么办\n\n答:");
+ EXPECT_EQ(ChatGLM2Tokenizer::build_prompt({{ChatMessage::ROLE_USER, "你好"}}), "[Round 1]\n\n问:你好\n\n答:");
+ EXPECT_EQ(
+ ChatGLM2Tokenizer::build_prompt({
+ {ChatMessage::ROLE_USER, "你好"},
+ {ChatMessage::ROLE_ASSISTANT, "你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。"},
+ {ChatMessage::ROLE_USER, "晚上睡不着应该怎么办"},
+ }),
+ "[Round 1]\n\n问:你好\n\n答:你好👋!我是人工智能助手 "
+ "ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。\n\n[Round 2]\n\n问:晚上睡不着应该怎么办\n\n答:");
}
// memory test
@@ -1102,20 +1108,25 @@ TEST(Pipeline, ChatGLM2) {
for (int i = 0; i < gen_config.max_context_length; i++) {
oss << "你好";
}
- std::vector history{oss.str()};
- pipeline.chat(history, gen_config);
+ std::vector messages{{ChatMessage::ROLE_USER, oss.str()}};
+ pipeline.chat(messages, gen_config);
}
// chat
{
GenerationConfig gen_config;
gen_config.do_sample = false;
- std::vector history{"你好"};
- std::string output = pipeline.chat(history, gen_config);
- EXPECT_EQ(output, "你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。");
+ std::vector messages{{ChatMessage::ROLE_USER, "你好"}};
+ ChatMessage output = pipeline.chat(messages, gen_config);
+ EXPECT_EQ(output.content, "你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。");
}
}
+static inline std::string read_text(const fs::path &path) {
+ MappedFile mapped_file(path.string());
+ return std::string(mapped_file.data, mapped_file.size);
+}
+
TEST(Pipeline, ChatGLM3) {
fs::path model_path = fs::path(__FILE__).parent_path() / "chatglm3-ggml.bin";
if (!fs::exists(model_path)) {
@@ -1124,29 +1135,62 @@ TEST(Pipeline, ChatGLM3) {
Pipeline pipeline(model_path.string());
EXPECT_TRUE(dynamic_cast(pipeline.model.get()));
+ const std::string system_tool_call =
+ read_text(fs::path(__FILE__).parent_path() / "examples/system/function_call.txt");
+ const std::string system_ci = read_text(fs::path(__FILE__).parent_path() / "examples/system/code_interpreter.txt");
+
// tokenizer
{
- std::vector cases{{"你好", {64790, 64792, 36474, 54591}}};
- check_tokenizer(pipeline.tokenizer.get(), cases);
-
- {
- std::vector history{"你好"};
- std::vector input_ids = pipeline.tokenizer->encode_history(history, 2048);
- std::vector target_ids{64790, 64792, 64795, 30910, 13, 36474, 54591, 64796, 30910, 13};
- EXPECT_EQ(input_ids, target_ids);
- }
- {
- std::vector history{"你好",
- "你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。",
- "晚上睡不着应该怎么办"};
- std::vector input_ids = pipeline.tokenizer->encode_history(history, 2048);
- std::vector target_ids{64790, 64792, 64795, 30910, 13, 36474, 54591, 64796, 30910, 13,
- 36474, 54591, 243, 162, 148, 142, 31404, 33030, 34797, 42481,
- 22011, 10461, 30944, 30966, 30941, 30978, 30949, 31123, 48895, 35214,
- 54622, 31123, 32616, 39905, 31901, 31639, 31155, 64795, 30910, 13,
- 30910, 32820, 54266, 31876, 35153, 64796, 30910, 13};
- EXPECT_EQ(input_ids, target_ids);
- }
+ std::vector target_ids{64790, 64792, 36474, 54591};
+ std::vector input_ids = pipeline.tokenizer->encode("你好", 2048);
+ EXPECT_EQ(input_ids, target_ids);
+ }
+ {
+ std::vector messages{{ChatMessage::ROLE_USER, "你好"}};
+ std::vector input_ids = pipeline.tokenizer->encode_messages(messages, 2048);
+ std::vector target_ids{64790, 64792, 64795, 30910, 13, 36474, 54591, 64796};
+ EXPECT_EQ(input_ids, target_ids);
+ }
+ {
+ std::vector messages{
+ {ChatMessage::ROLE_USER, "你好"},
+ {ChatMessage::ROLE_ASSISTANT, "你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。"},
+ {ChatMessage::ROLE_USER, "晚上睡不着应该怎么办"},
+ };
+ std::vector input_ids = pipeline.tokenizer->encode_messages(messages, 2048);
+ std::vector target_ids{64790, 64792, 64795, 30910, 13, 36474, 54591, 64796, 30910, 13, 36474, 54591,
+ 243, 162, 148, 142, 31404, 33030, 34797, 42481, 22011, 10461, 30944, 30966,
+ 30941, 30978, 30949, 31123, 48895, 35214, 54622, 31123, 32616, 39905, 31901, 31639,
+ 31155, 64795, 30910, 13, 30910, 32820, 54266, 31876, 35153, 64796};
+ EXPECT_EQ(input_ids, target_ids);
+ }
+ {
+ std::vector messages{
+ {ChatMessage::ROLE_SYSTEM, system_tool_call},
+ {ChatMessage::ROLE_USER, "生成一个随机数"},
+ };
+ std::vector input_ids = pipeline.tokenizer->encode_messages(messages, 2048);
+ std::vector target_ids{
+ 64790, 64792, 64794, 30910, 13, 20115, 267, 1762, 2554, 362, 1077, 362, 344, 457, 30930,
+ 809, 431, 1675, 289, 267, 1762, 4159, 30954, 13, 30982, 13, 296, 30955, 16599, 30962,
+ 11228, 30962, 7311, 1306, 2932, 729, 13, 352, 30955, 2323, 2932, 449, 16599, 30962, 11228,
+ 30962, 7311, 1306, 1252, 13, 352, 30955, 16302, 2932, 449, 9398, 711, 260, 5402, 1276,
+ 1994, 30932, 268, 30930, 30912, 30930, 2288, 30995, 30940, 30996, 14819, 1994, 906, 2288, 30995,
+ 30939, 30996, 1252, 13, 352, 30955, 12209, 2932, 790, 13, 753, 30982, 13, 647, 30955,
+ 2323, 2932, 449, 24794, 1252, 13, 647, 30955, 16302, 2932, 449, 1036, 5402, 9352, 1050,
+ 422, 267, 17009, 1252, 13, 647, 30955, 3543, 2932, 449, 592, 1252, 13, 647, 30955,
+ 20379, 2932, 2033, 13, 753, 4143, 13, 753, 30982, 13, 647, 30955, 2323, 2932, 449,
+ 7855, 1252, 13, 647, 30955, 16302, 2932, 449, 1036, 2288, 290, 267, 7383, 3859, 1252,
+ 13, 647, 30955, 3543, 2932, 449, 30912, 16471, 30995, 592, 30932, 558, 30996, 1252, 13,
+ 647, 30955, 20379, 2932, 2033, 13, 753, 30983, 13, 352, 30996, 13, 296, 4143, 13,
+ 296, 30955, 752, 30962, 27564, 2932, 729, 13, 352, 30955, 2323, 2932, 449, 752, 30962,
+ 27564, 1252, 13, 352, 30955, 16302, 2932, 449, 4867, 267, 1465, 5100, 332, 4256, 17654,
+ 30962, 2323, 31040, 1252, 13, 352, 30955, 12209, 2932, 790, 13, 753, 30982, 13, 647,
+ 30955, 2323, 2932, 449, 17654, 30962, 2323, 1252, 13, 647, 30955, 16302, 2932, 449, 1036,
+ 1462, 290, 267, 1911, 289, 330, 580, 266, 819, 1252, 13, 647, 30955, 3543, 2932,
+ 449, 2069, 1252, 13, 647, 30955, 20379, 2932, 2033, 13, 753, 30983, 13, 352, 30996,
+ 13, 296, 30983, 13, 30983, 64795, 30910, 13, 30910, 36454, 31623, 37853, 54744, 64796};
+ EXPECT_EQ(input_ids, target_ids);
}
// memory test
@@ -1160,17 +1204,90 @@ TEST(Pipeline, ChatGLM3) {
for (int i = 0; i < gen_config.max_context_length; i++) {
oss << "你好";
}
- std::vector history{oss.str()};
- pipeline.chat(history, gen_config);
+ std::vector messages{{ChatMessage::ROLE_USER, oss.str()}};
+ pipeline.chat(messages, gen_config);
}
// chat
{
GenerationConfig gen_config;
gen_config.do_sample = false;
- std::vector history{"你好"};
- std::string output = pipeline.chat(history, gen_config);
- EXPECT_EQ(output, "你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。");
+ std::vector messages{{ChatMessage::ROLE_USER, "你好"}};
+ ChatMessage output = pipeline.chat(messages, gen_config);
+ EXPECT_EQ(output.content, "你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。");
+ }
+
+ // tool call
+ {
+ GenerationConfig gen_config;
+ gen_config.do_sample = false;
+ std::vector messages{
+ {ChatMessage::ROLE_SYSTEM, system_tool_call},
+ {ChatMessage::ROLE_USER, "生成一个随机数"},
+ };
+ {
+ ChatMessage output = pipeline.chat(messages, gen_config);
+ EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT);
+ EXPECT_EQ(output.content, "```python\n"
+ "tool_call(seed=42, range=(0, 100))\n"
+ "```");
+ messages.emplace_back(std::move(output));
+ }
+ messages.emplace_back(ChatMessage::ROLE_OBSERVATION, "22");
+ {
+ ChatMessage output = pipeline.chat(messages, gen_config);
+ EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT);
+ EXPECT_EQ(output.content, "根据您的要求,我使用随机数生成器API生成了一个在0和100之间的随机数,结果为22。");
+ }
+ }
+
+ // code interpreter
+ {
+ GenerationConfig gen_config;
+ gen_config.do_sample = false;
+ std::vector messages{
+ {ChatMessage::ROLE_SYSTEM, system_ci},
+ {ChatMessage::ROLE_USER, "列出100以内的所有质数"},
+ };
+ {
+ ChatMessage output = pipeline.chat(messages, gen_config);
+ EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT);
+ EXPECT_EQ(output.content, "好的,我会为您列出100以内的所有质数。\n\n质数是指只能被1和它本身整除的大于1"
+ "的整数。例如,2、3、5、7等都是质数。\n\n让我们开始吧!");
+ EXPECT_EQ(output.tool_calls.front().code.input, R"(```python
+def is_prime(n):
+ """Check if a number is prime."""
+ if n <= 1:
+ return False
+ if n <= 3:
+ return True
+ if n % 2 == 0 or n % 3 == 0:
+ return False
+ i = 5
+ while i * i <= n:
+ if n % i == 0 or n % (i + 2) == 0:
+ return False
+ i += 6
+ return True
+
+# Get all prime numbers up to 100
+primes_upto_100 = [i for i in range(2, 101) if is_prime(i)]
+primes_upto_100
+```)");
+ messages.emplace_back(std::move(output));
+ }
+ messages.emplace_back(
+ ChatMessage::ROLE_OBSERVATION,
+ "[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]");
+ {
+ ChatMessage output = pipeline.chat(messages, gen_config);
+ EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT);
+ EXPECT_EQ(output.content, R"(100以内的所有质数为:
+
+$$
+2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97
+$$)");
+ }
}
}
@@ -1234,9 +1351,12 @@ TEST(Pipeline, Baichuan13B) {
1910, 73, 6011, 31169, 4315, 1766, 72, 1231, 11533, 31490, 31182, 21934}}};
check_tokenizer(pipeline.tokenizer.get(), cases);
- std::vector history{"你好呀", "你好!很高兴和你交流。请问有什么我可以帮助你的吗?",
- "你叫什么名字?"};
- std::vector input_ids = pipeline.tokenizer->encode_history(history, 2048);
+ std::vector messages{
+ {ChatMessage::ROLE_USER, "你好呀"},
+ {ChatMessage::ROLE_ASSISTANT, "你好!很高兴和你交流。请问有什么我可以帮助你的吗?"},
+ {ChatMessage::ROLE_USER, "你叫什么名字?"},
+ };
+ std::vector input_ids = pipeline.tokenizer->encode_messages(messages, 2048);
std::vector target_input_ids{195, 9875, 31213, 32889, 196, 9875, 31213, 74, 17318, 31906,
14822, 5536, 73, 20389, 7713, 31182, 1231, 4090, 2689, 31763,
75, 195, 9875, 32177, 1534, 10240, 75, 196};
@@ -1259,9 +1379,9 @@ TEST(Pipeline, Baichuan13B) {
GenerationConfig gen_config;
gen_config.do_sample = false;
gen_config.repetition_penalty = 1.1;
- std::vector history{"你好呀"};
- std::string output = pipeline.chat(history, gen_config);
- EXPECT_EQ(output, "你好!很高兴见到你。请问有什么我可以帮助你的吗?");
+ std::vector messages{{ChatMessage::ROLE_USER, "你好呀"}};
+ ChatMessage output = pipeline.chat(messages, gen_config);
+ EXPECT_EQ(output.content, "你好!很高兴见到你。请问有什么我可以帮助你的吗?");
}
}
@@ -1285,9 +1405,12 @@ TEST(Pipeline, Baichuan2_7B) {
2089, 23672, 1940, 1760, 66, 4173, 23181, 1754, 65, 65351, 39975, 14590}}};
check_tokenizer(pipeline.tokenizer.get(), cases);
- std::vector history{"你好呀", "你好!很高兴和你交流。请问有什么问题我可以帮助你解决吗?",
- "你叫什么名字?"};
- std::vector