diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index 1a21c2962d..d5fc53b8fa 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -123,7 +123,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { TVM_MODULE_VTABLE_END(); void InitBackgroundEngine(JSONFFIEngineConfig json_ffi_engine_config, EngineConfig engine_config, - Optional request_stream_callback, + Device device, Optional request_stream_callback, Optional trace_recorder) { std::optional conv_template = Conversation::FromJSON(json_ffi_engine_config->conv_template, &err_); @@ -150,7 +150,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { }; request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - this->engine_->InitBackgroundEngine(std::move(request_stream_callback), + this->engine_->InitBackgroundEngine(device, std::move(request_stream_callback), std::move(trace_recorder)); this->engine_->Reload(std::move(engine_config)); } diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 19f26ff624..3bb809ad67 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -244,17 +244,16 @@ String GenerationConfigNode::AsJSONString() const { TVM_REGISTER_OBJECT_TYPE(EngineConfigNode); EngineConfig::EngineConfig(String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, - int max_total_sequence_length, int max_single_sequence_length, - int prefill_chunk_size, int max_history_size, KVStateKind kv_state_kind, + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, + int max_single_sequence_length, int prefill_chunk_size, + int max_history_size, KVStateKind kv_state_kind, SpeculativeMode speculative_mode, int spec_draft_length) { ObjectPtr n = make_object(); n->model = std::move(model); n->model_lib_path = std::move(model_lib_path); n->additional_models = std::move(additional_models); n->additional_model_lib_paths = std::move(additional_model_lib_paths); - n->device = device; n->kv_cache_page_size = kv_cache_page_size; n->max_num_sequence = max_num_sequence; n->max_total_sequence_length = max_total_sequence_length; @@ -267,14 +266,60 @@ EngineConfig::EngineConfig(String model, String model_lib_path, Array ad data_ = std::move(n); } +EngineConfig EngineConfig::FromJSONString(const std::string& json_str) { + picojson::value config_json; + std::string err = picojson::parse(config_json, json_str); + if (!err.empty()) { + LOG(FATAL) << err; + } + + // Get json fields. + picojson::object config = config_json.get(); + String model = json::Lookup(config, "model"); + String model_lib_path = json::Lookup(config, "model_lib_path"); + std::vector additional_models; + std::vector additional_model_lib_paths; + int kv_cache_page_size = json::Lookup(config, "kv_cache_page_size"); + int max_num_sequence = json::Lookup(config, "max_num_sequence"); + int max_total_sequence_length = json::Lookup(config, "max_total_sequence_length"); + int max_single_sequence_length = json::Lookup(config, "max_single_sequence_length"); + int prefill_chunk_size = json::Lookup(config, "prefill_chunk_size"); + int max_history_size = json::Lookup(config, "max_history_size"); + KVStateKind kv_state_kind = + static_cast(json::Lookup(config, "kv_state_kind")); + SpeculativeMode speculative_mode = + static_cast(json::Lookup(config, "speculative_mode")); + int spec_draft_length = json::Lookup(config, "spec_draft_length"); + + picojson::array additional_models_arr = + json::Lookup(config, "additional_models"); + picojson::array additional_model_lib_paths_arr = + json::Lookup(config, "additional_model_lib_paths"); + CHECK_EQ(additional_models_arr.size(), additional_model_lib_paths_arr.size()) + << "The number of additional model lib paths does not match the number of additional models"; + int num_additional_models = additional_models_arr.size(); + additional_models.reserve(num_additional_models); + additional_model_lib_paths.reserve(num_additional_models); + for (int i = 0; i < num_additional_models; ++i) { + additional_models.push_back(json::Lookup(additional_models_arr, i)); + additional_model_lib_paths.push_back( + json::Lookup(additional_model_lib_paths_arr, i)); + } + + return EngineConfig(std::move(model), std::move(model_lib_path), additional_models, + additional_model_lib_paths, kv_cache_page_size, max_num_sequence, + max_total_sequence_length, max_single_sequence_length, prefill_chunk_size, + max_history_size, kv_state_kind, speculative_mode, spec_draft_length); +} + TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig") .set_body_typed([](String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, int max_single_sequence_length, int prefill_chunk_size, int max_history_size, int kv_state_kind, int speculative_mode, int spec_draft_length) { return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models), - std::move(additional_model_lib_paths), device, kv_cache_page_size, + std::move(additional_model_lib_paths), kv_cache_page_size, max_num_sequence, max_total_sequence_length, max_single_sequence_length, prefill_chunk_size, max_history_size, KVStateKind(kv_state_kind), SpeculativeMode(speculative_mode), spec_draft_length); diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 6a3bdd8997..fd76dd49f0 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -106,11 +106,6 @@ class EngineConfigNode : public Object { /*! \brief The path to the additional models' libraries. */ Array additional_model_lib_paths; - /*************** Device ***************/ - - /*! \brief The device where the models run. */ - DLDevice device; - /*************** KV cache config and engine capacities ***************/ /*! \brief The number of consecutive tokens handled in each page in paged KV cache. */ @@ -152,12 +147,15 @@ class EngineConfigNode : public Object { class EngineConfig : public ObjectRef { public: explicit EngineConfig(String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, int max_single_sequence_length, int prefill_chunk_size, int max_history_size, KVStateKind kv_state_kind, SpeculativeMode speculative_mode, int spec_draft_length); + /*! \brief Create EngineConfig from JSON string. */ + static EngineConfig FromJSONString(const std::string& json_str); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode); }; diff --git a/cpp/serve/draft_token_workspace_manager.cc b/cpp/serve/draft_token_workspace_manager.cc new file mode 100644 index 0000000000..d004e91ee5 --- /dev/null +++ b/cpp/serve/draft_token_workspace_manager.cc @@ -0,0 +1,54 @@ +/*! + * Copyright (c) 2024 by Contributors + * \file serve/draft_token_workspace_manager.cc + */ + +#include "draft_token_workspace_manager.h" + +#include "model.h" + +namespace mlc { +namespace llm { +namespace serve { + +DraftTokenWorkspaceManagerObj::DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size, + int hidden_size, + DLDataType hidden_states_dtype, + DLDevice device, + const FunctionTable& ft) + : max_num_tokens_(max_num_tokens), + vocab_size_(vocab_size), + hidden_size_(hidden_size), + hidden_states_dtype_(hidden_states_dtype), + device_(device), + ft_(ft) { + free_slots_.resize(max_num_tokens); + std::iota(free_slots_.begin(), free_slots_.end(), 0); +} + +void DraftTokenWorkspaceManagerObj::AllocSlots(int num_slots, std::vector* result) { + ICHECK_LE(num_slots, free_slots_.size()); + result->assign(free_slots_.rbegin(), free_slots_.rbegin() + num_slots); + std::vector allocated(free_slots_.begin(), free_slots_.begin() + num_slots); + free_slots_.resize(free_slots_.size() - num_slots); +} + +void DraftTokenWorkspaceManagerObj::FreeSlots(const std::vector& slots) { + std::copy(slots.begin(), slots.end(), std::back_inserter(free_slots_)); +} + +void DraftTokenWorkspaceManagerObj::AllocWorkspace(ModelWorkspace* workspace, + bool require_hidden_states) { + workspace->draft_probs = + NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_); + workspace->draft_probs_storage = + NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_); + if (require_hidden_states) { + workspace->draft_hidden_states_storage = + ft_.Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_); + } +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/draft_token_workspace_manager.h b/cpp/serve/draft_token_workspace_manager.h new file mode 100644 index 0000000000..1a1dfbc8e0 --- /dev/null +++ b/cpp/serve/draft_token_workspace_manager.h @@ -0,0 +1,95 @@ +/*! + * Copyright (c) 2024 by Contributors + * \file serve/draft_token_workspace_manager.h + */ + +#ifndef MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_ +#define MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_ +#include + +#include +#include +#include + +#include "data.h" +#include "function_table.h" +namespace mlc { +namespace llm { +namespace serve { + +using tvm::Device; +using namespace tvm::runtime; + +struct ModelWorkspace; + +/*! + * \brief Managing the workspace for draft token generation. + * + * The workspace is used to store the associated states for each draft token, including the + * probability distribution of the draft token, the hidden states, etc. The workspace manager + * maintains a pool of slots for the draft tokens to store the states. + */ +class DraftTokenWorkspaceManagerObj : public Object { + public: + /*! + * \brief Constructor + * \param max_num_tokens The maximum number of draft tokens that can be stored in the workspace. + * \param vocab_size The size of the vocabulary. + * \param hidden_size The size of the hidden states. + * \param hidden_states_dtype The data type of the hidden states. + * \param device The device running the model. + * \param ft The function table. + */ + DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size, int hidden_size, + DLDataType hidden_states_dtype, DLDevice device, + const FunctionTable& ft); + + /*! + * \brief Allocate the workspace for draft tokens and update `ModelWorkspace` data structure. + * \param workspace The object to stored the allocated draft token workspace. + * \param require_hidden_states Whether to allocate workspace for the hidden states. + */ + void AllocWorkspace(ModelWorkspace* workspace, bool require_hidden_states); + + /*! + * \brief Allocate slots for the draft tokens. + * \param num_slots The number of slots to allocate. + * \param result The vector to store the allocated slots. + */ + void AllocSlots(int num_slots, std::vector* result); + + /*! + * \brief Free the slots. + * \param slots The slots to free. + */ + void FreeSlots(const std::vector& slots); + + static constexpr const char* _type_key = "mlc.serve.DraftTokenWorkspaceManager"; + + private: + std::vector free_slots_; + int max_num_tokens_; + int vocab_size_; + int hidden_size_; + DataType hidden_states_dtype_; + DLDevice device_; + const FunctionTable& ft_; +}; + +class DraftTokenWorkspaceManager : public ObjectRef { + public: + DraftTokenWorkspaceManager(int max_num_tokens, int vocab_size, int hidden_size, + DLDataType hidden_states_dtype, DLDevice device, + const FunctionTable& ft) { + data_ = make_object(max_num_tokens, vocab_size, hidden_size, + hidden_states_dtype, device, ft); + } + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DraftTokenWorkspaceManager, ObjectRef, + DraftTokenWorkspaceManagerObj); +}; + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_ diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 0348f7f40a..755af998cd 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -44,7 +45,8 @@ class EngineImpl : public Engine { public: /********************** Engine Management **********************/ - explicit EngineImpl(EngineConfig engine_config, Optional request_stream_callback, + explicit EngineImpl(EngineConfig engine_config, DLDevice device, + Optional request_stream_callback, Optional trace_recorder) { // Step 1. Initialize metadata and singleton states inside the engine this->estate_->Reset(); @@ -54,18 +56,25 @@ class EngineImpl : public Engine { } this->request_stream_callback_ = std::move(request_stream_callback); this->trace_recorder_ = trace_recorder; - this->tokenizer_ = Tokenizer::FromPath(engine_config->model); - this->token_table_ = tokenizer_->TokenTable(); - this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_); + // Step 2. Initialize each model independently. // Create the logit processor and sampler. this->models_.clear(); this->model_workspaces_.clear(); - auto f_create_model = [this, &engine_config, &trace_recorder](const String& model_path, - const String& model_lib_path) { - Model model = Model::Create(model_lib_path, std::move(model_path), engine_config->device, - engine_config->max_num_sequence, + std::vector model_configs; + model_configs.push_back(Model::LoadModelConfig(engine_config->model)); + for (const auto& model_path : engine_config->additional_models) { + model_configs.push_back(Model::LoadModelConfig(model_path)); + } + + Optional session = CreateDiscoSession(model_configs, device); + + auto f_create_model = [this, &engine_config, &device, &trace_recorder, &model_configs, + &session](const String& model_path, const String& model_lib_path, + int model_index) { + Model model = Model::Create(model_lib_path, std::move(model_path), model_configs[model_index], + device, engine_config->max_num_sequence, session, /*trace_enabled=*/trace_recorder.defined()); model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, engine_config->max_total_sequence_length, @@ -80,53 +89,78 @@ class EngineImpl : public Engine { ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); }; - f_create_model(engine_config->model, engine_config->model_lib_path); + f_create_model(engine_config->model, engine_config->model_lib_path, /*model_index=*/0); CHECK_EQ(engine_config->additional_models.size(), engine_config->additional_model_lib_paths.size()) << "The additional model and lib path list has mismatched size."; for (int i = 0; i < static_cast(engine_config->additional_models.size()); ++i) { f_create_model(engine_config->additional_models[i], - engine_config->additional_model_lib_paths[i]); + engine_config->additional_model_lib_paths[i], /*model_index=*/i + 1); + } + + // Step 3. Initialize tokenizer and grammar + this->tokenizer_ = Tokenizer::FromPath(engine_config->model); + std::string token_table_postproc_method; + if (model_configs[0].count("token_table_postproc_method") == 0) { + // Backward compatibility: use "byte-fallback" by default + token_table_postproc_method = "byte-fallback"; + } else { + token_table_postproc_method = + model_configs[0].at("token_table_postproc_method").get(); } + this->token_table_ = + Tokenizer::PostProcessTokenTable(tokenizer_->TokenTable(), token_table_postproc_method); + this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_); + // Step 4. Initialize engine actions that represent state transitions. int max_num_tokens = engine_config->max_num_sequence; + DraftTokenWorkspaceManager draft_token_workspace_manager{nullptr}; if (engine_config->speculative_mode != SpeculativeMode::kDisable) { max_num_tokens *= engine_config->spec_draft_length + 1; + draft_token_workspace_manager = models_[0]->CreateDraftTokenWorkspaceManager(max_num_tokens); + draft_token_workspace_manager->AllocWorkspace( + &model_workspaces_[0], + /*require_hidden_states=*/engine_config->speculative_mode == SpeculativeMode::kEagle); } LogitProcessor logit_processor = this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); Sampler sampler = this->models_[0]->CreateSampler( max_num_tokens, static_cast(this->models_.size()), trace_recorder); - // Step 3. Initialize engine actions that represent state transitions. if (engine_config->speculative_mode != SpeculativeMode::kDisable) { // Speculative decoding is only possible for more than one model. ICHECK_GT(this->models_.size(), 1U); switch (engine_config->speculative_mode) { case SpeculativeMode::kEagle: - this->actions_ = {EngineAction::EagleNewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - engine_config, // - this->trace_recorder_), - EngineAction::EagleBatchDraft( - this->models_, logit_processor, sampler, this->model_workspaces_, - this->trace_recorder_, engine_config->spec_draft_length), - EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, - this->model_workspaces_, engine_config, - this->trace_recorder_)}; + this->actions_ = { + EngineAction::EagleNewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + draft_token_workspace_manager, // + engine_config, // + this->trace_recorder_), + EngineAction::EagleBatchDraft(this->models_, logit_processor, sampler, + this->model_workspaces_, draft_token_workspace_manager, + this->trace_recorder_, + engine_config->spec_draft_length), + EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, + this->model_workspaces_, draft_token_workspace_manager, + engine_config, this->trace_recorder_)}; break; default: - this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - engine_config, // - this->trace_recorder_), - EngineAction::BatchDraft(this->models_, logit_processor, sampler, - this->trace_recorder_), - EngineAction::BatchVerify(this->models_, logit_processor, sampler, - engine_config, this->trace_recorder_)}; + this->actions_ = { + EngineAction::NewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + engine_config, // + this->trace_recorder_), + EngineAction::BatchDraft(this->models_, logit_processor, sampler, + this->model_workspaces_, draft_token_workspace_manager, + this->trace_recorder_), + EngineAction::BatchVerify(this->models_, logit_processor, sampler, + this->model_workspaces_, draft_token_workspace_manager, + engine_config, this->trace_recorder_)}; } } else { this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // @@ -286,6 +320,51 @@ class EngineImpl : public Engine { "action (e.g. prefill, decode, etc.) but it does not."; } + /************** Utility Functions **************/ + Optional CreateDiscoSession(std::vector model_configs, Device device) { + const auto& base_model_config = model_configs[0]; + + auto f_get_num_shards = [](const picojson::object& model_config) -> int { + constexpr auto kNumShardsKey = "tensor_parallel_shards"; + if (model_config.count(kNumShardsKey)) { + const auto& val = model_config.at(kNumShardsKey); + CHECK(val.is()); + return static_cast(val.get()); + } else { + LOG(FATAL) << "Key \"tensor_parallel_shards\" not found."; + } + throw; + }; + + int num_shards = std::transform_reduce( + model_configs.begin(), model_configs.end(), 1, [](int a, int b) { return std::max(a, b); }, + f_get_num_shards); + Optional session = NullOpt; + if (num_shards > 1) { + constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool"; + if (Registry::Get(f_create_process_pool) == nullptr) { + LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. " + << "Multi-GPU inference depends on MLC LLM Python API to launch process."; + } + std::string ccl; + if (device.device_type == kDLCUDA) { + ccl = "nccl"; + } else if (device.device_type == kDLROCM) { + ccl = "rccl"; + } else { + LOG(FATAL) << "ValueError: Multi-GPU on device " << DLDeviceType2Str(device.device_type) + << " is not supported. Currently, only NCCL and RCCL are integrated."; + } + std::vector device_ids(num_shards); + for (int i = 0; i < num_shards; ++i) { + device_ids[i] = i; + } + session = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); + session.value()->InitCCL(ccl, ShapeTuple(device_ids)); + } + return session; + } + /************** Debug/Profile **************/ void DebugCallFuncOnAllAllWorker(const String& func_name) final { @@ -339,10 +418,11 @@ class EngineImpl : public Engine { Optional trace_recorder_; }; -std::unique_ptr Engine::Create(EngineConfig engine_config, +std::unique_ptr Engine::Create(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder) { - return std::make_unique(std::move(engine_config), std::move(request_stream_callback), + return std::make_unique(std::move(engine_config), device, + std::move(request_stream_callback), std::move(trace_recorder)); } @@ -368,10 +448,10 @@ class EngineModule : public ModuleNode { TVM_MODULE_VTABLE_END(); /*! \brief Initialize the engine with config and other fields. */ - void Init(EngineConfig engine_config, Optional request_stream_callback, + void Init(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder) { - this->engine_ = Engine::Create(std::move(engine_config), std::move(request_stream_callback), - std::move(trace_recorder)); + this->engine_ = Engine::Create(std::move(engine_config), device, + std::move(request_stream_callback), std::move(trace_recorder)); } /*! \brief Construct an EngineModule. */ static tvm::runtime::Module Create() { return Module(make_object()); } diff --git a/cpp/serve/engine.h b/cpp/serve/engine.h index bcc1b80988..2fc0a4d730 100644 --- a/cpp/serve/engine.h +++ b/cpp/serve/engine.h @@ -51,11 +51,12 @@ class Engine { /*! * \brief Create an engine in unique pointer. * \param engine_config The engine config. + * \param device The device where the run models. * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. * \return The created Engine in pointer. */ - static std::unique_ptr Create(EngineConfig engine_config, + static std::unique_ptr Create(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder); diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h index 79359c5741..c69c508810 100644 --- a/cpp/serve/engine_actions/action.h +++ b/cpp/serve/engine_actions/action.h @@ -8,6 +8,7 @@ #define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_ #include "../config.h" +#include "../draft_token_workspace_manager.h" #include "../engine_state.h" #include "../event_trace_recorder.h" #include "../model.h" @@ -72,15 +73,16 @@ class EngineAction : public ObjectRef { * \param logit_processor The logit processor. * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ - static EngineAction EagleNewRequestPrefill(Array models, LogitProcessor logit_processor, - Sampler sampler, - std::vector model_workspaces, - EngineConfig engine_config, - Optional trace_recorder); + static EngineAction EagleNewRequestPrefill( + Array models, LogitProcessor logit_processor, Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, + Optional trace_recorder); /*! * \brief Create the action that runs one-step decode for requests in the * `running_queue` of engine state. Preempt low-priority requests @@ -104,13 +106,16 @@ class EngineAction : public ObjectRef { * \param models The model to run decode in. When there are multiple * models, the `Step` function of the created action will not take effect. * \param sampler The sampler to sample new tokens. + * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param trace_recorder The event trace recorder for requests. * \param draft_length The number of draft proposal rounds. * \return The created action object. */ static EngineAction BatchDraft(Array models, LogitProcessor logit_processor, - Sampler sampler, Optional trace_recorder, - int draft_length = 4); + Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + Optional trace_recorder, int draft_length = 4); /*! * \brief Create the action that runs one-step speculative draft proposal for @@ -120,12 +125,14 @@ class EngineAction : public ObjectRef { * models, the `Step` function of the created action will not take effect. * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param trace_recorder The event trace recorder for requests. * \param draft_length The number of draft proposal rounds. * \return The created action object. */ static EngineAction EagleBatchDraft(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, Optional trace_recorder, int draft_length = 4); @@ -135,13 +142,17 @@ class EngineAction : public ObjectRef { * accordingly when it is impossible to decode all the running requests. * \param models The model to run decode in. When there are multiple * models, the `Step` function of the created action will not take effect. + * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param sampler The sampler to sample new tokens. * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ static EngineAction BatchVerify(Array models, LogitProcessor logit_processor, - Sampler sampler, EngineConfig engine_config, + Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + EngineConfig engine_config, Optional trace_recorder); /*! @@ -152,6 +163,7 @@ class EngineAction : public ObjectRef { * models, the `Step` function of the created action will not take effect. * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. @@ -159,6 +171,7 @@ class EngineAction : public ObjectRef { static EngineAction EagleBatchVerify(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, Optional trace_recorder); diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index 6eb7a3d84a..af0dfe978d 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -142,9 +142,10 @@ void ActionStepPostProcess(Array requests, EngineState estate, Array& models, - Optional trace_recorder) { +RequestStateEntry PreemptLastRunningRequestStateEntry( + EngineState estate, const Array& models, + Optional draft_token_workspace_manager, + Optional trace_recorder) { ICHECK(!estate->running_queue.empty()); Request request = estate->running_queue.back(); @@ -168,8 +169,12 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate, // - Update `inputs` for future prefill. RECORD_EVENT(trace_recorder, rsentry->request->id, "preempt"); rsentry->status = RequestStateStatus::kPending; + std::vector draft_token_slots; for (RequestModelState mstate : rsentry->mstates) { - mstate->RemoveAllDraftTokens(); + if (draft_token_workspace_manager.defined()) { + mstate->RemoveAllDraftTokens(&draft_token_slots); + draft_token_workspace_manager.value()->FreeSlots(draft_token_slots); + } std::vector committed_token_ids; committed_token_ids.reserve(mstate->committed_tokens.size()); for (const SampleResult& committed_token : mstate->committed_tokens) { diff --git a/cpp/serve/engine_actions/action_commons.h b/cpp/serve/engine_actions/action_commons.h index 78e3937d0b..07bef2d2d9 100644 --- a/cpp/serve/engine_actions/action_commons.h +++ b/cpp/serve/engine_actions/action_commons.h @@ -7,6 +7,7 @@ #define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_ #include "../../tokenizers.h" +#include "../draft_token_workspace_manager.h" #include "../engine.h" #include "../engine_state.h" #include "../event_trace_recorder.h" @@ -52,12 +53,14 @@ void ActionStepPostProcess(Array requests, EngineState estate, Array& models, - Optional trace_recorder); +RequestStateEntry PreemptLastRunningRequestStateEntry( + EngineState estate, const Array& models, + Optional draft_token_workspace_manager, + Optional trace_recorder); /*! \brief Get the running request entries from the engine state. */ inline std::vector GetRunningRequestStateEntries(const EngineState& estate) { diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index 36acc6b06e..ecff914baa 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -48,7 +48,7 @@ class BatchDecodeActionObj : public EngineActionObj { running_rsentries = GetRunningRequestStateEntries(estate); while (!CanDecode(running_rsentries.size())) { RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + PreemptLastRunningRequestStateEntry(estate, models_, NullOpt, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { running_rsentries.pop_back(); } diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index c1ddeb6e4e..513a0fe447 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -23,10 +23,14 @@ namespace serve { class BatchDraftActionObj : public EngineActionObj { public: explicit BatchDraftActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, Optional trace_recorder, int draft_length) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), + model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), trace_recorder_(std::move(trace_recorder)), draft_length_(draft_length) { ICHECK_GT(draft_length_, 0); @@ -41,8 +45,8 @@ class BatchDraftActionObj : public EngineActionObj { // Preempt request state entries when decode cannot apply. std::vector running_rsentries = GetRunningRequestStateEntries(estate); while (!CanDecode(running_rsentries.size())) { - RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + RequestStateEntry preempted = PreemptLastRunningRequestStateEntry( + estate, models_, draft_token_workspace_manager_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { running_rsentries.pop_back(); } @@ -123,8 +127,11 @@ class BatchDraftActionObj : public EngineActionObj { ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. + draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_); + models_[model_id]->ScatterDraftProbs(probs_on_device, draft_token_slots_, + &model_workspaces_[0].draft_probs_storage); for (int i = 0; i < num_rsentries; ++i) { - mstates[i]->AddDraftToken(sample_results[i], prob_dist[i]); + mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]); estate->stats.total_draft_length += 1; } } @@ -156,18 +163,27 @@ class BatchDraftActionObj : public EngineActionObj { LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; + /*! \brief The model workspaces. */ + std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief Event trace recorder. */ Optional trace_recorder_; /*! \brief Draft proposal length */ int draft_length_; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; EngineAction EngineAction::BatchDraft(Array models, LogitProcessor logit_processor, - Sampler sampler, Optional trace_recorder, + Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + Optional trace_recorder, int draft_length) { return EngineAction(make_object( - std::move(models), std::move(logit_processor), std::move(sampler), std::move(trace_recorder), - draft_length)); + std::move(models), std::move(logit_processor), std::move(sampler), + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(trace_recorder), draft_length)); } } // namespace serve diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 42c9bbe018..42524d46b2 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -28,11 +28,15 @@ namespace serve { class BatchVerifyActionObj : public EngineActionObj { public: explicit BatchVerifyActionObj(Array models, LogitProcessor logit_processor, - Sampler sampler, EngineConfig engine_config, + Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), + model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)), rng_(RandomGenerator::GetInstance()) {} @@ -61,14 +65,13 @@ class BatchVerifyActionObj : public EngineActionObj { Array generation_cfg; std::vector rngs; std::vector> draft_output_tokens; - std::vector> draft_output_prob_dist; request_internal_ids.reserve(num_rsentries); all_tokens_to_verify.reserve(total_verify_length); verify_request_mstates.reserve(num_rsentries); rngs.reserve(num_rsentries); generation_cfg.reserve(num_rsentries); draft_output_tokens.reserve(num_rsentries); - draft_output_prob_dist.reserve(num_rsentries); + draft_token_slots_.clear(); for (int i = 0; i < num_rsentries; ++i) { RequestModelState verify_mstate = rsentries[i]->mstates[verify_model_id_]; @@ -76,18 +79,22 @@ class BatchVerifyActionObj : public EngineActionObj { request_internal_ids.push_back(verify_mstate->internal_id); ICHECK(!verify_lengths.empty()); ICHECK_EQ(verify_lengths[i], draft_mstate->draft_output_tokens.size() + 1); - ICHECK_EQ(verify_lengths[i], draft_mstate->draft_output_prob_dist.size() + 1); + ICHECK_EQ(verify_lengths[i], draft_mstate->draft_token_slots.size() + 1); // the last committed token + all the draft tokens. + draft_token_slots_.push_back(0); // placeholder for the last committed token all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().sampled_token_id.first); for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()); ++j) { all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].sampled_token_id.first); + draft_token_slots_.push_back(draft_mstate->draft_token_slots[j]); } verify_request_mstates.push_back(verify_mstate); generation_cfg.push_back(rsentries[i]->request->generation_cfg); rngs.push_back(&rsentries[i]->rng); draft_output_tokens.push_back(draft_mstate->draft_output_tokens); - draft_output_prob_dist.push_back(draft_mstate->draft_output_prob_dist); } + NDArray draft_probs_on_device = models_[draft_model_id_]->GatherDraftProbs( + model_workspaces_[verify_model_id_].draft_probs_storage, draft_token_slots_, + &model_workspaces_[verify_model_id_].draft_probs); RECORD_EVENT(trace_recorder_, request_ids, "start verify embedding"); ObjectRef embeddings = models_[verify_model_id_]->TokenEmbed( @@ -123,7 +130,7 @@ class BatchVerifyActionObj : public EngineActionObj { std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokensWithProbAfterTopP( renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs, - draft_output_tokens, draft_output_prob_dist); + draft_output_tokens, draft_probs_on_device); ICHECK_EQ(sample_results_arr.size(), num_rsentries); for (int i = 0; i < num_rsentries; ++i) { @@ -134,6 +141,8 @@ class BatchVerifyActionObj : public EngineActionObj { rsentries[i]->mstates[draft_model_id_]->CommitToken(sample_result); } estate->stats.total_accepted_length += accept_length; + estate->stats.UpdateSpecDecodingStats(cum_verify_lengths[i + 1] - cum_verify_lengths[i], + accept_length); int rollback_length = std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0); // rollback kv cache @@ -149,7 +158,8 @@ class BatchVerifyActionObj : public EngineActionObj { // clear the draft model state entries for (int i = 0; i < num_rsentries; ++i) { - rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(); + rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(&draft_token_slots_); + draft_token_workspace_manager_->FreeSlots(draft_token_slots_); } auto tend = std::chrono::high_resolution_clock::now(); @@ -194,8 +204,8 @@ class BatchVerifyActionObj : public EngineActionObj { total_required_pages += num_require_pages; } while (!CanVerify(total_required_pages)) { - RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + RequestStateEntry preempted = PreemptLastRunningRequestStateEntry( + estate, models_, draft_token_workspace_manager_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { total_verify_length -= verify_lengths.back(); total_required_pages -= num_page_requirement.back(); @@ -222,6 +232,10 @@ class BatchVerifyActionObj : public EngineActionObj { LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; + /*! \brief The model workspaces. */ + std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief The engine config. */ EngineConfig engine_config_; /*! \brief Event trace recorder. */ @@ -232,14 +246,20 @@ class BatchVerifyActionObj : public EngineActionObj { const int verify_model_id_ = 0; const int draft_model_id_ = 1; const float eps_ = 1e-5; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; EngineAction EngineAction::BatchVerify(Array models, LogitProcessor logit_processor, - Sampler sampler, EngineConfig engine_config, + Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + EngineConfig engine_config, Optional trace_recorder) { return EngineAction(make_object( - std::move(models), std::move(logit_processor), std::move(sampler), std::move(engine_config), - std::move(trace_recorder))); + std::move(models), std::move(logit_processor), std::move(sampler), + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(engine_config), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/eagle_batch_draft.cc b/cpp/serve/engine_actions/eagle_batch_draft.cc index fde314a5c5..b4e7ec4c39 100644 --- a/cpp/serve/engine_actions/eagle_batch_draft.cc +++ b/cpp/serve/engine_actions/eagle_batch_draft.cc @@ -24,11 +24,13 @@ class EagleBatchDraftActionObj : public EngineActionObj { public: explicit EagleBatchDraftActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, Optional trace_recorder, int draft_length) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), trace_recorder_(std::move(trace_recorder)), draft_length_(draft_length) { ICHECK_GT(draft_length_, 0); @@ -43,8 +45,8 @@ class EagleBatchDraftActionObj : public EngineActionObj { // Preempt request state entries when decode cannot apply. std::vector running_rsentries = GetRunningRequestStateEntries(estate); while (!CanDecode(running_rsentries.size())) { - RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + RequestStateEntry preempted = PreemptLastRunningRequestStateEntry( + estate, models_, draft_token_workspace_manager_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { running_rsentries.pop_back(); } @@ -81,21 +83,16 @@ class EagleBatchDraftActionObj : public EngineActionObj { mstates.push_back(rsentry->mstates[model_id]); } // draft_length_ rounds of draft proposal. - NDArray hidden_states_nd{nullptr}; - ObjectRef last_hidden_states{nullptr}; ObjectRef hidden_states = model_workspaces_[model_id].hidden_states; // Concat last hidden_states - std::vector previous_hidden_on_device; - for (int i = 0; i < num_rsentries; ++i) { - previous_hidden_on_device.push_back(mstates[i]->draft_last_hidden_on_device.back()); + draft_token_slots_.clear(); + if (draft_length_ > 1) { + for (int i = 0; i < num_rsentries; ++i) { + draft_token_slots_.push_back(mstates[i]->draft_token_slots.back()); + } + hidden_states = models_[model_id]->GatherHiddenStates( + model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states); } - hidden_states_nd = - models_[model_id]->ConcatLastHidden(previous_hidden_on_device, &hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 2); - ICHECK_EQ(hidden_states_nd->shape[0], num_rsentries); - hidden_states_nd = hidden_states_nd.CreateView( - {hidden_states_nd->shape[0], 1, hidden_states_nd->shape[1]}, hidden_states_nd->dtype); - last_hidden_states = hidden_states_nd; // The first draft token has been generated in prefill/verify stage for (int draft_id = 1; draft_id < draft_length_; ++draft_id) { // prepare new input tokens @@ -113,19 +110,18 @@ class EagleBatchDraftActionObj : public EngineActionObj { // - Invoke model decode. RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); - ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden( - embeddings, last_hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); - hidden_states_nd = - models_[model_id]->BatchDecodeToLastHidden(fused_hidden_states, request_internal_ids); - last_hidden_states = hidden_states_nd; + ObjectRef fused_embedding_hidden_states = models_[model_id]->FuseEmbedHidden( + embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + hidden_states = models_[model_id]->BatchDecodeToLastHidden(fused_embedding_hidden_states, + request_internal_ids); NDArray logits; if (models_[model_id]->CanGetLogits()) { - logits = models_[model_id]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, + logits = models_[model_id]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); } else { // - Use base model's head. logits = - models_[0]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + models_[0]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); } RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); ICHECK_EQ(logits->ndim, 3); @@ -144,20 +140,19 @@ class EagleBatchDraftActionObj : public EngineActionObj { // Fill range [0, num_rsentries) into `sample_indices`. std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); - std::vector prob_dist; NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( probs_on_device, sample_indices, request_ids, generation_cfg); std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( - renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. + draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_); + models_[model_id]->ScatterDraftProbs(probs_on_device, draft_token_slots_, + &model_workspaces_[0].draft_probs_storage); + // No need to save hidden states as they are not used by subsequent engine actions for (int i = 0; i < num_rsentries; ++i) { - // - Slice hidden_states_for_sample - NDArray last_hidden_on_device = GetTokenHidden(hidden_states_nd, i); - CHECK(i < static_cast(prob_dist.size())); - CHECK(prob_dist[i].defined()); - mstates[i]->AddDraftToken(sample_results[i], prob_dist[i], last_hidden_on_device); + mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]); estate->stats.total_draft_length += 1; } } @@ -183,26 +178,6 @@ class EagleBatchDraftActionObj : public EngineActionObj { return true; } - /*! - * \brief Get one item from a hidden_states array, which corresponds to the last token. - * \param hidden_states The hidden_states of all the tokens. - * \param token_pos The desired token position in the sequence. - * \return The desired token's hidden_states - */ - NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { - ICHECK_EQ(hidden_states->ndim, 3); - NDArray last_hidden_on_device = - NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); - - int64_t ndata = hidden_states->shape[2]; - const int16_t* __restrict p_hidden = - static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + - (token_pos * ndata); - - last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); - return last_hidden_on_device; - } - /*! \brief The model to run draft generation in speculative decoding. */ Array models_; /*! \brief The logit processor. */ @@ -211,20 +186,26 @@ class EagleBatchDraftActionObj : public EngineActionObj { Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief Event trace recorder. */ Optional trace_recorder_; /*! \brief Draft proposal length */ int draft_length_; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; EngineAction EngineAction::EagleBatchDraft(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, Optional trace_recorder, int draft_length) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(trace_recorder), draft_length)); + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(trace_recorder), draft_length)); } } // namespace serve diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index b259417050..6b23035f78 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -29,12 +29,14 @@ class EagleBatchVerifyActionObj : public EngineActionObj { public: explicit EagleBatchVerifyActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)), rng_(RandomGenerator::GetInstance()) {} @@ -63,14 +65,13 @@ class EagleBatchVerifyActionObj : public EngineActionObj { Array generation_cfg; std::vector rngs; std::vector> draft_output_tokens; - std::vector> draft_output_prob_dist; request_internal_ids.reserve(num_rsentries); all_tokens_to_verify.reserve(total_draft_length); verify_request_mstates.reserve(num_rsentries); rngs.reserve(num_rsentries); generation_cfg.reserve(num_rsentries); draft_output_tokens.reserve(num_rsentries); - draft_output_prob_dist.reserve(num_rsentries); + draft_token_slots_.clear(); for (int i = 0; i < num_rsentries; ++i) { RequestModelState verify_mstate = rsentries[i]->mstates[verify_model_id_]; @@ -78,19 +79,24 @@ class EagleBatchVerifyActionObj : public EngineActionObj { request_internal_ids.push_back(verify_mstate->internal_id); ICHECK(!draft_lengths.empty()); ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_tokens.size()); - ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_prob_dist.size()); + ICHECK_EQ(draft_lengths[i], draft_mstate->draft_token_slots.size()); // the last committed token + all the draft tokens but the last one. all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().sampled_token_id.first); + draft_token_slots_.push_back(0); // placeholder for the last committed token for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()); ++j) { all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].sampled_token_id.first); + draft_token_slots_.push_back(draft_mstate->draft_token_slots[j]); } verify_request_mstates.push_back(verify_mstate); generation_cfg.push_back(rsentries[i]->request->generation_cfg); rngs.push_back(&rsentries[i]->rng); draft_output_tokens.push_back(draft_mstate->draft_output_tokens); - draft_output_prob_dist.push_back(draft_mstate->draft_output_prob_dist); } + NDArray draft_probs_on_device = models_[draft_model_id_]->GatherDraftProbs( + model_workspaces_[verify_model_id_].draft_probs_storage, draft_token_slots_, + &model_workspaces_[verify_model_id_].draft_probs); + std::vector cum_verify_lengths = {0}; cum_verify_lengths.reserve(num_rsentries + 1); std::vector verify_lengths; @@ -106,12 +112,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding"); RECORD_EVENT(trace_recorder_, request_ids, "start verify"); - ObjectRef fused_hidden_states = models_[verify_model_id_]->FuseEmbedHidden( - embeddings, NDArray(), 1, cum_verify_lengths[num_rsentries]); - NDArray hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden( - fused_hidden_states, request_internal_ids, verify_lengths); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], 1); + ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden( + embeddings, request_internal_ids, verify_lengths); NDArray logits = models_[verify_model_id_]->GetLogits(hidden_states, 1, cum_verify_lengths[num_rsentries]); RECORD_EVENT(trace_recorder_, request_ids, "finish verify"); @@ -135,10 +137,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj { std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokensWithProbAfterTopP( renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs, - draft_output_tokens, draft_output_prob_dist); + draft_output_tokens, draft_probs_on_device); ICHECK_EQ(sample_results_arr.size(), num_rsentries); - std::vector last_hidden_states; + std::vector last_accepted_hidden_positions; + last_accepted_hidden_positions.reserve(num_rsentries); for (int i = 0; i < num_rsentries; ++i) { const std::vector& sample_results = sample_results_arr[i]; int accept_length = sample_results.size(); @@ -147,6 +150,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { rsentries[i]->mstates[verify_model_id_]->CommitToken(sample_result); rsentries[i]->mstates[draft_model_id_]->CommitToken(sample_result); } + estate->stats.UpdateSpecDecodingStats(cum_verify_lengths[i + 1] - cum_verify_lengths[i], + accept_length); estate->stats.total_accepted_length += accept_length - 1; // - Minus one because the last draft token has no kv cache entry // - Take max with 0 in case of all accepted. @@ -163,24 +168,19 @@ class EagleBatchVerifyActionObj : public EngineActionObj { rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length - 1); } // clear the draft model state entries - rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(); - // - Slice hidden_states_for_sample - NDArray last_hidden_on_device = - GetTokenHidden(hidden_states, (cum_verify_lengths[i] + accept_length - 1)); - last_hidden_states.push_back(last_hidden_on_device); + rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(&draft_token_slots_); + draft_token_workspace_manager_->FreeSlots(draft_token_slots_); + // - Slice and save hidden_states_for_sample + last_accepted_hidden_positions.push_back(cum_verify_lengths[i] + accept_length - 1); } { // One step draft for the following steps - NDArray hidden_states_nd{nullptr}; - ObjectRef next_hidden_states = model_workspaces_[draft_model_id_].hidden_states; - // Concat last hidden_states - hidden_states_nd = - models_[draft_model_id_]->ConcatLastHidden(last_hidden_states, &next_hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 2); - ICHECK_EQ(hidden_states_nd->shape[0], num_rsentries); - hidden_states_nd = hidden_states_nd.CreateView( - {hidden_states_nd->shape[0], 1, hidden_states_nd->shape[1]}, hidden_states_nd->dtype); + + // Gather hidden states for the last accepted tokens. + hidden_states = models_[draft_model_id_]->GatherHiddenStates( + hidden_states, last_accepted_hidden_positions, + &model_workspaces_[draft_model_id_].hidden_states); std::vector input_tokens; Array mstates; @@ -202,18 +202,17 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // - Invoke model decode. RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); - ObjectRef fused_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( - embeddings, hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); - hidden_states_nd = models_[draft_model_id_]->BatchDecodeToLastHidden(fused_hidden_states, - request_internal_ids); + ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( + embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden( + fused_embedding_hidden_states, request_internal_ids); if (models_[draft_model_id_]->CanGetLogits()) { - logits = models_[draft_model_id_]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, + logits = models_[draft_model_id_]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); } else { // - Use base model's head. - logits = - models_[0]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + logits = models_[0]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); } RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); ICHECK_EQ(logits->ndim, 3); @@ -232,20 +231,23 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // Fill range [0, num_rsentries) into `sample_indices`. std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); - std::vector prob_dist; NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( probs_on_device, sample_indices, request_ids, generation_cfg); std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( - renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), num_rsentries); + // - Slice and save hidden_states_for_sample + draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_); + models_[draft_model_id_]->ScatterDraftProbs( + renormalized_probs, draft_token_slots_, + &model_workspaces_[verify_model_id_].draft_probs_storage); + models_[draft_model_id_]->ScatterHiddenStates( + hidden_states, draft_token_slots_, + &model_workspaces_[verify_model_id_].draft_hidden_states_storage); // - Add draft token to the state. for (int i = 0; i < num_rsentries; ++i) { - // - Slice hidden_states_for_sample - NDArray last_hidden_on_device = GetTokenHidden(hidden_states_nd, i); - CHECK(i < static_cast(prob_dist.size())); - CHECK(prob_dist[i].defined()); - mstates[i]->AddDraftToken(sample_results[i], prob_dist[i], last_hidden_on_device); + mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]); estate->stats.total_draft_length += 1; } } @@ -292,8 +294,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { total_required_pages += num_require_pages; } while (!CanVerify(total_required_pages)) { - RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + RequestStateEntry preempted = PreemptLastRunningRequestStateEntry( + estate, models_, draft_token_workspace_manager_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { total_draft_length -= draft_lengths.back(); total_required_pages -= num_page_requirement.back(); @@ -311,26 +313,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj { return num_required_pages <= num_available_pages; } - /*! - * \brief Get one item from a hidden_states array, which corresponds to the last token. - * \param hidden_states The hidden_states of all the tokens. - * \param token_pos The desired token position in the sequence. - * \return The desired token's hidden_states - */ - NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { - ICHECK_EQ(hidden_states->ndim, 3); - NDArray last_hidden_on_device = - NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); - - int64_t ndata = hidden_states->shape[2]; - const int16_t* __restrict p_hidden = - static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + - (token_pos * ndata); - - last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); - return last_hidden_on_device; - } - /*! * \brief The model to run decode in. When there are multiple * models, the `Step` function of the created action will not take effect. @@ -342,6 +324,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief The engine config. */ EngineConfig engine_config_; /*! \brief Event trace recorder. */ @@ -352,16 +336,19 @@ class EagleBatchVerifyActionObj : public EngineActionObj { const int verify_model_id_ = 0; const int draft_model_id_ = 1; const float eps_ = 1e-5; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; -EngineAction EngineAction::EagleBatchVerify(Array models, LogitProcessor logit_processor, - Sampler sampler, - std::vector model_workspaces, - EngineConfig engine_config, - Optional trace_recorder) { +EngineAction EngineAction::EagleBatchVerify( + Array models, LogitProcessor logit_processor, Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, + Optional trace_recorder) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(engine_config), std::move(trace_recorder))); + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(engine_config), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index a687e7eb7f..80de254ca8 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -24,12 +24,14 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { explicit EagleNewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)) {} @@ -81,8 +83,8 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // - Get embedding and run prefill for each model. std::vector prefill_lengths; prefill_lengths.resize(/*size=*/num_rsentries, /*value=*/-1); - NDArray hidden_states_for_input{nullptr}; - NDArray hidden_states_for_sample{nullptr}; + ObjectRef hidden_states_for_input{nullptr}; + ObjectRef hidden_states_for_sample{nullptr}; NDArray logits_for_sample{nullptr}; // A map used to record the entry and child_idx pair needed to fork sequence. // The base model (id 0) should record all the pairs and all the small models @@ -107,7 +109,7 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { } ICHECK(mstate->draft_output_tokens.empty()); - ICHECK(mstate->draft_output_prob_dist.empty()); + ICHECK(mstate->draft_token_slots.empty()); if (status_before_prefill[i] == RequestStateStatus::kPending) { // Add the sequence to the model, or fork the sequence from its parent. if (rsentry->parent_idx == -1) { @@ -165,14 +167,17 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { } RECORD_EVENT(trace_recorder_, request_ids, "start prefill"); - ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden( - embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length); - NDArray hidden_states = models_[model_id]->BatchPrefillToLastHidden( - fused_hidden_states, request_internal_ids, prefill_lengths); + ObjectRef embedding_or_hidden_states{nullptr}; + if (model_id == 0) { + embedding_or_hidden_states = embeddings; + } else { + embedding_or_hidden_states = models_[model_id]->FuseEmbedHidden( + embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length); + } + // hidden_states: (b * s, h) + ObjectRef hidden_states = models_[model_id]->BatchPrefillToLastHidden( + embedding_or_hidden_states, request_internal_ids, prefill_lengths); RECORD_EVENT(trace_recorder_, request_ids, "finish prefill"); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], 1); - ICHECK_EQ(hidden_states->shape[1], cum_prefill_length); if (model_id == 0) { // We only need to sample for model 0 in prefill. @@ -181,14 +186,23 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // Whether to use base model to get logits. int sample_model_id = !models_[model_id]->CanGetLogits() ? 0 : model_id; - hidden_states_for_sample = models_[sample_model_id]->BatchSelectLastHidden( - hidden_states, request_internal_ids, prefill_lengths); + + std::vector logit_positions; + { + // Prepare the logit positions + logit_positions.reserve(prefill_lengths.size()); + int total_len = 0; + for (int i = 0; i < prefill_lengths.size(); ++i) { + total_len += prefill_lengths[i]; + logit_positions.push_back(total_len - 1); + } + } + // hidden_states_for_sample: (b * s, h) + hidden_states_for_sample = models_[sample_model_id]->GatherHiddenStates( + hidden_states, logit_positions, &model_workspaces_[model_id].hidden_states); + // logits_for_sample: (b * s, v) logits_for_sample = models_[sample_model_id]->GetLogits(hidden_states_for_sample, 1, num_rsentries); - ICHECK_EQ(hidden_states_for_sample->ndim, 3); - ICHECK_EQ(hidden_states_for_sample->shape[0], 1); - ICHECK_EQ(hidden_states_for_sample->shape[1], num_rsentries); - // - Update logits. ICHECK(logits_for_sample.defined()); Array generation_cfg; @@ -276,18 +290,18 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { rsentry_activated.push_back(true); } } - std::vector prob_dist; + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( probs_on_device, sample_indices, request_ids, generation_cfg); std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( - renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); // - Update the committed tokens of states. // - If a request is first-time prefilled, set the prefill finish time. auto tnow = std::chrono::high_resolution_clock::now(); - for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { - if (model_id == 0) { + if (model_id == 0) { + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { for (int mid = 0; mid < static_cast(models_.size()); ++mid) { rsentries_for_sample[i]->mstates[mid]->CommitToken(sample_results[i]); if (!rsentry_activated[i]) { @@ -301,13 +315,20 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { rsentries_for_sample[i]->tprefill_finish = tnow; } - } else { - // - Slice hidden_states_for_sample - NDArray last_hidden_on_device = GetTokenHidden(hidden_states_for_sample, i); - CHECK(i < static_cast(prob_dist.size())); - CHECK(prob_dist[i].defined()); - rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(sample_results[i], prob_dist[i], - last_hidden_on_device); + } + } else { + // - Slice and save hidden_states_for_sample + draft_token_workspace_manager_->AllocSlots(rsentries_for_sample.size(), + &draft_token_slots_); + models_[model_id]->ScatterDraftProbs(renormalized_probs, draft_token_slots_, + &model_workspaces_[0].draft_probs_storage); + if (engine_config_->spec_draft_length > 1) { + models_[model_id]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_, + &model_workspaces_[0].draft_hidden_states_storage); + } + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { + rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(sample_results[i], + draft_token_slots_[i]); estate->stats.total_draft_length += 1; } } @@ -554,26 +575,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { ICHECK(false) << "Cannot reach here"; } - /*! - * \brief Get one item from a hidden_states array, which corresponds to the last token. - * \param hidden_states The hidden_states of all the tokens. - * \param token_pos The desired token position in the sequence. - * \return The desired token's hidden_states - */ - NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { - ICHECK_EQ(hidden_states->ndim, 3); - NDArray last_hidden_on_device = - NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); - - int64_t ndata = hidden_states->shape[2]; - const int16_t* __restrict p_hidden = - static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + - (token_pos * ndata); - - last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); - return last_hidden_on_device; - } - /*! \brief The models to run prefill in. */ Array models_; /*! \brief The logit processor. */ @@ -582,20 +583,25 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief The engine config. */ EngineConfig engine_config_; /*! \brief Event trace recorder. */ Optional trace_recorder_; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; -EngineAction EngineAction::EagleNewRequestPrefill(Array models, - LogitProcessor logit_processor, Sampler sampler, - std::vector model_workspaces, - EngineConfig engine_config, - Optional trace_recorder) { +EngineAction EngineAction::EagleNewRequestPrefill( + Array models, LogitProcessor logit_processor, Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, + Optional trace_recorder) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(engine_config), std::move(trace_recorder))); + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(engine_config), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index b4192a04f1..f801b1e282 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -100,7 +100,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { } ICHECK(mstate->draft_output_tokens.empty()); - ICHECK(mstate->draft_output_prob_dist.empty()); + ICHECK(mstate->draft_token_slots.empty()); if (status_before_prefill[i] == RequestStateStatus::kPending) { // Add the sequence to the model, or fork the sequence from its parent. if (rsentry->parent_idx == -1) { diff --git a/cpp/serve/engine_state.cc b/cpp/serve/engine_state.cc index 563f0e7b13..7847f53fd5 100644 --- a/cpp/serve/engine_state.cc +++ b/cpp/serve/engine_state.cc @@ -12,16 +12,25 @@ namespace serve { String EngineStats::AsJSON() const { picojson::object config; - config["single_token_prefill_latency"] = - picojson::value(request_total_prefill_time / total_prefill_length); - config["single_token_decode_latency"] = - picojson::value(request_total_decode_time / total_decode_length); + config["single_token_prefill_latency"] = picojson::value( + total_prefill_length > 0 ? request_total_prefill_time / total_prefill_length : 0.0); + config["single_token_decode_latency"] = picojson::value( + total_decode_length > 0 ? request_total_decode_time / total_decode_length : 0.0); config["engine_total_prefill_time"] = picojson::value(engine_total_prefill_time); config["engine_total_decode_time"] = picojson::value(engine_total_decode_time); config["total_prefill_tokens"] = picojson::value(total_prefill_length); config["total_decode_tokens"] = picojson::value(total_decode_length); config["total_accepted_tokens"] = picojson::value(total_accepted_length); config["total_draft_tokens"] = picojson::value(total_draft_length); + auto f_vector_to_array = [](const std::vector& vec) { + picojson::array arr; + for (int64_t v : vec) { + arr.push_back(picojson::value(v)); + } + return picojson::value(arr); + }; + config["accept_count"] = f_vector_to_array(accept_count); + config["draft_count"] = f_vector_to_array(draft_count); return picojson::value(config).serialize(true); } @@ -54,6 +63,19 @@ RequestState EngineStateObj::GetRequestState(Request request) { return it->second; } +void EngineStats::UpdateSpecDecodingStats(int draft_length, int accept_length) { + if (accept_count.size() < draft_length) { + this->accept_count.resize(draft_length, 0); + this->draft_count.resize(draft_length, 0); + } + for (int j = 0; j < draft_length; ++j) { + if (j < accept_length) { + this->accept_count[j]++; + } + this->draft_count[j]++; + } +} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/engine_state.h b/cpp/serve/engine_state.h index ff955a264f..8218cbd73d 100644 --- a/cpp/serve/engine_state.h +++ b/cpp/serve/engine_state.h @@ -34,6 +34,10 @@ struct EngineStats { int64_t total_accepted_length = 0; /*! \brief The total number of speculated draft tokens. */ int64_t total_draft_length = 0; + /*! \brief The number of accepted tokens in speculative decoding. */ + std::vector accept_count; + /*! \brief The number of draft tokens in speculative decoding. */ + std::vector draft_count; /*! * \brief Return the engine runtime statistics in JSON string. @@ -49,6 +53,14 @@ struct EngineStats { String AsJSON() const; /*! \brief Reset all the statistics. */ void Reset(); + + /*! + * \brief Update the statistics of speculative decoding. + * \param draft_length The number of draft tokens (including the last prediction by the base + * model) + * \param accept_length The number of accepted tokens in the speculative decoding. + */ + void UpdateSpecDecodingStats(int draft_length, int accept_length); }; /*! \brief The manager of internal id for requests in engine. */ diff --git a/cpp/serve/event_trace_recorder.cc b/cpp/serve/event_trace_recorder.cc index 8a930002fe..e0311716fd 100644 --- a/cpp/serve/event_trace_recorder.cc +++ b/cpp/serve/event_trace_recorder.cc @@ -51,7 +51,7 @@ class EventTraceRecorderImpl : public EventTraceRecorderObj { void AddEvent(const Array& request_ids, const std::string& event) final { double event_time = std::chrono::duration_cast>( std::chrono::system_clock::now().time_since_epoch()) - .count(); + .count(); // in seconds { std::lock_guard lock(mutex_); @@ -96,16 +96,16 @@ class EventTraceRecorderImpl : public EventTraceRecorderObj { name = event; phase = "i"; } - int64_t event_time_in_ms = static_cast(event_time * 1e6); + int64_t event_time_in_us = static_cast(event_time * 1e6); picojson::object event_json; event_json["name"] = picojson::value(name); event_json["ph"] = picojson::value(phase); - event_json["ts"] = picojson::value(event_time_in_ms); + event_json["ts"] = picojson::value(event_time_in_us); event_json["pid"] = picojson::value(static_cast(1)); event_json["tid"] = picojson::value(request_id); - events_to_sort.push_back({event_time_in_ms, picojson::value(event_json)}); + events_to_sort.push_back({event_time_in_us, picojson::value(event_json)}); } std::sort(events_to_sort.begin(), events_to_sort.end(), fcmp_events); for (auto [timestamp, event] : events_to_sort) { diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index b721eae7c3..16db4a8a03 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -69,7 +69,8 @@ PackedFunc FunctionTable::SessionFuncAsPackedFunc(Session sess, DRef sess_func, }); } -void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config) { +void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config, + Optional session) { local_gpu_device = device; Device null_device{DLDeviceType(0), 0}; int num_shards; @@ -85,33 +86,14 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object this->cached_buffers = Map(); if (num_shards > 1) { - constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool"; - if (Registry::Get(f_create_process_pool) == nullptr) { - LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. " - << "Multi-GPU inference depends on MLC LLM Python API to launch process."; - } - std::string ccl; - if (device.device_type == kDLCUDA) { - ccl = "nccl"; - } else if (device.device_type == kDLROCM) { - ccl = "rccl"; - } else { - LOG(FATAL) << "ValueError: Multi-GPU on device " << DLDeviceType2Str(device.device_type) - << " is not supported. Currently, only NCCL and RCCL are integrated."; - } - std::vector device_ids(num_shards); - for (int i = 0; i < num_shards; ++i) { - device_ids[i] = i; - } + this->sess = session.value(); this->use_disco = true; - this->sess = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); - this->sess->InitCCL(ccl, ShapeTuple(device_ids)); this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"), reload_lib_path, null_device); this->mod_get_func = [this, fmodule_get_function = sess->GetGlobalFunc("runtime.ModuleGetFunction")]( const std::string& name) -> PackedFunc { - DRef func = sess->CallPacked(fmodule_get_function, this->disco_mod, name, false); + DRef func = sess->CallPacked(fmodule_get_function, this->disco_mod, name, true); bool exists = (func->DebugGetFromRemote(0).operator PackedFunc()) != nullptr; if (!exists) { return PackedFunc(nullptr); @@ -236,7 +218,7 @@ void FunctionTable::_InitFunctions() { Module mod = this->use_disco ? this->disco_mod->DebugGetFromRemote(0) : this->local_vm; this->get_logits_func_ = mod_get_func("get_logits"); this->batch_get_logits_func_ = mod_get_func("batch_get_logits"); - this->batch_select_last_hidden_func_ = mod->GetFunction("batch_select_last_hidden_states", true); + this->batch_select_last_hidden_func_ = mod_get_func("batch_select_last_hidden_states"); this->softmax_func_ = mod->GetFunction("softmax_with_temperature", true); this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true); this->apply_penalty_func_ = mod->GetFunction("apply_penalty_inplace", true); @@ -277,6 +259,12 @@ void FunctionTable::_InitFunctions() { this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of"); this->nd_copy_embedding_to_offset_func_ = get_global_func("mlc.copy_embedding_to_offset"); support_backtracking_kv_ = true; + this->tuple_getitem_func_ = get_global_func("vm.builtin.tuple_getitem"); + + this->gather_probs_func_ = mod->GetFunction("gather_probs", true); + this->scatter_probs_func_ = mod->GetFunction("scatter_probs", true); + this->gather_hidden_states_func_ = mod_get_func("gather_hidden_states"); + this->scatter_hidden_states_func_ = mod_get_func("scatter_hidden_states"); } ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) const { @@ -290,8 +278,8 @@ ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) } ObjectRef FunctionTable::CopyToWorker0(const NDArray& host_array, String buffer_cache_key, - ShapeTuple max_reserved_shape) { - if (this->use_disco) { + ShapeTuple max_reserved_shape, bool local_only) { + if (this->use_disco && !local_only) { Device null_device{DLDeviceType(0), 0}; DRef buffer(nullptr); auto it = this->cached_buffers.find(buffer_cache_key); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index b6ea3287ad..2350f3d37a 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -41,7 +41,8 @@ using namespace tvm::runtime; struct FunctionTable { static PackedFunc SessionFuncAsPackedFunc(Session sess, DRef sess_func, String name); - void Init(String reload_lib_path, Device device, picojson::object model_config); + void Init(String reload_lib_path, Device device, picojson::object model_config, + Optional session); ObjectRef LoadParams(const std::string& model_path, Device device); @@ -49,8 +50,18 @@ struct FunctionTable { ObjectRef Empty(ShapeTuple shape, DataType dtype, Device device) const; + /*! + * \brief Copy a host array to the worker or local gpu. + * \param host_array The host array to be copied. + * \param buffer_cache_key The key to the buffer cache. + * \param max_reserved_shape The maximum shape to be reserved in the buffer cache. + * \param local_only Whether to copy the array to the local gpu only. If true, the use_disco + * flag will be ignored. This can be useful for functions that run only on the + * local gpu when disco is enabled. + * \return The array on the worker or local gpu. + */ ObjectRef CopyToWorker0(const NDArray& host_array, String buffer_cache_key, - ShapeTuple max_reserved_shape); + ShapeTuple max_reserved_shape, bool local_only = false); void DebugCallFuncOnAllAllWorker(const String& func_name) const; @@ -109,6 +120,12 @@ struct FunctionTable { PackedFunc nd_view_func_; PackedFunc nd_get_shape_func_; PackedFunc nd_copy_embedding_to_offset_func_; + PackedFunc tuple_getitem_func_; + // Auxiliary functions for speculative decoding. + PackedFunc gather_probs_func_; + PackedFunc scatter_probs_func_; + PackedFunc gather_hidden_states_func_; + PackedFunc scatter_hidden_states_func_; }; } // namespace serve diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index 1ece99099e..55ab0a1dff 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -156,14 +156,14 @@ int32_t EBNFParserImpl::ParseCharacterClass() { continue; } - auto [codepoint, len] = Utf8OrEscapeToCodepoint(cur_, kCustomEscapeMap); + auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_, kCustomEscapeMap); if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { - ThrowParseError("Invalid utf8 sequence"); + ThrowParseError("Invalid UTF8 sequence"); } if (codepoint == static_cast(CharHandlingError::kInvalidEscape)) { ThrowParseError("Invalid escape sequence"); } - Consume(len); + Consume(new_cur - cur_); if (past_is_hyphen) { ICHECK(!elements.empty()); if (elements.back().lower > codepoint) { @@ -194,14 +194,15 @@ int32_t EBNFParserImpl::ParseString() { if (Peek() == '\r' || Peek() == '\n') { ThrowParseError("There should be no newline character in a string literal"); } - auto [codepoint, len] = Utf8OrEscapeToCodepoint(cur_); + + auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_); if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { ThrowParseError("Invalid utf8 sequence"); } if (codepoint == static_cast(CharHandlingError::kInvalidEscape)) { ThrowParseError("Invalid escape sequence"); } - Consume(len); + Consume(new_cur - cur_); character_classes.push_back(builder_.AddCharacterClass({{codepoint, codepoint}})); } if (character_classes.empty()) { diff --git a/cpp/serve/grammar/grammar_serializer.cc b/cpp/serve/grammar/grammar_serializer.cc index fd41517863..c3c2c88baa 100644 --- a/cpp/serve/grammar/grammar_serializer.cc +++ b/cpp/serve/grammar/grammar_serializer.cc @@ -59,12 +59,12 @@ std::string BNFGrammarPrinter::PrintCharacterClass(const RuleExpr& rule_expr) { result += "^"; } for (auto i = 0; i < rule_expr.data_len; i += 2) { - result += CodepointToPrintable(rule_expr[i], kCustomEscapeMap); + result += PrintAsEscaped(rule_expr[i], kCustomEscapeMap); if (rule_expr[i] == rule_expr[i + 1]) { continue; } result += "-"; - result += CodepointToPrintable(rule_expr[i + 1], kCustomEscapeMap); + result += PrintAsEscaped(rule_expr[i + 1], kCustomEscapeMap); } result += "]"; return result; diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index 5c4ef98efe..451127e746 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -510,7 +510,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherResetState") bool MatchCompleteString(GrammarStateMatcher matcher, String str) { auto mutable_node = const_cast(matcher.as()); - auto codepoints = Utf8StringToCodepoints(str.c_str()); + auto codepoints = ParseUTF8(str.c_str()); int accepted_cnt = 0; for (auto codepoint : codepoints) { if (!mutable_node->AcceptCodepoint(codepoint, false)) { @@ -553,9 +553,9 @@ void PrintAcceptedRejectedTokens( // First cast to unsigned, then cast to int std::cerr << static_cast(static_cast(token[0])); } else { - auto codepoints = Utf8StringToCodepoints(token.c_str()); + auto codepoints = ParseUTF8(token.c_str()); for (auto c : codepoints) { - std::cerr << CodepointToPrintable(c); + std::cerr << PrintAsEscaped(c); } } std::cerr << "> "; @@ -571,9 +571,9 @@ void PrintAcceptedRejectedTokens( if (token.size() == 1 && ((unsigned char)token[0] >= 128 || token[0] == 0)) { std::cerr << (int)(unsigned char)token[0]; } else { - auto codepoints = Utf8StringToCodepoints(token.c_str()); + auto codepoints = ParseUTF8(token.c_str()); for (auto c : codepoints) { - std::cerr << CodepointToPrintable(c); + std::cerr << PrintAsEscaped(c); } } std::cerr << "> "; diff --git a/cpp/serve/grammar/grammar_state_matcher_base.h b/cpp/serve/grammar/grammar_state_matcher_base.h index 55c986bb10..5b774d33a4 100644 --- a/cpp/serve/grammar/grammar_state_matcher_base.h +++ b/cpp/serve/grammar/grammar_state_matcher_base.h @@ -156,15 +156,15 @@ inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool } if (tmp_new_stack_tops_.empty()) { if (verbose) { - std::cout << "Codepoint: " << codepoint << " \"" << CodepointToPrintable(codepoint) - << "\" Rejected" << std::endl; + std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Rejected" + << std::endl; } return false; } stack_tops_history_.PushHistory(tmp_new_stack_tops_); if (verbose) { - std::cout << "Codepoint: " << codepoint << " \"" << CodepointToPrintable(codepoint) - << "\" Accepted" << std::endl; + std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Accepted" + << std::endl; std::cout << "Stack after accepting: " << PrintStackState() << std::endl; } #if TVM_LOG_DEBUG diff --git a/cpp/serve/grammar/grammar_state_matcher_preproc.h b/cpp/serve/grammar/grammar_state_matcher_preproc.h index c853ac7e04..f63eee2c5c 100644 --- a/cpp/serve/grammar/grammar_state_matcher_preproc.h +++ b/cpp/serve/grammar/grammar_state_matcher_preproc.h @@ -268,7 +268,7 @@ inline std::shared_ptr GrammarStateMatcher::CreateInitC ptr->special_token_ids.push_back(i); } else { // First replace the special underscore with space. - auto codepoints = Utf8StringToCodepoints(token.c_str()); + auto codepoints = ParseUTF8(token.c_str()); DCHECK(!codepoints.empty() && codepoints[0] != static_cast(CharHandlingError::kInvalidUtf8)) << "Invalid token: " << token; diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index f7190d50ac..7ce70a0d26 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -289,7 +289,7 @@ class LogitProcessorImpl : public LogitProcessorObj { p_penalties[num_token_for_penalty * 3 + 2] = generation_cfg[i]->repetition_penalty; ++num_token_for_penalty; if (j > 0) { - mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray(), NDArray()); + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], /*draft_token_slot=*/-1); } } if (num_token_to_process != 1) { @@ -368,7 +368,7 @@ class LogitProcessorImpl : public LogitProcessorObj { p_seq_ids[token_start_offset + j] = 1; } if (j > 0) { - mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray(), NDArray()); + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], /*draft_token_slot=*/-1); } } if (token_number != 1) { diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 27a0043850..be76b40e2e 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -5,7 +5,6 @@ */ #include "model.h" -#include #include #include #include @@ -26,10 +25,27 @@ class ModelImpl; TVM_REGISTER_OBJECT_TYPE(ModelObj); -Model Model::Create(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled) { - return Model( - make_object(reload_lib_path, model_path, device, max_num_sequence, trace_enabled)); +Model Model::Create(String reload_lib_path, String model_path, const picojson::object& model_config, + DLDevice device, int max_num_sequence, const Optional& session, + bool trace_enabled) { + return Model(make_object(reload_lib_path, model_path, model_config, device, + max_num_sequence, session, trace_enabled)); +} + +picojson::object Model::LoadModelConfig(const String& model_path) { + picojson::object model_config; + std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); + std::ostringstream config_ostream; + ICHECK(config_istream); + config_ostream << config_istream.rdbuf(); + std::string config_str = config_ostream.str(); + picojson::value config_json; + std::string err = picojson::parse(config_json, config_str); + if (!err.empty()) { + LOG(FATAL) << err; + } + picojson::object config = config_json.get(); + return config; } class ModelImpl : public ModelObj { @@ -38,23 +54,16 @@ class ModelImpl : public ModelObj { * \brief Constructor of ModelImpl. * \sa Model::Create */ - explicit ModelImpl(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled) + explicit ModelImpl(String reload_lib_path, String model_path, picojson::object model_config, + DLDevice device, int max_num_sequence, const Optional& session, + bool trace_enabled) : device_(device) { // Step 1. Process model config json string. - picojson::object model_config; - { - std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); - std::ostringstream config_ostream; - ICHECK(config_istream); - config_ostream << config_istream.rdbuf(); - std::string config_str = config_ostream.str(); - model_config = LoadModelConfigJSON(config_str); - } + LoadModelConfigJSON(model_config); // Step 2. Initialize vm, we use the packed function mechanism // so there is no explicit abi dependency on these extra // classes other than basic tvm runtime. - this->ft_.Init(reload_lib_path, device_, model_config); + this->ft_.Init(reload_lib_path, device_, model_config, session); // Step 3. Load params in nd-array cache. this->params_ = ft_.LoadParams(model_path, device_); // Step 4. Set max_num_sequence @@ -127,35 +136,23 @@ class ModelImpl : public ModelObj { return ft_.get_logits_func_.defined() && ft_.batch_get_logits_func_.defined(); } - NDArray GetLogits(const ObjectRef& last_hidden_states, int batch_size, int seq_len) final { + NDArray GetLogits(const ObjectRef& hidden_states, int batch_size, int seq_len) final { NVTXScopedRange nvtx_scope("GetLogits"); CHECK(ft_.get_logits_func_.defined()) << "`get_logits` function is not found in the model."; - ObjectRef hidden_states_dref_or_nd; - CHECK(!last_hidden_states->IsInstance()); - // hidden_states: (b, s, h) - NDArray hidden_states = Downcast(last_hidden_states); - ICHECK_NE(hidden_size_, -1); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], batch_size); - ICHECK_EQ(hidden_states->shape[1], seq_len); - ICHECK_EQ(hidden_states->shape[2], hidden_size_); - ICHECK_EQ(hidden_states->device.device_type, device_.device_type); - ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - - hidden_states = - hidden_states.CreateView({batch_size * seq_len, hidden_size_}, hidden_states->dtype); - - // This copy can be avoided by not copying the hidden states to engine. - hidden_states_dref_or_nd = ft_.CopyToWorker0( - hidden_states, "hidden_states", {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); + ObjectRef hidden_states_dref_or_nd{nullptr}; + if (!ft_.use_disco && hidden_states->IsInstance()) { + hidden_states_dref_or_nd = Downcast(hidden_states)->DebugGetFromRemote(0); + } else { + hidden_states_dref_or_nd = hidden_states; + } ObjectRef ret = ft_.get_logits_func_(hidden_states_dref_or_nd, params_); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } NDArray logits{nullptr}; - if (ret->IsInstance()) { + if (ft_.use_disco) { logits = Downcast(ret)->DebugGetFromRemote(0); } else { logits = Downcast(ret); @@ -167,148 +164,11 @@ class ModelImpl : public ModelObj { return logits.CreateView({batch_size, seq_len, logits->shape[1]}, logits->dtype); } - NDArray BatchGetLogits(const ObjectRef& last_hidden_states, const std::vector& seq_ids, - const std::vector& lengths) { - NVTXScopedRange nvtx_scope("BatchGetLogits"); - CHECK(!seq_ids.empty()); - CHECK_EQ(seq_ids.size(), lengths.size()); - int num_sequences = seq_ids.size(); - int total_length = 0; - - int* p_logit_pos = static_cast(logit_pos_arr_->data); - for (int i = 0; i < num_sequences; ++i) { - total_length += lengths[i]; - p_logit_pos[i] = total_length - 1; - } - NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); - ObjectRef logit_pos_dref_or_nd = - ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); - - CHECK(ft_.batch_get_logits_func_.defined()) - << "`batch_get_logits` function is not found in the model."; - - ObjectRef hidden_states_dref_or_nd; - CHECK(!last_hidden_states->IsInstance()); - // hidden_states: (b, s, h) - NDArray hidden_states = Downcast(last_hidden_states); - ICHECK_NE(hidden_size_, -1); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], 1); - ICHECK_EQ(hidden_states->shape[1], total_length); - ICHECK_EQ(hidden_states->shape[2], hidden_size_); - ICHECK_EQ(hidden_states->device.device_type, device_.device_type); - ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - - hidden_states = hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); - - // This copy can be avoided by not copying the hidden states to engine. - hidden_states_dref_or_nd = ft_.CopyToWorker0( - hidden_states, "hidden_states", {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); - - ObjectRef ret = - ft_.batch_get_logits_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_); - if (trace_enabled_) { - TVMSynchronize(device_.device_type, device_.device_id, nullptr); - } - - NDArray logits; - logits = Downcast(ret); - CHECK(logits.defined()); - // logits: (b * s, v) - ICHECK_EQ(logits->ndim, 2); - ICHECK_EQ(logits->shape[0], num_sequences); - return logits.CreateView({1, num_sequences, logits->shape[1]}, logits->dtype); - } - - NDArray BatchSelectLastHidden(const ObjectRef& last_hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) { - NVTXScopedRange nvtx_scope("BatchSelectLastHidden"); - CHECK(!seq_ids.empty()); - CHECK_EQ(seq_ids.size(), lengths.size()); - int num_sequences = seq_ids.size(); - int total_length = 0; - - int* p_logit_pos = static_cast(logit_pos_arr_->data); - for (int i = 0; i < num_sequences; ++i) { - total_length += lengths[i]; - p_logit_pos[i] = total_length - 1; - } - NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); - - // This step runs on the engine thread. - // By temporarily turning off the disco flag, this copies the logit_pos_nd to the cached device - // tensor without actually copying to the worker. - bool use_disco = ft_.use_disco; - ft_.use_disco = false; - ObjectRef logit_pos_dref_or_nd = - ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); - ft_.use_disco = use_disco; - - CHECK(ft_.batch_select_last_hidden_func_.defined()) - << "`batch_select_last_hidden_states` function is not found in the model."; - - ObjectRef hidden_states_dref_or_nd; - CHECK(!last_hidden_states->IsInstance()); - // hidden_states: (b, s, h) - NDArray hidden_states = Downcast(last_hidden_states); - ICHECK_NE(hidden_size_, -1); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], 1); - ICHECK_EQ(hidden_states->shape[1], total_length); - ICHECK_EQ(hidden_states->shape[2], hidden_size_); - ICHECK_EQ(hidden_states->device.device_type, device_.device_type); - ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - - hidden_states_dref_or_nd = - hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); - - ObjectRef ret = - ft_.batch_select_last_hidden_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd); - if (trace_enabled_) { - TVMSynchronize(device_.device_type, device_.device_id, nullptr); - } - - NDArray hidden; - hidden = Downcast(ret); - // hidden: (b * s, v) - ICHECK_EQ(hidden->ndim, 2); - ICHECK_EQ(hidden->shape[0], num_sequences); - return hidden.CreateView({1, num_sequences, hidden->shape[1]}, hidden->dtype); - } - - NDArray ConcatLastHidden(std::vector& hidden_states, ObjectRef* dst) final { - NVTXScopedRange nvtx_scope("ConcatLastHidden"); - - CHECK(dst->defined()); - - int cum_length = 0; - ICHECK_GE(hidden_states.size(), 1); - for (auto hidden : hidden_states) { - ICHECK_EQ(hidden->ndim, 1); - // No ICHECK_EQ(hidden->shape[0], hidden_size_) here to allow different hidden_sizes. - hidden = hidden.CreateView({1, hidden_size_}, hidden->dtype); - // Reuse the copy embedding function - ObjectRef hidden_dref_or_nd = - ft_.CopyToWorker0(hidden, "hidden_for_concat", {1, hidden_size_}); - ft_.nd_copy_embedding_to_offset_func_(hidden_dref_or_nd, *dst, cum_length); - cum_length += 1; - } - NDArray ret{nullptr}; - if ((*dst)->IsInstance()) { - ret = Downcast(*dst)->DebugGetFromRemote(0); - } else { - ret = Downcast(*dst); - } - ret = ret.CreateView({cum_length, hidden_size_}, hidden_states[0]->dtype); - return ret; - } - ObjectRef FuseEmbedHidden(const ObjectRef& embeddings, const ObjectRef& previous_hidden_states, int batch_size, int seq_len) final { NVTXScopedRange nvtx_scope("FuseEmbedHidden"); - ObjectRef embeddings_dref_or_nd; + ObjectRef embeddings_dref_or_nd{nullptr}; if (!embeddings->IsInstance()) { // embeddings: (n, h) NDArray embeddings_nd = Downcast(embeddings); @@ -316,51 +176,33 @@ class ModelImpl : public ModelObj { ICHECK_EQ(embeddings_nd->ndim, 2); ICHECK_GE(embeddings_nd->shape[0], batch_size * seq_len); ICHECK_EQ(embeddings_nd->shape[1], hidden_size_); - ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type); - ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id); embeddings_dref_or_nd = embeddings_nd.CreateView({batch_size * seq_len, hidden_size_}, embeddings_nd->dtype); - - if (!ft_.fuse_embed_hidden_func_.defined() || !previous_hidden_states.defined()) { - // Model has no support for fuse_embed_hidden_states or this is the first model (base model) - return embeddings_nd.CreateView({batch_size, seq_len, hidden_size_}, embeddings_nd->dtype); - } } else { ShapeTuple embedding_shape{batch_size * seq_len, hidden_size_}; embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); - - if (!ft_.fuse_embed_hidden_func_.defined() || !previous_hidden_states.defined()) { - // Model has no support for fuse_embed_hidden_states or this is the first model (base model) - ShapeTuple embedding_shape{batch_size, seq_len, hidden_size_}; - return ft_.nd_view_func_(embeddings, embedding_shape); - } } - NDArray hidden_states = Downcast(previous_hidden_states); - CHECK(hidden_states.defined()); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], batch_size); - ICHECK_EQ(hidden_states->shape[1], seq_len); - ICHECK_EQ(hidden_states->shape[2], hidden_size_); - ICHECK_EQ(hidden_states->device.device_type, device_.device_type); - ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - NDArray hidden_states_2d = - hidden_states.CreateView({batch_size * seq_len, hidden_size_}, hidden_states->dtype); - auto hidden_states_dref_or_nd = - ft_.CopyToWorker0(hidden_states_2d, "hidden_states_2d", - {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); - - ObjectRef ret = - ft_.fuse_embed_hidden_func_(embeddings_dref_or_nd, hidden_states_dref_or_nd, params_); + ObjectRef previous_hidden_states_dref_or_nd{nullptr}; + if (!ft_.use_disco && previous_hidden_states->IsInstance()) { + previous_hidden_states_dref_or_nd = + Downcast(previous_hidden_states)->DebugGetFromRemote(0); + } else { + previous_hidden_states_dref_or_nd = previous_hidden_states; + } + ObjectRef fused = ft_.fuse_embed_hidden_func_(embeddings_dref_or_nd, + previous_hidden_states_dref_or_nd, params_); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } - if (!ret->IsInstance()) { - NDArray fused = Downcast(ret); - return fused.CreateView({batch_size, seq_len, hidden_size_}, fused->dtype); + ShapeTuple out_shape{batch_size, seq_len, hidden_size_}; + if (ft_.use_disco) { + return ft_.nd_view_func_(fused, out_shape); } else { - ShapeTuple fused_shape{batch_size, seq_len, hidden_size_}; - return ft_.nd_view_func_(ret, fused_shape); + NDArray fused_nd = Downcast(fused); + ICHECK_EQ(fused_nd->ndim, 2); + ICHECK_EQ(fused_nd->shape[0], batch_size * seq_len); + return fused_nd.CreateView(out_shape, fused_nd->dtype); } } @@ -435,9 +277,9 @@ class ModelImpl : public ModelObj { return logits; } - NDArray BatchPrefillToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) final { + ObjectRef BatchPrefillToLastHidden(const ObjectRef& embedding_or_hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) final { NVTXScopedRange nvtx_scope("BatchPrefillToLastHidden"); CHECK(!seq_ids.empty()); CHECK_EQ(seq_ids.size(), lengths.size()); @@ -448,19 +290,15 @@ class ModelImpl : public ModelObj { total_length += lengths[i]; } - ObjectRef hidden_states_dref_or_nd; - if (!hidden_states->IsInstance()) { - // hidden_states: (1, n, h) - NDArray hidden_states_nd = Downcast(hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 3); - ICHECK_EQ(hidden_states_nd->shape[0], 1); - ICHECK_EQ(hidden_states_nd->shape[1], total_length); - ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); - hidden_states_dref_or_nd = - hidden_states_nd.CreateView({1, total_length, hidden_size_}, hidden_states_nd->dtype); + ObjectRef embedding_or_hidden_states_dref_or_nd{nullptr}; + ShapeTuple hidden_states_shape{1, total_length, hidden_size_}; + if (!ft_.use_disco) { + NDArray embedding_or_hidden_states_nd = Downcast(embedding_or_hidden_states); + embedding_or_hidden_states_dref_or_nd = embedding_or_hidden_states_nd.CreateView( + hidden_states_shape, embedding_or_hidden_states_nd->dtype); } else { - ShapeTuple hidden_states_shape{1, total_length, hidden_size_}; - hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); + embedding_or_hidden_states_dref_or_nd = + ft_.nd_view_func_(embedding_or_hidden_states, hidden_states_shape); } CHECK(ft_.prefill_to_last_hidden_func_.defined()) @@ -475,32 +313,34 @@ class ModelImpl : public ModelObj { ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); // args: embeddings, logit_pos, kv_cache, params - ObjectRef ret; + ObjectRef result{nullptr}; if (seq_ids.size() == 1) { CHECK(ft_.single_batch_prefill_to_last_hidden_func_.defined()) << "`single_batch_prefill_to_last_hidden_states` function is not found in the model."; - ret = ft_.single_batch_prefill_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, - params_); - } else { - ret = ft_.prefill_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); - } - NDArray last_hidden_states; - if (ft_.use_disco) { - Array result = Downcast(ret)->DebugGetFromRemote(0); - last_hidden_states = Downcast(result[0]); + result = ft_.single_batch_prefill_to_last_hidden_func_(embedding_or_hidden_states_dref_or_nd, + kv_cache_, params_); } else { - last_hidden_states = Downcast>(ret)[0]; + result = ft_.prefill_to_last_hidden_func_(embedding_or_hidden_states_dref_or_nd, kv_cache_, + params_); } + ObjectRef hidden_states = ft_.tuple_getitem_func_(result, 0); + if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } ft_.kv_cache_end_forward_func_(kv_cache_); - // hidden_states: (1, total_length, v) - ICHECK_EQ(last_hidden_states->ndim, 3); - ICHECK_EQ(last_hidden_states->shape[0], 1); - ICHECK_EQ(last_hidden_states->shape[1], total_length); - return last_hidden_states; + ShapeTuple out_shape{total_length, hidden_size_}; + if (ft_.use_disco) { + return ft_.nd_view_func_(hidden_states, out_shape); + } else { + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], 1); + ICHECK_EQ(hidden_states_nd->shape[1], total_length); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + return hidden_states_nd.CreateView(out_shape, hidden_states_nd->dtype); + } } NDArray BatchDecode(const ObjectRef& embeddings, const std::vector& seq_ids) final { @@ -563,8 +403,8 @@ class ModelImpl : public ModelObj { return logits; } - NDArray BatchDecodeToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids) final { + ObjectRef BatchDecodeToLastHidden(const ObjectRef& hidden_states_dref_or_nd, + const std::vector& seq_ids) final { NVTXScopedRange nvtx_scope("BatchDecodeToLastHidden"); int num_sequence = seq_ids.size(); @@ -574,21 +414,6 @@ class ModelImpl : public ModelObj { ICHECK(ft_.kv_cache_end_forward_func_.defined()); ICHECK(kv_cache_.defined()) << "KV cache has not been initialized."; - ObjectRef hidden_states_dref_or_nd; - if (!hidden_states->IsInstance()) { - // hidden_states: (1, n, h) - NDArray hidden_states_nd = Downcast(hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 3); - ICHECK_EQ(hidden_states_nd->shape[0], num_sequence); - ICHECK_EQ(hidden_states_nd->shape[1], 1); - ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); - hidden_states_dref_or_nd = - hidden_states_nd.CreateView({num_sequence, 1, hidden_size_}, hidden_states_nd->dtype); - } else { - ShapeTuple hidden_states_shape{num_sequence, 1, hidden_size_}; - hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); - } - // Reserve in KV cache for the lengths of the input. // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); @@ -596,32 +421,34 @@ class ModelImpl : public ModelObj { ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); // args: embeddings, kv_cache, params - ObjectRef ret; + ObjectRef result{nullptr}; if (seq_ids.size() == 1) { CHECK(ft_.single_batch_decode_to_last_hidden_func_.defined()) << "`decode_to_last_hidden_states` function is not found in the model."; - ret = ft_.single_batch_decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, - params_); + result = ft_.single_batch_decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, + params_); } else { - ret = ft_.decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); - } - NDArray last_hidden_states; - if (ft_.use_disco) { - Array result = Downcast(ret)->DebugGetFromRemote(0); - last_hidden_states = Downcast(result[0]); - } else { - last_hidden_states = Downcast>(ret)[0]; + result = ft_.decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); } + ft_.kv_cache_end_forward_func_(kv_cache_); + ObjectRef hidden_states = ft_.tuple_getitem_func_(result, 0); + if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } - ft_.kv_cache_end_forward_func_(kv_cache_); - // hidden_states: (b, 1, v) - ICHECK_EQ(last_hidden_states->ndim, 3); - ICHECK_EQ(last_hidden_states->shape[0], num_sequence); - ICHECK_EQ(last_hidden_states->shape[1], 1); - return last_hidden_states; + // hidden_states: (b, 1, v) to (b, v) + ShapeTuple out_shape{num_sequence, hidden_size_}; + if (ft_.use_disco) { + return ft_.nd_view_func_(hidden_states, out_shape); + } else { + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], num_sequence); + ICHECK_EQ(hidden_states_nd->shape[1], 1); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + return hidden_states_nd.CreateView(out_shape, hidden_states_nd->dtype); + } } NDArray BatchVerify(const ObjectRef& embeddings, const std::vector& seq_ids, @@ -684,9 +511,9 @@ class ModelImpl : public ModelObj { return logits; } - NDArray BatchVerifyToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) final { + ObjectRef BatchVerifyToLastHidden(const ObjectRef& embeddings, + const std::vector& seq_ids, + const std::vector& lengths) final { NVTXScopedRange nvtx_scope("BatchVerifyToLastHidden"); CHECK(!seq_ids.empty()); CHECK_EQ(seq_ids.size(), lengths.size()); @@ -702,45 +529,46 @@ class ModelImpl : public ModelObj { ICHECK(ft_.kv_cache_end_forward_func_.defined()); ICHECK(kv_cache_.defined()) << "KV cache has not been initialized."; - ObjectRef hidden_states_dref_or_nd; - if (!hidden_states->IsInstance()) { - // hidden_states: (1, n, h) - NDArray hidden_states_nd = Downcast(hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 3); - ICHECK_EQ(hidden_states_nd->shape[0], 1); - ICHECK_EQ(hidden_states_nd->shape[1], total_length); - ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); - hidden_states_dref_or_nd = - hidden_states_nd.CreateView({1, total_length, hidden_size_}, hidden_states_nd->dtype); + ObjectRef embeddings_dref_or_nd; + if (!embeddings->IsInstance()) { + // embeddings: (1, n, h) + NDArray embeddings_nd = Downcast(embeddings); + ICHECK_NE(hidden_size_, -1); + ICHECK_EQ(embeddings_nd->ndim, 2); + ICHECK_GE(embeddings_nd->shape[0], total_length); + ICHECK_EQ(embeddings_nd->shape[1], hidden_size_); + ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type); + ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id); + embeddings_dref_or_nd = + embeddings_nd.CreateView({1, total_length, hidden_size_}, embeddings_nd->dtype); } else { - ShapeTuple hidden_states_shape{1, total_length, hidden_size_}; - hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); + ShapeTuple embedding_shape{1, total_length, hidden_size_}; + embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); } - // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); // args: embeddings, logit_pos, kv_cache, params - ObjectRef ret = ft_.verify_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); - NDArray last_hidden_states; - if (ft_.use_disco) { - Array result = Downcast(ret)->DebugGetFromRemote(0); - last_hidden_states = Downcast(result[0]); - } else { - last_hidden_states = Downcast>(ret)[0]; - } + ObjectRef result = ft_.verify_to_last_hidden_func_(embeddings_dref_or_nd, kv_cache_, params_); + ft_.kv_cache_end_forward_func_(kv_cache_); + ObjectRef hidden_states = ft_.tuple_getitem_func_(result, 0); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } - ft_.kv_cache_end_forward_func_(kv_cache_); - // hidden_states: (1, total_length, v) - ICHECK_EQ(last_hidden_states->ndim, 3); - ICHECK_EQ(last_hidden_states->shape[0], 1); - ICHECK_EQ(last_hidden_states->shape[1], total_length); - return last_hidden_states; + ShapeTuple out_shape{total_length, hidden_size_}; + if (!ft_.use_disco) { + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], 1); + ICHECK_EQ(hidden_states_nd->shape[1], total_length); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + return hidden_states_nd.CreateView(out_shape, hidden_states_nd->dtype); + } else { + return ft_.nd_view_func_(hidden_states, out_shape); + } } /*********************** KV Cache Management ***********************/ @@ -860,19 +688,19 @@ class ModelImpl : public ModelObj { // Allocate the hidden_states tensor. // Use the same function as embeddings. ObjectRef hidden_states = ft_.alloc_embedding_tensor_func_(); + NDArray hidden_states_nd{nullptr}; // Get the shape of the hidden_states tensor for hidden size. - ShapeTuple hidden_states_shape; if (ft_.use_disco) { ICHECK(hidden_states->IsInstance()); - ObjectRef shape_ref = ft_.nd_get_shape_func_(hidden_states); - hidden_states_shape = Downcast(shape_ref)->DebugGetFromRemote(0); + hidden_states_nd = Downcast(hidden_states)->DebugGetFromRemote(0); } else { - NDArray hidden_states_nd = Downcast(hidden_states); - hidden_states_shape = hidden_states_nd.Shape(); + hidden_states_nd = Downcast(hidden_states); } + ShapeTuple hidden_states_shape = hidden_states_nd.Shape(); ICHECK_EQ(hidden_states_shape.size(), 2); ICHECK_EQ(hidden_states_shape[0], prefill_chunk_size_); this->hidden_size_ = hidden_states_shape[1]; + this->hidden_states_dtype_ = hidden_states_nd->dtype; return hidden_states; } @@ -883,6 +711,63 @@ class ModelImpl : public ModelObj { } } + /********************** Utilities for speculative decoding **********************/ + + DraftTokenWorkspaceManager CreateDraftTokenWorkspaceManager(int max_num_tokens) { + return DraftTokenWorkspaceManager(max_num_tokens, vocab_size_, hidden_size_, + hidden_states_dtype_, device_, ft_); + } + + ObjectRef GatherHiddenStates(const ObjectRef& input, const std::vector& indices, + ObjectRef* dst) final { + ObjectRef dst_view{nullptr}; + ShapeTuple out_shape{static_cast(indices.size()), hidden_size_}; + if ((*dst)->IsInstance()) { + dst_view = ft_.nd_view_func_(*dst, out_shape); + } else { + NDArray dst_nd = Downcast(*dst); + dst_view = dst_nd.CreateView(out_shape, hidden_states_dtype_); + } + NDArray indices_nd = + logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); + indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, "logit_pos", {max_num_sequence_}); + ft_.gather_hidden_states_func_(input, indices_device, dst_view); + return dst_view; + } + + void ScatterHiddenStates(const ObjectRef& input, const std::vector& indices, + ObjectRef* dst) final { + NDArray indices_nd = + logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); + indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, "logit_pos", {max_num_sequence_}); + ft_.scatter_hidden_states_func_(input, indices_device, *dst); + } + + NDArray GatherDraftProbs(const NDArray& input, const std::vector& indices, + NDArray* dst) final { + NDArray dst_view = + dst->CreateView({static_cast(indices.size()), vocab_size_}, DataType::Float(32)); + NDArray indices_nd = + logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); + indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ObjectRef indices_device = + ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); + ft_.gather_probs_func_(input, indices_device, dst_view); + return dst_view; + } + + void ScatterDraftProbs(const NDArray& input, const std::vector& indices, + NDArray* dst) final { + NDArray indices_nd = + logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); + indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ObjectRef indices_device = + ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); + ft_.scatter_probs_func_(input, indices_device, *dst); + } + /************** Debug/Profile **************/ void DebugCallFuncOnAllAllWorker(const String& func_name) final { @@ -891,15 +776,7 @@ class ModelImpl : public ModelObj { private: /*! \brief Load model configuration from JSON. */ - picojson::object LoadModelConfigJSON(const std::string& config_str) { - picojson::value config_json; - std::string err = picojson::parse(config_json, config_str); - if (!err.empty()) { - LOG(FATAL) << err; - } - - // Get json fields. - picojson::object config = config_json.get(); + picojson::object LoadModelConfigJSON(picojson::object config) { if (config.count("context_window_size")) { CHECK(config["context_window_size"].is()); this->max_window_size_ = config["context_window_size"].get(); @@ -949,6 +826,7 @@ class ModelImpl : public ModelObj { int max_num_sequence_ = -1; int prefill_chunk_size_ = -1; int hidden_size_ = -1; + DLDataType hidden_states_dtype_; int vocab_size_ = -1; int image_embed_size_ = -1; //---------------------------- diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 045daff874..f587969bfb 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -7,11 +7,13 @@ #ifndef MLC_LLM_SERVE_MODEL_H_ #define MLC_LLM_SERVE_MODEL_H_ +#include #include #include #include "../base.h" #include "config.h" +#include "draft_token_workspace_manager.h" #include "event_trace_recorder.h" #include "function_table.h" #include "logit_processor.h" @@ -40,10 +42,26 @@ struct ModelWorkspace { */ ObjectRef embeddings{nullptr}; /*! - * \brief The hidden_states tensor. It can be either an NDArray when tensor + * \brief The hidden_states tensor for the current batch. It can be either an NDArray when tensor * model parallelism is not enabled, or a DRef when using tensor model parallelism. */ ObjectRef hidden_states{nullptr}; + + /*! + * \brief The draft token probabilities tensor for the current batch. + */ + NDArray draft_probs{nullptr}; + + /*! + * \brief The hidden_states tensor storing the hidden_states of draft tokens of all requests. + */ + ObjectRef draft_hidden_states_storage{nullptr}; + + /*! + * \brief The draft token probabilities tensor storing the probabilities of draft tokens of all + * requests. + */ + NDArray draft_probs_storage{nullptr}; }; /*! @@ -122,35 +140,6 @@ class ModelObj : public Object { */ virtual NDArray GetLogits(const ObjectRef& last_hidden_states, int batch_size, int seq_len) = 0; - /*! - * \brief Compute logits for last hidden_states in a batch. - * \param last_hidden_states The last hidden_states to compute logits for. - * \param seq_ids The id of the sequence in the KV cache. - * \param lengths The length of each sequence to prefill. - * \return The computed logits. - */ - virtual NDArray BatchGetLogits(const ObjectRef& last_hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) = 0; - - /*! - * \brief Select desired hidden_states for last hidden_states in a batch. - * \param last_hidden_states The last hidden_states to select from. - * \param seq_ids The id of the sequence in the KV cache. - * \param lengths The length of each sequence to prefill. - * \return The last hidden_states for the batch. - */ - virtual NDArray BatchSelectLastHidden(const ObjectRef& last_hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) = 0; - - /*! - * \brief Concat a list of 1D hidden_states to 2D tensor. - * \param hidden_states The hidden_states to concat. - * \param dst The copy destination. - */ - virtual NDArray ConcatLastHidden(std::vector& hidden_states, ObjectRef* dst) = 0; - /*! * \brief Batch prefill function. Embedding in, logits out. * The embedding order of sequences in `embedding_arr` follows @@ -171,9 +160,9 @@ class ModelObj : public Object { * \param lengths The length of each sequence to prefill. * \return The hidden_states for the next token. */ - virtual NDArray BatchPrefillToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) = 0; + virtual ObjectRef BatchPrefillToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) = 0; /*! * \brief Batch decode function. Embedding in, logits out. @@ -192,8 +181,8 @@ class ModelObj : public Object { * \param seq_id The id of the sequence in the KV cache. * \return The hidden_states for the next token for each sequence in the batch. */ - virtual NDArray BatchDecodeToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids) = 0; + virtual ObjectRef BatchDecodeToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids) = 0; /*! * \brief Batch verify function. Embedding in, logits out. @@ -219,9 +208,9 @@ class ModelObj : public Object { * That is to say, it does not accept "running a verify step for a subset * of the full batch". */ - virtual NDArray BatchVerifyToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) = 0; + virtual ObjectRef BatchVerifyToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) = 0; /*********************** KV Cache Management ***********************/ @@ -302,6 +291,27 @@ class ModelObj : public Object { /*! \brief Reset the model KV cache and other statistics. */ virtual void Reset() = 0; + /*********************** Utilities for speculative decoding. ***********************/ + + virtual DraftTokenWorkspaceManager CreateDraftTokenWorkspaceManager(int max_num_token) = 0; + + /*! \brief Gather the hidden_states of the given indices and in-place update the dst tensor. */ + virtual ObjectRef GatherHiddenStates(const ObjectRef& input, const std::vector& indices, + ObjectRef* dst) = 0; + + /*! \brief Scatter the hidden_states of the given indices to the dst tensor. */ + virtual void ScatterHiddenStates(const ObjectRef& input, const std::vector& indices, + ObjectRef* dst) = 0; + + /*! \brief Gather the draft token probabilities of the given indices and in-place update the dst + * tensor. */ + virtual NDArray GatherDraftProbs(const NDArray& input, const std::vector& indices, + NDArray* dst) = 0; + + /*! \brief Scatter the draft token probabilities of the given indices to the dst tensor. */ + virtual void ScatterDraftProbs(const NDArray& input, const std::vector& indices, + NDArray* dst) = 0; + /************** Debug/Profile **************/ /*! \brief Call the given global function on all workers. Only for debug purpose. */ @@ -319,13 +329,24 @@ class Model : public ObjectRef { * \brief Create the runtime module for LLM functions. * \param reload_lib_path The model library path. * \param model_path The path to the model weight parameters. + * \param model_config The model config json object. * \param device The device to run the model on. * \param max_num_sequence The maximum number of sequences to be processed + * \param session The session to run the model on. * \param trace_enabled A boolean indicating whether tracing is enabled. * \return The created runtime module. */ - TVM_DLL static Model Create(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled); + TVM_DLL static Model Create(String reload_lib_path, String model_path, + const picojson::object& model_config, DLDevice device, + int max_num_sequence, const Optional& session, + bool trace_enabled); + + /*! + * Load the model config from the given model path. + * \param model_path The path to the model weight parameters. + * \return The model config json object. + */ + static picojson::object LoadModelConfig(const String& model_path); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Model, ObjectRef, ModelObj); }; diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index b1f5ae27a2..4c59ae52a2 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -59,11 +59,9 @@ void RequestModelStateNode::CommitToken(SampleResult sampled_token) { } } -void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, NDArray prob_dist, - NDArray last_hidden_on_device) { +void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, int draft_token_slot) { draft_output_tokens.push_back(std::move(sampled_token)); - draft_output_prob_dist.push_back(std::move(prob_dist)); - draft_last_hidden_on_device.push_back(std::move(last_hidden_on_device)); + draft_token_slots.push_back(draft_token_slot); appeared_token_ids[sampled_token.sampled_token_id.first] += 1; } @@ -71,14 +69,17 @@ void RequestModelStateNode::RemoveLastDraftToken() { ICHECK(!draft_output_tokens.empty()); auto it = appeared_token_ids.find(draft_output_tokens.back().sampled_token_id.first); draft_output_tokens.pop_back(); - draft_output_prob_dist.pop_back(); CHECK(it != appeared_token_ids.end()); if (--it->second == 0) { appeared_token_ids.erase(it); } } -void RequestModelStateNode::RemoveAllDraftTokens() { +void RequestModelStateNode::RemoveAllDraftTokens(std::vector* removed_draft_token_slots) { + if (removed_draft_token_slots != nullptr) { + removed_draft_token_slots->assign(draft_token_slots.begin(), draft_token_slots.end()); + } + draft_token_slots.clear(); while (!draft_output_tokens.empty()) { RemoveLastDraftToken(); } diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 950bb6e290..79abcb1a24 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -62,20 +62,8 @@ class RequestModelStateNode : public Object { * result of speculation. */ std::vector draft_output_tokens; - /*! - * \brief The probability distribution on each position in the - * draft. We keep the distributions for stochastic sampling when merging - * speculations from multiple models. - * \note We only need this value when we have multiple parallel small models - * and draft outputs in speculative inference settings. - */ - std::vector draft_output_prob_dist; - /*! - * \brief The last hidden_states used to get probs in drafting. - * \note We only need this value when we have multiple parallel small models - * and draft outputs in speculative inference settings. - */ - std::vector draft_last_hidden_on_device; + /*! \brief The storage slots for the associated states of draft tokens. */ + std::vector draft_token_slots; /*! \brief The appeared committed and draft tokens and their occurrence times. */ std::unordered_map appeared_token_ids; @@ -101,17 +89,18 @@ class RequestModelStateNode : public Object { /*! \brief Commit a new token into committed_tokens. Update appeared_token_ids. */ void CommitToken(SampleResult sampled_token); /*! \brief Add a draft token into draft_output_tokens. Update appeared_token_ids. */ - void AddDraftToken(SampleResult sampled_token, NDArray prob_dist, - NDArray draft_last_hidden_on_device = NDArray()); - /*! \brief Remove the last token from draft_output_tokens. Update appeared_token_ids. */ - void RemoveLastDraftToken(); + void AddDraftToken(SampleResult sampled_token, int draft_token_slot); /*! \brief Remove all draft tokens from draft_output_tokens. Update appeared_token_ids. */ - void RemoveAllDraftTokens(); + void RemoveAllDraftTokens(std::vector* removed_draft_token_slots = nullptr); static constexpr const char* _type_key = "mlc.serve.RequestModelState"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; TVM_DECLARE_BASE_OBJECT_INFO(RequestModelStateNode, Object); + + private: + /*! \brief Remove the last token from draft_output_tokens. Update appeared_token_ids. */ + void RemoveLastDraftToken(); }; class RequestModelState : public ObjectRef { diff --git a/cpp/serve/sampler/cpu_sampler.cc b/cpp/serve/sampler/cpu_sampler.cc index 98080c979d..196a6dd695 100644 --- a/cpp/serve/sampler/cpu_sampler.cc +++ b/cpp/serve/sampler/cpu_sampler.cc @@ -430,7 +430,7 @@ class CPUSampler : public SamplerObj { const std::vector& cum_verify_lengths, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) final { + NDArray draft_probs_on_device) final { // probs_on_host: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); CHECK_EQ(probs_on_host->ndim, 2); @@ -438,8 +438,8 @@ class CPUSampler : public SamplerObj { int num_sequence = static_cast(cum_verify_lengths.size()) - 1; CHECK_EQ(rngs.size(), num_sequence); CHECK_EQ(draft_output_tokens.size(), num_sequence); - CHECK_EQ(draft_output_prob_dist.size(), num_sequence); + NDArray draft_probs_on_host = draft_probs_on_device.CopyTo(DLDevice{kDLCPU, 0}); std::vector> sample_results; sample_results.resize(num_sequence); @@ -451,6 +451,7 @@ class CPUSampler : public SamplerObj { [&](int i) { int verify_start = cum_verify_lengths[i]; int verify_end = cum_verify_lengths[i + 1]; + int cur_token_idx = 0; // Sub 1 to ignore the last prediction. for (; cur_token_idx < verify_end - verify_start - 1; ++cur_token_idx) { @@ -477,12 +478,9 @@ class CPUSampler : public SamplerObj { // normalize a new probability distribution double sum_v = 0.0; - NDArray q_dist = draft_output_prob_dist[i][cur_token_idx]; - ICHECK(q_dist->device.device_type == kDLCPU); - ICHECK(q_dist->ndim == 1); - ICHECK(vocab_size == q_dist->shape[q_dist->ndim - 1]); const float* __restrict p_qdist = - static_cast(__builtin_assume_aligned(q_dist->data, 4)); + static_cast(__builtin_assume_aligned(draft_probs_on_host->data, 4)) + + (verify_start + cur_token_idx + 1) * vocab_size; for (int j = 0; j < vocab_size; ++j) { p_probs[j] = std::max(p_probs[j] - p_qdist[j], 0.0f); diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index 62911a7cd1..87a9a31d30 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -51,6 +51,9 @@ class GPUSampler : public SamplerObj { ICHECK(gpu_sample_with_top_p_func_.defined()); ICHECK(gpu_sampler_take_probs_func_.defined()); + flashinfer_multinomial_sample_func_ = + Registry::Get("flashinfer.sampling.parallel_sampling_from_prob"); + DLDevice device_cpu{DLDeviceType::kDLCPU, /*device_id=*/0}; // We support at most 5 top prob results for each sequence. // Initialize auxiliary arrays on CPU. @@ -71,11 +74,11 @@ class GPUSampler : public SamplerObj { sample_indices_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); top_p_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device); top_prob_offsets_device_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device); - draft_probs_device_ = NDArray::Empty({max_num_sample, vocab_size}, dtype_f32_, device); draft_tokens_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); token_tree_first_child_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); token_tree_next_sibling_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); token_tree_parent_ptr_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + sampled_token_ids_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); // If the device is CUDA/ROCm, we create a standalone copy stream, in // purpose to hide the latency of auxiliary stream copy. @@ -163,7 +166,7 @@ class GPUSampler : public SamplerObj { const std::vector& cum_verify_lengths, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) final { + NDArray draft_probs_on_device) final { NVTXScopedRange nvtx_scope("BatchVerifyDraftTokensWithProbAfterTopP"); std::vector> sample_results; // probs_on_device: (n, v) @@ -173,38 +176,27 @@ class GPUSampler : public SamplerObj { int num_sequence = static_cast(cum_verify_lengths.size()) - 1; CHECK_EQ(rngs.size(), num_sequence); CHECK_EQ(draft_output_tokens.size(), num_sequence); - CHECK_EQ(draft_output_prob_dist.size(), num_sequence); sample_results.resize(num_sequence); int num_nodes = cum_verify_lengths.back(); + CHECK_EQ(draft_probs_on_device->shape[0], num_nodes); NDArray uniform_samples_host = uniform_samples_host_.CreateView({num_nodes}, dtype_f32_); NDArray uniform_samples_device = uniform_samples_device_.CreateView({num_nodes}, dtype_f32_); - NDArray draft_probs_device = - draft_probs_device_.CreateView({num_nodes, vocab_size_}, dtype_f32_); NDArray draft_tokens_host = draft_tokens_host_.CreateView({num_nodes}, dtype_i32_); NDArray draft_tokens_device = draft_tokens_device_.CreateView({num_nodes}, dtype_i32_); - // Concat draft prob distributions to a ragged tensor (num_nodes, vocab_size) + // Copy draft tokens to GPU + int* p_draft_tokens_host = static_cast(draft_tokens_host->data); for (int i = 0; i < num_sequence; i++) { const std::vector& draft_output_tokens_i = draft_output_tokens[i]; - const std::vector& draft_output_prob_dist_i = draft_output_prob_dist[i]; int start = cum_verify_lengths[i]; int end = cum_verify_lengths[i + 1]; // start/end is the range of the sequence i in probs_on_device, which includes the prob dist // of the draft tokens and the last committed token ICHECK_EQ(draft_output_tokens_i.size() + 1, end - start); - ICHECK_EQ(draft_output_prob_dist_i.size() + 1, end - start); for (int j = 0; j < end - start - 1; j++) { - // Copy prob dist - ICHECK_EQ(draft_probs_device->dtype.bits, 32); - float* p_draft_probs = - static_cast(draft_probs_device->data) + - (j + start + 1) * - vocab_size_; // shift by one, q of the last committed token is undefined // Copy sampled token id - draft_output_prob_dist_i[j].CopyToBytes(p_draft_probs, vocab_size_ * sizeof(float)); - *(static_cast(draft_tokens_host->data) + j + start + 1) = - draft_output_tokens_i[j].sampled_token_id.first; + p_draft_tokens_host[start + j + 1] = draft_output_tokens_i[j].sampled_token_id.first; } } CopyArray(draft_tokens_host, draft_tokens_device, copy_stream_); @@ -258,7 +250,7 @@ class GPUSampler : public SamplerObj { SyncCopyStream(device_, compute_stream_, copy_stream_); - gpu_verify_draft_tokens_func_(draft_probs_device, draft_tokens_device, probs_on_device, + gpu_verify_draft_tokens_func_(draft_probs_on_device, draft_tokens_device, probs_on_device, token_tree_first_child_device, token_tree_next_sibling_device, uniform_samples_device, token_tree_parent_ptr_device); @@ -495,8 +487,15 @@ class GPUSampler : public SamplerObj { if (!need_top_p && !need_prob_values) { // - Short path: If top_p and prob values are not needed, we directly sample from multinomial. SyncCopyStream(device_, compute_stream_, copy_stream_); - sampled_token_ids_device = gpu_multinomial_from_uniform_func_( - probs_on_device, uniform_samples_device, sample_indices_device); + if (flashinfer_multinomial_sample_func_ != nullptr) { + sampled_token_ids_device = + sampled_token_ids_device_.CreateView({sample_indices_device->shape[0]}, dtype_i32_); + (*flashinfer_multinomial_sample_func_)(probs_on_device, uniform_samples_device, + sample_indices_device, sampled_token_ids_device); + } else { + sampled_token_ids_device = gpu_multinomial_from_uniform_func_( + probs_on_device, uniform_samples_device, sample_indices_device); + } return {sampled_token_ids_device, sampled_probs_device, top_prob_probs_device, top_prob_indices_device}; } @@ -531,8 +530,15 @@ class GPUSampler : public SamplerObj { uniform_samples_device, sample_indices_device, top_p_device); } else { // - Sample without top_p. - sampled_token_ids_device = gpu_multinomial_from_uniform_func_( - probs_on_device, uniform_samples_device, sample_indices_device); + if (flashinfer_multinomial_sample_func_ != nullptr) { + sampled_token_ids_device = + sampled_token_ids_device_.CreateView({sample_indices_device->shape[0]}, dtype_i32_); + (*flashinfer_multinomial_sample_func_)(probs_on_device, uniform_samples_device, + sample_indices_device, sampled_token_ids_device); + } else { + sampled_token_ids_device = gpu_multinomial_from_uniform_func_( + probs_on_device, uniform_samples_device, sample_indices_device); + } } if (need_prob_values) { @@ -604,6 +610,7 @@ class GPUSampler : public SamplerObj { PackedFunc gpu_sampler_take_probs_func_; PackedFunc gpu_verify_draft_tokens_func_; PackedFunc gpu_renormalize_by_top_p_func_; + const PackedFunc* flashinfer_multinomial_sample_func_; // Auxiliary NDArrays on CPU NDArray uniform_samples_host_; NDArray sample_indices_host_; @@ -622,11 +629,11 @@ class GPUSampler : public SamplerObj { NDArray sample_indices_device_; NDArray top_p_device_; NDArray top_prob_offsets_device_; - NDArray draft_probs_device_; NDArray draft_tokens_device_; NDArray token_tree_first_child_device_; NDArray token_tree_next_sibling_device_; NDArray token_tree_parent_ptr_device_; + NDArray sampled_token_ids_device_; // The event trace recorder for requests. */ Optional trace_recorder_; // The device stream for the default computation operations. diff --git a/cpp/serve/sampler/sampler.h b/cpp/serve/sampler/sampler.h index 7943231e55..59e433ac47 100644 --- a/cpp/serve/sampler/sampler.h +++ b/cpp/serve/sampler/sampler.h @@ -108,15 +108,16 @@ class SamplerObj : public Object { * \param rngs The random number generator of each sequence. * \param draft_output_tokens The draft tokens generated by the small model for * each sequence. - * \param draft_output_prob_dist The probability distribution computed from the - * small model for each sequence. + * \param draft_probs_on_device The probability distribution computed from the + * small model for each sequence. Concatenated tensor of shape (total_verify_length, vocab_size). + * It includes the slot for the last committed token that has undefined probablity value. * \return The list of accepted tokens for each request. */ virtual std::vector> BatchVerifyDraftTokensWithProbAfterTopP( NDArray probs, const Array& request_ids, const std::vector& cum_verify_lengths, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) = 0; + NDArray draft_probs_on_device) = 0; static constexpr const char* _type_key = "mlc.serve.Sampler"; static constexpr const bool _type_has_method_sequal_reduce = false; diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc index f234dfbbc3..080853d465 100644 --- a/cpp/serve/threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -36,8 +36,9 @@ enum class InstructionKind : int { /*! \brief The implementation of ThreadedEngine. */ class ThreadedEngineImpl : public ThreadedEngine { public: - void InitBackgroundEngine(Optional request_stream_callback, + void InitBackgroundEngine(Device device, Optional request_stream_callback, Optional trace_recorder) final { + device_ = device; CHECK(request_stream_callback.defined()) << "ThreadedEngine requires request stream callback function, but it is not given."; request_stream_callback_ = request_stream_callback.value(); @@ -213,6 +214,11 @@ class ThreadedEngineImpl : public ThreadedEngine { } } + String Stats() final { + std::lock_guard lock(background_loop_mutex_); + return background_engine_->Stats(); + } + private: void EngineReloadImpl(EngineConfig engine_config) { auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { @@ -231,7 +237,7 @@ class ThreadedEngineImpl : public ThreadedEngine { }; Optional request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - background_engine_ = Engine::Create(std::move(engine_config), + background_engine_ = Engine::Create(std::move(engine_config), device_, std::move(request_stream_callback), trace_recorder_); } @@ -247,6 +253,8 @@ class ThreadedEngineImpl : public ThreadedEngine { } } + /*! \brief The device to run models on. */ + Device device_; /*! \brief The background normal engine for request processing. */ std::unique_ptr background_engine_; /*! \brief The request stream callback. */ @@ -311,6 +319,7 @@ class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &ThreadedEngineImpl::ExitBackgroundLoop); TVM_MODULE_VTABLE_ENTRY("debug_call_func_on_all_worker", &ThreadedEngineImpl::DebugCallFuncOnAllAllWorker); + TVM_MODULE_VTABLE_ENTRY("stats", &ThreadedEngineImpl::Stats); TVM_MODULE_VTABLE_END(); }; diff --git a/cpp/serve/threaded_engine.h b/cpp/serve/threaded_engine.h index f3d9c2b70c..d0f2ebe2d7 100644 --- a/cpp/serve/threaded_engine.h +++ b/cpp/serve/threaded_engine.h @@ -35,10 +35,11 @@ class ThreadedEngine { /*! * \brief Initialize the threaded engine from packed arguments in TVMArgs. + * \param device The device where to run models. * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. */ - virtual void InitBackgroundEngine(Optional request_stream_callback, + virtual void InitBackgroundEngine(Device device, Optional request_stream_callback, Optional trace_recorder) = 0; /*! @@ -76,6 +77,9 @@ class ThreadedEngine { /*! \brief Call the given global function on all workers. Only for debug purpose. */ virtual void DebugCallFuncOnAllAllWorker(const String& func_name) = 0; + + /*! \brief Print the statistics of the engine. */ + virtual String Stats() = 0; }; } // namespace serve diff --git a/cpp/support/encoding.cc b/cpp/support/encoding.cc index 0509c1eb2a..d9420bbbd5 100644 --- a/cpp/support/encoding.cc +++ b/cpp/support/encoding.cc @@ -11,7 +11,7 @@ namespace mlc { namespace llm { -std::string CodepointToUtf8(TCodepoint codepoint) { +std::string PrintAsUTF8(TCodepoint codepoint) { ICHECK(codepoint <= 0x10FFFF) << "Invalid codepoint: " << codepoint; std::string utf8; if (codepoint <= 0x7F) { @@ -36,8 +36,8 @@ std::string CodepointToUtf8(TCodepoint codepoint) { return utf8; } -std::string CodepointToPrintable( - TCodepoint codepoint, const std::unordered_map& custom_escape_map) { +std::string PrintAsEscaped(TCodepoint codepoint, + const std::unordered_map& custom_escape_map) { static const std::unordered_map kCodepointToEscape = { {'\'', "\\\'"}, {'\"', "\\\""}, {'\?', "\\\?"}, {'\\', "\\\\"}, {'\a', "\\a"}, {'\b', "\\b"}, {'\f', "\\f"}, {'\n', "\\n"}, {'\r', "\\r"}, {'\t', "\\t"}, @@ -63,10 +63,10 @@ std::string CodepointToPrintable( return codepoint <= 0xFFFF ? "\\u" + hex : "\\U" + hex; } -std::pair Utf8ToCodepoint(const char* utf8) { - const std::array kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07}; +std::pair ParseNextUTF8(const char* utf8) { + static const std::array kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07}; // clang-format off - const std::array kUtf8Bytes = { + static const std::array kUtf8Bytes = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, @@ -89,7 +89,7 @@ std::pair Utf8ToCodepoint(const char* utf8) { auto bytes = kUtf8Bytes[static_cast(utf8[0])]; if (bytes == -1) { // invalid utf8 - return {static_cast(CharHandlingError::kInvalidUtf8), 0}; + return {static_cast(CharHandlingError::kInvalidUtf8), utf8}; } TCodepoint res = static_cast(utf8[0]) & kFirstByteMask[bytes]; @@ -100,23 +100,23 @@ std::pair Utf8ToCodepoint(const char* utf8) { } res = (res << 6) | (static_cast(utf8[i]) & 0x3F); } - return {res, bytes}; + return {res, utf8 + bytes}; } -std::vector Utf8StringToCodepoints(const char* utf8) { +std::vector ParseUTF8(const char* utf8) { std::vector codepoints; while (*utf8 != 0) { - auto [codepoint, bytes] = Utf8ToCodepoint(utf8); + TCodepoint codepoint; + std::tie(codepoint, utf8) = ParseNextUTF8(utf8); if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { return {codepoint}; } codepoints.push_back(codepoint); - utf8 += bytes; } return codepoints; } -int HexCharToInt(char c) { +inline int HexCharToInt(char c) { if (c >= '0' && c <= '9') { return c - '0'; } else if (c >= 'a' && c <= 'f') { @@ -128,22 +128,22 @@ int HexCharToInt(char c) { } } -std::pair Utf8OrEscapeToCodepoint( +std::pair ParseNextUTF8OrEscaped( const char* utf8, const std::unordered_map& custom_escape_map) { static const std::unordered_map kEscapeToCodepoint = { {"\\\'", '\''}, {"\\\"", '\"'}, {"\\\?", '\?'}, {"\\\\", '\\'}, {"\\a", '\a'}, {"\\b", '\b'}, {"\\f", '\f'}, {"\\n", '\n'}, {"\\r", '\r'}, {"\\t", '\t'}, {"\\v", '\v'}, {"\\0", '\0'}, {"\\e", '\x1B'}}; if (utf8[0] != '\\') { - return Utf8ToCodepoint(utf8); + return ParseNextUTF8(utf8); } auto escape_sequence = std::string(utf8, 2); if (auto it = custom_escape_map.find(escape_sequence); it != custom_escape_map.end()) { - return {it->second, 2}; + return {it->second, utf8 + 2}; } if (auto it = kEscapeToCodepoint.find(escape_sequence); it != kEscapeToCodepoint.end()) { - return {it->second, 2}; + return {it->second, utf8 + 2}; } if (utf8[1] == 'x') { @@ -159,9 +159,9 @@ std::pair Utf8OrEscapeToCodepoint( ++len; } if (len == 0) { - return {static_cast(CharHandlingError::kInvalidEscape), 0}; + return {static_cast(CharHandlingError::kInvalidEscape), utf8}; } - return {codepoint, len + 2}; + return {codepoint, utf8 + len + 2}; } else if (utf8[1] == 'u' || utf8[1] == 'U') { // 4- or 8-digit hex int len = utf8[1] == 'u' ? 4 : 8; @@ -170,13 +170,13 @@ std::pair Utf8OrEscapeToCodepoint( for (int i = 0; i < len; ++i) { auto digit = HexCharToInt(utf8[i + 2]); if (digit == -1) { - return {static_cast(CharHandlingError::kInvalidEscape), 0}; + return {static_cast(CharHandlingError::kInvalidEscape), utf8}; } codepoint = codepoint * 16 + digit; } - return {codepoint, len + 2}; + return {codepoint, utf8 + len + 2}; } else { - return {static_cast(CharHandlingError::kInvalidEscape), 0}; + return {static_cast(CharHandlingError::kInvalidEscape), utf8}; } } diff --git a/cpp/support/encoding.h b/cpp/support/encoding.h index f28aae6d74..790040e97e 100644 --- a/cpp/support/encoding.h +++ b/cpp/support/encoding.h @@ -21,7 +21,7 @@ using TCodepoint = int32_t; * \param codepoint The codepoint. * \return The UTF-8 string. */ -std::string CodepointToUtf8(TCodepoint codepoint); +std::string PrintAsUTF8(TCodepoint codepoint); /*! * \brief Convert a codepoint to a printable string. If the codepoint is not printable, it will be @@ -29,10 +29,10 @@ std::string CodepointToUtf8(TCodepoint codepoint); * specify more escape sequences using custom_escape_map. * \param codepoint The codepoint. * \param custom_escape_map A map from codepoint to escape sequence. If the codepoint is in the map, - * it will be escaped using the corresponding escape sequence. e.g. {'-', "\\-"}. + * it will be escaped using the corresponding escape sequence. e.g. {{'-', "\\-"}}. * \return The printable string. */ -std::string CodepointToPrintable( +std::string PrintAsEscaped( TCodepoint codepoint, const std::unordered_map& custom_escape_map = {}); @@ -53,9 +53,9 @@ enum class CharHandlingError : TCodepoint { * \return The codepoint and the number of bytes consumed. If the UTF-8 string is invalid, the * function returns (CharHandlingError::kInvalidUtf8, 0). */ -std::pair Utf8ToCodepoint(const char* utf8); +std::pair ParseNextUTF8(const char* utf8); -std::vector Utf8StringToCodepoints(const char* utf8); +std::vector ParseUTF8(const char* utf8); /*! * \brief Convert a UTF-8 string or an escape sequence to a codepoint. By default the function @@ -63,12 +63,12 @@ std::vector Utf8StringToCodepoints(const char* utf8); * using custom_escape_map. * \param utf8 The UTF-8 string or the escape sequence. * \param custom_escape_map A map from escape sequence to codepoint. If the escape sequence is in - * the map, it will be converted to the corresponding codepoint. e.g. {"\\-", '-'}. + * the map, it will be converted to the corresponding codepoint. e.g. {{"\\-", '-'}}. * \return The codepoint and the number of bytes consumed. If the UTF-8 string or the escape * sequence is invalid, the function returns * (CharHandlingError::kInvalidUtf8 or CharHandlingError::kInvalidEscape, 0). */ -std::pair Utf8OrEscapeToCodepoint( +std::pair ParseNextUTF8OrEscaped( const char* utf8, const std::unordered_map& custom_escape_map = {}); } // namespace llm diff --git a/cpp/tokenizers.cc b/cpp/tokenizers.cc index ef866f3bfc..6fe9217520 100644 --- a/cpp/tokenizers.cc +++ b/cpp/tokenizers.cc @@ -9,10 +9,12 @@ #include #include +#include #include #include #include +#include "./support/encoding.h" #include "./support/load_bytes_from_file.h" namespace mlc { @@ -91,13 +93,8 @@ Tokenizer Tokenizer::FromPath(const String& _path) { LOG(FATAL) << "Cannot find any tokenizer under: " << _path; } -/*! - * \brief Post-process a raw token (which may be a raw byte or contain lower - * one eights block) to the actual token. - * We do this in order to conform with the tokenizers' setup. - */ -inline std::string PostProcessToken(std::string token) { - // 1. The token represents a byte. +/*! \brief ByteFallback decoder: transform tokens like <0x1B> to hex char byte 1B */ +inline std::string ByteFallbackDecoder(const std::string& token) { if (token.length() == 6 && token.substr(0, 3) == "<0x" && token.back() == '>') { int byte = 0; for (int i = 0; i < 2; ++i) { @@ -108,15 +105,82 @@ inline std::string PostProcessToken(std::string token) { ICHECK(byte >= 0 && byte < 256); return std::string(/*n=*/1, static_cast(byte)); } + return token; +} - // 2. The token contains "\u2581" which means space. - static const std::string& lower_one_eighth_block = "\u2581"; - size_t pos = token.find(lower_one_eighth_block); - while (pos != std::string::npos) { - token.replace(pos, /*n=*/lower_one_eighth_block.length(), /*str=*/" "); - pos = token.find(lower_one_eighth_block); +/*! \brief SpaceReplacer decoder: transform "\u2581" back to space */ +inline std::string SpaceReplacerDecoder(const std::string& token) { + // \u2581 is the unicode for "lower one eighth block" + // UTF8 encoding for \u2581 is 0xE2 0x96 0x81 + std::string result; + for (size_t i = 0; i < token.size(); ++i) { + if (i + 2 < token.size() && token[i] == char(0xE2) && token[i + 1] == char(0x96) && + token[i + 2] == char(0x81)) { + result += ' '; + i += 2; + } else { + result += token[i]; + } + } + return result; +} + +/*! \brief ByteLevel decoder: inverses the bytes-to-unicode transformation in the encoding + * process as in + * https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59 + */ +inline std::string ByteLevelDecoder(const std::string& token) { + // clang-format off + // The inverse map of bytes_to_unicode. -1 means there is no mapping to this unicode. + static const std::array unicode_to_byte_map = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, + 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, + 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, -1, + 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, + 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, + 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, + 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, + 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 127, 128, + 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, + 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 173 + }; + // clang-format on + + auto unicode_codepoints = ParseUTF8(token.c_str()); + std::string decoded; + + for (auto unicode_codepoint : unicode_codepoints) { + ICHECK(unicode_codepoint >= 0 && + unicode_codepoint < static_cast(unicode_to_byte_map.size())); + int byte = unicode_to_byte_map[unicode_codepoint]; + if (byte == -1) { + // If there is no mapping, add the codepoint itself to the result string + // Some tokenizer like Phi-2 have raw tokens like \t\t + decoded += static_cast(unicode_codepoint); + } else { + decoded += static_cast(byte); + } + } + return decoded; +} + +/*! + * \brief Post-process a raw token to the actual token with the given post-processing method. + */ +inline std::string PostProcessToken(const std::string& token, const std::string& postproc_method) { + if (postproc_method == "byte_fallback") { + return SpaceReplacerDecoder(ByteFallbackDecoder(token)); + } else if (postproc_method == "byte_level") { + return ByteLevelDecoder(token); + } else { + LOG(FATAL) << "Unknown post-processing method: " << postproc_method; } - return token; } const std::vector& TokenizerObj::TokenTable() { @@ -127,12 +191,21 @@ const std::vector& TokenizerObj::TokenTable() { int vocab_size = tokenizer->GetVocabSize(); token_table_.reserve(vocab_size); for (int32_t token_id = 0; token_id < vocab_size; ++token_id) { - std::string token = tokenizer->IdToToken(token_id); - token_table_.push_back(PostProcessToken(token)); + token_table_.push_back(tokenizer->IdToToken(token_id)); } return token_table_; } +std::vector Tokenizer::PostProcessTokenTable( + const std::vector& token_table, const std::string& postproc_method) { + std::vector postprocessed_token_table; + postprocessed_token_table.reserve(token_table.size()); + for (const std::string& token : token_table) { + postprocessed_token_table.push_back(PostProcessToken(token, postproc_method)); + } + return postprocessed_token_table; +} + TVM_REGISTER_GLOBAL("mlc.Tokenizer").set_body_typed([](const String& path) { return Tokenizer::FromPath(path); }); diff --git a/cpp/tokenizers.h b/cpp/tokenizers.h index 16d9ba456b..36fc0c23db 100644 --- a/cpp/tokenizers.h +++ b/cpp/tokenizers.h @@ -30,7 +30,7 @@ class TokenizerObj : public Object { std::vector Encode(const std::string& text) const; /*! \brief Decode token ids into text. */ std::string Decode(const std::vector& token_ids) const; - /*! \brief Return the token table of the tokenizer. */ + /*! \brief Return the token table of the tokenizer. Special tokens are included. */ const std::vector& TokenTable(); /*! @@ -64,6 +64,25 @@ class Tokenizer : public ObjectRef { /*! \brief Create a tokenizer from a directory path on disk. */ MLC_LLM_DLL static Tokenizer FromPath(const String& path); + /*! + * \brief Convert raw tokens provided by the tokenizer to their original string to simplify + * later processing. E.g. For LLaMA-2, convert "▁of" to " of". + * + * \param token_table The raw token table. + * \param postproc_method The postprocessing method to use. Now we only support "byte-fallback" + * and "byte-level", which refers to the type of the decoder of the tokenizer. + * - "byte-fallback": Use the decoding method in the byte-fallback BPE tokenizer. This is used + * by LLaMA-2, Mixtral-7b, etc. This method: 1) transform tokens like <0x1B> to hex char + * byte 1B. (known as the byte-fallback method); 2) transform \\u2581 to space. + * - "byte-level": Use the decoding method in the byte-level BPE tokenizer. This is used by + * LLaMA-3, GPT-2, Phi-2, etc. This method inverses the bytes-to-unicode transformation in + * the encoding process as in + * https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59 + * \returns The postprocessed token table containing the original strings. + */ + static std::vector PostProcessTokenTable(const std::vector& token_table, + const std::string& postproc_method); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Tokenizer, ObjectRef, TokenizerObj); private: diff --git a/python/mlc_llm/cli/lib_delivery.py b/python/mlc_llm/cli/lib_delivery.py new file mode 100644 index 0000000000..a5d678fbe2 --- /dev/null +++ b/python/mlc_llm/cli/lib_delivery.py @@ -0,0 +1,200 @@ +"""Continuous model delivery for MLC LLM models.""" + +import argparse +import dataclasses +import json +import os +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Any, Callable, Dict, List + +from mlc_llm.support import logging +from mlc_llm.support.argparse import ArgumentParser +from mlc_llm.support.constants import MLC_TEMP_DIR +from mlc_llm.support.style import bold, green, red + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ModelInfo: # pylint: disable=too-many-instance-attributes + """Necessary information for the model delivery""" + + model_id: str + model: Path + quantization: str + device: str + # overrides the `context_window_size`, `prefill_chunk_size`, + # `sliding_window_size`, `attention_sink_size`, `max_batch_size` + # and `tensor_parallel_shards in mlc-chat-config.json + overrides: Dict[str, int] + + +class DeferredScope: + """A context manager that defers execution of functions until exiting the scope.""" + + def __init__(self): + self.deferred_functions = [] + + def add(self, func: Callable[[], None]): + """Add a function to be executed when exiting the scope.""" + self.deferred_functions.append(func) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + for func in reversed(self.deferred_functions): + func() + return False + + def create_temp_dir(self) -> Path: + """Create a temporary directory that will be deleted when exiting the scope.""" + temp_dir = tempfile.mkdtemp(dir=MLC_TEMP_DIR) + self.add(lambda: shutil.rmtree(temp_dir, ignore_errors=True)) + return Path(temp_dir) + + +def _run_compilation(model_info: ModelInfo, repo_dir: Path) -> bool: + """Run the compilation of the model library.""" + + def get_lib_ext(device: str) -> str: + if device in ["cuda", "vulkan", "metal"]: + return ".so" + if device in ["android", "ios"]: + return ".tar" + if device in ["webgpu"]: + return ".wasm" + + return "" + + succeeded = True + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as temp_dir: + log_path = Path(temp_dir) / "logs.txt" + model_lib_name = f"{model_info.model_id}-{model_info.quantization}-{model_info.device}" + lib_ext = get_lib_ext(model_info.device) + if lib_ext == "": + raise ValueError(f"Unsupported device: {model_info.device}") + model_lib_name += lib_ext + with log_path.open("a", encoding="utf-8") as log_file: + overrides = ";".join(f"{key}={value}" for key, value in model_info.overrides.items()) + cmd = [ + sys.executable, + "-m", + "mlc_llm", + "compile", + str(model_info.model), + "--device", + model_info.device, + "--quantization", + model_info.quantization, + "--overrides", + overrides, + "--output", + os.path.join(temp_dir, model_lib_name), + ] + print(" ".join(cmd), file=log_file, flush=True) + subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT) + logger.info("[MLC] Compilation Complete!") + if not (Path(temp_dir) / model_lib_name).exists(): + logger.error( + "[%s] Model %s. Device %s. No compiled library found.", + red("FAILED"), + model_info.model_id, + model_info.device, + ) + succeeded = False + return succeeded + + # overwrite git repo file with the compiled library + repo_filepath = repo_dir / model_info.model_id / model_lib_name + if not repo_filepath.parent.exists(): + repo_filepath.parent.mkdir(parents=True, exist_ok=True) + # copy lib from Path(temp_dir) / model_lib_name to repo_filepath + shutil.copy(Path(temp_dir) / model_lib_name, repo_filepath) + logger.info("Saved library %s at %s", model_lib_name, repo_filepath) + return succeeded + + +def _main( # pylint: disable=too-many-locals + spec: Dict[str, Any], +): + """Compile the model libs in the spec and save them to the binary_libs_dir.""" + failed_cases: List[Any] = [] + for task_index, task in enumerate(spec["tasks"], 1): + logger.info( + bold("[{task_index}/{total_tasks}] Processing model: ").format( + task_index=task_index, + total_tasks=len(spec["tasks"]), + ) + + green(task["model_id"]) + ) + model_info = { + "model_id": task["model_id"], + "model": task["model"], + } + for compile_opt in spec["default_compile_options"] + task.get("compile_options", []): + for quantization in spec["default_quantization"] + task.get("quantization", []): + model_info["quantization"] = quantization + model_info["device"] = compile_opt["device"] + model_info["overrides"] = compile_opt.get("overrides", {}) + logger.info( + "[Config] " + + bold("model_id: ") + + model_info["model_id"] + + bold(", quantization: ") + + model_info["quantization"] + + bold(", device: ") + + model_info["device"] + + bold(", overrides: ") + + json.dumps(model_info["overrides"]) + ) + + result = _run_compilation( + ModelInfo(**model_info), + repo_dir=Path(spec["binary_libs_dir"]), + ) + if not result: + failed_cases.append(model_info) + + if failed_cases: + logger.info("Total %s %s:", len(failed_cases), red("failures")) + for case in failed_cases: + logger.info( + "model_id %s, quantization %s, device %s, overrides %s", + case["model_id"], + case["quantization"], + case["device"], + json.dumps(case["overrides"]), + ) + + +def main(): + """Entry point.""" + + def _load_spec(path_spec: str) -> Dict[str, Any]: + path = Path(path_spec) + if not path.exists(): + raise argparse.ArgumentTypeError(f"Spec file does not exist: {path}") + with path.open("r", encoding="utf-8") as i_f: + return json.load(i_f) + + parser = ArgumentParser("MLC LLM continuous library delivery") + parser.add_argument( + "--spec", + type=_load_spec, + required=True, + help="Path to the spec file", + ) + parsed = parser.parse_args() + _main( + spec=parsed.spec, + ) + + +if __name__ == "__main__": + main() diff --git a/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py b/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py new file mode 100644 index 0000000000..b7cfd76fa3 --- /dev/null +++ b/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py @@ -0,0 +1,66 @@ +"""The pass that attaches logit processor functions to the IRModule.""" + +import tvm +from tvm import IRModule +from tvm.script import tir as T + + +@tvm.transform.module_pass(opt_level=0, name="AttachSpecDecodeAuxFuncs") +class AttachSpecDecodeAuxFuncs: # pylint: disable=too-few-public-methods + """Attach logit processing TIR functions to IRModule.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + mod = mod.clone() + mod["scatter_probs"] = _get_scatter_2d_inplace( + dtype="float32", global_symbol="scatter_probs" + ) + mod["gather_probs"] = _get_gather_2d_inplace(dtype="float32", global_symbol="gather_probs") + if "prefill_to_last_hidden_states" in mod: + hidden_states_struct_info = mod["prefill_to_last_hidden_states"].ret_struct_info.fields[ + 0 + ] # pylint: disable=no-member + dtype = hidden_states_struct_info.dtype + mod["scatter_hidden_states"] = _get_scatter_2d_inplace( + dtype, global_symbol="scatter_hidden_states" + ) + mod["gather_hidden_states"] = _get_gather_2d_inplace( + dtype, global_symbol="gather_hidden_states" + ) + return mod + + +def _get_scatter_2d_inplace(dtype: str, global_symbol: str): + @T.prim_func + def _scatter_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): + T.func_attr({"global_symbol": global_symbol, "tir.noalias": True}) + batch_size = T.int32(is_size_var=True) + m = T.int32(is_size_var=True) + n = T.int32(is_size_var=True) + src = T.match_buffer(var_src, (batch_size, n), dtype) + indices = T.match_buffer(var_indices, (batch_size,), "int32") + dst = T.match_buffer(var_dst, (m, n), dtype) + for b, j in T.grid(batch_size, n): + with T.block("scatter_2d"): + vb, vj = T.axis.remap("SS", [b, j]) + dst[indices[vb], vj] = src[vb, vj] + + return _scatter_2d + + +def _get_gather_2d_inplace(dtype: str, global_symbol: str): + @T.prim_func + def _gather_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): + T.func_attr({"global_symbol": global_symbol, "tir.noalias": True}) + batch_size = T.int32(is_size_var=True) + m = T.int32(is_size_var=True) + n = T.int32(is_size_var=True) + src = T.match_buffer(var_src, (m, n), dtype) + indices = T.match_buffer(var_indices, (batch_size,), "int32") + dst = T.match_buffer(var_dst, (batch_size, n), dtype) + for b, j in T.grid(batch_size, n): + with T.block("gather_2d"): + vb, vj = T.axis.remap("SS", [b, j]) + dst[vb, vj] = src[indices[vb], vj] + + return _gather_2d diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index 57b68f742d..3c80d2c4df 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -15,6 +15,7 @@ from .attach_embedding_allocator import AttachAllocEmbeddingTensorFunc from .attach_logit_processor import AttachLogitProcessFunc from .attach_sampler import AttachGPUSamplingFunc +from .attach_spec_decode_aux_funcs import AttachSpecDecodeAuxFuncs from .attach_support_info import ( AttachAdditionalPrimFuncs, AttachCUDAGraphSymbolicCaptureHints, @@ -104,6 +105,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I AttachAdditionalPrimFuncs(additional_tirs), AttachAllocEmbeddingTensorFunc(metadata), AttachGPUSamplingFunc(target, variable_bounds), + AttachSpecDecodeAuxFuncs(), AttachMemoryPlanAttr(), tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)), _DebugDump("debug-phase0.py", debug_dump, show_meta=False), diff --git a/python/mlc_llm/compiler_pass/rewrite_softmax.py b/python/mlc_llm/compiler_pass/rewrite_softmax.py index 1a6e41eafc..82e6cf863b 100644 --- a/python/mlc_llm/compiler_pass/rewrite_softmax.py +++ b/python/mlc_llm/compiler_pass/rewrite_softmax.py @@ -34,7 +34,10 @@ def __init__(self, mod: IRModule, target: tvm.target.Target) -> None: def transform(self) -> IRModule: """Entry point""" - gv = self.mod.get_global_var("softmax_with_temperature") + func_name = "softmax_with_temperature" + if func_name not in self.mod: + return self.mod + gv = self.mod.get_global_var(func_name) updated_func = self.visit_expr(self.mod[gv]) self.builder_.update_func(gv, updated_func) return self.builder_.get() diff --git a/python/mlc_llm/interface/compile.py b/python/mlc_llm/interface/compile.py index 4e8bcabd9e..7be9dadd39 100644 --- a/python/mlc_llm/interface/compile.py +++ b/python/mlc_llm/interface/compile.py @@ -1,4 +1,5 @@ """Python entrypoint of compilation.""" + import dataclasses import math from io import StringIO @@ -162,7 +163,11 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: logger.info("Running optimizations using TVM Unity") additional_tirs = _apply_preproc_to_params(named_params, model_config) variable_bounds = _get_variable_bounds(model_config) - cuda_graph_symbolic_capture_hints = {"batch_decode": ["batch_size"]} + cuda_graph_symbolic_capture_hints = { + "batch_decode": ["batch_size"], + "batch_decode_to_last_hidden_states": ["batch_size"], + "batch_verify_to_last_hidden_states": ["batch_size", "seq_len"], + } metadata = { "model_type": args.model.name, "quantization": args.quantization.name, diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index 8e617fc3d2..13f0e1215f 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -5,7 +5,7 @@ import re import shutil from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from mlc_llm.conversation_template import ConvTemplateRegistry from mlc_llm.model import Model @@ -51,7 +51,11 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes pad_token_id: int = None bos_token_id: int = None eos_token_id: int = None + # Tokenizer configuration tokenizer_files: List[str] = dataclasses.field(default_factory=list) + # The method to post-process the token table. See + # cpp/tokenizers.h::Tokenizer::PostProcessTokenTable for details + token_table_postproc_method: Literal["byte_fallback", "byte_level"] = None # Version control version: str = VERSION @@ -129,6 +133,70 @@ def json2rwkv_tokenizer(vocab: Path, out: Path) -> None: msgpack.pack(idx2token, f) +def detect_token_table_postproc_method(output_path: Path) -> Literal["byte_fallback", "byte_level"]: + """Detect the token table postprocessing method from tokenizer.json that is found under + output_path. If not detected, use ByteFallback as default. + + Check the decoder field of the tokenizer. If it uses ByteFallback decoder, return + "byte_fallback". If it uses ByteLevel decoder, return "byte_level". Otherwise, use + ByteFallback as default. + + See also cpp/tokenizers.h::Tokenizer::PostProcessTokenTable. + """ + output_tokenizer_path = output_path / "tokenizer.json" + if not output_tokenizer_path.exists(): + logger.warning( + "Tokenizer token table postprocessing method is not detected as tokenizer.json " + "is not found, use ByteFallback (the same as LLaMA/LLaMA2) by default" + ) + return "byte_fallback" + + with output_tokenizer_path.open("r", encoding="utf-8") as in_file: + tokenizer_json = json.load(in_file) + + # Find all decoders in tokenizer.json + decoders = [] + + if "decoder" not in tokenizer_json: + logger.warning( + "Decoder field is not found in tokenizer.json, use ByteFallback (the same as " + "LLaMA/LLaMA2) as the token table postprocessing method by default" + ) + return "byte_fallback" + + decoders_json = tokenizer_json["decoder"] + assert "type" in decoders_json, "Decoder type is not specified in tokenizer.json" + if decoders_json["type"] == "Sequence": + assert "decoders" in decoders_json + decoders = decoders_json["decoders"] + else: + decoders = [decoders_json] + + is_byte_level = False + is_byte_fallback = False + + for decoder in decoders: + if decoder["type"] == "ByteLevel": + is_byte_level = True + if decoder["type"] == "ByteFallback": + is_byte_fallback = True + assert not ( + is_byte_level and is_byte_fallback + ), "Tokenizer decoder cannot have both type ByteLevel and type ByteFallback" + + if is_byte_level: + return "byte_level" + if is_byte_fallback: + return "byte_fallback" + + logger.warning( + "Neither ByteLevel nor ByteFallback decoder is detected in tokenizer.json, use " + "ByteFallback (the same as LLaMA/LLaMA2) as the token table postprocessing method " + "by default" + ) + return "byte_fallback" + + def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements config: Path, model: Model, @@ -255,6 +323,10 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b except Exception: # pylint: disable=broad-exception-caught logger.exception("%s with the exception below. Skipping", FAILED) + # 3.4. Find the token table postprocessing method from tokenizer.json if it exists. If not + # detected, use "byte_fallback" as default. + mlc_chat_config.token_table_postproc_method = detect_token_table_postproc_method(output) + # Step 4. Load system default value mlc_chat_config.apply_defaults() # Step 5. Dump the configuration file to output directory diff --git a/python/mlc_llm/json_ffi/__init__.py b/python/mlc_llm/json_ffi/__init__.py new file mode 100644 index 0000000000..8a7059153d --- /dev/null +++ b/python/mlc_llm/json_ffi/__init__.py @@ -0,0 +1,8 @@ +"""JSON FFI is a pure string based interface of MLC LLM Engine. + +We build interfacing with JSON FFI for both testing purposes +and internal use. For most python API usage, please use MLCEngine +and MLCAsyncEngine +""" + +from .engine import JSONFFIEngine diff --git a/python/mlc_llm/json_ffi/engine.py b/python/mlc_llm/json_ffi/engine.py new file mode 100644 index 0000000000..0c604a2ef3 --- /dev/null +++ b/python/mlc_llm/json_ffi/engine.py @@ -0,0 +1,310 @@ +# pylint: disable=chained-comparison,missing-docstring,too-few-public-methods,too-many-instance-attributes +# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +import json +import queue +import threading +from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union + +import tvm + +from mlc_llm.protocol import openai_api_protocol +from mlc_llm.serve import engine_utils +from mlc_llm.serve.engine_base import ( + EngineConfig, + SpeculativeMode, + _infer_kv_cache_config, + _parse_models, + _process_model_args, + detect_device, +) +from mlc_llm.tokenizer import Tokenizer + + +# TODO(mlc-team): further minimize the JSONFFIEngine +# construction to not depend on any config and directly pass in JSON +# model defined generation config should be read from the JSONFFIEngine via Reload +def create_model_defined_generation_config( + temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float +) -> tvm.runtime.Object: + return tvm.get_global_func("mlc.json_ffi.ModelDefinedGenerationConfig")( + temperature, + top_p, + frequency_penalty, + presence_penalty, + ) + + +# TODO(mlc-team): further minimize the JSONFFIEngine +# Engine config should be passed as json str +# and backend should have good default +# only model and model_lib should be mandatory +def create_json_ffi_engine_config( + conv_template: str, model_generation_cfgs: Dict[str, tvm.runtime.Object] +) -> tvm.runtime.Object: + return tvm.get_global_func("mlc.json_ffi.JSONFFIEngineConfig")( + conv_template, model_generation_cfgs + ) + + +class EngineState: + sync_queue: queue.Queue + + def get_request_stream_callback(self) -> Callable[[List[str]], None]: + # ChatCompletionStreamResponse + + def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: + self._sync_request_stream_callback(chat_completion_stream_responses_json_str) + + return _callback + + def _sync_request_stream_callback( + self, chat_completion_stream_responses_json_str: List[str] + ) -> None: + # Put the delta outputs to the queue in the unblocking way. + self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) + + +class JSONFFIEngine: + def __init__( # pylint: disable=too-many-arguments,too-many-locals + self, + model: str, + device: Union[str, tvm.runtime.Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + max_history_size: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, + gpu_memory_utilization: Optional[float] = None, + ) -> None: + # - Initialize model loading info. + models = _parse_models(model, model_lib_path, additional_models) + if isinstance(device, str): + device = detect_device(device) + assert isinstance(device, tvm.runtime.Device) + ( + model_args, + model_config_paths, + self.conv_template, + ) = _process_model_args(models, device) + + # TODO(mlc-team) Remove the model config parsing, estimation below + # in favor of a simple direct passing of parameters into backend. + # JSONFFIEngine do not have to support automatic mode + # + # Instead, its config should default to interactive mode always + # and allow overrides of parameters through json config via reload + # + # This is to simplify the logic of users of JSONFFI + # since we won't have similar logics in android/iOS + # + # - Load the raw model config into dict + self.model_config_dicts = [] + for i, model_info in enumerate(models): + model_info.model_lib_path = model_args[i][1] + with open(model_config_paths[i], "r", encoding="utf-8") as file: + self.model_config_dicts.append(json.load(file)) + + # - Decide the KV cache config based on mode and user input. + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + max_history_size, + kv_state_kind, + ) = _infer_kv_cache_config( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_history_size, + gpu_memory_utilization, + models, + device, + self.model_config_dicts, + model_config_paths, + ) + self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + + # - Initialize engine state and engine. + self.state = EngineState() + module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() + self._ffi = { + key: module[key] + for key in [ + "init_background_engine", + "reload", + "unload", + "reset", + "chat_completion", + "abort", + "get_last_error", + "run_background_loop", + "run_background_stream_back_loop", + "exit_background_loop", + ] + } + self.tokenizer = Tokenizer(model_args[0][0]) + + self.engine_config = EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ) + + self.json_ffi_engine_config = create_json_ffi_engine_config( + conv_template=self.conv_template.model_dump_json(), + model_generation_cfgs={ + model.model: create_model_defined_generation_config( + temperature=model_config["temperature"], + top_p=model_config["top_p"], + frequency_penalty=model_config["frequency_penalty"], + presence_penalty=model_config["presence_penalty"], + ) + for model, model_config in zip(models, self.model_config_dicts) + }, + ) + + self._ffi["init_background_engine"]( + self.json_ffi_engine_config, + self.engine_config, + device, + self.state.get_request_stream_callback(), + None, + ) + + def _background_loop(): + self._ffi["run_background_loop"]() + + def _background_stream_back_loop(): + self._ffi["run_background_stream_back_loop"]() + + # Create the background engine-driving thread and start the loop. + self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) + self._background_stream_back_loop_thread: threading.Thread = threading.Thread( + target=_background_stream_back_loop + ) + self._background_loop_thread.start() + self._background_stream_back_loop_thread.start() + self._terminated = False + + def terminate(self): + self._terminated = True + self._ffi["exit_background_loop"]() + self._background_loop_thread.join() + self._background_stream_back_loop_thread.join() + + def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + if request_id is None: + request_id = f"chatcmpl-{engine_utils.random_uuid()}" + + chatcmpl_generator = self._handle_chat_completion( + openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ).model_dump_json(), + n=n, + request_id=request_id, + ) + for response in chatcmpl_generator: + yield response + + def _handle_chat_completion( + self, request_json_str: str, n: int, request_id: str + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + self.state.sync_queue = queue.Queue() + num_unfinished_requests = n + + success = bool(self._ffi["chat_completion"](request_json_str, request_id)) + + try: + while num_unfinished_requests > 0: + chat_completion_stream_responses_json_str = self.state.sync_queue.get() + for chat_completion_response_json_str in chat_completion_stream_responses_json_str: + chat_completion_response = ( + openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( + chat_completion_response_json_str + ) + ) + for choice in chat_completion_response.choices: + if choice.finish_reason is not None: + num_unfinished_requests -= 1 + yield chat_completion_response + except Exception as exception: # pylint: disable=broad-exception-caught + self._ffi["abort"](request_id) + raise exception + + def _test_reload(self): + self._ffi["reload"](self.engine_config) + + def _test_reset(self): + self._ffi["reset"]() + + def _test_unload(self): + self._ffi["unload"]() diff --git a/python/mlc_llm/model/eagle/eagle_model.py b/python/mlc_llm/model/eagle/eagle_model.py index 355618df09..9d7820b841 100644 --- a/python/mlc_llm/model/eagle/eagle_model.py +++ b/python/mlc_llm/model/eagle/eagle_model.py @@ -190,8 +190,8 @@ def get_default_spec(self): }, }, "fuse_embed_hidden_states": { - "input_embed": nn.spec.Tensor(["length", self.hidden_size], self.dtype), - "hidden_states": nn.spec.Tensor(["length", self.hidden_size], self.dtype), + "input_embed": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), + "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/gpt2/gpt2_model.py b/python/mlc_llm/model/gpt2/gpt2_model.py index 28c34353e2..ede9dc350f 100644 --- a/python/mlc_llm/model/gpt2/gpt2_model.py +++ b/python/mlc_llm/model/gpt2/gpt2_model.py @@ -28,7 +28,7 @@ class GPT2Config(ConfigBase): # pylint: disable=too-many-instance-attributes n_embd: int n_layer: int n_head: int - layer_norm_epsilon: int + layer_norm_epsilon: float n_inner: int = -1 context_window_size: int = 0 prefill_chunk_size: int = 0 diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index 18238f688e..60c8f138d1 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -248,16 +248,11 @@ def get_logits(self, hidden_states: Tensor): logits = logits.astype("float32") return logits - def batch_get_logits(self, hidden_states: Tensor, logit_positions: Tensor): + def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): op_ext.configure() if self.tensor_parallel_shards > 1: logit_positions = op.ccl_broadcast_from_worker0(logit_positions) hidden_states = op.take(hidden_states, logit_positions, axis=0) - return self.get_logits(hidden_states) - - def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): - op_ext.configure() - hidden_states = op.take(hidden_states, logit_positions, axis=0) return hidden_states def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): @@ -368,14 +363,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "batch_get_logits": { - "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), - "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "$": { - "param_mode": "packed", - "effect_mode": "none", - }, - }, "batch_select_last_hidden_states": { "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), diff --git a/python/mlc_llm/op/batch_spec_verify.py b/python/mlc_llm/op/batch_spec_verify.py index 9cdbe2be21..d1a57fc71c 100644 --- a/python/mlc_llm/op/batch_spec_verify.py +++ b/python/mlc_llm/op/batch_spec_verify.py @@ -51,7 +51,7 @@ def batch_spec_verify(vocab_size): token_tree_parent_ptr: Current parent ptr state """ - TX = 128 + TX = 1024 def _var(dtype="int32"): return T.alloc_buffer((1,), dtype, scope="local") @@ -142,7 +142,6 @@ def _func( model_prob_local[0] = model_probs[parent_ptr[0], k] draft_prob_local[0] = draft_probs[child_ptr[0], k] model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0) - model_probs[parent_ptr[0], k] = model_prob_local[0] psum[0] += model_prob_local[0] with T.block("block_cross_thread"): @@ -155,13 +154,21 @@ def _func( ) T.tvm_thread_allreduce(T.uint32(1), psum[0], True, t0[0], tx, dtype="handle") - # renormalize - for i in T.serial(T.ceildiv(vocab_size, TX)): - k = T.meta_var(i * TX + tx) - if k < vocab_size: - model_probs[parent_ptr[0], k] = model_probs[parent_ptr[0], k] / t0[0] - - child_ptr[0] = token_tree_next_sibling[child_ptr[0]] + if t0[0] < 1e-7: + # accept the proposal, we move to child + parent_ptr[0] = child_ptr[0] + child_ptr[0] = token_tree_first_child[child_ptr[0]] + else: + # renormalize + for i in T.serial(T.ceildiv(vocab_size, TX)): + k = T.meta_var(i * TX + tx) + if k < vocab_size: + model_prob_local[0] = model_probs[parent_ptr[0], k] + draft_prob_local[0] = draft_probs[child_ptr[0], k] + model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0) + model_probs[parent_ptr[0], k] = model_prob_local[0] / t0[0] + + child_ptr[0] = token_tree_next_sibling[child_ptr[0]] if tx == 0: token_tree_parent_ptr[b] = parent_ptr[0] diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 40c53e336a..6b808ac37b 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -164,9 +164,6 @@ class EngineConfig(tvm.runtime.Object): additional_model_lib_paths : List[str] The path to the additional models' libraries. - device : tvm.runtime.Device - The device where the models run. - kv_cache_page_size : int The number of consecutive tokens handled in each page in paged KV cache. @@ -203,7 +200,6 @@ def __init__( # pylint: disable=too-many-arguments model_lib_path: str, additional_models: List[str], additional_model_lib_paths: List[str], - device: tvm.runtime.Device, kv_cache_page_size: int, max_num_sequence: int, max_total_sequence_length: int, @@ -220,7 +216,6 @@ def __init__( # pylint: disable=too-many-arguments model_lib_path, additional_models, additional_model_lib_paths, - device, kv_cache_page_size, max_num_sequence, max_total_sequence_length, diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index fb0a35ddd2..7f3f7e1331 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -1066,10 +1066,12 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "init_background_engine", "exit_background_loop", "debug_call_func_on_all_worker", + "stats", ] } self.tokenizer = Tokenizer(model_args[0][0]) self._ffi["init_background_engine"]( + device, self.state.get_request_stream_callback(kind), self.state.trace_recorder, ) @@ -1079,7 +1081,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model_lib_path=model_args[0][1], additional_models=[model_arg[0] for model_arg in model_args[1:]], additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, kv_cache_page_size=16, max_num_sequence=max_batch_size, max_total_sequence_length=max_total_sequence_length, @@ -1118,6 +1119,10 @@ def _debug_call_func_on_all_worker(self, func_name: str) -> None: """Call the given global function on all workers. Only for debug purpose.""" self._ffi["debug_call_func_on_all_worker"](func_name) + def stats(self): + """Get the engine stats.""" + return self._ffi["stats"]() + def process_chat_completion_request( # pylint: disable=too-many-arguments request: openai_api_protocol.ChatCompletionRequest, diff --git a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py index af1613c027..9f6508ea42 100644 --- a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py @@ -79,3 +79,31 @@ async def debug_cuda_profiler_stop(_request: fastapi.Request): "mlc.debug_cuda_profiler_stop" ) break + + +@app.post("/debug/dump_engine_stats") +async def debug_dump_engine_stats(request: fastapi.Request): + """Dump the engine stats for the engine. Only for debug purpose.""" + # Get the raw request body as bytes + request_raw_data = await request.body() + request_json_str = request_raw_data.decode("utf-8") + try: + # Parse the JSON string + request_dict = json.loads(request_json_str) + except json.JSONDecodeError: + return error_protocol.create_error_response( + HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" + ) + if "model" not in request_dict: + return error_protocol.create_error_response( + HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" + ) + + # - Check the requested model. + model = request_dict["model"] + + server_context: ServerContext = ServerContext.current() + async_engine = server_context.get_engine(model) + res = async_engine.stats() + print(res) + return json.loads(res) diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 7469ddc241..1be841cb08 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -166,7 +166,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model_lib_path=model_args[0][1], additional_models=[model_arg[0] for model_arg in model_args[1:]], additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, kv_cache_page_size=16, max_num_sequence=max_batch_size, max_total_sequence_length=max_total_sequence_length, @@ -177,6 +176,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ), + device, request_stream_callback, self.trace_recorder, ) diff --git a/tests/python/json_ffi/_ffi_api.py b/tests/python/json_ffi/_ffi_api.py deleted file mode 100644 index 3df07d6a1f..0000000000 --- a/tests/python/json_ffi/_ffi_api.py +++ /dev/null @@ -1,6 +0,0 @@ -"""FFI APIs for mlc.json_ffi""" -import tvm._ffi - -# Exports functions registered via TVM_REGISTER_GLOBAL with the "mlc.json_ffi" prefix. -# e.g. TVM_REGISTER_GLOBAL("mlc.serve.TextData") -tvm._ffi._init_api("mlc.json_ffi", __name__) # pylint: disable=protected-access diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index f5235663be..c52571b522 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -1,24 +1,6 @@ -# pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable -import json -import queue -import threading from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union -import tvm -from tests.python.json_ffi import _ffi_api - -from mlc_llm.protocol import openai_api_protocol -from mlc_llm.serve import engine_utils -from mlc_llm.serve.engine_base import ( - EngineConfig, - SpeculativeMode, - _infer_kv_cache_config, - _parse_models, - _process_model_args, - detect_device, -) -from mlc_llm.tokenizer import Tokenizer +from mlc_llm.json_ffi import JSONFFIEngine chat_completion_prompts = [ "What is the meaning of life?", @@ -61,286 +43,6 @@ ] -@tvm._ffi.register_object( - "mlc.json_ffi.ModelDefinedGenerationConfig" -) # pylint: disable=protected-access -class ModelDefinedGenerationConfig(tvm.runtime.Object): - def __init__( # pylint: disable=too-many-arguments - self, temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.ModelDefinedGenerationConfig, - temperature, - top_p, - frequency_penalty, - presence_penalty, - ) - - -@tvm._ffi.register_object("mlc.json_ffi.JSONFFIEngineConfig") # pylint: disable=protected-access -class JSONFFIEngineConfig(tvm.runtime.Object): - def __init__( # pylint: disable=too-many-arguments - self, conv_template: str, model_generation_cfgs: Dict[str, ModelDefinedGenerationConfig] - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.JSONFFIEngineConfig, conv_template, model_generation_cfgs - ) - - -class EngineState: - sync_queue: queue.Queue - - def get_request_stream_callback(self) -> Callable[[List[str]], None]: - # ChatCompletionStreamResponse - - def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: - self._sync_request_stream_callback(chat_completion_stream_responses_json_str) - - return _callback - - def _sync_request_stream_callback( - self, chat_completion_stream_responses_json_str: List[str] - ) -> None: - # Put the delta outputs to the queue in the unblocking way. - self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) - - -class JSONFFIEngine: - def __init__( # pylint: disable=too-many-arguments,too-many-locals - self, - model: str, - device: Union[str, tvm.runtime.Device] = "auto", - *, - model_lib_path: Optional[str] = None, - mode: Literal["local", "interactive", "server"] = "local", - additional_models: Optional[List[str]] = None, - max_batch_size: Optional[int] = None, - max_total_sequence_length: Optional[int] = None, - max_history_size: Optional[int] = None, - prefill_chunk_size: Optional[int] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, - spec_draft_length: int = 4, - gpu_memory_utilization: Optional[float] = None, - ) -> None: - # - Initialize model loading info. - models = _parse_models(model, model_lib_path, additional_models) - if isinstance(device, str): - device = detect_device(device) - assert isinstance(device, tvm.runtime.Device) - ( - model_args, - model_config_paths, - self.conv_template, - ) = _process_model_args(models, device) - - # - Load the raw model config into dict - self.model_config_dicts = [] - for i, model_info in enumerate(models): - model_info.model_lib_path = model_args[i][1] - with open(model_config_paths[i], "r", encoding="utf-8") as file: - self.model_config_dicts.append(json.load(file)) - - # - Decide the KV cache config based on mode and user input. - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - max_history_size, - kv_state_kind, - ) = _infer_kv_cache_config( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_history_size, - gpu_memory_utilization, - models, - device, - self.model_config_dicts, - model_config_paths, - ) - self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) - - # - Initialize engine state and engine. - self.state = EngineState() - module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() - self._ffi = { - key: module[key] - for key in [ - "init_background_engine", - "reload", - "unload", - "reset", - "chat_completion", - "abort", - "get_last_error", - "run_background_loop", - "run_background_stream_back_loop", - "exit_background_loop", - ] - } - self.tokenizer = Tokenizer(model_args[0][0]) - - self.engine_config = EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - max_history_size=max_history_size, - kv_state_kind=kv_state_kind, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ) - - self.json_ffi_engine_config = JSONFFIEngineConfig( - conv_template=self.conv_template.model_dump_json(), - model_generation_cfgs={ - model.model: ModelDefinedGenerationConfig( - temperature=model_config["temperature"], - top_p=model_config["top_p"], - frequency_penalty=model_config["frequency_penalty"], - presence_penalty=model_config["presence_penalty"], - ) - for model, model_config in zip(models, self.model_config_dicts) - }, - ) - - self._ffi["init_background_engine"]( - self.json_ffi_engine_config, - self.engine_config, - self.state.get_request_stream_callback(), - None, - ) - - def _background_loop(): - self._ffi["run_background_loop"]() - - def _background_stream_back_loop(): - self._ffi["run_background_stream_back_loop"]() - - # Create the background engine-driving thread and start the loop. - self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) - self._background_stream_back_loop_thread: threading.Thread = threading.Thread( - target=_background_stream_back_loop - ) - self._background_loop_thread.start() - self._background_stream_back_loop_thread.start() - self._terminated = False - - def terminate(self): - self._terminated = True - self._ffi["exit_background_loop"]() - self._background_loop_thread.join() - self._background_stream_back_loop_thread.join() - - def chat_completion( # pylint: disable=too-many-arguments,too-many-locals - self, - *, - messages: List[Dict[str, Any]], - model: str, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - logprobs: bool = False, - top_logprobs: int = 0, - logit_bias: Optional[Dict[int, float]] = None, - max_tokens: Optional[int] = None, - n: int = 1, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: bool = False, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - tools: Optional[List[Dict[str, Any]]] = None, - tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, - user: Optional[str] = None, - ignore_eos: bool = False, - response_format: Optional[Dict[str, Any]] = None, - request_id: Optional[str] = None, - ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: - if request_id is None: - request_id = f"chatcmpl-{engine_utils.random_uuid()}" - - chatcmpl_generator = self._handle_chat_completion( - openai_api_protocol.ChatCompletionRequest( - messages=[ - openai_api_protocol.ChatCompletionMessage.model_validate(message) - for message in messages - ], - model=model, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - logprobs=logprobs, - top_logprobs=top_logprobs, - logit_bias=logit_bias, - max_tokens=max_tokens, - n=n, - seed=seed, - stop=stop, - stream=stream, - temperature=temperature, - top_p=top_p, - tools=( - [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] - if tools is not None - else None - ), - tool_choice=tool_choice, - user=user, - ignore_eos=ignore_eos, - response_format=( - openai_api_protocol.RequestResponseFormat.model_validate(response_format) - if response_format is not None - else None - ), - ).model_dump_json(), - n=n, - request_id=request_id, - ) - for response in chatcmpl_generator: - yield response - - def _handle_chat_completion( - self, request_json_str: str, n: int, request_id: str - ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: - self.state.sync_queue = queue.Queue() - num_unfinished_requests = n - - success = bool(self._ffi["chat_completion"](request_json_str, request_id)) - - try: - while num_unfinished_requests > 0: - chat_completion_stream_responses_json_str = self.state.sync_queue.get() - for chat_completion_response_json_str in chat_completion_stream_responses_json_str: - chat_completion_response = ( - openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( - chat_completion_response_json_str - ) - ) - for choice in chat_completion_response.choices: - if choice.finish_reason is not None: - num_unfinished_requests -= 1 - yield chat_completion_response - except Exception as exception: # pylint: disable=broad-exception-caught - self._ffi["abort"](request_id) - raise exception - - def _test_reload(self): - self._ffi["reload"](self.engine_config) - - def _test_reset(self): - self._ffi["reset"]() - - def _test_unload(self): - self._ffi["unload"]() - - def run_chat_completion( engine: JSONFFIEngine, model: str, @@ -382,10 +84,8 @@ def run_chat_completion( def test_chat_completion(): # Create engine. model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-cuda.so" engine = JSONFFIEngine( model, - model_lib_path=model_lib_path, max_total_sequence_length=1024, ) @@ -402,10 +102,8 @@ def test_chat_completion(): def test_reload_reset_unload(): # Create engine. model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-cuda.so" engine = JSONFFIEngine( model, - model_lib_path=model_lib_path, max_total_sequence_length=1024, ) diff --git a/tests/python/op/test_batch_spec_verify.py b/tests/python/op/test_batch_spec_verify.py index 359fafdbd0..f35a39d71e 100644 --- a/tests/python/op/test_batch_spec_verify.py +++ b/tests/python/op/test_batch_spec_verify.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize("nbatch", [32, 64]) -@pytest.mark.parametrize("vocab", [3, 32, 64, 32000, 33, 65, 32001]) +@pytest.mark.parametrize("vocab", [3, 32, 64, 32000, 33, 65, 32001, 128000]) @pytest.mark.parametrize("plist", [[0.5, 0.5], [1, 0], [0, 1]]) def test_batch_spec_verify(nbatch, vocab, plist): def numpy_reference( @@ -141,6 +141,20 @@ def gen_full_binary_tree(height, base): token_tree_parent_ptr, token_tree_parent_ptr_tvm.asnumpy(), rtol=0, atol=0 ) + time_evaluator = mod.time_evaluator(mod.entry_name, dev, number=10, repeat=3) + print(f"batch_size: {nbatch}, vocab_size: {vocab}, tree_structure: {plist}") + print( + time_evaluator( + draft_probs_tvm, + draft_tokens_tvm, + model_probs_tvm, + token_tree_first_child_tvm, + token_tree_next_sibling_tvm, + uniform_samples_tvm, + token_tree_parent_ptr_tvm, + ) + ) + if __name__ == "__main__": tvm.testing.main()