-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
xiaying
committed
Oct 18, 2023
1 parent
1bbb3b1
commit 3ff49cb
Showing
160 changed files
with
12,759 additions
and
687 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# include dir | ||
include_directories(${CMAKE_CURRENT_LIST_DIR}/include/) | ||
|
||
# source files | ||
FILE(GLOB SRCS ${CMAKE_CURRENT_LIST_DIR}/src/*.cpp) | ||
|
||
if (MSVC) | ||
# compile static lib, surrpot Winwows | ||
add_library(llm STATIC ${SRCS}) | ||
target_link_libraries(llm ${MNN_DEPS}) | ||
else() | ||
# compile dynamic so, support Linux/Mac | ||
add_library(llm SHARED ${SRCS}) | ||
set_target_properties(llm PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE) | ||
target_link_libraries(llm ${MNN_DEPS}) | ||
endif() | ||
target_compile_features(llm PRIVATE cxx_std_17) | ||
|
||
add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/llm_demo.cpp) | ||
target_compile_features(llm_demo PRIVATE cxx_std_17) | ||
target_link_libraries(llm_demo llm) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
// | ||
// llm.hpp | ||
// | ||
// Created by MNN on 2023/08/25. | ||
// ZhaodeWang | ||
// | ||
|
||
#ifndef LLM_hpp | ||
#define LLM_hpp | ||
|
||
#include <vector> | ||
#include <memory> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <iostream> | ||
|
||
#include <MNN/AutoTime.hpp> | ||
#include <MNN/expr/Expr.hpp> | ||
#include <MNN/expr/Module.hpp> | ||
#include <MNN/expr/MathOp.hpp> | ||
#include <MNN/expr/NeuralNetWorkOp.hpp> | ||
#include "tokenizer.hpp" | ||
|
||
using namespace MNN; | ||
using namespace Express; | ||
|
||
class MNN_PUBLIC Llm { | ||
public: | ||
Llm() { | ||
// default tokenier is senrencepiece | ||
tokenizer_.reset(new Sentencepiece); | ||
} | ||
static Llm* createLLM(const std::string& path); | ||
VARP gen_embedding(const std::vector<int>& input_ids); | ||
void load(const std::string& model_dir); | ||
int forward(const std::vector<int>& input_ids); | ||
std::vector<int> tokenizer_encode(const std::string& input_str); | ||
std::string decode(int id); | ||
std::string response(const std::string& input_str, std::ostream* os = &std::cout); | ||
float load_progress() { return load_progress_; } | ||
void reset(); | ||
private: | ||
virtual std::vector<int> tokenizer(const std::string& query) = 0; | ||
virtual VARP gen_attention_mask(int seq_len) = 0; | ||
virtual VARP gen_position_ids(int seq_len) = 0; | ||
virtual bool is_stop(int token_id) = 0; | ||
protected: | ||
// model configs | ||
bool is_single_ = false; | ||
int layer_nums_ = 0; | ||
int hidden_size_ = 4096; | ||
std::vector<int> key_value_shape_ = {}; | ||
std::string model_name_ = ""; | ||
// gen info | ||
int gen_seq_len_ = 0; | ||
int all_seq_len_ = 0; | ||
int max_seq_len_ = 256; | ||
float load_progress_ = 0.f; | ||
// tokenizer | ||
std::unique_ptr<Tokenizer> tokenizer_; | ||
private: | ||
// MNN Modules | ||
std::shared_ptr<Executor::RuntimeManager> runtime_manager_; | ||
std::vector<std::shared_ptr<Module>> modules_; | ||
std::vector<VARP> past_key_values_; | ||
// model dir | ||
std::string model_dir_; | ||
// tokenizer | ||
std::vector<std::string> word_decoder_; | ||
std::unordered_map<std::string, int> word_encoder_; | ||
}; | ||
|
||
// some llm models | ||
class Chatglm_6b : public Llm { | ||
public: | ||
Chatglm_6b() { | ||
model_name_ = "Chatglm_6b"; | ||
layer_nums_ = 28; | ||
key_value_shape_ = {2, 0, 1, 32, 128}; | ||
} | ||
private: | ||
virtual std::vector<int> tokenizer(const std::string& query) override; | ||
virtual VARP gen_attention_mask(int seq_len) override; | ||
virtual VARP gen_position_ids(int seq_len) override; | ||
virtual bool is_stop(int token_id) override; | ||
int context_len_ = 0; | ||
}; | ||
|
||
class Chatglm2_6b : public Llm { | ||
public: | ||
Chatglm2_6b() { | ||
model_name_ = "Chatglm2_6b"; | ||
layer_nums_ = 28; | ||
key_value_shape_ = {2, 0, 1, 2, 128}; | ||
} | ||
private: | ||
virtual std::vector<int> tokenizer(const std::string& query) override; | ||
virtual VARP gen_attention_mask(int seq_len) override; | ||
virtual VARP gen_position_ids(int seq_len) override; | ||
virtual bool is_stop(int token_id) override; | ||
}; | ||
|
||
|
||
class Qwen_7b : public Llm { | ||
public: | ||
Qwen_7b() { | ||
model_name_ = "Qwen_7b"; | ||
layer_nums_ = 32; | ||
key_value_shape_ = {2, 1, 0, 32, 128}; | ||
tokenizer_.reset(new Tiktoken); | ||
} | ||
private: | ||
virtual std::vector<int> tokenizer(const std::string& query) override; | ||
virtual VARP gen_attention_mask(int seq_len) override; | ||
virtual VARP gen_position_ids(int seq_len) override; | ||
virtual bool is_stop(int token_id) override; | ||
}; | ||
|
||
class Llama2_7b : public Llm { | ||
public: | ||
Llama2_7b() { | ||
model_name_ = "Llama2_7b"; | ||
layer_nums_ = 32; | ||
key_value_shape_ = {2, 1, 32, 0, 128}; | ||
} | ||
private: | ||
virtual std::vector<int> tokenizer(const std::string& query) override; | ||
virtual VARP gen_attention_mask(int seq_len) override; | ||
virtual VARP gen_position_ids(int seq_len) override; | ||
virtual bool is_stop(int token_id) override; | ||
}; | ||
|
||
#endif // LLM_hpp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
// | ||
// tokenizer.hpp | ||
// | ||
// Created by MNN on 2023/09/25. | ||
// ZhaodeWang | ||
// | ||
|
||
#ifndef TOKENIZER_hpp | ||
#define TOKENIZER_hpp | ||
|
||
#include <vector> | ||
#include <memory> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <iostream> | ||
#include <string_view> | ||
|
||
class Tokenizer { | ||
public: | ||
Tokenizer() = default; | ||
virtual bool load(const std::string& filename) = 0; | ||
virtual std::vector<int> encode(const std::string& str) = 0; | ||
virtual std::string decode(int id) = 0; | ||
}; | ||
|
||
class Sentencepiece : public Tokenizer { | ||
public: | ||
Sentencepiece() = default; | ||
virtual bool load(const std::string& filename) override; | ||
virtual std::vector<int> encode(const std::string& str) override; | ||
virtual std::string decode(int id) override; | ||
private: | ||
enum ModelType { | ||
UNIGRAM = 1, | ||
BPE = 2, | ||
WORD = 3, | ||
CHAR = 4 | ||
}; | ||
enum PieceType { | ||
NORMAL = 1, | ||
UNKNOWN = 2, | ||
CONTROL = 3, | ||
USER_DEFINED = 4, | ||
UNUSED = 5, | ||
BYTE = 6 | ||
}; | ||
struct SentencePiece { | ||
std::string piece; | ||
float score; | ||
PieceType type = PieceType::NORMAL; | ||
}; | ||
using EncodeResult = std::vector<std::pair<std::string_view, int>>; | ||
private: | ||
// model train type | ||
ModelType type_ = BPE; | ||
// byte fall back enable | ||
bool byte_fall_back_ = true; | ||
// unknown id. | ||
int unk_id_ = 0; | ||
// pieces from model | ||
std::vector<SentencePiece> sentence_pieces_; | ||
// piece -> id map for normal pieces | ||
std::unordered_map<std::string, int> pieces_; | ||
// piece -> id map for control, unknown, and byte pieces | ||
std::unordered_map<std::string, int> reserved_id_map_; | ||
private: | ||
float get_score(int id) const; | ||
bool is_unused(int id) const; | ||
bool is_control(int id) const; | ||
int piece_to_id(const std::string& w) const; | ||
std::string byte_to_piece(unsigned char c) const; | ||
EncodeResult bpe_encode(std::string_view str, float alpha = 0.f); | ||
}; | ||
|
||
class Tiktoken : public Tokenizer { | ||
public: | ||
Tiktoken() = default; | ||
virtual bool load(const std::string& filename) override; | ||
virtual std::vector<int> encode(const std::string& str) override; | ||
virtual std::string decode(int id) override; | ||
private: | ||
std::vector<std::string> decoder_; | ||
std::vector<int> tokens_; | ||
std::vector<int> token_ids_; | ||
}; | ||
|
||
#endif // TOKENIZER_hpp |
Oops, something went wrong.