forked from mlc-ai/mlc-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'mlc-ai:main' into main
- Loading branch information
Showing
60 changed files
with
1,865 additions
and
1,127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/*! | ||
* Copyright (c) 2024 by Contributors | ||
* \file serve/draft_token_workspace_manager.cc | ||
*/ | ||
|
||
#include "draft_token_workspace_manager.h" | ||
|
||
#include "model.h" | ||
|
||
namespace mlc { | ||
namespace llm { | ||
namespace serve { | ||
|
||
DraftTokenWorkspaceManagerObj::DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size, | ||
int hidden_size, | ||
DLDataType hidden_states_dtype, | ||
DLDevice device, | ||
const FunctionTable& ft) | ||
: max_num_tokens_(max_num_tokens), | ||
vocab_size_(vocab_size), | ||
hidden_size_(hidden_size), | ||
hidden_states_dtype_(hidden_states_dtype), | ||
device_(device), | ||
ft_(ft) { | ||
free_slots_.resize(max_num_tokens); | ||
std::iota(free_slots_.begin(), free_slots_.end(), 0); | ||
} | ||
|
||
void DraftTokenWorkspaceManagerObj::AllocSlots(int num_slots, std::vector<int>* result) { | ||
ICHECK_LE(num_slots, free_slots_.size()); | ||
result->assign(free_slots_.rbegin(), free_slots_.rbegin() + num_slots); | ||
std::vector<int> allocated(free_slots_.begin(), free_slots_.begin() + num_slots); | ||
free_slots_.resize(free_slots_.size() - num_slots); | ||
} | ||
|
||
void DraftTokenWorkspaceManagerObj::FreeSlots(const std::vector<int>& slots) { | ||
std::copy(slots.begin(), slots.end(), std::back_inserter(free_slots_)); | ||
} | ||
|
||
void DraftTokenWorkspaceManagerObj::AllocWorkspace(ModelWorkspace* workspace, | ||
bool require_hidden_states) { | ||
workspace->draft_probs = | ||
NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_); | ||
workspace->draft_probs_storage = | ||
NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_); | ||
if (require_hidden_states) { | ||
workspace->draft_hidden_states_storage = | ||
ft_.Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_); | ||
} | ||
} | ||
|
||
} // namespace serve | ||
} // namespace llm | ||
} // namespace mlc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/*! | ||
* Copyright (c) 2024 by Contributors | ||
* \file serve/draft_token_workspace_manager.h | ||
*/ | ||
|
||
#ifndef MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_ | ||
#define MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_ | ||
#include <tvm/runtime/device_api.h> | ||
|
||
#include <numeric> | ||
#include <optional> | ||
#include <vector> | ||
|
||
#include "data.h" | ||
#include "function_table.h" | ||
namespace mlc { | ||
namespace llm { | ||
namespace serve { | ||
|
||
using tvm::Device; | ||
using namespace tvm::runtime; | ||
|
||
struct ModelWorkspace; | ||
|
||
/*! | ||
* \brief Managing the workspace for draft token generation. | ||
* | ||
* The workspace is used to store the associated states for each draft token, including the | ||
* probability distribution of the draft token, the hidden states, etc. The workspace manager | ||
* maintains a pool of slots for the draft tokens to store the states. | ||
*/ | ||
class DraftTokenWorkspaceManagerObj : public Object { | ||
public: | ||
/*! | ||
* \brief Constructor | ||
* \param max_num_tokens The maximum number of draft tokens that can be stored in the workspace. | ||
* \param vocab_size The size of the vocabulary. | ||
* \param hidden_size The size of the hidden states. | ||
* \param hidden_states_dtype The data type of the hidden states. | ||
* \param device The device running the model. | ||
* \param ft The function table. | ||
*/ | ||
DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size, int hidden_size, | ||
DLDataType hidden_states_dtype, DLDevice device, | ||
const FunctionTable& ft); | ||
|
||
/*! | ||
* \brief Allocate the workspace for draft tokens and update `ModelWorkspace` data structure. | ||
* \param workspace The object to stored the allocated draft token workspace. | ||
* \param require_hidden_states Whether to allocate workspace for the hidden states. | ||
*/ | ||
void AllocWorkspace(ModelWorkspace* workspace, bool require_hidden_states); | ||
|
||
/*! | ||
* \brief Allocate slots for the draft tokens. | ||
* \param num_slots The number of slots to allocate. | ||
* \param result The vector to store the allocated slots. | ||
*/ | ||
void AllocSlots(int num_slots, std::vector<int>* result); | ||
|
||
/*! | ||
* \brief Free the slots. | ||
* \param slots The slots to free. | ||
*/ | ||
void FreeSlots(const std::vector<int>& slots); | ||
|
||
static constexpr const char* _type_key = "mlc.serve.DraftTokenWorkspaceManager"; | ||
|
||
private: | ||
std::vector<int> free_slots_; | ||
int max_num_tokens_; | ||
int vocab_size_; | ||
int hidden_size_; | ||
DataType hidden_states_dtype_; | ||
DLDevice device_; | ||
const FunctionTable& ft_; | ||
}; | ||
|
||
class DraftTokenWorkspaceManager : public ObjectRef { | ||
public: | ||
DraftTokenWorkspaceManager(int max_num_tokens, int vocab_size, int hidden_size, | ||
DLDataType hidden_states_dtype, DLDevice device, | ||
const FunctionTable& ft) { | ||
data_ = make_object<DraftTokenWorkspaceManagerObj>(max_num_tokens, vocab_size, hidden_size, | ||
hidden_states_dtype, device, ft); | ||
} | ||
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DraftTokenWorkspaceManager, ObjectRef, | ||
DraftTokenWorkspaceManagerObj); | ||
}; | ||
|
||
} // namespace serve | ||
} // namespace llm | ||
} // namespace mlc | ||
|
||
#endif // MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_ |
Oops, something went wrong.