diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 6577a3d..03cb24b 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -149,14 +149,14 @@ public WritableMap completion(ReadableMap params) { params.hasKey("n_predict") ? params.getInt("n_predict") : -1, // int n_probs, params.hasKey("n_probs") ? params.getInt("n_probs") : 0, - // int repeat_last_n, - params.hasKey("repeat_last_n") ? params.getInt("repeat_last_n") : 64, - // float repeat_penalty, - params.hasKey("repeat_penalty") ? (float) params.getDouble("repeat_penalty") : 1.10f, - // float presence_penalty, - params.hasKey("presence_penalty") ? (float) params.getDouble("presence_penalty") : 0.00f, - // float frequency_penalty, - params.hasKey("frequency_penalty") ? (float) params.getDouble("frequency_penalty") : 0.00f, + // int penalty_last_n, + params.hasKey("penalty_last_n") ? params.getInt("penalty_last_n") : 64, + // float penalty_repeat, + params.hasKey("penalty_repeat") ? (float) params.getDouble("penalty_repeat") : 1.10f, + // float penalty_freq, + params.hasKey("penalty_freq") ? (float) params.getDouble("penalty_freq") : 0.00f, + // float penalty_present, + params.hasKey("penalty_present") ? (float) params.getDouble("penalty_present") : 0.00f, // float mirostat, params.hasKey("mirostat") ? (float) params.getDouble("mirostat") : 0.00f, // float mirostat_tau, @@ -307,10 +307,10 @@ protected static native WritableMap doCompletion( int n_threads, int n_predict, int n_probs, - int repeat_last_n, - float repeat_penalty, - float presence_penalty, - float frequency_penalty, + int penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present, float mirostat, float mirostat_tau, float mirostat_eta, diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 5d38120..fa71c8e 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -277,10 +277,10 @@ Java_com_rnllama_LlamaContext_doCompletion( jint n_threads, jint n_predict, jint n_probs, - jint repeat_last_n, - jfloat repeat_penalty, - jfloat presence_penalty, - jfloat frequency_penalty, + jint penalty_last_n, + jfloat penalty_repeat, + jfloat penalty_freq, + jfloat penalty_present, jfloat mirostat, jfloat mirostat_tau, jfloat mirostat_eta, @@ -301,7 +301,6 @@ Java_com_rnllama_LlamaContext_doCompletion( llama_reset_timings(llama->ctx); llama->params.prompt = env->GetStringUTFChars(prompt, nullptr); - llama->params.grammar = env->GetStringUTFChars(grammar, nullptr); int max_threads = std::thread::hardware_concurrency(); // Use 2 threads by default on 4-core devices, 4 threads on more cores @@ -311,12 +310,12 @@ Java_com_rnllama_LlamaContext_doCompletion( llama->params.n_predict = n_predict; llama->params.ignore_eos = ignore_eos; - auto & sparams = llama->params.sampling_params; + auto & sparams = llama->params.sparams; sparams.temp = temperature; - sparams.repeat_last_n = repeat_last_n; - sparams.repeat_penalty = repeat_penalty; - sparams.presence_penalty = presence_penalty; - sparams.frequency_penalty = frequency_penalty; + sparams.penalty_last_n = penalty_last_n; + sparams.penalty_repeat = penalty_repeat; + sparams.penalty_freq = penalty_freq; + sparams.penalty_present = penalty_present; sparams.mirostat = mirostat; sparams.mirostat_tau = mirostat_tau; sparams.mirostat_eta = mirostat_eta; @@ -325,6 +324,7 @@ Java_com_rnllama_LlamaContext_doCompletion( sparams.tfs_z = tfs_z; sparams.typical_p = typical_p; sparams.n_probs = n_probs; + sparams.grammar = env->GetStringUTFChars(grammar, nullptr); sparams.logit_bias.clear(); if (ignore_eos) { @@ -362,12 +362,11 @@ Java_com_rnllama_LlamaContext_doCompletion( env->ReleaseStringUTFChars(stop_str, stop_chars); } - if (!llama->loadGrammar()) { + if (!llama->initSampling()) { auto result = createWriteableMap(env); - putString(env, result, "error", "Failed to load grammar"); + putString(env, result, "error", "Failed to initialize sampling"); return reinterpret_cast(result); } - llama->loadPrompt(); llama->beginCompletion(); @@ -413,7 +412,7 @@ Java_com_rnllama_LlamaContext_doCompletion( auto tokenResult = createWriteableMap(env); putString(env, tokenResult, "token", to_send.c_str()); - if (llama->params.sampling_params.n_probs > 0) { + if (llama->params.sparams.n_probs > 0) { const std::vector to_send_toks = llama_tokenize(llama->ctx, to_send, false); size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size()); diff --git a/cpp/build-info.h b/cpp/build-info.h index d761343..832eba8 100644 --- a/cpp/build-info.h +++ b/cpp/build-info.h @@ -1,8 +1,8 @@ #ifndef BUILD_INFO_H #define BUILD_INFO_H -#define BUILD_NUMBER 1399 -#define BUILD_COMMIT "004797f" +#define BUILD_NUMBER 1414 +#define BUILD_COMMIT "96981f3" #define BUILD_COMPILER "" #define BUILD_TARGET "unknown" diff --git a/cpp/common.cpp b/cpp/common.cpp index ce26523..a975e41 100644 --- a/cpp/common.cpp +++ b/cpp/common.cpp @@ -107,7 +107,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { std::string arg; gpt_params default_params; const std::string arg_prefix = "--"; - llama_sampling_params & sparams = params.sampling_params; + llama_sampling_params & sparams = params.sparams; for (int i = 1; i < argc; i++) { arg = argv[i]; @@ -241,25 +241,26 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - sparams.repeat_last_n = std::stoi(argv[i]); + sparams.penalty_last_n = std::stoi(argv[i]); + sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n); } else if (arg == "--repeat-penalty") { if (++i >= argc) { invalid_param = true; break; } - sparams.repeat_penalty = std::stof(argv[i]); + sparams.penalty_repeat = std::stof(argv[i]); } else if (arg == "--frequency-penalty") { if (++i >= argc) { invalid_param = true; break; } - sparams.frequency_penalty = std::stof(argv[i]); + sparams.penalty_freq = std::stof(argv[i]); } else if (arg == "--presence-penalty") { if (++i >= argc) { invalid_param = true; break; } - sparams.presence_penalty = std::stof(argv[i]); + sparams.penalty_present = std::stof(argv[i]); } else if (arg == "--mirostat") { if (++i >= argc) { invalid_param = true; @@ -572,7 +573,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - params.grammar = argv[i]; + sparams.grammar = argv[i]; } else if (arg == "--grammar-file") { if (++i >= argc) { invalid_param = true; @@ -587,7 +588,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { std::copy( std::istreambuf_iterator(file), std::istreambuf_iterator(), - std::back_inserter(params.grammar) + std::back_inserter(sparams.grammar) ); #ifndef LOG_DISABLE_LOGS // Parse args for logging parameters @@ -631,6 +632,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { process_escapes(params.prompt); process_escapes(params.input_prefix); process_escapes(params.input_suffix); + process_escapes(sparams.cfg_negative_prompt); for (auto & antiprompt : params.antiprompt) { process_escapes(antiprompt); } @@ -640,7 +642,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { - const llama_sampling_params & sparams = params.sampling_params; + const llama_sampling_params & sparams = params.sparams; printf("usage: %s [options]\n", argv[0]); printf("\n"); @@ -678,10 +680,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); - printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n); - printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty); - printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty); - printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty); + printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); + printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat); + printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present); + printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq); printf(" --mirostat N use Mirostat sampling.\n"); printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat); @@ -878,7 +880,7 @@ std::tuple llama_init_from_gpt_par } if (params.ignore_eos) { - params.sampling_params.logit_bias[llama_token_eos(lctx)] = -INFINITY; + params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY; } { @@ -1123,28 +1125,28 @@ std::string get_sortable_timestamp() { void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx, const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { - const llama_sampling_params & sparams = params.sampling_params; + const llama_sampling_params & sparams = params.sparams; fprintf(stream, "build_commit: %s\n", BUILD_COMMIT); fprintf(stream, "build_number: %d\n", BUILD_NUMBER); - fprintf(stream, "cpu_has_arm_fma: %s\n", lm_ggml_cpu_has_arm_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_avx: %s\n", lm_ggml_cpu_has_avx() ? "true" : "false"); - fprintf(stream, "cpu_has_avx2: %s\n", lm_ggml_cpu_has_avx2() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512: %s\n", lm_ggml_cpu_has_avx512() ? "true" : "false"); + fprintf(stream, "cpu_has_arm_fma: %s\n", lm_ggml_cpu_has_arm_fma() ? "true" : "false"); + fprintf(stream, "cpu_has_avx: %s\n", lm_ggml_cpu_has_avx() ? "true" : "false"); + fprintf(stream, "cpu_has_avx2: %s\n", lm_ggml_cpu_has_avx2() ? "true" : "false"); + fprintf(stream, "cpu_has_avx512: %s\n", lm_ggml_cpu_has_avx512() ? "true" : "false"); fprintf(stream, "cpu_has_avx512_vbmi: %s\n", lm_ggml_cpu_has_avx512_vbmi() ? "true" : "false"); fprintf(stream, "cpu_has_avx512_vnni: %s\n", lm_ggml_cpu_has_avx512_vnni() ? "true" : "false"); - fprintf(stream, "cpu_has_blas: %s\n", lm_ggml_cpu_has_blas() ? "true" : "false"); - fprintf(stream, "cpu_has_cublas: %s\n", lm_ggml_cpu_has_cublas() ? "true" : "false"); - fprintf(stream, "cpu_has_clblast: %s\n", lm_ggml_cpu_has_clblast() ? "true" : "false"); - fprintf(stream, "cpu_has_fma: %s\n", lm_ggml_cpu_has_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_gpublas: %s\n", lm_ggml_cpu_has_gpublas() ? "true" : "false"); - fprintf(stream, "cpu_has_neon: %s\n", lm_ggml_cpu_has_neon() ? "true" : "false"); - fprintf(stream, "cpu_has_f16c: %s\n", lm_ggml_cpu_has_f16c() ? "true" : "false"); - fprintf(stream, "cpu_has_fp16_va: %s\n", lm_ggml_cpu_has_fp16_va() ? "true" : "false"); - fprintf(stream, "cpu_has_wasm_simd: %s\n", lm_ggml_cpu_has_wasm_simd() ? "true" : "false"); - fprintf(stream, "cpu_has_blas: %s\n", lm_ggml_cpu_has_blas() ? "true" : "false"); - fprintf(stream, "cpu_has_sse3: %s\n", lm_ggml_cpu_has_sse3() ? "true" : "false"); - fprintf(stream, "cpu_has_vsx: %s\n", lm_ggml_cpu_has_vsx() ? "true" : "false"); + fprintf(stream, "cpu_has_blas: %s\n", lm_ggml_cpu_has_blas() ? "true" : "false"); + fprintf(stream, "cpu_has_cublas: %s\n", lm_ggml_cpu_has_cublas() ? "true" : "false"); + fprintf(stream, "cpu_has_clblast: %s\n", lm_ggml_cpu_has_clblast() ? "true" : "false"); + fprintf(stream, "cpu_has_fma: %s\n", lm_ggml_cpu_has_fma() ? "true" : "false"); + fprintf(stream, "cpu_has_gpublas: %s\n", lm_ggml_cpu_has_gpublas() ? "true" : "false"); + fprintf(stream, "cpu_has_neon: %s\n", lm_ggml_cpu_has_neon() ? "true" : "false"); + fprintf(stream, "cpu_has_f16c: %s\n", lm_ggml_cpu_has_f16c() ? "true" : "false"); + fprintf(stream, "cpu_has_fp16_va: %s\n", lm_ggml_cpu_has_fp16_va() ? "true" : "false"); + fprintf(stream, "cpu_has_wasm_simd: %s\n", lm_ggml_cpu_has_wasm_simd() ? "true" : "false"); + fprintf(stream, "cpu_has_blas: %s\n", lm_ggml_cpu_has_blas() ? "true" : "false"); + fprintf(stream, "cpu_has_sse3: %s\n", lm_ggml_cpu_has_sse3() ? "true" : "false"); + fprintf(stream, "cpu_has_vsx: %s\n", lm_ggml_cpu_has_vsx() ? "true" : "false"); #ifdef NDEBUG fprintf(stream, "debug: false\n"); @@ -1178,8 +1180,8 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); - fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty); - dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str()); + fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); + dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str()); fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); @@ -1238,14 +1240,14 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false"); fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); - fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty); + fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present); dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str()); fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens); fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false"); - fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty); + fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat); fprintf(stream, "reverse_prompt:\n"); for (std::string ap : params.antiprompt) { diff --git a/cpp/common.h b/cpp/common.h index 65d3d20..84523a4 100644 --- a/cpp/common.h +++ b/cpp/common.h @@ -56,7 +56,7 @@ struct gpt_params { float rope_freq_scale = 0.0f; // RoPE frequency scaling factor // // sampling parameters - struct llama_sampling_params sampling_params; + struct llama_sampling_params sparams; std::string model = "models/7B/ggml-model-f16.gguf"; // model path std::string model_draft = ""; // draft model for speculative decoding @@ -66,7 +66,6 @@ struct gpt_params { std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with - std::string grammar = ""; // optional BNF-like grammar to constrain sampling std::vector antiprompt; // string upon seeing which more user input is prompted std::string logdir = ""; // directory in which to save YAML log files diff --git a/cpp/ggml.c b/cpp/ggml.c index 17ce01b..4b783a8 100644 --- a/cpp/ggml.c +++ b/cpp/ggml.c @@ -13537,7 +13537,7 @@ static void lm_ggml_compute_forward_rope_f16( dst_data[n_dims] = LM_GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); dst_data[n_dims/2*3] = LM_GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); } - } if (!is_neox) { + } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); @@ -19170,6 +19170,7 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna if (idx == -1) { fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i); + fclose(fout); return; } @@ -20844,7 +20845,7 @@ struct gguf_kv { }; struct gguf_header { - uint32_t magic; + char magic[4]; uint32_t version; uint64_t n_tensors; // GGUFv2 uint64_t n_kv; // GGUFv2 @@ -20914,7 +20915,7 @@ static bool gguf_fread_str_v1(FILE * file, struct gguf_str * p, size_t * offset) struct gguf_context * gguf_init_empty(void) { struct gguf_context * ctx = LM_GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); - ctx->header.magic = GGUF_MAGIC; + memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic)); ctx->header.version = GGUF_VERSION; ctx->header.n_tensors = 0; ctx->header.n_kv = 0; @@ -20940,16 +20941,18 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p // offset from start of file size_t offset = 0; - uint32_t magic = 0; + char magic[4]; // check the magic before making allocations { gguf_fread_el(file, &magic, sizeof(magic), &offset); - if (magic != GGUF_MAGIC) { - fprintf(stderr, "%s: invalid magic number %08x\n", __func__, magic); - fclose(file); - return NULL; + for (uint32_t i = 0; i < sizeof(magic); i++) { + if (magic[i] != GGUF_MAGIC[i]) { + fprintf(stderr, "%s: invalid magic characters %s.\n", __func__, magic); + fclose(file); + return NULL; + } } } @@ -20959,7 +20962,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p // read the header { - ctx->header.magic = magic; + strncpy(ctx->header.magic, magic, 4); + ctx->kv = NULL; ctx->infos = NULL; diff --git a/cpp/ggml.h b/cpp/ggml.h index 6ece44c..b937b70 100644 --- a/cpp/ggml.h +++ b/cpp/ggml.h @@ -231,8 +231,9 @@ #define LM_GGML_EXIT_SUCCESS 0 #define LM_GGML_EXIT_ABORTED 1 -#define GGUF_MAGIC 0x46554747 // "GGUF" -#define GGUF_VERSION 2 +#define GGUF_MAGIC "GGUF" + +#define GGUF_VERSION 3 #define GGUF_DEFAULT_ALIGNMENT 32 diff --git a/cpp/grammar-parser.cpp b/cpp/grammar-parser.cpp index 5a545a8..ff51cc8 100644 --- a/cpp/grammar-parser.cpp +++ b/cpp/grammar-parser.cpp @@ -399,7 +399,7 @@ namespace grammar_parser { void print_grammar(FILE * file, const parse_state & state) { try { std::map symbol_id_names; - for (auto kv : state.symbol_ids) { + for (const auto & kv : state.symbol_ids) { symbol_id_names[kv.second] = kv.first; } for (size_t i = 0, end = state.rules.size(); i < end; i++) { diff --git a/cpp/k_quants.c b/cpp/k_quants.c index eacebad..c37eb22 100644 --- a/cpp/k_quants.c +++ b/cpp/k_quants.c @@ -46,7 +46,7 @@ inline static int32_t vaddvq_s32(int32x4_t v) { #if defined(_MSC_VER) || defined(__MINGW32__) #include #else -#if !defined(__riscv) +#if !defined(__riscv) && !defined(__s390__) #include #endif #endif diff --git a/cpp/llama.cpp b/cpp/llama.cpp index e3fa6a5..e22bb7c 100644 --- a/cpp/llama.cpp +++ b/cpp/llama.cpp @@ -986,14 +986,15 @@ static void llama_nop(struct lm_ggml_tensor * tensor) { // don't offload by defa (void) tensor; } -static std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) { +static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) { std::vector result(8, 0); const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); if (n_tokens < 0) { result.resize(-n_tokens); int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); LM_GGML_ASSERT(check == -n_tokens); - } else { + } + else { result.resize(n_tokens); } @@ -1029,8 +1030,8 @@ enum e_model { }; static const size_t kB = 1024; -static const size_t MB = kB*kB; -static const size_t GB = kB*kB*kB; +static const size_t MB = 1024*kB; +static const size_t GB = 1024*MB; struct llama_hparams { bool vocab_only; @@ -1053,21 +1054,21 @@ struct llama_hparams { float f_max_alibi_bias; bool operator!=(const llama_hparams & other) const { - if (this->vocab_only != other.vocab_only) return true; - if (this->n_vocab != other.n_vocab) return true; + if (this->vocab_only != other.vocab_only) return true; + if (this->n_vocab != other.n_vocab) return true; if (this->n_ctx_train != other.n_ctx_train) return true; - if (this->n_embd != other.n_embd) return true; - if (this->n_head != other.n_head) return true; - if (this->n_head_kv != other.n_head_kv) return true; - if (this->n_layer != other.n_layer) return true; - if (this->n_rot != other.n_rot) return true; - if (this->n_ff != other.n_ff) return true; + if (this->n_embd != other.n_embd) return true; + if (this->n_head != other.n_head) return true; + if (this->n_head_kv != other.n_head_kv) return true; + if (this->n_layer != other.n_layer) return true; + if (this->n_rot != other.n_rot) return true; + if (this->n_ff != other.n_ff) return true; const float EPSILON = 1e-9; - if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; - if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; - if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; + if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; + if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; + if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; return false; @@ -1206,17 +1207,17 @@ struct llama_vocab { id special_sep_id = -1; id special_pad_id = -1; - id linefeed_id = 13; + id linefeed_id = 13; id special_prefix_id = 32007; id special_middle_id = 32009; id special_suffix_id = 32008; - id special_eot_id = 32010; + id special_eot_id = 32010; int find_bpe_rank(std::string token_left, std::string token_right) const { - replace_all(token_left, " ", "\u0120"); - replace_all(token_left, "\n", "\u010A"); - replace_all(token_right, " ", "\u0120"); - replace_all(token_right, "\n", "\u010A"); + LM_GGML_ASSERT(token_left.find(" ") == std::string::npos); + LM_GGML_ASSERT(token_left.find("\n") == std::string::npos); + LM_GGML_ASSERT(token_right.find(" ") == std::string::npos); + LM_GGML_ASSERT(token_right.find("\n") == std::string::npos); auto it = bpe_ranks.find(std::make_pair(token_left, token_right)); if (it == bpe_ranks.end()) { @@ -1370,10 +1371,7 @@ static bool llama_kv_cache_init( cache.cells.clear(); cache.cells.resize(n_ctx); - // TODO: this should be: - // cache.buf.resize(2u*n_elements*lm_ggml_type_size(wtype) + 2u*lm_ggml_tensor_overhead()); - // change it and test that it works - cache.buf.resize(2u*n_elements*lm_ggml_type_size(wtype) + 2u*MB); + cache.buf.resize(2u*n_elements*lm_ggml_type_size(wtype) + 2u*lm_ggml_tensor_overhead()); memset(cache.buf.data, 0, cache.buf.size); struct lm_ggml_init_params params; @@ -2252,15 +2250,35 @@ static void llm_load_vocab( if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); } else { - vocab.linefeed_id = llama_tokenize_internal(vocab, "\u010A", false)[0]; + const std::vector ids = llama_tokenize_internal(vocab, "\u010A", false); + LM_GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + vocab.linefeed_id = ids[0]; } // special tokens - GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID)); - GGUF_GET_KEY(ctx, vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID)); - GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID)); - GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID)); - GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID)); + { + const std::vector> special_token_types = { + { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id }, + { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id }, + { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id }, + { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id }, + { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id }, + }; + for (const auto & it : special_token_types) { + const std::string & key = kv(std::get<0>(it)); + int32_t & id = std::get<1>(it), old_id = id; + + GGUF_GET_KEY(ctx, id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, key); + // Must be >= -1 and < vocab size. Since the key is unsigned, -1 + // can only come from the default value, so there's no point in + // validating that. + if (size_t(id + 1) > vocab.id_to_token.size()) { + LLAMA_LOG_WARN("%s: bad special token: '%s' = %d, using default id %d\n", + __func__, key.c_str(), id, old_id); + id = old_id; + } + } + } // build special tokens cache { @@ -6117,11 +6135,10 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { } static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { + static const char * hex = "0123456789ABCDEF"; switch (llama_vocab_get_type(vocab)) { case LLAMA_VOCAB_TYPE_SPM: { - char buf[7]; - int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch); - LM_GGML_ASSERT(0 <= result && result < 7); + const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 }; return vocab.token_to_id.at(buf); } case LLAMA_VOCAB_TYPE_BPE: { @@ -6335,7 +6352,6 @@ struct llm_tokenizer_bpe { llm_symbol sym; size_t char_len = std::min(word.size() - offset, (size_t) ::utf8_len(word[offset])); sym.text = word.c_str() + offset; - sym.n = 1; sym.n = char_len; offset += sym.n; sym.prev = index - 1; @@ -7065,7 +7081,7 @@ static std::vector llama_grammar_reject_candidates_for_ std::vector rejects; if (stack.empty()) { - for (auto tok : candidates) { + for (const auto & tok : candidates) { if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { rejects.push_back(tok); } @@ -7076,7 +7092,7 @@ static std::vector llama_grammar_reject_candidates_for_ const llama_grammar_element * stack_pos = stack.back(); std::vector next_candidates; - for (auto tok : candidates) { + for (const auto & tok : candidates) { if (*tok.code_points == 0) { // reached end of full codepoints in token, reject iff it ended in a partial sequence // that cannot satisfy this position in grammar @@ -7102,7 +7118,7 @@ static std::vector llama_grammar_reject_candidates_for_ llama_grammar_advance_stack(rules, stack_after, next_stacks); auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); - for (auto tok : next_rejects) { + for (const auto & tok : next_rejects) { rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); } @@ -7429,37 +7445,15 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array llama_sample_temp(ctx, candidates_p, temp); } -void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) { - if (last_tokens_size == 0 || penalty == 1.0f) { - return; - } - - const int64_t t_start_sample_us = lm_ggml_time_us(); - - for (size_t i = 0; i < candidates->size; ++i) { - const auto * token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id); - if (token_iter == last_tokens + last_tokens_size) { - continue; - } - - // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. - // This is common fix for this problem, which is to multiply by the penalty instead of dividing. - if (candidates->data[i].logit <= 0) { - candidates->data[i].logit *= penalty; - } else { - candidates->data[i].logit /= penalty; - } - } - - candidates->sorted = false; - - if (ctx) { - ctx->t_sample_us += lm_ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { - if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) { +void llama_sample_repetition_penalties( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present) { + if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { return; } @@ -7467,19 +7461,28 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l // Create a frequency map to count occurrences of each token in last_tokens std::unordered_map token_count; - for (size_t i = 0; i < last_tokens_size; ++i) { - token_count[last_tokens_p[i]]++; + for (size_t i = 0; i < penalty_last_n; ++i) { + token_count[last_tokens[i]]++; } // Apply frequency and presence penalties to the candidates for (size_t i = 0; i < candidates->size; ++i) { - auto token_iter = token_count.find(candidates->data[i].id); + const auto token_iter = token_count.find(candidates->data[i].id); if (token_iter == token_count.end()) { continue; } - int count = token_iter->second; - candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence; + const int count = token_iter->second; + + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. + // This is common fix for this problem, which is to multiply by the penalty instead of dividing. + if (candidates->data[i].logit <= 0) { + candidates->data[i].logit *= penalty_repeat; + } else { + candidates->data[i].logit /= penalty_repeat; + } + + candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present; } candidates->sorted = false; @@ -7508,7 +7511,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string piece = llama_token_to_str(ctx, id); + const std::string piece = llama_token_to_piece(ctx, id); if (id == eos) { if (!allow_eos) { candidates->data[i].logit = -INFINITY; @@ -7720,7 +7723,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar LM_GGML_ASSERT(false); } - const std::string piece = llama_token_to_str(ctx, token); + const std::string piece = llama_token_to_piece(ctx, token); // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8); diff --git a/cpp/llama.h b/cpp/llama.h index 5aae8f5..04ff963 100644 --- a/cpp/llama.h +++ b/cpp/llama.h @@ -560,21 +560,15 @@ extern "C" { LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - LLAMA_API void llama_sample_repetition_penalty( - struct llama_context * ctx, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t last_tokens_size, - float penalty); - /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - LLAMA_API void llama_sample_frequency_and_presence_penalties( + LLAMA_API void llama_sample_repetition_penalties( struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, - size_t last_tokens_size, - float alpha_frequency, - float alpha_presence); + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present); /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index d38a94f..aeae7eb 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -139,10 +139,13 @@ struct llama_rn_context std::vector embd; + gpt_params params; + llama_model *model = nullptr; llama_context *ctx = nullptr; - gpt_params params; - llama_sampling_context *ctx_sampling; + llama_sampling_context *ctx_sampling = nullptr; + + int n_ctx; bool truncated = false; bool stopped_eos = false; @@ -173,7 +176,7 @@ struct llama_rn_context { is_interrupted = false; params.antiprompt.clear(); - params.grammar.clear(); + params.sparams.grammar.clear(); num_prompt_tokens = 0; num_tokens_predicted = 0; generated_text = ""; @@ -187,11 +190,15 @@ struct llama_rn_context multibyte_pending = 0; n_remain = 0; n_past = 0; + params.sparams.n_prev = n_ctx; + } + bool initSampling() { if (ctx_sampling != nullptr) { llama_sampling_free(ctx_sampling); } - ctx_sampling = llama_sampling_init(params); + ctx_sampling = llama_sampling_init(params.sparams); + return ctx_sampling != nullptr; } bool loadModel(gpt_params ¶ms_) @@ -203,13 +210,30 @@ struct llama_rn_context LOG_ERROR("unable to load model: %s", params_.model.c_str()); return false; } + n_ctx = llama_n_ctx(ctx); return true; } - bool loadGrammar() - { - ctx_sampling = llama_sampling_init(params); - return true; + void truncatePrompt(std::vector &prompt_tokens) { + const int n_left = n_ctx - params.n_keep; + const int n_block_size = n_left / 2; + const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_block_size) / n_block_size; + + // Keep n_keep tokens at start of prompt (at most n_ctx - 4) + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + + new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_block_size, prompt_tokens.end()); + + LOG_VERBOSE("input truncated, n_ctx: %d, n_keep: %d, n_left: %d, new_tokens: %s, num_prompt_tokens: %d", + n_ctx, + params.n_keep, + n_left, + tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()).c_str(), + new_tokens.size() + ); + + truncated = true; + prompt_tokens = new_tokens; } void loadPrompt() @@ -222,28 +246,20 @@ struct llama_rn_context { params.n_keep = (int)num_prompt_tokens; } - params.n_keep = std::min(params.n_ctx - 4, params.n_keep); + params.n_keep = std::min(n_ctx - 4, params.n_keep); // if input prompt is too big, truncate like normal - if (num_prompt_tokens >= (size_t)params.n_ctx) + if (num_prompt_tokens >= (size_t) n_ctx) { - const int n_left = (params.n_ctx - params.n_keep) / 2; - std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin()); + truncatePrompt(prompt_tokens); + num_prompt_tokens = prompt_tokens.size(); - LOG_VERBOSE("input truncated, n_ctx: %d, n_keep: %d, n_left: %d, new_tokens: %s", - params.n_ctx, params.n_keep, n_left, tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()).c_str() - ); - truncated = true; - prompt_tokens = new_tokens; + LM_GGML_ASSERT(num_prompt_tokens < (size_t) n_ctx); } - else + // push the prompt into the sampling context (do not apply grammar) + for (auto & token : prompt_tokens) { - const size_t ps = num_prompt_tokens; - std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps); + llama_sampling_accept(ctx_sampling, ctx, token, false); } // compare the evaluated prompt with the new prompt @@ -346,8 +362,8 @@ struct llama_rn_context llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false }; - const int32_t n_probs = params.sampling_params.n_probs; - if (params.sampling_params.temp <= 0 && n_probs > 0) + const int32_t n_probs = params.sparams.n_probs; + if (params.sparams.temp <= 0 && n_probs > 0) { // For llama_sample_token_greedy we need to sort candidates llama_sample_softmax(ctx, &cur_p); @@ -357,7 +373,7 @@ struct llama_rn_context { result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); } - llama_sampling_accept(ctx_sampling, ctx, result.tok); + llama_sampling_accept(ctx_sampling, ctx, result.tok, true); if (tg) { num_tokens_predicted++; } @@ -420,7 +436,7 @@ struct llama_rn_context const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok); generated_text += token_text; - if (params.sampling_params.n_probs > 0) + if (params.sparams.n_probs > 0) { generated_token_probs.push_back(token_with_probs); } diff --git a/cpp/sampling.cpp b/cpp/sampling.cpp index 0b24665..6f0af3c 100644 --- a/cpp/sampling.cpp +++ b/cpp/sampling.cpp @@ -1,9 +1,9 @@ #include "sampling.h" -struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) { +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); - result->params = params.sampling_params; + result->params = params; result->grammar = nullptr; // if there is a grammar, parse it @@ -23,7 +23,7 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_params & pa grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); } - result->prev.resize(params.n_ctx); + result->prev.resize(params.n_prev); return result; } @@ -66,25 +66,56 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds dst->prev = src->prev; } +llama_token llama_sampling_last(llama_sampling_context * ctx) { + return ctx->prev.back(); +} + +std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) { + const int size = ctx_sampling->prev.size(); + + n = std::min(n, size); + + std::string result; + + for (int i = size - n; i < size; i++) { + result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]); + } + + return result; +} + +std::string llama_sampling_print(const llama_sampling_params & params) { + char result[1024]; + + snprintf(result, sizeof(result), + "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" + "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n" + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", + params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, + params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, + params.mirostat, params.mirostat_eta, params.mirostat_tau); + + return std::string(result); +} + llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, const int idx) { - const int n_ctx = llama_n_ctx(ctx_main); - const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); - const llama_sampling_params & params = ctx_sampling->params; + const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const float top_p = params.top_p; const float tfs_z = params.tfs_z; const float typical_p = params.typical_p; - const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; - const float repeat_penalty = params.repeat_penalty; - const float alpha_presence = params.presence_penalty; - const float alpha_frequency = params.frequency_penalty; + const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; + const float penalty_repeat = params.penalty_repeat; + const float penalty_freq = params.penalty_freq; + const float penalty_present = params.penalty_present; const int mirostat = params.mirostat; const float mirostat_tau = params.mirostat_tau; const float mirostat_eta = params.mirostat_eta; @@ -97,7 +128,7 @@ llama_token llama_sampling_sample( float * logits = llama_get_logits_ith(ctx_main, idx); - // Apply params.logit_bias map + // apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { logits[it->first] += it->second; } @@ -117,14 +148,10 @@ llama_token llama_sampling_sample( // apply penalties if (!prev.empty()) { const float nl_logit = logits[llama_token_nl(ctx_main)]; - const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx); - llama_sample_repetition_penalty(ctx_main, &cur_p, - prev.data() + prev.size() - last_n_repeat, - last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p, - prev.data() + prev.size() - last_n_repeat, - last_n_repeat, alpha_frequency, alpha_presence); + llama_sample_repetition_penalties(ctx_main, &cur_p, + prev.data() + prev.size() - penalty_last_n, + penalty_last_n, penalty_repeat, penalty_freq, penalty_present); if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { @@ -141,7 +168,7 @@ llama_token llama_sampling_sample( } if (temp <= 0) { - // Greedy sampling + // greedy sampling id = llama_sample_token_greedy(ctx_main, &cur_p); } else { if (mirostat == 1) { @@ -152,8 +179,9 @@ llama_token llama_sampling_sample( llama_sample_temp(ctx_main, &cur_p, temp); id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); } else { - // Temperature sampling + // temperature sampling size_t min_keep = std::max(1, params.n_probs); + llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); @@ -183,11 +211,12 @@ llama_token llama_sampling_sample( void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, - llama_token id) { + llama_token id, + bool apply_grammar) { ctx_sampling->prev.erase(ctx_sampling->prev.begin()); ctx_sampling->prev.push_back(id); - if (ctx_sampling->grammar != NULL) { + if (ctx_sampling->grammar != NULL && apply_grammar) { llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); } } diff --git a/cpp/sampling.h b/cpp/sampling.h index 50afcbc..62ea6d4 100644 --- a/cpp/sampling.h +++ b/cpp/sampling.h @@ -10,30 +10,30 @@ // sampling parameters typedef struct llama_sampling_params { + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled float temp = 0.80f; // 1.0 = disabled - float repeat_penalty = 1.10f; // 1.0 = disabled - int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float frequency_penalty = 0.00f; // 0.0 = disabled - float presence_penalty = 0.00f; // 0.0 = disabled + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.10f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = true; // consider newlines as a repeatable token - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + std::string grammar; // optional BNF-like grammar to constrain sampling // Classifier-Free Guidance // https://arxiv.org/abs/2306.17806 - std::string cfg_negative_prompt; // string to help guidance - float cfg_scale = 1.f; // How strong is guidance + std::string cfg_negative_prompt; // string to help guidance + float cfg_scale = 1.f; // how strong is guidance std::unordered_map logit_bias; // logit bias for specific tokens - } llama_sampling_params; // general sampler context @@ -58,7 +58,7 @@ struct llama_sampling_context { #include "common.h" // Create a new sampling context instance. -struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params); +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params); void llama_sampling_free(struct llama_sampling_context * ctx); @@ -70,6 +70,15 @@ void llama_sampling_reset(llama_sampling_context * ctx); // Copy the sampler context void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); +// Get the last sampled token +llama_token llama_sampling_last(llama_sampling_context * ctx); + +// Get a string representation of the last sampled tokens +std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n); + +// Print sampling parameters into a string +std::string llama_sampling_print(const llama_sampling_params & params); + // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function // Note: When using multiple sequences, it is the caller's responsibility to call @@ -96,4 +105,5 @@ llama_token llama_sampling_sample( void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, - llama_token id); + llama_token id, + bool apply_grammar); diff --git a/example/src/App.tsx b/example/src/App.tsx index a38fd0d..e644e7d 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -285,14 +285,14 @@ export default function App() { prompt, n_predict: 400, temperature: 0.7, - repeat_last_n: 256, // 0 = disable penalty, -1 = context size - repeat_penalty: 1.18, // 1.0 = disabled top_k: 40, // <= 0 to use vocab size top_p: 0.5, // 1.0 = disabled tfs_z: 1.0, // 1.0 = disabled typical_p: 1.0, // 1.0 = disabled - presence_penalty: 0.0, // 0.0 = disabled - frequency_penalty: 0.0, // 0.0 = disabled + penalty_last_n: 256, // 0 = disable penalty, -1 = context size + penalty_repeat: 1.18, // 1.0 = disabled + penalty_freq: 0.0, // 0.0 = disabled + penalty_present: 0.0, // 0.0 = disabled mirostat: 0, // 0/1/2 mirostat_tau: 5, // target entropy mirostat_eta: 0.1, // learning rate diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index d8f1b50..29aab44 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -133,10 +133,6 @@ - (NSDictionary *)completion:(NSDictionary *)params llama->params.prompt = [prompt UTF8String]; - if (params[@"grammar"]) { - llama->params.grammar = [params[@"grammar"] UTF8String]; - } - if (params[@"n_threads"]) { int nThreads = params[@"n_threads"] ? [params[@"n_threads"] intValue] : llama->params.n_threads; const int maxThreads = (int) [[NSProcessInfo processInfo] processorCount]; @@ -146,16 +142,16 @@ - (NSDictionary *)completion:(NSDictionary *)params } if (params[@"n_predict"]) llama->params.n_predict = [params[@"n_predict"] intValue]; - auto & sparams = llama->params.sampling_params; + auto & sparams = llama->params.sparams; if (params[@"temperature"]) sparams.temp = [params[@"temperature"] doubleValue]; if (params[@"n_probs"]) sparams.n_probs = [params[@"n_probs"] intValue]; - if (params[@"repeat_last_n"]) sparams.repeat_last_n = [params[@"repeat_last_n"] intValue]; - if (params[@"repeat_penalty"]) sparams.repeat_penalty = [params[@"repeat_penalty"] doubleValue]; - if (params[@"presence_penalty"]) sparams.presence_penalty = [params[@"presence_penalty"] doubleValue]; - if (params[@"frequency_penalty"]) sparams.frequency_penalty = [params[@"frequency_penalty"] doubleValue]; + if (params[@"penalty_last_n"]) sparams.penalty_last_n = [params[@"penalty_last_n"] intValue]; + if (params[@"penalty_repeat"]) sparams.penalty_repeat = [params[@"penalty_repeat"] doubleValue]; + if (params[@"penalty_freq"]) sparams.penalty_freq = [params[@"penalty_freq"] doubleValue]; + if (params[@"penalty_present"]) sparams.penalty_present = [params[@"penalty_present"] doubleValue]; if (params[@"mirostat"]) sparams.mirostat = [params[@"mirostat"] intValue]; if (params[@"mirostat_tau"]) sparams.mirostat_tau = [params[@"mirostat_tau"] doubleValue]; @@ -167,6 +163,10 @@ - (NSDictionary *)completion:(NSDictionary *)params if (params[@"typical_p"]) sparams.typical_p = [params[@"typical_p"] doubleValue]; + if (params[@"grammar"]) { + sparams.grammar = [params[@"grammar"] UTF8String]; + } + llama->params.antiprompt.clear(); if (params[@"stop"]) { NSArray *stop = params[@"stop"]; @@ -197,10 +197,9 @@ - (NSDictionary *)completion:(NSDictionary *)params } } - if (!llama->loadGrammar()) { - @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to load grammar" userInfo:nil]; + if (!llama->initSampling()) { + @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to initialize sampling" userInfo:nil]; } - llama->loadPrompt(); llama->beginCompletion(); @@ -246,7 +245,7 @@ - (NSDictionary *)completion:(NSDictionary *)params NSMutableDictionary *tokenResult = [[NSMutableDictionary alloc] init]; tokenResult[@"token"] = [NSString stringWithUTF8String:to_send.c_str()]; - if (llama->params.sampling_params.n_probs > 0) { + if (llama->params.sparams.n_probs > 0) { const std::vector to_send_toks = llama_tokenize(llama->ctx, to_send, false); size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size()); diff --git a/llama.cpp b/llama.cpp index 004797f..96981f3 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit 004797f6ac135383f8c1d1f5bd415ddee2f79318 +Subproject commit 96981f37b1e3f450d9e63e571514217bf60f0a7f diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index 67a949b..896a37a 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -37,10 +37,10 @@ export type NativeCompletionParams = { temperature?: number // -> temp - repeat_last_n?: number - repeat_penalty?: number - presence_penalty?: number - frequency_penalty?: number + penalty_last_n?: number + penalty_repeat?: number + penalty_freq?: number + penalty_present?: number mirostat?: number mirostat_tau?: number mirostat_eta?: number