diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 1d383115..972060d0 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -36,7 +36,7 @@ jobs: - name: Lint with black uses: psf/black@stable with: - options: "--check --verbose --line-length 120" + options: "--check --verbose" src: "chatglm_cpp examples tests setup.py" - name: Test with pytest run: | diff --git a/.gitignore b/.gitignore index 47826aca..31d3823e 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ __pycache__/ *.egg-info/ dist/ *.so +*.whl .hypothesis/ # cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 8b7ac72e..373b98c7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,7 +92,19 @@ file(GLOB PY_SOURCES add_custom_target(lint COMMAND clang-format -i ${CPP_SOURCES} COMMAND isort ${PY_SOURCES} - COMMAND black ${PY_SOURCES} --line-length 120) + COMMAND black ${PY_SOURCES} --verbose) + +# mypy +add_custom_target(mypy + mypy chatglm_cpp examples --exclude __init__.pyi + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} +) + +# stub +add_custom_target(stub + pybind11-stubgen chatglm_cpp -o . + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} +) if (MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall") diff --git a/README.md b/README.md index 3c05e02d..fc87692c 100644 --- a/README.md +++ b/README.md @@ -105,11 +105,70 @@ python3 chatglm_cpp/convert.py -i THUDM/chatglm2-6b -t q4_0 -o chatglm2-ggml.bin
ChatGLM3-6B +ChatGLM3-6B further supports function call and code interpreter in addition to chat mode. + +Chat mode: ```sh python3 chatglm_cpp/convert.py -i THUDM/chatglm3-6b -t q4_0 -o chatglm3-ggml.bin ./build/bin/main -m chatglm3-ggml.bin -p 你好 --top_p 0.8 --temp 0.8 # 你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。 ``` + +Setting system prompt: +```sh +./build/bin/main -m chatglm3-ggml.bin -p 你好 -s "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown." +# 你好👋!我是 ChatGLM3,有什么问题可以帮您解答吗? +``` + +Function call: +~~~ +$ ./build/bin/main -m chatglm3-ggml.bin --top_p 0.8 --temp 0.8 --sp examples/system/function_call.txt -i +System > Answer the following questions as best as you can. You have access to the following tools: ... +Prompt > 生成一个随机数 +ChatGLM3 > random_number_generator +```python +tool_call(seed=42, range=(0, 100)) +``` +Tool Call > Please manually call function `random_number_generator` with args `tool_call(seed=42, range=(0, 100))` and provide the results below. +Observation > 23 +ChatGLM3 > 根据您的要求,我使用随机数生成器API生成了一个随机数。根据API返回结果,生成的随机数为23。 +~~~ + +Code interpreter: +~~~ +$ ./build/bin/main -m chatglm3-ggml.bin --top_p 0.8 --temp 0.8 --sp examples/system/code_interpreter.txt -i +System > 你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。 +Prompt > 列出100以内的所有质数 +ChatGLM3 > 好的,我会为您列出100以内的所有质数。 +```python +def is_prime(n): + """Check if a number is prime.""" + if n <= 1: + return False + if n <= 3: + return True + if n % 2 == 0 or n % 3 == 0: + return False + i = 5 + while i * i <= n: + if n % i == 0 or n % (i + 2) == 0: + return False + i += 6 + return True + +primes_upto_100 = [i for i in range(2, 101) if is_prime(i)] +primes_upto_100 +``` + +Code Interpreter > Please manually run the code and provide the results below. +Observation > [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97] +ChatGLM3 > 100以内的所有质数为: + +$$ +2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97 +$$ +~~~ +
@@ -251,7 +310,7 @@ pip install . Pre-built wheels for CPU backend on Linux / MacOS / Windows are published on [release](https://github.com/li-plus/chatglm.cpp/releases). For CUDA / Metal backends, please compile from source code or source distribution. -**Using pre-converted ggml models** +**Using Pre-converted GGML Models** Here is a simple demo that uses `chatglm_cpp.Pipeline` to load the GGML model and chat with it. First enter the examples folder (`cd examples`) and launch a Python interactive shell: ```python @@ -264,7 +323,7 @@ Here is a simple demo that uses `chatglm_cpp.Pipeline` to load the GGML model an To chat in stream, run the below Python example: ```sh -python3 cli_chat.py -m ../chatglm-ggml.bin -i +python3 cli_demo.py -m ../chatglm-ggml.bin -i ``` Launch a web demo to chat in your browser: @@ -280,7 +339,7 @@ For other models: ChatGLM2-6B ```sh -python3 cli_chat.py -m ../chatglm2-ggml.bin -p 你好 --temp 0.8 --top_p 0.8 # CLI demo +python3 cli_demo.py -m ../chatglm2-ggml.bin -p 你好 --temp 0.8 --top_p 0.8 # CLI demo python3 web_demo.py -m ../chatglm2-ggml.bin --temp 0.8 --top_p 0.8 # web demo ```
@@ -288,10 +347,40 @@ python3 web_demo.py -m ../chatglm2-ggml.bin --temp 0.8 --top_p 0.8 # web demo
ChatGLM3-6B +**CLI Demo** + +Chat mode: +```sh +python3 cli_demo.py -m ../chatglm3-ggml.bin -p 你好 --temp 0.8 --top_p 0.8 +``` + +Function call: ```sh -python3 cli_chat.py -m ../chatglm3-ggml.bin -p 你好 --temp 0.8 --top_p 0.8 # CLI demo -python3 web_demo.py -m ../chatglm3-ggml.bin --temp 0.8 --top_p 0.8 # web demo +python3 cli_demo.py -m ../chatglm3-ggml.bin --temp 0.8 --top_p 0.8 --sp system/function_call.txt -i ``` + +Code interpreter: +```sh +python3 cli_demo.py -m ../chatglm3-ggml.bin --temp 0.8 --top_p 0.8 --sp system/code_interpreter.txt -i +``` + +**Web Demo** + +Install Python dependencies and the IPython kernel for code interpreter. +```sh +pip install streamlit jupyter_client ipython ipykernel +ipython kernel install --name chatglm3-demo --user +``` + +Launch the web demo: +```sh +streamlit run chatglm3_demo.py +``` + +| Function Call | Code Interpreter | +|-----------------------------|--------------------------------| +| ![](docs/function_call.png) | ![](docs/code_interpreter.png) | +
@@ -299,7 +388,7 @@ python3 web_demo.py -m ../chatglm3-ggml.bin --temp 0.8 --top_p 0.8 # web demo ```sh # CLI demo -python3 cli_chat.py -m ../codegeex2-ggml.bin --temp 0 --mode generate -p "\ +python3 cli_demo.py -m ../codegeex2-ggml.bin --temp 0 --mode generate -p "\ # language: Python # write a bubble sort function " @@ -312,7 +401,7 @@ python3 web_demo.py -m ../codegeex2-ggml.bin --temp 0 --max_length 512 --mode ge Baichuan-13B-Chat ```sh -python3 cli_chat.py -m ../baichuan-13b-chat-ggml.bin -p 你好 --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.1 # CLI demo +python3 cli_demo.py -m ../baichuan-13b-chat-ggml.bin -p 你好 --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.1 # CLI demo python3 web_demo.py -m ../baichuan-13b-chat-ggml.bin --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.1 # web demo ```
@@ -321,7 +410,7 @@ python3 web_demo.py -m ../baichuan-13b-chat-ggml.bin --top_k 5 --top_p 0.85 --te Baichuan2-7B-Chat ```sh -python3 cli_chat.py -m ../baichuan2-7b-chat-ggml.bin -p 你好 --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.05 # CLI demo +python3 cli_demo.py -m ../baichuan2-7b-chat-ggml.bin -p 你好 --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.05 # CLI demo python3 web_demo.py -m ../baichuan2-7b-chat-ggml.bin --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.05 # web demo ``` @@ -330,7 +419,7 @@ python3 web_demo.py -m ../baichuan2-7b-chat-ggml.bin --top_k 5 --top_p 0.85 --te Baichuan2-13B-Chat ```sh -python3 cli_chat.py -m ../baichuan2-13b-chat-ggml.bin -p 你好 --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.05 # CLI demo +python3 cli_demo.py -m ../baichuan2-13b-chat-ggml.bin -p 你好 --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.05 # CLI demo python3 web_demo.py -m ../baichuan2-13b-chat-ggml.bin --top_k 5 --top_p 0.85 --temp 0.3 --repeat_penalty 1.05 # web demo ``` @@ -339,7 +428,7 @@ python3 web_demo.py -m ../baichuan2-13b-chat-ggml.bin --top_k 5 --top_p 0.85 --t InternLM-Chat-7B ```sh -python3 cli_chat.py -m ../internlm-chat-7b-ggml.bin -p 你好 --top_p 0.8 --temp 0.8 # CLI demo +python3 cli_demo.py -m ../internlm-chat-7b-ggml.bin -p 你好 --top_p 0.8 --temp 0.8 # CLI demo python3 web_demo.py -m ../internlm-chat-7b-ggml.bin --top_p 0.8 --temp 0.8 # web demo ``` @@ -348,12 +437,12 @@ python3 web_demo.py -m ../internlm-chat-7b-ggml.bin --top_p 0.8 --temp 0.8 # we InternLM-Chat-20B ```sh -python3 cli_chat.py -m ../internlm-chat-20b-ggml.bin -p 你好 --top_p 0.8 --temp 0.8 # CLI demo +python3 cli_demo.py -m ../internlm-chat-20b-ggml.bin -p 你好 --top_p 0.8 --temp 0.8 # CLI demo python3 web_demo.py -m ../internlm-chat-20b-ggml.bin --top_p 0.8 --temp 0.8 # web demo ``` -**Load and optimize Hugging Face LLMs in one line of code** +**Converting Hugging Face LLMs at Runtime** Sometimes it might be inconvenient to convert and save the intermediate GGML models beforehand. Here is an option to directly load from the original Hugging Face model, quantize it into GGML models in a minute, and start serving. All you need is to replace the GGML model path with the Hugging Face model name or path. ```python @@ -369,7 +458,7 @@ Processing model states: 100%|████████████████ Likewise, replace the GGML model path with Hugging Face model in any example script, and it just works. For example: ```sh -python3 cli_chat.py -m THUDM/chatglm-6b -p 你好 -i +python3 cli_demo.py -m THUDM/chatglm-6b -p 你好 -i ``` ## API Server @@ -443,7 +532,7 @@ docker build . --network=host -t chatglm.cpp # cpp demo docker run -it --rm -v $PWD:/opt chatglm.cpp ./build/bin/main -m /opt/chatglm-ggml.bin -p "你好" # python demo -docker run -it --rm -v $PWD:/opt chatglm.cpp python3 examples/cli_chat.py -m /opt/chatglm-ggml.bin -p "你好" +docker run -it --rm -v $PWD:/opt chatglm.cpp python3 examples/cli_demo.py -m /opt/chatglm-ggml.bin -p "你好" # langchain api server docker run -it --rm -v $PWD:/opt -p 8000:8000 -e MODEL=/opt/chatglm-ggml.bin chatglm.cpp \ uvicorn chatglm_cpp.langchain_api:app --host 0.0.0.0 --port 8000 diff --git a/chatglm.cpp b/chatglm.cpp index 0fdcdd2b..3cac2a41 100644 --- a/chatglm.cpp +++ b/chatglm.cpp @@ -136,6 +136,23 @@ ggml_tensor *tensor_to_cpu(ggml_tensor *tensor) { return tensor; } +const std::string ToolCallMessage::TYPE_FUNCTION = "function"; +const std::string ToolCallMessage::TYPE_CODE = "code"; + +const std::string ChatMessage::ROLE_USER = "user"; +const std::string ChatMessage::ROLE_ASSISTANT = "assistant"; +const std::string ChatMessage::ROLE_SYSTEM = "system"; +const std::string ChatMessage::ROLE_OBSERVATION = "observation"; + +void BaseTokenizer::check_chat_messages(const std::vector &messages) { + CHATGLM_CHECK(messages.size() % 2 == 1) << "invalid chat messages size " << messages.size(); + for (size_t i = 0; i < messages.size(); i++) { + const std::string &target_role = (i % 2 == 0) ? ChatMessage::ROLE_USER : ChatMessage::ROLE_ASSISTANT; + CHATGLM_CHECK(messages[i].role == target_role) + << "expect messages[" << i << "].role to be " << target_role << ", but got " << messages[i].role; + } +} + // Adapted from https://github.com/ggerganov/llama.cpp/blob/master/llama.cpp void ggml_graph_compute_helper(std::vector &buf, ggml_cgraph *graph, int n_threads) { struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); @@ -192,6 +209,24 @@ void StreamerGroup::end() { } } +// reference: https://stackoverflow.com/questions/216823/how-to-trim-a-stdstring + +// trim from start (in place) +static inline void ltrim(std::string &s) { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { return !std::isspace(ch); })); +} + +// trim from end (in place) +static inline void rtrim(std::string &s) { + s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base(), s.end()); +} + +// trim from both ends (in place) +static inline void trim(std::string &s) { + rtrim(s); + ltrim(s); +} + void TextStreamer::put(const std::vector &output_ids) { if (is_prompt_) { // skip prompt @@ -203,6 +238,9 @@ void TextStreamer::put(const std::vector &output_ids) { token_cache_.insert(token_cache_.end(), output_ids.begin(), output_ids.end()); std::string text = tokenizer_->decode(token_cache_); + if (is_first_line_) { + ltrim(text); + } if (text.empty()) { return; } @@ -211,6 +249,7 @@ void TextStreamer::put(const std::vector &output_ids) { if (text.back() == '\n') { // flush the cache after newline printable_text = text.substr(print_len_); + is_first_line_ = false; token_cache_.clear(); print_len_ = 0; } else if (std::find(puncts.begin(), puncts.end(), text.back()) != puncts.end()) { @@ -227,8 +266,12 @@ void TextStreamer::put(const std::vector &output_ids) { void TextStreamer::end() { std::string text = tokenizer_->decode(token_cache_); + if (is_first_line_) { + ltrim(text); + } os_ << text.substr(print_len_) << std::endl; is_prompt_ = true; + is_first_line_ = true; token_cache_.clear(); print_len_ = 0; } @@ -418,20 +461,20 @@ int get_default_num_threads() { std::string to_string(ModelType model_type) { switch (model_type) { - case MODEL_TYPE_CHATGLM: + case ModelType::CHATGLM: return "ChatGLM"; - case MODEL_TYPE_CHATGLM2: + case ModelType::CHATGLM2: return "ChatGLM2"; - case MODEL_TYPE_CHATGLM3: + case ModelType::CHATGLM3: return "ChatGLM3"; - case MODEL_TYPE_BAICHUAN7B: + case ModelType::BAICHUAN7B: return "Baichuan7B"; - case MODEL_TYPE_BAICHUAN13B: + case ModelType::BAICHUAN13B: return "Baichuan13B"; - case MODEL_TYPE_INTERNLM: + case ModelType::INTERNLM: return "InternLM"; default: - CHATGLM_THROW << "unknown model type " << model_type; + CHATGLM_THROW << "unknown model type " << (int)model_type; } } @@ -632,7 +675,9 @@ std::vector BaseModelForCausalLM::generate(const std::vector &input_id streamer->put({next_token_id}); } - if (next_token_id == config.eos_token_id) { + if (next_token_id == config.eos_token_id || + std::find(config.extra_eos_token_ids.begin(), config.extra_eos_token_ids.end(), next_token_id) != + config.extra_eos_token_ids.end()) { break; } } @@ -669,23 +714,23 @@ std::vector ChatGLMTokenizer::encode(const std::string &text, int max_lengt return ids; } -std::vector ChatGLMTokenizer::encode_history(const std::vector &history, int max_length) const { - std::string prompt = build_prompt(history); +std::vector ChatGLMTokenizer::encode_messages(const std::vector &messages, int max_length) const { + std::string prompt = build_prompt(messages); std::vector input_ids = encode(prompt, max_length); return input_ids; } -std::string ChatGLMTokenizer::build_prompt(const std::vector &history) { - CHATGLM_CHECK(history.size() % 2 == 1) << "invalid history size " << history.size(); +std::string ChatGLMTokenizer::build_prompt(const std::vector &messages) { + check_chat_messages(messages); std::ostringstream oss_prompt; - if (history.size() == 1) { - oss_prompt << history.front(); + if (messages.size() == 1) { + oss_prompt << messages.front().content; } else { - for (size_t i = 0; i < history.size(); i += 2) { - oss_prompt << "[Round " << i / 2 << "]\n问:" << history[i] << "\n答:"; - if (i < history.size() - 1) { - oss_prompt << history[i + 1] << "\n"; + for (size_t i = 0; i < messages.size(); i += 2) { + oss_prompt << "[Round " << i / 2 << "]\n问:" << messages[i].content << "\n答:"; + if (i + 1 < messages.size()) { + oss_prompt << messages[i + 1].content << "\n"; } } } @@ -909,20 +954,20 @@ std::string ChatGLM2Tokenizer::decode(const std::vector &ids) const { return text; } -std::vector ChatGLM2Tokenizer::encode_history(const std::vector &history, int max_length) const { - std::string prompt = build_prompt(history); +std::vector ChatGLM2Tokenizer::encode_messages(const std::vector &messages, int max_length) const { + std::string prompt = build_prompt(messages); std::vector input_ids = encode(prompt, max_length); return input_ids; } -std::string ChatGLM2Tokenizer::build_prompt(const std::vector &history) { - CHATGLM_CHECK(history.size() % 2 == 1) << "invalid history size " << history.size(); +std::string ChatGLM2Tokenizer::build_prompt(const std::vector &messages) { + check_chat_messages(messages); std::ostringstream oss_prompt; - for (size_t i = 0; i < history.size(); i += 2) { - oss_prompt << "[Round " << i / 2 + 1 << "]\n\n问:" << history[i] << "\n\n答:"; - if (i < history.size() - 1) { - oss_prompt << history[i + 1] << "\n\n"; + for (size_t i = 0; i < messages.size(); i += 2) { + oss_prompt << "[Round " << i / 2 + 1 << "]\n\n问:" << messages[i].content << "\n\n答:"; + if (i < messages.size() - 1) { + oss_prompt << messages[i + 1].content << "\n\n"; } } return oss_prompt.str(); @@ -1014,6 +1059,22 @@ ChatGLM3Tokenizer::ChatGLM3Tokenizer(std::string_view serialized_model_proto) { user_token_id = special_id++; assistant_token_id = special_id++; observation_token_id = special_id++; + + special_tokens = { + {"[MASK]", mask_token_id}, + {"[gMASK]", gmask_token_id}, + {"[sMASK]", smask_token_id}, + {"sop", sop_token_id}, + {"eop", eop_token_id}, + {"<|system|>", system_token_id}, + {"<|user|>", user_token_id}, + {"<|assistant|>", assistant_token_id}, + {"<|observation|>", observation_token_id}, + }; + + for (const auto &item : special_tokens) { + index_special_tokens[item.second] = item.first; + } } std::vector ChatGLM3Tokenizer::encode(const std::string &text, int max_length) const { @@ -1025,44 +1086,138 @@ std::vector ChatGLM3Tokenizer::encode(const std::string &text, int max_leng } std::string ChatGLM3Tokenizer::decode(const std::vector &ids) const { - // filter out special tokens - std::vector normal_ids(ids); - normal_ids.erase(std::remove_if(normal_ids.begin(), normal_ids.end(), [this](int id) { return is_special_id(id); }), - normal_ids.end()); + std::string text = decode_with_special_tokens(ids); + text = remove_special_tokens(text); + return text; +} - std::string text; - sp.Decode(normal_ids, &text); - text = replace_punctuations(text); +std::string ChatGLM3Tokenizer::decode_with_special_tokens(const std::vector &ids) const { + std::vector pieces; + for (int id : ids) { + auto pos = index_special_tokens.find(id); + if (pos != index_special_tokens.end()) { + // special tokens + pieces.emplace_back(pos->second); + } else { + // normal tokens + pieces.emplace_back(sp.IdToPiece(id)); + } + } + + std::string text = sp.DecodePieces(pieces); return text; } -std::vector ChatGLM3Tokenizer::encode_history(const std::vector &history, int max_length) const { - // TODO: need a new api for system / tools / metadata prompt +std::string ChatGLM3Tokenizer::remove_special_tokens(const std::string &text) { + std::string output = text; + static const std::vector special_token_regex{ + // std::regex(R"(<\|assistant\|> interpreter)"), + // std::regex(R"(<\|assistant\|> interpre)"), + std::regex(R"(<\|assistant\|>)"), + std::regex(R"(<\|user\|>)"), + std::regex(R"(<\|observation\|>)"), + }; + for (const auto &re : special_token_regex) { + output = std::regex_replace(output, re, ""); + } + return output; +} + +std::vector ChatGLM3Tokenizer::encode_single_message(const std::string &role, const std::string &content) const { + std::vector input_ids; + input_ids.emplace_back(get_command("<|" + role + "|>")); + // TODO: support metadata std::vector newline_ids; sp.Encode("\n", &newline_ids); + input_ids.insert(input_ids.end(), newline_ids.begin(), newline_ids.end()); + std::vector content_ids; + sp.Encode(content, &content_ids); + input_ids.insert(input_ids.end(), content_ids.begin(), content_ids.end()); + return input_ids; +} + +std::vector ChatGLM3Tokenizer::encode_messages(const std::vector &messages, int max_length) const { std::vector input_ids{gmask_token_id, sop_token_id}; - for (size_t i = 0; i < history.size(); i++) { - // TODO: support all roles - input_ids.emplace_back((i % 2 == 0) ? user_token_id : assistant_token_id); - // TODO: support metadata - input_ids.insert(input_ids.end(), newline_ids.begin(), newline_ids.end()); - std::vector content_ids; - sp.Encode(history[i], &content_ids); - input_ids.insert(input_ids.end(), content_ids.begin(), content_ids.end()); + for (const auto &msg : messages) { + auto msg_ids = encode_single_message(msg.role, msg.content); + input_ids.insert(input_ids.end(), msg_ids.begin(), msg_ids.end()); + + // encode code block into a separate message + if (!msg.tool_calls.empty() && msg.tool_calls.front().type == ToolCallMessage::TYPE_CODE) { + auto code_ids = encode_single_message(msg.role, msg.tool_calls.front().code.input); + input_ids.insert(input_ids.end(), code_ids.begin(), code_ids.end()); + } } input_ids.emplace_back(assistant_token_id); - // NOTE: push '\n' into input_ids to avoid model generating it, saving 2 tokens - input_ids.insert(input_ids.end(), newline_ids.begin(), newline_ids.end()); truncate(input_ids, max_length); return input_ids; } -bool ChatGLM3Tokenizer::is_special_id(int id) const { - return id == mask_token_id || id == gmask_token_id || id == smask_token_id || id == sop_token_id || - id == eop_token_id || id == system_token_id || id == user_token_id || id == assistant_token_id || - id == observation_token_id; +ChatMessage ChatGLM3Tokenizer::decode_message(const std::vector &ids) const { + ChatMessage message; + if (!ids.empty() && ids.back() == observation_token_id) { + // insert an <|assistant|> token before content to match possible interpreter delimiter + std::vector full_ids{assistant_token_id}; + full_ids.insert(full_ids.end(), ids.begin(), ids.end()); + + std::string output = decode_with_special_tokens(full_ids); + const std::string ci_delim = "<|assistant|> interpreter"; + size_t ci_pos = output.find(ci_delim); + if (ci_pos != std::string::npos) { + // code interpreter + std::string chat_output = output.substr(0, ci_pos); + chat_output = remove_special_tokens(chat_output); + trim(chat_output); + std::string code_output = output.substr(ci_pos + ci_delim.size()); + code_output = remove_special_tokens(code_output); + trim(code_output); + message = ChatMessage(ChatMessage::ROLE_ASSISTANT, std::move(chat_output), + {ToolCallMessage(CodeMessage(std::move(code_output)))}); + } else { + // tool call + output = remove_special_tokens(output); + + // parse tool name + std::string tool_name = "PARSE_ERROR"; + size_t pos = output.find('\n'); + if (pos != std::string::npos) { + // split tool name and args by 1st linebreak + tool_name = output.substr(0, pos); + trim(tool_name); + output.erase(0, pos + 1); + } + + // post process output + trim(output); + + // extract args + std::string tool_args = "PARSE_ERROR"; + static const std::regex args_regex(R"(```.*?\n(.*?)\n```)"); + std::smatch sm; + if (std::regex_search(output, sm, args_regex)) { + CHATGLM_CHECK(sm.size() == 2) << "unexpected regex match results"; + tool_args = sm[1]; + } + + message = ChatMessage(ChatMessage::ROLE_ASSISTANT, std::move(output), + {ToolCallMessage(FunctionMessage(std::move(tool_name), std::move(tool_args)))}); + } + } else { + // conversation + message = BaseTokenizer::decode_message(ids); + trim(message.content); // strip leading linebreak in conversation mode + } + return message; } +int ChatGLM3Tokenizer::get_command(const std::string &token) const { + auto pos = special_tokens.find(token); + CHATGLM_CHECK(pos != special_tokens.end()) << token << " is not a special token"; + return pos->second; +} + +bool ChatGLM3Tokenizer::is_special_id(int id) const { return index_special_tokens.count(id) > 0; } + void ChatGLM3Tokenizer::truncate(std::vector &ids, int max_length) { if ((int)ids.size() > max_length) { // sliding window: drop the least recent history while keeping the two special prefix tokens @@ -1095,18 +1250,14 @@ std::string BaichuanTokenizer::decode(const std::vector &ids) const { return text; } -std::vector BaichuanTokenizer::encode_history(const std::vector &history, int max_length) const { - CHATGLM_CHECK(history.size() % 2 == 1) << "invalid history size " << history.size(); +std::vector BaichuanTokenizer::encode_messages(const std::vector &messages, int max_length) const { + check_chat_messages(messages); std::vector ids; ids.reserve(max_length); - for (size_t i = 0; i < history.size(); i++) { - if (i % 2 == 0) { - ids.push_back(USER_TOKEN_ID); - } else { - ids.push_back(ASSISTANT_TOKEN_ID); - } - std::vector content_ids = encode(history[i], max_length); + for (const auto &msg : messages) { + ids.push_back((msg.role == ChatMessage::ROLE_USER) ? USER_TOKEN_ID : ASSISTANT_TOKEN_ID); + std::vector content_ids = encode(msg.content, max_length); ids.insert(ids.end(), content_ids.begin(), content_ids.end()); } ids.push_back(ASSISTANT_TOKEN_ID); @@ -1242,20 +1393,21 @@ std::string InternLMTokenizer::decode(const std::vector &ids) const { return text; } -std::vector InternLMTokenizer::encode_history(const std::vector &history, int max_length) const { - std::string prompt = build_prompt(history); +std::vector InternLMTokenizer::encode_messages(const std::vector &messages, int max_length) const { + std::string prompt = build_prompt(messages); std::vector input_ids = encode(prompt, max_length); return input_ids; } -std::string InternLMTokenizer::build_prompt(const std::vector &history) { - CHATGLM_CHECK(history.size() % 2 == 1) << "invalid history size " << history.size(); +std::string InternLMTokenizer::build_prompt(const std::vector &messages) { + check_chat_messages(messages); std::ostringstream oss_prompt; - for (size_t i = 0; i < history.size(); i += 2) { - oss_prompt << "<|User|>:" << history[i] << "\n<|Bot|>:"; - if (i < history.size() - 1) { - oss_prompt << history[i + 1] << "\n"; + for (const auto &msg : messages) { + if (msg.role == ChatMessage::ROLE_USER) { + oss_prompt << "<|User|>:" << msg.content << "\n<|Bot|>:"; + } else { + oss_prompt << msg.content << "\n"; } } return oss_prompt.str(); @@ -1324,7 +1476,7 @@ Pipeline::Pipeline(const std::string &path) { ModelType model_type = (ModelType)loader.read_basic(); // load version int version = loader.read_basic(); - if (model_type == MODEL_TYPE_CHATGLM) { + if (model_type == ModelType::CHATGLM) { CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version; // load config @@ -1339,7 +1491,7 @@ Pipeline::Pipeline(const std::string &path) { // load model model = std::make_unique(config); model->load(loader); - } else if (model_type == MODEL_TYPE_CHATGLM2 || model_type == MODEL_TYPE_CHATGLM3) { + } else if (model_type == ModelType::CHATGLM2 || model_type == ModelType::CHATGLM3) { CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version; // load config @@ -1350,17 +1502,19 @@ Pipeline::Pipeline(const std::string &path) { std::string_view serialized_model_proto((char *)mapped_file->data + loader.tell(), proto_size); loader.seek(proto_size, SEEK_CUR); - if (model_type == MODEL_TYPE_CHATGLM2) { + if (model_type == ModelType::CHATGLM2) { tokenizer = std::make_unique(serialized_model_proto); model = std::make_unique(config); } else { - tokenizer = std::make_unique(serialized_model_proto); + auto chatglm3_tokenizer = std::make_unique(serialized_model_proto); + config.extra_eos_token_ids = {chatglm3_tokenizer->observation_token_id, chatglm3_tokenizer->user_token_id}; + tokenizer = std::move(chatglm3_tokenizer); model = std::make_unique(config); } // load model model->load(loader); - } else if (model_type == MODEL_TYPE_BAICHUAN7B) { + } else if (model_type == ModelType::BAICHUAN7B) { CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version; // load config @@ -1376,7 +1530,7 @@ Pipeline::Pipeline(const std::string &path) { // load model model = std::make_unique(config); model->load(loader); - } else if (model_type == MODEL_TYPE_BAICHUAN13B) { + } else if (model_type == ModelType::BAICHUAN13B) { CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version; // load config @@ -1392,7 +1546,7 @@ Pipeline::Pipeline(const std::string &path) { // load model model = std::make_unique(config); model->load(loader); - } else if (model_type == MODEL_TYPE_INTERNLM) { + } else if (model_type == ModelType::INTERNLM) { CHATGLM_CHECK(version == 1) << "only support version 1 for now but got " << version; // load config @@ -1413,7 +1567,7 @@ Pipeline::Pipeline(const std::string &path) { } model->load(loader); } else { - CHATGLM_THROW << "invalid model type " << model_type; + CHATGLM_THROW << "invalid model type " << (int)model_type; } } @@ -1432,11 +1586,11 @@ std::string Pipeline::generate(const std::string &prompt, const GenerationConfig return output; } -std::string Pipeline::chat(const std::vector &history, const GenerationConfig &gen_config, +ChatMessage Pipeline::chat(const std::vector &messages, const GenerationConfig &gen_config, BaseStreamer *streamer) const { - std::vector input_ids = tokenizer->encode_history(history, gen_config.max_context_length); + std::vector input_ids = tokenizer->encode_messages(messages, gen_config.max_context_length); std::vector new_output_ids = generate(input_ids, gen_config, streamer); - std::string output = tokenizer->decode(new_output_ids); + ChatMessage output = tokenizer->decode_message(new_output_ids); return output; } diff --git a/chatglm.h b/chatglm.h index cf3562b2..2394e159 100644 --- a/chatglm.h +++ b/chatglm.h @@ -2,10 +2,10 @@ #include #include +#include #include #include #include -#include #ifdef GGML_USE_METAL #include @@ -46,13 +46,13 @@ ggml_tensor *tensor_to_device(ggml_tensor *tensor); ggml_tensor *tensor_to_cpu(ggml_tensor *tensor); -enum ModelType { - MODEL_TYPE_CHATGLM = 1, - MODEL_TYPE_CHATGLM2 = 2, - MODEL_TYPE_CHATGLM3 = 3, - MODEL_TYPE_BAICHUAN7B = 1024, - MODEL_TYPE_BAICHUAN13B = 1025, - MODEL_TYPE_INTERNLM = 1280, +enum class ModelType { + CHATGLM = 1, + CHATGLM2 = 2, + CHATGLM3 = 3, + BAICHUAN7B = 1024, + BAICHUAN13B = 1025, + INTERNLM = 1280, }; std::string to_string(ModelType model_type); @@ -87,21 +87,23 @@ class ModelConfig { ModelConfig(ModelType model_type, ggml_type dtype, int vocab_size, int hidden_size, int num_attention_heads, int num_kv_heads, int num_hidden_layers, int intermediate_size, float norm_eps, int max_length, - int bos_token_id, int eos_token_id, int pad_token_id, int sep_token_id) + int bos_token_id, int eos_token_id, int pad_token_id, int sep_token_id, + std::vector extra_eos_token_ids) : model_type(model_type), dtype(dtype), vocab_size(vocab_size), hidden_size(hidden_size), num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), num_hidden_layers(num_hidden_layers), intermediate_size(intermediate_size), norm_eps(norm_eps), max_length(max_length), bos_token_id(bos_token_id), - eos_token_id(eos_token_id), pad_token_id(pad_token_id), sep_token_id(sep_token_id) {} + eos_token_id(eos_token_id), pad_token_id(pad_token_id), sep_token_id(sep_token_id), + extra_eos_token_ids(std::move(extra_eos_token_ids)) {} ModelConfig(ModelType model_type, const ConfigRecordV1 &rec) : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_attention_heads, rec.num_hidden_layers, rec.intermediate_size, 1e-5, rec.max_length, - rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, rec.sep_token_id) {} + rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, rec.sep_token_id, {}) {} ModelConfig(ModelType model_type, const ConfigRecordV2 &rec) : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_kv_heads, rec.num_hidden_layers, rec.intermediate_size, 1e-5, rec.max_length, rec.bos_token_id, - rec.eos_token_id, rec.pad_token_id, rec.sep_token_id) {} + rec.eos_token_id, rec.pad_token_id, rec.sep_token_id, {}) {} std::string model_type_name() const { return to_string(model_type); } @@ -120,14 +122,91 @@ class ModelConfig { int eos_token_id; int pad_token_id; int sep_token_id; + std::vector extra_eos_token_ids; +}; + +struct FunctionMessage { + std::string name; + std::string arguments; + + FunctionMessage() = default; + FunctionMessage(std::string name, std::string arguments) : name(std::move(name)), arguments(std::move(arguments)) {} + + friend std::ostream &operator<<(std::ostream &os, const FunctionMessage &self) { + return os << "FunctionMessage(name=" << std::quoted(self.name) << ", arguments=" << std::quoted(self.arguments) + << ")"; + } +}; + +struct CodeMessage { + std::string input; + + CodeMessage() = default; + CodeMessage(std::string input) : input(std::move(input)) {} + + friend std::ostream &operator<<(std::ostream &os, const CodeMessage &self) { + return os << "CodeMessage(input=" << std::quoted(self.input) << ")"; + } +}; + +struct ToolCallMessage { + std::string type; + FunctionMessage function; + CodeMessage code; + + static const std::string TYPE_FUNCTION; + static const std::string TYPE_CODE; + + ToolCallMessage(FunctionMessage function) : type(TYPE_FUNCTION), function(std::move(function)) {} + + ToolCallMessage(CodeMessage code) : type(TYPE_CODE), code(std::move(code)) {} + + friend std::ostream &operator<<(std::ostream &os, const ToolCallMessage &self) { + return os << "ToolCallMessage(type=" << std::quoted(self.type) << ", function=" << self.function + << ", code=" << self.code << ")"; + } +}; + +struct ChatMessage { + std::string role; + std::string content; + std::vector tool_calls; + + static const std::string ROLE_USER; + static const std::string ROLE_ASSISTANT; + static const std::string ROLE_SYSTEM; + static const std::string ROLE_OBSERVATION; + + ChatMessage() = default; + ChatMessage(std::string role, std::string content, std::vector tool_calls = {}) + : role(std::move(role)), content(std::move(content)), tool_calls(std::move(tool_calls)) {} + + friend std::ostream &operator<<(std::ostream &os, const ChatMessage &self) { + os << "ChatMessage(role=" << std::quoted(self.role) << ", content=" << std::quoted(self.content) + << ", tool_calls=["; + for (size_t i = 0; i < self.tool_calls.size(); i++) { + os << (i > 0 ? ", " : "") << self.tool_calls[i]; + } + return os << "])"; + } }; class BaseTokenizer { public: virtual ~BaseTokenizer() = default; + virtual std::vector encode(const std::string &text, int max_length) const = 0; + virtual std::string decode(const std::vector &ids) const = 0; - virtual std::vector encode_history(const std::vector &history, int max_length) const = 0; + + virtual std::vector encode_messages(const std::vector &messages, int max_length) const = 0; + + virtual ChatMessage decode_message(const std::vector &ids) const { + return {ChatMessage::ROLE_ASSISTANT, decode(ids)}; + } + + protected: + static void check_chat_messages(const std::vector &messages); }; struct ggml_context_deleter_t { @@ -237,20 +316,20 @@ class RMSNorm { float eps; }; -enum ActivationType { - ACT_TYPE_GELU, - ACT_TYPE_SILU, +enum class ActivationType { + GELU, + SILU, }; template static inline ggml_tensor *apply_activation_inplace(ggml_context *ctx, ggml_tensor *hidden_states) { - static_assert(ACT_TYPE == ACT_TYPE_GELU || ACT_TYPE == ACT_TYPE_SILU); - if constexpr (ACT_TYPE == ACT_TYPE_GELU) { + static_assert(ACT_TYPE == ActivationType::GELU || ACT_TYPE == ActivationType::SILU); + if constexpr (ACT_TYPE == ActivationType::GELU) { hidden_states = tensor_assign_buffers(ggml_gelu_inplace(ctx, hidden_states)); - } else if constexpr (ACT_TYPE == ACT_TYPE_SILU) { + } else if constexpr (ACT_TYPE == ActivationType::SILU) { hidden_states = tensor_assign_buffers(ggml_silu_inplace(ctx, hidden_states)); } else { - CHATGLM_THROW << "Unknown activation type " << ACT_TYPE; + CHATGLM_THROW << "Unknown activation type " << (int)ACT_TYPE; } return hidden_states; } @@ -650,7 +729,7 @@ class StreamerGroup : public BaseStreamer { class TextStreamer : public BaseStreamer { public: TextStreamer(std::ostream &os, BaseTokenizer *tokenizer) - : os_(os), tokenizer_(tokenizer), is_prompt_(true), print_len_(0) {} + : os_(os), tokenizer_(tokenizer), is_prompt_(true), is_first_line_(true), print_len_(0) {} void put(const std::vector &output_ids) override; void end() override; @@ -658,6 +737,7 @@ class TextStreamer : public BaseStreamer { std::ostream &os_; BaseTokenizer *tokenizer_; bool is_prompt_; + bool is_first_line_; std::vector token_cache_; int print_len_; }; @@ -870,9 +950,9 @@ class ChatGLMTokenizer : public BaseTokenizer { std::string decode(const std::vector &ids) const override; - std::vector encode_history(const std::vector &history, int max_length) const override; + std::vector encode_messages(const std::vector &messages, int max_length) const override; - static std::string build_prompt(const std::vector &history); + static std::string build_prompt(const std::vector &messages); private: static std::string preprocess(const std::string &text); @@ -894,7 +974,7 @@ struct GLMContextMasker { using GLMAttention = BasicAttention; -using GLMMLP = BasicMLP; +using GLMMLP = BasicMLP; // NOTE: disable inplace norm since it causes nonsense on cuda when sequence length >= 144 class GLMBlock : public BasicBlock { @@ -942,10 +1022,11 @@ class ChatGLM2Tokenizer : public BaseTokenizer { std::string decode(const std::vector &ids) const override; - std::vector encode_history(const std::vector &history, int max_length) const override; + std::vector encode_messages(const std::vector &messages, int max_length) const override; - static std::string build_prompt(const std::vector &history); + static std::string build_prompt(const std::vector &messages); + private: bool is_special_id(int id) const; public: @@ -959,7 +1040,7 @@ class ChatGLM2Tokenizer : public BaseTokenizer { using GLM2Attention = BasicAttention, false, CausalContextMasker>; -using GLM2MLP = BasicGLU; +using GLM2MLP = BasicGLU; using GLM2Block = BasicBlock; @@ -991,11 +1072,21 @@ class ChatGLM3Tokenizer : public BaseTokenizer { std::string decode(const std::vector &ids) const override; - std::vector encode_history(const std::vector &history, int max_length) const override; + std::vector encode_messages(const std::vector &messages, int max_length) const override; + + ChatMessage decode_message(const std::vector &ids) const override; + + private: + std::vector encode_single_message(const std::string &role, const std::string &content) const; + + std::string decode_with_special_tokens(const std::vector &ids) const; + + static std::string remove_special_tokens(const std::string &text); + + int get_command(const std::string &token) const; bool is_special_id(int id) const; - protected: static void truncate(std::vector &ids, int max_length); public: @@ -1009,6 +1100,8 @@ class ChatGLM3Tokenizer : public BaseTokenizer { int user_token_id; int assistant_token_id; int observation_token_id; + std::unordered_map special_tokens; + std::unordered_map index_special_tokens; }; using ChatGLM3Model = ChatGLM2Model; @@ -1025,11 +1118,11 @@ class BaichuanTokenizer : public BaseTokenizer { std::string decode(const std::vector &ids) const override; - std::vector encode_history(const std::vector &history, int max_length) const override; + std::vector encode_messages(const std::vector &messages, int max_length) const override; + private: bool is_special_id(int id) const; - protected: static void truncate(std::vector &ids, int max_length); public: @@ -1047,7 +1140,7 @@ class BaichuanTokenizer : public BaseTokenizer { using Baichuan7BAttention = BasicAttention, false, CausalContextMasker>; -using Baichuan7BMLP = BasicGLU; +using Baichuan7BMLP = BasicGLU; using Baichuan7BBlock = BasicBlock; @@ -1073,7 +1166,7 @@ class Baichuan7BForCausalLM : public BasicModelForCausalLM { using Baichuan13BAttention = BasicAttention; -using Baichuan13BMLP = BasicGLU; +using Baichuan13BMLP = BasicGLU; using Baichuan13BBlock = BasicBlock; @@ -1105,10 +1198,11 @@ class InternLMTokenizer : public BaseTokenizer { std::string decode(const std::vector &ids) const override; - std::vector encode_history(const std::vector &history, int max_length) const override; + std::vector encode_messages(const std::vector &messages, int max_length) const override; - static std::string build_prompt(const std::vector &history); + static std::string build_prompt(const std::vector &messages); + private: bool is_special_id(int id) const { return id == unk_token_id || id == bos_token_id || id == eos_token_id; } public: @@ -1121,7 +1215,7 @@ class InternLMTokenizer : public BaseTokenizer { using InternLM7BAttention = BasicAttention, false, CausalContextMasker>; -using InternLM7BMLP = BasicGLU; +using InternLM7BMLP = BasicGLU; using InternLM7BBlock = BasicBlock; @@ -1130,7 +1224,7 @@ using InternLM7BModel = BasicModel, false, CausalContextMasker>; -using InternLM20BMLP = BasicGLU; +using InternLM20BMLP = BasicGLU; using InternLM20BBlock = BasicBlock; @@ -1171,7 +1265,7 @@ class Pipeline { std::string generate(const std::string &prompt, const GenerationConfig &gen_config, BaseStreamer *streamer = nullptr) const; - std::string chat(const std::vector &history, const GenerationConfig &gen_config, + ChatMessage chat(const std::vector &messages, const GenerationConfig &gen_config, BaseStreamer *streamer = nullptr) const; public: diff --git a/chatglm_cpp/_C.pyi b/chatglm_cpp/_C.pyi new file mode 100644 index 00000000..20a5df80 --- /dev/null +++ b/chatglm_cpp/_C.pyi @@ -0,0 +1,193 @@ +""" +ChatGLM.cpp python binding +""" +from __future__ import annotations +import typing +__all__ = ['Baichuan13BForCausalLM', 'Baichuan7BForCausalLM', 'BaichuanTokenizer', 'BaseModelForCausalLM', 'BaseTokenizer', 'ChatGLM2ForCausalLM', 'ChatGLM2Tokenizer', 'ChatGLM3Tokenizer', 'ChatGLMForCausalLM', 'ChatGLMTokenizer', 'ChatMessage', 'CodeMessage', 'FunctionMessage', 'GenerationConfig', 'InternLM20BForCausalLM', 'InternLM7BForCausalLM', 'InternLMTokenizer', 'ModelConfig', 'ModelType', 'Pipeline', 'ToolCallMessage'] +class Baichuan13BForCausalLM(BaseModelForCausalLM): + pass +class Baichuan7BForCausalLM(BaseModelForCausalLM): + pass +class BaichuanTokenizer(BaseTokenizer): + pass +class BaseModelForCausalLM: + def generate_next_token(self, input_ids: list[int], gen_config: GenerationConfig, n_past: int, n_ctx: int) -> int: + ... + @property + def config(self) -> ModelConfig: + ... +class BaseTokenizer: + def decode(self, ids: list[int]) -> str: + ... + def decode_message(self, ids: list[int]) -> ChatMessage: + ... + def encode(self, text: str, max_length: int) -> list[int]: + ... + def encode_messages(self, messages: list[ChatMessage], max_length: int) -> list[int]: + ... +class ChatGLM2ForCausalLM(BaseModelForCausalLM): + pass +class ChatGLM2Tokenizer(BaseTokenizer): + pass +class ChatGLM3Tokenizer(BaseTokenizer): + pass +class ChatGLMForCausalLM(BaseModelForCausalLM): + pass +class ChatGLMTokenizer(BaseTokenizer): + pass +class ChatMessage: + ROLE_ASSISTANT: typing.ClassVar[str] = 'assistant' + ROLE_OBSERVATION: typing.ClassVar[str] = 'observation' + ROLE_SYSTEM: typing.ClassVar[str] = 'system' + ROLE_USER: typing.ClassVar[str] = 'user' + content: str + role: str + tool_calls: list[ToolCallMessage] + def __init__(self, role: str, content: str, tool_calls: list[ToolCallMessage] = []) -> None: + ... + def __repr__(self) -> str: + ... + def __str__(self) -> str: + ... +class CodeMessage: + input: str + def __repr__(self) -> str: + ... + def __str__(self) -> str: + ... +class FunctionMessage: + arguments: str + name: str + def __repr__(self) -> str: + ... + def __str__(self) -> str: + ... +class GenerationConfig: + do_sample: bool + max_context_length: int + max_length: int + num_threads: int + repetition_penalty: float + temperature: float + top_k: int + top_p: float + def __init__(self, max_length: int = 2048, max_context_length: int = 512, do_sample: bool = True, top_k: int = 0, top_p: float = 0.7, temperature: float = 0.95, repetition_penalty: float = 1.0, num_threads: int = 0) -> None: + ... +class InternLM20BForCausalLM(BaseModelForCausalLM): + pass +class InternLM7BForCausalLM(BaseModelForCausalLM): + pass +class InternLMTokenizer(BaseTokenizer): + pass +class ModelConfig: + @property + def bos_token_id(self) -> int: + ... + @property + def eos_token_id(self) -> int: + ... + @property + def extra_eos_token_ids(self) -> list[int]: + ... + @property + def hidden_size(self) -> int: + ... + @property + def intermediate_size(self) -> int: + ... + @property + def max_length(self) -> int: + ... + @property + def model_type(self) -> ModelType: + ... + @property + def model_type_name(self) -> str: + ... + @property + def norm_eps(self) -> float: + ... + @property + def num_attention_heads(self) -> int: + ... + @property + def num_hidden_layers(self) -> int: + ... + @property + def num_kv_heads(self) -> int: + ... + @property + def pad_token_id(self) -> int: + ... + @property + def sep_token_id(self) -> int: + ... + @property + def vocab_size(self) -> int: + ... +class ModelType: + """ + Members: + + CHATGLM + + CHATGLM2 + + CHATGLM3 + + BAICHUAN7B + + BAICHUAN13B + + INTERNLM + """ + BAICHUAN13B: typing.ClassVar[ModelType] # value = + BAICHUAN7B: typing.ClassVar[ModelType] # value = + CHATGLM: typing.ClassVar[ModelType] # value = + CHATGLM2: typing.ClassVar[ModelType] # value = + CHATGLM3: typing.ClassVar[ModelType] # value = + INTERNLM: typing.ClassVar[ModelType] # value = + __members__: typing.ClassVar[dict[str, ModelType]] # value = {'CHATGLM': , 'CHATGLM2': , 'CHATGLM3': , 'BAICHUAN7B': , 'BAICHUAN13B': , 'INTERNLM': } + def __eq__(self, other: typing.Any) -> bool: + ... + def __getstate__(self) -> int: + ... + def __hash__(self) -> int: + ... + def __index__(self) -> int: + ... + def __init__(self, value: int) -> None: + ... + def __int__(self) -> int: + ... + def __ne__(self, other: typing.Any) -> bool: + ... + def __repr__(self) -> str: + ... + def __setstate__(self, state: int) -> None: + ... + def __str__(self) -> str: + ... + @property + def name(self) -> str: + ... + @property + def value(self) -> int: + ... +class Pipeline: + def __init__(self, path: str) -> None: + ... + @property + def model(self) -> BaseModelForCausalLM: + ... + @property + def tokenizer(self) -> BaseTokenizer: + ... +class ToolCallMessage: + code: CodeMessage + function: FunctionMessage + type: str + def __repr__(self) -> str: + ... + def __str__(self) -> str: + ... diff --git a/chatglm_cpp/__init__.py b/chatglm_cpp/__init__.py index a1dc1836..d3938bd0 100644 --- a/chatglm_cpp/__init__.py +++ b/chatglm_cpp/__init__.py @@ -1,11 +1,29 @@ import tempfile -import warnings +from dataclasses import dataclass from pathlib import Path -from typing import Iterator, List, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, Union import chatglm_cpp._C as _C +from chatglm_cpp._C import ChatMessage -__version__ = "0.2.10" +__version__ = "0.3.0" + + +@dataclass +class DeltaMessage: + role: str + content: str + token_ids: List[int] + + +def _ensure_chat_message(message: Union[ChatMessage, Dict[str, Any]]) -> ChatMessage: + if isinstance(message, ChatMessage): + chat_message = message + elif isinstance(message, dict): + chat_message = ChatMessage(**message) + else: + raise TypeError(f"expect message type to be ChatMessage or dict, but got {type(message)}") + return chat_message class Pipeline(_C.Pipeline): @@ -26,7 +44,7 @@ def __init__(self, model_path: str, *, dtype: Optional[str] = None) -> None: def chat( self, - history: List[str], + messages: List[ChatMessage], *, max_length: int = 2048, max_context_length: int = 512, @@ -37,10 +55,10 @@ def chat( repetition_penalty: float = 1.0, num_threads: int = 0, stream: bool = False, - ) -> Union[Iterator[str], str]: - input_ids = self.tokenizer.encode_history(history, max_context_length) - return self._generate( - input_ids=input_ids, + ) -> Union[Iterator[DeltaMessage], ChatMessage]: + messages = [_ensure_chat_message(msg) for msg in messages] + input_ids = self.tokenizer.encode_messages(messages, max_context_length) + gen_config = _C.GenerationConfig( max_length=max_length, max_context_length=max_context_length, do_sample=do_sample, @@ -49,8 +67,10 @@ def chat( temperature=temperature, repetition_penalty=repetition_penalty, num_threads=num_threads, - stream=stream, ) + if stream: + return self._stream_chat(input_ids=input_ids, gen_config=gen_config) + return self._sync_chat(input_ids=input_ids, gen_config=gen_config) def generate( self, @@ -67,33 +87,6 @@ def generate( stream: bool = False, ) -> Union[Iterator[str], str]: input_ids = self.tokenizer.encode(prompt, max_context_length) - return self._generate( - input_ids=input_ids, - max_length=max_length, - max_context_length=max_context_length, - do_sample=do_sample, - top_k=top_k, - top_p=top_p, - temperature=temperature, - repetition_penalty=repetition_penalty, - num_threads=num_threads, - stream=stream, - ) - - def _generate( - self, - input_ids: List[int], - *, - max_length: int = 2048, - max_context_length: int = 512, - do_sample: bool = True, - top_k: int = 0, - top_p: float = 0.7, - temperature: float = 0.95, - repetition_penalty: float = 1.0, - num_threads: int = 0, - stream: bool = False, - ) -> Union[Iterator[str], str]: gen_config = _C.GenerationConfig( max_length=max_length, max_context_length=max_context_length, @@ -104,83 +97,68 @@ def _generate( repetition_penalty=repetition_penalty, num_threads=num_threads, ) + if stream: + return self._stream_generate(input_ids=input_ids, gen_config=gen_config) + return self._sync_generate(input_ids=input_ids, gen_config=gen_config) - generate_fn = self._stream_generate if stream else self._sync_generate - return generate_fn(input_ids=input_ids, gen_config=gen_config) - - def _stream_generate(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> Iterator[str]: - input_ids = [x for x in input_ids] # make a copy + def _stream_generate_ids(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> Iterator[int]: + input_ids = input_ids.copy() n_past = 0 n_ctx = len(input_ids) - token_cache = [] - print_len = 0 while len(input_ids) < gen_config.max_length: next_token_id = self.model.generate_next_token(input_ids, gen_config, n_past, n_ctx) + yield next_token_id n_past = len(input_ids) input_ids.append(next_token_id) + if next_token_id in [self.model.config.eos_token_id, *self.model.config.extra_eos_token_ids]: + break + + def _stream_chat(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> Iterator[DeltaMessage]: + token_cache = [] + print_len = 0 + print_token_len = 0 + for next_token_id in self._stream_generate_ids(input_ids=input_ids, gen_config=gen_config): token_cache.append(next_token_id) output = self.tokenizer.decode(token_cache) if output.endswith("\n"): - yield output[print_len:] + yield DeltaMessage( + role=ChatMessage.ROLE_ASSISTANT, content=output[print_len:], token_ids=token_cache[print_token_len:] + ) token_cache = [] print_len = 0 + print_token_len = 0 elif output.endswith((",", "!", ":", ";", "?", "�")): pass else: - yield output[print_len:] + yield DeltaMessage( + role=ChatMessage.ROLE_ASSISTANT, content=output[print_len:], token_ids=token_cache[print_token_len:] + ) print_len = len(output) - - if next_token_id == self.model.config.eos_token_id: - break + print_token_len = len(token_cache) output = self.tokenizer.decode(token_cache) - yield output[print_len:] + yield DeltaMessage( + role=ChatMessage.ROLE_ASSISTANT, content=output[print_len:], token_ids=token_cache[print_token_len:] + ) - def _sync_generate(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> str: - input_ids = [x for x in input_ids] # make a copy - n_past = 0 - n_ctx = len(input_ids) + def _stream_generate(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> Iterator[str]: + for msg in self._stream_chat(input_ids=input_ids, gen_config=gen_config): + yield msg.content - while len(input_ids) < gen_config.max_length: - next_token_id = self.model.generate_next_token(input_ids, gen_config, n_past, n_ctx) - n_past = len(input_ids) - input_ids.append(next_token_id) - if next_token_id == self.model.config.eos_token_id: - break + def _sync_generate_ids(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> List[int]: + return list(self._stream_generate_ids(input_ids=input_ids, gen_config=gen_config)) - output = self.tokenizer.decode(input_ids[n_ctx:]) - return output + def _sync_generate(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> str: + output_ids = self._sync_generate_ids(input_ids=input_ids, gen_config=gen_config) + return self.tokenizer.decode(output_ids) - def stream_chat( - self, - history: List[str], - *, - max_length: int = 2048, - max_context_length: int = 512, - do_sample: bool = True, - top_k: int = 0, - top_p: float = 0.7, - temperature: float = 0.95, - repetition_penalty: float = 1.0, - num_threads: int = 0, - ) -> Iterator[str]: - warnings.warn( - "stream_chat is deprecated in favor of chat(..., stream=True), and will be removed in the next major version of chatglm-cpp", - DeprecationWarning, - stacklevel=2, - ) - return self.chat( - history=history, - max_length=max_length, - max_context_length=max_context_length, - do_sample=do_sample, - top_k=top_k, - top_p=top_p, - temperature=temperature, - repetition_penalty=repetition_penalty, - num_threads=num_threads, - stream=True, - ) + def _sync_chat(self, input_ids: List[int], gen_config: _C.GenerationConfig) -> ChatMessage: + output_ids = self._sync_generate_ids(input_ids=input_ids, gen_config=gen_config) + return self.tokenizer.decode_message(output_ids) + + def merge_streaming_messages(self, chunks: List[DeltaMessage]) -> ChatMessage: + output_ids = [x for chunk in chunks for x in chunk.token_ids] + return self.tokenizer.decode_message(output_ids) diff --git a/chatglm_cpp/convert.py b/chatglm_cpp/convert.py index 53d5f4c7..32275b98 100644 --- a/chatglm_cpp/convert.py +++ b/chatglm_cpp/convert.py @@ -25,7 +25,7 @@ if platform.system() == "Darwin": # cpm_kernels doesn't support macOS but transformers will check missing packages, so mock it - sys.modules["cpm_kernels"] = object() + sys.modules["cpm_kernels"] = object() # type: ignore class GGMLType(Enum): @@ -47,7 +47,7 @@ class ModelType(Enum): INTERNLM = 1280 -def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor: +def quantize_q8_0(tensor: torch.Tensor) -> torch.Tensor: # equivalent to ggml_quantize_q8_0 in ggml.c assert tensor.shape[1] % GGML_QK8_0 == 0 tensor = tensor.view(-1, GGML_QK8_0) @@ -58,7 +58,7 @@ def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor: return tensor -def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor: +def quantize_q4_0(tensor: torch.Tensor) -> torch.Tensor: # equivalent to ggml_quantize_q4_0 in ggml.c assert tensor.shape[1] % GGML_QK4_0 == 0 tensor = tensor.view(-1, GGML_QK4_0) @@ -73,7 +73,7 @@ def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor: return tensor -def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor: +def quantize_q4_1(tensor: torch.Tensor) -> torch.Tensor: # equivalent to ggml_quantize_q4_1 in ggml.c assert tensor.shape[1] % GGML_QK4_1 == 0 tensor = tensor.view(-1, GGML_QK4_1) @@ -88,7 +88,7 @@ def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor: return tensor -def quantize_q5_0(tensor: torch.Tensor) -> torch.CharTensor: +def quantize_q5_0(tensor: torch.Tensor) -> torch.Tensor: # equivalent to ggml_quantize_q5_0 in ggml.c assert tensor.shape[1] % GGML_QK5_0 == 0 tensor = tensor.view(-1, GGML_QK5_0) @@ -106,7 +106,7 @@ def quantize_q5_0(tensor: torch.Tensor) -> torch.CharTensor: return tensor -def quantize_q5_1(tensor: torch.Tensor) -> torch.CharTensor: +def quantize_q5_1(tensor: torch.Tensor) -> torch.Tensor: # equivalent to ggml_quantize_q5_1 in ggml.c assert tensor.shape[1] % GGML_QK5_1 == 0 tensor = tensor.view(-1, GGML_QK5_1) diff --git a/chatglm_cpp/langchain_api.py b/chatglm_cpp/langchain_api.py index 5cb19115..ceea5988 100644 --- a/chatglm_cpp/langchain_api.py +++ b/chatglm_cpp/langchain_api.py @@ -53,17 +53,27 @@ class ChatResponse(BaseModel): @app.post("/") async def chat(body: ChatRequest) -> ChatResponse: - chat_history = [msg for pair in body.history for msg in pair] + [body.prompt] - response = pipeline.chat( - chat_history, + messages = [] + for prompt, response in body.history: + messages += [ + chatglm_cpp.ChatMessage(role="user", content=prompt), + chatglm_cpp.ChatMessage(role="assistant", content=response), + ] + messages.append(chatglm_cpp.ChatMessage(role="user", content=body.prompt)) + + output = pipeline.chat( + messages, max_length=body.max_length, do_sample=body.temperature > 0, top_p=body.top_p, temperature=body.temperature, ) - history = body.history + [(body.prompt, response)] + history = body.history + [(body.prompt, output.content)] answer = ChatResponse( - response=response, history=history, status=status.HTTP_200_OK, time=datetime.now().strftime("%Y-%m-%d %H:%M:%S") + response=output.content, + history=history, + status=status.HTTP_200_OK, + time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), ) - logging.info(f'prompt: "{body.prompt}", response: "{response}"') + logging.info(f'prompt: "{body.prompt}", response: "{output.content}"') return answer diff --git a/chatglm_cpp/openai_api.py b/chatglm_cpp/openai_api.py index fbda7e95..20976899 100644 --- a/chatglm_cpp/openai_api.py +++ b/chatglm_cpp/openai_api.py @@ -102,14 +102,14 @@ class ChatCompletionResponse(BaseModel): lock = asyncio.Lock() -def stream_chat(history, body): +def stream_chat(messages, body): yield ChatCompletionResponse( object="chat.completion.chunk", choices=[ChatCompletionResponseStreamChoice(delta=DeltaMessage(role="assistant"))], ) - for piece in pipeline.chat( - history, + for chunk in pipeline.chat( + messages=messages, max_length=body.max_tokens, do_sample=body.temperature > 0, top_p=body.top_p, @@ -119,7 +119,7 @@ def stream_chat(history, body): ): yield ChatCompletionResponse( object="chat.completion.chunk", - choices=[ChatCompletionResponseStreamChoice(delta=DeltaMessage(content=piece))], + choices=[ChatCompletionResponseStreamChoice(delta=DeltaMessage(content=chunk.content))], ) yield ChatCompletionResponse( @@ -144,31 +144,31 @@ async def stream_chat_event_publisher(history, body): @app.post("/v1/chat/completions") async def create_chat_completion(body: ChatCompletionRequest) -> ChatCompletionResponse: - # ignore system messages - history = [msg.content for msg in body.messages if msg.role != "system"] - if len(history) % 2 != 1: - raise HTTPException(status.HTTP_400_BAD_REQUEST, "invalid history size") + if not body.messages: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "empty messages") + + messages = [chatglm_cpp.ChatMessage(role=msg.role, content=msg.content) for msg in body.messages] if body.stream: - generator = stream_chat_event_publisher(history, body) + generator = stream_chat_event_publisher(messages, body) return EventSourceResponse(generator) max_context_length = 512 output = pipeline.chat( - history=history, + messages=messages, max_length=body.max_tokens, max_context_length=max_context_length, do_sample=body.temperature > 0, top_p=body.top_p, temperature=body.temperature, ) - logging.info(f'prompt: "{history[-1]}", sync response: "{output}"') - prompt_tokens = len(pipeline.tokenizer.encode_history(history, max_context_length)) - completion_tokens = len(pipeline.tokenizer.encode(output, body.max_tokens)) + logging.info(f'prompt: "{messages[-1].content}", sync response: "{output.content}"') + prompt_tokens = len(pipeline.tokenizer.encode_messages(messages, max_context_length)) + completion_tokens = len(pipeline.tokenizer.encode(output.content, body.max_tokens)) return ChatCompletionResponse( object="chat.completion", - choices=[ChatCompletionResponseChoice(message=ChatMessage(role="assistant", content=output))], + choices=[ChatCompletionResponseChoice(message=ChatMessage(role="assistant", content=output.content))], usage=ChatCompletionUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), ) diff --git a/chatglm_pybind.cpp b/chatglm_pybind.cpp index 2fcd7c7d..24749b8d 100644 --- a/chatglm_pybind.cpp +++ b/chatglm_pybind.cpp @@ -17,8 +17,8 @@ class PyBaseTokenizer : public BaseTokenizer { std::string decode(const std::vector &ids) const override { PYBIND11_OVERLOAD_PURE(std::string, BaseTokenizer, decode, ids); } - std::vector encode_history(const std::vector &history, int max_length) const override { - PYBIND11_OVERLOAD_PURE(std::vector, BaseTokenizer, encode_history, history, max_length); + std::vector encode_messages(const std::vector &history, int max_length) const override { + PYBIND11_OVERLOAD_PURE(std::vector, BaseTokenizer, encode_messages, history, max_length); } }; @@ -32,12 +32,27 @@ class PyBaseModelForCausalLM : public BaseModelForCausalLM { } }; +template +static inline std::string to_string(const T &obj) { + std::ostringstream oss; + oss << obj; + return oss.str(); +} + PYBIND11_MODULE(_C, m) { m.doc() = "ChatGLM.cpp python binding"; + py::enum_(m, "ModelType") + .value("CHATGLM", ModelType::CHATGLM) + .value("CHATGLM2", ModelType::CHATGLM2) + .value("CHATGLM3", ModelType::CHATGLM3) + .value("BAICHUAN7B", ModelType::BAICHUAN7B) + .value("BAICHUAN13B", ModelType::BAICHUAN13B) + .value("INTERNLM", ModelType::INTERNLM); + py::class_(m, "ModelConfig") .def_readonly("model_type", &ModelConfig::model_type) - .def_readonly("dtype", &ModelConfig::dtype) + // .def_readonly("dtype", &ModelConfig::dtype) .def_readonly("vocab_size", &ModelConfig::vocab_size) .def_readonly("hidden_size", &ModelConfig::hidden_size) .def_readonly("num_attention_heads", &ModelConfig::num_attention_heads) @@ -50,17 +65,9 @@ PYBIND11_MODULE(_C, m) { .def_readonly("eos_token_id", &ModelConfig::eos_token_id) .def_readonly("pad_token_id", &ModelConfig::pad_token_id) .def_readonly("sep_token_id", &ModelConfig::sep_token_id) + .def_readonly("extra_eos_token_ids", &ModelConfig::extra_eos_token_ids) .def_property_readonly("model_type_name", &ModelConfig::model_type_name); - py::class_(m, "BaseTokenizer") - .def("encode", &BaseTokenizer::encode) - .def("decode", &BaseTokenizer::decode) - .def("encode_history", &BaseTokenizer::encode_history); - - py::class_(m, "BaseModelForCausalLM") - .def("generate_next_token", &BaseModelForCausalLM::generate_next_token) - .def_readonly("config", &BaseModelForCausalLM::config); - py::class_(m, "GenerationConfig") .def(py::init(), "max_length"_a = 2048, "max_context_length"_a = 512, "do_sample"_a = true, "top_k"_a = 0, "top_p"_a = 0.7, "temperature"_a = 0.95, @@ -74,6 +81,48 @@ PYBIND11_MODULE(_C, m) { .def_readwrite("repetition_penalty", &GenerationConfig::repetition_penalty) .def_readwrite("num_threads", &GenerationConfig::num_threads); + py::class_(m, "FunctionMessage") + .def("__repr__", &to_string) + .def("__str__", &to_string) + .def_readwrite("name", &FunctionMessage::name) + .def_readwrite("arguments", &FunctionMessage::arguments); + + py::class_(m, "CodeMessage") + .def("__repr__", &to_string) + .def("__str__", &to_string) + .def_readwrite("input", &CodeMessage::input); + + py::class_(m, "ToolCallMessage") + .def("__repr__", &to_string) + .def("__str__", &to_string) + .def_readwrite("type", &ToolCallMessage::type) + .def_readwrite("function", &ToolCallMessage::function) + .def_readwrite("code", &ToolCallMessage::code); + + py::class_(m, "ChatMessage") + .def(py::init>(), "role"_a, "content"_a, + "tool_calls"_a = std::vector{}) + .def("__repr__", &to_string) + .def("__str__", &to_string) + .def_readonly_static("ROLE_SYSTEM", &ChatMessage::ROLE_SYSTEM) + .def_readonly_static("ROLE_USER", &ChatMessage::ROLE_USER) + .def_readonly_static("ROLE_ASSISTANT", &ChatMessage::ROLE_ASSISTANT) + .def_readonly_static("ROLE_OBSERVATION", &ChatMessage::ROLE_OBSERVATION) + .def_readwrite("role", &ChatMessage::role) + .def_readwrite("content", &ChatMessage::content) + .def_readwrite("tool_calls", &ChatMessage::tool_calls); + + py::class_(m, "BaseTokenizer") + .def("encode", &BaseTokenizer::encode, "text"_a, "max_length"_a) + .def("decode", &BaseTokenizer::decode, "ids"_a) + .def("encode_messages", &BaseTokenizer::encode_messages, "messages"_a, "max_length"_a) + .def("decode_message", &BaseTokenizer::decode_message, "ids"_a); + + py::class_(m, "BaseModelForCausalLM") + .def("generate_next_token", &BaseModelForCausalLM::generate_next_token, "input_ids"_a, "gen_config"_a, + "n_past"_a, "n_ctx"_a) + .def_readonly("config", &BaseModelForCausalLM::config); + // ===== ChatGLM ===== py::class_(m, "ChatGLMTokenizer"); @@ -109,7 +158,7 @@ PYBIND11_MODULE(_C, m) { // ===== Pipeline ==== py::class_(m, "Pipeline") - .def(py::init()) + .def(py::init(), "path"_a) .def_property_readonly("model", [](const Pipeline &self) { return self.model.get(); }) .def_property_readonly("tokenizer", [](const Pipeline &self) { return self.tokenizer.get(); }); } diff --git a/chatglm_test.cpp b/chatglm_test.cpp index 6df30289..03b8cd5d 100644 --- a/chatglm_test.cpp +++ b/chatglm_test.cpp @@ -1022,12 +1022,15 @@ TEST(Pipeline, ChatGLM) { // prompter { - EXPECT_EQ(ChatGLMTokenizer::build_prompt({"你好"}), "你好"); - EXPECT_EQ(ChatGLMTokenizer::build_prompt( - {"你好", "你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。", - "晚上睡不着应该怎么办"}), - "[Round 0]\n问:你好\n答:你好👋!我是人工智能助手 " - "ChatGLM-6B,很高兴见到你,欢迎问我任何问题。\n[Round 1]\n问:晚上睡不着应该怎么办\n答:"); + EXPECT_EQ(ChatGLMTokenizer::build_prompt({{ChatMessage::ROLE_USER, "你好"}}), "你好"); + EXPECT_EQ( + ChatGLMTokenizer::build_prompt({ + {ChatMessage::ROLE_USER, "你好"}, + {ChatMessage::ROLE_ASSISTANT, "你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"}, + {ChatMessage::ROLE_USER, "晚上睡不着应该怎么办"}, + }), + "[Round 0]\n问:你好\n答:你好👋!我是人工智能助手 " + "ChatGLM-6B,很高兴见到你,欢迎问我任何问题。\n[Round 1]\n问:晚上睡不着应该怎么办\n答:"); } // memory test @@ -1041,17 +1044,17 @@ TEST(Pipeline, ChatGLM) { for (int i = 0; i < gen_config.max_context_length; i++) { oss << "你好"; } - std::vector history{oss.str()}; - pipeline.chat(history, gen_config); + std::vector messages{{ChatMessage::ROLE_USER, oss.str()}}; + pipeline.chat(messages, gen_config); } // chat { GenerationConfig gen_config; gen_config.do_sample = false; - std::vector history{"你好"}; - std::string output = pipeline.chat(history, gen_config); - EXPECT_EQ(output, "你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"); + std::vector messages{{ChatMessage::ROLE_USER, "你好"}}; + ChatMessage output = pipeline.chat(messages, gen_config); + EXPECT_EQ(output.content, "你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"); } } @@ -1083,12 +1086,15 @@ TEST(Pipeline, ChatGLM2) { // prompter { - EXPECT_EQ(ChatGLM2Tokenizer::build_prompt({"你好"}), "[Round 1]\n\n问:你好\n\n答:"); - EXPECT_EQ(ChatGLM2Tokenizer::build_prompt( - {"你好", "你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。", - "晚上睡不着应该怎么办"}), - "[Round 1]\n\n问:你好\n\n答:你好👋!我是人工智能助手 " - "ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。\n\n[Round 2]\n\n问:晚上睡不着应该怎么办\n\n答:"); + EXPECT_EQ(ChatGLM2Tokenizer::build_prompt({{ChatMessage::ROLE_USER, "你好"}}), "[Round 1]\n\n问:你好\n\n答:"); + EXPECT_EQ( + ChatGLM2Tokenizer::build_prompt({ + {ChatMessage::ROLE_USER, "你好"}, + {ChatMessage::ROLE_ASSISTANT, "你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。"}, + {ChatMessage::ROLE_USER, "晚上睡不着应该怎么办"}, + }), + "[Round 1]\n\n问:你好\n\n答:你好👋!我是人工智能助手 " + "ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。\n\n[Round 2]\n\n问:晚上睡不着应该怎么办\n\n答:"); } // memory test @@ -1102,20 +1108,25 @@ TEST(Pipeline, ChatGLM2) { for (int i = 0; i < gen_config.max_context_length; i++) { oss << "你好"; } - std::vector history{oss.str()}; - pipeline.chat(history, gen_config); + std::vector messages{{ChatMessage::ROLE_USER, oss.str()}}; + pipeline.chat(messages, gen_config); } // chat { GenerationConfig gen_config; gen_config.do_sample = false; - std::vector history{"你好"}; - std::string output = pipeline.chat(history, gen_config); - EXPECT_EQ(output, "你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。"); + std::vector messages{{ChatMessage::ROLE_USER, "你好"}}; + ChatMessage output = pipeline.chat(messages, gen_config); + EXPECT_EQ(output.content, "你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。"); } } +static inline std::string read_text(const fs::path &path) { + MappedFile mapped_file(path.string()); + return std::string(mapped_file.data, mapped_file.size); +} + TEST(Pipeline, ChatGLM3) { fs::path model_path = fs::path(__FILE__).parent_path() / "chatglm3-ggml.bin"; if (!fs::exists(model_path)) { @@ -1124,29 +1135,62 @@ TEST(Pipeline, ChatGLM3) { Pipeline pipeline(model_path.string()); EXPECT_TRUE(dynamic_cast(pipeline.model.get())); + const std::string system_tool_call = + read_text(fs::path(__FILE__).parent_path() / "examples/system/function_call.txt"); + const std::string system_ci = read_text(fs::path(__FILE__).parent_path() / "examples/system/code_interpreter.txt"); + // tokenizer { - std::vector cases{{"你好", {64790, 64792, 36474, 54591}}}; - check_tokenizer(pipeline.tokenizer.get(), cases); - - { - std::vector history{"你好"}; - std::vector input_ids = pipeline.tokenizer->encode_history(history, 2048); - std::vector target_ids{64790, 64792, 64795, 30910, 13, 36474, 54591, 64796, 30910, 13}; - EXPECT_EQ(input_ids, target_ids); - } - { - std::vector history{"你好", - "你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。", - "晚上睡不着应该怎么办"}; - std::vector input_ids = pipeline.tokenizer->encode_history(history, 2048); - std::vector target_ids{64790, 64792, 64795, 30910, 13, 36474, 54591, 64796, 30910, 13, - 36474, 54591, 243, 162, 148, 142, 31404, 33030, 34797, 42481, - 22011, 10461, 30944, 30966, 30941, 30978, 30949, 31123, 48895, 35214, - 54622, 31123, 32616, 39905, 31901, 31639, 31155, 64795, 30910, 13, - 30910, 32820, 54266, 31876, 35153, 64796, 30910, 13}; - EXPECT_EQ(input_ids, target_ids); - } + std::vector target_ids{64790, 64792, 36474, 54591}; + std::vector input_ids = pipeline.tokenizer->encode("你好", 2048); + EXPECT_EQ(input_ids, target_ids); + } + { + std::vector messages{{ChatMessage::ROLE_USER, "你好"}}; + std::vector input_ids = pipeline.tokenizer->encode_messages(messages, 2048); + std::vector target_ids{64790, 64792, 64795, 30910, 13, 36474, 54591, 64796}; + EXPECT_EQ(input_ids, target_ids); + } + { + std::vector messages{ + {ChatMessage::ROLE_USER, "你好"}, + {ChatMessage::ROLE_ASSISTANT, "你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。"}, + {ChatMessage::ROLE_USER, "晚上睡不着应该怎么办"}, + }; + std::vector input_ids = pipeline.tokenizer->encode_messages(messages, 2048); + std::vector target_ids{64790, 64792, 64795, 30910, 13, 36474, 54591, 64796, 30910, 13, 36474, 54591, + 243, 162, 148, 142, 31404, 33030, 34797, 42481, 22011, 10461, 30944, 30966, + 30941, 30978, 30949, 31123, 48895, 35214, 54622, 31123, 32616, 39905, 31901, 31639, + 31155, 64795, 30910, 13, 30910, 32820, 54266, 31876, 35153, 64796}; + EXPECT_EQ(input_ids, target_ids); + } + { + std::vector messages{ + {ChatMessage::ROLE_SYSTEM, system_tool_call}, + {ChatMessage::ROLE_USER, "生成一个随机数"}, + }; + std::vector input_ids = pipeline.tokenizer->encode_messages(messages, 2048); + std::vector target_ids{ + 64790, 64792, 64794, 30910, 13, 20115, 267, 1762, 2554, 362, 1077, 362, 344, 457, 30930, + 809, 431, 1675, 289, 267, 1762, 4159, 30954, 13, 30982, 13, 296, 30955, 16599, 30962, + 11228, 30962, 7311, 1306, 2932, 729, 13, 352, 30955, 2323, 2932, 449, 16599, 30962, 11228, + 30962, 7311, 1306, 1252, 13, 352, 30955, 16302, 2932, 449, 9398, 711, 260, 5402, 1276, + 1994, 30932, 268, 30930, 30912, 30930, 2288, 30995, 30940, 30996, 14819, 1994, 906, 2288, 30995, + 30939, 30996, 1252, 13, 352, 30955, 12209, 2932, 790, 13, 753, 30982, 13, 647, 30955, + 2323, 2932, 449, 24794, 1252, 13, 647, 30955, 16302, 2932, 449, 1036, 5402, 9352, 1050, + 422, 267, 17009, 1252, 13, 647, 30955, 3543, 2932, 449, 592, 1252, 13, 647, 30955, + 20379, 2932, 2033, 13, 753, 4143, 13, 753, 30982, 13, 647, 30955, 2323, 2932, 449, + 7855, 1252, 13, 647, 30955, 16302, 2932, 449, 1036, 2288, 290, 267, 7383, 3859, 1252, + 13, 647, 30955, 3543, 2932, 449, 30912, 16471, 30995, 592, 30932, 558, 30996, 1252, 13, + 647, 30955, 20379, 2932, 2033, 13, 753, 30983, 13, 352, 30996, 13, 296, 4143, 13, + 296, 30955, 752, 30962, 27564, 2932, 729, 13, 352, 30955, 2323, 2932, 449, 752, 30962, + 27564, 1252, 13, 352, 30955, 16302, 2932, 449, 4867, 267, 1465, 5100, 332, 4256, 17654, + 30962, 2323, 31040, 1252, 13, 352, 30955, 12209, 2932, 790, 13, 753, 30982, 13, 647, + 30955, 2323, 2932, 449, 17654, 30962, 2323, 1252, 13, 647, 30955, 16302, 2932, 449, 1036, + 1462, 290, 267, 1911, 289, 330, 580, 266, 819, 1252, 13, 647, 30955, 3543, 2932, + 449, 2069, 1252, 13, 647, 30955, 20379, 2932, 2033, 13, 753, 30983, 13, 352, 30996, + 13, 296, 30983, 13, 30983, 64795, 30910, 13, 30910, 36454, 31623, 37853, 54744, 64796}; + EXPECT_EQ(input_ids, target_ids); } // memory test @@ -1160,17 +1204,90 @@ TEST(Pipeline, ChatGLM3) { for (int i = 0; i < gen_config.max_context_length; i++) { oss << "你好"; } - std::vector history{oss.str()}; - pipeline.chat(history, gen_config); + std::vector messages{{ChatMessage::ROLE_USER, oss.str()}}; + pipeline.chat(messages, gen_config); } // chat { GenerationConfig gen_config; gen_config.do_sample = false; - std::vector history{"你好"}; - std::string output = pipeline.chat(history, gen_config); - EXPECT_EQ(output, "你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。"); + std::vector messages{{ChatMessage::ROLE_USER, "你好"}}; + ChatMessage output = pipeline.chat(messages, gen_config); + EXPECT_EQ(output.content, "你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。"); + } + + // tool call + { + GenerationConfig gen_config; + gen_config.do_sample = false; + std::vector messages{ + {ChatMessage::ROLE_SYSTEM, system_tool_call}, + {ChatMessage::ROLE_USER, "生成一个随机数"}, + }; + { + ChatMessage output = pipeline.chat(messages, gen_config); + EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT); + EXPECT_EQ(output.content, "```python\n" + "tool_call(seed=42, range=(0, 100))\n" + "```"); + messages.emplace_back(std::move(output)); + } + messages.emplace_back(ChatMessage::ROLE_OBSERVATION, "22"); + { + ChatMessage output = pipeline.chat(messages, gen_config); + EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT); + EXPECT_EQ(output.content, "根据您的要求,我使用随机数生成器API生成了一个在0和100之间的随机数,结果为22。"); + } + } + + // code interpreter + { + GenerationConfig gen_config; + gen_config.do_sample = false; + std::vector messages{ + {ChatMessage::ROLE_SYSTEM, system_ci}, + {ChatMessage::ROLE_USER, "列出100以内的所有质数"}, + }; + { + ChatMessage output = pipeline.chat(messages, gen_config); + EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT); + EXPECT_EQ(output.content, "好的,我会为您列出100以内的所有质数。\n\n质数是指只能被1和它本身整除的大于1" + "的整数。例如,2、3、5、7等都是质数。\n\n让我们开始吧!"); + EXPECT_EQ(output.tool_calls.front().code.input, R"(```python +def is_prime(n): + """Check if a number is prime.""" + if n <= 1: + return False + if n <= 3: + return True + if n % 2 == 0 or n % 3 == 0: + return False + i = 5 + while i * i <= n: + if n % i == 0 or n % (i + 2) == 0: + return False + i += 6 + return True + +# Get all prime numbers up to 100 +primes_upto_100 = [i for i in range(2, 101) if is_prime(i)] +primes_upto_100 +```)"); + messages.emplace_back(std::move(output)); + } + messages.emplace_back( + ChatMessage::ROLE_OBSERVATION, + "[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]"); + { + ChatMessage output = pipeline.chat(messages, gen_config); + EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT); + EXPECT_EQ(output.content, R"(100以内的所有质数为: + +$$ +2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97 +$$)"); + } } } @@ -1234,9 +1351,12 @@ TEST(Pipeline, Baichuan13B) { 1910, 73, 6011, 31169, 4315, 1766, 72, 1231, 11533, 31490, 31182, 21934}}}; check_tokenizer(pipeline.tokenizer.get(), cases); - std::vector history{"你好呀", "你好!很高兴和你交流。请问有什么我可以帮助你的吗?", - "你叫什么名字?"}; - std::vector input_ids = pipeline.tokenizer->encode_history(history, 2048); + std::vector messages{ + {ChatMessage::ROLE_USER, "你好呀"}, + {ChatMessage::ROLE_ASSISTANT, "你好!很高兴和你交流。请问有什么我可以帮助你的吗?"}, + {ChatMessage::ROLE_USER, "你叫什么名字?"}, + }; + std::vector input_ids = pipeline.tokenizer->encode_messages(messages, 2048); std::vector target_input_ids{195, 9875, 31213, 32889, 196, 9875, 31213, 74, 17318, 31906, 14822, 5536, 73, 20389, 7713, 31182, 1231, 4090, 2689, 31763, 75, 195, 9875, 32177, 1534, 10240, 75, 196}; @@ -1259,9 +1379,9 @@ TEST(Pipeline, Baichuan13B) { GenerationConfig gen_config; gen_config.do_sample = false; gen_config.repetition_penalty = 1.1; - std::vector history{"你好呀"}; - std::string output = pipeline.chat(history, gen_config); - EXPECT_EQ(output, "你好!很高兴见到你。请问有什么我可以帮助你的吗?"); + std::vector messages{{ChatMessage::ROLE_USER, "你好呀"}}; + ChatMessage output = pipeline.chat(messages, gen_config); + EXPECT_EQ(output.content, "你好!很高兴见到你。请问有什么我可以帮助你的吗?"); } } @@ -1285,9 +1405,12 @@ TEST(Pipeline, Baichuan2_7B) { 2089, 23672, 1940, 1760, 66, 4173, 23181, 1754, 65, 65351, 39975, 14590}}}; check_tokenizer(pipeline.tokenizer.get(), cases); - std::vector history{"你好呀", "你好!很高兴和你交流。请问有什么问题我可以帮助你解决吗?", - "你叫什么名字?"}; - std::vector input_ids = pipeline.tokenizer->encode_history(history, 2048); + std::vector messages{ + {ChatMessage::ROLE_USER, "你好呀"}, + {ChatMessage::ROLE_ASSISTANT, "你好!很高兴和你交流。请问有什么问题我可以帮助你解决吗?"}, + {ChatMessage::ROLE_USER, "你叫什么名字?"}, + }; + std::vector input_ids = pipeline.tokenizer->encode_messages(messages, 2048); std::vector target_input_ids{195, 16829, 94278, 196, 16829, 67, 52160, 10329, 3341, 66, 23216, 5817, 1754, 92392, 21777, 92430, 2740, 93122, 68, 195, 92430, 93410, 1747, 6642, 68, 196}; @@ -1310,9 +1433,9 @@ TEST(Pipeline, Baichuan2_7B) { GenerationConfig gen_config; gen_config.do_sample = false; gen_config.repetition_penalty = 1.05; - std::vector history{"你好呀"}; - std::string output = pipeline.chat(history, gen_config); - EXPECT_EQ(output, "你好!很高兴为你服务。请问有什么问题我可以帮助你解决?"); + std::vector messages{{ChatMessage::ROLE_USER, "你好呀"}}; + ChatMessage output = pipeline.chat(messages, gen_config); + EXPECT_EQ(output.content, "你好!很高兴为你服务。请问有什么问题我可以帮助你解决?"); } } @@ -1336,9 +1459,12 @@ TEST(Pipeline, Baichuan2_13B) { 2089, 23672, 1940, 1760, 66, 4173, 23181, 1754, 65, 65351, 39975, 14590}}}; check_tokenizer(pipeline.tokenizer.get(), cases); - std::vector history{"你好呀", "你好!很高兴和你交流。请问有什么我可以帮助你的吗?", - "你叫什么名字?"}; - std::vector input_ids = pipeline.tokenizer->encode_history(history, 2048); + std::vector messages{ + {ChatMessage::ROLE_USER, "你好呀"}, + {ChatMessage::ROLE_ASSISTANT, "你好!很高兴和你交流。请问有什么我可以帮助你的吗?"}, + {ChatMessage::ROLE_USER, "你叫什么名字?"}, + }; + std::vector input_ids = pipeline.tokenizer->encode_messages(messages, 2048); std::vector target_input_ids{195, 16829, 94278, 196, 16829, 67, 52160, 10329, 3341, 66, 23216, 5817, 92392, 21777, 2193, 93122, 68, 195, 92430, 93410, 1747, 6642, 68, 196}; EXPECT_TRUE(equal(input_ids, target_input_ids)); @@ -1349,9 +1475,9 @@ TEST(Pipeline, Baichuan2_13B) { GenerationConfig gen_config; gen_config.do_sample = false; gen_config.repetition_penalty = 1.05; - std::vector history{"你好呀"}; - std::string output = pipeline.chat(history, gen_config); - EXPECT_EQ(output, "你好!很高兴见到你。请问有什么我可以帮助你的吗?"); + std::vector messages{{ChatMessage::ROLE_USER, "你好呀"}}; + ChatMessage output = pipeline.chat(messages, gen_config); + EXPECT_EQ(output.content, "你好!很高兴见到你。请问有什么我可以帮助你的吗?"); } } @@ -1378,8 +1504,12 @@ TEST(Pipeline, InternLM) { // prompter { - EXPECT_EQ(InternLMTokenizer::build_prompt({"你好"}), "<|User|>:你好\n<|Bot|>:"); - EXPECT_EQ(InternLMTokenizer::build_prompt({"你好", "你好,有什么我可以帮助你的吗?", "晚上睡不着应该怎么办"}), + EXPECT_EQ(InternLMTokenizer::build_prompt({{ChatMessage::ROLE_USER, "你好"}}), "<|User|>:你好\n<|Bot|>:"); + EXPECT_EQ(InternLMTokenizer::build_prompt({ + {ChatMessage::ROLE_USER, "你好"}, + {ChatMessage::ROLE_ASSISTANT, "你好,有什么我可以帮助你的吗?"}, + {ChatMessage::ROLE_USER, "晚上睡不着应该怎么办"}, + }), "<|User|>:你好\n<|Bot|>:你好,有什么我可以帮助你的吗?\n<|User|>:晚上睡不着应该怎么办" "\n<|Bot|>:"); } @@ -1399,9 +1529,9 @@ TEST(Pipeline, InternLM) { { GenerationConfig gen_config; gen_config.do_sample = false; - std::vector history{"你好"}; - std::string output = pipeline.chat(history, gen_config); - EXPECT_EQ(output, "你好,有什么我可以帮助你的吗?"); + std::vector messages{{ChatMessage::ROLE_USER, "你好"}}; + ChatMessage output = pipeline.chat(messages, gen_config); + EXPECT_EQ(output.content, "你好,有什么我可以帮助你的吗?"); } } @@ -1416,8 +1546,11 @@ static void run_benchmark(const fs::path &model_path) { int64_t load_model_ms = ggml_time_ms() - start_ms; start_ms = ggml_time_ms(); - std::vector history{"你好", "你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。", - "晚上睡不着应该怎么办"}; + std::vector messages{ + {ChatMessage::ROLE_USER, "你好"}, + {ChatMessage::ROLE_ASSISTANT, "你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"}, + {ChatMessage::ROLE_USER, "晚上睡不着应该怎么办"}, + }; GenerationConfig gen_config; gen_config.do_sample = false; @@ -1425,7 +1558,7 @@ static void run_benchmark(const fs::path &model_path) { PerfStreamer streamer; start_ms = ggml_time_ms(); - pipeline.chat(history, gen_config, &streamer); + pipeline.chat(messages, gen_config, &streamer); int64_t gen_s = (ggml_time_ms() - start_ms) / 1000.f; std::cout << "======== benchmark results for " << model_path.filename() << " ========\n" diff --git a/docs/code_interpreter.png b/docs/code_interpreter.png new file mode 100644 index 00000000..6ae1873c Binary files /dev/null and b/docs/code_interpreter.png differ diff --git a/docs/function_call.png b/docs/function_call.png new file mode 100644 index 00000000..91b29efd Binary files /dev/null and b/docs/function_call.png differ diff --git a/examples/chatglm3_demo.py b/examples/chatglm3_demo.py new file mode 100644 index 00000000..5c200bb8 --- /dev/null +++ b/examples/chatglm3_demo.py @@ -0,0 +1,344 @@ +from __future__ import annotations + +import base64 +import functools +import io +import json +import queue +import re +import traceback +from enum import Enum +from pathlib import Path +from typing import Callable + +import chatglm_cpp +import jupyter_client +import streamlit as st +from PIL import Image + +IPYKERNEL = "chatglm3-demo" +MODEL_PATH = Path(__file__).resolve().parent.parent / "chatglm3-ggml.bin" + +CHAT_SYSTEM_PROMPT = "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown." + +TOOLS = [ + { + "name": "random_number_generator", + "description": "Generates a random number x, s.t. range[0] <= x < range[1]", + "parameters": { + "type": "object", + "properties": { + "seed": {"description": "The random seed used by the generator", "type": "integer"}, + "range": { + "description": "The range of the generated numbers", + "type": "array", + "items": [{"type": "integer"}, {"type": "integer"}], + }, + }, + "required": ["seed", "range"], + }, + }, + { + "name": "get_weather", + "description": "Get the current weather for `city_name`", + "parameters": { + "type": "object", + "properties": {"city_name": {"description": "The name of the city to be queried", "type": "string"}}, + "required": ["city_name"], + }, + }, +] + +TOOL_SYSTEM_PROMPT = ( + "Answer the following questions as best as you can. You have access to the following tools:\n" + + json.dumps(TOOLS, indent=4) +) + +CI_SYSTEM_PROMPT = "你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。" + + +class Mode(str, Enum): + CHAT = "💬 Chat" + TOOL = "🛠️ Tool" + CI = "🧑‍💻 Code Interpreter" + + +@st.cache_resource +def get_model(model_path: str) -> chatglm_cpp.Pipeline: + return chatglm_cpp.Pipeline(model_path) + + +class Message(chatglm_cpp.ChatMessage): + def __init__( + self, role: str, content: str, tool_calls: list | None = None, image: Image.Image | None = None + ) -> None: + if tool_calls is None: + tool_calls = [] + super().__init__(role, content, tool_calls) + self.image = image + + @staticmethod + def from_cpp(cpp_message: chatglm_cpp.ChatMessage) -> Message: + return Message( + role=cpp_message.role, content=cpp_message.content, tool_calls=cpp_message.tool_calls, image=None + ) + + +def show_message(message: Message) -> None: + role_avatars = {"user": "user", "observation": "user", "assistant": "assistant"} + avatar = role_avatars.get(message.role) + if avatar is None: + st.error(f"Unexpected message role {message.role}") + return + + display_content = message.content + if message.tool_calls: + (tool_call,) = message.tool_calls + if tool_call.type == "function": + display_content = f"{tool_call.function.name}\n{display_content}" + elif tool_call.type == "code": + display_content += "\n" + tool_call.code.input + + if message.role == "observation": + display_content = f"```\n{display_content.strip()}\n```" + + with st.chat_message(name=message.role, avatar=avatar): + if message.image: + st.image(message.image) + else: + st.markdown(display_content) + + +# ----- begin function call ----- + +_FUNCTION_REGISTRY = {} + + +def register_function(func: Callable) -> Callable: + _FUNCTION_REGISTRY[func.__name__] = func + + @functools.wraps(func) + def wrap(*args, **kwargs): + return func(*args, **kwargs) + + return wrap + + +@register_function +def random_number_generator(seed: int, range: tuple[int, int]) -> int: + import random + + return random.Random(seed).randint(*range) + + +@register_function +def get_weather(city_name: str) -> str: + import requests + + key_selection = { + "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"], + } + resp = requests.get(f"https://wttr.in/{city_name}?format=j1") + resp.raise_for_status() + resp = resp.json() + + ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()} + return json.dumps(ret) + + +def run_function(name: str, arguments: str) -> str: + def tool_call(**kwargs): + return kwargs + + func = _FUNCTION_REGISTRY.get(name) + if func is None: + return f"Function `{name}` is not defined" + + try: + kwargs = eval(arguments, dict(tool_call=tool_call)) + except Exception: + return f"Invalid arguments {arguments}" + + try: + return str(func(**kwargs)) + except Exception: + return traceback.format_exc() + + +# ----- end function call ----- + +# ----- begin code interpreter ----- + + +@st.cache_resource +def get_kernel_client(kernel_name) -> jupyter_client.BlockingKernelClient: + km = jupyter_client.KernelManager(kernel_name=kernel_name) + km.start_kernel() + + kc: jupyter_client.BlockingKernelClient = km.blocking_client() + kc.start_channels() + + return kc + + +def clean_ansi_codes(text: str) -> str: + ansi_escape = re.compile(r"(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]") + return ansi_escape.sub("", text) + + +def extract_code(text: str) -> str: + return re.search(r"```.*?\n(.*?)```", text, re.DOTALL)[1] + + +def run_code(kc: jupyter_client.BlockingKernelClient, code: str) -> str | Image.Image: + kc.execute(code) + + try: + shell_msg = kc.get_shell_msg(timeout=30) + io_msg_content = None + while True: + try: + next_io_msg_content = kc.get_iopub_msg(timeout=30)["content"] + except queue.Empty: + break + if next_io_msg_content.get("execution_state") == "idle": + break + io_msg_content = next_io_msg_content + + if shell_msg["metadata"]["status"] == "timeout": + return "Execution Timeout Expired" + + if shell_msg["metadata"]["status"] == "error": + try: + traceback_content = clean_ansi_codes(io_msg_content["traceback"][-1]) + except Exception: + traceback_content = "Traceback Error" + return traceback_content + + if "text" in io_msg_content: + return io_msg_content["text"] + + data_content = io_msg_content.get("data") + if data_content is not None: + image_content = data_content.get("image/png") + if image_content is not None: + return Image.open(io.BytesIO(base64.b64decode(image_content))) + + text_content = data_content.get("text/plain") + if text_content is not None: + return text_content + + return "" + + except Exception: + return traceback.format_exc() + + +# ----- end code interpreter ----- + + +def main(): + st.set_page_config(page_title="ChatGLM3 Demo", page_icon="🚀", layout="centered", initial_sidebar_state="auto") + + pipeline = get_model(MODEL_PATH) + + st.session_state.setdefault("messages", []) + + st.title("ChatGLM3 Demo") + + prompt = st.chat_input("Chat with ChatGLM3!", key="chat_input") + + mode = st.radio("Mode", [x.value for x in Mode], horizontal=True, label_visibility="hidden") + + DEFAULT_SYSTEM_PROMPT_MAP = { + Mode.CHAT: CHAT_SYSTEM_PROMPT, + Mode.TOOL: TOOL_SYSTEM_PROMPT, + Mode.CI: CI_SYSTEM_PROMPT, + } + default_system_prompt = DEFAULT_SYSTEM_PROMPT_MAP.get(mode) + if default_system_prompt is None: + st.error(f"Unexpected mode {mode}") + + with st.sidebar: + top_p = st.slider(label="Top P", min_value=0.0, max_value=1.0, value=0.8, step=0.01) + temperature = st.slider(label="Temperature", min_value=0.0, max_value=1.5, value=0.8, step=0.01) + max_length = st.slider(label="Max Length", min_value=128, max_value=2048, value=2048, step=16) + max_context_length = st.slider(label="Max Context Length", min_value=128, max_value=2048, value=1536, step=16) + system_prompt = st.text_area(label="System Prompt", value=default_system_prompt, height=300) + if st.button(label="Clear Context", type="primary"): + st.session_state.messages = [] + + messages: list[Message] = st.session_state.messages + + for msg in messages: + show_message(msg) + + if not prompt: + return + + prompt = prompt.strip() + messages.append(Message(role="user", content=prompt)) + show_message(messages[-1]) + + TOOL_CALL_MAX_RETRY = 5 + for _ in range(TOOL_CALL_MAX_RETRY): + messages_with_system = [] + if system_prompt: + messages_with_system.append(Message(role="system", content=system_prompt)) + messages_with_system += messages + + chunks = [] + response = "" + + with st.chat_message(name="assistant", avatar="assistant"): + message_placeholder = st.empty() + + for chunk in pipeline.chat( + messages_with_system, + max_length=max_length, + max_context_length=max_context_length, + do_sample=temperature > 0, + top_k=0, + top_p=top_p, + temperature=temperature, + repetition_penalty=1.0, + num_threads=0, + stream=True, + ): + response += chunk.content + chunks.append(chunk) + message_placeholder.markdown(response + "▌") + + message_placeholder.markdown(response) + + reply_message = Message.from_cpp(pipeline.merge_streaming_messages(chunks)) + messages.append(reply_message) + if not reply_message.tool_calls: + break + + (tool_call,) = reply_message.tool_calls + if tool_call.type == "function": + with st.spinner(f"Calling function `{tool_call.function.name}` ..."): + observation = run_function(tool_call.function.name, tool_call.function.arguments) + elif tool_call.type == "code": + kc = get_kernel_client(IPYKERNEL) + code = extract_code(tool_call.code.input) + with st.spinner(f"Executing code ..."): + observation = run_code(kc, code) + else: + st.error(f"Unexpected tool call type {tool_call.type}") + return + + OBSERVATION_MAX_LENGTH = 1024 + if isinstance(observation, str) and len(observation) > OBSERVATION_MAX_LENGTH: + observation = observation[:OBSERVATION_MAX_LENGTH] + " [TRUNCATED]" + + if isinstance(observation, str): + messages.append(Message(role="observation", content=observation)) + else: + messages.append(Message(role="observation", content="[IMAGE]", image=observation)) + show_message(messages[-1]) + + +if __name__ == "__main__": + main() diff --git a/examples/cli_chat.py b/examples/cli_demo.py similarity index 51% rename from examples/cli_chat.py rename to examples/cli_demo.py index c5cda138..1a07efaa 100644 --- a/examples/cli_chat.py +++ b/examples/cli_demo.py @@ -1,5 +1,6 @@ import argparse from pathlib import Path +from typing import List import chatglm_cpp @@ -18,11 +19,24 @@ WELCOME_MESSAGE = "Welcome to ChatGLM.cpp! Ask whatever you want. Type 'clear' to clear context. Type 'stop' to exit." -def main(): +def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("-m", "--model", default=DEFAULT_MODEL_PATH, type=Path, help="model path") + parser.add_argument("-m", "--model", default=DEFAULT_MODEL_PATH, type=str, help="model path") parser.add_argument("--mode", default="chat", type=str, choices=["chat", "generate"], help="inference mode") parser.add_argument("-p", "--prompt", default="你好", type=str, help="prompt to start generation with") + parser.add_argument( + "--pp", "--prompt_path", default=None, type=Path, help="path to the plain text file that stores the prompt" + ) + parser.add_argument( + "-s", "--system", default=None, type=str, help="system message to set the behavior of the assistant" + ) + parser.add_argument( + "--sp", + "--system_path", + default=None, + type=Path, + help="path to the plain text file that stores the system message", + ) parser.add_argument("-i", "--interactive", action="store_true", help="run in interactive mode") parser.add_argument( "-l", "--max_length", default=2048, type=int, help="max total length including prompt and output" @@ -35,6 +49,14 @@ def main(): parser.add_argument("-t", "--threads", default=0, type=int, help="number of threads for inference") args = parser.parse_args() + prompt = args.prompt + if args.pp: + prompt = args.pp.read_text() + + system = args.system + if args.sp: + system = args.sp.read_text() + pipeline = chatglm_cpp.Pipeline(args.model) if args.mode != "chat" and args.interactive: @@ -52,14 +74,20 @@ def main(): stream=True, ) + system_messages: List[chatglm_cpp.ChatMessage] = [] + if system is not None: + system_messages.append(chatglm_cpp.ChatMessage(role="system", content=system)) + + messages = system_messages.copy() + if not args.interactive: - generator = ( - pipeline.chat([args.prompt], **generation_kwargs) - if args.mode == "chat" - else pipeline.generate(args.prompt, **generation_kwargs) - ) - for piece in generator: - print(piece, sep="", end="", flush=True) + if args.mode == "chat": + messages.append(chatglm_cpp.ChatMessage(role="user", content=prompt)) + for chunk in pipeline.chat(messages, **generation_kwargs): + print(chunk.content, sep="", end="", flush=True) + else: + for chunk in pipeline.generate(prompt, **generation_kwargs): + print(chunk, sep="", end="", flush=True) print() return @@ -67,27 +95,52 @@ def main(): print() print(WELCOME_MESSAGE) print() - history = [] + + prompt_width = len(pipeline.model.config.model_type_name) + + if system: + print(f"{'System':{prompt_width}} > {system}") + while True: + if messages and messages[-1].tool_calls: + (tool_call,) = messages[-1].tool_calls + if tool_call.type == "function": + print( + f"Function Call > Please manually call function `{tool_call.function.name}` and provide the results below." + ) + input_prompt = "Observation > " + elif tool_call.type == "code": + print(f"Code Interpreter > Please manually run the code and provide the results below.") + input_prompt = "Observation > " + else: + raise ValueError(f"unexpected tool call type {tool_call.type}") + role = "observation" + else: + input_prompt = f"{'Prompt':{prompt_width}} > " + role = "user" + try: - prompt = input(f"{'Prompt':{len(pipeline.model.config.model_type_name)}} > ") + prompt = input(input_prompt) except EOFError: break + if not prompt: continue if prompt == "stop": break if prompt == "clear": - history = [] + messages = system_messages continue - history.append(prompt) + + messages.append(chatglm_cpp.ChatMessage(role=role, content=prompt)) print(f"{pipeline.model.config.model_type_name} > ", sep="", end="") - output = "" - for piece in pipeline.chat(history, **generation_kwargs): - print(piece, sep="", end="", flush=True) - output += piece + chunks = [] + for chunk in pipeline.chat(messages, **generation_kwargs): + print(chunk.content, sep="", end="", flush=True) + chunks.append(chunk) print() - history.append(output) + messages.append(pipeline.merge_streaming_messages(chunks)) + print("Bye") diff --git a/examples/langchain_openai_client.py b/examples/langchain_openai_client.py new file mode 100644 index 00000000..a06781ce --- /dev/null +++ b/examples/langchain_openai_client.py @@ -0,0 +1,4 @@ +from langchain.chat_models import ChatOpenAI + +chat_model = ChatOpenAI() +print(chat_model.predict(text="你好", max_tokens=2048)) diff --git a/examples/system/code_interpreter.txt b/examples/system/code_interpreter.txt new file mode 100644 index 00000000..7561ed20 --- /dev/null +++ b/examples/system/code_interpreter.txt @@ -0,0 +1 @@ +你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。 \ No newline at end of file diff --git a/examples/system/default.txt b/examples/system/default.txt new file mode 100644 index 00000000..2345cf50 --- /dev/null +++ b/examples/system/default.txt @@ -0,0 +1 @@ +You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown. \ No newline at end of file diff --git a/examples/system/function_call.txt b/examples/system/function_call.txt new file mode 100644 index 00000000..d25a10c7 --- /dev/null +++ b/examples/system/function_call.txt @@ -0,0 +1,33 @@ +Answer the following questions as best as you can. You have access to the following tools: +{ + "random_number_generator": { + "name": "random_number_generator", + "description": "Generates a random number x, s.t. range[0] <= x < range[1]", + "params": [ + { + "name": "seed", + "description": "The random seed used by the generator", + "type": "int", + "required": true + }, + { + "name": "range", + "description": "The range of the generated numbers", + "type": "tuple[int, int]", + "required": true + } + ] + }, + "get_weather": { + "name": "get_weather", + "description": "Get the current weather for `city_name`", + "params": [ + { + "name": "city_name", + "description": "The name of the city to be queried", + "type": "str", + "required": true + } + ] + } +} \ No newline at end of file diff --git a/examples/web_demo.py b/examples/web_demo.py index 54ab1b87..7fe74cdf 100644 --- a/examples/web_demo.py +++ b/examples/web_demo.py @@ -30,10 +30,9 @@ def postprocess(text): return text -def predict(input, chatbot, max_length, top_p, temperature, history): +def predict(input, chatbot, max_length, top_p, temperature, messages): chatbot.append((postprocess(input), "")) - response = "" - history.append(input) + messages.append(chatglm_cpp.ChatMessage(role="user", content=input)) generation_kwargs = dict( max_length=max_length, @@ -46,19 +45,23 @@ def predict(input, chatbot, max_length, top_p, temperature, history): num_threads=args.threads, stream=True, ) - generator = ( - pipeline.chat(history, **generation_kwargs) - if args.mode == "chat" - else pipeline.generate(input, **generation_kwargs) - ) - for response_piece in generator: - response += response_piece - chatbot[-1] = (chatbot[-1][0], postprocess(response)) - yield chatbot, history - - history.append(response) - yield chatbot, history + response = "" + if args.mode == "chat": + chunks = [] + for chunk in pipeline.chat(messages, **generation_kwargs): + response += chunk.content + chunks.append(chunk) + chatbot[-1] = (chatbot[-1][0], postprocess(response)) + yield chatbot, messages + messages.append(pipeline.merge_streaming_messages(chunks)) + else: + for chunk in pipeline.generate(input, **generation_kwargs): + response += chunk + chatbot[-1] = (chatbot[-1][0], postprocess(response)) + yield chatbot, messages + + yield chatbot, messages def reset_user_input(): @@ -83,13 +86,16 @@ def reset_state(): temperature = gr.Slider(0, 1, value=args.temp, step=0.01, label="Temperature", interactive=True) emptyBtn = gr.Button("Clear History") - history = gr.State([]) + messages = gr.State([]) submitBtn.click( - predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True + predict, + [user_input, chatbot, max_length, top_p, temperature, messages], + [chatbot, messages], + show_progress=True, ) submitBtn.click(reset_user_input, [], [user_input]) - emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) + emptyBtn.click(reset_state, outputs=[chatbot, messages], show_progress=True) demo.queue().launch(share=False, inbrowser=True) diff --git a/main.cpp b/main.cpp index cf305a71..e5bed73d 100644 --- a/main.cpp +++ b/main.cpp @@ -1,4 +1,5 @@ #include "chatglm.h" +#include #include #include @@ -23,7 +24,9 @@ static inline InferenceMode to_inference_mode(const std::string &s) { struct Args { std::string model_path = "chatglm-ggml.bin"; InferenceMode mode = INFERENCE_MODE_CHAT; + bool sync = false; std::string prompt = "你好"; + std::string system = ""; int max_length = 2048; int max_context_length = 512; bool interactive = false; @@ -42,7 +45,11 @@ static void usage(const std::string &prog) { << " -h, --help show this help message and exit\n" << " -m, --model PATH model path (default: chatglm-ggml.bin)\n" << " --mode inference mode chose from {chat, generate} (default: chat)\n" + << " --sync synchronized generation without streaming\n" << " -p, --prompt PROMPT prompt to start generation with (default: 你好)\n" + << " --pp, --prompt_path path to the plain text file that stores the prompt\n" + << " -s, --system SYSTEM system message to set the behavior of the assistant\n" + << " --sp, --system_path path to the plain text file that stores the system message\n" << " -i, --interactive run in interactive mode\n" << " -l, --max_length N max total length including prompt and output (default: 2048)\n" << " -c, --max_context_length N\n" @@ -55,42 +62,58 @@ static void usage(const std::string &prog) { << " -v, --verbose display verbose output including config/system/performance info\n"; } +static std::string read_text(std::string path) { + std::ifstream fin(path); + CHATGLM_CHECK(fin) << "cannot open file " << path; + std::ostringstream oss; + oss << fin.rdbuf(); + return oss.str(); +} + static Args parse_args(const std::vector &argv) { Args args; for (size_t i = 1; i < argv.size(); i++) { - const std::string &arg = argv[i]; + const std::string &arg = argv.at(i); if (arg == "-h" || arg == "--help") { - usage(argv[0]); + usage(argv.at(0)); exit(EXIT_SUCCESS); } else if (arg == "-m" || arg == "--model") { - args.model_path = argv[++i]; + args.model_path = argv.at(++i); } else if (arg == "--mode") { - args.mode = to_inference_mode(argv[++i]); + args.mode = to_inference_mode(argv.at(++i)); + } else if (arg == "--sync") { + args.sync = true; } else if (arg == "-p" || arg == "--prompt") { - args.prompt = argv[++i]; + args.prompt = argv.at(++i); + } else if (arg == "--pp" || arg == "--prompt_path") { + args.prompt = read_text(argv.at(++i)); + } else if (arg == "-s" || arg == "--system") { + args.system = argv.at(++i); + } else if (arg == "--sp" || arg == "--system_path") { + args.system = read_text(argv.at(++i)); } else if (arg == "-i" || arg == "--interactive") { args.interactive = true; } else if (arg == "-l" || arg == "--max_length") { - args.max_length = std::stoi(argv[++i]); + args.max_length = std::stoi(argv.at(++i)); } else if (arg == "-c" || arg == "--max_context_length") { - args.max_context_length = std::stoi(argv[++i]); + args.max_context_length = std::stoi(argv.at(++i)); } else if (arg == "--top_k") { - args.top_k = std::stoi(argv[++i]); + args.top_k = std::stoi(argv.at(++i)); } else if (arg == "--top_p") { - args.top_p = std::stof(argv[++i]); + args.top_p = std::stof(argv.at(++i)); } else if (arg == "--temp") { - args.temp = std::stof(argv[++i]); + args.temp = std::stof(argv.at(++i)); } else if (arg == "--repeat_penalty") { - args.repeat_penalty = std::stof(argv[++i]); + args.repeat_penalty = std::stof(argv.at(++i)); } else if (arg == "-t" || arg == "--threads") { - args.num_threads = std::stoi(argv[++i]); + args.num_threads = std::stoi(argv.at(++i)); } else if (arg == "-v" || arg == "--verbose") { args.verbose = true; } else { std::cerr << "Unknown argument: " << arg << std::endl; - usage(argv[0]); + usage(argv.at(0)); exit(EXIT_FAILURE); } } @@ -133,6 +156,13 @@ static bool get_utf8_line(std::string &line) { #endif } +static inline void print_message(const chatglm::ChatMessage &message) { + std::cout << message.content << "\n"; + if (!message.tool_calls.empty() && message.tool_calls.front().type == chatglm::ToolCallMessage::TYPE_CODE) { + std::cout << message.tool_calls.front().code.input << "\n"; + } +} + static void chat(Args &args) { ggml_time_init(); int64_t start_load_us = ggml_time_us(); @@ -143,8 +173,11 @@ static void chat(Args &args) { auto text_streamer = std::make_shared(std::cout, pipeline.tokenizer.get()); auto perf_streamer = std::make_shared(); - auto streamer = std::make_shared( - std::vector>{text_streamer, perf_streamer}); + std::vector> streamers{perf_streamer}; + if (!args.sync) { + streamers.emplace_back(text_streamer); + } + auto streamer = std::make_unique(std::move(streamers)); chatglm::GenerationConfig gen_config(args.max_length, args.max_context_length, args.temp > 0, args.top_k, args.top_p, args.temp, args.repeat_penalty, args.num_threads); @@ -172,6 +205,7 @@ static void chat(Args &args) { << "top_k = " << args.top_k << " | " << "top_p = " << args.top_p << " | " << "temperature = " << args.temp << " | " + << "repetition_penalty = " << args.repeat_penalty << " | " << "num_threads = " << args.num_threads << " |\n"; std::cout << "loaded " << pipeline.model->config.model_type_name() << " model from " << args.model_path @@ -185,6 +219,11 @@ static void chat(Args &args) { args.interactive = false; } + std::vector system_messages; + if (!args.system.empty()) { + system_messages.emplace_back(chatglm::ChatMessage::ROLE_SYSTEM, args.system); + } + if (args.interactive) { std::cout << R"( ________ __ ________ __ ___ )" << '\n' << R"( / ____/ /_ ____ _/ /_/ ____/ / / |/ /_________ ____ )" << '\n' @@ -198,10 +237,33 @@ static void chat(Args &args) { << "Welcome to ChatGLM.cpp! Ask whatever you want. Type 'clear' to clear context. Type 'stop' to exit.\n" << "\n"; - std::vector history; + std::vector messages = system_messages; + if (!args.system.empty()) { + std::cout << std::setw(model_name.size()) << std::left << "System" + << " > " << args.system << std::endl; + } while (1) { - std::cout << std::setw(model_name.size()) << std::left << "Prompt" - << " > " << std::flush; + std::string role; + if (!messages.empty() && !messages.back().tool_calls.empty()) { + const auto &tool_call = messages.back().tool_calls.front(); + if (tool_call.type == chatglm::ToolCallMessage::TYPE_FUNCTION) { + // function call + std::cout << "Function Call > Please manually call function `" << tool_call.function.name + << "` with args `" << tool_call.function.arguments << "` and provide the results below.\n" + << "Observation > " << std::flush; + } else if (tool_call.type == chatglm::ToolCallMessage::TYPE_CODE) { + // code interpreter + std::cout << "Code Interpreter > Please manually run the code and provide the results below.\n" + << "Observation > " << std::flush; + } else { + CHATGLM_THROW << "unexpected tool type " << tool_call.type; + } + role = chatglm::ChatMessage::ROLE_OBSERVATION; + } else { + std::cout << std::setw(model_name.size()) << std::left << "Prompt" + << " > " << std::flush; + role = chatglm::ChatMessage::ROLE_USER; + } std::string prompt; if (!get_utf8_line(prompt) || prompt == "stop") { break; @@ -210,13 +272,16 @@ static void chat(Args &args) { continue; } if (prompt == "clear") { - history.clear(); + messages = system_messages; continue; } - history.emplace_back(std::move(prompt)); + messages.emplace_back(std::move(role), std::move(prompt)); std::cout << model_name << " > "; - std::string output = pipeline.chat(history, gen_config, streamer.get()); - history.emplace_back(std::move(output)); + chatglm::ChatMessage output = pipeline.chat(messages, gen_config, streamer.get()); + if (args.sync) { + print_message(output); + } + messages.emplace_back(std::move(output)); if (args.verbose) { std::cout << "\n" << perf_streamer->to_string() << "\n\n"; } @@ -225,9 +290,17 @@ static void chat(Args &args) { std::cout << "Bye\n"; } else { if (args.mode == INFERENCE_MODE_CHAT) { - pipeline.chat({args.prompt}, gen_config, streamer.get()); + std::vector messages = system_messages; + messages.emplace_back(chatglm::ChatMessage::ROLE_USER, args.prompt); + chatglm::ChatMessage output = pipeline.chat(messages, gen_config, streamer.get()); + if (args.sync) { + print_message(output); + } } else { - pipeline.generate(args.prompt, gen_config, streamer.get()); + std::string output = pipeline.generate(args.prompt, gen_config, streamer.get()); + if (args.sync) { + std::cout << output << "\n"; + } } if (args.verbose) { std::cout << "\n" << perf_streamer->to_string() << "\n\n"; diff --git a/pyproject.toml b/pyproject.toml index 5dda0b29..fc442293 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,3 +39,8 @@ api = [ [project.urls] Homepage = "https://github.com/li-plus/chatglm.cpp" Repository = "https://github.com/li-plus/chatglm.cpp.git" + +# reference: https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html#configuration-format +[tool.black] +line-length = 120 +include = '\.py$' diff --git a/tests/test_chatglm_cpp.py b/tests/test_chatglm_cpp.py index 29b2829c..ae1bb37a 100644 --- a/tests/test_chatglm_cpp.py +++ b/tests/test_chatglm_cpp.py @@ -20,58 +20,34 @@ def test_chatglm_version(): print(chatglm_cpp.__version__) -@pytest.mark.skipif(not CHATGLM_MODEL_PATH.exists(), reason="model file not found") -def test_chatglm_pipeline(): - history = ["你好"] - target = "你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。" +def check_pipeline(model_path, prompt, target, gen_kwargs={}): + messages = [chatglm_cpp.ChatMessage(role="user", content=prompt)] - pipeline = chatglm_cpp.Pipeline(CHATGLM_MODEL_PATH) - output = pipeline.chat(history, do_sample=False) + pipeline = chatglm_cpp.Pipeline(model_path) + output = pipeline.chat(messages, do_sample=False, **gen_kwargs).content assert output == target - stream_output = pipeline.stream_chat(history, do_sample=False) - stream_output = "".join(stream_output) + stream_output = pipeline.chat(messages, do_sample=False, stream=True, **gen_kwargs) + stream_output = "".join([msg.content for msg in stream_output]) + if model_path == CHATGLM3_MODEL_PATH: + # hack for ChatGLM3 + stream_output = stream_output.strip() assert stream_output == target - stream_output = pipeline.chat(history, do_sample=False, stream=True) - stream_output = "".join(stream_output) - assert stream_output == target + +@pytest.mark.skipif(not CHATGLM_MODEL_PATH.exists(), reason="model file not found") +def test_chatglm_pipeline(): + check_pipeline(model_path=CHATGLM_MODEL_PATH, prompt="你好", target="你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。") @pytest.mark.skipif(not CHATGLM2_MODEL_PATH.exists(), reason="model file not found") def test_chatglm2_pipeline(): - history = ["你好"] - target = "你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。" - - pipeline = chatglm_cpp.Pipeline(CHATGLM2_MODEL_PATH) - output = pipeline.chat(history, do_sample=False) - assert output == target - - stream_output = pipeline.stream_chat(history, do_sample=False) - stream_output = "".join(stream_output) - assert stream_output == target - - stream_output = pipeline.chat(history, do_sample=False, stream=True) - stream_output = "".join(stream_output) - assert stream_output == target + check_pipeline(model_path=CHATGLM2_MODEL_PATH, prompt="你好", target="你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。") @pytest.mark.skipif(not CHATGLM3_MODEL_PATH.exists(), reason="model file not found") def test_chatglm3_pipeline(): - history = ["你好"] - target = "你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。" - - pipeline = chatglm_cpp.Pipeline(CHATGLM3_MODEL_PATH) - output = pipeline.chat(history, do_sample=False) - assert output == target - - stream_output = pipeline.stream_chat(history, do_sample=False) - stream_output = "".join(stream_output) - assert stream_output == target - - stream_output = pipeline.chat(history, do_sample=False, stream=True) - stream_output = "".join(stream_output) - assert stream_output == target + check_pipeline(model_path=CHATGLM3_MODEL_PATH, prompt="你好", target="你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。") @pytest.mark.skipif(not CODEGEEX2_MODEL_PATH.exists(), reason="model file not found") @@ -100,99 +76,39 @@ def bubble_sort(list): @pytest.mark.skipif(not BAICHUAN13B_MODEL_PATH.exists(), reason="model file not found") def test_baichuan13b_pipeline(): - history = ["你好呀"] - target = "你好!很高兴见到你。请问有什么我可以帮助你的吗?" - - gen_kwargs = dict(do_sample=False, repetition_penalty=1.1) - - pipeline = chatglm_cpp.Pipeline(BAICHUAN13B_MODEL_PATH) - output = pipeline.chat(history, **gen_kwargs) - assert output == target - - stream_output = pipeline.stream_chat(history, **gen_kwargs) - stream_output = "".join(stream_output) - assert stream_output == target - - stream_output = pipeline.chat(history, **gen_kwargs, stream=True) - stream_output = "".join(stream_output) - assert stream_output == target + check_pipeline( + model_path=BAICHUAN13B_MODEL_PATH, + prompt="你好呀", + target="你好!很高兴见到你。请问有什么我可以帮助你的吗?", + gen_kwargs=dict(repetition_penalty=1.1), + ) @pytest.mark.skipif(not BAICHUAN2_7B_MODEL_PATH.exists(), reason="model file not found") def test_baichuan2_7b_pipeline(): - history = ["你好呀"] - target = "你好!很高兴为你服务。请问有什么问题我可以帮助你解决?" - - gen_kwargs = dict(do_sample=False, repetition_penalty=1.05) - - pipeline = chatglm_cpp.Pipeline(BAICHUAN2_7B_MODEL_PATH) - output = pipeline.chat(history, **gen_kwargs) - assert output == target - - stream_output = pipeline.stream_chat(history, **gen_kwargs) - stream_output = "".join(stream_output) - assert stream_output == target - - stream_output = pipeline.chat(history, **gen_kwargs, stream=True) - stream_output = "".join(stream_output) - assert stream_output == target + check_pipeline( + model_path=BAICHUAN2_7B_MODEL_PATH, + prompt="你好呀", + target="你好!很高兴为你服务。请问有什么问题我可以帮助你解决?", + gen_kwargs=dict(repetition_penalty=1.05), + ) @pytest.mark.skipif(not BAICHUAN2_13B_MODEL_PATH.exists(), reason="model file not found") def test_baichuan2_13b_pipeline(): - history = ["你好呀"] - target = "你好!很高兴见到你。请问有什么我可以帮助你的吗?" - - gen_kwargs = dict(do_sample=False, repetition_penalty=1.05) - - pipeline = chatglm_cpp.Pipeline(BAICHUAN2_13B_MODEL_PATH) - output = pipeline.chat(history, **gen_kwargs) - assert output == target - - stream_output = pipeline.stream_chat(history, **gen_kwargs) - stream_output = "".join(stream_output) - assert stream_output == target - - stream_output = pipeline.chat(history, **gen_kwargs, stream=True) - stream_output = "".join(stream_output) - assert stream_output == target + check_pipeline( + model_path=BAICHUAN2_13B_MODEL_PATH, + prompt="你好呀", + target="你好!很高兴见到你。请问有什么我可以帮助你的吗?", + gen_kwargs=dict(repetition_penalty=1.05), + ) @pytest.mark.skipif(not INTERNLM7B_MODEL_PATH.exists(), reason="model file not found") def test_internlm7b_pipeline(): - history = ["你好"] - target = "你好,有什么我可以帮助你的吗?" - - gen_kwargs = dict(do_sample=False) - - pipeline = chatglm_cpp.Pipeline(INTERNLM7B_MODEL_PATH) - output = pipeline.chat(history, **gen_kwargs) - assert output == target - - stream_output = pipeline.stream_chat(history, **gen_kwargs) - stream_output = "".join(stream_output) - assert stream_output == target - - stream_output = pipeline.chat(history, **gen_kwargs, stream=True) - stream_output = "".join(stream_output) - assert stream_output == target + check_pipeline(model_path=INTERNLM7B_MODEL_PATH, prompt="你好", target="你好,有什么我可以帮助你的吗?") @pytest.mark.skipif(not INTERNLM20B_MODEL_PATH.exists(), reason="model file not found") def test_internlm20b_pipeline(): - history = ["你好"] - target = "你好!有什么我可以帮助你的吗?" - - gen_kwargs = dict(do_sample=False) - - pipeline = chatglm_cpp.Pipeline(INTERNLM20B_MODEL_PATH) - output = pipeline.chat(history, **gen_kwargs) - assert output == target - - stream_output = pipeline.stream_chat(history, **gen_kwargs) - stream_output = "".join(stream_output) - assert stream_output == target - - stream_output = pipeline.chat(history, **gen_kwargs, stream=True) - stream_output = "".join(stream_output) - assert stream_output == target + check_pipeline(model_path=INTERNLM20B_MODEL_PATH, prompt="你好", target="你好!有什么我可以帮助你的吗?")