diff --git a/3rdparty/tvm b/3rdparty/tvm index c8f7ec8dc0..ce58d63453 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit c8f7ec8dc0377ad362e1c81b194c6e2322f27a75 +Subproject commit ce58d63453ff83b930fa2be665647621b2eec4d2 diff --git a/cpp/json_ffi/conv_template.cc b/cpp/json_ffi/conv_template.cc index e23258f0b8..a386e09921 100644 --- a/cpp/json_ffi/conv_template.cc +++ b/cpp/json_ffi/conv_template.cc @@ -131,7 +131,7 @@ ModelConfig ModelConfig::FromJSON(const picojson::object& json_obj) { /****************** Conversation template ******************/ -std::map PLACEHOLDERS = { +std::unordered_map PLACEHOLDERS = { {MessagePlaceholders::SYSTEM, "{system_message}"}, {MessagePlaceholders::USER, "{user_message}"}, {MessagePlaceholders::ASSISTANT, "{assistant_message}"}, @@ -153,120 +153,213 @@ Conversation::Conversation() {"assistant", PLACEHOLDERS[MessagePlaceholders::ASSISTANT]}, {"tool", PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {} -Result> Conversation::AsPrompt(ModelConfig config, DLDevice device) { - using TResult = Result>; - // Get the system message - std::string system_msg = system_template; - size_t pos = system_msg.find(PLACEHOLDERS[MessagePlaceholders::SYSTEM]); +std::string Conversation::GetSystemText(const std::string& system_msg) const { + std::string system_text = this->system_template; + static std::string system_placeholder = PLACEHOLDERS[MessagePlaceholders::SYSTEM]; + size_t pos = system_text.find(system_placeholder); if (pos != std::string::npos) { - system_msg.replace(pos, PLACEHOLDERS[MessagePlaceholders::SYSTEM].length(), - this->system_message); + system_text.replace(pos, system_placeholder.length(), system_msg); } + return system_text; +} - // Get the message strings - std::vector message_list; - std::vector separators = seps; - if (separators.size() == 1) { - separators.push_back(separators[0]); +std::string Conversation::GetRoleText(const std::string& role, const std::string& content, + const std::optional& fn_call_string) const { + std::string role_text = this->role_templates.at(role); + std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString(role)]; + size_t pos = role_text.find(placeholder); + if (pos != std::string::npos) { + role_text.replace(pos, placeholder.length(), content); + } + if (fn_call_string) { + // replace placeholder[FUNCTION] with function_string + // this assumes function calling is used for a single request scenario only + pos = role_text.find(PLACEHOLDERS[MessagePlaceholders::FUNCTION]); + if (pos != std::string::npos) { + role_text.replace(pos, PLACEHOLDERS[MessagePlaceholders::FUNCTION].length(), + fn_call_string.value()); + } } + return role_text; +} - if (!system_msg.empty()) { - system_msg += separators[0]; - message_list.push_back(TextData(system_message)); +/// Try to detect if function calling is needed, if so, return the function calling string +Result> TryGetFunctionCallingString( + const Conversation& conv, const ChatCompletionRequest& request) { + using TResult = Result>; + if (!request.tools.has_value() || + (request.tool_choice.has_value() && request.tool_choice.value() == "none")) { + return TResult::Ok(std::nullopt); + } + std::vector tools_ = request.tools.value(); + std::string tool_choice_ = request.tool_choice.value(); + + // TODO: support with tool choice as dict + for (const auto& tool : tools_) { + if (tool.function.name == tool_choice_) { + picojson::value function_str(tool.function.AsJSON()); + return TResult::Ok(function_str.serialize()); + } } - for (int i = 0; i < messages.size(); i++) { - std::string role = messages[i].role; - // Todo(mlc-team): support content to be a single string. - std::optional>> content = - messages[i].content; - if (roles.find(role) == roles.end()) { - return TResult::Error("Role \"" + role + "\" is not supported"); - } + if (tool_choice_ != "auto") { + return TResult::Error("Invalid tool_choice value in the request: " + tool_choice_); + } + + picojson::array function_list; + for (const auto& tool : tools_) { + function_list.push_back(picojson::value(tool.function.AsJSON())); + } - std::string separator = separators[role == "assistant"]; // check assistant role + picojson::value function_list_json(function_list); + return TResult::Ok(function_list_json.serialize()); +}; - // If content is empty, add the role and separator - // assistant's turn to generate text - if (!content.has_value()) { - message_list.push_back(TextData(roles[role] + role_empty_sep)); - continue; - } +Result> CreatePrompt(const Conversation& conv, + const ChatCompletionRequest& request, + const ModelConfig& config, DLDevice device) { + using TResult = Result>; + + Result> fn_call_str_tmp = TryGetFunctionCallingString(conv, request); + if (fn_call_str_tmp.IsErr()) { + return TResult::Error(fn_call_str_tmp.UnwrapErr()); + } + std::optional fn_call_string = fn_call_str_tmp.Unwrap(); + + // Handle system message + // concz + bool has_custom_system = false; + std::string custom_system_inputs; - std::string message = ""; - std::string role_prefix = ""; - // Do not append role prefix if this is the first message and there - // is already a system message - if (add_role_after_system_message || system_msg.empty() || i != 0) { - role_prefix = roles[role] + role_content_sep; + auto f_populate_system_message = [&](const std::vector& msg_vec) { + for (ChatCompletionMessage msg : msg_vec) { + if (msg.role == "system") { + ICHECK(msg.content.IsText()) << "System message must be text"; + custom_system_inputs += msg.content.Text(); + has_custom_system = true; + } } + }; + // go through messages in template and passed in. + f_populate_system_message(conv.messages); + f_populate_system_message(request.messages); - message += role_prefix; + // pending text records the text to be put into data + // we lazily accumulate the pending text + // to reduce amount of segments in the Data vector + std::string pending_text = + conv.GetSystemText(has_custom_system ? custom_system_inputs : conv.system_message); - for (const auto& item : content.value()) { - auto it_type = item.find("type"); - if (it_type == item.end()) { - return TResult::Error("The content of a message does not have \"type\" field"); + // the seperator after system message. + if (!pending_text.empty()) { + pending_text += conv.seps[0]; + } + + // Get the message strings + std::vector message_list; + size_t non_system_msg_count = 0; + + // returns error if error happens + auto f_process_messages = + [&](const std::vector& msg_vec) -> std::optional { + for (size_t i = 0; i < msg_vec.size(); ++i) { + const ChatCompletionMessage& msg = msg_vec[i]; + auto role_it = conv.roles.find(msg.role); + if (role_it == conv.roles.end()) { + return TResult::Error("Role \"" + msg.role + "\" is not supported"); } - if (it_type->second == "text") { - auto it_text = item.find("text"); - if (it_text == item.end()) { - return TResult::Error("The text type content of a message does not have \"text\" field"); - } - // replace placeholder[ROLE] with input message from role - std::string role_text = role_templates[role]; - std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString(role)]; - size_t pos = role_text.find(placeholder); - if (pos != std::string::npos) { - role_text.replace(pos, placeholder.length(), it_text->second); - } - if (use_function_calling) { - // replace placeholder[FUNCTION] with function_string - // this assumes function calling is used for a single request scenario only - if (!function_string.has_value()) { - return TResult::Error( - "The function string in conversation template is not defined for function " - "calling."); + const std::string& role_name = role_it->second; + // skip system message as it is already processed + if (msg.role == "system") continue; + // skip when content is empty + if (msg.content.IsNull()) { + pending_text += role_name + conv.role_empty_sep; + continue; + } + ++non_system_msg_count; + // assistant uses conv.seps[1] if there are two seps + int sep_offset = msg.role == "assistant" ? 1 : 0; + const std::string& seperator = conv.seps[sep_offset % conv.seps.size()]; + // setup role prefix + std::string role_prefix = ""; + // Do not append role prefix if this is the first message and there is already a system + // message + if (conv.add_role_after_system_message || pending_text.empty() || non_system_msg_count != 1) { + role_prefix = role_name + conv.role_content_sep; + } + pending_text += role_prefix; + + if (msg.content.IsParts()) { + for (const auto& item : msg.content.Parts()) { + auto it_type = item.find("type"); + if (it_type == item.end()) { + return TResult::Error("The content of a message does not have \"type\" field"); } - pos = role_text.find(PLACEHOLDERS[MessagePlaceholders::FUNCTION]); - if (pos != std::string::npos) { - role_text.replace(pos, PLACEHOLDERS[MessagePlaceholders::FUNCTION].length(), - function_string.value()); + if (it_type->second == "text") { + auto it_text = item.find("text"); + if (it_text == item.end()) { + return TResult::Error( + "The text type content of a message does not have \"text\" field"); + } + // replace placeholder[ROLE] with input message from role + pending_text += conv.GetRoleText(msg.role, it_text->second, fn_call_string); + } else if (it_type->second == "image_url") { + if (item.find("image_url") == item.end()) { + return TResult::Error("Content should have an image_url field"); + } + std::string image_url = + item.at("image_url"); // TODO(mlc-team): According to OpenAI API reference this + // should be a map, with a "url" key containing the URL, but + // we are just assuming this as the URL for now + std::string base64_image = image_url.substr(image_url.find(",") + 1); + Result image_data_res = LoadImageFromBase64(base64_image); + if (image_data_res.IsErr()) { + return TResult::Error(image_data_res.UnwrapErr()); + } + if (!config.vision_config.has_value()) { + return TResult::Error("Vision config is required for image input"); + } + int image_size = config.vision_config.value().image_size; + int patch_size = config.vision_config.value().patch_size; + + int embed_size = (image_size * image_size) / (patch_size * patch_size); + + auto image_ndarray = ClipPreprocessor(image_data_res.Unwrap(), image_size, device); + // lazily commit text data + if (pending_text.length() != 0) { + message_list.push_back(TextData(pending_text)); + pending_text = ""; + } + message_list.push_back(ImageData(image_ndarray, embed_size)); + } else { + return TResult::Error("Unsupported content type: " + it_type->second); } } - message += role_text; - } else if (it_type->second == "image_url") { - if (item.find("image_url") == item.end()) { - return TResult::Error("Content should have an image_url field"); - } - std::string image_url = - item.at("image_url"); // TODO(mlc-team): According to OpenAI API reference this - // should be a map, with a "url" key containing the URL, but - // we are just assuming this as the URL for now - std::string base64_image = image_url.substr(image_url.find(",") + 1); - Result image_data_res = LoadImageFromBase64(base64_image); - if (image_data_res.IsErr()) { - return TResult::Error(image_data_res.UnwrapErr()); - } - if (!config.vision_config.has_value()) { - return TResult::Error("Vision config is required for image input"); - } - int image_size = config.vision_config.value().image_size; - int patch_size = config.vision_config.value().patch_size; - - int embed_size = (image_size * image_size) / (patch_size * patch_size); - - auto image_ndarray = ClipPreprocessor(image_data_res.Unwrap(), image_size, device); - message_list.push_back(ImageData(image_ndarray, embed_size)); } else { - return TResult::Error("Unsupported content type: " + it_type->second); + ICHECK(msg.content.IsText()); + pending_text += conv.GetRoleText(msg.role, msg.content.Text(), fn_call_string); } + pending_text += seperator; } + return std::nullopt; + }; - message += separator; - message_list.push_back(TextData(message)); + if (auto err = f_process_messages(conv.messages)) { + return err.value(); + } + if (auto err = f_process_messages(request.messages)) { + return err.value(); + } + // append last assistant begin message + ChatCompletionMessage last_assistant_begin; + last_assistant_begin.role = "assistant"; + last_assistant_begin.content = std::nullopt; + if (auto err = f_process_messages({last_assistant_begin})) { + return err.value(); + } + if (pending_text.length() != 0) { + message_list.push_back(TextData(pending_text)); } - return TResult::Ok(message_list); } @@ -383,7 +476,10 @@ Result Conversation::FromJSON(const picojson::object& json_obj) { content.push_back(std::move(item_map)); } } - conv.messages.push_back({role_res.Unwrap(), content}); + ChatCompletionMessage msg; + msg.role = role_res.Unwrap(); + msg.content = content; + conv.messages.push_back(msg); } Result seps_arr_res = @@ -438,21 +534,6 @@ Result Conversation::FromJSON(const picojson::object& json_obj) { } conv.stop_token_ids.push_back(stop.get()); } - - Result> function_string_res = - json::LookupOptionalWithResultReturn(json_obj, "function_string"); - if (function_string_res.IsErr()) { - return TResult::Error(function_string_res.UnwrapErr()); - } - conv.function_string = function_string_res.Unwrap(); - - Result use_function_calling_res = json::LookupOrDefaultWithResultReturn( - json_obj, "use_function_calling", conv.use_function_calling); - if (use_function_calling_res.IsErr()) { - return TResult::Error(use_function_calling_res.UnwrapErr()); - } - conv.use_function_calling = use_function_calling_res.Unwrap(); - return TResult::Ok(conv); } diff --git a/cpp/json_ffi/conv_template.h b/cpp/json_ffi/conv_template.h index 8217c5d6e5..e6c8e784f7 100644 --- a/cpp/json_ffi/conv_template.h +++ b/cpp/json_ffi/conv_template.h @@ -11,6 +11,7 @@ #include "../serve/data.h" #include "../support/result.h" +#include "openai_api_protocol.h" #include "picojson.h" using namespace mlc::llm::serve; @@ -62,12 +63,6 @@ enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION }; MessagePlaceholders MessagePlaceholderFromString(const std::string& role); -class Message { - public: - std::string role; - std::optional>> content = std::nullopt; -}; - /** * @brief A struct that specifies the convention template of conversation * and contains the conversation history. @@ -102,7 +97,7 @@ struct Conversation { // The conversation history messages. // Each message is a pair of strings, denoting "(role, content)". // The content can be None. - std::vector messages; + std::vector messages; // The separators between messages when concatenating into a single prompt. // List size should be either 1 or 2. @@ -121,15 +116,24 @@ struct Conversation { std::vector stop_str; std::vector stop_token_ids; - // Function call fields - // whether using function calling or not, helps check for output message format in API call - std::optional function_string = std::nullopt; - bool use_function_calling = false; - Conversation(); - /*! \brief Create the list of prompts from the messages based on the conversation template. */ - Result> AsPrompt(ModelConfig config, DLDevice device); + /*! + * \brief Get the system text(with the prompt template) given the system prompt message + * \param system_msg The system prompt message. + * \return The created system text. + */ + std::string GetSystemText(const std::string& system_msg) const; + + /*! + * \brief replace the content from role by the correct role text in template + * \param role The input role + * \param content The input content from the role + * \param fn_call_str The function calling string if any. + * \return The created text. + */ + std::string GetRoleText(const std::string& role, const std::string& content, + const std::optional& fn_call_str) const; /*! \brief Create a Conversation instance from the given JSON object. */ static Result FromJSON(const picojson::object& json); @@ -137,6 +141,11 @@ struct Conversation { static Result FromJSON(const std::string& json_str); }; +/*! \brief Create the list of prompts from the messages based on the conversation template. */ +Result> CreatePrompt(const Conversation& conv, + const ChatCompletionRequest& request, + const ModelConfig& config, DLDevice device); + } // namespace json_ffi } // namespace llm } // namespace mlc diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index 65f3183424..98d00061a8 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -31,7 +31,7 @@ void JSONFFIEngine::StreamBackError(std::string request_id) { ChatCompletionMessage delta; delta.content = std::vector>{ {{"type", "text"}, {"text", this->err_}}}; - delta.role = Role::assistant; + delta.role = "assistant"; ChatCompletionStreamResponseChoice choice; choice.finish_reason = FinishReason::error; @@ -44,7 +44,9 @@ void JSONFFIEngine::StreamBackError(std::string request_id) { response.model = "json_ffi"; // TODO: Return model name from engine (or from args) response.system_fingerprint = ""; - this->request_stream_callback_(Array{picojson::value(response.AsJSON()).serialize()}); + picojson::array response_arr; + response_arr.push_back(picojson::value(response.AsJSON())); + this->request_stream_callback_(picojson::value(response_arr).serialize()); } bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request_id) { @@ -54,38 +56,9 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request return false; } ChatCompletionRequest request = request_res.Unwrap(); - // Create Request - // TODO: Check if request_id is present already - - // inputs - Conversation conv_template = this->conv_template_; - std::vector messages; - for (const auto& message : request.messages) { - std::string role; - if (message.role == Role::user) { - role = "user"; - } else if (message.role == Role::assistant) { - role = "assistant"; - } else if (message.role == Role::tool) { - role = "tool"; - } else { - role = "system"; - } - messages.push_back({role, message.content}); - } - messages.push_back({"assistant", std::nullopt}); - conv_template.messages = messages; - - // check function calling - Result updated_conv_template = request.CheckFunctionCalling(conv_template); - if (updated_conv_template.IsErr()) { - err_ = updated_conv_template.UnwrapErr(); - return false; - } - conv_template = updated_conv_template.Unwrap(); - - // get prompt - Result> inputs_obj = conv_template.AsPrompt(this->model_config_, this->device_); + // get prompt: note, assistant was appended in the end. + Result> inputs_obj = + CreatePrompt(this->conv_template_, request, this->model_config_, this->device_); if (inputs_obj.IsErr()) { err_ = inputs_obj.UnwrapErr(); return false; @@ -94,8 +67,8 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request // generation_cfg Array stop_strs; - stop_strs.reserve(conv_template.stop_str.size()); - for (const std::string& stop_str : conv_template.stop_str) { + stop_strs.reserve(this->conv_template_.stop_str.size()); + for (const std::string& stop_str : this->conv_template_.stop_str) { stop_strs.push_back(stop_str); } if (request.stop.has_value()) { @@ -110,7 +83,7 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request /*repetition_penalty=*/std::nullopt, request.logprobs, request.top_logprobs, request.logit_bias, request.seed, request.ignore_eos, request.max_tokens, std::move(stop_strs), - conv_template.stop_token_ids, /*response_format=*/std::nullopt, + conv_template_.stop_token_ids, /*response_format=*/std::nullopt, this->default_generation_cfg_json_str_); Request engine_request(request_id, inputs, generation_cfg); @@ -146,8 +119,9 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); TVM_MODULE_VTABLE_END(); - void InitBackgroundEngine(Device device, Optional request_stream_callback, - Optional trace_recorder) { + void InitBackgroundEngine(int device_type, int device_id, + Optional request_stream_callback) { + DLDevice device{static_cast(device_type), device_id}; this->device_ = device; CHECK(request_stream_callback.defined()) << "JSONFFIEngine requires request stream callback function, but it is not given."; @@ -156,13 +130,12 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { ICHECK_EQ(args.size(), 1); Array delta_outputs = args[0]; - Array responses = this->GetResponseFromStreamOutput(delta_outputs); + String responses = this->GetResponseFromStreamOutput(delta_outputs); this->request_stream_callback_(responses); }; request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - this->engine_->InitThreadedEngine(device, std::move(request_stream_callback), - std::move(trace_recorder)); + this->engine_->InitThreadedEngine(device, std::move(request_stream_callback), NullOpt); } void Reload(String engine_config_json_str) { @@ -198,7 +171,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { void RunBackgroundStreamBackLoop() { this->engine_->RunBackgroundStreamBackLoop(); } - Array GetResponseFromStreamOutput(Array delta_outputs) { + String GetResponseFromStreamOutput(Array delta_outputs) { std::unordered_map> response_map; for (const auto& delta_output : delta_outputs) { std::string request_id = delta_output->request_id; @@ -232,27 +205,24 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { // Size of delta_output->group_delta_token_ids Array should be 1 IntTuple delta_token_ids = delta_output->group_delta_token_ids[0]; std::vector delta_token_ids_vec(delta_token_ids.begin(), delta_token_ids.end()); - delta.content = std::vector>(); - delta.content.value().push_back(std::unordered_map{ - {"type", "text"}, {"text", this->streamer_->Put(delta_token_ids_vec)}}); - - delta.role = Role::assistant; + delta.content = this->streamer_->Put(delta_token_ids_vec); + delta.role = "assistant"; choice.delta = delta; response_map[request_id].push_back(choice); } - Array response_arr; + picojson::array response_arr; for (const auto& [request_id, choices] : response_map) { ChatCompletionStreamResponse response; response.id = request_id; response.choices = choices; response.model = "json_ffi"; // TODO: Return model name from engine (or from args) response.system_fingerprint = ""; - response_arr.push_back(picojson::value(response.AsJSON()).serialize()); + response_arr.push_back(picojson::value(response.AsJSON())); } - return response_arr; + return picojson::value(response_arr).serialize(); } }; diff --git a/cpp/json_ffi/openai_api_protocol.cc b/cpp/json_ffi/openai_api_protocol.cc index c07de8fef5..525366440a 100644 --- a/cpp/json_ffi/openai_api_protocol.cc +++ b/cpp/json_ffi/openai_api_protocol.cc @@ -170,25 +170,37 @@ picojson::object ChatToolCall::AsJSON() const { Result ChatCompletionMessage::FromJSON(const picojson::object& json_obj) { using TResult = Result; ChatCompletionMessage message; + ChatCompletionMessageContent content; // content - Result content_arr_res = - json::LookupWithResultReturn(json_obj, "content"); - if (content_arr_res.IsErr()) { - return TResult::Error(content_arr_res.UnwrapErr()); - } - std::vector> content; - for (const auto& item : content_arr_res.Unwrap()) { - // Todo(mlc-team): allow content item to be a single string. - if (!item.is()) { - return TResult::Error("The content of chat completion message is not an object"); + auto it = json_obj.find("content"); + if (it == json_obj.end()) { + return TResult::Error("ValueError: key \"content\" not found in the chat completion."); + } + if (it->second.is()) { + content = it->second.get(); + } else if (it->second.is()) { + // skip + } else { + // most complicated case + std::vector> parts; + Result content_arr_res = + json::LookupWithResultReturn(json_obj, "content"); + if (content_arr_res.IsErr()) { + return TResult::Error(content_arr_res.UnwrapErr()); } - picojson::object item_obj = item.get(); - std::unordered_map item_map; - for (const auto& [key, value] : item_obj) { - item_map[key] = value.to_str(); + for (const auto& item : content_arr_res.Unwrap()) { + if (!item.is()) { + return TResult::Error("The content of chat completion message is not an object"); + } + picojson::object item_obj = item.get(); + std::unordered_map item_map; + for (const auto& [key, value] : item_obj) { + item_map[key] = value.to_str(); + } + parts.push_back(std::move(item_map)); } - content.push_back(std::move(item_map)); + content = parts; } message.content = content; @@ -198,14 +210,8 @@ Result ChatCompletionMessage::FromJSON(const picojson::ob return TResult::Error(role_str_res.UnwrapErr()); } std::string role_str = role_str_res.Unwrap(); - if (role_str == "system") { - message.role = Role::system; - } else if (role_str == "user") { - message.role = Role::user; - } else if (role_str == "assistant") { - message.role = Role::assistant; - } else if (role_str == "tool") { - message.role = Role::tool; + if (role_str == "system" || role_str == "user" || role_str == "assistant" || role_str == "tool") { + message.role = role_str; } else { return TResult::Error("Invalid role in chat completion message: " + role_str); } @@ -282,7 +288,8 @@ Result ChatCompletionRequest::FromJSON(const std::string& request.messages = messages; // model - Result model_res = json::LookupWithResultReturn(json_obj, "model"); + Result> model_res = + json::LookupOptionalWithResultReturn(json_obj, "model"); if (model_res.IsErr()) { return TResult::Error(model_res.UnwrapErr()); } @@ -344,30 +351,28 @@ Result ChatCompletionRequest::FromJSON(const std::string& } // TODO: Other parameters - return TResult::Ok(request); } picojson::object ChatCompletionMessage::AsJSON() const { picojson::object obj; - picojson::array content_arr; - for (const auto& item : this->content.value()) { - picojson::object item_obj; - for (const auto& pair : item) { - item_obj[pair.first] = picojson::value(pair.second); + + if (this->content.IsText()) { + obj["content"] = picojson::value(this->content.Text()); + } else if (this->content.IsParts()) { + picojson::array content_arr; + for (const auto& item : this->content.Parts()) { + picojson::object item_obj; + for (const auto& pair : item) { + item_obj[pair.first] = picojson::value(pair.second); + } + content_arr.push_back(picojson::value(item_obj)); } - content_arr.push_back(picojson::value(item_obj)); - } - obj["content"] = picojson::value(content_arr); - if (this->role == Role::system) { - obj["role"] = picojson::value("system"); - } else if (this->role == Role::user) { - obj["role"] = picojson::value("user"); - } else if (this->role == Role::assistant) { - obj["role"] = picojson::value("assistant"); - } else if (this->role == Role::tool) { - obj["role"] = picojson::value("tool"); + obj["content"] = picojson::value(content_arr); } + + obj["role"] = picojson::value(this->role); + if (this->name.has_value()) { obj["name"] = picojson::value(this->name.value()); } @@ -384,40 +389,6 @@ picojson::object ChatCompletionMessage::AsJSON() const { return obj; } -Result ChatCompletionRequest::CheckFunctionCalling(Conversation conv_template) { - using TResult = Result; - if (!tools.has_value() || (tool_choice.has_value() && tool_choice.value() == "none")) { - conv_template.use_function_calling = false; - return TResult::Ok(conv_template); - } - std::vector tools_ = tools.value(); - std::string tool_choice_ = tool_choice.value(); - - // TODO: support with tool choice as dict - for (const auto& tool : tools_) { - if (tool.function.name == tool_choice_) { - conv_template.use_function_calling = true; - picojson::value function_str(tool.function.AsJSON()); - conv_template.function_string = function_str.serialize(); - return TResult::Ok(conv_template); - } - } - - if (tool_choice_ != "auto") { - return TResult::Error("Invalid tool_choice value in the request: " + tool_choice_); - } - - picojson::array function_list; - for (const auto& tool : tools_) { - function_list.push_back(picojson::value(tool.function.AsJSON())); - } - - conv_template.use_function_calling = true; - picojson::value function_list_json(function_list); - conv_template.function_string = function_list_json.serialize(); - return TResult::Ok(conv_template); -}; - picojson::object ChatCompletionResponseChoice::AsJSON() const { picojson::object obj; if (!this->finish_reason.has_value()) { diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index 914366c2f1..50f7315778 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -14,14 +14,12 @@ #include #include "../support/result.h" -#include "conv_template.h" #include "picojson.h" namespace mlc { namespace llm { namespace json_ffi { -enum class Role { system, user, assistant, tool }; enum class Type { text, json_object, function }; enum class FinishReason { stop, length, tool_calls, error }; @@ -80,11 +78,41 @@ class ChatToolCall { picojson::object AsJSON() const; }; +class ChatCompletionMessageContent { + public: + ChatCompletionMessageContent() = default; + + ChatCompletionMessageContent(std::nullopt_t) {} // NOLINT(*) + + ChatCompletionMessageContent(std::string text) : text_(text) {} // NOLINT(*) + + ChatCompletionMessageContent( + std::vector> parts) // NOLINT(*) + : parts_(parts) {} + + bool IsNull() const { return !IsText() && !IsParts(); } + + bool IsText() const { return text_.operator bool(); } + + bool IsParts() const { return parts_.operator bool(); } + + const std::string& Text() const { return text_.value(); } + + const std::vector>& Parts() const { + return parts_.value(); + } + + private: + /*! \brief used to store text content */ + std::optional text_; + std::optional>> parts_; +}; + class ChatCompletionMessage { public: - std::optional>> content = + ChatCompletionMessageContent content = std::nullopt; // Assuming content is a list of string key-value pairs - Role role; + std::string role; std::optional name = std::nullopt; std::optional> tool_calls = std::nullopt; std::optional tool_call_id = std::nullopt; @@ -102,7 +130,7 @@ class RequestResponseFormat { class ChatCompletionRequest { public: std::vector messages; - std::string model; + std::optional model = std::nullopt; std::optional frequency_penalty = std::nullopt; std::optional presence_penalty = std::nullopt; bool logprobs = false; @@ -124,7 +152,6 @@ class ChatCompletionRequest { /*! \brief Parse and create a ChatCompletionRequest instance from the given JSON string. */ static Result FromJSON(const std::string& json_str); - Result CheckFunctionCalling(Conversation conv_template); // TODO: check_penalty_range, check_logit_bias, check_logprobs }; diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 93de185eb2..a8d2edc11a 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -302,7 +302,7 @@ struct FunctionTable { Device null_device{DLDeviceType(0), 0}; if (this->use_disco) { DRef empty_func = sess->GetGlobalFunc("runtime.disco.empty"); - return sess->CallPacked(empty_func, shape, dtype, null_device); + return sess->CallPacked(empty_func, shape, dtype, null_device, false); } else { return NDArray::Empty(shape, dtype, device); } diff --git a/cpp/metadata/model.cc b/cpp/metadata/model.cc index 62ba2787b9..e3e9a79b3c 100644 --- a/cpp/metadata/model.cc +++ b/cpp/metadata/model.cc @@ -63,8 +63,17 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata, if (metadata.count("attention_sink_size")) // remove after sink is decoupled from model lib result.attention_sink_size = json::Lookup(metadata, "attention_sink_size"); result.tensor_parallel_shards = json::Lookup(metadata, "tensor_parallel_shards"); - result.kv_cache_metadata = - KVCacheMetadata::FromJSON(json::Lookup(metadata, "kv_cache")); + result.kv_state_kind = KVStateKindFromString( + json::LookupOrDefault(metadata, "kv_state_kind", "kv_cache")); + if (result.kv_state_kind != KVStateKind::kNone) { + result.kv_cache_metadata = + KVCacheMetadata::FromJSON(json::Lookup(metadata, "kv_cache")); + } else { + result.kv_cache_metadata = {/*num_hidden_layers=*/0, + /*head_dim=*/0, + /*num_attention_heads=*/0, + /*num_key_value_heads=*/0}; + } { std::vector& params = result.params; picojson::array json_params = json::Lookup(metadata, "params"); @@ -94,7 +103,7 @@ ModelMetadata ModelMetadata::FromModule(tvm::runtime::Module module, try { return ModelMetadata::FromJSON(json, model_config); } catch (const std::exception& e) { - LOG(WARNING) << "Failed to parse metadata:\n" << json_str; + LOG(WARNING) << "Failed to parse metadata:\n" << json_str << "\nerror: " << e.what(); throw e; } } diff --git a/cpp/metadata/model.h b/cpp/metadata/model.h index ede06b6b3f..4b204f6902 100644 --- a/cpp/metadata/model.h +++ b/cpp/metadata/model.h @@ -16,6 +16,36 @@ namespace mlc { namespace llm { +/*! \brief The kind of cache. */ +enum class KVStateKind : int { + kKVCache = 0, + kRNNState = 1, + kNone = 2, +}; + +inline std::string KVStateKindToString(KVStateKind kv_state_kind) { + if (kv_state_kind == KVStateKind::kKVCache) { + return "kv_cache"; + } else if (kv_state_kind == KVStateKind::kRNNState) { + return "rnn_state"; + } else if (kv_state_kind == KVStateKind::kNone) { + return "none"; + } else { + LOG(FATAL) << "Invalid kv state kind: " << static_cast(kv_state_kind); + } +} + +inline KVStateKind KVStateKindFromString(const std::string& kv_state_kind) { + if (kv_state_kind == "kv_cache") { + return KVStateKind::kKVCache; + } else if (kv_state_kind == "rnn_state") { + return KVStateKind::kRNNState; + } else if (kv_state_kind == "none") { + return KVStateKind::kNone; + } else { + LOG(FATAL) << "Invalid kv state kind string: " << kv_state_kind; + } +} struct ModelMetadata { struct Param { struct Preproc { @@ -49,6 +79,7 @@ struct ModelMetadata { int64_t attention_sink_size; std::vector params; std::unordered_map memory_usage; + KVStateKind kv_state_kind; KVCacheMetadata kv_cache_metadata; static ModelMetadata FromJSON(const picojson::object& json_str, diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 9b9d5ba65a..cbc4c6c613 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -248,7 +248,6 @@ EngineConfig EngineConfig::FromJSONAndInferredConfig( CHECK(inferred_config.max_single_sequence_length.has_value()); CHECK(inferred_config.prefill_chunk_size.has_value()); CHECK(inferred_config.max_history_size.has_value()); - CHECK(inferred_config.kv_state_kind.has_value()); ObjectPtr n = make_object(); // - Get models and model libs. @@ -290,7 +289,6 @@ EngineConfig EngineConfig::FromJSONAndInferredConfig( n->max_single_sequence_length = inferred_config.max_single_sequence_length.value(); n->prefill_chunk_size = inferred_config.prefill_chunk_size.value(); n->max_history_size = inferred_config.max_history_size.value(); - n->kv_state_kind = inferred_config.kv_state_kind.value(); return EngineConfig(n); } @@ -356,7 +354,6 @@ String EngineConfigNode::AsJSONString() const { picojson::value(static_cast(this->max_single_sequence_length)); config["prefill_chunk_size"] = picojson::value(static_cast(this->prefill_chunk_size)); config["max_history_size"] = picojson::value(static_cast(this->max_history_size)); - config["kv_state_kind"] = picojson::value(KVStateKindToString(this->kv_state_kind)); config["speculative_mode"] = picojson::value(SpeculativeModeToString(this->speculative_mode)); config["spec_draft_length"] = picojson::value(static_cast(this->spec_draft_length)); config["verbose"] = picojson::value(static_cast(this->verbose)); @@ -428,14 +425,18 @@ Result GetModelConfigLimits(const std::vector(compile_time_model_config, "max_batch_size")); } ICHECK_NE(model_max_prefill_chunk_size, std::numeric_limits::max()); ICHECK_NE(model_max_batch_size, std::numeric_limits::max()); + ICHECK_GT(model_max_prefill_chunk_size, 0); + ICHECK_GT(model_max_batch_size, 0); return Result::Ok( {model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size}); } @@ -689,7 +690,6 @@ Result InferrableEngineConfig::InferForKVCache( << " MB). The actual usage might be slightly larger than the estimated number."; } - inferred_config.kv_state_kind = KVStateKind::kKVCache; inferred_config.max_history_size = 0; return Result::Ok(inferred_config); } @@ -853,7 +853,6 @@ Result InferrableEngineConfig::InferForRNNState( << " MB). The actual usage might be slightly larger than the estimated number."; } - inferred_config.kv_state_kind = KVStateKind::kRNNState; return Result::Ok(inferred_config); } diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 8437232d37..2680eb755c 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -114,12 +114,8 @@ enum class SpeculativeMode : int { kSmallDraft = 1, /*! \brief The eagle-style speculative decoding. */ kEagle = 2, -}; - -/*! \brief The kind of cache. */ -enum class KVStateKind : int { - kKVCache = 0, - kRNNState = 1, + /*! \brief The Medusa-style speculative decoding. */ + kMedusa = 3, }; class InferrableEngineConfig; @@ -172,8 +168,6 @@ class EngineConfigNode : public Object { int prefill_chunk_size = 1024; /*! \brief The maximum history size for RNN state. KV cache does not need this. */ int max_history_size = 0; - /*! \brief The kind of cache. Whether it's KV cache or RNN state. */ - KVStateKind kv_state_kind = KVStateKind::kKVCache; /*************** Speculative decoding ***************/ @@ -216,7 +210,6 @@ struct InferrableEngineConfig { std::optional max_single_sequence_length; std::optional prefill_chunk_size; std::optional max_history_size; - std::optional kv_state_kind; /*! \brief Infer the config for KV cache from a given initial config. */ TVM_DLL static Result InferForKVCache( @@ -238,9 +231,16 @@ struct InferrableEngineConfig { Result ModelsUseKVCache(const std::vector& model_configs); inline std::string EngineModeToString(EngineMode mode) { - return mode == EngineMode::kLocal ? "local" - : mode == EngineMode::kInteractive ? "interactive" - : "server"; + if (mode == EngineMode::kLocal) { + return "local"; + } else if (mode == EngineMode::kInteractive) { + return "interactive"; + } else if (mode == EngineMode::kServer) { + return "server"; + } else { + LOG(FATAL) << "Invalid engine mode: " << static_cast(mode); + throw; + } } inline EngineMode EngineModeFromString(const std::string& mode) { @@ -252,13 +252,22 @@ inline EngineMode EngineModeFromString(const std::string& mode) { return EngineMode::kServer; } else { LOG(FATAL) << "Invalid engine mode string: " << mode; + throw; } } inline std::string SpeculativeModeToString(SpeculativeMode speculative_mode) { - return speculative_mode == SpeculativeMode::kDisable ? "disable" - : speculative_mode == SpeculativeMode::kSmallDraft ? "small_draft" - : "eagle"; + if (speculative_mode == SpeculativeMode::kDisable) { + return "disable"; + } else if (speculative_mode == SpeculativeMode::kSmallDraft) { + return "small_draft"; + } else if (speculative_mode == SpeculativeMode::kEagle) { + return "eagle"; + } else if (speculative_mode == SpeculativeMode::kMedusa) { + return "medusa"; + } else { + LOG(FATAL) << "Invalid speculative mode: " << static_cast(speculative_mode); + } } inline SpeculativeMode SpeculativeModeFromString(const std::string& speculative_mode) { @@ -268,22 +277,11 @@ inline SpeculativeMode SpeculativeModeFromString(const std::string& speculative_ return SpeculativeMode::kSmallDraft; } else if (speculative_mode == "eagle") { return SpeculativeMode::kEagle; + } else if (speculative_mode == "medusa") { + return SpeculativeMode::kMedusa; } else { LOG(FATAL) << "Invalid speculative mode string: " << speculative_mode; - } -} - -inline std::string KVStateKindToString(KVStateKind kv_state_kind) { - return kv_state_kind == KVStateKind::kKVCache ? "kv_cache" : "rnn_State"; -} - -inline KVStateKind KVStateKindFromString(const std::string& kv_state_kind) { - if (kv_state_kind == "kv_cache") { - return KVStateKind::kKVCache; - } else if (kv_state_kind == "rnn_state") { - return KVStateKind::kRNNState; - } else { - LOG(FATAL) << "Invalid kv state kind string: " << kv_state_kind; + throw; } } diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 616c463d9c..418cabfc91 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -105,8 +105,7 @@ class EngineImpl : public Engine { model->SetPrefillChunkSize(engine_config->prefill_chunk_size); model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, engine_config->max_total_sequence_length, - engine_config->prefill_chunk_size, engine_config->max_history_size, - engine_config->kv_state_kind); + engine_config->prefill_chunk_size, engine_config->max_history_size); n->model_workspaces_.push_back( ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); } @@ -122,7 +121,7 @@ class EngineImpl : public Engine { } n->token_table_ = Tokenizer::PostProcessTokenTable(n->tokenizer_->TokenTable(), token_table_postproc_method); - n->grammar_init_context_storage_ = GrammarInitContextStorage(n->token_table_); + n->grammar_init_context_cache_ = GrammarInitContextCache(n->token_table_); // - Create the logit processor and sampler, and // the DraftTokenWorkspaceManager for speculative decoding. int max_num_tokens = engine_config->max_num_sequence; @@ -161,6 +160,18 @@ class EngineImpl : public Engine { n->model_workspaces_, draft_token_workspace_manager, engine_config, n->trace_recorder_)}; break; + case SpeculativeMode::kMedusa: + n->actions_ = {EngineAction::EagleNewRequestPrefill(n->models_, // + logit_processor, // + sampler, // + n->model_workspaces_, // + draft_token_workspace_manager, // + engine_config, // + n->trace_recorder_), + EngineAction::EagleBatchVerify( + n->models_, logit_processor, sampler, n->model_workspaces_, + draft_token_workspace_manager, engine_config, n->trace_recorder_)}; + break; default: n->actions_ = { EngineAction::NewRequestPrefill(n->models_, // @@ -422,13 +433,9 @@ class EngineImpl : public Engine { json::LookupOptional(config, "max_history_size"); std::optional kv_state_kind_str = json::LookupOptional(config, "kv_state_kind"); - std::optional kv_state_kind; - if (kv_state_kind_str.has_value()) { - kv_state_kind = KVStateKindFromString(kv_state_kind_str.value()); - } - InferrableEngineConfig inferrable_cfg{max_num_sequence, max_total_sequence_length, + InferrableEngineConfig inferrable_cfg{max_num_sequence, max_total_sequence_length, max_single_sequence_length, prefill_chunk_size, - max_history_size, kv_state_kind}; + max_history_size}; // - Get the model metadata. std::vector model_metadata; @@ -440,28 +447,13 @@ class EngineImpl : public Engine { if (use_kv_cache.IsErr()) { return TResult::Error(use_kv_cache.UnwrapErr()); } - KVStateKind inferred_kv_state_kind; Result inferrable_cfg_res; if (use_kv_cache.Unwrap()) { - inferred_kv_state_kind = KVStateKind::kKVCache; - // - Check if the kv state kind from config is valid. - if (kv_state_kind.has_value() && kv_state_kind.value() != inferred_kv_state_kind) { - return TResult::Error( - "Invalid kv state kind in EngineConfig. The models use KV cache, but RNN state is " - "specified in EngineConfig."); - } // - Infer configuration. inferrable_cfg_res = InferrableEngineConfig::InferForKVCache( mode, device_, gpu_memory_utilization, model_configs, model_metadata, inferrable_cfg, verbose); } else { - inferred_kv_state_kind = KVStateKind::kRNNState; - // - Check if the kv state kind from config is valid. - if (kv_state_kind.has_value() && kv_state_kind.value() != inferred_kv_state_kind) { - return TResult::Error( - "Invalid kv state kind in EngineConfig. The models use RNN state, but KV cache is " - "specified in EngineConfig."); - } // - Infer configuration. inferrable_cfg_res = InferrableEngineConfig::InferForRNNState( mode, device_, gpu_memory_utilization, model_configs, model_metadata, inferrable_cfg, @@ -477,7 +469,6 @@ class EngineImpl : public Engine { ICHECK(inferrable_cfg.max_single_sequence_length.has_value()); ICHECK(inferrable_cfg.prefill_chunk_size.has_value()); ICHECK(inferrable_cfg.max_history_size.has_value()); - ICHECK(inferrable_cfg.kv_state_kind.has_value()); return TResult::Ok(EngineConfig::FromJSONAndInferredConfig(config, inferrable_cfg)); } @@ -499,9 +490,9 @@ class EngineImpl : public Engine { if (response_format.type != "json_object") { return std::nullopt; } else if (!response_format.schema) { - return grammar_init_context_storage_->GetInitContextForJSON(); + return grammar_init_context_cache_->GetInitContextForJSON(); } else { - return grammar_init_context_storage_->GetInitContextForJSONSchema( + return grammar_init_context_cache_->GetInitContextForJSONSchema( response_format.schema.value()); } } @@ -513,7 +504,7 @@ class EngineImpl : public Engine { Tokenizer tokenizer_; std::vector token_table_; // Helper to get the grammar init context for requests. - GrammarInitContextStorage grammar_init_context_storage_; + GrammarInitContextCache grammar_init_context_cache_; // Models Array models_; // Device that the models run on. diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index af0dfe978d..3289ef57c6 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -211,6 +211,26 @@ RequestStateEntry PreemptLastRunningRequestStateEntry( return rsentry; } +std::pair> ApplyLogitProcessorAndSample( + const LogitProcessor& logit_processor, const Sampler& sampler, const NDArray& logits, + const Array& generation_cfg, const Array& request_ids, + const Array& mstates, const std::vector& rngs, + const std::vector& sample_indices) { + // - Update logits. + logit_processor->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); + + // - Compute probability distributions. + NDArray probs_on_device = + logit_processor->ComputeProbsFromLogits(logits, generation_cfg, request_ids); + + // - Sample tokens. + NDArray renormalized_probs = sampler->BatchRenormalizeProbsByTopP(probs_on_device, sample_indices, + request_ids, generation_cfg); + std::vector sample_results = sampler->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); + return {std::move(probs_on_device), std::move(sample_results)}; +} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/engine_actions/action_commons.h b/cpp/serve/engine_actions/action_commons.h index 07bef2d2d9..de98e11e67 100644 --- a/cpp/serve/engine_actions/action_commons.h +++ b/cpp/serve/engine_actions/action_commons.h @@ -75,6 +75,24 @@ inline std::vector GetRunningRequestStateEntries(const Engine return rsentries; } +/*! + * \brief Apply the logit processor to the logits and sample one token for each request. + * \param logit_processor The logit processor to apply. + * \param sampler The sampler to sample tokens. + * \param logits The logits to process. + * \param generation_cfg The generation configurations of the requests. + * \param request_ids The request ids. + * \param mstates The model states of the requests. + * \param rngs The random generators of the requests. + * \param sample_indices The indices of the requests to sample. + * \return The processed logits and the sampled results. + */ +std::pair> ApplyLogitProcessorAndSample( + const LogitProcessor& logit_processor, const Sampler& sampler, const NDArray& logits, + const Array& generation_cfg, const Array& request_ids, + const Array& mstates, const std::vector& rngs, + const std::vector& sample_indices); + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/engine_actions/batch_prefill_base.cc b/cpp/serve/engine_actions/batch_prefill_base.cc new file mode 100644 index 0000000000..f570551417 --- /dev/null +++ b/cpp/serve/engine_actions/batch_prefill_base.cc @@ -0,0 +1,315 @@ +/*! + * Copyright (c) 2024 by Contributors + * \file serve/engine_actions/batch_prefill_base.h + */ + +#include "batch_prefill_base.h" + +namespace mlc { +namespace llm { +namespace serve { + +BatchPrefillBaseActionObj::BatchPrefillBaseActionObj(Array models, + EngineConfig engine_config, + Optional trace_recorder) + : models_(models), engine_config_(engine_config), trace_recorder_(trace_recorder) {} + +/*! + * \brief Find one or multiple request state entries to run prefill. + * \param estate The engine state. + * \return The request entries to prefill, together with their input lengths. + */ +std::vector +BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) { + if (estate->waiting_queue.empty()) { + // No request to prefill. + return {}; + } + + std::vector prefill_inputs; + + // - Try to prefill pending requests. + int total_input_length = 0; + int total_required_pages = 0; + int num_available_pages = models_[0]->GetNumAvailablePages(); + int num_running_rsentries = GetRunningRequestStateEntries(estate).size(); + int current_total_seq_len = models_[0]->GetCurrentTotalSequenceLength(); + KVStateKind kv_state_kind = models_[0]->GetMetadata().kv_state_kind; + + int num_prefill_rsentries = 0; + for (const Request& request : estate->waiting_queue) { + RequestState rstate = estate->GetRequestState(request); + bool prefill_stops = false; + for (const RequestStateEntry& rsentry : rstate->entries) { + // A request state entry can be prefilled only when: + // - it has inputs, and + // - it has no parent or its parent is alive and has no remaining input. + if (rsentry->mstates[0]->inputs.empty() || + (rsentry->parent_idx != -1 && + (rstate->entries[rsentry->parent_idx]->status == RequestStateStatus::kPending || + !rstate->entries[rsentry->parent_idx]->mstates[0]->inputs.empty()))) { + continue; + } + + int input_length = rsentry->mstates[0]->GetInputLength(); + int num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; + total_input_length += input_length; + total_required_pages += num_require_pages; + // - Attempt 1. Check if the entire request state entry can fit for prefill. + bool can_prefill = false; + for (int num_child_to_activate = rsentry->child_indices.size(); num_child_to_activate >= 0; + --num_child_to_activate) { + if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate, + total_input_length, total_required_pages, num_available_pages, + current_total_seq_len, num_running_rsentries, kv_state_kind)) { + prefill_inputs.push_back({rsentry, input_length, num_child_to_activate}); + num_prefill_rsentries += 1 + num_child_to_activate; + can_prefill = true; + break; + } + } + if (can_prefill) { + continue; + } + total_input_length -= input_length; + total_required_pages -= num_require_pages; + + // - Attempt 2. Check if the request state entry can partially fit by input chunking. + ICHECK_LE(total_input_length, engine_config_->prefill_chunk_size); + if (engine_config_->prefill_chunk_size - total_input_length >= input_length || + engine_config_->prefill_chunk_size == total_input_length) { + // 1. If the input length can fit the remaining prefill chunk size, + // it means the failure of attempt 1 is not because of the input + // length being too long, and thus chunking does not help. + // 2. If the total input length already reaches the prefill chunk size, + // the current request state entry will not be able to be processed. + // So we can safely return in either case. + prefill_stops = true; + break; + } + input_length = engine_config_->prefill_chunk_size - total_input_length; + num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; + total_input_length += input_length; + total_required_pages += num_require_pages; + if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages, + num_available_pages, current_total_seq_len, num_running_rsentries, + kv_state_kind)) { + prefill_inputs.push_back({rsentry, input_length, 0}); + num_prefill_rsentries += 1; + } + + // - Prefill stops here. + prefill_stops = true; + break; + } + if (prefill_stops) { + break; + } + } + + return prefill_inputs; +} + +/*! \brief Check if the input requests can be prefilled under conditions. */ +bool BatchPrefillBaseActionObj::CanPrefill(EngineState estate, int num_prefill_rsentries, + int total_input_length, int num_required_pages, + int num_available_pages, int current_total_seq_len, + int num_running_rsentries, KVStateKind kv_state_kind) { + ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); + + // For RNN State, it can prefill as long as it can be instantiated. + if (kv_state_kind == KVStateKind::kRNNState || kv_state_kind == KVStateKind::kNone) { + return true; + } + + // No exceeding of the maximum allowed requests that can + // run simultaneously. + int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable + ? (engine_config_->spec_draft_length + 1) + : 1; + if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > + std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { + return false; + } + + // NOTE: The conditions are heuristic and can be revised. + // Cond 1: total input length <= prefill chunk size. + // Cond 2: at least one decode can be performed after prefill. + // Cond 3: number of total tokens after 8 times of decode does not + // exceed the limit, where 8 is a watermark number can + // be configured and adjusted in the future. + int new_batch_size = num_running_rsentries + num_prefill_rsentries; + return total_input_length <= engine_config_->prefill_chunk_size && + num_required_pages + new_batch_size <= num_available_pages && + current_total_seq_len + total_input_length + 8 * new_batch_size <= + engine_config_->max_total_sequence_length; +} + +/*! + * \brief Chunk the input of the given RequestModelState for prefill + * with regard to the provided maximum allowed prefill length. + * Return the list of input for prefill and the total prefill length. + * The `inputs` field of the given `mstate` will be mutated to exclude + * the returned input. + * \param mstate The RequestModelState whose input data is to be chunked. + * \param max_prefill_length The maximum allowed prefill length for the mstate. + * \return The list of input for prefill and the total prefill length. + */ +std::pair, int> BatchPrefillBaseActionObj::ChunkPrefillInputData( + const RequestModelState& mstate, int max_prefill_length) { + if (mstate->inputs.empty()) { + } + ICHECK(!mstate->inputs.empty()); + std::vector inputs; + int cum_input_length = 0; + inputs.reserve(mstate->inputs.size()); + for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { + inputs.push_back(mstate->inputs[i]); + int input_length = mstate->inputs[i]->GetLength(); + cum_input_length += input_length; + // Case 0. the cumulative input length does not reach the maximum prefill length. + if (cum_input_length < max_prefill_length) { + continue; + } + + // Case 1. the cumulative input length equals the maximum prefill length. + if (cum_input_length == max_prefill_length) { + if (i == static_cast(mstate->inputs.size()) - 1) { + // - If `i` is the last input, we just copy and reset `mstate->inputs`. + mstate->inputs.clear(); + } else { + // - Otherwise, set the new input array. + mstate->inputs = Array{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; + } + return {inputs, cum_input_length}; + } + + // Case 2. cum_input_length > max_prefill_length + // The input `i` itself needs chunking if it is TokenData, + // or otherwise it cannot be chunked. + Data input = mstate->inputs[i]; + inputs.pop_back(); + cum_input_length -= input_length; + const auto* token_input = input.as(); + if (token_input == nullptr) { + // Cannot chunk the input. + if (i != 0) { + mstate->inputs = Array{mstate->inputs.begin() + i, mstate->inputs.end()}; + } + return {inputs, cum_input_length}; + } + + // Split the token data into two parts. + // Return the first part for prefill, and keep the second part. + int chunked_input_length = max_prefill_length - cum_input_length; + ICHECK_GT(input_length, chunked_input_length); + TokenData chunked_input(IntTuple{token_input->token_ids.begin(), + token_input->token_ids.begin() + chunked_input_length}); + TokenData remaining_input(IntTuple{token_input->token_ids.begin() + chunked_input_length, + token_input->token_ids.end()}); + inputs.push_back(chunked_input); + cum_input_length += chunked_input_length; + std::vector remaining_inputs{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; + remaining_inputs.insert(remaining_inputs.begin(), remaining_input); + mstate->inputs = remaining_inputs; + return {inputs, cum_input_length}; + } + + ICHECK(false) << "Cannot reach here"; +} + +void BatchPrefillBaseActionObj::UpdateRequestToAlive( + const std::vector& prefill_inputs, + const EngineState& estate, Array* request_ids, + std::vector* rstates_of_entries, + std::vector* status_before_prefill) { + int num_rsentries = prefill_inputs.size(); + request_ids->reserve(num_rsentries); + rstates_of_entries->reserve(num_rsentries); + status_before_prefill->reserve(num_rsentries); + for (const PrefillInput& prefill_input : prefill_inputs) { + const RequestStateEntry& rsentry = prefill_input.rsentry; + const Request& request = rsentry->request; + RequestState request_rstate = estate->GetRequestState(request); + request_ids->push_back(request->id); + status_before_prefill->push_back(rsentry->status); + rsentry->status = RequestStateStatus::kAlive; + + if (status_before_prefill->back() == RequestStateStatus::kPending) { + // - Add the request to running queue if the request state + // status was pending and all its request states were pending. + bool alive_state_existed = false; + for (const RequestStateEntry& rsentry_ : request_rstate->entries) { + if (rsentry_->status == RequestStateStatus::kAlive && !rsentry_.same_as(rsentry)) { + alive_state_existed = true; + } + } + if (!alive_state_existed) { + estate->running_queue.push_back(request); + } + } + rstates_of_entries->push_back(std::move(request_rstate)); + } +} + +std::vector BatchPrefillBaseActionObj::RemoveProcessedRequests( + const std::vector& prefill_inputs, + const EngineState& estate, const std::vector& rstates_of_entries) { + // - Remove the request from waiting queue if all its request states + // are now alive and have no remaining chunked inputs. + std::vector processed_requests; + int num_rsentries = prefill_inputs.size(); + processed_requests.reserve(num_rsentries); + std::unordered_set dedup_map; + for (int i = 0; i < num_rsentries; ++i) { + const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; + if (dedup_map.find(rsentry->request.get()) != dedup_map.end()) { + continue; + } + dedup_map.insert(rsentry->request.get()); + processed_requests.push_back(rsentry->request); + + bool pending_state_exists = false; + for (const RequestStateEntry& rsentry_ : rstates_of_entries[i]->entries) { + if (rsentry_->status == RequestStateStatus::kPending || + !rsentry_->mstates[0]->inputs.empty()) { + pending_state_exists = true; + break; + } + } + if (!pending_state_exists) { + auto it = + std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), rsentry->request); + ICHECK(it != estate->waiting_queue.end()); + estate->waiting_queue.erase(it); + } + } + return processed_requests; +} + +void BatchPrefillBaseActionObj::UpdateRequestStateEntriesWithSampleResults( + const std::vector& rsentries_for_sample, + const std::vector& rsentry_activated, const std::vector& sample_results) { + auto tnow = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { + // Update all model states of the request state entry. + for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) { + mstate->CommitToken(sample_results[i]); + if (!rsentry_activated[i]) { + // When the child rsentry is not activated, + // add the sampled token as an input of the mstate for prefill. + mstate->inputs.push_back( + TokenData(std::vector{sample_results[i].sampled_token_id.first})); + } + } + if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { + rsentries_for_sample[i]->tprefill_finish = tnow; + } + } +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/engine_actions/batch_prefill_base.h b/cpp/serve/engine_actions/batch_prefill_base.h new file mode 100644 index 0000000000..122a214496 --- /dev/null +++ b/cpp/serve/engine_actions/batch_prefill_base.h @@ -0,0 +1,107 @@ +/*! + * Copyright (c) 2024 by Contributors + * \file serve/engine_actions/batch_prefill_base.h + */ + +#include + +#include "../config.h" +#include "../model.h" +#include "action.h" +#include "action_commons.h" + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief The base action of that prefills requests in the `waiting_queue` of + * the engine state. + */ +class BatchPrefillBaseActionObj : public EngineActionObj { + protected: + /*! \brief The class of request state entry and its maximum allowed length for prefill. */ + struct PrefillInput { + RequestStateEntry rsentry; + int max_prefill_length = 0; + int num_child_to_activate = 0; + }; + + BatchPrefillBaseActionObj(Array models, EngineConfig engine_config, + Optional trace_recorder); + + /*! + * \brief Find one or multiple request state entries to run prefill. + * \param estate The engine state. + * \return The request entries to prefill, together with their input lengths. + */ + std::vector GetRequestStateEntriesToPrefill(EngineState estate); + + /*! \brief Check if the input requests can be prefilled under conditions. */ + bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, + int num_required_pages, int num_available_pages, int current_total_seq_len, + int num_running_rsentries, KVStateKind kv_state_kind); + + /*! + * \brief Chunk the input of the given RequestModelState for prefill + * with regard to the provided maximum allowed prefill length. + * Return the list of input for prefill and the total prefill length. + * The `inputs` field of the given `mstate` will be mutated to exclude + * the returned input. + * \param mstate The RequestModelState whose input data is to be chunked. + * \param max_prefill_length The maximum allowed prefill length for the mstate. + * \return The list of input for prefill and the total prefill length. + */ + std::pair, int> ChunkPrefillInputData(const RequestModelState& mstate, + int max_prefill_length); + + /*! + * \brief Update status of request states from pending to alive and collect request state entries + * from the prefill input. + * \param prefill_inputs The prefill input. + * \param estate The engine state. + * \param[out] request_ids The array to store the request ids of the request state entries. + * \param[out] rstates_of_entries The vector to store the request state entries. + * \param[out] status_before_prefill The vector to store the status of the request state entries + * before prefill. + */ + void UpdateRequestToAlive(const std::vector& prefill_inputs, + const EngineState& estate, Array* request_ids, + std::vector* rstates_of_entries, + std::vector* status_before_prefill); + + /*! + * \brief Remove the request from waiting queue if all its request states are now alive and have + * no remaining chunked inputs. + * \param prefill_inputs The prefill input. + * \param estate The engine state. + * \param rstates_of_entries The request state entries for each prefill input. + * \return The processed requests. + */ + std::vector RemoveProcessedRequests(const std::vector& prefill_inputs, + const EngineState& estate, + const std::vector& rstates_of_entries); + + /*! + * \brief Update the committed tokens of states. If a request is first-time prefilled, set the + * prefill finish time. + * \param rsentries_for_sample The request state entries for sample. + * \param + * rsentry_activated The activation status of the request state entries. + * \param sample_results The sample results. + */ + void UpdateRequestStateEntriesWithSampleResults( + const std::vector& rsentries_for_sample, + const std::vector& rsentry_activated, const std::vector& sample_results); + + /*! \brief The models to run prefill in. */ + Array models_; + /*! \brief The engine config. */ + EngineConfig engine_config_; + /*! \brief Event trace recorder. */ + Optional trace_recorder_; +}; + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index 71daaf1bf9..1a8bec2eea 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -179,7 +179,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // - Slice and save hidden_states_for_sample last_accepted_hidden_positions.push_back(cum_verify_lengths[i] + accept_length - 1); } - if (!fully_accepted_rsentries.empty()) { + if (!fully_accepted_rsentries.empty() && + engine_config_->speculative_mode == SpeculativeMode::kEagle) { // - Run a step of batch decode for requests whose drafts are fully accepted. // When a request's draft is fully accepted, there is an extra token proposed // by the draft model but not added into the draft model's KV cache. @@ -218,11 +219,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj { hidden_states, hidden_states_positions_for_fully_accepted, &model_workspaces_[draft_model_id_].hidden_states); // - Invoke model decode. - ObjectRef fused_embedding_hidden_states = - models_[draft_model_id_]->FuseEmbedHidden(embeddings, hidden_states_for_fully_accepted, - /*batch_size*/ num_rsentries, /*seq_len*/ 1); + ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( + embeddings, hidden_states_for_fully_accepted, + /*batch_size*/ fully_accepted_rsentries.size(), /*seq_len*/ 1); hidden_states_for_fully_accepted = models_[draft_model_id_]->BatchDecodeToLastHidden( - fused_embedding_hidden_states, request_internal_ids); + fused_embedding_hidden_states, fully_accepted_request_internal_ids); // - We explicitly synchronize to avoid the input tokens getting overriden in the // next runs of BatchDecode. // This is because we do not do sample for this round of batch decode. @@ -239,9 +240,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // One step draft for the following steps // Gather hidden states for the last accepted tokens. - hidden_states = models_[draft_model_id_]->GatherHiddenStates( - hidden_states, last_accepted_hidden_positions, - &model_workspaces_[draft_model_id_].hidden_states); + // Use the function and the workspace of the verify model because the information about the + // hidden states is not available in the draft model for medusa. + hidden_states = models_[0]->GatherHiddenStates(hidden_states, last_accepted_hidden_positions, + &model_workspaces_[0].hidden_states); std::vector input_tokens; Array mstates; @@ -255,61 +257,50 @@ class EagleBatchVerifyActionObj : public EngineActionObj { input_tokens.push_back(mstates[i]->committed_tokens.back().sampled_token_id.first); } - // - Compute embeddings. - RECORD_EVENT(trace_recorder_, request_ids, "start proposal embedding"); - embeddings = models_[draft_model_id_]->TokenEmbed( - {IntTuple{input_tokens.begin(), input_tokens.end()}}); - RECORD_EVENT(trace_recorder_, request_ids, "finish proposal embedding"); - - // - Invoke model decode. - RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); - ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( - embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); - hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden( - fused_embedding_hidden_states, request_internal_ids); - - if (models_[draft_model_id_]->CanGetLogits()) { - logits = models_[draft_model_id_]->GetLogits(hidden_states); - } else { - // - Use base model's head. - logits = models_[0]->GetLogits(hidden_states); + Array multi_step_logits{nullptr}; // for medusa output + if (engine_config_->speculative_mode == SpeculativeMode::kEagle) { + // - Compute embeddings. + RECORD_EVENT(trace_recorder_, request_ids, "start proposal embedding"); + embeddings = models_[draft_model_id_]->TokenEmbed( + {IntTuple{input_tokens.begin(), input_tokens.end()}}); + RECORD_EVENT(trace_recorder_, request_ids, "finish proposal embedding"); + + // - Invoke model decode. + RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); + ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( + embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden( + fused_embedding_hidden_states, request_internal_ids); + + int lm_head_model_id = models_[draft_model_id_]->CanGetLogits() ? draft_model_id_ : 0; + logits = models_[lm_head_model_id]->GetLogits(hidden_states); + RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); + ICHECK_EQ(logits->ndim, 2); + ICHECK_EQ(logits->shape[0], num_rsentries); + } else if (engine_config_->speculative_mode == SpeculativeMode::kMedusa) { + multi_step_logits = models_[draft_model_id_]->GetMultiStepLogits(hidden_states); } - RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); - ICHECK_EQ(logits->ndim, 2); - ICHECK_EQ(logits->shape[0], num_rsentries); - // - Update logits. - logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); - - // - Compute probability distributions. - probs_on_device = - logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids); - - // - Sample tokens. // Fill range [0, num_rsentries) into `sample_indices`. std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); - NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( - probs_on_device, sample_indices, request_ids, generation_cfg); - std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( - renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); - ICHECK_EQ(sample_results.size(), num_rsentries); - // - Slice and save hidden_states_for_sample - draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_); - models_[draft_model_id_]->ScatterDraftProbs( - renormalized_probs, draft_token_slots_, - &model_workspaces_[verify_model_id_].draft_probs_storage); - models_[draft_model_id_]->ScatterHiddenStates( - hidden_states, draft_token_slots_, - &model_workspaces_[verify_model_id_].draft_hidden_states_storage); - // - Add draft token to the state. - for (int i = 0; i < num_rsentries; ++i) { - mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]); - estate->stats.total_draft_length += 1; + if (engine_config_->speculative_mode == SpeculativeMode::kEagle) { + const auto& [renormalized_probs, sample_results] = + ApplyLogitProcessorAndSample(logit_processor_, sampler_, logits, generation_cfg, + request_ids, mstates, rngs, sample_indices); + UpdateRequestStatesWithDraftProposals(mstates, sample_results, draft_model_id_, + renormalized_probs, hidden_states, estate); + } else if (engine_config_->speculative_mode == SpeculativeMode::kMedusa) { + for (int draft_id = 0; draft_id < engine_config_->spec_draft_length; draft_id++) { + const auto& [renormalized_probs, sample_results] = ApplyLogitProcessorAndSample( + logit_processor_, sampler_, multi_step_logits[draft_id], generation_cfg, request_ids, + mstates, rngs, sample_indices); + UpdateRequestStatesWithDraftProposals(mstates, sample_results, draft_model_id_, + renormalized_probs, hidden_states, estate); + } } } - auto tend = std::chrono::high_resolution_clock::now(); estate->stats.engine_total_decode_time += static_cast((tend - tstart).count()) / 1e9; @@ -371,6 +362,24 @@ class EagleBatchVerifyActionObj : public EngineActionObj { return num_required_pages <= num_available_pages; } + void UpdateRequestStatesWithDraftProposals(const Array& mstates, + const std::vector& sample_results, + int model_id, const NDArray& renormalized_probs, + const ObjectRef& hidden_states_for_sample, + EngineState estate) { + draft_token_workspace_manager_->AllocSlots(mstates.size(), &draft_token_slots_); + models_[0]->ScatterDraftProbs(renormalized_probs, draft_token_slots_, + &model_workspaces_[0].draft_probs_storage); + if (engine_config_->speculative_mode == SpeculativeMode::kEagle && + engine_config_->spec_draft_length > 1) { + models_[0]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_, + &model_workspaces_[0].draft_hidden_states_storage); + } + for (int i = 0; i < static_cast(mstates.size()); ++i) { + mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]); + estate->stats.total_draft_length += 1; + } + } /*! * \brief The model to run decode in. When there are multiple * models, the `Step` function of the created action will not take effect. diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index e2d2d661f8..a2da53e171 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -10,6 +10,7 @@ #include "../sampler/sampler.h" #include "action.h" #include "action_commons.h" +#include "batch_prefill_base.h" namespace mlc { namespace llm { @@ -19,7 +20,7 @@ namespace serve { * \brief The action that prefills requests in the `waiting_queue` of * the engine state. */ -class EagleNewRequestPrefillActionObj : public EngineActionObj { +class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj { public: explicit EagleNewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, @@ -27,13 +28,12 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, Optional trace_recorder) - : models_(std::move(models)), + : BatchPrefillBaseActionObj(std::move(models), std::move(engine_config), + std::move(trace_recorder)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), - draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), - engine_config_(std::move(engine_config)), - trace_recorder_(std::move(trace_recorder)) {} + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)) {} Array Step(EngineState estate) final { // - Find the requests in `waiting_queue` that can prefill in this step. @@ -53,32 +53,8 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { Array request_ids; std::vector rstates_of_entries; std::vector status_before_prefill; - request_ids.reserve(num_rsentries); - rstates_of_entries.reserve(num_rsentries); - status_before_prefill.reserve(num_rsentries); - for (const PrefillInput& prefill_input : prefill_inputs) { - const RequestStateEntry& rsentry = prefill_input.rsentry; - const Request& request = rsentry->request; - RequestState request_rstate = estate->GetRequestState(request); - request_ids.push_back(request->id); - status_before_prefill.push_back(rsentry->status); - rsentry->status = RequestStateStatus::kAlive; - - if (status_before_prefill.back() == RequestStateStatus::kPending) { - // - Add the request to running queue if the request state - // status was pending and all its request states were pending. - bool alive_state_existed = false; - for (const RequestStateEntry& rsentry_ : request_rstate->entries) { - if (rsentry_->status == RequestStateStatus::kAlive && !rsentry_.same_as(rsentry)) { - alive_state_existed = true; - } - } - if (!alive_state_existed) { - estate->running_queue.push_back(request); - } - } - rstates_of_entries.push_back(std::move(request_rstate)); - } + UpdateRequestToAlive(prefill_inputs, estate, &request_ids, &rstates_of_entries, + &status_before_prefill); // - Get embedding and run prefill for each model. std::vector prefill_lengths; @@ -135,6 +111,11 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { } } request_internal_ids.push_back(mstate->internal_id); + + if (engine_config_->speculative_mode == SpeculativeMode::kMedusa && model_id > 0) { + // Embedding is only needed for the base model in Medusa. + continue; + } RECORD_EVENT(trace_recorder_, prefill_inputs[i].rsentry->request->id, "start embedding"); // Speculative models shift left the input tokens by 1 when base model has committed tokens. // Note: for n > 1 cases Eagle doesn't work because parent entry doesn't shift input tokens. @@ -149,59 +130,56 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { } RECORD_EVENT(trace_recorder_, request_ids, "start prefill"); - ObjectRef embedding_or_hidden_states{nullptr}; - if (model_id == 0) { - embedding_or_hidden_states = embeddings; - } else { - embedding_or_hidden_states = models_[model_id]->FuseEmbedHidden( - embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length); - } - // hidden_states: (b * s, h) - ObjectRef hidden_states = models_[model_id]->BatchPrefillToLastHidden( - embedding_or_hidden_states, request_internal_ids, prefill_lengths); - RECORD_EVENT(trace_recorder_, request_ids, "finish prefill"); - if (model_id == 0) { - // We only need to sample for model 0 in prefill. - hidden_states_for_input = hidden_states; - } + Array multi_step_logits{nullptr}; - // Whether to use base model to get logits. - int sample_model_id = !models_[model_id]->CanGetLogits() ? 0 : model_id; + if (model_id == 0 || engine_config_->speculative_mode == SpeculativeMode::kEagle) { + ObjectRef embedding_or_hidden_states{nullptr}; + if (model_id == 0) { + embedding_or_hidden_states = embeddings; + } else { + embedding_or_hidden_states = + models_[model_id]->FuseEmbedHidden(embeddings, hidden_states_for_input, + /*batch_size*/ 1, /*seq_len*/ cum_prefill_length); + } + // hidden_states: (b * s, h) + ObjectRef hidden_states = models_[model_id]->BatchPrefillToLastHidden( + embedding_or_hidden_states, request_internal_ids, prefill_lengths); + RECORD_EVENT(trace_recorder_, request_ids, "finish prefill"); + + if (model_id == 0) { + // We only need to sample for model 0 in prefill. + hidden_states_for_input = hidden_states; + } - std::vector logit_positions; - { - // Prepare the logit positions - logit_positions.reserve(prefill_lengths.size()); - int total_len = 0; - for (int i = 0; i < prefill_lengths.size(); ++i) { - total_len += prefill_lengths[i]; - logit_positions.push_back(total_len - 1); + // Whether to use base model to get logits. + int sample_model_id = !models_[model_id]->CanGetLogits() ? 0 : model_id; + + std::vector logit_positions; + { + // Prepare the logit positions + logit_positions.reserve(prefill_lengths.size()); + int total_len = 0; + for (int i = 0; i < prefill_lengths.size(); ++i) { + total_len += prefill_lengths[i]; + logit_positions.push_back(total_len - 1); + } } + // hidden_states_for_sample: (b * s, h) + hidden_states_for_sample = models_[sample_model_id]->GatherHiddenStates( + hidden_states, logit_positions, &model_workspaces_[model_id].hidden_states); + // logits_for_sample: (b * s, v) + logits_for_sample = models_[sample_model_id]->GetLogits(hidden_states_for_sample); + } else if (engine_config_->speculative_mode == SpeculativeMode::kMedusa) { + // Note: spec_draft_length in engine config has to be match the model config in Medusa. + multi_step_logits = models_[model_id]->GetMultiStepLogits(hidden_states_for_sample); + } else { + LOG(FATAL) << "unreachable"; } - // hidden_states_for_sample: (b * s, h) - hidden_states_for_sample = models_[sample_model_id]->GatherHiddenStates( - hidden_states, logit_positions, &model_workspaces_[model_id].hidden_states); - // logits_for_sample: (b * s, v) - logits_for_sample = models_[sample_model_id]->GetLogits(hidden_states_for_sample); - // - Update logits. - ICHECK(logits_for_sample.defined()); - Array generation_cfg; - Array mstates_for_logitproc; - generation_cfg.reserve(num_rsentries); - mstates_for_logitproc.reserve(num_rsentries); - for (int i = 0; i < num_rsentries; ++i) { - generation_cfg.push_back(prefill_inputs[i].rsentry->request->generation_cfg); - mstates_for_logitproc.push_back(prefill_inputs[i].rsentry->mstates[sample_model_id]); - } - logit_processor_->InplaceUpdateLogits(logits_for_sample, generation_cfg, - mstates_for_logitproc, request_ids); - // - Compute probability distributions. - NDArray probs_on_device = - logit_processor_->ComputeProbsFromLogits(logits_for_sample, generation_cfg, request_ids); + Array request_ids_for_logitproc = request_ids; - // - Sample tokens. + // - Prepare the configurations for the sampler. // For prefill_inputs which have children, sample // one token for each rstate that is depending. // Otherwise, sample a token for the current rstate. @@ -209,12 +187,12 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { std::vector rsentries_for_sample; std::vector rngs; std::vector rsentry_activated; + Array generation_cfg; sample_indices.reserve(num_rsentries); rsentries_for_sample.reserve(num_rsentries); rngs.reserve(num_rsentries); rsentry_activated.reserve(num_rsentries); request_ids.clear(); - generation_cfg.clear(); for (int i = 0; i < num_rsentries; ++i) { const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; // No sample for rsentries with remaining inputs. @@ -275,27 +253,26 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { } } - NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( - probs_on_device, sample_indices, request_ids, generation_cfg); - std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( - renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); - ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); - - // - Update the committed tokens of states. - // - If a request is first-time prefilled, set the prefill finish time. - auto tnow = std::chrono::high_resolution_clock::now(); - if (model_id == 0) { - for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { - for (int mid = 0; mid < static_cast(models_.size()); ++mid) { - rsentries_for_sample[i]->mstates[mid]->CommitToken(sample_results[i]); - if (!rsentry_activated[i]) { - // When the child rsentry is not activated, - // add the sampled token as an input of the mstate for prefill. - rsentries_for_sample[i]->mstates[mid]->inputs.push_back( - TokenData(std::vector{sample_results[i].sampled_token_id.first})); - } - if (mid > 0) { - // Add the sampled token as an input of the eagle models. + // - Prepare input for logit processor. + ICHECK(logits_for_sample.defined()); + Array generation_cfg_for_logitproc; + Array mstates_for_logitproc; + generation_cfg_for_logitproc.reserve(num_rsentries); + mstates_for_logitproc.reserve(num_rsentries); + for (int i = 0; i < num_rsentries; ++i) { + generation_cfg_for_logitproc.push_back(prefill_inputs[i].rsentry->request->generation_cfg); + mstates_for_logitproc.push_back(prefill_inputs[i].rsentry->mstates[model_id]); + } + if (model_id == 0 || engine_config_->speculative_mode == SpeculativeMode::kEagle) { + const auto& [renormalized_probs, sample_results] = ApplyLogitProcessorAndSample( + logit_processor_, sampler_, logits_for_sample, generation_cfg_for_logitproc, + request_ids_for_logitproc, mstates_for_logitproc, rngs, sample_indices); + if (model_id == 0) { + UpdateRequestStateEntriesWithSampleResults(rsentries_for_sample, rsentry_activated, + sample_results); + // Add the sampled token as an input of the eagle models. + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { + for (int mid = 1; mid < static_cast(models_.size()); ++mid) { TokenData token_data = Downcast(rsentries_for_sample[i]->mstates[mid]->inputs.back()); std::vector token_ids = {token_data->token_ids.begin(), @@ -306,25 +283,21 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { ninputs - 1, TokenData(IntTuple(token_ids.begin(), token_ids.end()))); } } - // Only base model trigger timing records. - if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { - rsentries_for_sample[i]->tprefill_finish = tnow; - } - } - } else { - // - Slice and save hidden_states_for_sample - draft_token_workspace_manager_->AllocSlots(rsentries_for_sample.size(), - &draft_token_slots_); - models_[model_id]->ScatterDraftProbs(renormalized_probs, draft_token_slots_, - &model_workspaces_[0].draft_probs_storage); - if (engine_config_->spec_draft_length > 1) { - models_[model_id]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_, - &model_workspaces_[0].draft_hidden_states_storage); + } else { + // - Slice and save hidden_states_for_sample + UpdateRequestStatesWithDraftProposals(rsentries_for_sample, sample_results, model_id, + renormalized_probs, hidden_states_for_sample, + estate); } - for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { - rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(sample_results[i], - draft_token_slots_[i]); - estate->stats.total_draft_length += 1; + } else if (engine_config_->speculative_mode == SpeculativeMode::kMedusa) { + for (int draft_id = 0; draft_id < engine_config_->spec_draft_length; ++draft_id) { + const auto& [renormalized_probs, sample_results] = ApplyLogitProcessorAndSample( + logit_processor_, sampler_, multi_step_logits[draft_id], generation_cfg_for_logitproc, + request_ids_for_logitproc, mstates_for_logitproc, rngs, sample_indices); + + UpdateRequestStatesWithDraftProposals(rsentries_for_sample, sample_results, model_id, + renormalized_probs, + /*hidden_states=*/ObjectRef{nullptr}, estate); } } } @@ -332,246 +305,32 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { auto tend = std::chrono::high_resolution_clock::now(); estate->stats.engine_total_prefill_time += static_cast((tend - tstart).count()) / 1e9; - // - Remove the request from waiting queue if all its request states - // are now alive and have no remaining chunked inputs. - std::vector processed_requests; - { - processed_requests.reserve(num_rsentries); - std::unordered_set dedup_map; - for (int i = 0; i < num_rsentries; ++i) { - const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; - if (dedup_map.find(rsentry->request.get()) != dedup_map.end()) { - continue; - } - dedup_map.insert(rsentry->request.get()); - processed_requests.push_back(rsentry->request); - - bool pending_state_exists = false; - for (const RequestStateEntry& rsentry_ : rstates_of_entries[i]->entries) { - if (rsentry_->status == RequestStateStatus::kPending || - !rsentry_->mstates[0]->inputs.empty()) { - pending_state_exists = true; - break; - } - } - if (!pending_state_exists) { - auto it = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), - rsentry->request); - ICHECK(it != estate->waiting_queue.end()); - estate->waiting_queue.erase(it); - } - } - } + std::vector processed_requests = + RemoveProcessedRequests(prefill_inputs, estate, rstates_of_entries); return processed_requests; } - private: - /*! \brief The class of request state entry and its maximum allowed length for prefill. */ - struct PrefillInput { - RequestStateEntry rsentry; - int max_prefill_length = 0; - int num_child_to_activate = 0; - }; - - /*! - * \brief Find one or multiple request state entries to run prefill. - * \param estate The engine state. - * \return The request entries to prefill, together with their input lengths. - */ - std::vector GetRequestStateEntriesToPrefill(EngineState estate) { - if (estate->waiting_queue.empty()) { - // No request to prefill. - return {}; + void UpdateRequestStatesWithDraftProposals( + const std::vector& rsentries_for_sample, + const std::vector& sample_results, int model_id, + const NDArray& renormalized_probs, const ObjectRef& hidden_states_for_sample, + EngineState estate) { + draft_token_workspace_manager_->AllocSlots(rsentries_for_sample.size(), &draft_token_slots_); + models_[0]->ScatterDraftProbs(renormalized_probs, draft_token_slots_, + &model_workspaces_[0].draft_probs_storage); + if (engine_config_->speculative_mode == SpeculativeMode::kEagle && + engine_config_->spec_draft_length > 1) { + models_[0]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_, + &model_workspaces_[0].draft_hidden_states_storage); } - - std::vector prefill_inputs; - - // - Try to prefill pending requests. - int total_input_length = 0; - int total_required_pages = 0; - int num_available_pages = models_[0]->GetNumAvailablePages(); - int num_running_rsentries = GetRunningRequestStateEntries(estate).size(); - int current_total_seq_len = models_[0]->GetCurrentTotalSequenceLength(); - - int num_prefill_rsentries = 0; - for (const Request& request : estate->waiting_queue) { - RequestState rstate = estate->GetRequestState(request); - bool prefill_stops = false; - for (const RequestStateEntry& rsentry : rstate->entries) { - // A request state entry can be prefilled only when: - // - it has inputs, and - // - it has no parent or its parent is alive and has no remaining input. - if (rsentry->mstates[0]->inputs.empty() || - (rsentry->parent_idx != -1 && - (rstate->entries[rsentry->parent_idx]->status == RequestStateStatus::kPending || - !rstate->entries[rsentry->parent_idx]->mstates[0]->inputs.empty()))) { - continue; - } - - int input_length = rsentry->mstates[0]->GetInputLength(); - int num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / - engine_config_->kv_cache_page_size; - total_input_length += input_length; - total_required_pages += num_require_pages; - // - Attempt 1. Check if the entire request state entry can fit for prefill. - bool can_prefill = false; - for (int num_child_to_activate = rsentry->child_indices.size(); num_child_to_activate >= 0; - --num_child_to_activate) { - if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate, - total_input_length, total_required_pages, num_available_pages, - current_total_seq_len, num_running_rsentries)) { - prefill_inputs.push_back({rsentry, input_length, num_child_to_activate}); - num_prefill_rsentries += 1 + num_child_to_activate; - can_prefill = true; - break; - } - } - if (can_prefill) { - continue; - } - total_input_length -= input_length; - total_required_pages -= num_require_pages; - - // - Attempt 2. Check if the request state entry can partially fit by input chunking. - ICHECK_LE(total_input_length, engine_config_->prefill_chunk_size); - if (engine_config_->prefill_chunk_size - total_input_length >= input_length || - engine_config_->prefill_chunk_size == total_input_length) { - // 1. If the input length can fit the remaining prefill chunk size, - // it means the failure of attempt 1 is not because of the input - // length being too long, and thus chunking does not help. - // 2. If the total input length already reaches the prefill chunk size, - // the current request state entry will not be able to be processed. - // So we can safely return in either case. - prefill_stops = true; - break; - } - input_length = engine_config_->prefill_chunk_size - total_input_length; - num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / - engine_config_->kv_cache_page_size; - total_input_length += input_length; - total_required_pages += num_require_pages; - if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages, - num_available_pages, current_total_seq_len, num_running_rsentries)) { - prefill_inputs.push_back({rsentry, input_length, 0}); - num_prefill_rsentries += 1; - } - - // - Prefill stops here. - prefill_stops = true; - break; - } - if (prefill_stops) { - break; - } + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { + rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(sample_results[i], + draft_token_slots_[i]); + estate->stats.total_draft_length += 1; } - - return prefill_inputs; - } - - /*! \brief Check if the input requests can be prefilled under conditions. */ - bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, - int num_required_pages, int num_available_pages, int current_total_seq_len, - int num_running_rsentries) { - ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); - - // No exceeding of the maximum allowed requests that can - // run simultaneously. - int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable - ? (engine_config_->spec_draft_length + 1) - : 1; - if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > - std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { - return false; - } - - // NOTE: The conditions are heuristic and can be revised. - // Cond 1: total input length <= prefill chunk size. - // Cond 2: at least one decode can be performed after prefill. - // Cond 3: number of total tokens after 8 times of decode does not - // exceed the limit, where 8 is a watermark number can - // be configured and adjusted in the future. - int new_batch_size = num_running_rsentries + num_prefill_rsentries; - return total_input_length <= engine_config_->prefill_chunk_size && - num_required_pages + new_batch_size <= num_available_pages && - current_total_seq_len + total_input_length + 8 * new_batch_size <= - engine_config_->max_total_sequence_length; } - /*! - * \brief Chunk the input of the given RequestModelState for prefill - * with regard to the provided maximum allowed prefill length. - * Return the list of input for prefill and the total prefill length. - * The `inputs` field of the given `mstate` will be mutated to exclude - * the returned input. - * \param mstate The RequestModelState whose input data is to be chunked. - * \param max_prefill_length The maximum allowed prefill length for the mstate. - * \return The list of input for prefill and the total prefill length. - */ - std::pair, int> ChunkPrefillInputData(const RequestModelState& mstate, - int max_prefill_length) { - if (mstate->inputs.empty()) { - } - ICHECK(!mstate->inputs.empty()); - std::vector inputs; - int cum_input_length = 0; - inputs.reserve(mstate->inputs.size()); - for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { - inputs.push_back(mstate->inputs[i]); - int input_length = mstate->inputs[i]->GetLength(); - cum_input_length += input_length; - // Case 0. the cumulative input length does not reach the maximum prefill length. - if (cum_input_length < max_prefill_length) { - continue; - } - - // Case 1. the cumulative input length equals the maximum prefill length. - if (cum_input_length == max_prefill_length) { - if (i == static_cast(mstate->inputs.size()) - 1) { - // - If `i` is the last input, we just copy and reset `mstate->inputs`. - mstate->inputs.clear(); - } else { - // - Otherwise, set the new input array. - mstate->inputs = Array{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; - } - return {inputs, cum_input_length}; - } - - // Case 2. cum_input_length > max_prefill_length - // The input `i` itself needs chunking if it is TokenData, - // or otherwise it cannot be chunked. - Data input = mstate->inputs[i]; - inputs.pop_back(); - cum_input_length -= input_length; - const auto* token_input = input.as(); - if (token_input == nullptr) { - // Cannot chunk the input. - if (i != 0) { - mstate->inputs = Array{mstate->inputs.begin() + i, mstate->inputs.end()}; - } - return {inputs, cum_input_length}; - } - - // Split the token data into two parts. - // Return the first part for prefill, and keep the second part. - int chunked_input_length = max_prefill_length - cum_input_length; - ICHECK_GT(input_length, chunked_input_length); - TokenData chunked_input(IntTuple{token_input->token_ids.begin(), - token_input->token_ids.begin() + chunked_input_length}); - TokenData remaining_input(IntTuple{token_input->token_ids.begin() + chunked_input_length, - token_input->token_ids.end()}); - inputs.push_back(chunked_input); - cum_input_length += chunked_input_length; - std::vector remaining_inputs{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; - remaining_inputs.insert(remaining_inputs.begin(), remaining_input); - mstate->inputs = remaining_inputs; - return {inputs, cum_input_length}; - } - - ICHECK(false) << "Cannot reach here"; - } - - /*! \brief The models to run prefill in. */ - Array models_; + private: /*! \brief The logit processor. */ LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ @@ -580,10 +339,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { std::vector model_workspaces_; /*! \brief The draft token workspace manager. */ DraftTokenWorkspaceManager draft_token_workspace_manager_; - /*! \brief The engine config. */ - EngineConfig engine_config_; - /*! \brief Event trace recorder. */ - Optional trace_recorder_; /*! \brief Temporary buffer to store the slots of the current draft tokens */ std::vector draft_token_slots_; }; diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index 5a5847aaa0..038a6cc66c 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -10,6 +10,7 @@ #include "../sampler/sampler.h" #include "action.h" #include "action_commons.h" +#include "batch_prefill_base.h" namespace mlc { namespace llm { @@ -19,18 +20,17 @@ namespace serve { * \brief The action that prefills requests in the `waiting_queue` of * the engine state. */ -class NewRequestPrefillActionObj : public EngineActionObj { +class NewRequestPrefillActionObj : public BatchPrefillBaseActionObj { public: explicit NewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, EngineConfig engine_config, Optional trace_recorder) - : models_(std::move(models)), + : BatchPrefillBaseActionObj(std::move(models), std::move(engine_config), + std::move(trace_recorder)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), - model_workspaces_(std::move(model_workspaces)), - engine_config_(std::move(engine_config)), - trace_recorder_(std::move(trace_recorder)) {} + model_workspaces_(std::move(model_workspaces)) {} Array Step(EngineState estate) final { // - Find the requests in `waiting_queue` that can prefill in this step. @@ -50,32 +50,8 @@ class NewRequestPrefillActionObj : public EngineActionObj { Array request_ids; std::vector rstates_of_entries; std::vector status_before_prefill; - request_ids.reserve(num_rsentries); - rstates_of_entries.reserve(num_rsentries); - status_before_prefill.reserve(num_rsentries); - for (const PrefillInput& prefill_input : prefill_inputs) { - const RequestStateEntry& rsentry = prefill_input.rsentry; - const Request& request = rsentry->request; - RequestState request_rstate = estate->GetRequestState(request); - request_ids.push_back(request->id); - status_before_prefill.push_back(rsentry->status); - rsentry->status = RequestStateStatus::kAlive; - - if (status_before_prefill.back() == RequestStateStatus::kPending) { - // - Add the request to running queue if the request state - // status was pending and all its request states were pending. - bool alive_state_existed = false; - for (const RequestStateEntry& rsentry_ : request_rstate->entries) { - if (rsentry_->status == RequestStateStatus::kAlive && !rsentry_.same_as(rsentry)) { - alive_state_existed = true; - } - } - if (!alive_state_existed) { - estate->running_queue.push_back(request); - } - } - rstates_of_entries.push_back(std::move(request_rstate)); - } + UpdateRequestToAlive(prefill_inputs, estate, &request_ids, &rstates_of_entries, + &status_before_prefill); // - Get embedding and run prefill for each model. std::vector prefill_lengths; @@ -237,280 +213,24 @@ class NewRequestPrefillActionObj : public EngineActionObj { // - Update the committed tokens of states. // - If a request is first-time prefilled, set the prefill finish time. - auto tnow = std::chrono::high_resolution_clock::now(); - for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { - for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) { - mstate->CommitToken(sample_results[i]); - if (!rsentry_activated[i]) { - // When the child rsentry is not activated, - // add the sampled token as an input of the mstate for prefill. - mstate->inputs.push_back( - TokenData(std::vector{sample_results[i].sampled_token_id.first})); - } - } - if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { - rsentries_for_sample[i]->tprefill_finish = tnow; - } - } + UpdateRequestStateEntriesWithSampleResults(rsentries_for_sample, rsentry_activated, + sample_results); auto tend = std::chrono::high_resolution_clock::now(); estate->stats.engine_total_prefill_time += static_cast((tend - tstart).count()) / 1e9; - // - Remove the request from waiting queue if all its request states - // are now alive and have no remaining chunked inputs. - std::vector processed_requests; - { - processed_requests.reserve(num_rsentries); - std::unordered_set dedup_map; - for (int i = 0; i < num_rsentries; ++i) { - const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; - if (dedup_map.find(rsentry->request.get()) != dedup_map.end()) { - continue; - } - dedup_map.insert(rsentry->request.get()); - processed_requests.push_back(rsentry->request); - - bool pending_state_exists = false; - for (const RequestStateEntry& rsentry_ : rstates_of_entries[i]->entries) { - if (rsentry_->status == RequestStateStatus::kPending || - !rsentry_->mstates[0]->inputs.empty()) { - pending_state_exists = true; - break; - } - } - if (!pending_state_exists) { - auto it = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), - rsentry->request); - ICHECK(it != estate->waiting_queue.end()); - estate->waiting_queue.erase(it); - } - } - } + std::vector processed_requests = + RemoveProcessedRequests(prefill_inputs, estate, rstates_of_entries); return processed_requests; } private: - /*! \brief The class of request state entry and its maximum allowed length for prefill. */ - struct PrefillInput { - RequestStateEntry rsentry; - int max_prefill_length = 0; - int num_child_to_activate = 0; - }; - - /*! - * \brief Find one or multiple request state entries to run prefill. - * \param estate The engine state. - * \return The request entries to prefill, together with their input lengths. - */ - std::vector GetRequestStateEntriesToPrefill(EngineState estate) { - if (estate->waiting_queue.empty()) { - // No request to prefill. - return {}; - } - - std::vector prefill_inputs; - - // - Try to prefill pending requests. - int total_input_length = 0; - int total_required_pages = 0; - int num_available_pages = models_[0]->GetNumAvailablePages(); - int num_running_rsentries = GetRunningRequestStateEntries(estate).size(); - int current_total_seq_len = models_[0]->GetCurrentTotalSequenceLength(); - - int num_prefill_rsentries = 0; - for (const Request& request : estate->waiting_queue) { - RequestState rstate = estate->GetRequestState(request); - bool prefill_stops = false; - for (const RequestStateEntry& rsentry : rstate->entries) { - // A request state entry can be prefilled only when: - // - it has inputs, and - // - it has no parent or its parent is alive and has no remaining input. - if (rsentry->mstates[0]->inputs.empty() || - (rsentry->parent_idx != -1 && - (rstate->entries[rsentry->parent_idx]->status == RequestStateStatus::kPending || - !rstate->entries[rsentry->parent_idx]->mstates[0]->inputs.empty()))) { - continue; - } - - int input_length = rsentry->mstates[0]->GetInputLength(); - int num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / - engine_config_->kv_cache_page_size; - total_input_length += input_length; - total_required_pages += num_require_pages; - // - Attempt 1. Check if the entire request state entry can fit for prefill. - bool can_prefill = false; - for (int num_child_to_activate = rsentry->child_indices.size(); num_child_to_activate >= 0; - --num_child_to_activate) { - if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate, - total_input_length, total_required_pages, num_available_pages, - current_total_seq_len, num_running_rsentries)) { - prefill_inputs.push_back({rsentry, input_length, num_child_to_activate}); - num_prefill_rsentries += 1 + num_child_to_activate; - can_prefill = true; - break; - } - } - if (can_prefill) { - continue; - } - total_input_length -= input_length; - total_required_pages -= num_require_pages; - - // - Attempt 2. Check if the request state entry can partially fit by input chunking. - ICHECK_LE(total_input_length, engine_config_->prefill_chunk_size); - if (engine_config_->prefill_chunk_size - total_input_length >= input_length || - engine_config_->prefill_chunk_size == total_input_length) { - // 1. If the input length can fit the remaining prefill chunk size, - // it means the failure of attempt 1 is not because of the input - // length being too long, and thus chunking does not help. - // 2. If the total input length already reaches the prefill chunk size, - // the current request state entry will not be able to be processed. - // So we can safely return in either case. - prefill_stops = true; - break; - } - input_length = engine_config_->prefill_chunk_size - total_input_length; - num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / - engine_config_->kv_cache_page_size; - total_input_length += input_length; - total_required_pages += num_require_pages; - if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages, - num_available_pages, current_total_seq_len, num_running_rsentries)) { - prefill_inputs.push_back({rsentry, input_length, 0}); - num_prefill_rsentries += 1; - } - - // - Prefill stops here. - prefill_stops = true; - break; - } - if (prefill_stops) { - break; - } - } - - return prefill_inputs; - } - - /*! \brief Check if the input requests can be prefilled under conditions. */ - bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, - int num_required_pages, int num_available_pages, int current_total_seq_len, - int num_running_rsentries) { - ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); - - // For RNN State, it can prefill as long as it can be instantiated. - if (engine_config_->kv_state_kind == KVStateKind::kRNNState) { - return true; - } - - // No exceeding of the maximum allowed requests that can - // run simultaneously. - int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable - ? (engine_config_->spec_draft_length + 1) - : 1; - if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > - std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { - return false; - } - - // NOTE: The conditions are heuristic and can be revised. - // Cond 1: total input length <= prefill chunk size. - // Cond 2: at least one decode can be performed after prefill. - // Cond 3: number of total tokens after 8 times of decode does not - // exceed the limit, where 8 is a watermark number can - // be configured and adjusted in the future. - int new_batch_size = num_running_rsentries + num_prefill_rsentries; - return total_input_length <= engine_config_->prefill_chunk_size && - num_required_pages + new_batch_size <= num_available_pages && - current_total_seq_len + total_input_length + 8 * new_batch_size <= - engine_config_->max_total_sequence_length; - } - - /*! - * \brief Chunk the input of the given RequestModelState for prefill - * with regard to the provided maximum allowed prefill length. - * Return the list of input for prefill and the total prefill length. - * The `inputs` field of the given `mstate` will be mutated to exclude - * the returned input. - * \param mstate The RequestModelState whose input data is to be chunked. - * \param max_prefill_length The maximum allowed prefill length for the mstate. - * \return The list of input for prefill and the total prefill length. - */ - std::pair, int> ChunkPrefillInputData(const RequestModelState& mstate, - int max_prefill_length) { - if (mstate->inputs.empty()) { - } - ICHECK(!mstate->inputs.empty()); - std::vector inputs; - int cum_input_length = 0; - inputs.reserve(mstate->inputs.size()); - for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { - inputs.push_back(mstate->inputs[i]); - int input_length = mstate->inputs[i]->GetLength(); - cum_input_length += input_length; - // Case 0. the cumulative input length does not reach the maximum prefill length. - if (cum_input_length < max_prefill_length) { - continue; - } - - // Case 1. the cumulative input length equals the maximum prefill length. - if (cum_input_length == max_prefill_length) { - if (i == static_cast(mstate->inputs.size()) - 1) { - // - If `i` is the last input, we just copy and reset `mstate->inputs`. - mstate->inputs.clear(); - } else { - // - Otherwise, set the new input array. - mstate->inputs = Array{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; - } - return {inputs, cum_input_length}; - } - - // Case 2. cum_input_length > max_prefill_length - // The input `i` itself needs chunking if it is TokenData, - // or otherwise it cannot be chunked. - Data input = mstate->inputs[i]; - inputs.pop_back(); - cum_input_length -= input_length; - const auto* token_input = input.as(); - if (token_input == nullptr) { - // Cannot chunk the input. - if (i != 0) { - mstate->inputs = Array{mstate->inputs.begin() + i, mstate->inputs.end()}; - } - return {inputs, cum_input_length}; - } - - // Split the token data into two parts. - // Return the first part for prefill, and keep the second part. - int chunked_input_length = max_prefill_length - cum_input_length; - ICHECK_GT(input_length, chunked_input_length); - TokenData chunked_input(IntTuple{token_input->token_ids.begin(), - token_input->token_ids.begin() + chunked_input_length}); - TokenData remaining_input(IntTuple{token_input->token_ids.begin() + chunked_input_length, - token_input->token_ids.end()}); - inputs.push_back(chunked_input); - cum_input_length += chunked_input_length; - std::vector remaining_inputs{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; - remaining_inputs.insert(remaining_inputs.begin(), remaining_input); - mstate->inputs = remaining_inputs; - return {inputs, cum_input_length}; - } - - ICHECK(false) << "Cannot reach here"; - } - - /*! \brief The models to run prefill in. */ - Array models_; /*! \brief The logit processor. */ LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; - /*! \brief The engine config. */ - EngineConfig engine_config_; - /*! \brief Event trace recorder. */ - Optional trace_recorder_; }; EngineAction EngineAction::NewRequestPrefill(Array models, LogitProcessor logit_processor, diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index bdf28dfdb5..2ed864f298 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -232,7 +232,6 @@ void FunctionTable::_InitFunctions() { } else { this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); } - ICHECK(this->create_kv_cache_func_.defined()); } this->reset_kv_cache_func_ = get_global_func("vm.builtin.kv_state_clear"); this->kv_cache_add_sequence_func_ = get_global_func("vm.builtin.kv_state_add_sequence"); @@ -271,7 +270,7 @@ ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) Device null_device{DLDeviceType(0), 0}; if (this->use_disco) { DRef empty_func = sess->GetGlobalFunc("runtime.disco.empty"); - return sess->CallPacked(empty_func, shape, dtype, null_device); + return sess->CallPacked(empty_func, shape, dtype, null_device, false); } else { return NDArray::Empty(shape, dtype, device); } diff --git a/cpp/serve/grammar/grammar.cc b/cpp/serve/grammar/grammar.cc index c8d760538c..2f0d7f565f 100644 --- a/cpp/serve/grammar/grammar.cc +++ b/cpp/serve/grammar/grammar.cc @@ -5,9 +5,9 @@ #include "grammar.h" +#include "grammar_functor.h" #include "grammar_parser.h" #include "grammar_serializer.h" -#include "grammar_simplifier.h" #include "json_schema_converter.h" namespace mlc { @@ -21,18 +21,28 @@ std::ostream& operator<<(std::ostream& os, const BNFGrammar& grammar) { return os; } -BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string, const std::string& main_rule, - bool normalize, bool simplify) { +BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string, + const std::string& main_rule) { auto grammar = EBNFParser::Parse(ebnf_string, main_rule); - if (normalize) { - grammar = NestedRuleUnwrapper(grammar).Apply(); - } + // Normalize the grammar by default + grammar = BNFGrammarNormalizer().Apply(grammar); return grammar; } TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromEBNFString") - .set_body_typed([](String ebnf_string, String main_rule, bool normalize, bool simplify) { - return BNFGrammar::FromEBNFString(ebnf_string, main_rule, normalize, simplify); + .set_body_typed([](String ebnf_string, String main_rule) { + return BNFGrammar::FromEBNFString(ebnf_string, main_rule); + }); + +// Parse the EBNF string but not normalize it +BNFGrammar DebugFromEBNFStringNoNormalize(const std::string& ebnf_string, + const std::string& main_rule) { + return EBNFParser::Parse(ebnf_string, main_rule); +} + +TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarDebugFromEBNFStringNoNormalize") + .set_body_typed([](String ebnf_string, String main_rule) { + return DebugFromEBNFStringNoNormalize(ebnf_string, main_rule); }); BNFGrammar BNFGrammar::FromJSON(const std::string& json_string) { @@ -69,79 +79,90 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromSchema").set_body([](TVMArgs args, *rv = BNFGrammar::FromSchema(args[0], indent, separators, args[3]); }); +// Optimized json grammar for the speed of the grammar state matcher const std::string kJSONGrammarString = R"( main ::= ( - "{" ws members_or_embrace | - "[" ws elements_or_embrace + "{" [ \n\t]* members_and_embrace | + "[" [ \n\t]* elements_or_embrace ) -value ::= ( - "{" ws members_or_embrace | - "[" ws elements_or_embrace | - "\"" characters "\"" | - [0-9] fraction exponent | - [1-9] digits fraction exponent | +value_non_str ::= ( + "{" [ \n\t]* members_and_embrace | + "[" [ \n\t]* elements_or_embrace | + "0" fraction exponent | + [1-9] [0-9]* fraction exponent | "-" [0-9] fraction exponent | - "-" [1-9] digits fraction exponent | + "-" [1-9] [0-9]* fraction exponent | "true" | "false" | "null" -) -members_or_embrace ::= ( - "\"" characters "\"" ws ":" ws value members_rest ws "}" | - "}" -) -members ::= "\"" characters "\"" ws ":" ws value members_rest -members_rest ::= ( - "" | - "," ws "\"" characters "\"" ws ":" ws value members_rest | - " " ws "," ws "\"" characters "\"" ws ":" ws value members_rest | - "\n" ws "," ws "\"" characters "\"" ws ":" ws value members_rest | - "\t" ws "," ws "\"" characters "\"" ws ":" ws value members_rest -) +) (= [ \n\t,}\]]) +members_and_embrace ::= ("\"" characters_and_colon [ \n\t]* members_suffix | "}") (= [ \n\t,}\]]) +members_suffix ::= ( + value_non_str [ \n\t]* member_suffix_suffix | + "\"" characters_and_embrace | + "\"" characters_and_comma [ \n\t]* "\"" characters_and_colon [ \n\t]* members_suffix +) (= [ \n\t,}\]]) +member_suffix_suffix ::= ( + "}" | + "," [ \n\t]* "\"" characters_and_colon [ \n\t]* members_suffix +) (= [ \n\t,}\]]) elements_or_embrace ::= ( - "{" ws members_or_embrace elements_rest ws "]" | - "[" ws elements_or_embrace elements_rest ws "]" | - "\"" characters "\"" elements_rest ws "]" | - [0-9] fraction exponent elements_rest ws "]" | - [1-9] digits fraction exponent elements_rest ws "]" | - "-" [0-9] fraction exponent elements_rest ws "]" | - "-" [1-9] digits fraction exponent elements_rest ws "]" | - "true" elements_rest ws "]" | - "false" elements_rest ws "]" | - "null" elements_rest ws "]" | + "{" [ \n\t]* members_and_embrace elements_rest [ \n\t]* "]" | + "[" [ \n\t]* elements_or_embrace elements_rest [ \n\t]* "]" | + "\"" characters_item elements_rest [ \n\t]* "]" | + "0" fraction exponent elements_rest [ \n\t]* "]" | + [1-9] [0-9]* fraction exponent elements_rest [ \n\t]* "]" | + "-" "0" fraction exponent elements_rest [ \n\t]* "]" | + "-" [1-9] [0-9]* fraction exponent elements_rest [ \n\t]* "]" | + "true" elements_rest [ \n\t]* "]" | + "false" elements_rest [ \n\t]* "]" | + "null" elements_rest [ \n\t]* "]" | "]" ) elements ::= ( - "{" ws members_or_embrace elements_rest | - "[" ws elements_or_embrace elements_rest | - "\"" characters "\"" elements_rest | - [0-9] fraction exponent elements_rest | - [1-9] digits fraction exponent elements_rest | + "{" [ \n\t]* members_and_embrace elements_rest | + "[" [ \n\t]* elements_or_embrace elements_rest | + "\"" characters_item elements_rest | + "0" fraction exponent elements_rest | + [1-9] [0-9]* fraction exponent elements_rest | "-" [0-9] fraction exponent elements_rest | - "-" [1-9] digits fraction exponent elements_rest | + "-" [1-9] [0-9]* fraction exponent elements_rest | "true" elements_rest | "false" elements_rest | "null" elements_rest ) elements_rest ::= ( "" | - "," ws elements | - " " ws "," ws elements | - "\n" ws "," ws elements | - "\t" ws "," ws elements + [ \n\t]* "," [ \n\t]* elements ) -characters ::= "" | [^"\\\r\n] characters | "\\" escape characters +characters_and_colon ::= ( + "\"" [ \n\t]* ":" | + [^"\\\x00-\x1F] characters_and_colon | + "\\" escape characters_and_colon +) (=[ \n\t]* [\"{[0-9tfn-]) +characters_and_comma ::= ( + "\"" [ \n\t]* "," | + [^"\\\x00-\x1F] characters_and_comma | + "\\" escape characters_and_comma +) (=[ \n\t]* "\"") +characters_and_embrace ::= ( + "\"" [ \n\t]* "}" | + [^"\\\x00-\x1F] characters_and_embrace | + "\\" escape characters_and_embrace +) (=[ \n\t]* [},]) +characters_item ::= ( + "\"" | + [^"\\\x00-\x1F] characters_item | + "\\" escape characters_item +) (= [ \n\t]* [,\]]) escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -digits ::= [0-9] | [0-9] digits -fraction ::= "" | "." digits -exponent ::= "" | "e" sign digits | "E" sign digits +fraction ::= "" | "." [0-9] [0-9]* +exponent ::= "" | "e" sign [0-9] [0-9]* | "E" sign [0-9] [0-9]* sign ::= "" | "+" | "-" -ws ::= [ \n\t]* )"; BNFGrammar BNFGrammar::GetGrammarOfJSON() { - static const BNFGrammar grammar = - BNFGrammar::FromEBNFString(kJSONGrammarString, "main", true, false); + static const BNFGrammar grammar = BNFGrammar::FromEBNFString(kJSONGrammarString, "main"); return grammar; } diff --git a/cpp/serve/grammar/grammar.h b/cpp/serve/grammar/grammar.h index ba15e58af3..b7922301cb 100644 --- a/cpp/serve/grammar/grammar.h +++ b/cpp/serve/grammar/grammar.h @@ -44,16 +44,15 @@ using namespace tvm::runtime; * #### Types of RuleExprs * Every RuleExpr is represented by a type as well as a variable-length array containing its data. * RuleExpr has several types: + * - Byte string: a string of bytes (0~255). Supports UTF-8 strings. * - Character class: a range of characters (each character is a unicode codepoint), e.g. [a-z], - * [ac-z]. - * A single character is represented by a character class with the same lower and upper bound. - * A string is represented by a sequence of character classes. - * - Negated character class: all characters that are not in the range, e.g. [^a-z], [^ac-z] + * [ac-z]. Can be negated: [^a-z], [^ac-z]. Now only ascii chars is allowed in [], but this + * expression can accept/reject unicode chars. + * - Character class star: a star quantifier of a character class. e.g. [a-z]*, [^a-z]*. * - EmptyStr: an empty string, i.e. "" * - Rule reference: a reference to another rule * - Sequence: a sequence of rule_exprs, e.g. ("a" "b"). These rule_exprs are concatenated together. * - Choices: a choice of rule_exprs, e.g. ("a" "b") | "c". Each rule_expr can be matched. - * - Character class star: special support for a repetition of a character class. e.g. [a-z]* * * #### Storage of RuleExprs * Each type of RuleExpr has a different data format. For the format of each type of RuleExpr, see @@ -76,6 +75,9 @@ class BNFGrammarNode : public Object { std::string name; /*! \brief The RuleExpr id of the body of the rule. */ int32_t body_expr_id; + /*! \brief The id of the associated lookahead assertion expr. For now it must be a id of a + * sequence RuleExpr. -1 if not exists. */ + int32_t lookahead_assertion_id = -1; }; /*! \brief Get the number of rules. */ @@ -86,6 +88,8 @@ class BNFGrammarNode : public Object { << "rule_id " << rule_id << " is out of bound"; return rules_[rule_id]; } + /*! \brief Get the main rule id of the grammar. */ + int32_t GetMainRuleId() const { return main_rule_id_; } /*! \brief Get the main rule of the grammar. */ const Rule& GetMainRule() const { DCHECK(main_rule_id_ >= 0 && main_rule_id_ < static_cast(rules_.size())) @@ -95,10 +99,11 @@ class BNFGrammarNode : public Object { /*! \brief The type of the rule expr. */ enum class RuleExprType : int32_t { - // data format: [lower0, upper0, lower1, upper1, ...] + // data format: [byte0, byte1, ...] + kByteString, + // data format: [is_negative, lower0, upper0, lower1, upper1, ...] kCharacterClass, - // data format: [lower0, upper0, lower1, upper1, ...] - kNegCharacterClass, + kCharacterClassStar, // data format: [] kEmptyStr, // data format: [rule_id] @@ -107,8 +112,6 @@ class BNFGrammarNode : public Object { kSequence, // data format: [rule_expr_id0, rule_expr_id1, ...] kChoices, - // data format: [rule_expr_id] - kCharacterClassStar, }; /*! \brief The object representing a rule expr. */ @@ -154,8 +157,8 @@ class BNFGrammarNode : public Object { std::vector rules_; /*! \brief The data of all rule_exprs. */ std::vector rule_expr_data_; - /*! \brief The start index of every rule_expr in rule_expr_data_. rule_expr_id corresponds the - * index of this vector. */ + /*! \brief The start index of every rule_expr in rule_expr_data_. rule_expr_id is the index + * to the elements in this vector. */ std::vector rule_expr_indptr_; /*! \brief The id of the main rule. */ int32_t main_rule_id_ = -1; @@ -168,25 +171,13 @@ class BNFGrammarNode : public Object { class BNFGrammar : public ObjectRef { public: /*! - * \brief Construct a BNF grammar with a EBNF-formatted string. Will parse the string and - * transform it into BNF AST. + * \brief Construct a BNF grammar with a EBNF-formatted string. The grammar will be normalized + * (simplified) by default. * \param ebnf_string The EBNF-formatted string. * \param main_rule The name of the main rule. - * \param normalize Whether to normalize the grammar. Default: true. Only set to false for the - * purpose of testing. - * - * \note In The normalized form of a BNF grammar, every rule is in the form: - * `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. - * - * I.e. a list of choices, each choice is a sequence of elements. Elements can be a character - * class or a rule reference. And if the rule can be empty, the first choice will be an empty - * string. - * \param simplify Whether to simplify the grammar to make matching more efficient. Default: true. - * Not implemented yet. */ static BNFGrammar FromEBNFString(const std::string& ebnf_string, - const std::string& main_rule = "main", bool normalize = true, - bool simplify = true); + const std::string& main_rule = "main"); /*! * \brief Construct a BNF grammar from the dumped JSON string. diff --git a/cpp/serve/grammar/grammar_builder.h b/cpp/serve/grammar/grammar_builder.h index 0854cc9789..7987a67f98 100644 --- a/cpp/serve/grammar/grammar_builder.h +++ b/cpp/serve/grammar/grammar_builder.h @@ -56,6 +56,16 @@ class BNFGrammarBuilder { return static_cast(grammar_->rule_expr_indptr_.size()) - 1; } + /*! + * \brief Add a RuleExpr for string stored in bytes. + * \param bytes A vector of int32_t, each representing a byte (0~255) in the string. + * The string is stored in int32 vector to match the storage format of the grammar. + */ + int32_t AddByteString(const std::vector& bytes) { + return AddRuleExpr( + {RuleExprType::kByteString, bytes.data(), static_cast(bytes.size())}); + } + /*! * \brief One element of a character class, containing a lower and a upper bound. Both bounds are * inclusive. @@ -66,19 +76,39 @@ class BNFGrammarBuilder { }; /*! - * \brief Add a RuleExpr for character class. + * \brief Add a RuleExpr for a character class. * \param elements A vector of CharacterClassElement, each containing a lower and a upper bound. - * \param is_neg_range Whether the character class is negated. + * \param is_negative Whether the character class is negated. */ int32_t AddCharacterClass(const std::vector& elements, - bool is_neg_range = false) { + bool is_negative = false) { std::vector data; + data.reserve(1 + elements.size() * 2); + data.push_back(static_cast(is_negative)); for (const auto& range : elements) { data.push_back(range.lower); data.push_back(range.upper); } - auto type = is_neg_range ? RuleExprType::kNegCharacterClass : RuleExprType::kCharacterClass; - return AddRuleExpr({type, data.data(), static_cast(data.size())}); + return AddRuleExpr( + {RuleExprType::kCharacterClass, data.data(), static_cast(data.size())}); + } + + /*! + * \brief Add a RuleExpr for a star quantifier of a character class. + * \param elements A vector of CharacterClassElement, each containing a lower and a upper bound. + * \param is_negative Whether the character class is negated. + */ + int32_t AddCharacterClassStar(const std::vector& elements, + bool is_negative = false) { + std::vector data; + data.reserve(1 + elements.size() * 2); + data.push_back(static_cast(is_negative)); + for (const auto& range : elements) { + data.push_back(range.lower); + data.push_back(range.upper); + } + return AddRuleExpr( + {RuleExprType::kCharacterClassStar, data.data(), static_cast(data.size())}); } /*! \brief Add a RuleExpr for empty string.*/ @@ -93,23 +123,14 @@ class BNFGrammarBuilder { /*! \brief Add a RuleExpr for RuleExpr sequence.*/ int32_t AddSequence(const std::vector& elements) { - std::vector data; - data.insert(data.end(), elements.begin(), elements.end()); - return AddRuleExpr({RuleExprType::kSequence, data.data(), static_cast(data.size())}); + return AddRuleExpr( + {RuleExprType::kSequence, elements.data(), static_cast(elements.size())}); } /*! \brief Add a RuleExpr for RuleExpr choices.*/ int32_t AddChoices(const std::vector& choices) { - std::vector data; - data.insert(data.end(), choices.begin(), choices.end()); - return AddRuleExpr({RuleExprType::kChoices, data.data(), static_cast(data.size())}); - } - - int32_t AddCharacterClassStar(int32_t element) { - std::vector data; - data.push_back(element); return AddRuleExpr( - {RuleExprType::kCharacterClassStar, data.data(), static_cast(data.size())}); + {RuleExprType::kChoices, choices.data(), static_cast(choices.size())}); } size_t NumRuleExprs() const { return grammar_->NumRuleExprs(); } @@ -154,7 +175,7 @@ class BNFGrammarBuilder { * rule body of a rule inserted by BNFGrammarBuilder::AddEmptyRule. */ void UpdateRuleBody(int32_t rule_id, int32_t body_expr_id) { - CHECK(rule_id < static_cast(grammar_->rules_.size())) + CHECK(rule_id >= 0 && rule_id < static_cast(grammar_->rules_.size())) << "Rule id " << rule_id << " is out of range."; grammar_->rules_[rule_id].body_expr_id = body_expr_id; } @@ -169,6 +190,28 @@ class BNFGrammarBuilder { UpdateRuleBody(rule_id, body_expr_id); } + /*! + * \brief Add a lookahead assertion to a rule referred by the given rule_id. The lookahead + * assertion should be a sequence RuleExpr id. An id of -1 means no lookahead assertion. + */ + void AddLookaheadAssertion(int32_t rule_id, int32_t lookahead_assertion_id) { + CHECK(rule_id < static_cast(grammar_->rules_.size())) + << "Rule id " << rule_id << " is out of range."; + CHECK(grammar_->rules_[rule_id].lookahead_assertion_id == -1) + << "Rule " << rule_id << " already has a lookahead assertion."; + grammar_->rules_[rule_id].lookahead_assertion_id = lookahead_assertion_id; + } + + /*! + * \brief Add a lookahead assertion to a rule referred by the given name. The lookahead + * assertion should be a sequence RuleExpr id. An id of -1 means no lookahead assertion. + */ + void AddLookaheadAssertion(std::string rule_name, int32_t lookahead_assertion_id) { + int32_t rule_id = GetRuleId(rule_name); + CHECK(rule_id != -1) << "Rule " << rule_name << " is not found."; + AddLookaheadAssertion(rule_id, lookahead_assertion_id); + } + /*! * \brief Find a name for a new rule starting with the given name hint. Some integer suffix (_1, * _2, ...) may be added to avoid name conflict. diff --git a/cpp/serve/grammar/grammar_simplifier.cc b/cpp/serve/grammar/grammar_functor.cc similarity index 54% rename from cpp/serve/grammar/grammar_simplifier.cc rename to cpp/serve/grammar/grammar_functor.cc index 109b5d85e1..ae4e108233 100644 --- a/cpp/serve/grammar/grammar_simplifier.cc +++ b/cpp/serve/grammar/grammar_functor.cc @@ -1,56 +1,101 @@ /*! * Copyright (c) 2023 by Contributors - * \file serve/grammar/grammar_simplifier.cc + * \file serve/grammar/grammar_functor.cc */ -#include "grammar_simplifier.h" +#include "grammar_functor.h" + +#include "../../support/encoding.h" namespace mlc { namespace llm { namespace serve { /*! - * \brief Eliminates single-element sequence or choice nodes in the grammar. - * \example The sequence `(a)` or the choice `(a)` will be replaced by `a` in a rule. - * \example The rule `A ::= ((b) (((d))))` will be replaced by `A ::= (b d)`. + * \brief Eliminates single-element sequence or choice or character class in the grammar. + * \example `A ::= choices("a")` --> `A ::= "a"` (the body is a string) + * \example `A ::= sequence("a")` --> `A ::= "a"` (the body is a string) + * \example `A ::= [a-a]` --> `A ::= "a"` (the body is a string) */ -class SingleElementSequenceOrChoiceEliminator : public BNFGrammarMutator { +class SingleElementExprEliminator : public BNFGrammarMutator { public: using BNFGrammarMutator::Apply; using BNFGrammarMutator::BNFGrammarMutator; private: - int32_t VisitSequence(const RuleExpr& rule_expr) { + // Keep the sequence expr in lookahead assertion + int32_t VisitLookaheadAssertion(int32_t lookahead_assertion_id) final { + if (lookahead_assertion_id == -1) { + return -1; + } + auto rule_expr = grammar_->GetRuleExpr(lookahead_assertion_id); + CHECK(rule_expr.type == RuleExprType::kSequence); + + std::vector sequence_ids; + for (int32_t i : rule_expr) { + sequence_ids.push_back(VisitExpr(i)); + } + return builder_.AddSequence(sequence_ids); + } + + int32_t VisitSequence(const RuleExpr& rule_expr) final { std::vector sequence_ids; for (int32_t i : rule_expr) { - sequence_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + sequence_ids.push_back(VisitExpr(i)); } if (sequence_ids.size() == 1) { return sequence_ids[0]; - } else { - return builder_.AddSequence(sequence_ids); } + return builder_.AddSequence(sequence_ids); } - int32_t VisitChoices(const RuleExpr& rule_expr) { + int32_t VisitChoices(const RuleExpr& rule_expr) final { std::vector choice_ids; for (int32_t i : rule_expr) { - choice_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + choice_ids.push_back(VisitExpr(i)); } if (choice_ids.size() == 1) { return choice_ids[0]; - } else { - return builder_.AddChoices(choice_ids); } + return builder_.AddChoices(choice_ids); + } + + int32_t VisitCharacterClass(const RuleExpr& rule_expr) final { + if (rule_expr.data_len == 3 && rule_expr[0] == 0 && rule_expr[1] == rule_expr[2]) { + std::string str = PrintAsUTF8(rule_expr[1]); + std::vector bytes; + bytes.reserve(str.size()); + for (char c : str) { + bytes.push_back(static_cast(c)); + } + return builder_.AddByteString(bytes); + } + return builder_.AddRuleExpr(rule_expr); } }; -class NestedRuleUnwrapperImpl : public BNFGrammarMutator { +/*! + * \brief Unwrap the rules containing nested expressions. After unwrapping, each rule will be in + * the form: `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. + * + * I.e. a list of choices, each choice is a sequence of elements. Elements can be a character class + * or a rule reference. And if the rule can be empty, the first choice will be an empty string. + * + * \example The rule `A ::= ((a) (((b)) (c)) "")` will be replaced by `A ::= ((a b c))`. One choice + * containing a sequence of three elements. The empty string is removed. + * \example The rule `A ::= (a | (b | (c | "")))` will be replaced by + * `A ::= ("" | (a) | (b) | (c))`. The first choice is an empty string, and each of the other three + * choices is a sequence containing a single element. + * \example The rule `A ::= (a | (b (c | d)))` will be replaced by + * `A ::= ((a) | (b B)), B ::= ((c) | (d))`. A new rule B is created to represent the nested + * choices. + */ +class NestedRuleUnwrapper : public BNFGrammarMutator { public: using BNFGrammarMutator::BNFGrammarMutator; - BNFGrammar Apply() final { - grammar_ = SingleElementSequenceOrChoiceEliminator(grammar_).Apply(); + BNFGrammar Apply(const BNFGrammar& grammar) final { + Init(grammar); for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { builder_.AddEmptyRule(grammar_->GetRule(i).name); } @@ -60,11 +105,20 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { cur_rule_name_ = rule.name; auto new_body_expr_id = VisitRuleBody(rule_expr); builder_.UpdateRuleBody(i, new_body_expr_id); + builder_.AddLookaheadAssertion(i, VisitLookaheadAssertion(rule.lookahead_assertion_id)); } return builder_.Get(grammar_->GetMainRule().name); } private: + int32_t VisitLookaheadAssertion(int32_t lookahead_assertion_id) final { + if (lookahead_assertion_id == -1) { + return -1; + } + auto assertion_expr = grammar_->GetRuleExpr(lookahead_assertion_id); + return builder_.AddSequence(VisitSequence_(assertion_expr)); + } + /*! \brief Visit a RuleExpr as a rule body. */ int32_t VisitRuleBody(const RuleExpr& rule_expr) { switch (rule_expr.type) { @@ -74,12 +128,11 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { return builder_.AddChoices(VisitChoices_(rule_expr)); case RuleExprType::kEmptyStr: return builder_.AddChoices({builder_.AddEmptyStr()}); + case RuleExprType::kByteString: case RuleExprType::kCharacterClass: - case RuleExprType::kNegCharacterClass: + case RuleExprType::kCharacterClassStar: case RuleExprType::kRuleRef: return builder_.AddChoices({builder_.AddSequence({builder_.AddRuleExpr(rule_expr)})}); - case RuleExprType::kCharacterClassStar: - return builder_.AddCharacterClassStar(VisitExpr(grammar_->GetRuleExpr(rule_expr[0]))); default: LOG(FATAL) << "Unexpected sequence type: " << static_cast(rule_expr.type); } @@ -104,14 +157,12 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { case RuleExprType::kEmptyStr: found_empty = true; break; + case RuleExprType::kByteString: case RuleExprType::kCharacterClass: - case RuleExprType::kNegCharacterClass: + case RuleExprType::kCharacterClassStar: case RuleExprType::kRuleRef: VisitElementInChoices(choice_expr, &new_choice_ids); break; - case RuleExprType::kCharacterClassStar: - VisitCharacterClassStarInChoices(choice_expr, &new_choice_ids); - break; default: LOG(FATAL) << "Unexpected choice type: " << static_cast(choice_expr.type); } @@ -154,16 +205,6 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { new_choice_ids->push_back(builder_.AddSequence({sub_expr_id})); } - /*! \brief Visit a character class star RuleExpr that is one of a list of choices. */ - void VisitCharacterClassStarInChoices(const RuleExpr& rule_expr, - std::vector* new_choice_ids) { - auto sub_expr_id = builder_.AddRuleExpr(grammar_->GetRuleExpr(rule_expr[0])); - auto new_star_id = builder_.AddCharacterClassStar(sub_expr_id); - auto new_rule_id = builder_.AddRuleWithHint(cur_rule_name_ + "_star", new_star_id); - auto new_rule_ref_id = builder_.AddRuleRef(new_rule_id); - new_choice_ids->push_back(builder_.AddSequence({new_rule_ref_id})); - } - /*! * \brief Visit a RuleExpr containing a sequence. * \returns A list of new sequence RuleExpr ids. @@ -171,26 +212,24 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { std::vector VisitSequence_(const RuleExpr& rule_expr) { std::vector new_sequence_ids; for (auto i : rule_expr) { - auto seq_expr = grammar_->GetRuleExpr(i); - switch (seq_expr.type) { + auto element_expr = grammar_->GetRuleExpr(i); + switch (element_expr.type) { case RuleExprType::kSequence: - VisitSequenceInSequence(seq_expr, &new_sequence_ids); + VisitSequenceInSequence(element_expr, &new_sequence_ids); break; case RuleExprType::kChoices: - VisitChoiceInSequence(seq_expr, &new_sequence_ids); + VisitChoiceInSequence(element_expr, &new_sequence_ids); break; case RuleExprType::kEmptyStr: break; + case RuleExprType::kByteString: case RuleExprType::kCharacterClass: - case RuleExprType::kNegCharacterClass: - case RuleExprType::kRuleRef: - VisitElementInSequence(seq_expr, &new_sequence_ids); - break; case RuleExprType::kCharacterClassStar: - VisitCharacterClassStarInSequence(seq_expr, &new_sequence_ids); + case RuleExprType::kRuleRef: + VisitElementInSequence(element_expr, &new_sequence_ids); break; default: - LOG(FATAL) << "Unexpected sequence type: " << static_cast(seq_expr.type); + LOG(FATAL) << "Unexpected sequence type: " << static_cast(element_expr.type); } } return new_sequence_ids; @@ -223,22 +262,58 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { void VisitElementInSequence(const RuleExpr& rule_expr, std::vector* new_sequence_ids) { new_sequence_ids->push_back(builder_.AddRuleExpr(rule_expr)); } +}; - /*! \brief Visit a character class star RuleExpr that is in a sequence. */ - void VisitCharacterClassStarInSequence(const RuleExpr& rule_expr, - std::vector* new_sequence_ids) { - auto sub_expr_id = builder_.AddRuleExpr(grammar_->GetRuleExpr(rule_expr[0])); - auto new_star_id = builder_.AddCharacterClassStar(sub_expr_id); - auto new_rule_id = builder_.AddRuleWithHint(cur_rule_name_ + "_star", new_star_id); - auto new_rule_ref_id = builder_.AddRuleRef(new_rule_id); - new_sequence_ids->push_back(new_rule_ref_id); - } +class ByteStringFuser : public BNFGrammarMutator { + public: + using BNFGrammarMutator::Apply; + using BNFGrammarMutator::BNFGrammarMutator; - /*! \brief The name of the current rule being visited. */ - std::string cur_rule_name_; + private: + /*! + * \brief Visit a RuleExpr containing a sequence. + * \returns A list of new sequence RuleExpr ids. + */ + int32_t VisitSequence(const RuleExpr& rule_expr) final { + std::vector new_sequence_ids; + std::vector cur_byte_string; + for (auto i : rule_expr) { + auto element_expr = grammar_->GetRuleExpr(i); + if (element_expr.type == RuleExprType::kByteString) { + cur_byte_string.insert(cur_byte_string.end(), element_expr.begin(), element_expr.end()); + continue; + } else { + if (!cur_byte_string.empty()) { + new_sequence_ids.push_back(builder_.AddByteString(cur_byte_string)); + cur_byte_string.clear(); + } + new_sequence_ids.push_back(builder_.AddRuleExpr(element_expr)); + } + } + if (!cur_byte_string.empty()) { + new_sequence_ids.push_back(builder_.AddByteString(cur_byte_string)); + } + return builder_.AddSequence(new_sequence_ids); + } }; -BNFGrammar NestedRuleUnwrapper::Apply() { return NestedRuleUnwrapperImpl(grammar_).Apply(); } +// Return the list of all normalizers in the class. The normalizers are applied one by one. +std::vector> BNFGrammarNormalizer::GetNormalizerList() { + std::vector> normalizer_mutators; + normalizer_mutators.emplace_back(std::make_unique()); + normalizer_mutators.emplace_back(std::make_unique()); + normalizer_mutators.emplace_back(std::make_unique()); + return normalizer_mutators; +} + +BNFGrammar BNFGrammarNormalizer::Apply(const BNFGrammar& grammar) { + std::vector> normalizer_mutators = GetNormalizerList(); + grammar_ = grammar; + for (auto& mutator : normalizer_mutators) { + grammar_ = mutator->Apply(grammar_); + } + return grammar_; +} } // namespace serve } // namespace llm diff --git a/cpp/serve/grammar/grammar_simplifier.h b/cpp/serve/grammar/grammar_functor.h similarity index 58% rename from cpp/serve/grammar/grammar_simplifier.h rename to cpp/serve/grammar/grammar_functor.h index 50f3804387..123700778e 100644 --- a/cpp/serve/grammar/grammar_simplifier.h +++ b/cpp/serve/grammar/grammar_functor.h @@ -1,11 +1,11 @@ /*! * Copyright (c) 2023 by Contributors - * \file serve/grammar/grammar_simplifier.h + * \file serve/grammar/grammar_functor.h * \brief The header for the simplification of the BNF AST. */ -#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SIMPLIFIER_H_ -#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SIMPLIFIER_H_ +#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_FUNCTOR_H_ +#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_FUNCTOR_H_ #include #include @@ -27,29 +27,44 @@ namespace serve { * are void (for visitor) and BNFGrammar (for mutator). */ template -class BNFGrammarMutator { +class BNFGrammarFunctor { public: /*! * \brief Constructor. * \param grammar The grammar to visit or mutate. */ - explicit BNFGrammarMutator(const BNFGrammar& grammar) : grammar_(grammar) {} + explicit BNFGrammarFunctor() {} /*! * \brief Apply the transformation to the grammar, or visit the grammar. * \return The transformed grammar, or the visiting result, or void. - * \note Should be called only once after the mutator is constructed. */ - virtual ReturnType Apply() { - if constexpr (std::is_same::value && std::is_same::value) { + virtual ReturnType Apply(const BNFGrammar& grammar) { + Init(grammar); + if constexpr (std::is_same::value) { for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { auto rule = grammar_->GetRule(i); - auto rule_expr = grammar_->GetRuleExpr(rule.body_expr_id); - auto new_body_expr_id = VisitExpr(rule_expr); - builder_.AddRule(rule.name, new_body_expr_id); + cur_rule_name_ = rule.name; + VisitExpr(rule.body_expr_id); + VisitLookaheadAssertion(rule.lookahead_assertion_id); + } + } else if constexpr (std::is_same::value && + std::is_same::value) { + // First add empty rules to ensure the new rule ids the same as the old ones, then update + // the rule bodies + for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { + builder_.AddEmptyRule(grammar_->GetRule(i).name); + } + for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { + auto rule = grammar_->GetRule(i); + cur_rule_name_ = rule.name; + auto new_body_expr_id = VisitExpr(rule.body_expr_id); + builder_.UpdateRuleBody(i, new_body_expr_id); + // Handle lookahead assertion + builder_.AddLookaheadAssertion(i, VisitLookaheadAssertion(rule.lookahead_assertion_id)); } return builder_.Get(grammar_->GetMainRule().name); - } else if constexpr (!std::is_same::value) { + } else { return ReturnType(); } } @@ -59,6 +74,25 @@ class BNFGrammarMutator { using RuleExpr = BNFGrammarNode::RuleExpr; using RuleExprType = BNFGrammarNode::RuleExprType; + /*! \brief Initialize the functor. Should be called at the beginning of Apply(). */ + virtual void Init(const BNFGrammar& grammar) { + grammar_ = grammar; + builder_ = BNFGrammarBuilder(); + } + + /*! \brief Visit a lookahead assertion expr referred by id. */ + virtual T VisitLookaheadAssertion(int32_t lookahead_assertion_id) { + if (lookahead_assertion_id == -1) { + return -1; + } + return VisitExpr(lookahead_assertion_id); + } + + /*! \brief Visit a RuleExpr by id. */ + virtual T VisitExpr(int32_t old_rule_expr_id) { + return VisitExpr(grammar_->GetRuleExpr(old_rule_expr_id)); + } + /*! \brief Visit a RuleExpr. Dispatch to the corresponding Visit function. */ virtual T VisitExpr(const RuleExpr& rule_expr) { switch (rule_expr.type) { @@ -68,47 +102,48 @@ class BNFGrammarMutator { return VisitChoices(rule_expr); case RuleExprType::kEmptyStr: return VisitEmptyStr(rule_expr); + case RuleExprType::kByteString: + return VisitByteString(rule_expr); case RuleExprType::kCharacterClass: - case RuleExprType::kNegCharacterClass: return VisitCharacterClass(rule_expr); - case RuleExprType::kRuleRef: - return VisitRuleRef(rule_expr); case RuleExprType::kCharacterClassStar: return VisitCharacterClassStar(rule_expr); + case RuleExprType::kRuleRef: + return VisitRuleRef(rule_expr); default: LOG(FATAL) << "Unexpected sequence type: " << static_cast(rule_expr.type); } } - /*! \brief Visit a sequence RuleExpr. */ - virtual T VisitSequence(const RuleExpr& rule_expr) { + /*! \brief Visit a choices RuleExpr. */ + virtual T VisitChoices(const RuleExpr& rule_expr) { if constexpr (std::is_same::value) { for (auto i : rule_expr) { - VisitExpr(grammar_->GetRuleExpr(i)); + VisitExpr(i); } } else if constexpr (std::is_same::value) { - std::vector sequence_ids; + std::vector choice_ids; for (int32_t i : rule_expr) { - sequence_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + choice_ids.push_back(VisitExpr(i)); } - return builder_.AddSequence(sequence_ids); + return builder_.AddChoices(choice_ids); } else { return T(); } } - /*! \brief Visit a choices RuleExpr. */ - virtual T VisitChoices(const RuleExpr& rule_expr) { + /*! \brief Visit a sequence RuleExpr. */ + virtual T VisitSequence(const RuleExpr& rule_expr) { if constexpr (std::is_same::value) { for (auto i : rule_expr) { - VisitExpr(grammar_->GetRuleExpr(i)); + VisitExpr(i); } } else if constexpr (std::is_same::value) { - std::vector choice_ids; + std::vector sequence_ids; for (int32_t i : rule_expr) { - choice_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + sequence_ids.push_back(VisitExpr(i)); } - return builder_.AddChoices(choice_ids); + return builder_.AddSequence(sequence_ids); } else { return T(); } @@ -128,23 +163,18 @@ class BNFGrammarMutator { /*! \brief Visit an empty string RuleExpr. */ virtual T VisitEmptyStr(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + /*! \brief Visit a character class RuleExpr. */ + virtual T VisitByteString(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + /*! \brief Visit a character class RuleExpr. */ virtual T VisitCharacterClass(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + /*! \brief Visit a star quantifier RuleExpr. */ + virtual T VisitCharacterClassStar(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + /*! \brief Visit a rule reference RuleExpr. */ virtual T VisitRuleRef(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } - /*! \brief Visit a star quantifier RuleExpr. */ - virtual T VisitCharacterClassStar(const RuleExpr& rule_expr) { - if constexpr (std::is_same::value) { - VisitExpr(grammar_->GetRuleExpr(rule_expr[0])); - } else if constexpr (std::is_same::value) { - return builder_.AddCharacterClassStar(VisitExpr(grammar_->GetRuleExpr(rule_expr[0]))); - } else { - return T(); - } - } - /*! \brief The grammar to visit or mutate. */ BNFGrammar grammar_; /*! @@ -152,33 +182,38 @@ class BNFGrammarMutator { * can be used to build a new grammar in subclasses. */ BNFGrammarBuilder builder_; + /*! \brief The name of the current rule being visited. */ + std::string cur_rule_name_; }; /*! - * \brief Unwrap the rules containing nested expressions. After unwrapping, each rule will be in - * the form: `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. - * - * I.e. a list of choices, each choice is a sequence of elements. Elements can be a character class - * or a rule reference. And if the rule can be empty, the first choice will be an empty string. - * - * \example The rule `A ::= ((a) (((b)) (c)) "")` will be replaced by `A ::= ((a b c))`. One choice - * containing a sequence of three elements. The empty string is removed. - * \example The rule `A ::= (a | (b | (c | "")))` will be replaced by - * `A ::= ("" | (a) | (b) | (c))`. The first choice is an empty string, and each of the other three - * choices is a sequence containing a single element. - * \example The rule `A ::= (a | (b (c | d)))` will be replaced by - * `A ::= ((a) | (b B)), B ::= ((c) | (d))`. A new rule B is created to represent the nested - * choices. + * \brief Visitor of BNFGrammar. + * \tparam ReturnType The return type of the Apply() function. Denotes the collected information. */ -class NestedRuleUnwrapper : public BNFGrammarMutator { +template +using BNFGrammarVisitor = BNFGrammarFunctor; + +/*! + * \brief Mutator of BNFGrammar. The Apply() function returns the updated grammar. + */ +using BNFGrammarMutator = BNFGrammarFunctor; + +/*! + * \brief Normalize a BNFGrammar: expand the nested rules, combine consequent sequences and strings, + * etc. + */ +class BNFGrammarNormalizer : public BNFGrammarMutator { public: using BNFGrammarMutator::BNFGrammarMutator; - BNFGrammar Apply() final; + BNFGrammar Apply(const BNFGrammar& grammar) final; + + private: + std::vector> GetNormalizerList(); }; } // namespace serve } // namespace llm } // namespace mlc -#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SIMPLIFIER_H_ +#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_FUNCTOR_H_ diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index a4eda4e395..2799ee4ba9 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -29,6 +29,7 @@ class EBNFParserImpl { int32_t ParseRuleRef(); int32_t ParseElement(); int32_t ParseQuantifier(); + int32_t ParseLookaheadAssertion(); int32_t ParseSequence(); int32_t ParseChoices(); Rule ParseRule(); @@ -157,10 +158,10 @@ int32_t EBNFParserImpl::ParseCharacterClass() { } auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_, kCustomEscapeMap); - if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { + if (codepoint == CharHandlingError::kInvalidUTF8) { ThrowParseError("Invalid UTF8 sequence"); } - if (codepoint == static_cast(CharHandlingError::kInvalidEscape)) { + if (codepoint == CharHandlingError::kInvalidEscape) { ThrowParseError("Invalid escape sequence"); } Consume(new_cur - cur_); @@ -189,26 +190,37 @@ int32_t EBNFParserImpl::ParseCharacterClass() { // parse a c style string with utf8 support int32_t EBNFParserImpl::ParseString() { - std::vector character_classes; + std::vector codepoints; while (Peek() && Peek() != '\"') { if (Peek() == '\r' || Peek() == '\n') { ThrowParseError("There should be no newline character in a string literal"); } auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_); - if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { + if (codepoint == CharHandlingError::kInvalidUTF8) { ThrowParseError("Invalid utf8 sequence"); } - if (codepoint == static_cast(CharHandlingError::kInvalidEscape)) { + if (codepoint == CharHandlingError::kInvalidEscape) { ThrowParseError("Invalid escape sequence"); } Consume(new_cur - cur_); - character_classes.push_back(builder_.AddCharacterClass({{codepoint, codepoint}})); + codepoints.push_back(codepoint); } - if (character_classes.empty()) { + if (codepoints.empty()) { return builder_.AddEmptyStr(); } - return builder_.AddSequence(character_classes); + + // convert codepoints to string + std::string str; + for (auto codepoint : codepoints) { + str += PrintAsUTF8(codepoint); + } + // convert str to int32_t vector + std::vector bytes; + for (auto c : str) { + bytes.push_back(static_cast(c)); + } + return builder_.AddByteString(bytes); } int32_t EBNFParserImpl::ParseRuleRef() { @@ -264,9 +276,11 @@ int32_t EBNFParserImpl::ParseElement() { } int32_t EBNFParserImpl::HandleStarQuantifier(int32_t rule_expr_id) { - if (builder_.GetRuleExpr(rule_expr_id).type == BNFGrammarBuilder::RuleExprType::kCharacterClass) { + BNFGrammarNode::RuleExpr rule_expr = builder_.GetRuleExpr(rule_expr_id); + if (rule_expr.type == BNFGrammarBuilder::RuleExprType::kCharacterClass) { // We have special handling for character class star, e.g. [a-z]* - return builder_.AddCharacterClassStar(rule_expr_id); + rule_expr.type = BNFGrammarBuilder::RuleExprType::kCharacterClassStar; + return builder_.AddRuleExpr(rule_expr); } else { // For other star quantifiers, we transform it into a rule: // a* --> rule ::= a rule | "" @@ -327,12 +341,11 @@ int32_t EBNFParserImpl::ParseQuantifier() { int32_t EBNFParserImpl::ParseSequence() { std::vector elements; - elements.push_back(ParseQuantifier()); - ConsumeSpace(in_parentheses_); - while (Peek() && Peek() != '|' && Peek() != ')' && Peek() != '\n' && Peek() != '\r') { + do { elements.push_back(ParseQuantifier()); ConsumeSpace(in_parentheses_); - } + } while (Peek() && Peek() != '|' && Peek() != ')' && Peek() != '\n' && Peek() != '\r' && + (Peek() != '(' || Peek(1) != '=')); return builder_.AddSequence(elements); } @@ -350,6 +363,24 @@ int32_t EBNFParserImpl::ParseChoices() { return builder_.AddChoices(choices); } +int32_t EBNFParserImpl::ParseLookaheadAssertion() { + if (Peek() != '(' || Peek(1) != '=') { + return -1; + } + Consume(2); + auto prev_in_parentheses = in_parentheses_; + in_parentheses_ = true; + ConsumeSpace(in_parentheses_); + auto result = ParseSequence(); + ConsumeSpace(in_parentheses_); + if (Peek() != ')') { + ThrowParseError("Expect )"); + } + Consume(); + in_parentheses_ = prev_in_parentheses; + return result; +} + EBNFParserImpl::Rule EBNFParserImpl::ParseRule() { std::string name = ParseName(); cur_rule_name_ = name; @@ -359,7 +390,10 @@ EBNFParserImpl::Rule EBNFParserImpl::ParseRule() { } Consume(3); ConsumeSpace(); - return {name, ParseChoices()}; + auto body_id = ParseChoices(); + ConsumeSpace(); + auto lookahead_id = ParseLookaheadAssertion(); + return {name, body_id, lookahead_id}; } void EBNFParserImpl::BuildRuleNameToId() { @@ -399,8 +433,14 @@ BNFGrammar EBNFParserImpl::DoParse(std::string ebnf_string, std::string main_rul ResetStringIterator(ebnf_string.c_str()); ConsumeSpace(); while (Peek()) { + // Throw error when there are multiple lookahead assertions + if (Peek() == '(' && Peek(1) == '=') { + ThrowParseError("Unexpected lookahead assertion"); + } auto new_rule = ParseRule(); builder_.UpdateRuleBody(new_rule.name, new_rule.body_expr_id); + // Update the lookahead assertion + builder_.AddLookaheadAssertion(new_rule.name, new_rule.lookahead_assertion_id); ConsumeSpace(); } diff --git a/cpp/serve/grammar/grammar_parser.h b/cpp/serve/grammar/grammar_parser.h index 4d10e8eb0d..94ac3d4ce1 100644 --- a/cpp/serve/grammar/grammar_parser.h +++ b/cpp/serve/grammar/grammar_parser.h @@ -23,7 +23,7 @@ using namespace tvm::runtime; * \details This function accepts the EBNF notation defined in the W3C XML Specification * (https://www.w3.org/TR/xml/#sec-notation), which is a popular standard, with the following * changes: - * - Using # as comment mark instead of /**\/ + * - Using # as comment mark instead of C-style comments * - Accept C-style unicode escape sequence \u01AB, \U000001AB, \xAB instead of #x0123 * - Rule A-B (match A and not match B) is not supported yet * diff --git a/cpp/serve/grammar/grammar_serializer.cc b/cpp/serve/grammar/grammar_serializer.cc index c3c2c88baa..5176b9f102 100644 --- a/cpp/serve/grammar/grammar_serializer.cc +++ b/cpp/serve/grammar/grammar_serializer.cc @@ -18,7 +18,11 @@ namespace serve { using namespace tvm::runtime; std::string BNFGrammarPrinter::PrintRule(const Rule& rule) { - return rule.name + " ::= " + PrintRuleExpr(rule.body_expr_id); + std::string res = rule.name + " ::= " + PrintRuleExpr(rule.body_expr_id); + if (rule.lookahead_assertion_id != -1) { + res += " (=" + PrintRuleExpr(rule.lookahead_assertion_id) + ")"; + } + return res; } std::string BNFGrammarPrinter::PrintRule(int32_t rule_id) { @@ -28,10 +32,12 @@ std::string BNFGrammarPrinter::PrintRule(int32_t rule_id) { std::string BNFGrammarPrinter::PrintRuleExpr(const RuleExpr& rule_expr) { std::string result; switch (rule_expr.type) { + case RuleExprType::kByteString: + return PrintByteString(rule_expr); case RuleExprType::kCharacterClass: return PrintCharacterClass(rule_expr); - case RuleExprType::kNegCharacterClass: - return PrintCharacterClass(rule_expr); + case RuleExprType::kCharacterClassStar: + return PrintCharacterClassStar(rule_expr); case RuleExprType::kEmptyStr: return PrintEmptyStr(rule_expr); case RuleExprType::kRuleRef: @@ -40,8 +46,6 @@ std::string BNFGrammarPrinter::PrintRuleExpr(const RuleExpr& rule_expr) { return PrintSequence(rule_expr); case RuleExprType::kChoices: return PrintChoices(rule_expr); - case RuleExprType::kCharacterClassStar: - return PrintCharacterClassStar(rule_expr); default: LOG(FATAL) << "Unexpected RuleExpr type: " << static_cast(rule_expr.type); } @@ -51,14 +55,29 @@ std::string BNFGrammarPrinter::PrintRuleExpr(int32_t rule_expr_id) { return PrintRuleExpr(grammar_->GetRuleExpr(rule_expr_id)); } +std::string BNFGrammarPrinter::PrintByteString(const RuleExpr& rule_expr) { + std::string internal_str; + internal_str.reserve(rule_expr.data_len); + for (int i = 0; i < rule_expr.data_len; ++i) { + internal_str += static_cast(rule_expr[i]); + } + auto codepoints = ParseUTF8(internal_str.c_str(), UTF8ErrorPolicy::kReturnByte); + std::string result; + for (auto codepoint : codepoints) { + result += PrintAsEscaped(codepoint); + } + return "\"" + result + "\""; +} + std::string BNFGrammarPrinter::PrintCharacterClass(const RuleExpr& rule_expr) { static const std::unordered_map kCustomEscapeMap = {{'-', "\\-"}, {']', "\\]"}}; std::string result = "["; - if (rule_expr.type == RuleExprType::kNegCharacterClass) { + bool is_negative = static_cast(rule_expr[0]); + if (is_negative) { result += "^"; } - for (auto i = 0; i < rule_expr.data_len; i += 2) { + for (auto i = 1; i < rule_expr.data_len; i += 2) { result += PrintAsEscaped(rule_expr[i], kCustomEscapeMap); if (rule_expr[i] == rule_expr[i + 1]) { continue; @@ -70,6 +89,10 @@ std::string BNFGrammarPrinter::PrintCharacterClass(const RuleExpr& rule_expr) { return result; } +std::string BNFGrammarPrinter::PrintCharacterClassStar(const RuleExpr& rule_expr) { + return PrintCharacterClass(rule_expr) + "*"; +} + std::string BNFGrammarPrinter::PrintEmptyStr(const RuleExpr& rule_expr) { return "\"\""; } std::string BNFGrammarPrinter::PrintRuleRef(const RuleExpr& rule_expr) { @@ -103,10 +126,6 @@ std::string BNFGrammarPrinter::PrintChoices(const RuleExpr& rule_expr) { return result; } -std::string BNFGrammarPrinter::PrintCharacterClassStar(const RuleExpr& rule_expr) { - return PrintRuleExpr(rule_expr[0]) + "*"; -} - std::string BNFGrammarPrinter::ToString() { std::string result; auto num_rules = grammar_->NumRules(); @@ -121,7 +140,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarToString").set_body_typed([](const BNFG }); std::string BNFGrammarJSONSerializer::ToString() { - picojson::object grammar_json; + picojson::object grammar_json_obj; picojson::array rules_json; for (const auto& rule : grammar_->rules_) { @@ -130,20 +149,21 @@ std::string BNFGrammarJSONSerializer::ToString() { rule_json["body_expr_id"] = picojson::value(static_cast(rule.body_expr_id)); rules_json.push_back(picojson::value(rule_json)); } - grammar_json["rules"] = picojson::value(rules_json); + grammar_json_obj["rules"] = picojson::value(rules_json); picojson::array rule_expr_data_json; for (const auto& data : grammar_->rule_expr_data_) { rule_expr_data_json.push_back(picojson::value(static_cast(data))); } - grammar_json["rule_expr_data"] = picojson::value(rule_expr_data_json); + grammar_json_obj["rule_expr_data"] = picojson::value(rule_expr_data_json); picojson::array rule_expr_indptr_json; for (const auto& index_ptr : grammar_->rule_expr_indptr_) { rule_expr_indptr_json.push_back(picojson::value(static_cast(index_ptr))); } - grammar_json["rule_expr_indptr"] = picojson::value(rule_expr_indptr_json); + grammar_json_obj["rule_expr_indptr"] = picojson::value(rule_expr_indptr_json); - return picojson::value(grammar_json).serialize(prettify_); + auto grammar_json = picojson::value(grammar_json_obj); + return grammar_json.serialize(prettify_); } TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarToJSON") diff --git a/cpp/serve/grammar/grammar_serializer.h b/cpp/serve/grammar/grammar_serializer.h index 4ad5c2103b..f0837d9638 100644 --- a/cpp/serve/grammar/grammar_serializer.h +++ b/cpp/serve/grammar/grammar_serializer.h @@ -62,8 +62,12 @@ class BNFGrammarPrinter : public BNFGrammarSerializer { std::string PrintRuleExpr(int32_t rule_expr_id); private: + /*! \brief Print a RuleExpr for byte string. */ + std::string PrintByteString(const RuleExpr& rule_expr); /*! \brief Print a RuleExpr for character class. */ std::string PrintCharacterClass(const RuleExpr& rule_expr); + /*! \brief Print a RuleExpr for a star quantifier of a character class. */ + std::string PrintCharacterClassStar(const RuleExpr& rule_expr); /*! \brief Print a RuleExpr for empty string. */ std::string PrintEmptyStr(const RuleExpr& rule_expr); /*! \brief Print a RuleExpr for rule reference. */ @@ -72,8 +76,6 @@ class BNFGrammarPrinter : public BNFGrammarSerializer { std::string PrintSequence(const RuleExpr& rule_expr); /*! \brief Print a RuleExpr for rule_expr choices. */ std::string PrintChoices(const RuleExpr& rule_expr); - /*! \brief Print a RuleExpr for star quantifier. */ - std::string PrintCharacterClassStar(const RuleExpr& rule_expr); }; /*! diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index 451127e746..e6e68f376f 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -2,6 +2,7 @@ * Copyright (c) 2023 by Contributors * \file serve/grammar/grammar_state_matcher.cc */ +// #define TVM_LOG_DEBUG 1 #include "grammar_state_matcher.h" #include @@ -123,13 +124,15 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm private: using RuleExpr = BNFGrammarNode::RuleExpr; using RuleExprType = BNFGrammarNode::RuleExprType; + using SaveType = CatagorizedTokens::SaveType; public: GrammarStateMatcherNodeImpl(std::shared_ptr init_ctx, int max_rollback_steps = 0) : GrammarStateMatcherBase(init_ctx->grammar), init_ctx_(init_ctx), - max_rollback_steps_(max_rollback_steps) {} + max_rollback_steps_(max_rollback_steps), + tmp_accepted_bitset_(init_ctx_->vocab_size) {} bool AcceptToken(int32_t token_id) final; @@ -143,8 +146,8 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm void ResetState() final { stack_tops_history_.Reset(); - token_size_history_.clear(); - InitStackState(); + token_length_history.clear(); + PushInitialState(kInvalidRulePosition, true); } private: @@ -160,14 +163,8 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm const std::vector& uncertain_tokens_bitset); /*! \brief Set the acceptable next token in next_token_bitmask. */ - void SetTokenBitmask(DLTensor* next_token_bitmask, std::vector& accepted_indices, - std::vector& rejected_indices, bool can_reach_end); - - /*! \brief Check if a token is a stop token. */ - bool IsStopToken(int32_t token_id) const { - return std::find(init_ctx_->stop_token_ids.begin(), init_ctx_->stop_token_ids.end(), - token_id) != init_ctx_->stop_token_ids.end(); - } + void SetTokenBitmask(DLTensor* next_token_bitmask, const DynamicBitset& accepted_bitset, + const std::vector& rejected_indices, bool can_reach_end); /*! * \brief Accept the stop token and terminates the matcher. @@ -180,14 +177,12 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm std::shared_ptr init_ctx_; int max_rollback_steps_; - std::deque token_size_history_; + std::deque token_length_history; // Temporary data for FindNextTokenBitmask. They are stored here to avoid repeated allocation. - std::vector tmp_accepted_indices_; + DynamicBitset tmp_accepted_bitset_; std::vector tmp_rejected_indices_; - std::vector tmp_accepted_indices_delta_; std::vector tmp_rejected_indices_delta_; - std::vector tmp_uncertain_tokens_bitset_; }; bool GrammarStateMatcherNodeImpl::AcceptStopToken() { @@ -204,23 +199,31 @@ bool GrammarStateMatcherNodeImpl::AcceptToken(int32_t token_id) { "accept another token id " << token_id; + CHECK(token_id >= 0 && token_id < init_ctx_->vocab_size) + << "Invalid token id " << token_id << " for GrammarStateMatcher"; + // Handle the stop token - if (IsStopToken(token_id)) { + if (std::find(init_ctx_->stop_token_ids.begin(), init_ctx_->stop_token_ids.end(), token_id) != + init_ctx_->stop_token_ids.end()) { return AcceptStopToken(); } - CHECK(init_ctx_->id_to_token_codepoints.count(token_id) > 0) - << "Token id " << token_id << " is not supported in generation"; - const auto& token = init_ctx_->id_to_token_codepoints[token_id].token; - for (auto codepoint : token) { - if (!AcceptCodepoint(codepoint, false)) { + if (init_ctx_->special_token_ids.count(token_id) > 0) { + LOG(FATAL) + << "Token id " << token_id << ": " << init_ctx_->token_table[token_id] + << " is regarded as a special token, and cannot be accepted by the GrammarStateMatcher"; + } + + const auto& token = init_ctx_->token_table[token_id]; + for (auto char_value : token) { + if (!AcceptChar(char_value, false)) { return false; } } - token_size_history_.push_back(token.size()); - if (token_size_history_.size() > max_rollback_steps_) { - DiscardEarliestCodepoints(token_size_history_.front()); - token_size_history_.pop_front(); + token_length_history.push_back(token.size()); + if (token_length_history.size() > max_rollback_steps_) { + DiscardEarliestChars(token_length_history.front()); + token_length_history.pop_front(); } return true; } @@ -229,7 +232,7 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm CHECK(!IsTerminated()) << "GrammarStateMatcher has terminated after accepting the stop token, but is trying to " "find the next token mask"; - const auto& sorted_token_codepoints = init_ctx_->sorted_token_codepoints; + const auto& sorted_token_table = init_ctx_->sorted_token_table; const auto& catagorized_tokens_for_grammar = init_ctx_->catagorized_tokens_for_grammar; const auto& latest_stack_tops = stack_tops_history_.GetLatest(); @@ -238,113 +241,132 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm // The final accepted token set is the union of the accepted token sets of all stacks. // The final rejected token set is the intersection of the rejected token sets of all stacks. - // Note these indices store the indices in sorted_token_codepoints, instead of the token ids. - tmp_accepted_indices_.clear(); + // Note these indices store the indices in sorted_token_table, instead of the token ids. + tmp_accepted_bitset_.Reset(); // {-1} means the universal set, i.e. all tokens initially tmp_rejected_indices_.assign({-1}); + // std::chrono::microseconds time_unc(0); + // std::chrono::microseconds time_idx(0); + int check_cnt = 0; + for (auto top : latest_stack_tops) { - // Step 1. Find the current catagorized_tokens auto cur_rule_position = tree_[top]; - auto current_sequence = grammar_->GetRuleExpr(cur_rule_position.sequence_id); - if (cur_rule_position.parent_id == RulePosition::kNoParent && - cur_rule_position.element_id == current_sequence.size()) { + if (tree_.IsEndPosition(cur_rule_position)) { continue; } - const auto& catagorized_tokens = catagorized_tokens_for_grammar.at( - {cur_rule_position.sequence_id, cur_rule_position.element_id}); + const auto& catagorized_tokens = catagorized_tokens_for_grammar.at(cur_rule_position); + + // auto start = std::chrono::high_resolution_clock::now(); // For each stack, we will check every uncertain token and put them into the accepted or // rejected list. - // If the accepted tokens are saved, it means it is likely to be smaller than the rejected - // tokens, so we will just find the accepted tokens, and vice versa. - bool is_find_accept_mode = - catagorized_tokens.not_saved_index != CatagorizedTokens::NotSavedIndex::kAccepted; - - // If uncertain tokens are saved, we will iterate over the uncertain tokens. - // Otherwise, we will iterate over all_tokens - accepted_tokens - rejected_tokens. - bool is_uncertain_saved = - catagorized_tokens.not_saved_index != CatagorizedTokens::NotSavedIndex::kUncertain; // Step 2. Update the accepted tokens in accepted_indices_delta, or the rejected tokens in // rejected_indices_delta. - // Examine only the current one stack - stack_tops_history_.PushHistory({tree_.NewNode(cur_rule_position)}); - - const std::vector* prev_token = nullptr; - int prev_matched_size = 0; + // If the accepted tokens are saved, it means it is likely to be smaller than the rejected + // tokens, so we will just find the accepted tokens, and vice versa. - tmp_accepted_indices_delta_.clear(); tmp_rejected_indices_delta_.clear(); - if (!is_uncertain_saved) { - // unc_tokens = all_tokens - accepted_tokens - rejected_tokens - tmp_uncertain_tokens_bitset_.assign(sorted_token_codepoints.size(), true); - for (auto idx : catagorized_tokens.accepted_indices) { - tmp_uncertain_tokens_bitset_[idx] = false; - } - for (auto idx : catagorized_tokens.rejected_indices) { - tmp_uncertain_tokens_bitset_[idx] = false; - } - } + // Examine only the current one stack + stack_tops_history_.PushHistory({tree_.NewNode(cur_rule_position)}); - int iterator_uncertain = -1; + const std::string* prev_token = nullptr; + int prev_matched_size = 0; - while (true) { - // Step 2.1. Find the current token. - auto idx = - GetNextUncertainToken(is_uncertain_saved, &iterator_uncertain, - catagorized_tokens.uncertain_indices, tmp_uncertain_tokens_bitset_); - if (idx == -1) { - break; - } - const auto& cur_token = sorted_token_codepoints[idx].token; + // std::cout << tree_.PrintNode(top) << std::endl; + + // std::cout << "Accepted count: " << catagorized_tokens.accepted_indices.size() + // << ", rejected count: " << catagorized_tokens.rejected_indices.size() + // << ", uncertain count: " << catagorized_tokens.uncertain_indices.size() + // << ", save type: " << static_cast(catagorized_tokens.save_type) << std::endl; + + // if (catagorized_tokens.accepted_indices.size() < 200) { + // std::cout << "Accpeted: "; + // for (int i = 0; i < catagorized_tokens.accepted_indices.size(); ++i) { + // std::cout << "<" + // << PrintAsEscaped( + // sorted_token_table[catagorized_tokens.accepted_indices[i]].second) + // << "> "; + // } + // std::cout << "\n"; + // } + + // if (catagorized_tokens.uncertain_indices.size() > 100) { + // std::cout << "Uncertain: "; + // for (int i = 0; i < catagorized_tokens.uncertain_indices.size(); ++i) { + // std::cout << "<" + // << PrintAsEscaped( + // sorted_token_table[catagorized_tokens.uncertain_indices[i]].second) + // << "> "; + // } + // std::cout << "\n"; + // } + + for (auto cur_token_idx : catagorized_tokens.uncertain_indices) { + const auto& cur_token = sorted_token_table[cur_token_idx].second; + bool accepted = true; - // Step 2.2. Find the longest common prefix with the accepted part of the previous token. + // Step 2.1. Find the longest common prefix with the accepted part of the previous token. // We can reuse the previous matched size to avoid unnecessary matching. - int prev_useful_size = 0; if (prev_token) { - prev_useful_size = std::min(prev_matched_size, static_cast(cur_token.size())); - for (int j = 0; j < prev_useful_size; ++j) { - if (cur_token[j] != (*prev_token)[j]) { - prev_useful_size = j; - break; - } + int lcp_len = std::mismatch(cur_token.begin(), cur_token.end(), prev_token->begin(), + prev_token->end()) + .first - + cur_token.begin(); + if (lcp_len > prev_matched_size) { + accepted = false; + } else if (lcp_len < prev_matched_size) { + RollbackChars(prev_matched_size - lcp_len); } - RollbackCodepoints(prev_matched_size - prev_useful_size); + prev_matched_size = std::min(prev_matched_size, lcp_len); } - // Step 2.3. Find if the current token is accepted or rejected. - bool accepted = true; - prev_matched_size = prev_useful_size; - - for (int j = prev_useful_size; j < cur_token.size(); ++j) { - if (!AcceptCodepoint(cur_token[j], false)) { - accepted = false; - break; + // Step 2.2. Find if the current token is accepted or rejected. + if (accepted) { + for (int j = prev_matched_size; j < cur_token.size(); ++j) { + ++check_cnt; + if (!AcceptChar(cur_token[j], false)) { + accepted = false; + break; + } + prev_matched_size = j + 1; } - prev_matched_size = j + 1; } - // Step 2.4. Push the result to the delta list. - if (accepted && is_find_accept_mode) { - tmp_accepted_indices_delta_.push_back(idx); - } else if (!accepted && !is_find_accept_mode) { - tmp_rejected_indices_delta_.push_back(idx); + // Step 2.3. Push the result to the delta list. + if (catagorized_tokens.save_type == SaveType::kAcceptedBitset || + catagorized_tokens.save_type == SaveType::kAccepted) { + if (accepted) { + tmp_accepted_bitset_.Set(sorted_token_table[cur_token_idx].first, true); + } + } else { + if (!accepted) { + tmp_rejected_indices_delta_.push_back(cur_token_idx); + } } prev_token = &cur_token; } - RollbackCodepoints(prev_matched_size + 1); + RollbackChars(prev_matched_size + 1); + + // auto end = std::chrono::high_resolution_clock::now(); + + // time_unc += std::chrono::duration_cast(end - start); + + // start = std::chrono::high_resolution_clock::now(); // Step 3. Update the accepted_indices and rejected_indices - if (is_find_accept_mode) { - // accepted_indices += catagorized_tokens.accepted_indices + accepted_indices_delta - IntsetUnion(&tmp_accepted_indices_delta_, catagorized_tokens.accepted_indices); - IntsetUnion(&tmp_accepted_indices_, tmp_accepted_indices_delta_); + if (catagorized_tokens.save_type == SaveType::kAcceptedBitset) { + tmp_accepted_bitset_ |= catagorized_tokens.accepted_bitset; + } else if (catagorized_tokens.save_type == SaveType::kAccepted) { + for (auto idx : catagorized_tokens.accepted_indices) { + tmp_accepted_bitset_.Set(sorted_token_table[idx].first, true); + } } else { // rejected_indices = Intersect( // rejected_indices, @@ -352,72 +374,81 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm IntsetUnion(&tmp_rejected_indices_delta_, catagorized_tokens.rejected_indices); IntsetIntersection(&tmp_rejected_indices_, tmp_rejected_indices_delta_); } + // end = std::chrono::high_resolution_clock::now(); + // time_idx += std::chrono::duration_cast(end - start); } // Finally update the rejected_ids bitset + // auto start = std::chrono::high_resolution_clock::now(); bool can_reach_end = CanReachEnd(); - SetTokenBitmask(next_token_bitmask, tmp_accepted_indices_, tmp_rejected_indices_, can_reach_end); + SetTokenBitmask(next_token_bitmask, tmp_accepted_bitset_, tmp_rejected_indices_, can_reach_end); + // auto end = std::chrono::high_resolution_clock::now(); + // time_idx += std::chrono::duration_cast(end - start); + // std::cout << "Time for uncertain: " << time_unc.count() + // << "us, time for index: " << time_idx.count() << "us" << std::endl; + // std::cout << "Check cnt " << check_cnt << std::endl; } void GrammarStateMatcherNodeImpl::Rollback(int num_tokens) { - CHECK(num_tokens <= token_size_history_.size()) + CHECK(num_tokens <= token_length_history.size()) << "Intended to rollback " << num_tokens << " tokens, but only the last " - << token_size_history_.size() << " steps of history are saved"; + << token_length_history.size() << " steps of history are saved"; while (num_tokens > 0) { - int steps = token_size_history_.back(); - RollbackCodepoints(steps); - token_size_history_.pop_back(); + int steps = token_length_history.back(); + RollbackChars(steps); + token_length_history.pop_back(); --num_tokens; } } void GrammarStateMatcherNodeImpl::SetTokenBitmask(DLTensor* next_token_bitmask, - std::vector& accepted_indices, - std::vector& rejected_indices, + const DynamicBitset& accepted_bitset, + const std::vector& rejected_indices, bool can_reach_end) { - // accepted_ids = Union(accepted_indices, all_tokens - rejected_indices) - // rejected_ids = Intersect(all_tokens - accepted_indices, rejected_indices) + // next_token_bitmask = set(all accepted tokens) = + // 1. all_tokens - (rejected_ids / accepted_ids) + // (when rejected_ids != {-1}, i.e. rejected_ids is not the universal set) + // 2. accepted_ids + // (otherwise, when rejected_ids is the universal set) CHECK(next_token_bitmask->dtype.code == kDLUInt && next_token_bitmask->dtype.bits == 32 && next_token_bitmask->data && next_token_bitmask->ndim == 1 && next_token_bitmask->shape) << "The provied bitmask's shape or dtype is not valid."; + CHECK(next_token_bitmask->shape[0] >= DynamicBitset::CalculateBufferSize(init_ctx_->vocab_size)) + << "The provided bitmask is not large enough to store the token set. The length should be " + << DynamicBitset::CalculateBufferSize(init_ctx_->vocab_size) << " at least"; - BitsetManager next_token_bitset(reinterpret_cast(next_token_bitmask->data), - next_token_bitmask->shape[0], init_ctx_->vocab_size); + DynamicBitset next_token_bitset(init_ctx_->vocab_size, + reinterpret_cast(next_token_bitmask->data)); + const auto& sorted_token_table = init_ctx_->sorted_token_table; if (rejected_indices.size() == 1 && rejected_indices[0] == -1) { // If rejected_indices is the universal set, the final accepted token set is just // accepted_indices - next_token_bitset.Reset(false); - for (int idx : accepted_indices) { - next_token_bitset.Set(init_ctx_->sorted_token_codepoints[idx].id, true); - } + next_token_bitset = accepted_bitset; if (can_reach_end) { // add end tokens - for (int idx : init_ctx_->stop_token_ids) { - next_token_bitset.Set(idx, true); + for (int id : init_ctx_->stop_token_ids) { + next_token_bitset.Set(id, true); } } } else { // Otherwise, the final rejected token set is (rejected_indices \ accepted_indices) - next_token_bitset.Reset(true); + next_token_bitset.Set(); - auto it_acc = accepted_indices.begin(); for (auto i : rejected_indices) { - while (it_acc != accepted_indices.end() && *it_acc < i) { - ++it_acc; - } - if (it_acc == accepted_indices.end() || *it_acc != i) { - next_token_bitset.Set(init_ctx_->sorted_token_codepoints[i].id, false); + auto id = sorted_token_table[i].first; + if (!accepted_bitset[id]) { + next_token_bitset.Set(id, false); } } - for (int idx : init_ctx_->special_token_ids) { - next_token_bitset.Set(idx, false); + for (int id : init_ctx_->special_token_ids) { + next_token_bitset.Set(id, false); } if (!can_reach_end) { - for (int idx : init_ctx_->stop_token_ids) { - next_token_bitset.Set(idx, false); + for (int id : init_ctx_->stop_token_ids) { + next_token_bitset.Set(id, false); } } } @@ -452,16 +483,24 @@ GrammarStateMatcher::GrammarStateMatcher(std::shared_ptr tokenizer, int max_rollback_steps) { + .set_body_typed([](BNFGrammar grammar, Optional tokenizer, int max_rollback_steps, + String token_table_postproc_method) { auto preproc_start = std::chrono::high_resolution_clock::now(); - auto init_ctx = GrammarStateMatcher::CreateInitContext( - grammar, tokenizer ? tokenizer.value()->TokenTable() : std::vector()); + std::shared_ptr init_ctx; + if (tokenizer) { + auto token_table = Tokenizer::PostProcessTokenTable(tokenizer.value()->TokenTable(), + token_table_postproc_method); + init_ctx = GrammarStateMatcher::CreateInitContext(grammar, token_table); + } else { + init_ctx = GrammarStateMatcher::CreateInitContext(grammar, {}); + } + auto preproc_end = std::chrono::high_resolution_clock::now(); - std::cerr << "Preprocess takes " + LOG(INFO) << "GrammarStateMatcher preprocess takes " << std::chrono::duration_cast(preproc_end - preproc_start) .count() - << "us" << std::endl; + << "us"; return GrammarStateMatcher(init_ctx, max_rollback_steps); }); #endif @@ -479,11 +518,11 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenTable") *rv = GrammarStateMatcher(init_ctx, max_rollback_steps); }); -TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugAcceptCodepoint") - .set_body_typed([](GrammarStateMatcher matcher, int32_t codepoint) { +TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugAcceptChar") + .set_body_typed([](GrammarStateMatcher matcher, int32_t codepoint, bool verbose) { auto mutable_node = const_cast(matcher.as()); - return mutable_node->AcceptCodepoint(codepoint); + return mutable_node->AcceptChar(codepoint, verbose); }); TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherAcceptToken") @@ -507,32 +546,43 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherResetState") /*! \brief Check if a matcher can accept the complete string, and then reach the end of the * grammar. Does not change the state of the GrammarStateMatcher. For test purpose. */ -bool MatchCompleteString(GrammarStateMatcher matcher, String str) { +bool MatchCompleteString(GrammarStateMatcher matcher, String str, bool verbose) { auto mutable_node = const_cast(matcher.as()); - auto codepoints = ParseUTF8(str.c_str()); int accepted_cnt = 0; - for (auto codepoint : codepoints) { - if (!mutable_node->AcceptCodepoint(codepoint, false)) { - mutable_node->RollbackCodepoints(accepted_cnt); + for (auto char_value : str.operator std::string()) { + if (!mutable_node->AcceptChar(char_value, verbose)) { + if (verbose) { + LOG(INFO) << "Matching failed after accepting " << accepted_cnt << " characters"; + } + mutable_node->RollbackChars(accepted_cnt); return false; } ++accepted_cnt; } auto accepted = mutable_node->CanReachEnd(); - mutable_node->RollbackCodepoints(accepted_cnt); + if (verbose) { + if (accepted) { + LOG(INFO) << "Matching succeed after accepting " << accepted_cnt << " characters"; + } else { + LOG(INFO) << "Matching failed due to the end state not reached after all " << accepted_cnt + << " characters are accepted"; + } + } + mutable_node->RollbackChars(accepted_cnt); return accepted; } TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugMatchCompleteString") - .set_body_typed([](GrammarStateMatcher matcher, String str) { - return MatchCompleteString(matcher, str); + .set_body_typed([](GrammarStateMatcher matcher, String str, bool verbose) { + return MatchCompleteString(matcher, str, verbose); }); /*! \brief Print the accepted and rejected tokens stored in the bitset. For debug purposes. */ -void PrintAcceptedRejectedTokens( +std::string PrintAcceptedRejectedTokens( const std::shared_ptr& init_ctx, - const BitsetManager& bitset, int threshold = 500) { + const DynamicBitset& bitset, int threshold = 300) { + std::stringstream ss; auto vocab_size = init_ctx->vocab_size; std::vector accepted_ids; std::vector rejected_ids; @@ -544,42 +594,27 @@ void PrintAcceptedRejectedTokens( } } - if (accepted_ids.size() < threshold) { - std::cerr << "Accepted: "; - for (auto id : accepted_ids) { - std::cerr << "<"; - auto token = init_ctx->token_table[id]; - if (token.size() == 1 && (static_cast(token[0]) >= 128 || token[0] == 0)) { - // First cast to unsigned, then cast to int - std::cerr << static_cast(static_cast(token[0])); - } else { - auto codepoints = ParseUTF8(token.c_str()); - for (auto c : codepoints) { - std::cerr << PrintAsEscaped(c); - } - } - std::cerr << "> "; - } - std::cerr << "\n"; + ss << "Accepted: "; + auto end_it = + accepted_ids.size() > threshold ? accepted_ids.begin() + threshold : accepted_ids.end(); + for (auto it = accepted_ids.begin(); it != end_it; ++it) { + ss << "<" << PrintAsEscaped(init_ctx->token_table[*it]) << "> "; + } + if (accepted_ids.size() > threshold) { + ss << "..."; } + ss << "\n"; - if (rejected_ids.size() < threshold) { - std::cerr << "Rejected: "; - for (auto id : rejected_ids) { - std::cerr << "<"; - auto token = init_ctx->token_table[id]; - if (token.size() == 1 && ((unsigned char)token[0] >= 128 || token[0] == 0)) { - std::cerr << (int)(unsigned char)token[0]; - } else { - auto codepoints = ParseUTF8(token.c_str()); - for (auto c : codepoints) { - std::cerr << PrintAsEscaped(c); - } - } - std::cerr << "> "; - } - std::cerr << "\n"; + ss << "Rejected: "; + end_it = rejected_ids.size() > threshold ? rejected_ids.begin() + threshold : rejected_ids.end(); + for (auto it = rejected_ids.begin(); it != end_it; ++it) { + ss << "<" << PrintAsEscaped(init_ctx->token_table[*it]) << "> "; + } + if (rejected_ids.size() > threshold) { + ss << "..."; } + ss << "\n"; + return ss.str(); } /*! @@ -591,7 +626,7 @@ void PrintAcceptedRejectedTokens( IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose = false) { auto init_ctx = matcher.as()->init_ctx_; auto vocab_size = init_ctx->vocab_size; - auto bitset_size = BitsetManager::CalculateBufferSize(vocab_size); + auto bitset_size = DynamicBitset::CalculateBufferSize(vocab_size); auto ndarray = NDArray::Empty(ShapeTuple{static_cast(bitset_size)}, DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0}); auto dltensor = const_cast(ndarray.operator->()); @@ -605,7 +640,7 @@ IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose = fals end = std::chrono::high_resolution_clock::now(); } - auto bitset = BitsetManager(reinterpret_cast(dltensor->data), bitset_size, vocab_size); + auto bitset = DynamicBitset(vocab_size, reinterpret_cast(dltensor->data)); std::vector rejected_ids; for (int i = 0; i < vocab_size; i++) { if (bitset[i] == 0) { @@ -614,10 +649,10 @@ IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose = fals } if (verbose) { - std::cerr << "FindNextTokenBitmask takes " + LOG(INFO) << "FindNextTokenBitmask takes " << std::chrono::duration_cast(end - start).count() << "us" << ", found accepted: " << vocab_size - rejected_ids.size() - << ", rejected: " << rejected_ids.size() << std::endl; + << ", rejected: " << rejected_ids.size(); } auto ret = IntTuple(rejected_ids); @@ -634,7 +669,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFindNextRejectedTokens") NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher) { auto init_ctx = matcher.as()->init_ctx_; auto vocab_size = init_ctx->vocab_size; - auto bitset_size = BitsetManager::CalculateBufferSize(vocab_size); + auto bitset_size = DynamicBitset::CalculateBufferSize(vocab_size); auto bitmask = NDArray::Empty(ShapeTuple{static_cast(bitset_size)}, DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0}); auto dltensor = const_cast(bitmask.operator->()); diff --git a/cpp/serve/grammar/grammar_state_matcher.h b/cpp/serve/grammar/grammar_state_matcher.h index eceaa75d07..eedf7a1989 100644 --- a/cpp/serve/grammar/grammar_state_matcher.h +++ b/cpp/serve/grammar/grammar_state_matcher.h @@ -130,14 +130,13 @@ class GrammarStateMatcher : public ObjectRef { }; /*! - * \brief Helper class to get the grammar state init context for grammars or schemas. This class - * maintains cache internally, so the same grammar or schema will not be preprocessed multiple - * times. + * \brief A cache to get the grammar state init context for grammar or schema. This class avoids + * redundant preprocessing of the grammar or schema when constructing a GrammarStateInitContext. * \note This class is associated with a token table when constructed. The token table is used to * create every grammar state init context. If multiple toke tables are used to create init * contexts, an instance of this class for each token table should be created. */ -class GrammarInitContextStorageNode : public Object { +class GrammarInitContextCacheNode : public Object { public: /*! \brief Get the init context for pure JSON. */ virtual std::shared_ptr GetInitContextForJSON() = 0; @@ -147,25 +146,25 @@ class GrammarInitContextStorageNode : public Object { const std::string& schema) = 0; /*! \brief Clear the interal cache of init contexts. */ - virtual void ClearCache() = 0; + virtual void Clear() = 0; - static constexpr const char* _type_key = "mlc.serve.GrammarInitContextStorageNode"; + static constexpr const char* _type_key = "mlc.serve.GrammarInitContextCacheNode"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_BASE_OBJECT_INFO(GrammarInitContextStorageNode, Object); + TVM_DECLARE_BASE_OBJECT_INFO(GrammarInitContextCacheNode, Object); }; -class GrammarInitContextStorage : public ObjectRef { +class GrammarInitContextCache : public ObjectRef { public: /*! - * \brief Construct a GrammarInitContextStorage with a token table. This class will always create + * \brief Construct a GrammarInitContextCache with a token table. This class will always create * grammar state init contexts with this token table. * \param token_table The token table that the grammar will use. */ - GrammarInitContextStorage(const std::vector& token_table); + GrammarInitContextCache(const std::vector& token_table); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GrammarInitContextStorage, ObjectRef, - GrammarInitContextStorageNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GrammarInitContextCache, ObjectRef, + GrammarInitContextCacheNode); }; } // namespace serve diff --git a/cpp/serve/grammar/grammar_state_matcher_base.h b/cpp/serve/grammar/grammar_state_matcher_base.h index 5b774d33a4..1241e7307a 100644 --- a/cpp/serve/grammar/grammar_state_matcher_base.h +++ b/cpp/serve/grammar/grammar_state_matcher_base.h @@ -32,95 +32,172 @@ class GrammarStateMatcherBase { * \param grammar The grammar to match. * \param init_rule_position The initial rule position. If not specified, the main rule will be * used. + * \param expand_init_rule_position Whether to expand the initial rule position to all possible + * locations. See ExpandRulePosition. */ - GrammarStateMatcherBase(const BNFGrammar& grammar, RulePosition init_rule_position = {}) + GrammarStateMatcherBase(const BNFGrammar& grammar, + RulePosition init_rule_position = kInvalidRulePosition, + bool expand_init_rule_position = true) : grammar_(grammar), tree_(grammar), stack_tops_history_(&tree_) { - InitStackState(init_rule_position); + PushInitialState(init_rule_position, expand_init_rule_position); } - /*! \brief Accept one codepoint. */ - bool AcceptCodepoint(TCodepoint codepoint, bool verbose = false); + /*! \brief Accept one character. */ + bool AcceptChar(uint8_t char_value, bool verbose = false); /*! \brief Check if the end of the main rule is reached. If so, the stop token can be accepted. */ bool CanReachEnd() const; - /*! \brief Rollback the matcher to a previous state. */ - void RollbackCodepoints(int rollback_codepoint_cnt); + /*! \brief Rollback the matcher to a previous state by the number of characters. */ + void RollbackChars(int rollback_cnt); - /*! \brief Discard the earliest history. */ - void DiscardEarliestCodepoints(int discard_codepoint_cnt); + /*! \brief Discard the earliest history by the number of characters. */ + void DiscardEarliestChars(int discard_cnt); /*! \brief Print the stack state. */ std::string PrintStackState(int steps_behind_latest = 0) const; protected: - // Init the stack state according to the given rule position. - // If init_rule_position is {}, init the stack with the main rule. - void InitStackState(RulePosition init_rule_position = {}); + // Push an initial stack state according to the given rule position. + // If init_rule_position is kInvalidRulePosition, init the stack with the main rule. + void PushInitialState(RulePosition init_rule_position, bool expand_init_rule_position); - // Update the char_class_star_id field of the given rule_position, if it refers to a character - // class star rule. - void UpdateCharClassStarId(RulePosition* rule_position) const; + // Check if the character is accepted by the current rule position. + bool CheckIfAccepted(const RulePosition& rule_position, uint8_t char_value) const; /*! * \brief Find the next position in the rule. If the next position is at the end of the rule, - * the result depends on the consider_parent parameter: - * - false: kInvalidRulePosition will be returned. - * - true: the next position of the parent rule will be returned. If the current rule is the root - * rule, the RulePosition will be returned as is to indicate the end of the grammar. + * and consider_parent is true, will iteratively find the next position in the parent rule. * \param rule_position The current position. - * \param consider_parent Whether to consider the parent position if the current position is at - * the end of the rule. + * \param consider_parent Whether to consider the parent position if the current position is + * at the end of the rule. + * \returns (success, next_rule_position), indicating if the iteration is successful and the + * next rule position. */ - RulePosition IterateToNextPosition(const RulePosition& rule_position, bool consider_parent) const; + std::pair GetNextPositionInSequence(const RulePosition& rule_position, + bool consider_parent) const; + + // Return the updated rule position after accepting the char + RulePosition UpdatePositionWithChar(const RulePosition& rule_position, uint8_t char_value) const; /*! - * \brief Expand the given rule position (may be a RuleRef element) s.t. every new position is a - * CharacterClass or refers to a CharacterClassStar rule. Push all new positions into - * new_stack_tops. - * \details This method will start from cur_rule_position and continuously iterate to the next - * position as long as the current position can be empty (e.g. the current position is a - * reference to an rule that can be empty, or to a character class star rule). If the current - * position can not be empty, stop expanding. All positions collected will be pushed into - * new_stack_tops. + * \brief Expand the given rule position to all possible positions approachable in the grammar. + * The expanded positions must refers to an element (CharacterClass or CharacterClassStar or + * ByteString) in a rule. Push all new positions into new_stack_tops. + * \example + * A ::= "a" B [a-z]* "c" + * B ::= "b" | "" * - * If the end of the current rule is reached: - * - If is_outmost_level is true, we can go to the next position in the parent rule. - * - Otherwise, stop iteration. + * Input position: (rule=A, position=B) + * Approachable positions: (rule=B, position="b"), (rule=A, position=[a-z]*), + * (rule=A, position="c"), since B and [a-z]* can be empty. * \param cur_rule_position The current rule position. * \param new_stack_tops The vector to store the new stack tops. - * \param is_outmost_level Whether the current position is the outmost level of the rule. - * \param first_id_if_inserted Being not -1 means the first node is already inserted. This is the - * id of the first node. This is used to avoid inserting the same node twice. - * \return Whether the end of the rule can be reached. Used as the condition of recursion. + * \param consider_parent Whether consider expanding the elements in the parent rule. Useful for + * inner recursion. + * \param first_id_if_inserted An optimization. When cur_rule_position is already inserted to + * the state tree, pass its id to avoid inserting it again. -1 (ignore it) by default. + * \return Whether the end of the rule can be reached. Useful for inner recursion. */ bool ExpandRulePosition(RulePosition cur_rule_position, std::vector* new_stack_tops, - bool is_outmost_level, int32_t first_id_if_inserted = -1); + bool consider_parent = true, int32_t first_id_if_inserted = -1); + // The matched grammar. BNFGrammar grammar_; + // The tree storing all states RulePositionTree tree_; + // The tracked history of stack tops (each stack top refers to a node in the tree). + // We store the stack tops in different steps in the history to support rollback. StackTopsHistory stack_tops_history_; - // Temporary data for AcceptCodepoint. + // Temporary data for AcceptChar. std::vector tmp_new_stack_tops_; }; /*! \brief Check the codepoint is contained in the character class. */ -inline bool CharacterClassContains(const BNFGrammarNode::RuleExpr& rule_expr, - TCodepoint codepoint) { - DCHECK(rule_expr.type == BNFGrammarNode::RuleExprType::kCharacterClass || - rule_expr.type == BNFGrammarNode::RuleExprType::kNegCharacterClass); - for (int i = 0; i < rule_expr.size(); i += 2) { - if (rule_expr.data[i] <= codepoint && codepoint <= rule_expr.data[i + 1]) { - return rule_expr.type == BNFGrammarNode::RuleExprType::kCharacterClass; +inline bool GrammarStateMatcherBase::CheckIfAccepted(const RulePosition& rule_position, + uint8_t char_value) const { + auto current_sequence = grammar_->GetRuleExpr(rule_position.sequence_id); + auto current_element = grammar_->GetRuleExpr(current_sequence[rule_position.element_id]); + if (current_element.type == RuleExprType::kCharacterClass || + current_element.type == RuleExprType::kCharacterClassStar) { + if (rule_position.left_utf8_bytes > 0) { + return (char_value & 0xC0) == 0x80; + } + auto [accepted, num_bytes, codepoint] = HandleUTF8FirstByte(char_value); + if (!accepted) { + return false; + } + bool is_negative = static_cast(current_element[0]); + if (num_bytes > 1) { + return is_negative; + } + for (int i = 1; i < current_element.size(); i += 2) { + if (current_element[i] <= char_value && char_value <= current_element[i + 1]) { + return !is_negative; + } + } + return is_negative; + } else if (current_element.type == RuleExprType::kByteString) { + return current_element[rule_position.element_in_string] == char_value; + } else { + LOG(FATAL) << "Unexpected RuleExprType in CheckIfAccepted: " + << static_cast(current_element.type); + } +} + +inline RulePosition GrammarStateMatcherBase::UpdatePositionWithChar( + const RulePosition& rule_position, uint8_t char_value) const { + auto current_sequence = grammar_->GetRuleExpr(rule_position.sequence_id); + auto current_element = grammar_->GetRuleExpr(current_sequence[rule_position.element_id]); + RulePosition new_rule_position = rule_position; + switch (current_element.type) { + case RuleExprType::kCharacterClass: { + if (rule_position.left_utf8_bytes > 1) { + new_rule_position.left_utf8_bytes -= 1; + return new_rule_position; + } else if (rule_position.left_utf8_bytes == 1) { + return GetNextPositionInSequence(rule_position, true).second; + } + // If no left utf8 bytes, check the first byte to find the left bytes needed. + DCHECK(rule_position.left_utf8_bytes == 0); + auto [accepted, num_bytes, codepoint] = HandleUTF8FirstByte(char_value); + DCHECK(accepted); + if (num_bytes > 1) { + new_rule_position.left_utf8_bytes = num_bytes - 1; + return new_rule_position; + } + return GetNextPositionInSequence(rule_position, true).second; + } + case RuleExprType::kCharacterClassStar: { + if (rule_position.left_utf8_bytes >= 1) { + new_rule_position.left_utf8_bytes -= 1; + } else { + DCHECK(rule_position.left_utf8_bytes == 0); + auto [accepted, num_bytes, codepoint] = HandleUTF8FirstByte(char_value); + DCHECK(accepted); + new_rule_position.left_utf8_bytes = num_bytes - 1; + } + return new_rule_position; + } + case RuleExprType::kByteString: { + if (rule_position.element_in_string + 1 < current_element.size()) { + new_rule_position.element_in_string += 1; + return new_rule_position; + } + return GetNextPositionInSequence(rule_position, true).second; } + default: + LOG(FATAL) << "Unexpected RuleExprType in UpdatePositionWithChar: " + << static_cast(current_element.type); } - return rule_expr.type == BNFGrammarNode::RuleExprType::kNegCharacterClass; } -inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool verbose) { +inline bool GrammarStateMatcherBase::AcceptChar(uint8_t char_value, bool verbose) { if (verbose) { - std::cout << "Stack before accepting: " << PrintStackState() << std::endl; + LOG(INFO) << "Matching char: " << static_cast(char_value) << " \"" + << PrintAsEscaped(char_value) << "\""; + LOG(INFO) << "Previous stack: " << PrintStackState(); } const auto& prev_stack_tops = stack_tops_history_.GetLatest(); @@ -135,37 +212,31 @@ inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool continue; } - auto current_char_class = - cur_rule_position.char_class_star_id != -1 - ? grammar_->GetRuleExpr(cur_rule_position.char_class_star_id) - : grammar_->GetRuleExpr(current_sequence[cur_rule_position.element_id]); - DCHECK(current_char_class.type == RuleExprType::kCharacterClass || - current_char_class.type == RuleExprType::kNegCharacterClass); - auto ok = CharacterClassContains(current_char_class, codepoint); - if (!ok) { + auto accepted = CheckIfAccepted(cur_rule_position, char_value); + if (!accepted) { continue; } - if (cur_rule_position.char_class_star_id == -1) { - auto next_rule_position = IterateToNextPosition(cur_rule_position, true); - DCHECK(next_rule_position != kInvalidRulePosition); - ExpandRulePosition(next_rule_position, &tmp_new_stack_tops_, true); + auto new_rule_position = UpdatePositionWithChar(cur_rule_position, char_value); + + if (new_rule_position == cur_rule_position) { + ExpandRulePosition(new_rule_position, &tmp_new_stack_tops_, true, prev_top); } else { - ExpandRulePosition(cur_rule_position, &tmp_new_stack_tops_, true, prev_top); + ExpandRulePosition(new_rule_position, &tmp_new_stack_tops_, true); } } if (tmp_new_stack_tops_.empty()) { if (verbose) { - std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Rejected" - << std::endl; + LOG(INFO) << "Character " << static_cast(char_value) << " \"" + << PrintAsEscaped(char_value) << "\" Rejected"; } return false; } stack_tops_history_.PushHistory(tmp_new_stack_tops_); if (verbose) { - std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Accepted" - << std::endl; - std::cout << "Stack after accepting: " << PrintStackState() << std::endl; + LOG(INFO) << "Character: " << static_cast(char_value) << " \"" + << PrintAsEscaped(char_value) << "\" Accepted"; + LOG(INFO) << "New stack after acceptance: " << PrintStackState(); } #if TVM_LOG_DEBUG stack_tops_history_.CheckWellFormed(); @@ -179,80 +250,92 @@ inline bool GrammarStateMatcherBase::CanReachEnd() const { [&](int32_t id) { return tree_.IsEndPosition(tree_[id]); }); } -inline void GrammarStateMatcherBase::RollbackCodepoints(int rollback_codepoint_cnt) { - stack_tops_history_.Rollback(rollback_codepoint_cnt); +inline void GrammarStateMatcherBase::RollbackChars(int rollback_cnt) { + stack_tops_history_.Rollback(rollback_cnt); } -inline void GrammarStateMatcherBase::DiscardEarliestCodepoints(int discard_codepoint_cnt) { - stack_tops_history_.DiscardEarliest(discard_codepoint_cnt); +inline void GrammarStateMatcherBase::DiscardEarliestChars(int discard_cnt) { + stack_tops_history_.DiscardEarliest(discard_cnt); } inline std::string GrammarStateMatcherBase::PrintStackState(int steps_behind_latest) const { return stack_tops_history_.PrintHistory(steps_behind_latest); } -inline void GrammarStateMatcherBase::InitStackState(RulePosition init_rule_position) { +inline void GrammarStateMatcherBase::PushInitialState(RulePosition init_rule_position, + bool expand_init_rule_position) { if (init_rule_position == kInvalidRulePosition) { // Initialize the stack with the main rule. auto main_rule = grammar_->GetMainRule(); auto main_rule_body = grammar_->GetRuleExpr(main_rule.body_expr_id); - std::vector new_stack_tops; + std::vector stack_tops; for (auto i : main_rule_body) { auto init_rule_position = RulePosition(0, i, 0, RulePosition::kNoParent); - UpdateCharClassStarId(&init_rule_position); - ExpandRulePosition(init_rule_position, &new_stack_tops, true); + if (expand_init_rule_position) { + ExpandRulePosition(init_rule_position, &stack_tops, true); + } else { + stack_tops.push_back(tree_.NewNode(init_rule_position)); + } } - stack_tops_history_.PushHistory(new_stack_tops); + stack_tops_history_.PushHistory(stack_tops); } else { - stack_tops_history_.PushHistory({tree_.NewNode(init_rule_position)}); - } -} - -inline void GrammarStateMatcherBase::UpdateCharClassStarId(RulePosition* rule_position) const { - auto rule_expr = grammar_->GetRuleExpr(rule_position->sequence_id); - auto element = grammar_->GetRuleExpr(rule_expr[rule_position->element_id]); - if (element.type == RuleExprType::kRuleRef) { - auto sub_rule_body = grammar_->GetRuleExpr(grammar_->GetRule(element[0]).body_expr_id); - if (sub_rule_body.type == RuleExprType::kCharacterClassStar) { - rule_position->char_class_star_id = sub_rule_body[0]; + if (expand_init_rule_position) { + std::vector stack_tops; + ExpandRulePosition(init_rule_position, &stack_tops, true); + stack_tops_history_.PushHistory(stack_tops); + } else { + stack_tops_history_.PushHistory({tree_.NewNode(init_rule_position)}); } } } -inline RulePosition GrammarStateMatcherBase::IterateToNextPosition( +inline std::pair GrammarStateMatcherBase::GetNextPositionInSequence( const RulePosition& rule_position, bool consider_parent) const { - auto next_position = RulePosition(rule_position.rule_id, rule_position.sequence_id, - rule_position.element_id + 1, rule_position.parent_id); - auto rule_expr = grammar_->GetRuleExpr(rule_position.sequence_id); - auto current_sequence_length = rule_expr.size(); - DCHECK(next_position.element_id <= current_sequence_length); - - if (next_position.element_id < current_sequence_length) { - // Update char_class_star_id if the position refers to a character class star rule. - UpdateCharClassStarId(&next_position); - return next_position; + auto sequence = grammar_->GetRuleExpr(rule_position.sequence_id); + + auto next_position = rule_position; + next_position.element_id += 1; + next_position.element_in_string = 0; + next_position.left_utf8_bytes = 0; + + DCHECK(next_position.element_id <= sequence.size()); + + if (next_position.element_id < sequence.size()) { + return {true, next_position}; } if (!consider_parent) { - return kInvalidRulePosition; + return {false, kInvalidRulePosition}; } - if (next_position.parent_id == RulePosition::kNoParent) { - return next_position; - } else { - auto parent_rule_position = tree_[next_position.parent_id]; - return IterateToNextPosition(parent_rule_position, true); + // Find the next position in the parent rule + while (next_position.parent_id != RulePosition::kNoParent) { + next_position = tree_[next_position.parent_id]; + next_position.element_id += 1; + DCHECK(next_position.element_in_string == 0); + DCHECK(next_position.left_utf8_bytes == 0); + + sequence = grammar_->GetRuleExpr(next_position.sequence_id); + DCHECK(next_position.element_id <= sequence.size()); + + if (next_position.element_id < sequence.size()) { + break; + } } + + return {true, next_position}; } inline bool GrammarStateMatcherBase::ExpandRulePosition(RulePosition cur_rule_position, std::vector* new_stack_tops, - bool is_outmost_level, + bool consider_parent, int32_t first_id_if_inserted) { bool is_first = false; + bool is_iteration_successful = true; - for (; cur_rule_position != kInvalidRulePosition; - cur_rule_position = IterateToNextPosition(cur_rule_position, is_outmost_level)) { + for (; is_iteration_successful; + std::tie(is_iteration_successful, cur_rule_position) = + GetNextPositionInSequence(cur_rule_position, consider_parent)) { // Insert the node to the tree, if not inserted before. int32_t new_node_id; if (is_first && first_id_if_inserted != -1) { @@ -263,7 +346,7 @@ inline bool GrammarStateMatcherBase::ExpandRulePosition(RulePosition cur_rule_po is_first = false; // Case 1. The current position points to the end of the grammar. - if (is_outmost_level) { + if (consider_parent) { if (tree_.IsEndPosition(cur_rule_position)) { new_stack_tops->push_back(new_node_id); return true; @@ -272,42 +355,39 @@ inline bool GrammarStateMatcherBase::ExpandRulePosition(RulePosition cur_rule_po DCHECK(!tree_.IsEndPosition(cur_rule_position)); } - // Case 2. The current position refers to a character class star rule. It can be empty. - if (cur_rule_position.char_class_star_id != -1) { - new_stack_tops->push_back(new_node_id); - continue; - } - - // Case 3. Character class: cannot be empty. auto sequence = grammar_->GetRuleExpr(cur_rule_position.sequence_id); auto element = grammar_->GetRuleExpr(sequence[cur_rule_position.element_id]); - if (element.type == RuleExprType::kCharacterClass || - element.type == RuleExprType::kNegCharacterClass) { - new_stack_tops->push_back(new_node_id); - return false; - } - - // Case 4. The current position refers to a normal rule, i.e. a rule of choices of sequences. - DCHECK(element.type == RuleExprType::kRuleRef); - auto sub_rule_id = element[0]; - auto sub_rule = grammar_->GetRule(sub_rule_id); - auto sub_rule_body = grammar_->GetRuleExpr(sub_rule.body_expr_id); - DCHECK(sub_rule_body.type == RuleExprType::kChoices); - - bool contain_empty = false; - - for (auto sequence_id : sub_rule_body) { - auto sequence = grammar_->GetRuleExpr(sequence_id); - if (sequence.type == RuleExprType::kEmptyStr) { - contain_empty = true; - continue; + bool can_be_empty = false; + + if (element.type == RuleExprType::kRuleRef) { + // Case 2. The current position refers to another rule. + auto ref_rule = grammar_->GetRule(element[0]); + auto ref_rule_body = grammar_->GetRuleExpr(ref_rule.body_expr_id); + DCHECK(ref_rule_body.type == RuleExprType::kChoices); + + for (auto sequence_id : ref_rule_body) { + auto ref_rule_sequence = grammar_->GetRuleExpr(sequence_id); + if (ref_rule_sequence.type == RuleExprType::kEmptyStr) { + can_be_empty = true; + continue; + } + auto ref_rule_position = RulePosition(element[0], sequence_id, 0, new_node_id); + // Find the positions in every choice of the referred rule + can_be_empty |= ExpandRulePosition(ref_rule_position, new_stack_tops, false); } - auto sub_rule_position = RulePosition(sub_rule_id, sequence_id, 0, new_node_id); - UpdateCharClassStarId(&sub_rule_position); - contain_empty |= ExpandRulePosition(sub_rule_position, new_stack_tops, false); + } else if (element.type == RuleExprType::kCharacterClass || + element.type == RuleExprType::kByteString) { + // Case 3. Character class or byte string. cannot be empty. + new_stack_tops->push_back(new_node_id); + can_be_empty = false; + } else { + DCHECK(element.type == RuleExprType::kCharacterClassStar); + // Case 4. Character class star. Might be empty. + new_stack_tops->push_back(new_node_id); + can_be_empty = cur_rule_position.left_utf8_bytes == 0; } - if (!contain_empty) { + if (!can_be_empty) { return false; } } diff --git a/cpp/serve/grammar/grammar_state_matcher_preproc.h b/cpp/serve/grammar/grammar_state_matcher_preproc.h index f63eee2c5c..dc9fb9646e 100644 --- a/cpp/serve/grammar/grammar_state_matcher_preproc.h +++ b/cpp/serve/grammar/grammar_state_matcher_preproc.h @@ -9,6 +9,7 @@ #include #include "../../support/encoding.h" +#include "../../support/utils.h" #include "grammar.h" #include "grammar_state_matcher_base.h" @@ -18,34 +19,47 @@ namespace serve { using namespace tvm::runtime; -/*! \brief A token and its id. */ -struct TokenAndId { - std::vector token; - int32_t id; - /*! \brief Compare tokens by their unicode codepoint sequence. */ - bool operator<(const TokenAndId& other) const; -}; - /*! - * \brief Preprocessed information, for a given specific rule and position, divides the token set + * \brief Preprocessed information, for a given specific RulePosition, divides the token set * into three categories: accepted, rejected, and uncertain. - * \note Since the union of these three sets is the whole token set, we only need to store the - * smaller two sets. The unsaved set is specified by not_saved_index. - * \note These indices are the indices of sorted_token_codepoints in the GrammarStateInitContext + * Accepted: tokens that can be determined by the current RulePosition to be acceptable + * Rejected: tokens that can be determined by the current RulePosition to be unacceptable + * Uncertain: tokens that need the state of the parent RulePositions to determine if acceptable + * + * \note uncertain indices are stored directly. Accepted / rejected indices have three ways to + * store to reduce memory and computation usage. See SaveType. + * \note These indices are the indices of sorted_token_table in the GrammarStateInitContext * object, instead of the token ids. That helps the matching process. */ struct CatagorizedTokens { + enum class SaveType { + // Only store all accepted token indices. Then rejected indices = all_indices - accepted_indices + // - uncertain_indices. This is useful when |accepted_indices| < |rejected_indices|. + kAccepted = 0, + // Only store all accepted token indices. Then accepted indices = all_indices - rejected_indices + // - uncertain_indices. This is useful when |accepted_indices| > |rejected_indices|. + kRejected = 1, + // Store all accepted token indices in a bitset. This is useful when both |accepted_indices| and + // |rejected_indices| are large. + kAcceptedBitset = 2 + }; + SaveType save_type; + + static constexpr int USE_BITSET_THRESHOLD = 200; + std::vector accepted_indices; std::vector rejected_indices; + DynamicBitset accepted_bitset; + std::vector uncertain_indices; - enum class NotSavedIndex { kAccepted = 0, kRejected = 1, kUncertain = 2 }; - NotSavedIndex not_saved_index; CatagorizedTokens() = default; - CatagorizedTokens(std::vector&& accepted_indices, - std::vector&& rejected_indices, - std::vector&& uncertain_indices); + CatagorizedTokens(int vocab_size, + const std::vector>& sorted_token_table, + const std::vector& accepted_indices, + const std::vector& rejected_indices, + const std::vector& uncertain_indices); }; /*! @@ -57,189 +71,227 @@ class GrammarStateInitContext { public: /******************* Information about the tokenizer *******************/ - /*! \brief The token table. Now only used for debug purpose. */ - std::vector token_table; - /*! \brief The vocabulary size of the tokenizer. */ + /*! \brief The vocabulary size of the tokenizer. Special tokens are included. */ size_t vocab_size; - /*! \brief All tokens represented by the id and codepoints of each. The tokens are sorted by - * codepoint values to reuse the common prefix during matching. */ - std::vector sorted_token_codepoints; - /*! \brief The mapping from token id to token represented by codepoints. Only contains - * non-special and non-stop tokens. */ - std::unordered_map id_to_token_codepoints; - /*! \brief The stop tokens. They can be accepted iff GramamrMatcher can reach the end of the - * grammar. */ + /*! \brief The token table. Special tokens are included. */ + std::vector token_table; + /*! \brief All (id, token) pairs sorted in lexicographic order. This sorting is done to + * maximize prefix reuse during matching. Special tokens and stop tokens are not included. */ + std::vector> sorted_token_table; + /*! \brief The stop tokens. When the GrammarStateMatcher can reach the end of the= grammar, + * stop tokens can be accepted. */ std::vector stop_token_ids; - /*! \brief The special tokens. Currently we will ignore these tokens during grammar-guided - * matching. */ - std::vector special_token_ids; + /*! \brief The special tokens. These tokens are ignored (masked out) during the grammar-guided + * generation. */ + std::unordered_set special_token_ids; /******************* Information about the grammar *******************/ + /*! \brief The grammar for the GrammarStateMatcher. */ BNFGrammar grammar; /******************* Grammar-specific tokenizer information *******************/ - /*! \brief A sequence id and its position. */ - struct SequenceIdAndPosition { - int32_t sequence_id; - int32_t element_id; - bool operator==(const SequenceIdAndPosition& other) const { - return sequence_id == other.sequence_id && element_id == other.element_id; + struct RulePositionEqual { + std::size_t operator()(const RulePosition& lhs, const RulePosition& rhs) const noexcept { + return lhs.sequence_id == rhs.sequence_id && lhs.element_id == rhs.element_id && + lhs.left_utf8_bytes == rhs.left_utf8_bytes && + lhs.element_in_string == rhs.element_in_string; } }; - /*! \brief Hash function for SequenceIdAndPosition. */ - struct SequenceIdAndPositionHash { - std::size_t operator()(const SequenceIdAndPosition& k) const { - return std::hash()(k.sequence_id) ^ (std::hash()(k.element_id) << 1); + struct RulePositionHash { + std::size_t operator()(const RulePosition& rule_position) const noexcept { + return HashCombine(rule_position.sequence_id, rule_position.element_id, + rule_position.left_utf8_bytes, rule_position.element_in_string); } }; - /*! \brief Mapping from sequence id and its position to the catagorized tokens. */ - std::unordered_map + /*! \brief Mapping from RulePositions to the catagorized tokens. */ + std::unordered_map catagorized_tokens_for_grammar; }; -/* \brief The concrete implementation of GrammarStateMatcherNode. */ +/*! \brief The concrete implementation of GrammarStateMatcherNode. */ class GrammarStateMatcherForInitContext : public GrammarStateMatcherBase { public: + // Do not expand the initial rule position: we want to find the accepted/rejected tokens + // that exactly start from the initial rule position. GrammarStateMatcherForInitContext(const BNFGrammar& grammar, RulePosition init_rule_position) - : GrammarStateMatcherBase(grammar, init_rule_position) {} - - CatagorizedTokens GetCatagorizedTokens(const std::vector& sorted_token_codepoints, - bool is_main_rule); + : GrammarStateMatcherBase(grammar, init_rule_position, false), + init_rule_id(init_rule_position.rule_id) {} + + /*! + * \brief Get the catagorized tokens for the given RulePosition. + * \param consider_parent_rule Whether to consider the parent rule. If false, there will be + * no uncertain tokens. Useful for the main rule. + */ + CatagorizedTokens GetCatagorizedTokens( + int vocab_size, const std::vector>& sorted_token_table, + bool consider_parent_rule); private: using RuleExpr = BNFGrammarNode::RuleExpr; using RuleExprType = BNFGrammarNode::RuleExprType; + /*! \brief Check if a token can pass the lookahead assertion. */ + bool IsTokenPassLookaheadAssertion(const std::string& token, + const std::vector& can_reach_end_stack); + + // The id of the initial rule. + int32_t init_rule_id; + // Temporary data for GetCatagorizedTokens. std::vector tmp_accepted_indices_; std::vector tmp_rejected_indices_; std::vector tmp_uncertain_indices_; - std::vector tmp_can_see_end_stack_; + std::vector tmp_can_reach_end_stack_; + std::vector tmp_can_reach_end_prefix_or_stack_; }; -inline bool TokenAndId::operator<(const TokenAndId& other) const { - for (size_t i = 0; i < token.size(); ++i) { - if (i >= other.token.size()) { - return false; - } - if (token[i] < other.token[i]) { - return true; - } else if (token[i] > other.token[i]) { - return false; +inline CatagorizedTokens::CatagorizedTokens( + int vocab_size, const std::vector>& sorted_token_table, + const std::vector& accepted_indices, const std::vector& rejected_indices, + const std::vector& uncertain_indices) { + auto size_acc = accepted_indices.size(); + auto size_rej = rejected_indices.size(); + + save_type = size_acc >= USE_BITSET_THRESHOLD && size_rej >= USE_BITSET_THRESHOLD + ? SaveType::kAcceptedBitset + : size_acc < size_rej ? SaveType::kAccepted + : SaveType::kRejected; + + if (save_type == SaveType::kAcceptedBitset) { + accepted_bitset = DynamicBitset(vocab_size); + for (auto idx : accepted_indices) { + accepted_bitset.Set(sorted_token_table[idx].first, true); } + } else if (save_type == SaveType::kAccepted) { + this->accepted_indices = accepted_indices; + } else { + this->rejected_indices = rejected_indices; } - return token.size() < other.token.size(); + + this->uncertain_indices = uncertain_indices; } -inline CatagorizedTokens::CatagorizedTokens(std::vector&& accepted_indices, - std::vector&& rejected_indices, - std::vector&& uncertain_indices) { - auto size_acc = accepted_indices.size(); - auto size_rej = rejected_indices.size(); - auto size_unc = uncertain_indices.size(); - not_saved_index = - (size_acc >= size_rej && size_acc >= size_unc) - ? NotSavedIndex::kAccepted - : (size_rej >= size_unc ? NotSavedIndex::kRejected : NotSavedIndex::kUncertain); - - if (not_saved_index != NotSavedIndex::kAccepted) { - this->accepted_indices = std::move(accepted_indices); +bool GrammarStateMatcherForInitContext::IsTokenPassLookaheadAssertion( + const std::string& token, const std::vector& can_reach_end_stack) { + auto lookahead_assertion_id = grammar_->GetRule(init_rule_id).lookahead_assertion_id; + if (lookahead_assertion_id == -1) { + return true; } - if (not_saved_index != NotSavedIndex::kRejected) { - this->rejected_indices = std::move(rejected_indices); - } - if (not_saved_index != NotSavedIndex::kUncertain) { - this->uncertain_indices = std::move(uncertain_indices); + auto lookahead_rule_position = RulePosition(-1, lookahead_assertion_id, 0); + PushInitialState(lookahead_rule_position, true); + int token_len = token.size(); + + // Find all positions that can come to and end. Then check if the suffix from that position + // can be accepted by the lookahead assertion. + for (int i = static_cast(can_reach_end_stack.size()); i >= 0; --i) { + if (!can_reach_end_stack[i]) { + continue; + } + int last_accept_pos = i - 1; + for (int pos = i; pos < token_len; ++pos) { + if (!AcceptChar(token[pos])) { + break; + } + last_accept_pos = pos; + // Case 1. The whole rule is finished. + if (CanReachEnd()) { + // accepted chars: pos - i + 1 + // we need to rollback the pushed initial state as well + RollbackChars(pos - i + 2); + return true; + } + } + // Case 2. The whole token is accepted + if (last_accept_pos == token_len - 1) { + RollbackChars(last_accept_pos - i + 2); + return true; + } + // Case 3. The token is not accepted. Check the next position. + RollbackChars(last_accept_pos - i + 1); } + + RollbackChars(1); + return false; } inline CatagorizedTokens GrammarStateMatcherForInitContext::GetCatagorizedTokens( - const std::vector& sorted_token_codepoints, bool is_main_rule) { - // Support the current stack contains only one stack with one RulePosition. - // Iterate over all tokens. Split them into three categories: - // - accepted_indices: If a token is accepted by current rule - // - rejected_indices: If a token is rejected by current rule - // - uncertain_indices: If a prefix of a token is accepted by current rule and comes to the end - // of the rule. - - // Note many tokens may contain the same prefix, so we will avoid unnecessary matching - + int vocab_size, const std::vector>& sorted_token_table, + bool consider_parent_rule) { tmp_accepted_indices_.clear(); tmp_rejected_indices_.clear(); tmp_uncertain_indices_.clear(); + // For every character in the current token, stores whether it is possible to reach the end of - // the rule when matching until this character. Useful for rollback. - tmp_can_see_end_stack_.assign({CanReachEnd()}); + // the rule when matching until this character. Store it in a stack for later rollback. + tmp_can_reach_end_stack_.assign({CanReachEnd()}); + tmp_can_reach_end_prefix_or_stack_.assign({tmp_can_reach_end_stack_.back()}); int prev_matched_size = 0; - for (int i = 0; i < static_cast(sorted_token_codepoints.size()); ++i) { - const auto& token = sorted_token_codepoints[i].token; - const auto* prev_token = i > 0 ? &sorted_token_codepoints[i - 1].token : nullptr; - - // Find the longest common prefix with the accepted part of the previous token. - auto prev_useful_size = 0; - if (prev_token) { - prev_useful_size = std::min(prev_matched_size, static_cast(token.size())); - for (int j = 0; j < prev_useful_size; ++j) { - if (token[j] != (*prev_token)[j]) { - prev_useful_size = j; - break; - } - } - RollbackCodepoints(prev_matched_size - prev_useful_size); - tmp_can_see_end_stack_.erase( - tmp_can_see_end_stack_.end() - (prev_matched_size - prev_useful_size), - tmp_can_see_end_stack_.end()); - } + for (int i = 0; i < static_cast(sorted_token_table.size()); ++i) { + const auto& token = sorted_token_table[i].second; - // Find if the current token is accepted or rejected or uncertain. bool accepted = true; - bool can_see_end = tmp_can_see_end_stack_.back(); - prev_matched_size = prev_useful_size; - for (int j = prev_useful_size; j < token.size(); ++j) { - if (!AcceptCodepoint(token[j], false)) { + + // Many tokens may contain the same prefix, so we will avoid unnecessary matching + // by finding the longest common prefix with the previous token. + if (i > 0) { + const auto& prev_token = sorted_token_table[i - 1].second; + int lcp_len = + std::mismatch(token.begin(), token.end(), prev_token.begin(), prev_token.end()).first - + token.begin(); + if (lcp_len > prev_matched_size) { + // Case 1. The common prefix is rejected by the matcher in the last token. Reject directly. accepted = false; - break; + } else if (lcp_len < prev_matched_size) { + // Case 2. The common prefix is shorter than the previous matched size. Rollback + // the non-common part. + RollbackChars(prev_matched_size - lcp_len); + tmp_can_reach_end_stack_.erase( + tmp_can_reach_end_stack_.end() - (prev_matched_size - lcp_len), + tmp_can_reach_end_stack_.end()); + tmp_can_reach_end_prefix_or_stack_.erase( + tmp_can_reach_end_prefix_or_stack_.end() - (prev_matched_size - lcp_len), + tmp_can_reach_end_prefix_or_stack_.end()); } - if (CanReachEnd()) { - can_see_end = true; + prev_matched_size = std::min(prev_matched_size, lcp_len); + } + + if (accepted) { + // Accept the rest chars one by one + for (int j = prev_matched_size; j < token.size(); ++j) { + if (!AcceptChar(token[j], false)) { + accepted = false; + break; + } + tmp_can_reach_end_stack_.push_back(CanReachEnd()); + tmp_can_reach_end_prefix_or_stack_.push_back(tmp_can_reach_end_stack_.back() || + tmp_can_reach_end_prefix_or_stack_.back()); + prev_matched_size = j + 1; } - tmp_can_see_end_stack_.push_back(can_see_end); - prev_matched_size = j + 1; } + + bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back(); + if (accepted) { tmp_accepted_indices_.push_back(i); - } else if (can_see_end && !is_main_rule) { - // If the current rule is the main rule, there will be no uncertain indices since we will - // never consider its parent rule. Unaccepted tokens are just rejected. + } else if (can_reach_end && consider_parent_rule && + IsTokenPassLookaheadAssertion(token, tmp_can_reach_end_stack_)) { + // 1. If the current rule is the main rule (consider_parent_rule=false), there are no + // uncertain tokens. Not accepted tokens are just rejected. + // 2. If a token cannot pass the lookahead assertion, it is rejected. tmp_uncertain_indices_.push_back(i); } else { tmp_rejected_indices_.push_back(i); } } - RollbackCodepoints(prev_matched_size); - return CatagorizedTokens(std::move(tmp_accepted_indices_), std::move(tmp_rejected_indices_), - std::move(tmp_uncertain_indices_)); -} - -inline std::string ReplaceUnderscoreWithSpace(const std::string& str, - const std::string& kSpecialUnderscore) { - std::string res; - size_t pos = 0; - while (pos < str.size()) { - size_t found = str.find(kSpecialUnderscore, pos); - if (found == std::string::npos) { - res += str.substr(pos); - break; - } - res += str.substr(pos, found - pos) + " "; - pos = found + kSpecialUnderscore.size(); - } - return res; + // Rollback the last matched part + RollbackChars(prev_matched_size); + return CatagorizedTokens(vocab_size, sorted_token_table, tmp_accepted_indices_, + tmp_rejected_indices_, tmp_uncertain_indices_); } inline std::shared_ptr GrammarStateMatcher::CreateInitContext( @@ -248,87 +300,94 @@ inline std::shared_ptr GrammarStateMatcher::CreateInitC auto ptr = std::make_shared(); ptr->grammar = grammar; - ptr->token_table = token_table; ptr->vocab_size = token_table.size(); + ptr->token_table = token_table; if (ptr->vocab_size == 0) { return ptr; } for (int i = 0; i < token_table.size(); ++i) { - auto token = token_table[i]; - if (token == "" || token == "" || token == "") { - ptr->special_token_ids.push_back(i); - } else if (token == "") { + const auto& token = token_table[i]; + // LLaMA2: + // LLaMA3: <|end_of_text|>, <|eot_id|> + // Phi-2: <|endoftext|> + // Gemma: , + if (token == "" || token == "<|end_of_text|>" || token == "<|eot_id|>" || + token == "<|endoftext|>" || token == "" || token == "") { ptr->stop_token_ids.push_back(i); - } else if (token.size() == 1 && - (static_cast(token[0]) >= 128 || token[0] == 0)) { - // Currently we consider all tokens with one character that >= 128 as special tokens, - // and will ignore generating them during grammar-guided generation. - ptr->special_token_ids.push_back(i); + } else if ((token[0] == '<' && token[token.size() - 1] == '>' && token.size() >= 3) || + token == "[@BOS@]") { + // gemma treats [@BOS@] as a special token + ptr->special_token_ids.insert(i); } else { - // First replace the special underscore with space. - auto codepoints = ParseUTF8(token.c_str()); - DCHECK(!codepoints.empty() && - codepoints[0] != static_cast(CharHandlingError::kInvalidUtf8)) - << "Invalid token: " << token; - ptr->sorted_token_codepoints.push_back({codepoints, i}); - ptr->id_to_token_codepoints[i] = {codepoints, i}; + ptr->sorted_token_table.push_back({i, token}); } } - std::sort(ptr->sorted_token_codepoints.begin(), ptr->sorted_token_codepoints.end()); + + auto f_compare_token = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(ptr->sorted_token_table.begin(), ptr->sorted_token_table.end(), f_compare_token); // Find the corresponding catagorized tokens for: - // 1. All character elements in the grammar - // 2. All RuleRef elements that refers to a rule containing a CharacterClassStar RuleExpr. - for (int i = 0; i < static_cast(grammar->NumRules()); ++i) { - auto rule = grammar->GetRule(i); - auto rule_expr = grammar->GetRuleExpr(rule.body_expr_id); - // Skip CharacterClassStar since we just handle it at the reference element during matching. - if (rule_expr.type == RuleExprType::kCharacterClassStar) { - continue; - } - DCHECK(rule_expr.type == RuleExprType::kChoices); - for (auto sequence_id : rule_expr) { - auto sequence_expr = grammar->GetRuleExpr(sequence_id); - if (sequence_expr.type == RuleExprType::kEmptyStr) { + // 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3) + // 2. All byte strings (with element_in_string=0, 1, 2, ...) + auto main_rule_id = grammar->GetMainRuleId(); + for (int rule_id = 0; rule_id < static_cast(grammar->NumRules()); ++rule_id) { + auto rule = grammar->GetRule(rule_id); + auto rule_body = grammar->GetRuleExpr(rule.body_expr_id); + DCHECK(rule_body.type == RuleExprType::kChoices); + for (auto sequence_id : rule_body) { + auto sequence = grammar->GetRuleExpr(sequence_id); + if (sequence.type == RuleExprType::kEmptyStr) { continue; } - DCHECK(sequence_expr.type == RuleExprType::kSequence); - for (int element_id = 0; element_id < sequence_expr.size(); ++element_id) { - auto element_expr = grammar->GetRuleExpr(sequence_expr[element_id]); - auto cur_rule_position = RulePosition{i, sequence_id, element_id}; - if (element_expr.type == RuleExprType::kRuleRef) { - auto ref_rule = grammar->GetRule(element_expr[0]); - auto ref_rule_expr = grammar->GetRuleExpr(ref_rule.body_expr_id); - if (ref_rule_expr.type == RuleExprType::kChoices) { - continue; - } else { - // Reference to a CharacterClassStar of a character class. - cur_rule_position.char_class_star_id = ref_rule_expr[0]; - } + DCHECK(sequence.type == RuleExprType::kSequence); + for (int element_id = 0; element_id < sequence.size(); ++element_id) { + auto element = grammar->GetRuleExpr(sequence[element_id]); + if (element.type == RuleExprType::kRuleRef) { + continue; } - auto grammar_state_matcher = GrammarStateMatcherForInitContext(grammar, cur_rule_position); - auto cur_catagorized_tokens_for_grammar = - grammar_state_matcher.GetCatagorizedTokens(ptr->sorted_token_codepoints, i == 0); - ptr->catagorized_tokens_for_grammar[{sequence_id, element_id}] = - cur_catagorized_tokens_for_grammar; + auto add_catagorized_tokens = [&](const RulePosition& rule_position) { + auto grammar_state_matcher = GrammarStateMatcherForInitContext(grammar, rule_position); + auto cur_catagorized_tokens_for_grammar = grammar_state_matcher.GetCatagorizedTokens( + ptr->vocab_size, ptr->sorted_token_table, rule_id != main_rule_id); + ptr->catagorized_tokens_for_grammar[rule_position] = cur_catagorized_tokens_for_grammar; + }; + + auto cur_rule_position = RulePosition(rule_id, sequence_id, element_id); + if (element.type == RuleExprType::kByteString) { + for (int idx = 0; idx < element.size(); ++idx) { + cur_rule_position.element_in_string = idx; + add_catagorized_tokens(cur_rule_position); + } + } else { + DCHECK(element.type == RuleExprType::kCharacterClassStar || + element.type == RuleExprType::kCharacterClass); + for (int left_utf8_bytes = 0; left_utf8_bytes <= 3; ++left_utf8_bytes) { + cur_rule_position.left_utf8_bytes = left_utf8_bytes; + add_catagorized_tokens(cur_rule_position); + } + } } } } return ptr; } -class GrammarInitContextStorageImpl : public GrammarInitContextStorageNode { +class GrammarInitContextCacheImpl : public GrammarInitContextCacheNode { public: - GrammarInitContextStorageImpl(const std::vector& token_table); + GrammarInitContextCacheImpl(const std::vector& token_table); - std::shared_ptr GetInitContextForJSONSchema(const std::string& schema); + std::shared_ptr GetInitContextForJSONSchema( + const std::string& schema) final; - std::shared_ptr GetInitContextForJSON(); + std::shared_ptr GetInitContextForJSON() final; - void ClearCache(); + void Clear() final; private: /*! \brief The token table associated with this storage class. */ @@ -340,7 +399,7 @@ class GrammarInitContextStorageImpl : public GrammarInitContextStorageNode { std::shared_ptr init_ctx_for_json_; }; -inline GrammarInitContextStorageImpl::GrammarInitContextStorageImpl( +inline GrammarInitContextCacheImpl::GrammarInitContextCacheImpl( const std::vector& token_table) : token_table_(token_table) { init_ctx_for_json_ = @@ -348,7 +407,7 @@ inline GrammarInitContextStorageImpl::GrammarInitContextStorageImpl( } inline std::shared_ptr -GrammarInitContextStorageImpl::GetInitContextForJSONSchema(const std::string& schema) { +GrammarInitContextCacheImpl::GetInitContextForJSONSchema(const std::string& schema) { auto it = init_ctx_for_schema_cache_.find(schema); if (it != init_ctx_for_schema_cache_.end()) { return it->second; @@ -360,14 +419,14 @@ GrammarInitContextStorageImpl::GetInitContextForJSONSchema(const std::string& sc } inline std::shared_ptr -GrammarInitContextStorageImpl::GetInitContextForJSON() { +GrammarInitContextCacheImpl::GetInitContextForJSON() { return init_ctx_for_json_; } -inline void GrammarInitContextStorageImpl::ClearCache() { init_ctx_for_schema_cache_.clear(); } +inline void GrammarInitContextCacheImpl::Clear() { init_ctx_for_schema_cache_.clear(); } -GrammarInitContextStorage::GrammarInitContextStorage(const std::vector& token_table) - : ObjectRef(make_object(token_table)) {} +GrammarInitContextCache::GrammarInitContextCache(const std::vector& token_table) + : ObjectRef(make_object(token_table)) {} } // namespace serve } // namespace llm diff --git a/cpp/serve/grammar/grammar_state_matcher_state.h b/cpp/serve/grammar/grammar_state_matcher_state.h index 47f3e11c7b..1b8a34074f 100644 --- a/cpp/serve/grammar/grammar_state_matcher_state.h +++ b/cpp/serve/grammar/grammar_state_matcher_state.h @@ -20,18 +20,20 @@ using namespace tvm::runtime; /*! \brief Specifies a position in a rule. */ struct RulePosition { - /*! \brief The rule's id. */ + /*! \brief The rule's id. Used for debug purposes. */ int32_t rule_id = -1; /*! \brief Which choice in this rule is selected. */ int32_t sequence_id = -1; - /*! \brief Which element of the choice sequence is being visited. */ + /*! \brief Which element of the choice sequence is to be visited. */ int32_t element_id = -1; - /*! - * \brief If the element refers to another rule, and the body of another rule is a - * CharacterClassStar RuleExpr, this field will be set to the id of the character class. - * This is for the special support of CharacterClassStar. - */ - int32_t char_class_star_id = -1; + + /*! \brief The number of left utf8 bytes in the current element. Used when the element is + * a character class or a character class star. */ + int32_t left_utf8_bytes = 0; + /*! \brief The next position to match in the current byte string. Used when the element is + * a byte string. */ + int32_t element_in_string = 0; + /*! \brief The id of the parent node in the RulePositionTree. */ int32_t parent_id = -1; /*! \brief The reference count of this RulePosition. If reduces to zero, the node will be @@ -43,24 +45,21 @@ struct RulePosition { constexpr RulePosition() = default; constexpr RulePosition(int32_t rule_id, int32_t sequence_id, int32_t element_id, - int32_t parent_id = kNoParent, int32_t char_class_star_id = -1) - : rule_id(rule_id), - sequence_id(sequence_id), - element_id(element_id), - char_class_star_id(char_class_star_id), - parent_id(parent_id) {} + int32_t parent_id = kNoParent) + : rule_id(rule_id), sequence_id(sequence_id), element_id(element_id), parent_id(parent_id) {} + + // The position is invalid when sequence_id is -1. + bool IsInvalid() const { return sequence_id == -1; } bool operator==(const RulePosition& other) const { return rule_id == other.rule_id && sequence_id == other.sequence_id && - element_id == other.element_id && char_class_star_id == other.char_class_star_id && - parent_id == other.parent_id; + element_id == other.element_id && parent_id == other.parent_id && + left_utf8_bytes == other.left_utf8_bytes && element_in_string == other.element_in_string; } - - bool operator!=(const RulePosition& other) const { return !(*this == other); } }; /*! \brief A special value for invalid RulePosition. */ -inline constexpr RulePosition kInvalidRulePosition(-1, -1, -1, -1, -1); +inline constexpr RulePosition kInvalidRulePosition(-1, -1, -1, -1); /*! \brief A buffer to manage all RulePositions. */ class RulePositionBuffer { @@ -76,7 +75,7 @@ class RulePositionBuffer { id = buffer_.size() - 1; } else { id = free_nodes_.back(); - DCHECK(buffer_[id] == kInvalidRulePosition); + DCHECK(buffer_[id].IsInvalid()); free_nodes_.pop_back(); } rule_position.reference_count = 0; @@ -86,7 +85,7 @@ class RulePositionBuffer { /*! \brief Free the RulePosition with the given id. */ void Free(int32_t id) { - DCHECK(buffer_[id] != kInvalidRulePosition); + DCHECK(!buffer_[id].IsInvalid()); buffer_[id] = kInvalidRulePosition; free_nodes_.push_back(id); } @@ -102,11 +101,13 @@ class RulePositionBuffer { /*! \brief Get the RulePosition with the given id. */ RulePosition& operator[](int32_t id) { - DCHECK(id < static_cast(buffer_.size()) && buffer_[id] != kInvalidRulePosition); + DCHECK(id >= 0 && id < static_cast(buffer_.size())); + DCHECK(!buffer_[id].IsInvalid()); return buffer_[id]; } const RulePosition& operator[](int32_t id) const { - DCHECK(id < static_cast(buffer_.size()) && buffer_[id] != kInvalidRulePosition); + DCHECK(id >= 0 && id < static_cast(buffer_.size())); + DCHECK(!buffer_[id].IsInvalid()); return buffer_[id]; } @@ -145,7 +146,7 @@ class RulePositionTree { auto id = node_buffer_.Allocate(rule_position); if (rule_position.parent_id != RulePosition::kNoParent) { DCHECK(rule_position.parent_id < static_cast(node_buffer_.Capacity()) && - node_buffer_[rule_position.parent_id] != kInvalidRulePosition); + !node_buffer_[rule_position.parent_id].IsInvalid()); node_buffer_[rule_position.parent_id].reference_count++; } return id; @@ -183,7 +184,7 @@ class RulePositionTree { /*! \brief Get the RulePosition with the given id. */ const RulePosition& operator[](int32_t id) const { DCHECK(id != RulePosition::kNoParent); - DCHECK(node_buffer_[id] != kInvalidRulePosition); + DCHECK(!node_buffer_[id].IsInvalid()); return node_buffer_[id]; } @@ -331,15 +332,26 @@ inline std::string RulePositionTree::PrintNode(int32_t id) const { inline std::string RulePositionTree::PrintNode(const RulePosition& rule_position) const { std::stringstream ss; - ss << "RulePosition: rule " << rule_position.rule_id << ": " - << grammar_->GetRule(rule_position.rule_id).name; + ss << "RulePosition: rule " << rule_position.rule_id; + if (rule_position.rule_id != -1) { + ss << ": " << grammar_->GetRule(rule_position.rule_id).name; + } ss << ", sequence " << rule_position.sequence_id << ": " << BNFGrammarPrinter(grammar_).PrintRuleExpr(rule_position.sequence_id); ss << ", element id: " << rule_position.element_id; - if (rule_position.char_class_star_id != -1) { - ss << ", char class " << rule_position.char_class_star_id << ": " - << BNFGrammarPrinter(grammar_).PrintRuleExpr(rule_position.char_class_star_id) << "*"; + + auto sequence = grammar_->GetRuleExpr(rule_position.sequence_id); + if (rule_position.element_id < static_cast(sequence.size())) { + auto element = grammar_->GetRuleExpr(sequence[rule_position.element_id]); + if (element.type == BNFGrammarNode::RuleExprType::kByteString) { + ss << ", element in string: " << rule_position.element_in_string; + } else { + DCHECK(element.type == BNFGrammarNode::RuleExprType::kCharacterClass || + element.type == BNFGrammarNode::RuleExprType::kCharacterClassStar); + ss << ", left utf8 bytes: " << rule_position.left_utf8_bytes; + } } + ss << ", parent id: " << rule_position.parent_id << ", ref count: " << rule_position.reference_count; return ss.str(); @@ -370,7 +382,7 @@ inline void RulePositionTree::CheckWellFormed(const std::vector& outsid std::queue visit_queue; for (auto id : outside_pointers) { CHECK(id >= 0 && id < buffer_size); - CHECK(buffer[id] != kInvalidRulePosition); + CHECK(!buffer[id].IsInvalid()); new_reference_counter[id]++; if (visited[id] == false) { visited[id] = true; @@ -383,7 +395,7 @@ inline void RulePositionTree::CheckWellFormed(const std::vector& outsid const auto& rule_position = buffer[cur_id]; if (rule_position.parent_id != RulePosition::kNoParent) { CHECK(rule_position.parent_id >= 0 && rule_position.parent_id < buffer_size); - CHECK(buffer[rule_position.parent_id] != kInvalidRulePosition); + CHECK(!buffer[rule_position.parent_id].IsInvalid()); new_reference_counter[rule_position.parent_id]++; if (visited[rule_position.parent_id] == false) { visited[rule_position.parent_id] = true; @@ -394,11 +406,11 @@ inline void RulePositionTree::CheckWellFormed(const std::vector& outsid for (int i = 0; i < static_cast(buffer.size()); ++i) { if (free_nodes_set.count(i)) { - CHECK(buffer[i] == kInvalidRulePosition); + CHECK(buffer[i].IsInvalid()); CHECK(visited[i] == false); } else { CHECK(visited[i] == true); - CHECK(buffer[i] != kInvalidRulePosition); + CHECK(!buffer[i].IsInvalid()); CHECK(new_reference_counter[i] == buffer[i].reference_count) << "Reference counters unmatch for node #" << i << ": Updated " << new_reference_counter[i] << ", Original " << buffer[i].reference_count; diff --git a/cpp/serve/grammar/json_schema_converter.cc b/cpp/serve/grammar/json_schema_converter.cc index 83be710cf5..e0c465ba9e 100644 --- a/cpp/serve/grammar/json_schema_converter.cc +++ b/cpp/serve/grammar/json_schema_converter.cc @@ -385,9 +385,9 @@ void JSONSchemaToEBNFConverter::AddBasicRules() { void JSONSchemaToEBNFConverter::AddHelperRules() { rules_.push_back(std::make_pair( kBasicEscape, "[\"\\\\/bfnrt] | \"u\" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]")); - rules_.push_back(std::make_pair(kBasicStringSub, "\"\" | [^\"\\\\\\r\\n] " + kBasicStringSub + - " | \"\\\\\" " + kBasicEscape + " " + - kBasicStringSub)); + rules_.push_back(std::make_pair( + kBasicStringSub, "(\"\\\"\" | [^\"\\\\\\r\\n] " + kBasicStringSub + " | \"\\\\\" " + + kBasicEscape + " " + kBasicStringSub + ") (= [ \\n\\t]* [,}\\]:])")); } void JSONSchemaToEBNFConverter::CreateBasicRule(const picojson::value& schema, @@ -648,7 +648,7 @@ std::string JSONSchemaToEBNFConverter::VisitString(const picojson::object& schem "pattern", "format", }); - return "[\"] " + kBasicStringSub + " [\"]"; + return "[\"] " + kBasicStringSub; } std::string JSONSchemaToEBNFConverter::VisitBoolean(const picojson::object& schema, diff --git a/cpp/serve/grammar/support.h b/cpp/serve/grammar/support.h index fb9002dbac..c8b3f34344 100644 --- a/cpp/serve/grammar/support.h +++ b/cpp/serve/grammar/support.h @@ -8,30 +8,72 @@ #include +#include #include #include +#include namespace mlc { namespace llm { namespace serve { -/*! \brief Manages a segment of externally provided memory and use it as a bitset. */ -class BitsetManager { +/*! \brief A bitset with runtime specified length. It manages memory internally or the memory + * provided externally with enough size. */ +class DynamicBitset { public: - BitsetManager(uint32_t* data, int buffer_size, int element_cnt) - : data_(data), buffer_size_(buffer_size), element_cnt_(element_cnt) { - DCHECK(buffer_size >= CalculateBufferSize(element_cnt)); + static int CalculateBufferSize(int element_size) { return (element_size + 31) / 32; } + + DynamicBitset() : size_(0), buffer_size_(0), data_(nullptr), is_internal_(true) {} + + DynamicBitset(int size, uint32_t* data = nullptr) + : size_(size), buffer_size_(CalculateBufferSize(size)) { + if (data == nullptr) { + internal_buffer_.resize(buffer_size_, 0); + data_ = internal_buffer_.data(); + is_internal_ = true; + } else { + data_ = data; + is_internal_ = false; + } } - static int CalculateBufferSize(int element_cnt) { return (element_cnt + 31) / 32; } + DynamicBitset& operator=(const DynamicBitset& other) { + DCHECK(is_internal_ || size_ >= other.size_) << "Expanding bitset size is not allowed when the " + "memory of the bitset is externally managed"; + size_ = other.size_; + buffer_size_ = other.buffer_size_; + if (is_internal_) { + internal_buffer_.reserve(buffer_size_); + data_ = internal_buffer_.data(); + } + if (data_ != other.data_) { + std::memcpy(data_, other.data_, buffer_size_ * sizeof(uint32_t)); + } + return *this; + } + + DynamicBitset& operator=(DynamicBitset&& other) { + size_ = other.size_; + buffer_size_ = other.buffer_size_; + is_internal_ = other.is_internal_; + if (is_internal_) { + internal_buffer_ = std::move(other.internal_buffer_); + data_ = internal_buffer_.data(); + } else { + data_ = other.data_; + } + return *this; + } bool operator[](int index) const { - DCHECK(index >= 0 && index < element_cnt_); + DCHECK(data_ && index >= 0 && index < size_); return (data_[index / 32] >> (index % 32)) & 1; } + int Size() const { return size_; } + void Set(int index, bool value) { - DCHECK(index >= 0 && index < element_cnt_); + DCHECK(data_ && index >= 0 && index < size_); if (value) { data_[index / 32] |= 1 << (index % 32); } else { @@ -39,14 +81,30 @@ class BitsetManager { } } - void Reset(bool value) { std::memset(data_, value ? 0xFF : 0, buffer_size_ * sizeof(uint32_t)); } + void Set() { + DCHECK(data_); + std::memset(data_, 0xFF, buffer_size_ * sizeof(uint32_t)); + } + + void Reset() { + DCHECK(data_); + std::memset(data_, 0, buffer_size_ * sizeof(uint32_t)); + } - int GetElementCnt() const { return element_cnt_; } + DynamicBitset& operator|=(const DynamicBitset& other) { + DCHECK(buffer_size_ <= other.buffer_size_); + for (int i = 0; i < buffer_size_; ++i) { + data_[i] |= other.data_[i]; + } + return *this; + } private: - uint32_t* const data_; - const int buffer_size_; - const int element_cnt_; + int size_; + int buffer_size_; + uint32_t* data_; + std::vector internal_buffer_; + bool is_internal_; }; /*! diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index e16432c222..89b25827b8 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -67,11 +67,7 @@ class ModelImpl : public ModelObj { // Step 3. Reset this->Reset(); // Step 4. Set model type - if (json::Lookup(model_config, "model_type").find("rwkv") != std::string::npos) { - this->kind = KVStateKind::kRNNState; - } else { - this->kind = KVStateKind::kKVCache; - } + this->kind = GetMetadata().kv_state_kind; } /*********************** Model Computation ***********************/ @@ -149,6 +145,21 @@ class ModelImpl : public ModelObj { return logits; } + Array GetMultiStepLogits(const ObjectRef& hidden_states) final { + NVTXScopedRange nvtx_scope("GetMultiStepLogits"); + CHECK(ft_.get_logits_func_.defined()) << "`get_logits` function is not found in the model."; + + ObjectRef hidden_states_dref_or_nd{nullptr}; + ObjectRef ret = ft_.get_logits_func_(hidden_states, params_); + Array logits{nullptr}; + if (ft_.use_disco) { + logits = Downcast(ret)->DebugGetFromRemote(0); + } else { + logits = Downcast>(ret); + } + return logits; + } + ObjectRef FuseEmbedHidden(const ObjectRef& embeddings, const ObjectRef& previous_hidden_states, int batch_size, int seq_len) final { NVTXScopedRange nvtx_scope("FuseEmbedHidden"); @@ -563,8 +574,9 @@ class ModelImpl : public ModelObj { /*********************** KV Cache Management ***********************/ void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size, int max_history_size, - KVStateKind kv_state_kind) final { + int prefill_chunk_size, int max_history_size) final { + // KVStateKind kv_state_kind) final { + KVStateKind kv_state_kind = GetMetadata().kv_state_kind; if (kv_state_kind == KVStateKind::kKVCache) { IntTuple max_num_sequence_tuple{max_num_sequence}; IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; @@ -576,30 +588,51 @@ class ModelImpl : public ModelObj { support_sliding_window); local_kv_cache_ = ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; - } else { + } else if (kv_state_kind == KVStateKind::kRNNState) { IntTuple max_num_sequence_tuple{max_num_sequence}; IntTuple max_history_size_tuple = {std::max(max_history_size, 1)}; kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_history_size_tuple); local_kv_cache_ = ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + } else if (kv_state_kind == KVStateKind::kNone) { + // Do nothing + } else { + LOG(FATAL) << "Unknown kv_state_kind: " << static_cast(kv_state_kind); } } - void AddNewSequence(int64_t seq_id) final { ft_.kv_cache_add_sequence_func_(kv_cache_, seq_id); } + void AddNewSequence(int64_t seq_id) final { + if (ft_.model_metadata_.kv_state_kind == KVStateKind::kNone) { + return; + } + ft_.kv_cache_add_sequence_func_(kv_cache_, seq_id); + } void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos) final { + if (ft_.model_metadata_.kv_state_kind == KVStateKind::kNone) { + return; + } ft_.kv_cache_fork_sequence_func_(kv_cache_, parent_seq_id, child_seq_id, fork_pos); } void RemoveSequence(int64_t seq_id) final { + if (this->kind == KVStateKind::kNone) { + return; + } ft_.kv_cache_remove_sequence_func_(kv_cache_, seq_id); } void PopNFromKVCache(int64_t seq_id, int num_tokens) final { + if (this->kind == KVStateKind::kNone) { + return; + } ft_.kv_cache_popn_func_(kv_cache_, seq_id, num_tokens); } void EnableSlidingWindowForSeq(int64_t seq_id) final { + if (this->kind == KVStateKind::kNone) { + return; + } if (sliding_window_size_ != -1) { ft_.kv_cache_enable_sliding_window_for_seq_(kv_cache_, seq_id, sliding_window_size_, attention_sink_size_); @@ -620,7 +653,7 @@ class ModelImpl : public ModelObj { } int GetCurrentTotalSequenceLength() const final { - if (this->kind == KVStateKind::kRNNState) { + if (this->kind == KVStateKind::kRNNState || this->kind == KVStateKind::kNone) { // RNNState does not have a total sequence length limit return 0; } else { @@ -670,6 +703,9 @@ class ModelImpl : public ModelObj { } ObjectRef AllocEmbeddingTensor() final { + if (!ft_.alloc_embedding_tensor_func_.defined()) { + return ObjectRef{nullptr}; + } // Allocate the embedding tensor. ObjectRef embedding = ft_.alloc_embedding_tensor_func_(); // Get the shape of the embedding tensor for hidden size. @@ -690,6 +726,9 @@ class ModelImpl : public ModelObj { } ObjectRef AllocHiddenStatesTensor() final { + if (!ft_.alloc_embedding_tensor_func_.defined()) { + return ObjectRef{nullptr}; + } // Allocate the hidden_states tensor. // Use the same function as embeddings. ObjectRef hidden_states = ft_.alloc_embedding_tensor_func_(); @@ -778,6 +817,17 @@ class ModelImpl : public ModelObj { ft_.scatter_probs_func_(input, indices_device, *dst); } + Array GetMedusaLogits(const ObjectRef& hidden_states) { + ObjectRef result = ft_.get_logits_func_(hidden_states); + Array logits{nullptr}; + if (ft_.use_disco) { + logits = Downcast(result)->DebugGetFromRemote(0); + } else { + logits = Downcast>(result); + } + return logits; + } + /************** Debug/Profile **************/ void DebugCallFuncOnAllAllWorker(const String& func_name) final { diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 96d2ecb401..41fccf8d0b 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -139,6 +139,8 @@ class ModelObj : public Object { */ virtual NDArray GetLogits(const ObjectRef& last_hidden_states) = 0; + virtual Array GetMultiStepLogits(const ObjectRef& last_hidden_states) = 0; + /*! * \brief Batch prefill function. Embedding in, logits out. * The embedding order of sequences in `embedding_arr` follows @@ -224,11 +226,9 @@ class ModelObj : public Object { * are allowed to exist in the KV cache at any time. * \param max_history_size The maximum history size for RNN state to roll back. * The KV cache does not need this. - * \param kv_state_kind The kind of cache. It can be KV cache or RNN state. */ virtual void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size, int max_history_size, - KVStateKind kv_state_kind) = 0; + int prefill_chunk_size, int max_history_size) = 0; /*! \brief Add a new sequence with the given sequence id to the KV cache. */ virtual void AddNewSequence(int64_t seq_id) = 0; diff --git a/cpp/support/encoding.cc b/cpp/support/encoding.cc index d9420bbbd5..9f33f98a7e 100644 --- a/cpp/support/encoding.cc +++ b/cpp/support/encoding.cc @@ -36,14 +36,15 @@ std::string PrintAsUTF8(TCodepoint codepoint) { return utf8; } -std::string PrintAsEscaped(TCodepoint codepoint, - const std::unordered_map& custom_escape_map) { +std::string PrintAsEscaped( + TCodepoint codepoint, + const std::unordered_map& additional_escape_map) { static const std::unordered_map kCodepointToEscape = { {'\'', "\\\'"}, {'\"', "\\\""}, {'\?', "\\\?"}, {'\\', "\\\\"}, {'\a', "\\a"}, {'\b', "\\b"}, {'\f', "\\f"}, {'\n', "\\n"}, {'\r', "\\r"}, {'\t', "\\t"}, {'\v', "\\v"}, {'\0', "\\0"}, {'\x1B', "\\e"}}; - if (auto it = custom_escape_map.find(codepoint); it != custom_escape_map.end()) { + if (auto it = additional_escape_map.find(codepoint); it != additional_escape_map.end()) { return it->second; } @@ -56,14 +57,24 @@ std::string PrintAsEscaped(TCodepoint codepoint, } // convert codepoint to hex - int width = codepoint <= 0xFFFF ? 4 : 8; + char prefix = codepoint <= 0xFF ? 'x' : codepoint <= 0xFFFF ? 'u' : 'U'; + int width = codepoint <= 0xFF ? 2 : codepoint <= 0xFFFF ? 4 : 8; std::stringstream ss; ss << std::setfill('0') << std::setw(width) << std::hex << codepoint; auto hex = ss.str(); - return codepoint <= 0xFFFF ? "\\u" + hex : "\\U" + hex; + return std::string("\\") + prefix + hex; } -std::pair ParseNextUTF8(const char* utf8) { +std::string PrintAsEscaped(std::string raw_str) { + std::string res; + auto codepoints = ParseUTF8(raw_str.c_str(), UTF8ErrorPolicy::kReturnByte); + for (auto c : codepoints) { + res += PrintAsEscaped(c); + } + return res; +} + +std::tuple HandleUTF8FirstByte(uint8_t byte) { static const std::array kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07}; // clang-format off static const std::array kUtf8Bytes = { @@ -85,30 +96,44 @@ std::pair ParseNextUTF8(const char* utf8) { 4, 4, 4, 4, 4, 4, 4, 4, -1, -1, -1, -1, -1, -1, -1, -1, }; // clang-format on + auto num_bytes = kUtf8Bytes[static_cast(byte)]; + if (num_bytes == -1) { + return {false, 0, 0}; + } + return {true, num_bytes, byte & kFirstByteMask[num_bytes]}; +} - auto bytes = kUtf8Bytes[static_cast(utf8[0])]; - if (bytes == -1) { - // invalid utf8 - return {static_cast(CharHandlingError::kInvalidUtf8), utf8}; +std::pair ParseNextUTF8(const char* utf8, UTF8ErrorPolicy error_policy) { + auto [accepted, num_bytes, res] = HandleUTF8FirstByte(utf8[0]); + if (accepted) { + for (int i = 1; i < num_bytes; ++i) { + if (utf8[i] == 0 || (static_cast(utf8[i]) & 0xC0) != 0x80) { + // invalid utf8 + accepted = false; + break; + } + res = (res << 6) | (static_cast(utf8[i]) & 0x3F); + } } - TCodepoint res = static_cast(utf8[0]) & kFirstByteMask[bytes]; - for (int i = 1; i < bytes; ++i) { - if (utf8[i] == 0 || (static_cast(utf8[i]) & 0xC0) != 0x80) { - // invalid utf8 - return {static_cast(CharHandlingError::kInvalidUtf8), 0}; + if (!accepted) { + // invalid utf8 + if (error_policy == UTF8ErrorPolicy::kReturnInvalid) { + return {CharHandlingError::kInvalidUTF8, utf8}; + } else { + return {static_cast(utf8[0]), utf8 + 1}; } - res = (res << 6) | (static_cast(utf8[i]) & 0x3F); } - return {res, utf8 + bytes}; + + return {res, utf8 + num_bytes}; } -std::vector ParseUTF8(const char* utf8) { +std::vector ParseUTF8(const char* utf8, UTF8ErrorPolicy error_policy) { std::vector codepoints; while (*utf8 != 0) { TCodepoint codepoint; - std::tie(codepoint, utf8) = ParseNextUTF8(utf8); - if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { + std::tie(codepoint, utf8) = ParseNextUTF8(utf8, error_policy); + if (codepoint == CharHandlingError::kInvalidUTF8) { return {codepoint}; } codepoints.push_back(codepoint); @@ -129,17 +154,17 @@ inline int HexCharToInt(char c) { } std::pair ParseNextUTF8OrEscaped( - const char* utf8, const std::unordered_map& custom_escape_map) { + const char* utf8, const std::unordered_map& additional_escape_map) { static const std::unordered_map kEscapeToCodepoint = { {"\\\'", '\''}, {"\\\"", '\"'}, {"\\\?", '\?'}, {"\\\\", '\\'}, {"\\a", '\a'}, {"\\b", '\b'}, {"\\f", '\f'}, {"\\n", '\n'}, {"\\r", '\r'}, {"\\t", '\t'}, {"\\v", '\v'}, {"\\0", '\0'}, {"\\e", '\x1B'}}; if (utf8[0] != '\\') { - return ParseNextUTF8(utf8); + return ParseNextUTF8(utf8, UTF8ErrorPolicy::kReturnInvalid); } auto escape_sequence = std::string(utf8, 2); - if (auto it = custom_escape_map.find(escape_sequence); it != custom_escape_map.end()) { + if (auto it = additional_escape_map.find(escape_sequence); it != additional_escape_map.end()) { return {it->second, utf8 + 2}; } if (auto it = kEscapeToCodepoint.find(escape_sequence); it != kEscapeToCodepoint.end()) { @@ -159,7 +184,7 @@ std::pair ParseNextUTF8OrEscaped( ++len; } if (len == 0) { - return {static_cast(CharHandlingError::kInvalidEscape), utf8}; + return {CharHandlingError::kInvalidEscape, utf8}; } return {codepoint, utf8 + len + 2}; } else if (utf8[1] == 'u' || utf8[1] == 'U') { @@ -170,13 +195,13 @@ std::pair ParseNextUTF8OrEscaped( for (int i = 0; i < len; ++i) { auto digit = HexCharToInt(utf8[i + 2]); if (digit == -1) { - return {static_cast(CharHandlingError::kInvalidEscape), utf8}; + return {CharHandlingError::kInvalidEscape, utf8}; } codepoint = codepoint * 16 + digit; } return {codepoint, utf8 + len + 2}; } else { - return {static_cast(CharHandlingError::kInvalidEscape), utf8}; + return {CharHandlingError::kInvalidEscape, utf8}; } } diff --git a/cpp/support/encoding.h b/cpp/support/encoding.h index 790040e97e..0b18c43b0d 100644 --- a/cpp/support/encoding.h +++ b/cpp/support/encoding.h @@ -17,59 +17,89 @@ namespace llm { using TCodepoint = int32_t; /*! - * \brief Convert a codepoint to a UTF-8 string. + * \brief Handle the utf-8 first byte. + * \returns (is_valid, total_number_of_bytes, initial_codepoint). + */ +std::tuple HandleUTF8FirstByte(uint8_t byte); + +/*! + * \brief Print a codepoint to a UTF-8 string. * \param codepoint The codepoint. * \return The UTF-8 string. */ std::string PrintAsUTF8(TCodepoint codepoint); /*! - * \brief Convert a codepoint to a printable string. If the codepoint is not printable, it will be + * \brief Print a codepoint to a escaped string. If the codepoint is not printable, it will be * escaped. By default the function support escape sequences in C ("\n", "\t", "\u0123"). User can - * specify more escape sequences using custom_escape_map. + * specify more escape sequences using additional_escape_map. * \param codepoint The codepoint. - * \param custom_escape_map A map from codepoint to escape sequence. If the codepoint is in the map, - * it will be escaped using the corresponding escape sequence. e.g. {{'-', "\\-"}}. - * \return The printable string. + * \param additional_escape_map A map from codepoint to escape sequence. If the codepoint is in the + * map, it will be escaped using the corresponding escape sequence. e.g. {{'-', "\\-"}}. \return The + * printable string. */ std::string PrintAsEscaped( TCodepoint codepoint, - const std::unordered_map& custom_escape_map = {}); + const std::unordered_map& additional_escape_map = {}); + +/*! + * \brief Print the given string to a escaped string that can be printed. + * \return The escaped string. + */ +std::string PrintAsEscaped(std::string raw_str); /*! * \brief Represents an error when handling characters. Will be returned as a special TCodepoint * value. */ -enum class CharHandlingError : TCodepoint { +enum CharHandlingError : TCodepoint { /*! \brief The UTF-8 string is invalid. */ - kInvalidUtf8 = -10, + kInvalidUTF8 = -10, /*! \brief The escape sequence is invalid. */ kInvalidEscape = -11, }; /*! - * \brief Convert a UTF-8 string to a codepoint. + * \brief The method to handle invalid UTF-8 sequence. + */ +enum class UTF8ErrorPolicy { + /*! \brief Return an error codepoint when an error is encountered. */ + kReturnInvalid, + /*! \brief Skip the error and continue parsing. */ + kReturnByte, +}; + +/*! + * \brief Parse the first codepoint in a UTF-8 string. * \param utf8 The UTF-8 string. - * \return The codepoint and the number of bytes consumed. If the UTF-8 string is invalid, the - * function returns (CharHandlingError::kInvalidUtf8, 0). + * \return The codepoint and new pointer. If the UTF-8 string is invalid, and the error policy is + * kReturnInvalid, the function returns (CharHandlingError::kInvalidUTF8, input char pointer). */ -std::pair ParseNextUTF8(const char* utf8); +std::pair ParseNextUTF8( + const char* utf8, UTF8ErrorPolicy error_policy = UTF8ErrorPolicy::kReturnInvalid); -std::vector ParseUTF8(const char* utf8); +/*! + * \brief Parse all codepoints in a UTF-8 string. + * \param utf8 The UTF-8 string. + * \return All codepoints. If the UTF-8 string is invalid, and the error policy is + * kReturnInvalid, the function returns {CharHandlingError::kInvalidUTF8}. + */ +std::vector ParseUTF8(const char* utf8, + UTF8ErrorPolicy error_policy = UTF8ErrorPolicy::kReturnInvalid); /*! - * \brief Convert a UTF-8 string or an escape sequence to a codepoint. By default the function - * supports escape sequences in C ("\n", "\t", "\u0123"). User can specify more escape sequences - * using custom_escape_map. + * \brief Parse the first codepoint from a UTF-8 string. Also checks escape sequences and converts + * the escaped char to its original value. * \param utf8 The UTF-8 string or the escape sequence. - * \param custom_escape_map A map from escape sequence to codepoint. If the escape sequence is in - * the map, it will be converted to the corresponding codepoint. e.g. {{"\\-", '-'}}. - * \return The codepoint and the number of bytes consumed. If the UTF-8 string or the escape - * sequence is invalid, the function returns - * (CharHandlingError::kInvalidUtf8 or CharHandlingError::kInvalidEscape, 0). + * \param additional_escape_map A map from escape sequence to codepoint. If the escape sequence is + * in the map, it will be converted to the corresponding codepoint. e.g. {{"\\-", '-'}}. + * \return The codepoint and the new pointer. If the UTF-8 string or the escape sequence is + * invalid, and the error policy is kReturnInvalid, the function returns + * (CharHandlingError::kInvalidUTF8, input char pointer). */ std::pair ParseNextUTF8OrEscaped( - const char* utf8, const std::unordered_map& custom_escape_map = {}); + const char* utf8, + const std::unordered_map& additional_escape_map = {}); } // namespace llm } // namespace mlc diff --git a/cpp/support/utils.h b/cpp/support/utils.h index 6c53e35715..2789654a88 100644 --- a/cpp/support/utils.h +++ b/cpp/support/utils.h @@ -37,5 +37,23 @@ inline bool StartsWith(const std::string& str, const char* prefix) { return prefix[n] == '\0'; } +/*! + * \brief Hash and combine value into seed. + * \ref https://www.boost.org/doc/libs/1_84_0/boost/intrusive/detail/hash_combine.hpp + */ +inline void HashCombineBinary(uint32_t& seed, uint32_t value) { + seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +/*! + * \brief Find the hash sum of several uint32_t args. + */ +template +uint32_t HashCombine(Args... args) { + uint32_t seed = 0; + (..., HashCombineBinary(seed, args)); + return seed; +} + } // namespace llm } // namespace mlc diff --git a/cpp/tokenizers.cc b/cpp/tokenizers.cc index 6fe9217520..cc1c172697 100644 --- a/cpp/tokenizers.cc +++ b/cpp/tokenizers.cc @@ -152,7 +152,8 @@ inline std::string ByteLevelDecoder(const std::string& token) { }; // clang-format on - auto unicode_codepoints = ParseUTF8(token.c_str()); + auto unicode_codepoints = ParseUTF8(token.c_str(), UTF8ErrorPolicy::kReturnInvalid); + ICHECK(unicode_codepoints.size() != 1 || unicode_codepoints[0] != kInvalidUTF8); std::string decoded; for (auto unicode_codepoint : unicode_codepoints) { diff --git a/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift b/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift index 26361977ce..991149be2b 100644 --- a/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift +++ b/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift @@ -38,29 +38,20 @@ class AppState: ObservableObject { // Step 0: load the engine await engine.reload(modelPath: modelLocalPath, modelLib: modelLib) - // TODO(mlc-team) update request so it is also structure based - // as in open ai api - // sent a request - let jsonRequest = """ - { - "model": "llama3", - "messages": [ - { - "role": "user", - "content": [ - { "type": "text", "text": "What is the meaning of life?" } - ] - } - ] - } - """ // run chat completion as in OpenAI API style - for await res in await engine.chatCompletion(jsonRequest: jsonRequest) { + for await res in await engine.chatCompletion( + messages: [ + ChatCompletionMessage( + role: .user, + content: "What is the meaning of life?" + ) + ] + ) { // publish at main event loop DispatchQueue.main.async { // parse the result content in structured form // and stream back to the display - self.displayText += res.choices[0].delta.content![0]["text"]! + self.displayText += res.choices[0].delta.content!.asText() } } } diff --git a/ios/MLCSwift/Sources/Swift/LLMEngine.swift b/ios/MLCSwift/Sources/Swift/LLMEngine.swift index 91a4d20b81..ce167b7dd3 100644 --- a/ios/MLCSwift/Sources/Swift/LLMEngine.swift +++ b/ios/MLCSwift/Sources/Swift/LLMEngine.swift @@ -61,8 +61,55 @@ public actor MLCEngine { jsonFFIEngine.unload() } - // TODO(mlc-team) turn into a structured interface - public func chatCompletion(jsonRequest: String) -> AsyncStream { + // offer a direct convenient method to pass in messages + public func chatCompletion( + messages: [ChatCompletionMessage], + model: Optional = nil, + frequency_penalty: Optional = nil, + presence_penalty: Optional = nil, + logprobs: Bool = false, + top_logprobs: Int = 0, + logit_bias: Optional<[Int : Float]> = nil, + max_tokens: Optional = nil, + n: Int = 1, + seed: Optional = nil, + stop: Optional<[String]> = nil, + stream: Bool = false, + temperature: Optional = nil, + top_p: Optional = nil, + tools: Optional<[ChatTool]> = nil, + user: Optional = nil, + response_format: Optional = nil + ) -> AsyncStream { + let request = ChatCompletionRequest( + messages: messages, + model: model, + frequency_penalty: frequency_penalty, + presence_penalty: presence_penalty, + logprobs: logprobs, + top_logprobs: top_logprobs, + logit_bias: logit_bias, + max_tokens: max_tokens, + n: n, + seed: seed, + stop: stop, + stream: stream, + temperature: temperature, + top_p: top_p, + tools: tools, + user: user, + response_format: response_format + ) + return self.chatCompletion(request: request) + } + + // completion function + public func chatCompletion( + request: ChatCompletionRequest + ) -> AsyncStream { + let encoder = JSONEncoder() + let data = try! encoder.encode(request) + let jsonRequest = String(data: data, encoding: .utf8)! // generate a UUID for the request let requestID = UUID().uuidString let stream = AsyncStream(ChatCompletionStreamResponse.self) { continuation in diff --git a/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift b/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift index 1f36933a15..edb0fa5211 100644 --- a/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift +++ b/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift @@ -5,14 +5,14 @@ import Foundation // API reference: https://platform.openai.com/docs/api-reference/chat/create public struct TopLogProbs : Codable { - public let token: String - public let logprob: Float - public let bytes: Optional<[Int]> + public var token: String + public var logprob: Float + public var bytes: Optional<[Int]> } public struct LogProbsContent : Codable { - public let token: String - public let logprob: Float + public var token: String + public var logprob: Float public var bytes: Optional<[Int]> = nil public var top_logprobs: [TopLogProbs] = [] } @@ -22,49 +22,225 @@ public struct LogProbs : Codable { } public struct ChatFunction : Codable { - public let name: String + public var name: String public var description: Optional = nil - public let parameters: [String: String] + public var parameters: [String: String] + + public init( + name: String, + description: Optional = nil, + parameters: [String : String] + ) { + self.name = name + self.description = description + self.parameters = parameters + } } public struct ChatTool : Codable { public var type: String = "function" public let function: ChatFunction + + public init(type: String, function: ChatFunction) { + self.type = type + self.function = function + } } public struct ChatFunctionCall : Codable { - public let name: String + public var name: String // NOTE: arguments shold be dict str to any codable // for now only allow string output due to typing issues public var arguments: Optional<[String: String]> = nil + + public init(name: String, arguments: Optional<[String : String]> = nil) { + self.name = name + self.arguments = arguments + } } public struct ChatToolCall : Codable { public var id: String = UUID().uuidString public var type: String = "function" - public let function: ChatFunctionCall + public var function: ChatFunctionCall + + public init( + id: String = UUID().uuidString, + type: String = "function", + function: ChatFunctionCall + ) { + self.id = id + self.type = type + self.function = function + } +} + +public enum ChatCompletionRole: String, Codable { + case system = "system" + case user = "user" + case assistant = "assistant" + case tool = "tool" +} + +public enum ChatCompletionMessageContent: Codable { + case text(String) + case parts([[String: String]]) + + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + if let text = try? container.decode(String.self) { + self = .text(text) + } else { + let parts = try container.decode([[String: String]].self) + self = .parts(parts) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .text(let text): try container.encode(text) + case .parts(let parts): try container.encode(parts) + } + } + + public func asText() -> String { + switch (self) { + case .text(let text): return text + case .parts(let parts): + var res = "" + for item in parts { + if item["type"]! == "text" { + res += item["text"]! + } + } + return res + } + } } -public struct ChatCompletionMessage : Codable { - public let role: String - public var content: Optional<[[String: String]]> = nil +public struct ChatCompletionMessage: Codable { + public var role: ChatCompletionRole + public var content: Optional = nil public var name: Optional = nil public var tool_calls: Optional<[ChatToolCall]> = nil public var tool_call_id: Optional = nil + + // more complicated content construction + public init( + role: ChatCompletionRole, + content: Optional<[[String : String]]> = nil, + name: Optional = nil, + tool_calls: Optional<[ChatToolCall]> = nil, + tool_call_id: Optional = nil + ) { + self.role = role + if let cvalue = content { + self.content = .parts(cvalue) + } else { + self.content = nil + } + self.name = name + self.tool_calls = tool_calls + self.tool_call_id = tool_call_id + } + + // convenient method to construct content from string + public init( + role: ChatCompletionRole, + content: String, + name: Optional = nil, + tool_calls: Optional<[ChatToolCall]> = nil, + tool_call_id: Optional = nil + ) { + self.role = role + self.content = .text(content) + self.name = name + self.tool_calls = tool_calls + self.tool_call_id = tool_call_id + } } public struct ChatCompletionStreamResponseChoice: Codable { public var finish_reason: Optional = nil - public let index: Int - public let delta: ChatCompletionMessage + public var index: Int + public var delta: ChatCompletionMessage public var lobprobs: Optional = nil } public struct ChatCompletionStreamResponse: Codable { - public let id : String + public var id : String public var choices: [ChatCompletionStreamResponseChoice] = [] public var created: Optional = nil public var model: Optional = nil - public let system_fingerprint: String + public var system_fingerprint: String public var object: Optional = nil } + +public struct ResponseFormat: Codable { + public var type: String + public var schema: Optional = nil + + public init(type: String, schema: Optional = nil) { + self.type = type + self.schema = schema + } +} + +public struct ChatCompletionRequest: Codable { + public var messages: [ChatCompletionMessage] + public var model: Optional = nil + public var frequency_penalty: Optional = nil + public var presence_penalty: Optional = nil + public var logprobs: Bool = false + public var top_logprobs: Int = 0 + public var logit_bias: Optional<[Int: Float]> = nil + public var max_tokens: Optional = nil + public var n: Int = 1 + public var seed: Optional = nil + public var stop: Optional<[String]> = nil + public var stream: Bool = false + public var temperature: Optional = nil + public var top_p: Optional = nil + public var tools: Optional<[ChatTool]> = nil + public var user: Optional = nil + public var response_format: Optional = nil + + public init( + messages: [ChatCompletionMessage], + model: Optional = nil, + frequency_penalty: Optional = nil, + presence_penalty: Optional = nil, + logprobs: Bool = false, + top_logprobs: Int = 0, + logit_bias: Optional<[Int : Float]> = nil, + max_tokens: Optional = nil, + n: Int = 1, + seed: Optional = nil, + stop: Optional<[String]> = nil, + stream: Bool = false, + temperature: Optional = nil, + top_p: Optional = nil, + tools: Optional<[ChatTool]> = nil, + user: Optional = nil, + response_format: Optional = nil + ) { + self.messages = messages + self.model = model + self.frequency_penalty = frequency_penalty + self.presence_penalty = presence_penalty + self.logprobs = logprobs + self.top_logprobs = top_logprobs + self.logit_bias = logit_bias + self.max_tokens = max_tokens + self.n = n + self.seed = seed + self.stop = stop + self.stream = stream + self.temperature = temperature + self.top_p = top_p + self.tools = tools + self.user = user + self.response_format = response_format + } +} diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index d776ed146b..c6314f2c04 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -50,7 +50,7 @@ def main(argv): parser.add_argument( "--speculative-mode", type=str, - choices=["disable", "small_draft", "eagle"], + choices=["disable", "small_draft", "eagle", "medusa"], default="disable", help=HELP["speculative_mode_serve"] + ' (default: "%(default)s")', ) diff --git a/python/mlc_llm/compiler_pass/cublas_dispatch.py b/python/mlc_llm/compiler_pass/cublas_dispatch.py index f5af94cc4b..b8e461e945 100644 --- a/python/mlc_llm/compiler_pass/cublas_dispatch.py +++ b/python/mlc_llm/compiler_pass/cublas_dispatch.py @@ -20,7 +20,8 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR model_names = [ gv.name_hint for gv, func in mod.functions.items() if isinstance(func, relax.Function) ] - model_names = [name for name in model_names if "batch" not in name] + # exclude single batch decode + model_names = [name for name in model_names if "batch" in name or "decode" not in name] mod = tvm.transform.Sequential( [ relax.transform.FuseOpsByPattern( diff --git a/python/mlc_llm/interface/compile.py b/python/mlc_llm/interface/compile.py index 7aafc64738..a8a170c3ad 100644 --- a/python/mlc_llm/interface/compile.py +++ b/python/mlc_llm/interface/compile.py @@ -85,6 +85,14 @@ def _apply_preproc_to_params( return extra_tirs +def _infer_kv_state_kind(model_type) -> str: + if "rwkv" in model_type: + return "rnn_state" + if "medusa" in model_type: + return "none" + return "kv_cache" + + def _compile(args: CompileArgs, model_config: ConfigBase): def _get_variable_bounds(model_config) -> Dict[str, int]: if hasattr(model_config, "sliding_window_size"): @@ -178,6 +186,7 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: "prefill_chunk_size": model_config.prefill_chunk_size, # type: ignore "tensor_parallel_shards": model_config.tensor_parallel_shards, # type: ignore "kv_cache_bytes": kv_cache_bytes, + "kv_state_kind": _infer_kv_state_kind(args.model.name), } logger.info("Registering metadata: %s", metadata) metadata["params"] = [_get_param_metadata(name, param) for name, param in named_params] diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index d1cde12678..acf6ead514 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -23,7 +23,7 @@ def serve( prefill_chunk_size: Optional[int], max_history_size: Optional[int], gpu_memory_utilization: Optional[float], - speculative_mode: Literal["disable", "small_draft", "eagle"], + speculative_mode: Literal["disable", "small_draft", "eagle", "medusa"], spec_draft_length: int, enable_tracing: bool, host: str, diff --git a/python/mlc_llm/json_ffi/engine.py b/python/mlc_llm/json_ffi/engine.py index 237319a926..9a95d4b0a4 100644 --- a/python/mlc_llm/json_ffi/engine.py +++ b/python/mlc_llm/json_ffi/engine.py @@ -1,5 +1,6 @@ # pylint: disable=chained-comparison,missing-docstring,too-few-public-methods,too-many-instance-attributes # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +import json import queue import threading from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union @@ -20,17 +21,15 @@ class EngineState: sync_queue: queue.Queue - def get_request_stream_callback(self) -> Callable[[List[str]], None]: + def get_request_stream_callback(self) -> Callable[[str], None]: # ChatCompletionStreamResponse - def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: + def _callback(chat_completion_stream_responses_json_str: str) -> None: self._sync_request_stream_callback(chat_completion_stream_responses_json_str) return _callback - def _sync_request_stream_callback( - self, chat_completion_stream_responses_json_str: List[str] - ) -> None: + def _sync_request_stream_callback(self, chat_completion_stream_responses_json_str: str) -> None: # Put the delta outputs to the queue in the unblocking way. self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) @@ -125,7 +124,9 @@ def _background_stream_back_loop(): verbose=False, ) - self._ffi["init_background_engine"](device, self.state.get_request_stream_callback(), None) + self._ffi["init_background_engine"]( + device.device_type, device.device_id, self.state.get_request_stream_callback() + ) self._ffi["reload"](self.engine_config.asjson()) def terminate(self): @@ -210,11 +211,12 @@ def _handle_chat_completion( try: while num_unfinished_requests > 0: - chat_completion_stream_responses_json_str = self.state.sync_queue.get() - for chat_completion_response_json_str in chat_completion_stream_responses_json_str: + chat_completion_responses_json_str = self.state.sync_queue.get() + chat_completion_responses_list = json.loads(chat_completion_responses_json_str) + for chat_completion_response_json_dict in chat_completion_responses_list: chat_completion_response = ( - openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( - chat_completion_response_json_str + openai_api_protocol.ChatCompletionStreamResponse.model_validate( + chat_completion_response_json_dict ) ) for choice in chat_completion_response.choices: diff --git a/python/mlc_llm/model/medusa/__init__.py b/python/mlc_llm/model/medusa/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/medusa/medusa_loader.py b/python/mlc_llm/model/medusa/medusa_loader.py new file mode 100644 index 0000000000..41bef4d98d --- /dev/null +++ b/python/mlc_llm/model/medusa/medusa_loader.py @@ -0,0 +1,51 @@ +""" +This file specifies how MLC's Medusa parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .medusa_model import MedusaConfig, MedusaModel + + +def huggingface(model_config: MedusaConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : MedusaConfig + The configuration of the Medusa model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = MedusaModel(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_llm/model/medusa/medusa_model.py b/python/mlc_llm/model/medusa/medusa_model.py new file mode 100644 index 0000000000..af21164421 --- /dev/null +++ b/python/mlc_llm/model/medusa/medusa_model.py @@ -0,0 +1,83 @@ +"""Medusa model definition.""" +import dataclasses +from typing import Any, Dict, Optional + +from tvm.relax.frontend import nn + +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class MedusaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Llama model.""" + + medusa_num_heads: int + medusa_num_layers: int + hidden_size: int + vocab_size: int + max_batch_size: int = 1 + tensor_parallel_shards: int = 1 + + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + # Unused parameters. Kept for compatibility with the compilation flow. + prefill_chunk_size: int = -1 + context_window_size: int = -1 + + +# pylint: disable=missing-docstring + + +class ResBlock(nn.Module): + """Residual block with SiLU activation.""" + + def __init__(self, hidden_size): + super().__init__() + self.linear = nn.Linear(hidden_size, hidden_size) + self.act = nn.SiLU() + + def forward(self, x): + return x + self.act(self.linear(x)) + + +class MedusaModel(nn.Module): + """Medusa model definition.""" + + def __init__(self, config: MedusaConfig): + self.hidden_size = config.hidden_size + self.dtype = "float32" + self.medusa_head = nn.ModuleList( + [ + nn.ModuleList( + [ResBlock(config.hidden_size) for _ in range(config.medusa_num_layers)] + + [nn.Linear(config.hidden_size, config.vocab_size, bias=False)] + ) + for _ in range(config.medusa_num_heads) + ] + ) + + def get_default_spec(self): + mod_spec = { + "get_logits": { + "hidden_states": nn.spec.Tensor(["batch_size", self.hidden_size], self.dtype), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) + + def get_logits(self, hidden_states: nn.Tensor): + logits = [] + for head in self.medusa_head: + logits.append(head(hidden_states).astype("float32")) + return logits + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype diff --git a/python/mlc_llm/model/medusa/medusa_quantization.py b/python/mlc_llm/model/medusa/medusa_quantization.py new file mode 100644 index 0000000000..9fb2b6c255 --- /dev/null +++ b/python/mlc_llm/model/medusa/medusa_quantization.py @@ -0,0 +1,20 @@ +"""This file specifies how MLC's Medusa parameters are quantized.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import NoQuantize + +from .medusa_model import MedusaConfig, MedusaModel + + +def no_quant( + model_config: MedusaConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama2 model without quantization.""" + model: nn.Module = MedusaModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 08d272f409..042bd7ceaa 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -19,6 +19,7 @@ from .internlm import internlm_loader, internlm_model, internlm_quantization from .llama import llama_loader, llama_model, llama_quantization from .llava import llava_loader, llava_model, llava_quantization +from .medusa import medusa_loader, medusa_model, medusa_quantization from .mistral import mistral_loader, mistral_model, mistral_quantization from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization from .orion import orion_loader, orion_model, orion_quantization @@ -385,4 +386,16 @@ class Model: "ft-quant": bert_quantization.ft_quant, }, ), + "medusa": Model( + name="medusa", + model=medusa_model.MedusaModel, + config=medusa_model.MedusaConfig, + source={ + "huggingface-torch": medusa_loader.huggingface, + "huggingface-safetensor": medusa_loader.huggingface, + }, + quantize={ + "no-quant": medusa_quantization.no_quant, + }, + ), } diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 916403839a..2dbaaf36a6 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -194,11 +194,12 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes kv_state_kind: Optional[Literal["kv_cache", "rnn_state"]] The kind of cache. - speculative_mode : Literal["disable", "small_draft", "eagle"] + speculative_mode : Literal["disable", "small_draft", "eagle", "medusa"] The speculative mode. "disable" means speculative decoding is disabled. "small_draft" means the normal speculative decoding (small draft) mode. "eagle" means the eagle-style speculative decoding. + "medusa" means the medusa-style speculative decoding. spec_draft_length : int The number of tokens to generate in speculative proposal (draft). @@ -220,7 +221,7 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes prefill_chunk_size: Optional[int] = None max_history_size: Optional[int] = None kv_state_kind: Optional[Literal["kv_cache", "rnn_state"]] = None - speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable" + speculative_mode: Literal["disable", "small_draft", "eagle", "medusa"] = "disable" spec_draft_length: int = 4 verbose: bool = True diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index c99dbd4794..896930e684 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -827,11 +827,12 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): significantly smaller than this number. Under mode "server", the actual memory usage may be slightly larger than this number. - speculative_mode : Literal["disable", "small_draft", "eagle"] + speculative_mode : Literal["disable", "small_draft", "eagle", "medusa"] The speculative mode. "disable" means speculative decoding is disabled. "small_draft" means the normal speculative decoding (small draft) mode. "eagle" means the eagle-style speculative decoding. + "medusa" means the medusa-style speculative decoding. spec_draft_length : int The number of tokens to generate in speculative proposal (draft). @@ -856,7 +857,7 @@ def __init__( # pylint: disable=too-many-arguments prefill_chunk_size: Optional[int] = None, max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, - speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable", + speculative_mode: Literal["disable", "small_draft", "eagle", "medusa"] = "disable", spec_draft_length: int = 4, enable_tracing: bool = False, verbose: bool = True, diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 641c8f6ed5..12b495dfca 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -425,7 +425,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals prefill_chunk_size: Optional[int], max_history_size: Optional[int], gpu_memory_utilization: Optional[float], - speculative_mode: Literal["disable", "small_draft", "eagle"], + speculative_mode: Literal["disable", "small_draft", "eagle", "medusa"], spec_draft_length: int, enable_tracing: bool, verbose: bool, diff --git a/python/mlc_llm/serve/grammar.py b/python/mlc_llm/serve/grammar.py index cf491884c2..8b5b7d9649 100644 --- a/python/mlc_llm/serve/grammar.py +++ b/python/mlc_llm/serve/grammar.py @@ -1,6 +1,6 @@ """Classes handling the grammar guided generation of MLC LLM serving""" -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import tvm import tvm._ffi @@ -22,19 +22,20 @@ class BNFGrammar(Object): def from_ebnf_string( ebnf_string: str, main_rule: str = "main", - normalize: bool = True, - simplify: bool = True, ) -> "BNFGrammar": - r"""Parse a BNF grammar from a string in BNF/EBNF format. - - This method accepts the EBNF notation from the W3C XML Specification - (https://www.w3.org/TR/xml/#sec-notation), which is a popular standard, with the following - changes: - - Using # as comment mark instead of /**/ - - Using C-style unicode escape sequence \u01AB, \U000001AB, \xAB instead of #x0123 - - Do not support A-B (match A and not match B) yet - - See tests/python/serve/json.ebnf for an example. + r"""Construct a BNF grammar with a EBNF-formatted string. The grammar will be normalized + (simplified) by default. + + EBNF grammar: see https://www.w3.org/TR/xml/#sec-notation. Note: + 1. Use # as the comment mark + 2. Use C-style unicode escape sequence \u01AB, \U000001AB, \xAB + 3. A-B (match A and not match B) is not supported yet + 4. Lookahead assertion can be added at the end of a rule to speed up matching. E.g. + ``` + main ::= "ab" a [a-z] + a ::= "cd" (=[a-z]) + ``` + The assertion (=[a-z]) means a must be followed by [a-z]. Parameters ---------- @@ -44,28 +45,13 @@ def from_ebnf_string( main_rule : str The name of the main rule. Default: "main". - normalize : bool - Whether to normalize the grammar. Default: true. Only set to false for the purpose of - testing. - - In The normalized form of a BNF grammar, every rule is in the form: - `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. - - I.e. a list of choices, each choice is a sequence of elements. Elements can be a - character class or a rule reference. And if the rule can be empty, the first choice - will be an empty string. - - simplify : bool - Whether to simplify the grammar to make matching more efficient. Default: true. Not - implemented yet. - Returns ------- grammar : BNFGrammar The parsed BNF grammar. """ return _ffi_api.BNFGrammarFromEBNFString( # type: ignore # pylint: disable=no-member - ebnf_string, main_rule, normalize, simplify + ebnf_string, main_rule ) def to_string(self) -> str: @@ -167,6 +153,31 @@ def get_grammar_of_json() -> "BNFGrammar": """ return _ffi_api.BNFGrammarGetGrammarOfJSON() # type: ignore # pylint: disable=no-member + @staticmethod + def debug_from_ebnf_string_no_normalize( + ebnf_string: str, + main_rule: str = "main", + ) -> "BNFGrammar": + r"""Construct a BNF grammar with a EBNF-formatted string, but not normalize it. + For test purposes. + + Parameters + ---------- + ebnf_string : str + The grammar string. + + main_rule : str + The name of the main rule. Default: "main". + + Returns + ------- + grammar : BNFGrammar + The parsed BNF grammar. + """ + return _ffi_api.BNFGrammarDebugFromEBNFStringNoNormalize( # type: ignore # pylint: disable=no-member + ebnf_string, main_rule + ) + @staticmethod def debug_json_schema_to_ebnf( schema: str, @@ -235,6 +246,11 @@ class GrammarStateMatcher(Object): max_rollback_steps : int The maximum number of steps to rollback when backtracking. Default: 0. + + token_table_postproc_method : Literal["byte_fallback", "byte_level"] + A helper parameter for the tokenizer. Only useful when the tokenizer is specified. + The method to postprocess the token table. For LLaMA and LLaMA-2 tokenizer, use + "byte_fallback"; for LLaMA-3 tokenizer, use "byte_level". Default: "byte_fallback". """ def __init__( @@ -242,6 +258,7 @@ def __init__( grammar: BNFGrammar, tokenizer: Union[None, Tokenizer, List[str]] = None, max_rollback_steps: int = 0, + token_table_postproc_method: Literal["byte_fallback", "byte_level"] = "byte_fallback", ): if isinstance(tokenizer, list): self.__init_handle_by_constructor__( @@ -256,6 +273,7 @@ def __init__( grammar, tokenizer, max_rollback_steps, + token_table_postproc_method, ) def accept_token(self, token_id: int) -> bool: @@ -346,7 +364,7 @@ def is_terminated(self) -> bool: """ return _ffi_api.GrammarStateMatcherIsTerminated(self) # type: ignore # pylint: disable=no-member - def debug_accept_char(self, codepoint: int) -> bool: + def debug_accept_char(self, codepoint: int, verbose: bool = False) -> bool: """Accept one unicode codepoint to the current state. For test purposes. Parameters @@ -354,11 +372,11 @@ def debug_accept_char(self, codepoint: int) -> bool: codepoint : int The unicode codepoint of the character to be accepted. """ - return _ffi_api.GrammarStateMatcherDebugAcceptCodepoint( # type: ignore # pylint: disable=no-member - self, codepoint + return _ffi_api.GrammarStateMatcherDebugAcceptChar( # type: ignore # pylint: disable=no-member + self, codepoint, verbose ) - def debug_match_complete_string(self, string: str) -> bool: + def debug_match_complete_string(self, string: str, verbose: bool = False) -> bool: """Check if the matcher can accept the complete string, and then reach the end of the grammar. Does not change the state of the GrammarStateMatcher. For test purposes. @@ -367,4 +385,4 @@ def debug_match_complete_string(self, string: str) -> bool: string : str The string to be matched. """ - return _ffi_api.GrammarStateMatcherDebugMatchCompleteString(self, string) # type: ignore # pylint: disable=no-member + return _ffi_api.GrammarStateMatcherDebugMatchCompleteString(self, string, verbose) # type: ignore # pylint: disable=no-member diff --git a/python/mlc_llm/testing/debug_chat.py b/python/mlc_llm/testing/debug_chat.py index 8ff370e9d9..fee8cb8867 100644 --- a/python/mlc_llm/testing/debug_chat.py +++ b/python/mlc_llm/testing/debug_chat.py @@ -351,7 +351,9 @@ def _sample_token_from_logits( if presence_penalty != 0.0 or frequency_penalty != 0.0: self._apply_presence_and_freq_penalty(logits_np, presence_penalty, frequency_penalty) - self._softmax_with_temperature(logits_np, temperature) + logits_np = self._softmax_with_temperature(logits_np, temperature) + np.savez(self.instrument.debug_out / "logits.npz", logits_np) + logits = logits.copyfrom(logits_np) next_token = self.sample_topp_from_prob_func(logits, top_p, random.random()) return next_token diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index b438c2a352..ca2e7deb98 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -66,9 +66,8 @@ def run_chat_completion( ): for choice in response.choices: assert choice.delta.role == "assistant" - assert isinstance(choice.delta.content[0], Dict) - assert choice.delta.content[0]["type"] == "text" - output_texts[rid][choice.index] += choice.delta.content[0]["text"] + assert isinstance(choice.delta.content, str) + output_texts[rid][choice.index] += choice.delta.content # Print output. print("Chat completion all finished") @@ -83,7 +82,7 @@ def run_chat_completion( def test_chat_completion(): # Create engine. - model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" engine = JSONFFIEngine( model, max_total_sequence_length=1024, @@ -101,7 +100,7 @@ def test_chat_completion(): def test_reload_reset_unload(): # Create engine. - model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" engine = JSONFFIEngine( model, max_total_sequence_length=1024, @@ -136,4 +135,4 @@ def test_function_calling(): if __name__ == "__main__": test_chat_completion() test_reload_reset_unload() - test_function_calling() + # test_function_calling() diff --git a/tests/python/serve/test_grammar_parser.py b/tests/python/serve/test_grammar_parser.py index 10eacdf9b9..5e335e15c7 100644 --- a/tests/python/serve/test_grammar_parser.py +++ b/tests/python/serve/test_grammar_parser.py @@ -1,4 +1,5 @@ # pylint: disable=missing-module-docstring,missing-function-docstring +import json import os import pytest @@ -14,11 +15,13 @@ def test_bnf_simple(): c ::= "c" """ expected = """main ::= ((b c)) -b ::= (([b])) -c ::= (([c])) +b ::= (("b")) +c ::= (("c")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() + print(after) + print(expected) assert after == expected @@ -32,11 +35,11 @@ def test_ebnf(): b ::= ((b_1)) c ::= ((c_1)) d ::= ((d_1)) -b_1 ::= ("" | ([a] [b] b_1)) +b_1 ::= ("" | ("ab" b_1)) c_1 ::= (([acep-z] c_1) | ([acep-z])) -d_1 ::= ("" | ([d])) +d_1 ::= ("" | ("d")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -49,18 +52,33 @@ def test_star_quantifier(): e ::= [e]* [f]* | [g]* """ expected = """main ::= ((b c d)) -b ::= [b]* +b ::= (([b]*)) c ::= ((c_1)) d ::= ((d_1)) -e ::= ((e_star e_star_1) | (e_star_2)) -c_1 ::= ("" | ([b] c_1)) +e ::= (([e]* [f]*) | ([g]*)) +c_1 ::= ("" | ("b" c_1)) d_1 ::= ("" | (d_1_choice d_1)) -e_star ::= [e]* -e_star_1 ::= [f]* -e_star_2 ::= [g]* -d_1_choice ::= (([b] [c] [d]) | ([p] [q])) +d_1_choice ::= (("bcd") | ("pq")) +""" + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") + after = bnf_grammar.to_string() + assert after == expected + + +def test_lookahead_assertion(): + before = """main ::= ((b c d)) +b ::= (("abc" [a-z])) (=("abc")) +c ::= (("a") | ("b")) (=([a-z] "b")) +d ::= (("ac") | ("b" d_choice)) (=("abc")) +d_choice ::= (("e") | ("d")) +""" + expected = """main ::= ((b c d)) +b ::= (("abc" [a-z])) (=("abc")) +c ::= (("a") | ("b")) (=([a-z] "b")) +d ::= (("ac") | ("b" d_choice)) (=("abc")) +d_choice ::= (("e") | ("d")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -68,14 +86,14 @@ def test_star_quantifier(): def test_char(): before = r"""main ::= [a-z] [A-z] "\u0234" "\U00000345\xff" [-A-Z] [--] [^a] rest rest ::= [a-zA-Z0-9-] [\u0234-\U00000345] [测-试] [\--\]] rest1 -rest1 ::= "\?\"\'测试あc" "👀" "" +rest1 ::= "\?\"\'测试あc" "👀" "" [a-a] [b-b] """ - expected = r"""main ::= (([a-z] [A-z] ([\u0234]) ([\u0345] [\u00ff]) [\-A-Z] [\-\-] [^a] rest)) + expected = r"""main ::= (([a-z] [A-z] "\u0234\u0345\u00ff" [\-A-Z] [\-\-] [^a] rest)) rest ::= (([a-zA-Z0-9\-] [\u0234-\u0345] [\u6d4b-\u8bd5] [\--\]] rest1)) -rest1 ::= ((([\?] [\"] [\'] [\u6d4b] [\u8bd5] [\u3042] [c]) ([\U0001f440]) "")) +rest1 ::= (("\?\"\'\u6d4b\u8bd5\u3042c\U0001f440ab")) """ # Disable unwrap_nesting_rules to expose the result before unwrapping. - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", False, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -88,9 +106,9 @@ def test_space(): "f" | "g" """ - expected = """main ::= (([a] [b] [c] [d] [e]) | ([f]) | ([g])) + expected = """main ::= (("abcde") | ("f") | ("g")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -98,10 +116,10 @@ def test_space(): def test_nest(): before = """main::= "a" ("b" | "c" "d") | (("e" "f")) """ - expected = """main ::= (([a] main_choice) | ([e] [f])) -main_choice ::= (([b]) | ([c] [d])) + expected = """main ::= (("a" main_choice) | ("ef")) +main_choice ::= (("b") | ("cd")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -115,15 +133,16 @@ def test_flatten(): empty_test ::= "d" | (("" | "" "") "" | "a" "") | ("" ("" | "")) "" "" """ expected = """main ::= ((or_test sequence_test nested_test empty_test)) -or_test ::= ("" | ([a]) | ([b]) | ([d] [e]) | (or_test) | ([^a-z])) -sequence_test ::= (([a] [a] [b] sequence_test_choice [d] [e] sequence_test)) -nested_test ::= (([a] [b] [c] [d]) | ([a]) | ([b]) | ([c]) | (nested_rest)) -nested_rest ::= (([a]) | ([b] [c]) | ([d]) | ([e] [f]) | ([g])) -empty_test ::= ("" | ([d]) | ([a])) -sequence_test_choice ::= (([c]) | ([d])) +or_test ::= ("" | ("a") | ("b") | ("de") | (or_test) | ([^a-z])) +sequence_test ::= (("aab" sequence_test_choice "de" sequence_test)) +nested_test ::= (("abcd") | ("a") | ("b") | ("c") | (nested_rest)) +nested_rest ::= (("a") | ("bc") | ("d") | ("ef") | ("g")) +empty_test ::= ("" | ("d") | ("a")) +sequence_test_choice ::= (("c") | ("d")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() + print(after) assert after == expected @@ -135,51 +154,53 @@ def test_json(): before = file.read() expected = r"""main ::= ((element)) -value ::= ((object) | (array) | (string) | (number) | ([t] [r] [u] [e]) | ([f] [a] [l] [s] [e]) | ([n] [u] [l] [l])) -object ::= (([{] ws [}]) | ([{] members [}])) -members ::= ((member) | (member [,] members)) -member ::= ((ws string ws [:] element)) -array ::= (([[] ws [\]]) | ([[] elements [\]])) -elements ::= ((element) | (element [,] elements)) +value ::= ((object) | (array) | (string) | (number) | ("true") | ("false") | ("null")) +object ::= (("{" ws "}") | ("{" members "}")) +members ::= ((member) | (member "," members)) +member ::= ((ws string ws ":" element)) +array ::= (("[" ws "]") | ("[" elements "]")) +elements ::= ((element) | (element "," elements)) element ::= ((ws value ws)) -string ::= (([\"] characters [\"])) +string ::= (("\"" characters "\"")) characters ::= ("" | (character characters)) -character ::= (([^\"\\]) | ([\\] escape)) -escape ::= (([\"]) | ([\\]) | ([/]) | ([b]) | ([f]) | ([n]) | ([r]) | ([t]) | ([u] hex hex hex hex)) +character ::= (([^\"\\]) | ("\\" escape)) +escape ::= (("\"") | ("\\") | ("/") | ("b") | ("f") | ("n") | ("r") | ("t") | ("u" hex hex hex hex)) hex ::= (([A-Fa-f0-9])) number ::= ((integer fraction exponent)) -integer ::= ((digit) | (onenine digits) | ([\-] digit) | ([\-] onenine digits)) +integer ::= ((digit) | (onenine digits) | ("-" digit) | ("-" onenine digits)) digits ::= ((digit) | (digit digits)) digit ::= (([0-9])) onenine ::= (([1-9])) -fraction ::= ("" | ([.] digits)) +fraction ::= ("" | ("." digits)) exponent ::= ("" | (exponent_choice exponent_choice_1 digits)) -ws ::= ("" | ([ ] ws) | ([\n] ws) | ([\r] ws) | ([\t] ws)) -exponent_choice ::= (([e]) | ([E])) -exponent_choice_1 ::= ("" | ([+]) | ([\-])) +ws ::= ("" | (" " ws) | ("\n" ws) | ("\r" ws) | ("\t" ws)) +exponent_choice ::= (("e") | ("E")) +exponent_choice_1 ::= ("" | ("+") | ("-")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() + print(after) assert after == expected def test_to_string_roundtrip(): """Checks the printed result can be parsed, and the parsing-printing process is idempotent.""" - before = r"""main ::= (b c) | (b main) -b ::= b_1 d -c ::= c_1 -d ::= d_1 -b_1 ::= ([b] b_1) | "" -c_1 ::= (c_2 c_1) | c_2 -c_2 ::= [acep-z] -d_1 ::= [d] | "" + before = r"""main ::= ((b c) | (b main)) +b ::= ((b_1 d)) +c ::= ((c_1)) +d ::= ((d_1)) +b_1 ::= ("" | ("b" b_1)) +c_1 ::= ((c_2 c_1) | (c_2)) (=("abc" [a-z])) +c_2 ::= (([acep-z])) +d_1 ::= ("" | ("d")) """ - bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main") output_string_1 = bnf_grammar_1.to_string() - bnf_grammar_2 = BNFGrammar.from_ebnf_string(output_string_1, "main", True, False) + bnf_grammar_2 = BNFGrammar.from_ebnf_string(output_string_1, "main") output_string_2 = bnf_grammar_2.to_string() + assert before == output_string_1 assert output_string_1 == output_string_2 @@ -245,34 +266,50 @@ def test_error(): ): BNFGrammar.from_ebnf_string('a ::= "a"') + with pytest.raises( + TVMError, + match="TVMError: EBNF parse error at line 1, column 21: Unexpected lookahead assertion", + ): + BNFGrammar.from_ebnf_string('main ::= "a" (="a") (="b")') + def test_to_json(): before = """main ::= b c | b main b ::= "bcd" c ::= [a-z] """ - expected = ( - '{"rule_expr_indptr":[0,3,6,10,13,16,20,24,28,32,36,41,44,48,51],"rule_expr_data"' - ":[3,1,1,3,1,2,4,2,0,1,3,1,1,3,1,0,4,2,3,4,5,2,2,5,0,2,98,98,0,2,99,99,0,2,100,100," - '4,3,7,8,9,5,1,10,0,2,97,122,4,1,12,5,1,13],"rules":[{"body_expr_id":6,"name":"main"},' - '{"body_expr_id":11,"name":"b"},{"body_expr_id":14,"name":"c"}]}' - ) - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) - after = bnf_grammar.to_json(False) - assert after == expected + expected_obj = { + "rules": [ + {"body_expr_id": 6, "name": "main"}, + {"body_expr_id": 9, "name": "b"}, + {"body_expr_id": 12, "name": "c"}, + ], + "rule_expr_indptr": [0, 3, 6, 10, 13, 16, 20, 24, 29, 32, 35, 40, 43], + "rule_expr_data": [ + # fmt: off + 4,1,1,4,1,2,5,2,0,1,4,1,1,4,1,0,5,2,3,4,6,2,2,5,0,3,98,99, + 100,5,1,7,6,1,8,1,3,0,97,122,5,1,10,6,1,11 + # fmt: on + ], + } + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") + print(bnf_grammar) + after_str = bnf_grammar.to_json(False) + after_obj = json.loads(after_str) + assert after_obj == expected_obj def test_to_json_roundtrip(): before = r"""main ::= ((b c) | (b main)) -b ::= ((b_1 d)) +b ::= ((b_1 d [a]*)) c ::= ((c_1)) d ::= ((d_1)) -b_1 ::= ("" | ([b] b_1)) +b_1 ::= ("" | ("b" b_1)) c_1 ::= ((c_2 c_1) | (c_2)) c_2 ::= (([acep-z])) -d_1 ::= ("" | ([d])) +d_1 ::= ("" | ("d")) """ - bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main") output_json_1 = bnf_grammar_1.to_json(False) bnf_grammar_2 = BNFGrammar.from_json(output_json_1) output_json_2 = bnf_grammar_2.to_json(False) diff --git a/tests/python/serve/test_grammar_state_matcher_custom.py b/tests/python/serve/test_grammar_state_matcher_custom.py index 6fc48705d1..6ad6294d77 100644 --- a/tests/python/serve/test_grammar_state_matcher_custom.py +++ b/tests/python/serve/test_grammar_state_matcher_custom.py @@ -40,6 +40,20 @@ def json_grammar(): return get_json_grammar() +def test_simple(): + grammar_str = """main ::= rule1 rule2 +rule1 ::= (rule2 | rule3) "a" +rule2 ::= "b" +rule3 ::= "c" +""" + + grammar = BNFGrammar.from_ebnf_string(grammar_str) + matcher = GrammarStateMatcher(grammar) + assert matcher.debug_match_complete_string("bab") + assert not matcher.debug_match_complete_string("abb") + assert matcher.debug_match_complete_string("cab") + + (json_input_accepted,) = tvm.testing.parameters( ('{"name": "John"}',), ('{ "name" : "John" }',), @@ -241,8 +255,8 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): '{"id": 1,"name": "Example"}', [ # fmt: off - 31989, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 299, 299, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 292, 292, 292, 292, 292, 31974, 31999 + 31989, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 272, 272, 272, 272, + 272, 31973, 31846, 31846, 265, 265, 265, 265, 265, 265, 265, 265, 31974, 31999 # fmt: on ], ), @@ -258,15 +272,15 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): }""", [ # fmt: off - 31989, 31912, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 31915, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 31974, 31915, 31915, 299, 299, 299, 31973, - 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 299, 299, 31973, 31846, 31846, - 31840, 291, 291, 291, 31969, 31846, 31846, 291, 291, 291, 31969, 31974, 31915, 31915, - 299, 299, 299, 31973, 31846, 31846, 31908, 299, 299, 299, 299, 31973, 31846, 31846, - 31906, 299, 299, 299, 299, 31973, 31846, 31846, 291, 291, 291, 31968, 31970, 31915, - 31915, 299, 299, 299, 299, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, - 31846, 31846, 31943, 31970, 31974, 31915, 31915, 299, 299, 299, 299, 31973, 31846, - 31846, 292, 292, 292, 292, 31974, 31974, 31999 + 31989, 31912, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 31915, 272, 272, + 272, 31973, 31846, 31846, 265, 265, 265, 31974, 31915, 31915, 272, 272, 272, 31973, + 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 272, 272, 31973, 31846, 31846, + 31840, 264, 264, 264, 31969, 31846, 31846, 264, 264, 264, 31969, 31974, 31915, 31915, + 272, 272, 272, 31973, 31846, 31846, 31908, 272, 272, 272, 272, 31973, 31846, 31846, + 31906, 272, 272, 272, 272, 31973, 31846, 31846, 264, 264, 264, 31968, 31970, 31915, + 31915, 272, 272, 272, 272, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, + 31846, 31846, 31943, 31970, 31974, 31915, 31915, 272, 272, 272, 272, 31973, 31846, + 31846, 265, 265, 265, 265, 31974, 31974, 31999 # fmt: on ], ), @@ -395,5 +409,6 @@ class MainModel(BaseModel): if __name__ == "__main__": # Run a benchmark to show the performance before running tests test_find_next_rejected_tokens(get_json_grammar(), '{"id": 1,"name": "Example"}') + test_find_next_rejected_tokens_schema() tvm.testing.main() diff --git a/tests/python/serve/test_grammar_state_matcher_json.py b/tests/python/serve/test_grammar_state_matcher_json.py index fc0f79a041..51737e1435 100644 --- a/tests/python/serve/test_grammar_state_matcher_json.py +++ b/tests/python/serve/test_grammar_state_matcher_json.py @@ -2,7 +2,7 @@ # pylint: disable=redefined-outer-name,unbalanced-tuple-unpacking """This test uses the optimized JSON grammar provided by the grammar library.""" import sys -from typing import List, Optional +from typing import List, Literal, Optional import pytest import tvm @@ -213,19 +213,40 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_pressure) -(input_find_rejected_tokens, expected_rejected_sizes) = tvm.testing.parameters( +( + tokenizer_path, + input_find_rejected_tokens, + expected_rejected_sizes, + token_table_postproc_method, +) = tvm.testing.parameters( ( # short test + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", '{"id": 1,"name": "Example"}', [ # fmt: off - 31989, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 299, 299, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 292, 292, 292, 292, 292, 31974, 31999 + 31989, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 272, 272, 272, 272, + 272, 31973, 31846, 31846, 265, 265, 265, 265, 265, 265, 265, 265, 31974, 31999 # fmt: on ], + "byte_fallback", + ), + ( + # short test + "dist/Meta-Llama-3-8B-Instruct-q4f16_1-MLC", + '{"id": 1,"name": "Example哈哈"}', + [ + # fmt: off + 128235, 127497, 5002, 5002, 5002, 127849, 126399, 126399, 126760, 127499, 5002, 5002, + 5002, 5002, 5002, 127849, 126399, 126399, 4952, 4952, 4952, 4952, 4952, 4952, 4952, + 4952, 128066, 128111, 4952, 128066, 128111, 4952, 127873, 128254 + # fmt: on + ], + "byte_level", ), ( # long test + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", """{ "id": 1, "na": "ex", @@ -236,40 +257,51 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): }""", [ # fmt: off - 31989, 31912, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 31915, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 31974, 31915, 31915, 299, 299, 299, 31973, - 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 299, 299, 31973, 31846, 31846, - 31840, 291, 291, 291, 31969, 31846, 31846, 291, 291, 291, 31969, 31974, 31915, 31915, - 299, 299, 299, 31973, 31846, 31846, 31908, 299, 299, 299, 299, 31973, 31846, 31846, - 31906, 299, 299, 299, 299, 31973, 31846, 31846, 291, 291, 291, 31968, 31970, 31915, - 31915, 299, 299, 299, 299, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, - 31846, 31846, 31943, 31970, 31974, 31915, 31915, 299, 299, 299, 299, 31973, 31846, - 31846, 292, 292, 292, 292, 31974, 31974, 31999 + 31989, 31912, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 31915, 272, 272, + 272, 31973, 31846, 31846, 265, 265, 265, 31974, 31915, 31915, 272, 272, 272, 31973, + 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 272, 272, 31973, 31846, 31846, + 31840, 264, 264, 264, 31969, 31846, 31846, 264, 264, 264, 31969, 31974, 31915, 31915, + 272, 272, 272, 31973, 31846, 31846, 31908, 272, 272, 272, 272, 31973, 31846, 31846, + 31906, 272, 272, 272, 272, 31973, 31846, 31846, 264, 264, 264, 31968, 31970, 31915, + 31915, 272, 272, 272, 272, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, + 31846, 31846, 31943, 31970, 31974, 31915, 31915, 272, 272, 272, 272, 31973, 31846, + 31846, 265, 265, 265, 265, 31974, 31974, 31999 # fmt: on ], + "byte_fallback", ), ) def test_find_next_rejected_tokens( json_grammar: BNFGrammar, + tokenizer_path: str, input_find_rejected_tokens: str, - expected_rejected_sizes: Optional[List[int]] = None, + expected_rejected_sizes: Optional[List[int]], + token_table_postproc_method: Literal["byte_fallback", "byte_level"], ): - tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" tokenizer = Tokenizer(tokenizer_path) - grammar_state_matcher = GrammarStateMatcher(json_grammar, tokenizer) + grammar_state_matcher = GrammarStateMatcher( + json_grammar, tokenizer, token_table_postproc_method=token_table_postproc_method + ) + input_bytes = input_find_rejected_tokens.encode("utf-8") + rejected_sizes = [] - real_sizes = [] - for c in input_find_rejected_tokens: + for i, c in enumerate(input_bytes): rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) - real_sizes.append(len(rejected_token_ids)) - print("Accepting char:", c, file=sys.stderr) - assert grammar_state_matcher.debug_accept_char(ord(c)) + rejected_sizes.append(len(rejected_token_ids)) + if expected_rejected_sizes is not None: + assert rejected_sizes[-1] == expected_rejected_sizes[i], ( + len(rejected_token_ids), + expected_rejected_sizes[i], + ) + print("Accepting char:", c, bytes([c]), file=sys.stderr) + assert grammar_state_matcher.debug_accept_char(c) + rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) - real_sizes.append(len(rejected_token_ids)) + rejected_sizes.append(len(rejected_token_ids)) if expected_rejected_sizes is not None: - assert real_sizes == expected_rejected_sizes + assert rejected_sizes[-1] == expected_rejected_sizes[-1] def test_token_based_operations(json_grammar: BNFGrammar): @@ -305,7 +337,7 @@ def test_token_based_operations(json_grammar: BNFGrammar): accepted = list(set(range(len(token_table))) - set(rejected)) accepted_tokens = [token_table[i] for i in accepted] result.append(accepted_tokens) - assert id in accepted + assert id in accepted, token_table[id] assert grammar_state_matcher.accept_token(id) rejected = grammar_state_matcher.find_next_rejected_tokens() @@ -407,6 +439,20 @@ def test_termination(json_grammar: BNFGrammar): if __name__ == "__main__": # Run a benchmark to show the performance before running tests - test_find_next_rejected_tokens(BNFGrammar.get_grammar_of_json(), '{"id": 1,"name": "Example"}') + test_find_next_rejected_tokens( + BNFGrammar.get_grammar_of_json(), + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + '{"id": 1,"name": "Example"}', + None, + "byte_fallback", + ) + + test_find_next_rejected_tokens( + BNFGrammar.get_grammar_of_json(), + "dist/Meta-Llama-3-8B-Instruct-q4f16_1-MLC", + '{"id": 1,"name": "Example哈哈"}', + None, + "byte_level", + ) tvm.testing.main() diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index 2b3ce29c7f..8bd86a25a1 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -13,7 +13,7 @@ prompts_list = [ "Generate a JSON string containing 20 objects:", - "Generate a JSON containing a list:", + "Generate a JSON containing a non-empty list:", "Generate a JSON with 5 elements:", ] model_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" diff --git a/web/emcc/mlc_wasm_runtime.cc b/web/emcc/mlc_wasm_runtime.cc index b9a7f55bfa..6ba914ee9f 100644 --- a/web/emcc/mlc_wasm_runtime.cc +++ b/web/emcc/mlc_wasm_runtime.cc @@ -36,9 +36,9 @@ // Grammar related #include "serve/grammar/grammar.cc" +#include "serve/grammar/grammar_functor.cc" #include "serve/grammar/grammar_parser.cc" #include "serve/grammar/grammar_serializer.cc" -#include "serve/grammar/grammar_simplifier.cc" #include "serve/grammar/grammar_state_matcher.cc" #include "serve/grammar/json_schema_converter.cc" #include "support/encoding.cc"