Skip to content

Commit

Permalink
Qwen2-VL support added
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Jan 7, 2025
1 parent 335250b commit 6262e24
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 50 deletions.
18 changes: 18 additions & 0 deletions llama_bringup/models/Qwen2-VL.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use_llava: True

n_ctx: 8192
n_batch: 512
n_gpu_layers: 15
n_threads: -1
n_predict: 8192

model_repo: "bartowski/Qwen2-VL-2B-Instruct-GGUF"
model_filename: "Qwen2-VL-2B-Instruct-IQ2_M.gguf"

mmproj_repo: "bartowski/Qwen2-VL-2B-Instruct-GGUF"
mmproj_filename: "mmproj-Qwen2-VL-2B-Instruct-f16.gguf"

image_prefix: "<|vision_start|>"
image_suffix: "<|vision_end|>"

system_prompt_type: "ChatML"
7 changes: 4 additions & 3 deletions llama_ros/include/llama_ros/llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,15 @@ using GenerateResponseCallback = std::function<void(struct CompletionOutput)>;
class Llama {

public:
Llama(const struct common_params &params, std::string system_prompt = "");
Llama(const struct common_params &params, std::string system_prompt = "",
bool initial_reset = true);
virtual ~Llama();

std::vector<llama_token> tokenize(const std::string &text, bool add_bos,
bool special = false);
std::string detokenize(const std::vector<llama_token> &tokens);

void reset();
virtual void reset();
void cancel();

std::string format_chat_prompt(std::vector<struct common_chat_msg> chat_msgs,
Expand Down Expand Up @@ -266,7 +267,7 @@ class Llama {
virtual bool eval_prompt();
bool eval_prompt(std::vector<llama_token> prompt_tokens);
bool eval_token(llama_token token);
bool eval(std::vector<llama_token> tokens);
virtual bool eval(std::vector<llama_token> tokens);
bool eval(struct llama_batch batch);

std::vector<struct TokenProb> get_probs();
Expand Down
3 changes: 3 additions & 0 deletions llama_ros/include/llava_ros/llava.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Llava : public llama_ros::Llama {
const struct LlavaParams &llava_params, std::string system_prompt = "");
~Llava();

void reset() override;
bool load_image(std::string base64_str);
struct llava_image_embed *
base64_image_to_embed(const std::string &base64_str);
Expand All @@ -59,6 +60,7 @@ class Llava : public llama_ros::Llama {
bool add_sfx) override;
bool eval_image(struct llava_image_embed *image_embed);
bool eval_prompt();
bool eval(std::vector<llama_token> tokens) override;

struct llava_image_embed *image_embed;
struct clip_ctx *ctx_clip;
Expand All @@ -67,6 +69,7 @@ class Llava : public llama_ros::Llama {
private:
void free_image();
int image_pose;
int st_pos_id;
};

} // namespace llava_ros
Expand Down
86 changes: 45 additions & 41 deletions llama_ros/src/llama_ros/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@

using namespace llama_ros;

Llama::Llama(const struct common_params &params, std::string system_prompt)
Llama::Llama(const struct common_params &params, std::string system_prompt,
bool initial_reset)
: params(params), system_prompt(system_prompt) {

print_build_info();
Expand Down Expand Up @@ -100,7 +101,9 @@ Llama::Llama(const struct common_params &params, std::string system_prompt)
}

// set inital values
this->reset();
if (initial_reset) {
this->reset();
}

// show info
LLAMA_LOG_INFO("llama.cpp: build = %d, commit = %s", LLAMA_BUILD_NUMBER,
Expand Down Expand Up @@ -148,6 +151,38 @@ Llama::~Llama() {
this->threadpool_batch = nullptr;
}

/*
*****************************
* RESET *
* CANCEL *
*****************************
*/
void Llama::reset() {

llama_kv_cache_clear(this->ctx);

if (this->sampler != nullptr) {
common_sampler_reset(this->sampler);
}

this->canceled = false;
this->n_past = 0;
this->n_consumed = 0;
this->ga_i = 0;

this->prompt_tokens.clear();

// load system prompt
if (!this->eval_system_prompt()) {
LLAMA_LOG_ERROR("Failed to eval system prompt");
}

// number of tokens to keep when resetting context
if (this->params.n_keep < 0) {
this->params.n_keep = (int)this->prompt_tokens.size();
}
}

/*
*****************************
* METADATA *
Expand Down Expand Up @@ -339,38 +374,6 @@ struct Metadata Llama::get_metadata() {
return metadata;
}

/*
*****************************
* RESET *
* CANCEL *
*****************************
*/
void Llama::reset() {

llama_kv_cache_clear(this->ctx);

if (this->sampler != nullptr) {
common_sampler_reset(this->sampler);
}

this->canceled = false;
this->n_past = 0;
this->n_consumed = 0;
this->ga_i = 0;

this->prompt_tokens.clear();

// load system prompt
if (!this->eval_system_prompt()) {
LLAMA_LOG_ERROR("Failed to eval system prompt");
}

// number of tokens to keep when resetting context
if (this->params.n_keep < 0) {
this->params.n_keep = (int)this->prompt_tokens.size();
}
}

/*
*****************************
* TOKENIZE *
Expand Down Expand Up @@ -911,6 +914,7 @@ bool Llama::eval_prompt() { return this->eval_prompt(this->prompt_tokens); }
bool Llama::eval_prompt(std::vector<llama_token> prompt_tokens) {

std::vector<llama_token> batch;
batch.reserve(this->params.n_batch);

while (((int)prompt_tokens.size() > this->n_consumed)) {

Expand Down Expand Up @@ -941,13 +945,13 @@ bool Llama::eval(std::vector<llama_token> tokens) {

// create batch
struct llama_batch batch = {
int32_t(tokens.size()),
tokens.data(),
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
int32_t(tokens.size()), // n_tokens
tokens.data(), // tokens
nullptr, // embd
nullptr, // pos
nullptr, // n_seq_id
nullptr, // seq_id
nullptr, // logits
};

return this->eval(batch);
Expand Down
97 changes: 91 additions & 6 deletions llama_ros/src/llava_ros/llava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,32 @@ using namespace llava_ros;

Llava::Llava(const struct common_params &params,
const struct LlavaParams &llava_params, std::string system_prompt)
: llama_ros::Llama(params, system_prompt), llava_params(llava_params) {
: llama_ros::Llama(params, system_prompt, false),
llava_params(llava_params), image_pose(0), st_pos_id(-1) {

// load clip model
const char *clip_path = this->params.mmproj.c_str();
this->ctx_clip = clip_model_load(clip_path, 1);
this->image_embed = nullptr;

// set inital values
this->reset();
}

Llava::~Llava() {
this->image_pose = 0;
this->st_pos_id = -1;
clip_free(this->ctx_clip);
this->free_image();
}

void Llava::reset() {
this->image_pose = 0;
this->st_pos_id = -1;
this->free_image();
Llama::reset();
}

/*
*****************************
* LOAD IMAGE *
Expand Down Expand Up @@ -150,13 +163,40 @@ bool Llava::eval_image(struct llava_image_embed *image_embed) {
int n_embd = this->get_n_embd();
bool succ = true;

for (int i = 0; i < image_embed->n_image_pos; i += this->params.n_batch) {
// for qwen2-vl
auto img_tokens = image_embed->n_image_pos;

std::vector<llama_pos> mrope_pos;
mrope_pos.resize(img_tokens * 4);

std::vector<llama_pos> batch_mrope_pos;
batch_mrope_pos.resize(img_tokens * 4);

int n_eval = image_embed->n_image_pos - i;
// fill mrope if qwen2-vl
if (clip_is_qwen2vl(this->ctx_clip)) {
auto image_size = clip_get_load_image_size(this->ctx_clip);
const int patch_size = 14 * 2;

if (n_eval > this->params.n_batch) {
n_eval = this->params.n_batch;
const int ph =
image_size->height / patch_size + (image_size->height % patch_size > 0);
const int pw =
image_size->width / patch_size + (image_size->width % patch_size > 0);

for (int y = 0; y < ph; y++) {
for (int x = 0; x < pw; x++) {
int i = y * pw + x;
mrope_pos[i] = this->st_pos_id;
mrope_pos[i + img_tokens] = this->st_pos_id + y;
mrope_pos[i + img_tokens * 2] = this->st_pos_id + x;
mrope_pos[i + img_tokens * 3] = 0;
}
}
this->st_pos_id += std::max(pw, ph);
}

for (int i = 0; i < image_embed->n_image_pos; i += this->params.n_batch) {

int n_eval = std::min(this->params.n_batch, image_embed->n_image_pos - i);

struct llama_batch batch = {
int32_t(n_eval), // n_tokens
Expand All @@ -168,7 +208,19 @@ bool Llava::eval_image(struct llava_image_embed *image_embed) {
nullptr // logits
};

if (!this->eval(batch)) {
if (clip_is_qwen2vl(this->ctx_clip)) {
std::fill(batch_mrope_pos.begin(), batch_mrope_pos.end(), 0);
memcpy(batch_mrope_pos.data(), &mrope_pos[i], n_eval * sizeof(llama_pos));
memcpy(&batch_mrope_pos[n_eval * 1], &mrope_pos[img_tokens * 1 + i],
n_eval * sizeof(llama_pos));
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + i],
n_eval * sizeof(llama_pos));
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + i],
n_eval * sizeof(llama_pos));
batch.pos = batch_mrope_pos.data();
}

if (!Llama::eval(batch)) {
LLAMA_LOG_ERROR("Failed in image eval");
succ = false;
break;
Expand Down Expand Up @@ -212,3 +264,36 @@ bool Llava::eval_prompt() {

return true;
}

bool Llava::eval(std::vector<llama_token> tokens) {

std::vector<llama_pos> pos;

// create batch
struct llama_batch batch = {
int32_t(tokens.size()), // n_tokens
tokens.data(), // tokens
nullptr, // embd
nullptr, // pos
nullptr, // n_seq_id
nullptr, // seq_id
nullptr, // logits
};

if (clip_is_qwen2vl(this->ctx_clip)) {
pos.resize(batch.n_tokens * 4);
std::fill(pos.begin(), pos.end(), 0);
for (int j = 0; j < batch.n_tokens * 3; j++) {
pos[j] = this->st_pos_id + (j % batch.n_tokens);
}
batch.pos = pos.data();
}

if (!Llama::eval(batch)) {
return false;
}

this->st_pos_id += batch.n_tokens;

return true;
}

0 comments on commit 6262e24

Please sign in to comment.