Skip to content

Commit

Permalink
[update] support more config.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Sep 13, 2024
1 parent 39e1087 commit 14fde20
Show file tree
Hide file tree
Showing 4 changed files with 454 additions and 302 deletions.
6 changes: 2 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ option(LLM_SUPPORT_VISION "Llm model support vision input." OFF)
option(DUMP_PROFILE_INFO "Dump profile info when chat." OFF)
option(BUILD_JNI "Build JNI for android app." OFF)

if (LLM_SUPPORT_VISION)
add_definitions(-DLLM_SUPPORT_VISION)
endif()

if (DUMP_PROFILE_INFO)
add_definitions(-DDUMP_PROFILE_INFO)
endif()
Expand All @@ -25,6 +21,7 @@ if (BUILD_FOR_ANDROID)
set(MNN_ARM82 ON CACHE BOOL "Open MNN_ARM82" FORCE)
endif()
if (LLM_SUPPORT_VISION)
add_definitions(-DLLM_SUPPORT_VISION)
set(MNN_BUILD_OPENCV ON CACHE BOOL "Open MNN_BUILD_OPENCV" FORCE)
set(MNN_IMGCODECS ON CACHE BOOL "Open MNN_IMGCODECS" FORCE)
endif()
Expand All @@ -34,6 +31,7 @@ add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/MNN)
include_directories(${CMAKE_CURRENT_LIST_DIR}/include/
${CMAKE_CURRENT_LIST_DIR}/MNN/include/
${CMAKE_CURRENT_LIST_DIR}/MNN/tools/cv/include/
${CMAKE_CURRENT_LIST_DIR}/MNN/MNN/3rd_party/
)

# source files
Expand Down
221 changes: 24 additions & 197 deletions include/llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using namespace Express;
using json = nlohmann::json;
class Tokenizer;
class Pipeline;
class LlmConfig;

// Llm start
// llm stream buffer with callback
Expand Down Expand Up @@ -64,175 +65,13 @@ struct Prompt {
std::vector<int> tokens;
};

static inline bool has_suffix(const std::string& str, const std::string& suffix) {
return str.size() >= suffix.size() &&
str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
}

static inline std::string base_dir(const std::string& path) {
size_t pos = path.find_last_of("/\\");
if (pos == std::string::npos) {
return "./";
} else {
return path.substr(0, pos + 1);
}
}

static inline std::string file_name(const std::string& path) {
size_t pos = path.find_last_of("/\\");
if (pos == std::string::npos) {
return path;
} else {
return path.substr(pos + 1);
}
}

class LlmConfig {
public:
std::string base_dir_;
json config_, llm_config_;
LlmConfig() {}
LlmConfig(const std::string& path) {
// load config
if (has_suffix(path, ".json")) {
std::ifstream config_file(path);
if (config_file.is_open()) {
config_ = json::parse(config_file);
} else {
std::cerr << "Unable to open config file: " << path << std::endl;
}
base_dir_ = base_dir(path);
} else {
// compatibility with the original usage
if (has_suffix(path, ".mnn")) {
auto model_name = file_name(path);
config_ = {
{"llm_model", model_name},
{"llm_weight", model_name + ".weight"}
};
base_dir_ = base_dir(path);
} else {
config_ = {};
base_dir_ = path;
}
}
// using config's base_dir
base_dir_ = config_.value("base_dir", base_dir_);
// load llm_config for model info
std::ifstream llm_config_file(llm_config());
if (llm_config_file.is_open()) {
llm_config_ = json::parse(llm_config_file);
} else {
std::cerr << "Unable to open llm_config file: " << llm_config() << std::endl;
}
}

// < model file config start
std::string llm_config() const {
return base_dir_ + config_.value("llm_config", "llm_config.json");
}

std::string llm_model() const {
return base_dir_ + config_.value("llm_model", "llm.mnn");
}

std::string llm_weight() const {
return base_dir_ + config_.value("llm_weight", "llm.mnn.weight");
}

std::string block_model(int index) const {
return base_dir_ + config_.value("block_model", "block_") + std::to_string(index) + ".mnn";
}

std::string lm_model() const {
return base_dir_ + config_.value("lm_model", "lm.mnn");
}

std::string embedding_model() const {
return base_dir_ + config_.value("embedding_model", "embedding.mnn");
}

std::string embedding_file() const {
return base_dir_ + config_.value("embedding_file", "embeddings_bf16.bin");
}

std::string tokenizer_file() const {
return base_dir_ + config_.value("tokenizer_file", "tokenizer.txt");
}

std::string visual_model() const {
return base_dir_ + config_.value("visual_model", "visual.mnn");
}
// model file config end >

// < generate config start
int max_new_tokens() const {
return config_.value("max_new_tokens", 512);
}
// generate config end >

// < backend config start
std::string backend_type() const {
return config_.value("backend_type", "cpu");
}

int thread_num() const {
return config_.value("thread_num", 4);
}

std::string precision() const {
return config_.value("precision", "low");
}

std::string memory() const {
return config_.value("memory", "low");
}
// backend config end >

// < llm model config start
bool is_single() const {
return llm_config_.value("is_single", true);
}

bool is_visual() const {
return llm_config_.value("is_visual", false);
}

int hidden_size() const {
return llm_config_.value("hidden_size", 4096);
}

int layer_nums() const {
return llm_config_.value("layer_nums", 32);
}

std::vector<int> key_value_shape() const {
return llm_config_.value("key_value_shape", std::vector<int>{});
}

std::string attention_mask() const {
return llm_config_.value("attention_mask", "int");
}

std::string chat_template() const {
return llm_config_.value("chat_template", "");
}

std::string prompt_template() const {
return llm_config_.value("prompt_template", "");
}
// llm model config end >
};

class Llm {
public:
using PromptItem = std::pair<std::string, std::string>; // <role, content>
Llm(std::shared_ptr<LlmConfig> config) : config_(config) {}
virtual ~Llm() {
modules_.clear();
runtime_manager_.reset();
}
virtual ~Llm();
void chat();
void reset();
static Llm* createLLM(const std::string& config_path);
virtual void load();
VARP forward(const std::vector<int>& input_ids);
Expand All @@ -245,56 +84,44 @@ class Llm {
std::string generate(const std::vector<int>& input_ids, std::ostream* os, const char* end_with);
std::vector<int> generate(const std::vector<int>& input_ids, int max_new_tokens = -1);
void print_speed();
// config function
std::string dump_config();
bool set_config(const std::string& content);
// lora function
size_t apply_lora(const std::string& lora_path);
Llm* create_lora(const std::string& lora_path);
bool release_module(size_t index);
bool select_module(size_t index);
friend class Pipeline;
public:
// forward info
int prompt_len_ = 0;
int gen_seq_len_ = 0;
int all_seq_len_ = 0;
std::vector<int> history_ids_;
// time
int64_t prefill_us_ = 0;
int64_t decode_us_ = 0;
float load_progress_ = 0.f;
bool is_single_ = true;
bool is_disk_embedding_ = true;
std::shared_ptr<LlmConfig> config_;
std::unique_ptr<Tokenizer> tokenizer_;
bool attention_fused_ = true;
protected:
std::shared_ptr<LlmConfig> config_;
std::shared_ptr<Tokenizer> tokenizer_;
std::vector<int> key_value_shape_ = {};
std::vector<VARP> past_key_values_;
VARP inputs_embeds_, attention_mask_, position_ids_;
std::shared_ptr<Executor::RuntimeManager> runtime_manager_;
std::vector<std::shared_ptr<Module>> modules_;
std::vector<MNN::Express::VARP> past_key_values_;
MNN::Express::VARP inputs_embeds_, attention_mask_, position_ids_;
std::shared_ptr<MNN::Express::Executor::RuntimeManager> runtime_manager_;
std::vector<std::shared_ptr<MNN::Express::Module>> modules_;
std::vector<std::shared_ptr<MNN::Express::Module>> prefill_modules_, decode_modules_, current_modules_;
const MNN::Express::Module* base_module_ = nullptr;
void init_runtime();
std::string decode(int id);
bool is_stop(int token_id);
virtual std::vector<int> tokenizer(const std::string& query);
virtual VARP embedding(const std::vector<int>& input_ids);
virtual VARP gen_attention_mask(int seq_len);
virtual VARP gen_position_ids(int seq_len);
};

class Lvlm : public Llm {
public:
Lvlm(std::shared_ptr<LlmConfig> config) : Llm(config) {
image_size_ = config->llm_config_.value("image_size", image_size_);
image_pad_ = config->llm_config_.value("image_pad", image_pad_);
vision_start_ = config->llm_config_.value("vision_start", vision_start_);
vision_end_ = config->llm_config_.value("vision_end", vision_end_);
image_mean_ = config->llm_config_.value("image_mean", image_mean_);
image_norm_ = config->llm_config_.value("image_norm", image_norm_);
}
~Lvlm() { visual_module_.reset(); }
virtual void load() override;
virtual std::vector<int> tokenizer(const std::string& query) override;
virtual MNN::Express::VARP embedding(const std::vector<int>& input_ids) override;
private:
int image_size_ = 448, vision_start_ = 151857, vision_end_ = 151858, image_pad_ = 151859;
std::vector<float> image_mean_ {122.7709383 , 116.7460125 , 104.09373615};
std::vector<float> image_norm_ {0.01459843, 0.01500777, 0.01422007};
std::vector<int> image_process(const std::string& img_info);
std::shared_ptr<Module> visual_module_;
std::vector<VARP> image_embeddings_;
virtual MNN::Express::VARP embedding(const std::vector<int>& input_ids);
virtual MNN::Express::VARP gen_attention_mask(int seq_len);
virtual MNN::Express::VARP gen_position_ids(int seq_len);
};
// Llm end

Expand Down
Loading

0 comments on commit 14fde20

Please sign in to comment.