Skip to content

Commit

Permalink
llama.cpp updated + common prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Oct 11, 2024
1 parent f4da47a commit b2d7322
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 46 deletions.
2 changes: 1 addition & 1 deletion llama_cpp_vendor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ find_package(ament_cmake REQUIRED)
FetchContent_Declare(
llama
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
GIT_TAG b3899
GIT_TAG b3906
)

FetchContent_MakeAvailable(llama)
Expand Down
12 changes: 6 additions & 6 deletions llama_ros/include/llama_ros/llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ using GenerateResponseCallback = std::function<void(struct completion_output)>;
class Llama {

public:
Llama(const struct gpt_params &params, bool debug);
Llama(const struct common_params &params, bool debug);
virtual ~Llama();

std::vector<llama_token> tokenize(const std::string &text, bool add_bos,
Expand All @@ -94,7 +94,7 @@ class Llama {
void reset();
void cancel();

std::string format_chat_prompt(std::vector<struct llama_chat_msg> chat_msgs,
std::string format_chat_prompt(std::vector<struct common_chat_msg> chat_msgs,
bool add_ass);
std::vector<struct lora> list_loras();
void update_loras(std::vector<struct lora> loras);
Expand All @@ -111,7 +111,7 @@ class Llama {
const std::vector<std::string> &documents);

response_output generate_response(const std::string &input_prompt,
struct gpt_sampler_params sparams,
struct common_sampler_params sparams,
GenerateResponseCallback callbakc = nullptr,
std::vector<std::string> stop = {});
response_output generate_response(const std::string &input_prompt,
Expand All @@ -130,13 +130,13 @@ class Llama {
llama_token get_token_eos() { return llama_token_eos(this->model); }

protected:
struct gpt_params params;
struct common_params params;

// model
struct llama_context *ctx;
struct llama_model *model;
std::vector<struct llama_lora_adapter_container> lora_adapters;
struct gpt_sampler *sampler;
std::vector<struct common_lora_adapter_container> lora_adapters;
struct common_sampler *sampler;
struct ggml_threadpool *threadpool;
struct ggml_threadpool *threadpool_batch;

Expand Down
8 changes: 4 additions & 4 deletions llama_ros/include/llama_utils/llama_params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

#ifndef LLAMA_ROS__GPT_PARAMS_HPP
#define LLAMA_ROS__GPT_PARAMS_HPP
#ifndef LLAMA_ROS__common_params_HPP
#define LLAMA_ROS__common_params_HPP

#include <memory>
#include <rclcpp/rclcpp.hpp>
Expand All @@ -36,7 +36,7 @@ namespace llama_utils {

struct llama_params {
bool debug;
struct gpt_params params;
struct common_params params;
struct llava_ros::llava_params llava_params;
};

Expand All @@ -48,7 +48,7 @@ get_llama_params(const rclcpp_lifecycle::LifecycleNode::SharedPtr &node);

enum ggml_sched_priority parse_priority(std::string priority);

struct gpt_sampler_params
struct common_sampler_params
parse_sampling_params(const llama_msgs::msg::SamplingConfig &sampling_config,
int n_vocab);

Expand Down
2 changes: 1 addition & 1 deletion llama_ros/include/llava_ros/llava.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct llava_params {
class Llava : public llama_ros::Llama {

public:
Llava(const struct gpt_params &params,
Llava(const struct common_params &params,
const struct llava_params &llava_params, bool debug = false);
~Llava();

Expand Down
54 changes: 27 additions & 27 deletions llama_ros/src/llama_ros/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

using namespace llama_ros;

Llama::Llama(const struct gpt_params &params, bool debug)
Llama::Llama(const struct common_params &params, bool debug)
: params(params), debug(debug) {

if (this->debug) {
Expand All @@ -42,8 +42,7 @@ Llama::Llama(const struct gpt_params &params, bool debug)
llama_backend_init();
llama_numa_init(this->params.numa);

struct llama_init_result llama_init =
llama_init_from_gpt_params(this->params);
struct common_init_result llama_init = common_init_from_params(this->params);
this->model = llama_init.model;
this->ctx = llama_init.context;
this->lora_adapters = llama_init.lora_adapters;
Expand Down Expand Up @@ -86,7 +85,7 @@ Llama::Llama(const struct gpt_params &params, bool debug)
llama_attach_threadpool(this->ctx, this->threadpool, this->threadpool_batch);

// create the sampler
this->sampler = gpt_sampler_init(this->model, this->params.sparams);
this->sampler = common_sampler_init(this->model, this->params.sparams);
if (!this->sampler) {
LLAMA_LOG_ERROR("Failed to initialize sampling subsystem");
return;
Expand All @@ -105,7 +104,7 @@ Llama::Llama(const struct gpt_params &params, bool debug)
// show info
LLAMA_LOG_INFO("llama.cpp: build = %d, commit = %s", LLAMA_BUILD_NUMBER,
LLAMA_COMMIT);
LLAMA_LOG_INFO("%s", gpt_params_get_system_info(this->params).c_str());
LLAMA_LOG_INFO("%s", common_params_get_system_info(this->params).c_str());

LLAMA_LOG_INFO(
"Generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d",
Expand Down Expand Up @@ -136,7 +135,7 @@ Llama::~Llama() {
this->model = nullptr;

if (this->sampler != nullptr) {
gpt_sampler_free(this->sampler);
common_sampler_free(this->sampler);
this->sampler = nullptr;
}
llama_backend_free();
Expand All @@ -159,7 +158,7 @@ void Llama::reset() {
llama_kv_cache_clear(this->ctx);

if (this->sampler != nullptr) {
gpt_sampler_reset(this->sampler);
common_sampler_reset(this->sampler);
}

this->canceled = false;
Expand Down Expand Up @@ -189,7 +188,7 @@ void Llama::reset() {
std::vector<llama_token> Llama::tokenize(const std::string &text, bool add_bos,
bool special) {
std::lock_guard<std::recursive_mutex> lk(this->mutex);
return llama_tokenize(this->ctx, text, add_bos, special);
return common_tokenize(this->ctx, text, add_bos, special);
}

std::string Llama::detokenize(const std::vector<llama_token> &tokens) {
Expand All @@ -198,7 +197,7 @@ std::string Llama::detokenize(const std::vector<llama_token> &tokens) {
std::string text;

for (llama_token t : tokens) {
text.append(llama_token_to_piece(this->ctx, t));
text.append(common_token_to_piece(this->ctx, t));
}

return text;
Expand Down Expand Up @@ -237,7 +236,7 @@ Llama::generate_embeddings(const std::vector<llama_token> &tokens,
// llama eval
struct llama_batch batch = llama_batch_init(this->params.n_batch, 0, 1);
for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, {0}, i == tokens.size() - 1);
common_batch_add(batch, tokens[i], i, {0}, i == tokens.size() - 1);
}

if (llama_decode(this->ctx, batch)) {
Expand Down Expand Up @@ -265,7 +264,7 @@ Llama::generate_embeddings(const std::vector<llama_token> &tokens,
}

if (normalize) {
llama_embd_normalize(embd, embd_res.data(), n_embd);
common_embd_normalize(embd, embd_res.data(), n_embd);

} else {
for (int i = 0; i < n_embd; i++) {
Expand Down Expand Up @@ -366,9 +365,9 @@ Llama::rank_documents(const std::string &query,
*****************************
*/
std::string
Llama::format_chat_prompt(std::vector<struct llama_chat_msg> chat_msgs,
Llama::format_chat_prompt(std::vector<struct common_chat_msg> chat_msgs,
bool add_ass) {
return llama_chat_apply_template(this->get_model(), "", chat_msgs, add_ass);
return common_chat_apply_template(this->get_model(), "", chat_msgs, add_ass);
}

/*
Expand Down Expand Up @@ -426,7 +425,7 @@ void Llama::update_loras(std::vector<struct lora> loras) {
}
}

llama_lora_adapters_apply(this->ctx, this->lora_adapters);
common_lora_adapters_apply(this->ctx, this->lora_adapters);
}

/*
Expand All @@ -437,12 +436,12 @@ void Llama::update_loras(std::vector<struct lora> loras) {
response_output Llama::generate_response(const std::string &input_prompt,
GenerateResponseCallback callback,
std::vector<std::string> stop) {
struct gpt_sampler_params sparams;
struct common_sampler_params sparams;
return this->generate_response(input_prompt, sparams, callback, stop);
}

response_output Llama::generate_response(const std::string &input_prompt,
struct gpt_sampler_params sparams,
struct common_sampler_params sparams,
GenerateResponseCallback callback,
std::vector<std::string> stop) {

Expand All @@ -464,10 +463,10 @@ response_output Llama::generate_response(const std::string &input_prompt,
this->params.sparams = sparams;

if (this->sampler != nullptr) {
gpt_sampler_free(this->sampler);
common_sampler_free(this->sampler);
}

this->sampler = gpt_sampler_init(this->model, this->params.sparams);
this->sampler = common_sampler_init(this->model, this->params.sparams);

if (this->sampler == nullptr) {
output.stop = stop_type::ABORT;
Expand All @@ -481,7 +480,7 @@ response_output Llama::generate_response(const std::string &input_prompt,
if (this->debug) {
LLAMA_LOG_INFO("Sampler params: %s", this->params.sparams.print().c_str());
LLAMA_LOG_INFO("Sampler constr: %s",
gpt_sampler_print(this->sampler).c_str());
common_sampler_print(this->sampler).c_str());

LLAMA_LOG_INFO("Prompt tokens:\n%s",
this->detokenize(this->prompt_tokens).c_str());
Expand Down Expand Up @@ -540,7 +539,7 @@ response_output Llama::generate_response(const std::string &input_prompt,
LLAMA_LOG_INFO("Finish Response Generation");

if (this->debug) {
gpt_perf_print(this->ctx, this->sampler);
common_perf_print(this->ctx, this->sampler);
}

output.completions = response;
Expand Down Expand Up @@ -575,7 +574,7 @@ void Llama::load_prompt(const std::string &input_prompt, bool add_pfx,

const int n_prev = 64;
const std::string last_output =
gpt_sampler_prev_str(this->sampler, this->ctx, n_prev);
common_sampler_prev_str(this->sampler, this->ctx, n_prev);

// check if prefix is already added
if (last_output.find(
Expand Down Expand Up @@ -610,7 +609,7 @@ Llama::find_stop(std::vector<struct completion_output> completion_result_list,
// check if stopping word appear at the end of the output
const int n_prev = 32;
const std::string last_output =
gpt_sampler_prev_str(this->sampler, this->ctx, n_prev);
common_sampler_prev_str(this->sampler, this->ctx, n_prev);

for (auto w : stopping_words) {
if (last_output.find(w.c_str(), last_output.length() - w.length(),
Expand All @@ -621,7 +620,7 @@ Llama::find_stop(std::vector<struct completion_output> completion_result_list,
}

// eos
if (llama_token_is_eog(this->model, gpt_sampler_last(this->sampler))) {
if (llama_token_is_eog(this->model, common_sampler_last(this->sampler))) {
LLAMA_LOG_INFO("Stopping with EOS");
return FULL_STOP;
}
Expand Down Expand Up @@ -716,7 +715,8 @@ bool Llama::eval_prompt(std::vector<llama_token> prompt_tokens) {
((int)batch.size() < this->params.n_batch)) {

batch.push_back(prompt_tokens[this->n_consumed]);
gpt_sampler_accept(this->sampler, prompt_tokens[this->n_consumed], false);
common_sampler_accept(this->sampler, prompt_tokens[this->n_consumed],
false);
++this->n_consumed;
}

Expand Down Expand Up @@ -836,7 +836,7 @@ bool Llama::eval(struct llama_batch batch) {
std::vector<token_prob> Llama::get_probs() {
std::vector<token_prob> probs;

const auto *cur_p = gpt_sampler_get_candidates(this->sampler);
const auto *cur_p = common_sampler_get_candidates(this->sampler);

const int32_t n_probs = this->params.sparams.n_probs;

Expand All @@ -853,8 +853,8 @@ std::vector<token_prob> Llama::get_probs() {
struct completion_output Llama::sample() {

// sample token
llama_token id = gpt_sampler_sample(this->sampler, this->ctx, -1);
gpt_sampler_accept(this->sampler, id, true);
llama_token id = common_sampler_sample(this->sampler, this->ctx, -1);
common_sampler_accept(this->sampler, id, true);

// create output
struct completion_output result;
Expand Down
6 changes: 3 additions & 3 deletions llama_ros/src/llama_ros/llama_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ void LlamaNode::format_chat_service_callback(
const std::shared_ptr<llama_msgs::srv::FormatChatMessages::Request> request,
std::shared_ptr<llama_msgs::srv::FormatChatMessages::Response> response) {

std::vector<struct llama_chat_msg> converted_messages;
std::vector<struct common_chat_msg> converted_messages;
for (auto message : request->messages) {
struct llama_chat_msg aux;
struct common_chat_msg aux;
aux.role = message.role.c_str();
aux.content = message.content.c_str();

Expand Down Expand Up @@ -401,7 +401,7 @@ void LlamaNode::execute(
this->llama->reset();
}

// update sampling params of gpt_params
// update sampling params of common_params
auto sampling_config = goal_handle->get_goal()->sampling_config;
auto sparams = llama_utils::parse_sampling_params(sampling_config,
this->llama->get_n_vocab());
Expand Down
6 changes: 3 additions & 3 deletions llama_ros/src/llama_utils/llama_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,10 @@ enum ggml_sched_priority llama_utils::parse_priority(std::string priority) {
return GGML_SCHED_PRIO_NORMAL;
}

struct gpt_sampler_params llama_utils::parse_sampling_params(
struct common_sampler_params llama_utils::parse_sampling_params(
const llama_msgs::msg::SamplingConfig &sampling_config, int n_vocab) {

struct gpt_sampler_params sparams;
struct common_sampler_params sparams;

sparams.n_prev = sampling_config.n_prev;
sparams.n_probs = sampling_config.n_probs;
Expand All @@ -432,7 +432,7 @@ struct gpt_sampler_params llama_utils::parse_sampling_params(
sparams.penalize_nl = sampling_config.penalize_nl;

sparams.samplers =
gpt_sampler_types_from_chars(sampling_config.samplers_sequence);
common_sampler_types_from_chars(sampling_config.samplers_sequence);
sparams.grammar = sampling_config.grammar;

if (sparams.grammar.size() == 0 &&
Expand Down
2 changes: 1 addition & 1 deletion llama_ros/src/llava_ros/llava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

using namespace llava_ros;

Llava::Llava(const struct gpt_params &params,
Llava::Llava(const struct common_params &params,
const struct llava_params &llava_params, bool debug)
: llama_ros::Llama(params, debug), llava_params(llava_params) {

Expand Down

0 comments on commit b2d7322

Please sign in to comment.