Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sync llama.cpp #29

Merged
merged 3 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ public WritableMap completion(ReadableMap params) {
params.hasKey("n_predict") ? params.getInt("n_predict") : -1,
// int n_probs,
params.hasKey("n_probs") ? params.getInt("n_probs") : 0,
// int repeat_last_n,
params.hasKey("repeat_last_n") ? params.getInt("repeat_last_n") : 64,
// float repeat_penalty,
params.hasKey("repeat_penalty") ? (float) params.getDouble("repeat_penalty") : 1.10f,
// float presence_penalty,
params.hasKey("presence_penalty") ? (float) params.getDouble("presence_penalty") : 0.00f,
// float frequency_penalty,
params.hasKey("frequency_penalty") ? (float) params.getDouble("frequency_penalty") : 0.00f,
// int penalty_last_n,
params.hasKey("penalty_last_n") ? params.getInt("penalty_last_n") : 64,
// float penalty_repeat,
params.hasKey("penalty_repeat") ? (float) params.getDouble("penalty_repeat") : 1.10f,
// float penalty_freq,
params.hasKey("penalty_freq") ? (float) params.getDouble("penalty_freq") : 0.00f,
// float penalty_present,
params.hasKey("penalty_present") ? (float) params.getDouble("penalty_present") : 0.00f,
// float mirostat,
params.hasKey("mirostat") ? (float) params.getDouble("mirostat") : 0.00f,
// float mirostat_tau,
Expand Down Expand Up @@ -307,10 +307,10 @@ protected static native WritableMap doCompletion(
int n_threads,
int n_predict,
int n_probs,
int repeat_last_n,
float repeat_penalty,
float presence_penalty,
float frequency_penalty,
int penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present,
float mirostat,
float mirostat_tau,
float mirostat_eta,
Expand Down
27 changes: 13 additions & 14 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ Java_com_rnllama_LlamaContext_doCompletion(
jint n_threads,
jint n_predict,
jint n_probs,
jint repeat_last_n,
jfloat repeat_penalty,
jfloat presence_penalty,
jfloat frequency_penalty,
jint penalty_last_n,
jfloat penalty_repeat,
jfloat penalty_freq,
jfloat penalty_present,
jfloat mirostat,
jfloat mirostat_tau,
jfloat mirostat_eta,
Expand All @@ -301,7 +301,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
llama_reset_timings(llama->ctx);

llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
llama->params.grammar = env->GetStringUTFChars(grammar, nullptr);

int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
Expand All @@ -311,12 +310,12 @@ Java_com_rnllama_LlamaContext_doCompletion(
llama->params.n_predict = n_predict;
llama->params.ignore_eos = ignore_eos;

auto & sparams = llama->params.sampling_params;
auto & sparams = llama->params.sparams;
sparams.temp = temperature;
sparams.repeat_last_n = repeat_last_n;
sparams.repeat_penalty = repeat_penalty;
sparams.presence_penalty = presence_penalty;
sparams.frequency_penalty = frequency_penalty;
sparams.penalty_last_n = penalty_last_n;
sparams.penalty_repeat = penalty_repeat;
sparams.penalty_freq = penalty_freq;
sparams.penalty_present = penalty_present;
sparams.mirostat = mirostat;
sparams.mirostat_tau = mirostat_tau;
sparams.mirostat_eta = mirostat_eta;
Expand All @@ -325,6 +324,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
sparams.tfs_z = tfs_z;
sparams.typical_p = typical_p;
sparams.n_probs = n_probs;
sparams.grammar = env->GetStringUTFChars(grammar, nullptr);

sparams.logit_bias.clear();
if (ignore_eos) {
Expand Down Expand Up @@ -362,12 +362,11 @@ Java_com_rnllama_LlamaContext_doCompletion(
env->ReleaseStringUTFChars(stop_str, stop_chars);
}

if (!llama->loadGrammar()) {
if (!llama->initSampling()) {
auto result = createWriteableMap(env);
putString(env, result, "error", "Failed to load grammar");
putString(env, result, "error", "Failed to initialize sampling");
return reinterpret_cast<jobject>(result);
}

llama->loadPrompt();
llama->beginCompletion();

Expand Down Expand Up @@ -413,7 +412,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
auto tokenResult = createWriteableMap(env);
putString(env, tokenResult, "token", to_send.c_str());

if (llama->params.sampling_params.n_probs > 0) {
if (llama->params.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama->ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
Expand Down
4 changes: 2 additions & 2 deletions cpp/build-info.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#ifndef BUILD_INFO_H
#define BUILD_INFO_H

#define BUILD_NUMBER 1399
#define BUILD_COMMIT "004797f"
#define BUILD_NUMBER 1414
#define BUILD_COMMIT "96981f3"
#define BUILD_COMPILER ""
#define BUILD_TARGET "unknown"

Expand Down
70 changes: 36 additions & 34 deletions cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::string arg;
gpt_params default_params;
const std::string arg_prefix = "--";
llama_sampling_params & sparams = params.sampling_params;
llama_sampling_params & sparams = params.sparams;

for (int i = 1; i < argc; i++) {
arg = argv[i];
Expand Down Expand Up @@ -241,25 +241,26 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
sparams.repeat_last_n = std::stoi(argv[i]);
sparams.penalty_last_n = std::stoi(argv[i]);
sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
} else if (arg == "--repeat-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.repeat_penalty = std::stof(argv[i]);
sparams.penalty_repeat = std::stof(argv[i]);
} else if (arg == "--frequency-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.frequency_penalty = std::stof(argv[i]);
sparams.penalty_freq = std::stof(argv[i]);
} else if (arg == "--presence-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.presence_penalty = std::stof(argv[i]);
sparams.penalty_present = std::stof(argv[i]);
} else if (arg == "--mirostat") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -572,7 +573,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
params.grammar = argv[i];
sparams.grammar = argv[i];
} else if (arg == "--grammar-file") {
if (++i >= argc) {
invalid_param = true;
Expand All @@ -587,7 +588,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(params.grammar)
std::back_inserter(sparams.grammar)
);
#ifndef LOG_DISABLE_LOGS
// Parse args for logging parameters
Expand Down Expand Up @@ -631,6 +632,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
process_escapes(params.prompt);
process_escapes(params.input_prefix);
process_escapes(params.input_suffix);
process_escapes(sparams.cfg_negative_prompt);
for (auto & antiprompt : params.antiprompt) {
process_escapes(antiprompt);
}
Expand All @@ -640,7 +642,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
}

void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
const llama_sampling_params & sparams = params.sampling_params;
const llama_sampling_params & sparams = params.sparams;

printf("usage: %s [options]\n", argv[0]);
printf("\n");
Expand Down Expand Up @@ -678,10 +680,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n);
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty);
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty);
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
printf(" --mirostat N use Mirostat sampling.\n");
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
Expand Down Expand Up @@ -878,7 +880,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
}

if (params.ignore_eos) {
params.sampling_params.logit_bias[llama_token_eos(lctx)] = -INFINITY;
params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY;
}

{
Expand Down Expand Up @@ -1123,28 +1125,28 @@ std::string get_sortable_timestamp() {

void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
const llama_sampling_params & sparams = params.sampling_params;
const llama_sampling_params & sparams = params.sparams;

fprintf(stream, "build_commit: %s\n", BUILD_COMMIT);
fprintf(stream, "build_number: %d\n", BUILD_NUMBER);
fprintf(stream, "cpu_has_arm_fma: %s\n", lm_ggml_cpu_has_arm_fma() ? "true" : "false");
fprintf(stream, "cpu_has_avx: %s\n", lm_ggml_cpu_has_avx() ? "true" : "false");
fprintf(stream, "cpu_has_avx2: %s\n", lm_ggml_cpu_has_avx2() ? "true" : "false");
fprintf(stream, "cpu_has_avx512: %s\n", lm_ggml_cpu_has_avx512() ? "true" : "false");
fprintf(stream, "cpu_has_arm_fma: %s\n", lm_ggml_cpu_has_arm_fma() ? "true" : "false");
fprintf(stream, "cpu_has_avx: %s\n", lm_ggml_cpu_has_avx() ? "true" : "false");
fprintf(stream, "cpu_has_avx2: %s\n", lm_ggml_cpu_has_avx2() ? "true" : "false");
fprintf(stream, "cpu_has_avx512: %s\n", lm_ggml_cpu_has_avx512() ? "true" : "false");
fprintf(stream, "cpu_has_avx512_vbmi: %s\n", lm_ggml_cpu_has_avx512_vbmi() ? "true" : "false");
fprintf(stream, "cpu_has_avx512_vnni: %s\n", lm_ggml_cpu_has_avx512_vnni() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", lm_ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_cublas: %s\n", lm_ggml_cpu_has_cublas() ? "true" : "false");
fprintf(stream, "cpu_has_clblast: %s\n", lm_ggml_cpu_has_clblast() ? "true" : "false");
fprintf(stream, "cpu_has_fma: %s\n", lm_ggml_cpu_has_fma() ? "true" : "false");
fprintf(stream, "cpu_has_gpublas: %s\n", lm_ggml_cpu_has_gpublas() ? "true" : "false");
fprintf(stream, "cpu_has_neon: %s\n", lm_ggml_cpu_has_neon() ? "true" : "false");
fprintf(stream, "cpu_has_f16c: %s\n", lm_ggml_cpu_has_f16c() ? "true" : "false");
fprintf(stream, "cpu_has_fp16_va: %s\n", lm_ggml_cpu_has_fp16_va() ? "true" : "false");
fprintf(stream, "cpu_has_wasm_simd: %s\n", lm_ggml_cpu_has_wasm_simd() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", lm_ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_sse3: %s\n", lm_ggml_cpu_has_sse3() ? "true" : "false");
fprintf(stream, "cpu_has_vsx: %s\n", lm_ggml_cpu_has_vsx() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", lm_ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_cublas: %s\n", lm_ggml_cpu_has_cublas() ? "true" : "false");
fprintf(stream, "cpu_has_clblast: %s\n", lm_ggml_cpu_has_clblast() ? "true" : "false");
fprintf(stream, "cpu_has_fma: %s\n", lm_ggml_cpu_has_fma() ? "true" : "false");
fprintf(stream, "cpu_has_gpublas: %s\n", lm_ggml_cpu_has_gpublas() ? "true" : "false");
fprintf(stream, "cpu_has_neon: %s\n", lm_ggml_cpu_has_neon() ? "true" : "false");
fprintf(stream, "cpu_has_f16c: %s\n", lm_ggml_cpu_has_f16c() ? "true" : "false");
fprintf(stream, "cpu_has_fp16_va: %s\n", lm_ggml_cpu_has_fp16_va() ? "true" : "false");
fprintf(stream, "cpu_has_wasm_simd: %s\n", lm_ggml_cpu_has_wasm_simd() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", lm_ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_sse3: %s\n", lm_ggml_cpu_has_sse3() ? "true" : "false");
fprintf(stream, "cpu_has_vsx: %s\n", lm_ggml_cpu_has_vsx() ? "true" : "false");

#ifdef NDEBUG
fprintf(stream, "debug: false\n");
Expand Down Expand Up @@ -1178,8 +1180,8 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty);
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str());
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
Expand Down Expand Up @@ -1238,14 +1240,14 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty);
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty);
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat);

fprintf(stream, "reverse_prompt:\n");
for (std::string ap : params.antiprompt) {
Expand Down
3 changes: 1 addition & 2 deletions cpp/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct gpt_params {
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor

// // sampling parameters
struct llama_sampling_params sampling_params;
struct llama_sampling_params sparams;

std::string model = "models/7B/ggml-model-f16.gguf"; // model path
std::string model_draft = ""; // draft model for speculative decoding
Expand All @@ -66,7 +66,6 @@ struct gpt_params {
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files

Expand Down
22 changes: 13 additions & 9 deletions cpp/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -13537,7 +13537,7 @@ static void lm_ggml_compute_forward_rope_f16(
dst_data[n_dims] = LM_GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
dst_data[n_dims/2*3] = LM_GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
}
} if (!is_neox) {
} else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
Expand Down Expand Up @@ -19170,6 +19170,7 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna

if (idx == -1) {
fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i);
fclose(fout);
return;
}

Expand Down Expand Up @@ -20844,7 +20845,7 @@ struct gguf_kv {
};

struct gguf_header {
uint32_t magic;
char magic[4];
uint32_t version;
uint64_t n_tensors; // GGUFv2
uint64_t n_kv; // GGUFv2
Expand Down Expand Up @@ -20914,7 +20915,7 @@ static bool gguf_fread_str_v1(FILE * file, struct gguf_str * p, size_t * offset)
struct gguf_context * gguf_init_empty(void) {
struct gguf_context * ctx = LM_GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));

ctx->header.magic = GGUF_MAGIC;
memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
ctx->header.version = GGUF_VERSION;
ctx->header.n_tensors = 0;
ctx->header.n_kv = 0;
Expand All @@ -20940,16 +20941,18 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
// offset from start of file
size_t offset = 0;

uint32_t magic = 0;
char magic[4];

// check the magic before making allocations
{
gguf_fread_el(file, &magic, sizeof(magic), &offset);

if (magic != GGUF_MAGIC) {
fprintf(stderr, "%s: invalid magic number %08x\n", __func__, magic);
fclose(file);
return NULL;
for (uint32_t i = 0; i < sizeof(magic); i++) {
if (magic[i] != GGUF_MAGIC[i]) {
fprintf(stderr, "%s: invalid magic characters %s.\n", __func__, magic);
fclose(file);
return NULL;
}
}
}

Expand All @@ -20959,7 +20962,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p

// read the header
{
ctx->header.magic = magic;
strncpy(ctx->header.magic, magic, 4);


ctx->kv = NULL;
ctx->infos = NULL;
Expand Down
Loading