Skip to content

Commit

Permalink
Merge branch 'mlc-ai:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
JackWeiw authored May 1, 2024
2 parents f90bde8 + d206c44 commit 6fcd10d
Show file tree
Hide file tree
Showing 60 changed files with 1,865 additions and 1,127 deletions.
4 changes: 2 additions & 2 deletions cpp/json_ffi/json_ffi_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
TVM_MODULE_VTABLE_END();

void InitBackgroundEngine(JSONFFIEngineConfig json_ffi_engine_config, EngineConfig engine_config,
Optional<PackedFunc> request_stream_callback,
Device device, Optional<PackedFunc> request_stream_callback,
Optional<EventTraceRecorder> trace_recorder) {
std::optional<Conversation> conv_template =
Conversation::FromJSON(json_ffi_engine_config->conv_template, &err_);
Expand All @@ -150,7 +150,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
};

request_stream_callback = PackedFunc(frequest_stream_callback_wrapper);
this->engine_->InitBackgroundEngine(std::move(request_stream_callback),
this->engine_->InitBackgroundEngine(device, std::move(request_stream_callback),
std::move(trace_recorder));
this->engine_->Reload(std::move(engine_config));
}
Expand Down
61 changes: 53 additions & 8 deletions cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,16 @@ String GenerationConfigNode::AsJSONString() const {
TVM_REGISTER_OBJECT_TYPE(EngineConfigNode);

EngineConfig::EngineConfig(String model, String model_lib_path, Array<String> additional_models,
Array<String> additional_model_lib_paths, DLDevice device,
int kv_cache_page_size, int max_num_sequence,
int max_total_sequence_length, int max_single_sequence_length,
int prefill_chunk_size, int max_history_size, KVStateKind kv_state_kind,
Array<String> additional_model_lib_paths, int kv_cache_page_size,
int max_num_sequence, int max_total_sequence_length,
int max_single_sequence_length, int prefill_chunk_size,
int max_history_size, KVStateKind kv_state_kind,
SpeculativeMode speculative_mode, int spec_draft_length) {
ObjectPtr<EngineConfigNode> n = make_object<EngineConfigNode>();
n->model = std::move(model);
n->model_lib_path = std::move(model_lib_path);
n->additional_models = std::move(additional_models);
n->additional_model_lib_paths = std::move(additional_model_lib_paths);
n->device = device;
n->kv_cache_page_size = kv_cache_page_size;
n->max_num_sequence = max_num_sequence;
n->max_total_sequence_length = max_total_sequence_length;
Expand All @@ -267,14 +266,60 @@ EngineConfig::EngineConfig(String model, String model_lib_path, Array<String> ad
data_ = std::move(n);
}

EngineConfig EngineConfig::FromJSONString(const std::string& json_str) {
picojson::value config_json;
std::string err = picojson::parse(config_json, json_str);
if (!err.empty()) {
LOG(FATAL) << err;
}

// Get json fields.
picojson::object config = config_json.get<picojson::object>();
String model = json::Lookup<std::string>(config, "model");
String model_lib_path = json::Lookup<std::string>(config, "model_lib_path");
std::vector<String> additional_models;
std::vector<String> additional_model_lib_paths;
int kv_cache_page_size = json::Lookup<int64_t>(config, "kv_cache_page_size");
int max_num_sequence = json::Lookup<int64_t>(config, "max_num_sequence");
int max_total_sequence_length = json::Lookup<int64_t>(config, "max_total_sequence_length");
int max_single_sequence_length = json::Lookup<int64_t>(config, "max_single_sequence_length");
int prefill_chunk_size = json::Lookup<int64_t>(config, "prefill_chunk_size");
int max_history_size = json::Lookup<int64_t>(config, "max_history_size");
KVStateKind kv_state_kind =
static_cast<KVStateKind>(json::Lookup<int64_t>(config, "kv_state_kind"));
SpeculativeMode speculative_mode =
static_cast<SpeculativeMode>(json::Lookup<int64_t>(config, "speculative_mode"));
int spec_draft_length = json::Lookup<int64_t>(config, "spec_draft_length");

picojson::array additional_models_arr =
json::Lookup<picojson::array>(config, "additional_models");
picojson::array additional_model_lib_paths_arr =
json::Lookup<picojson::array>(config, "additional_model_lib_paths");
CHECK_EQ(additional_models_arr.size(), additional_model_lib_paths_arr.size())
<< "The number of additional model lib paths does not match the number of additional models";
int num_additional_models = additional_models_arr.size();
additional_models.reserve(num_additional_models);
additional_model_lib_paths.reserve(num_additional_models);
for (int i = 0; i < num_additional_models; ++i) {
additional_models.push_back(json::Lookup<std::string>(additional_models_arr, i));
additional_model_lib_paths.push_back(
json::Lookup<std::string>(additional_model_lib_paths_arr, i));
}

return EngineConfig(std::move(model), std::move(model_lib_path), additional_models,
additional_model_lib_paths, kv_cache_page_size, max_num_sequence,
max_total_sequence_length, max_single_sequence_length, prefill_chunk_size,
max_history_size, kv_state_kind, speculative_mode, spec_draft_length);
}

TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig")
.set_body_typed([](String model, String model_lib_path, Array<String> additional_models,
Array<String> additional_model_lib_paths, DLDevice device,
int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length,
Array<String> additional_model_lib_paths, int kv_cache_page_size,
int max_num_sequence, int max_total_sequence_length,
int max_single_sequence_length, int prefill_chunk_size, int max_history_size,
int kv_state_kind, int speculative_mode, int spec_draft_length) {
return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models),
std::move(additional_model_lib_paths), device, kv_cache_page_size,
std::move(additional_model_lib_paths), kv_cache_page_size,
max_num_sequence, max_total_sequence_length, max_single_sequence_length,
prefill_chunk_size, max_history_size, KVStateKind(kv_state_kind),
SpeculativeMode(speculative_mode), spec_draft_length);
Expand Down
12 changes: 5 additions & 7 deletions cpp/serve/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,6 @@ class EngineConfigNode : public Object {
/*! \brief The path to the additional models' libraries. */
Array<String> additional_model_lib_paths;

/*************** Device ***************/

/*! \brief The device where the models run. */
DLDevice device;

/*************** KV cache config and engine capacities ***************/

/*! \brief The number of consecutive tokens handled in each page in paged KV cache. */
Expand Down Expand Up @@ -152,12 +147,15 @@ class EngineConfigNode : public Object {
class EngineConfig : public ObjectRef {
public:
explicit EngineConfig(String model, String model_lib_path, Array<String> additional_models,
Array<String> additional_model_lib_paths, DLDevice device,
int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length,
Array<String> additional_model_lib_paths, int kv_cache_page_size,
int max_num_sequence, int max_total_sequence_length,
int max_single_sequence_length, int prefill_chunk_size,
int max_history_size, KVStateKind kv_state_kind,
SpeculativeMode speculative_mode, int spec_draft_length);

/*! \brief Create EngineConfig from JSON string. */
static EngineConfig FromJSONString(const std::string& json_str);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode);
};

Expand Down
54 changes: 54 additions & 0 deletions cpp/serve/draft_token_workspace_manager.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*!
* Copyright (c) 2024 by Contributors
* \file serve/draft_token_workspace_manager.cc
*/

#include "draft_token_workspace_manager.h"

#include "model.h"

namespace mlc {
namespace llm {
namespace serve {

DraftTokenWorkspaceManagerObj::DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size,
int hidden_size,
DLDataType hidden_states_dtype,
DLDevice device,
const FunctionTable& ft)
: max_num_tokens_(max_num_tokens),
vocab_size_(vocab_size),
hidden_size_(hidden_size),
hidden_states_dtype_(hidden_states_dtype),
device_(device),
ft_(ft) {
free_slots_.resize(max_num_tokens);
std::iota(free_slots_.begin(), free_slots_.end(), 0);
}

void DraftTokenWorkspaceManagerObj::AllocSlots(int num_slots, std::vector<int>* result) {
ICHECK_LE(num_slots, free_slots_.size());
result->assign(free_slots_.rbegin(), free_slots_.rbegin() + num_slots);
std::vector<int> allocated(free_slots_.begin(), free_slots_.begin() + num_slots);
free_slots_.resize(free_slots_.size() - num_slots);
}

void DraftTokenWorkspaceManagerObj::FreeSlots(const std::vector<int>& slots) {
std::copy(slots.begin(), slots.end(), std::back_inserter(free_slots_));
}

void DraftTokenWorkspaceManagerObj::AllocWorkspace(ModelWorkspace* workspace,
bool require_hidden_states) {
workspace->draft_probs =
NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_);
workspace->draft_probs_storage =
NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_);
if (require_hidden_states) {
workspace->draft_hidden_states_storage =
ft_.Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_);
}
}

} // namespace serve
} // namespace llm
} // namespace mlc
95 changes: 95 additions & 0 deletions cpp/serve/draft_token_workspace_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*!
* Copyright (c) 2024 by Contributors
* \file serve/draft_token_workspace_manager.h
*/

#ifndef MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_
#define MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_
#include <tvm/runtime/device_api.h>

#include <numeric>
#include <optional>
#include <vector>

#include "data.h"
#include "function_table.h"
namespace mlc {
namespace llm {
namespace serve {

using tvm::Device;
using namespace tvm::runtime;

struct ModelWorkspace;

/*!
* \brief Managing the workspace for draft token generation.
*
* The workspace is used to store the associated states for each draft token, including the
* probability distribution of the draft token, the hidden states, etc. The workspace manager
* maintains a pool of slots for the draft tokens to store the states.
*/
class DraftTokenWorkspaceManagerObj : public Object {
public:
/*!
* \brief Constructor
* \param max_num_tokens The maximum number of draft tokens that can be stored in the workspace.
* \param vocab_size The size of the vocabulary.
* \param hidden_size The size of the hidden states.
* \param hidden_states_dtype The data type of the hidden states.
* \param device The device running the model.
* \param ft The function table.
*/
DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size, int hidden_size,
DLDataType hidden_states_dtype, DLDevice device,
const FunctionTable& ft);

/*!
* \brief Allocate the workspace for draft tokens and update `ModelWorkspace` data structure.
* \param workspace The object to stored the allocated draft token workspace.
* \param require_hidden_states Whether to allocate workspace for the hidden states.
*/
void AllocWorkspace(ModelWorkspace* workspace, bool require_hidden_states);

/*!
* \brief Allocate slots for the draft tokens.
* \param num_slots The number of slots to allocate.
* \param result The vector to store the allocated slots.
*/
void AllocSlots(int num_slots, std::vector<int>* result);

/*!
* \brief Free the slots.
* \param slots The slots to free.
*/
void FreeSlots(const std::vector<int>& slots);

static constexpr const char* _type_key = "mlc.serve.DraftTokenWorkspaceManager";

private:
std::vector<int> free_slots_;
int max_num_tokens_;
int vocab_size_;
int hidden_size_;
DataType hidden_states_dtype_;
DLDevice device_;
const FunctionTable& ft_;
};

class DraftTokenWorkspaceManager : public ObjectRef {
public:
DraftTokenWorkspaceManager(int max_num_tokens, int vocab_size, int hidden_size,
DLDataType hidden_states_dtype, DLDevice device,
const FunctionTable& ft) {
data_ = make_object<DraftTokenWorkspaceManagerObj>(max_num_tokens, vocab_size, hidden_size,
hidden_states_dtype, device, ft);
}
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DraftTokenWorkspaceManager, ObjectRef,
DraftTokenWorkspaceManagerObj);
};

} // namespace serve
} // namespace llm
} // namespace mlc

#endif // MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_
Loading

0 comments on commit 6fcd10d

Please sign in to comment.