Skip to content

Commit

Permalink
feat: sync llama.cpp (#29)
Browse files Browse the repository at this point in the history
* feat: sync llama.cpp

* fix: handle sampling init failure

* feat: sync llama.cpp
  • Loading branch information
jhen0409 authored Oct 23, 2023
1 parent dacc66f commit adb7545
Show file tree
Hide file tree
Showing 18 changed files with 303 additions and 247 deletions.
24 changes: 12 additions & 12 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ public WritableMap completion(ReadableMap params) {
params.hasKey("n_predict") ? params.getInt("n_predict") : -1,
// int n_probs,
params.hasKey("n_probs") ? params.getInt("n_probs") : 0,
// int repeat_last_n,
params.hasKey("repeat_last_n") ? params.getInt("repeat_last_n") : 64,
// float repeat_penalty,
params.hasKey("repeat_penalty") ? (float) params.getDouble("repeat_penalty") : 1.10f,
// float presence_penalty,
params.hasKey("presence_penalty") ? (float) params.getDouble("presence_penalty") : 0.00f,
// float frequency_penalty,
params.hasKey("frequency_penalty") ? (float) params.getDouble("frequency_penalty") : 0.00f,
// int penalty_last_n,
params.hasKey("penalty_last_n") ? params.getInt("penalty_last_n") : 64,
// float penalty_repeat,
params.hasKey("penalty_repeat") ? (float) params.getDouble("penalty_repeat") : 1.10f,
// float penalty_freq,
params.hasKey("penalty_freq") ? (float) params.getDouble("penalty_freq") : 0.00f,
// float penalty_present,
params.hasKey("penalty_present") ? (float) params.getDouble("penalty_present") : 0.00f,
// float mirostat,
params.hasKey("mirostat") ? (float) params.getDouble("mirostat") : 0.00f,
// float mirostat_tau,
Expand Down Expand Up @@ -307,10 +307,10 @@ protected static native WritableMap doCompletion(
int n_threads,
int n_predict,
int n_probs,
int repeat_last_n,
float repeat_penalty,
float presence_penalty,
float frequency_penalty,
int penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present,
float mirostat,
float mirostat_tau,
float mirostat_eta,
Expand Down
27 changes: 13 additions & 14 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ Java_com_rnllama_LlamaContext_doCompletion(
jint n_threads,
jint n_predict,
jint n_probs,
jint repeat_last_n,
jfloat repeat_penalty,
jfloat presence_penalty,
jfloat frequency_penalty,
jint penalty_last_n,
jfloat penalty_repeat,
jfloat penalty_freq,
jfloat penalty_present,
jfloat mirostat,
jfloat mirostat_tau,
jfloat mirostat_eta,
Expand All @@ -301,7 +301,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
llama_reset_timings(llama->ctx);

llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
llama->params.grammar = env->GetStringUTFChars(grammar, nullptr);

int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
Expand All @@ -311,12 +310,12 @@ Java_com_rnllama_LlamaContext_doCompletion(
llama->params.n_predict = n_predict;
llama->params.ignore_eos = ignore_eos;

auto & sparams = llama->params.sampling_params;
auto & sparams = llama->params.sparams;
sparams.temp = temperature;
sparams.repeat_last_n = repeat_last_n;
sparams.repeat_penalty = repeat_penalty;
sparams.presence_penalty = presence_penalty;
sparams.frequency_penalty = frequency_penalty;
sparams.penalty_last_n = penalty_last_n;
sparams.penalty_repeat = penalty_repeat;
sparams.penalty_freq = penalty_freq;
sparams.penalty_present = penalty_present;
sparams.mirostat = mirostat;
sparams.mirostat_tau = mirostat_tau;
sparams.mirostat_eta = mirostat_eta;
Expand All @@ -325,6 +324,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
sparams.tfs_z = tfs_z;
sparams.typical_p = typical_p;
sparams.n_probs = n_probs;
sparams.grammar = env->GetStringUTFChars(grammar, nullptr);

sparams.logit_bias.clear();
if (ignore_eos) {
Expand Down Expand Up @@ -362,12 +362,11 @@ Java_com_rnllama_LlamaContext_doCompletion(
env->ReleaseStringUTFChars(stop_str, stop_chars);
}

if (!llama->loadGrammar()) {
if (!llama->initSampling()) {
auto result = createWriteableMap(env);
putString(env, result, "error", "Failed to load grammar");
putString(env, result, "error", "Failed to initialize sampling");
return reinterpret_cast<jobject>(result);
}

llama->loadPrompt();
llama->beginCompletion();

Expand Down Expand Up @@ -413,7 +412,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
auto tokenResult = createWriteableMap(env);
putString(env, tokenResult, "token", to_send.c_str());

if (llama->params.sampling_params.n_probs > 0) {
if (llama->params.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama->ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
Expand Down
4 changes: 2 additions & 2 deletions cpp/build-info.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#ifndef BUILD_INFO_H
#define BUILD_INFO_H

#define BUILD_NUMBER 1399
#define BUILD_COMMIT "004797f"
#define BUILD_NUMBER 1414
#define BUILD_COMMIT "96981f3"
#define BUILD_COMPILER ""
#define BUILD_TARGET "unknown"

Expand Down
70 changes: 36 additions & 34 deletions cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::string arg;
gpt_params default_params;
const std::string arg_prefix = "--";
llama_sampling_params & sparams = params.sampling_params;
llama_sampling_params & sparams = params.sparams;

for (int i = 1; i < argc; i++) {
arg = argv[i];
Expand Down Expand Up @@ -241,25 +241,26 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
sparams.repeat_last_n = std::stoi(argv[i]);
sparams.penalty_last_n = std::stoi(argv[i]);
sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
} else if (arg == "--repeat-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.repeat_penalty = std::stof(argv[i]);
sparams.penalty_repeat = std::stof(argv[i]);
} else if (arg == "--frequency-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.frequency_penalty = std::stof(argv[i]);
sparams.penalty_freq = std::stof(argv[i]);
} else if (arg == "--presence-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.presence_penalty = std::stof(argv[i]);
sparams.penalty_present = std::stof(argv[i]);
} else if (arg == "--mirostat") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -572,7 +573,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
params.grammar = argv[i];
sparams.grammar = argv[i];
} else if (arg == "--grammar-file") {
if (++i >= argc) {
invalid_param = true;
Expand All @@ -587,7 +588,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(params.grammar)
std::back_inserter(sparams.grammar)
);
#ifndef LOG_DISABLE_LOGS
// Parse args for logging parameters
Expand Down Expand Up @@ -631,6 +632,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
process_escapes(params.prompt);
process_escapes(params.input_prefix);
process_escapes(params.input_suffix);
process_escapes(sparams.cfg_negative_prompt);
for (auto & antiprompt : params.antiprompt) {
process_escapes(antiprompt);
}
Expand All @@ -640,7 +642,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
}

void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
const llama_sampling_params & sparams = params.sampling_params;
const llama_sampling_params & sparams = params.sparams;

printf("usage: %s [options]\n", argv[0]);
printf("\n");
Expand Down Expand Up @@ -678,10 +680,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n);
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty);
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty);
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
printf(" --mirostat N use Mirostat sampling.\n");
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
Expand Down Expand Up @@ -878,7 +880,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
}

if (params.ignore_eos) {
params.sampling_params.logit_bias[llama_token_eos(lctx)] = -INFINITY;
params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY;
}

{
Expand Down Expand Up @@ -1123,28 +1125,28 @@ std::string get_sortable_timestamp() {

void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
const llama_sampling_params & sparams = params.sampling_params;
const llama_sampling_params & sparams = params.sparams;

fprintf(stream, "build_commit: %s\n", BUILD_COMMIT);
fprintf(stream, "build_number: %d\n", BUILD_NUMBER);
fprintf(stream, "cpu_has_arm_fma: %s\n", lm_ggml_cpu_has_arm_fma() ? "true" : "false");
fprintf(stream, "cpu_has_avx: %s\n", lm_ggml_cpu_has_avx() ? "true" : "false");
fprintf(stream, "cpu_has_avx2: %s\n", lm_ggml_cpu_has_avx2() ? "true" : "false");
fprintf(stream, "cpu_has_avx512: %s\n", lm_ggml_cpu_has_avx512() ? "true" : "false");
fprintf(stream, "cpu_has_arm_fma: %s\n", lm_ggml_cpu_has_arm_fma() ? "true" : "false");
fprintf(stream, "cpu_has_avx: %s\n", lm_ggml_cpu_has_avx() ? "true" : "false");
fprintf(stream, "cpu_has_avx2: %s\n", lm_ggml_cpu_has_avx2() ? "true" : "false");
fprintf(stream, "cpu_has_avx512: %s\n", lm_ggml_cpu_has_avx512() ? "true" : "false");
fprintf(stream, "cpu_has_avx512_vbmi: %s\n", lm_ggml_cpu_has_avx512_vbmi() ? "true" : "false");
fprintf(stream, "cpu_has_avx512_vnni: %s\n", lm_ggml_cpu_has_avx512_vnni() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", lm_ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_cublas: %s\n", lm_ggml_cpu_has_cublas() ? "true" : "false");
fprintf(stream, "cpu_has_clblast: %s\n", lm_ggml_cpu_has_clblast() ? "true" : "false");
fprintf(stream, "cpu_has_fma: %s\n", lm_ggml_cpu_has_fma() ? "true" : "false");
fprintf(stream, "cpu_has_gpublas: %s\n", lm_ggml_cpu_has_gpublas() ? "true" : "false");
fprintf(stream, "cpu_has_neon: %s\n", lm_ggml_cpu_has_neon() ? "true" : "false");
fprintf(stream, "cpu_has_f16c: %s\n", lm_ggml_cpu_has_f16c() ? "true" : "false");
fprintf(stream, "cpu_has_fp16_va: %s\n", lm_ggml_cpu_has_fp16_va() ? "true" : "false");
fprintf(stream, "cpu_has_wasm_simd: %s\n", lm_ggml_cpu_has_wasm_simd() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", lm_ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_sse3: %s\n", lm_ggml_cpu_has_sse3() ? "true" : "false");
fprintf(stream, "cpu_has_vsx: %s\n", lm_ggml_cpu_has_vsx() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", lm_ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_cublas: %s\n", lm_ggml_cpu_has_cublas() ? "true" : "false");
fprintf(stream, "cpu_has_clblast: %s\n", lm_ggml_cpu_has_clblast() ? "true" : "false");
fprintf(stream, "cpu_has_fma: %s\n", lm_ggml_cpu_has_fma() ? "true" : "false");
fprintf(stream, "cpu_has_gpublas: %s\n", lm_ggml_cpu_has_gpublas() ? "true" : "false");
fprintf(stream, "cpu_has_neon: %s\n", lm_ggml_cpu_has_neon() ? "true" : "false");
fprintf(stream, "cpu_has_f16c: %s\n", lm_ggml_cpu_has_f16c() ? "true" : "false");
fprintf(stream, "cpu_has_fp16_va: %s\n", lm_ggml_cpu_has_fp16_va() ? "true" : "false");
fprintf(stream, "cpu_has_wasm_simd: %s\n", lm_ggml_cpu_has_wasm_simd() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", lm_ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_sse3: %s\n", lm_ggml_cpu_has_sse3() ? "true" : "false");
fprintf(stream, "cpu_has_vsx: %s\n", lm_ggml_cpu_has_vsx() ? "true" : "false");

#ifdef NDEBUG
fprintf(stream, "debug: false\n");
Expand Down Expand Up @@ -1178,8 +1180,8 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty);
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str());
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
Expand Down Expand Up @@ -1238,14 +1240,14 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty);
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty);
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat);

fprintf(stream, "reverse_prompt:\n");
for (std::string ap : params.antiprompt) {
Expand Down
3 changes: 1 addition & 2 deletions cpp/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct gpt_params {
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor

// // sampling parameters
struct llama_sampling_params sampling_params;
struct llama_sampling_params sparams;

std::string model = "models/7B/ggml-model-f16.gguf"; // model path
std::string model_draft = ""; // draft model for speculative decoding
Expand All @@ -66,7 +66,6 @@ struct gpt_params {
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files

Expand Down
22 changes: 13 additions & 9 deletions cpp/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -13537,7 +13537,7 @@ static void lm_ggml_compute_forward_rope_f16(
dst_data[n_dims] = LM_GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
dst_data[n_dims/2*3] = LM_GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
}
} if (!is_neox) {
} else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
Expand Down Expand Up @@ -19170,6 +19170,7 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna

if (idx == -1) {
fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i);
fclose(fout);
return;
}

Expand Down Expand Up @@ -20844,7 +20845,7 @@ struct gguf_kv {
};

struct gguf_header {
uint32_t magic;
char magic[4];
uint32_t version;
uint64_t n_tensors; // GGUFv2
uint64_t n_kv; // GGUFv2
Expand Down Expand Up @@ -20914,7 +20915,7 @@ static bool gguf_fread_str_v1(FILE * file, struct gguf_str * p, size_t * offset)
struct gguf_context * gguf_init_empty(void) {
struct gguf_context * ctx = LM_GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));

ctx->header.magic = GGUF_MAGIC;
memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
ctx->header.version = GGUF_VERSION;
ctx->header.n_tensors = 0;
ctx->header.n_kv = 0;
Expand All @@ -20940,16 +20941,18 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
// offset from start of file
size_t offset = 0;

uint32_t magic = 0;
char magic[4];

// check the magic before making allocations
{
gguf_fread_el(file, &magic, sizeof(magic), &offset);

if (magic != GGUF_MAGIC) {
fprintf(stderr, "%s: invalid magic number %08x\n", __func__, magic);
fclose(file);
return NULL;
for (uint32_t i = 0; i < sizeof(magic); i++) {
if (magic[i] != GGUF_MAGIC[i]) {
fprintf(stderr, "%s: invalid magic characters %s.\n", __func__, magic);
fclose(file);
return NULL;
}
}
}

Expand All @@ -20959,7 +20962,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p

// read the header
{
ctx->header.magic = magic;
strncpy(ctx->header.magic, magic, 4);


ctx->kv = NULL;
ctx->infos = NULL;
Expand Down
Loading

0 comments on commit adb7545

Please sign in to comment.