From fcf6538ba6702c55eaec70da9a75c81d04900a72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 21 May 2024 19:27:12 +0200 Subject: [PATCH 01/15] CUDA: fix unused warning in mmq.cu (#7442) --- ggml-cuda/mmq.cu | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ggml-cuda/mmq.cu b/ggml-cuda/mmq.cu index 5b540d375031b..933d799ce8bcb 100644 --- a/ggml-cuda/mmq.cu +++ b/ggml-cuda/mmq.cu @@ -1220,6 +1220,7 @@ template static __global__ void load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else + GGML_UNUSED(get_arch_config_device); GGML_UNUSED(vec_dot_q4_0_q8_1_mul_mat); NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1244,6 +1245,7 @@ template static __global__ void load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else + GGML_UNUSED(get_arch_config_device); GGML_UNUSED(vec_dot_q4_1_q8_1_mul_mat); NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1266,6 +1268,7 @@ template static __global__ void load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else + GGML_UNUSED(get_arch_config_device); GGML_UNUSED(vec_dot_q5_0_q8_1_mul_mat); NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1288,6 +1291,7 @@ mul_mat_q5_1( load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else + GGML_UNUSED(get_arch_config_device); GGML_UNUSED(vec_dot_q5_1_q8_1_mul_mat); NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1310,6 +1314,7 @@ template static __global__ void load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else + GGML_UNUSED(get_arch_config_device); GGML_UNUSED(vec_dot_q8_0_q8_1_mul_mat); NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1332,6 +1337,7 @@ mul_mat_q2_K( load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else + GGML_UNUSED(get_arch_config_device); GGML_UNUSED(vec_dot_q2_K_q8_1_mul_mat); NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1356,6 +1362,7 @@ template static __global__ void load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else + GGML_UNUSED(get_arch_config_device); GGML_UNUSED(vec_dot_q3_K_q8_1_mul_mat); NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1380,6 +1387,7 @@ template static __global__ void load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else + GGML_UNUSED(get_arch_config_device); GGML_UNUSED(vec_dot_q4_K_q8_1_mul_mat); NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1402,6 +1410,7 @@ mul_mat_q5_K( load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else + GGML_UNUSED(get_arch_config_device); GGML_UNUSED(vec_dot_q5_K_q8_1_mul_mat); NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1426,6 +1435,7 @@ template static __global__ void load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else + GGML_UNUSED(get_arch_config_device); GGML_UNUSED(vec_dot_q6_K_q8_1_mul_mat); NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A From e402de364b643cb89ea9f43057733b5d36298670 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Tue, 21 May 2024 20:40:00 +0100 Subject: [PATCH 02/15] `grammars`: fix resampling logic regression (#7424) --- common/sampling.cpp | 13 +++++++------ examples/main/main.cpp | 4 ++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index f0f1b92d37f59..7fc2e2158d5c4 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -179,7 +179,7 @@ static llama_token llama_sampling_sample_impl( struct llama_context * ctx_main, struct llama_context * ctx_cfg, const int idx, - bool is_resampling) { // Add a parameter to indicate if we are resampling + bool is_resampling) { const llama_sampling_params & params = ctx_sampling->params; const float temp = params.temp; @@ -188,8 +188,8 @@ static llama_token llama_sampling_sample_impl( const float mirostat_eta = params.mirostat_eta; std::vector original_logits; - auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits); - if (!is_resampling) { + auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); + if (ctx_sampling->grammar != NULL && !is_resampling) { GGML_ASSERT(!original_logits.empty()); } llama_token id = 0; @@ -252,7 +252,7 @@ static llama_token llama_sampling_sample_impl( // Restore logits from the copy std::copy(original_logits.begin(), original_logits.end(), logits); - return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling + return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true); } } @@ -285,7 +285,8 @@ static llama_token_data_array llama_sampling_prepare_impl( // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); - if (apply_grammar && original_logits != NULL) { + if (ctx_sampling->grammar != NULL && !apply_grammar) { + GGML_ASSERT(original_logits != NULL); // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))}; } @@ -342,7 +343,7 @@ llama_token llama_sampling_sample( struct llama_context * ctx_cfg, const int idx) { // Call the implementation function with is_resampling set to false by default - return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); + return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false); } llama_token_data_array llama_sampling_prepare( diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 9dee41001f12c..832b51ee086be 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -707,7 +707,7 @@ int main(int argc, char ** argv) { const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); - llama_sampling_accept(ctx_sampling, ctx, id, true); + llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true); LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); @@ -728,7 +728,7 @@ int main(int argc, char ** argv) { // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); + llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { From 6369bf04336ab60e5c892dd77a3246df91015147 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 21 May 2024 23:03:42 +0300 Subject: [PATCH 03/15] metal : handle F16 inf values, fix FA partial offload (#7434) ggml-ci --- ggml-metal.metal | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 386e9195fcffa..cf262e8349874 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2204,11 +2204,7 @@ kernel void kernel_flash_attn_ext_f16( // pointer to the mask device const half * mp = (device const half *) (mask + iq1*nb31); - // prepare diagonal scale matrix - simdgroup_float8x8 mscale(scale); - - // prepare diagonal slope matrix - simdgroup_float8x8 mslope(1.0f); + float slope = 1.0f; // ALiBi if (max_bias > 0.0f) { @@ -2217,7 +2213,7 @@ kernel void kernel_flash_attn_ext_f16( const float base = h < n_head_log2 ? m0 : m1; const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - mslope = simdgroup_float8x8(pow(base, exph)); + slope = pow(base, exph); } // loop over the KV cache @@ -2242,18 +2238,20 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } + simdgroup_store(mqk, ss + 8*cc, TF, 0, false); + + const short tx = tiisg%4; + const short ty = tiisg/4; + if (mask != q) { // mqk = mqk*scale + mask*slope - simdgroup_half8x8 mm; - simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); - simdgroup_multiply(mm, mslope, mm); - simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; + ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; } else { // mqk = mqk*scale - simdgroup_multiply(mqk, mscale, mqk); + ss[8*cc + ty*TF + 2*tx + 0] *= scale; + ss[8*cc + ty*TF + 2*tx + 1] *= scale; } - - simdgroup_store(mqk, ss + 8*cc, TF, 0, false); } } @@ -2816,8 +2814,7 @@ kernel void kernel_cpy_f32_f16( for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - // TODO: is there a better way to handle -INFINITY? - dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0]; + dst_data[i00] = src[0]; } } From 201cc11afa0a1950e1f632390b2ac6c937a0d8f0 Mon Sep 17 00:00:00 2001 From: liuwei-git <14815172+liuwei-git@users.noreply.github.com> Date: Wed, 22 May 2024 04:28:32 +0800 Subject: [PATCH 04/15] llama : add phi3 128K model support (#7225) * add phi3 128k support in convert-hf-to-gguf * add phi3 128k support in cuda * address build warnings on llama.cpp * adjust index value in cuda long rope freq factors * add long rope support in ggml cpu backend * make freq factors only depend on ctx size * remove unused rope scaling type 'su' frin gguf converter * fix flint warnings on convert-hf-to-gguf.py * set to the short freq factor when context size is small than trained context size * add one line of comments * metal : support rope freq_factors * ggml : update ggml_rope_ext API to support freq. factors * backends : add dev messages to support rope freq. factors * minor : style * tests : update to use new rope API * backends : fix pragma semicolons * minor : cleanup * llama : move rope factors from KV header to tensors * llama : remove tmp assert * cuda : fix compile warning * convert : read/write n_head_kv * llama : fix uninitialized tensors --------- Co-authored-by: Georgi Gerganov --- convert-hf-to-gguf.py | 49 +++- examples/finetune/finetune.cpp | 4 +- .../train-text-from-scratch.cpp | 4 +- ggml-cuda/rope.cu | 72 +++-- ggml-kompute.cpp | 4 + ggml-metal.m | 121 ++++---- ggml-metal.metal | 6 +- ggml-sycl.cpp | 3 + ggml-vulkan.cpp | 4 + ggml.c | 80 ++++- ggml.h | 45 ++- gguf-py/gguf/constants.py | 17 +- gguf-py/gguf/gguf_writer.py | 3 + llama.cpp | 277 +++++++++++------- tests/test-backend-ops.cpp | 16 +- 15 files changed, 478 insertions(+), 227 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 6357d40348b34..daad1c4fc7255 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -14,6 +14,7 @@ from hashlib import sha256 from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast +import math import numpy as np import torch @@ -1784,23 +1785,59 @@ def set_vocab(self): def set_gguf_parameters(self): block_count = self.find_hparam(["num_hidden_layers", "n_layer"]) - rot_pct = 1.0 n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) + n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) rms_eps = self.find_hparam(["rms_norm_eps"]) + max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) + orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) + rope_dims = n_embd // n_head self.gguf_writer.add_name("Phi3") - self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"])) - + self.gguf_writer.add_context_length(max_pos_embds) + self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds) self.gguf_writer.add_embedding_length(n_embd) - self.gguf_writer.add_feed_forward_length(8192) + self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"])) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_head_count(n_head) - self.gguf_writer.add_head_count_kv(n_head) + self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_layer_norm_rms_eps(rms_eps) - self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) + self.gguf_writer.add_rope_dimension_count(rope_dims) + self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"])) self.gguf_writer.add_file_type(self.ftype) + # write rope scaling for long context (128k) model + rope_scaling = self.find_hparam(['rope_scaling'], True) + if (rope_scaling is None): + return + + scale = max_pos_embds / orig_max_pos_embds + + rope_scaling_type = rope_scaling.get('type', '').lower() + if len(rope_scaling_type) == 0: + raise KeyError('Missing the required key rope_scaling.type') + + if rope_scaling_type == 'su': + attn_factor = math.sqrt(1 + math.log(scale) / math.log(orig_max_pos_embds)) if scale > 1.0 else 1.0 + elif rope_scaling_type == 'yarn': + attn_factor = 0.1 * math.log(scale) + 1.0 if scale > 1.0 else 1.0 + else: + raise NotImplementedError(f'The rope scaling type {rope_scaling_type} is not supported yet') + + self.gguf_writer.add_rope_scaling_attn_factors(attn_factor) + + long_factors = rope_scaling.get('long_factor', None) + short_factors = rope_scaling.get('short_factor', None) + + if long_factors is None or short_factors is None: + raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') + + if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: + raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') + + self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32)) + self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32)) + @Model.register("PlamoForCausalLM") class PlamoModel(Model): diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 22743b1bf02fd..992426c1b69e2 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -563,8 +563,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( // not capturing these, to silcence warnings const int rope_mode = 0; - return ggml_rope_custom(ctx, - t, KQ_pos, n_rot, rope_mode, n_ctx, 0, + return ggml_rope_ext(ctx, + t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f ); }; diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 587418cc73964..45bdfa8f5d80c 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -301,8 +301,8 @@ static struct ggml_tensor * llama_build_train_graphs( // not capturing these, to silcence warnings const int rope_mode = 0; - return ggml_rope_custom( - ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f + return ggml_rope_ext( + ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f ); }; diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index 4b0d2e5adbbc5..4a558f4b3757e 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -58,10 +58,10 @@ static __global__ void rope( dst[i + 1] = x0*sin_theta + x1*cos_theta; } -template +template static __global__ void rope_neox( const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors ) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -88,7 +88,9 @@ static __global__ void rope_neox( float cur_rot = inv_ndims * ic - ib; const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f); + const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f; + + const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor; float cos_theta, sin_theta; rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); @@ -164,7 +166,7 @@ static void rope_cuda( template static void rope_neox_cuda( const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream ) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); @@ -175,15 +177,29 @@ static void rope_neox_cuda( const float inv_ndims = -1.0f / n_dims; if (pos == nullptr) { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims - ); + if (freq_factors == nullptr) { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } else { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } } else { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims - ); + if (freq_factors == nullptr) { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } else { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } } } @@ -214,24 +230,27 @@ static void rope_cuda_f32( static void rope_neox_cuda_f16( const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) { + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); + rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } static void rope_neox_cuda_f32( const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream ) { - rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); + rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); @@ -241,7 +260,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; - const int64_t ne2 = dst->ne[2]; const int64_t nrows = ggml_nrows(src0); //const int n_past = ((int32_t *) dst->op_params)[0]; @@ -259,16 +277,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + const float * freq_factors = nullptr; const int32_t * pos = nullptr; - if ((mode & 1) == 0) { - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(src1->ne[0] == ne2); - pos = (const int32_t *) src1_d; - } const bool is_neox = mode & 2; const bool is_glm = mode & 4; + if (is_neox) { + pos = (const int32_t *) src1_d; + + if (src2 != nullptr) { + freq_factors = (const float *) src2->data; + } + } else { + GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox"); + } + rope_corr_dims corr_dims; ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v); @@ -280,12 +304,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (src0->type == GGML_TYPE_F32) { rope_neox_cuda_f32( (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + attn_factor, corr_dims, freq_factors, stream ); } else if (src0->type == GGML_TYPE_F16) { rope_neox_cuda_f16( (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + attn_factor, corr_dims, freq_factors, stream ); } else { GGML_ASSERT(false); diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 3f033d58be481..6c6058b2a95b1 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1677,6 +1677,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml } break; case GGML_OP_ROPE: { +#pragma message("TODO: implement phi3 frequency factors support") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225") + GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); + GGML_ASSERT(ne10 == ne02); GGML_ASSERT(src0t == dstt); // const int n_past = ((int32_t *) dst->op_params)[0]; diff --git a/ggml-metal.m b/ggml-metal.m index b0b16dbf77160..5d5ad20ada788 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -927,22 +927,32 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t ne10 = src1 ? src1->ne[0] : 0; const int64_t ne11 = src1 ? src1->ne[1] : 0; const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + const int64_t ne13 = src1 ? src1->ne[3] : 0; const uint64_t nb10 = src1 ? src1->nb[0] : 0; const uint64_t nb11 = src1 ? src1->nb[1] : 0; const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + const uint64_t nb13 = src1 ? src1->nb[3] : 0; - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; + const int64_t ne20 = src2 ? src2->ne[0] : 0; + const int64_t ne21 = src2 ? src2->ne[1] : 0; + const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); + const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; + const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; + + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; + + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; @@ -1785,16 +1795,6 @@ static enum ggml_status ggml_metal_graph_compute( const int n_as = src0->ne[2]; // src2 = ids - const int64_t ne20 = src2->ne[0]; - const int64_t ne21 = src2->ne[1]; - const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22); - const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23); - - const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20); - const uint64_t nb21 = src2->nb[1]; - const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22); - const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23); - const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); GGML_ASSERT(src2t == GGML_TYPE_I32); @@ -2244,7 +2244,13 @@ static enum ggml_status ggml_metal_graph_compute( // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -2252,6 +2258,15 @@ static enum ggml_status ggml_metal_graph_compute( memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; + + GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal"); + + if (!is_neox) { + GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox"); + } + id pipeline = nil; switch (src0->type) { @@ -2263,33 +2278,38 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:19]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:20]; - [encoder setBytes:&mode length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + if (id_src2 != nil) { + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; + [encoder setBytes:&mode length:sizeof( int) atIndex:22]; + [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:24]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -2535,11 +2555,6 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); - const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; - const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); //const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); diff --git a/ggml-metal.metal b/ggml-metal.metal index cf262e8349874..c5eb252808377 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1640,6 +1640,7 @@ static void rope_yarn_corr_dims( typedef void (rope_t)( device const void * src0, device const int32_t * src1, + device const float * src2, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1675,6 +1676,7 @@ template kernel void kernel_rope( device const void * src0, device const int32_t * src1, + device const float * src2, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1744,8 +1746,10 @@ kernel void kernel_rope( // simplified from `(ib * n_dims + ic) * inv_ndims` const float cur_rot = inv_ndims*ic - ib; + const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f; + + const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor; - const float theta = theta_0 * pow(freq_base, cur_rot); float cos_theta, sin_theta; rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index eac8f55796735..f486b6c0a5a3b 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -14454,6 +14454,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const dpct::queue_ptr &main_stream) { +#pragma message("TODO: implement phi3 frequency factors support") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225") + GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index aff451b6354e5..16287a28089a0 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -4238,6 +4238,10 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#pragma message("TODO: implement phi3 frequency factors support") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225") + GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); + const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; // const int n_ctx = ((int32_t *) dst->op_params)[3]; diff --git a/ggml.c b/ggml.c index 4bd911528586b..37b16b7a9ce7f 100644 --- a/ggml.c +++ b/ggml.c @@ -6231,6 +6231,7 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -6248,6 +6249,11 @@ static struct ggml_tensor * ggml_rope_impl( GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[2] == b->ne[0]); + if (c) { + GGML_ASSERT(c->type == GGML_TYPE_F32); + GGML_ASSERT(c->ne[0] >= n_dims / 2); + } + bool is_node = false; if (a->grad) { @@ -6271,6 +6277,7 @@ static struct ggml_tensor * ggml_rope_impl( result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; + result->src[2] = c; return result; } @@ -6283,7 +6290,7 @@ struct ggml_tensor * ggml_rope( int mode, int n_ctx) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false + ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false ); } @@ -6295,14 +6302,15 @@ struct ggml_tensor * ggml_rope_inplace( int mode, int n_ctx) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true + ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true ); } -struct ggml_tensor * ggml_rope_custom( +struct ggml_tensor * ggml_rope_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -6314,15 +6322,16 @@ struct ggml_tensor * ggml_rope_custom( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false ); } -struct ggml_tensor * ggml_rope_custom_inplace( +struct ggml_tensor * ggml_rope_ext_inplace( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -6334,19 +6343,49 @@ struct ggml_tensor * ggml_rope_custom_inplace( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true ); } -struct ggml_tensor * ggml_rope_xpos_inplace( +struct ggml_tensor * ggml_rope_custom( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - float base, - bool down) { - return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true); + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false + ); +} + +struct ggml_tensor * ggml_rope_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true + ); } // ggml_rope_back @@ -6355,6 +6394,7 @@ struct ggml_tensor * ggml_rope_back( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -6370,6 +6410,7 @@ struct ggml_tensor * ggml_rope_back( GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[2] == b->ne[0]); + GGML_ASSERT(c == NULL && "freq factors not implemented yet"); GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); @@ -14304,6 +14345,7 @@ static void ggml_compute_forward_rope_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; @@ -14363,6 +14405,17 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const float * freq_factors = NULL; + if (is_neox) { + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; + } + } else { + GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1"); + } + // backward process uses inverse rotation by cos and sin. // cos and sin build a rotation matrix, where the inverse is the transpose. // this essentially just switches the sign of sin. @@ -14439,10 +14492,11 @@ static void ggml_compute_forward_rope_f32( // simplified from `(ib * n_dims + ic) * inv_ndims` float cur_rot = inv_ndims * ic - ib; + float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f; float cos_theta, sin_theta; rope_yarn( - theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta ); sin_theta *= sin_sign; @@ -18387,6 +18441,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) { struct ggml_tensor * src0 = tensor->src[0]; struct ggml_tensor * src1 = tensor->src[1]; + struct ggml_tensor * src2 = tensor->src[2]; switch (tensor->op) { case GGML_OP_DUP: @@ -18918,6 +18973,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_rope_back(ctx, tensor->grad, src1, + src2, n_dims, mode, n_ctx, @@ -18957,6 +19013,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_rope_impl(ctx, tensor->grad, src1, + src2, n_dims, mode, n_ctx, @@ -19038,7 +19095,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor masked); } - struct ggml_tensor * src2 = tensor->src[2]; const int64_t elem_q = ggml_nelements(src0); const int64_t elem_k = ggml_nelements(src1); const int64_t elem_v = ggml_nelements(src2); diff --git a/ggml.h b/ggml.h index 77475710129d7..35ac9110ceb17 100644 --- a/ggml.h +++ b/ggml.h @@ -1465,6 +1465,7 @@ extern "C" { // if mode & 4 == 1, ChatGLM style // // b is an int32 vector with size a->ne[2], it contains the positions + // c is freq factors (e.g. phi3-128k), (optional) GGML_API struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1483,10 +1484,11 @@ extern "C" { int n_ctx); // custom RoPE - GGML_API struct ggml_tensor * ggml_rope_custom( + GGML_API struct ggml_tensor * ggml_rope_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -1499,10 +1501,11 @@ extern "C" { float beta_slow); // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_rope_custom_inplace( + GGML_API struct ggml_tensor * ggml_rope_ext_inplace( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -1514,18 +1517,41 @@ extern "C" { float beta_fast, float beta_slow); - // compute correction dims for YaRN RoPE scaling - GGML_CALL void ggml_rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow), + "use ggml_rope_ext instead"); - // xPos RoPE, in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - float base, - bool down); + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow), + "use ggml_rope_ext_inplace instead"); + + // compute correction dims for YaRN RoPE scaling + GGML_CALL void ggml_rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); // rotary position embedding backward, i.e compute dx from dy // a - dy @@ -1533,6 +1559,7 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 692120f4d64b0..42df2e4d00604 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -57,12 +57,13 @@ class Attention: CAUSAL = "{arch}.attention.causal" class Rope: - DIMENSION_COUNT = "{arch}.rope.dimension_count" - FREQ_BASE = "{arch}.rope.freq_base" - SCALING_TYPE = "{arch}.rope.scaling.type" - SCALING_FACTOR = "{arch}.rope.scaling.factor" - SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" - SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + DIMENSION_COUNT = "{arch}.rope.dimension_count" + FREQ_BASE = "{arch}.rope.freq_base" + SCALING_TYPE = "{arch}.rope.scaling.type" + SCALING_FACTOR = "{arch}.rope.scaling.factor" + SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" + SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" + SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" class SSM: CONV_KERNEL = "{arch}.ssm.conv_kernel" @@ -148,6 +149,8 @@ class MODEL_TENSOR(IntEnum): OUTPUT = auto() OUTPUT_NORM = auto() ROPE_FREQS = auto() + ROPE_FACTORS_LONG = auto() + ROPE_FACTORS_SHORT = auto() ATTN_Q = auto() ATTN_K = auto() ATTN_V = auto() @@ -225,6 +228,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.ROPE_FREQS: "rope_freqs", + MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", + MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index d5e323a52ef14..8b41b54eaa5a6 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -433,6 +433,9 @@ def add_rope_scaling_type(self, value: RopeScalingType) -> None: def add_rope_scaling_factor(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value) + def add_rope_scaling_attn_factors(self, value: Sequence[float]) -> None: + self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value) + def add_rope_scaling_orig_ctx_len(self, value: int) -> None: self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value) diff --git a/llama.cpp b/llama.cpp index d26fe559a2051..abff8c1c03e7a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -304,6 +304,7 @@ enum llm_kv { LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ATTN_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, @@ -381,6 +382,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, @@ -436,6 +438,8 @@ enum llm_tensor { LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, LLM_TENSOR_ATTN_Q, LLM_TENSOR_ATTN_K, LLM_TENSOR_ATTN_V, @@ -803,18 +807,20 @@ static const std::map> LLM_TENSOR_NA { LLM_ARCH_PHI3, { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, { @@ -1750,6 +1756,7 @@ struct llama_hparams { float f_norm_eps; float f_norm_rms_eps; + float rope_attn_factor = 1.0f; float rope_freq_base_train; float rope_freq_scale_train; uint32_t n_yarn_orig_ctx; @@ -1798,6 +1805,7 @@ struct llama_hparams { if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; + if (!is_float_close(this->rope_attn_factor, other.rope_attn_factor, EPSILON)) return true; if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; @@ -2103,6 +2111,10 @@ struct llama_model { struct ggml_tensor * output; struct ggml_tensor * output_b; + // long rope factors + struct ggml_tensor * rope_long = nullptr; + struct ggml_tensor * rope_short = nullptr; + std::vector layers; llama_split_mode split_mode; @@ -3306,6 +3318,39 @@ struct llama_model_loader { return get_arr_n(llm_kv(kid), result, required); } + template + bool get_arr(const std::string & key, std::vector & result, const bool required = true) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) { + throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str())); + } + + // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); + GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same::value)); + GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same::value)); + + result.resize(arr_info.length); + result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + + return true; + } + + template + bool get_arr(const enum llm_kv kid, T& result, const bool required = true) { + return get_arr(llm_kv(kid), result, required); + } + template bool get_key(const std::string & key, T & result, const bool required = true) { auto it = kv_overrides.find(key); @@ -3849,6 +3894,8 @@ static void llm_load_hparams( } hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; + ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); + // sanity check for n_rot (optional) { hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; @@ -4880,6 +4927,7 @@ static bool llm_load_tensors( // create tensors for the weights { const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_head = n_embd / hparams.n_head; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = n_embd_v_gqa; @@ -5591,6 +5639,9 @@ static bool llm_load_tensors( { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }); + model.rope_long = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head/2 }, false); + model.rope_short = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, false); + // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }); @@ -5601,12 +5652,12 @@ static bool llm_load_tensors( ggml_context* ctx_layer = ctx_for_layer(i); ggml_context* ctx_split = ctx_for_layer_split(i); - auto& layer = model.layers[i]; + auto & layer = model.layers[i]; layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }); layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, false); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }); @@ -6821,17 +6872,20 @@ struct llm_build_context { cb(lctx.inp_K_shift, "K_shift", -1); ggml_set_input(lctx.inp_K_shift); + struct ggml_tensor * rope_factors = build_rope_factors(); + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * tmp = // we rotate only the first n_rot dimensions - ggml_rope_custom_inplace(ctx0, + ggml_rope_ext_inplace(ctx0, ggml_view_3d(ctx0, kv_self.k_l[il], n_embd_head_k, n_head_kv, n_ctx, ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), 0), - lctx.inp_K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + cb(tmp, "K_shifted", il); ggml_build_forward_expand(gf, tmp); } @@ -6934,6 +6988,17 @@ struct llm_build_context { return lctx.inp_pos; } + struct ggml_tensor * build_rope_factors() { + // choose long/short freq factors based on the context size + const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max; + + if (n_ctx_pre_seq > hparams.n_yarn_orig_ctx) { + return model.rope_long; + } + + return model.rope_short; + } + struct ggml_tensor * build_inp_out_ids() { lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs); cb(lctx.inp_out_ids, "inp_out_ids", -1); @@ -7041,15 +7106,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -7171,13 +7236,13 @@ struct llm_build_context { switch (model.type) { case MODEL_7B: - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -7283,15 +7348,15 @@ struct llm_build_context { struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -7404,14 +7469,14 @@ struct llm_build_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -7527,15 +7592,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -7679,15 +7744,15 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8032,15 +8097,15 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8472,15 +8537,15 @@ struct llm_build_context { } - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8592,14 +8657,14 @@ struct llm_build_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -8703,15 +8768,15 @@ struct llm_build_context { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8817,15 +8882,15 @@ struct llm_build_context { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8969,8 +9034,8 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); @@ -8980,8 +9045,8 @@ struct llm_build_context { Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9052,6 +9117,9 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + // rope freq factors for 128k context + struct ggml_tensor * rope_factors = build_rope_factors(); + for (int il = 0; il < n_layer; ++il) { auto residual = inpL; @@ -9088,8 +9156,8 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); @@ -9097,8 +9165,8 @@ struct llm_build_context { Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9204,14 +9272,14 @@ struct llm_build_context { struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr, n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr, n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); @@ -9412,15 +9480,15 @@ struct llm_build_context { cb(tmpk, "tmpk", il); cb(Vcur, "Vcur", il); - struct ggml_tensor * Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, + struct ggml_tensor * Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, + struct ggml_tensor * Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -9528,15 +9596,15 @@ struct llm_build_context { // cb(Vcur, "Vcur", il); // } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -9645,15 +9713,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -9775,15 +9843,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -9895,8 +9963,8 @@ struct llm_build_context { struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr, n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); @@ -9904,8 +9972,8 @@ struct llm_build_context { Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); cb(Qcur, "Qcur_scaled", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr, n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); @@ -10015,15 +10083,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -10305,15 +10373,15 @@ struct llm_build_context { cb(Kcur, "Kcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -10436,15 +10504,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -15417,6 +15485,7 @@ struct llama_context * llama_new_context_with_model( cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; } + cparams.yarn_attn_factor *= hparams.rope_attn_factor; cparams.causal_attn = hparams.causal_attn; if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c74e253db4b3b..1493a7ca7c405 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1763,14 +1763,14 @@ struct test_llama : public test_llm { struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur); struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur); - Qcur = ggml_rope_custom( - ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr, hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_custom( - ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr, hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -1889,13 +1889,13 @@ struct test_falcon : public test_llm { Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens); // using mode = 2 for neox mode - Qcur = ggml_rope_custom( - ctx, Qcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx, + Qcur = ggml_rope_ext( + ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_custom( - ctx, Kcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx, + Kcur = ggml_rope_ext( + ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); From 3e5faa85032ec3106a2ad831bf412be9ff139f47 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 22 May 2024 11:01:35 +0300 Subject: [PATCH 05/15] cuda : fix rope + add tests (#7452) * cuda : fix rope pos data ggml-ci * ggml : drop mode & 1 == 1 support for ggml_rope ggml-ci * ggml : support freq_factors for f16 rope (CPU) ggml-ci * tests : add rope tests using frequency factors ggml-ci --- ggml-cuda/rope.cu | 4 ++-- ggml.c | 20 +++++++++++++++++-- ggml.h | 2 +- tests/test-backend-ops.cpp | 41 ++++++++++++++++++++++++-------------- 4 files changed, 47 insertions(+), 20 deletions(-) diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index 4a558f4b3757e..50f2cf415ef60 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -283,9 +283,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const bool is_neox = mode & 2; const bool is_glm = mode & 4; - if (is_neox) { - pos = (const int32_t *) src1_d; + pos = (const int32_t *) src1_d; + if (is_neox) { if (src2 != nullptr) { freq_factors = (const float *) src2->data; } diff --git a/ggml.c b/ggml.c index 37b16b7a9ce7f..d316e3d316806 100644 --- a/ggml.c +++ b/ggml.c @@ -6245,6 +6245,8 @@ static struct ggml_tensor * ggml_rope_impl( float xpos_base, bool xpos_down, bool inplace) { + GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); + GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[2] == b->ne[0]); @@ -14413,7 +14415,7 @@ static void ggml_compute_forward_rope_f32( freq_factors = (const float *) src2->data; } } else { - GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1"); + GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox"); } // backward process uses inverse rotation by cos and sin. @@ -14529,6 +14531,7 @@ static void ggml_compute_forward_rope_f32( } } +// TODO: deduplicate f16/f32 code static void ggml_compute_forward_rope_f16( const struct ggml_compute_params * params, struct ggml_tensor * dst, @@ -14536,6 +14539,7 @@ static void ggml_compute_forward_rope_f16( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; @@ -14588,6 +14592,17 @@ static void ggml_compute_forward_rope_f16( const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const float * freq_factors = NULL; + if (is_neox) { + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; + } + } else { + GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox"); + } + // backward process uses inverse rotation by cos and sin. // cos and sin build a rotation matrix, where the inverse is the transpose. // this essentially just switches the sign of sin. @@ -14660,10 +14675,11 @@ static void ggml_compute_forward_rope_f16( // simplified from `(ib * n_dims + ic) * inv_ndims` float cur_rot = inv_ndims * ic - ib; + float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f; float cos_theta, sin_theta; rope_yarn( - theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta ); sin_theta *= sin_sign; diff --git a/ggml.h b/ggml.h index 35ac9110ceb17..08835042c0bfd 100644 --- a/ggml.h +++ b/ggml.h @@ -1460,7 +1460,7 @@ extern "C" { struct ggml_tensor * b); // rotary position embedding - // if mode & 1 == 1, skip n_past elements (DEPRECATED) + // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED) // if mode & 2 == 1, GPT-NeoX style // if mode & 4 == 1, ChatGLM style // diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 1493a7ca7c405..de74585da29dd 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1142,20 +1142,22 @@ struct test_rope : public test_case { int n_dims; int mode; int n_ctx; + bool ff; std::string vars() override { - return VARS_TO_STR5(type, ne, n_dims, mode, n_ctx); + return VARS_TO_STR6(type, ne, n_dims, mode, n_ctx, ff); } test_rope(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 1}, - int n_dims = 10, int mode = 0, int n_ctx = 512) - : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx) {} + int n_dims = 10, int mode = 0, int n_ctx = 512, bool ff = false) + : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), ff(ff) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]); - ggml_tensor * out = ggml_rope(ctx, a, pos, n_dims, mode, n_ctx); + ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr; + ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f); return out; } @@ -1169,7 +1171,12 @@ struct test_rope : public test_case { } ggml_backend_tensor_set(t, data.data(), 0, ne[2] * sizeof(int)); } else { - init_tensor_uniform(t); + if (t->ne[0] == n_dims/2) { + // frequency factors in the range [0.9f, 1.1f] + init_tensor_uniform(t, 0.9f, 1.1f); + } else { + init_tensor_uniform(t); + } } } } @@ -2188,16 +2195,20 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f)); for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B - test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512)); // llama 13B - test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512)); // llama 30B - test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512)); // llama 65B - test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B) - test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B) - test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm) - test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2) + // TODO: ff not supported yet for !neox + test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, false)); // llama 7B + test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, false)); // llama 13B + test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, false)); // llama 30B + test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, false)); // llama 65B + + for (bool ff : {false, true}) { // freq_factors + test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512, ff)); // neox (stablelm) + test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512, ff)); // neox (phi-2) + } } test_cases.emplace_back(new test_concat(GGML_TYPE_F32)); From 95fb0aefab568348da159efdd370e064d1b35f97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 22 May 2024 10:24:29 +0200 Subject: [PATCH 06/15] CUDA: remove incorrect precision check (#7454) --- ggml-cuda/fattn-tile-f32.cu | 3 --- 1 file changed, 3 deletions(-) diff --git a/ggml-cuda/fattn-tile-f32.cu b/ggml-cuda/fattn-tile-f32.cu index 130e7cbdbe10d..54db765e2f8ee 100644 --- a/ggml-cuda/fattn-tile-f32.cu +++ b/ggml-cuda/fattn-tile-f32.cu @@ -286,9 +286,6 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[2]; - GGML_ASSERT(precision == GGML_PREC_DEFAULT); - if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; From 9b3d83318931aa98c487baaa977626931d059e6a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 22 May 2024 12:36:37 +0300 Subject: [PATCH 07/15] cuda : fix compile warning (#7454) --- ggml-cuda/fattn-tile-f32.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ggml-cuda/fattn-tile-f32.cu b/ggml-cuda/fattn-tile-f32.cu index 54db765e2f8ee..b8b2f69e19edb 100644 --- a/ggml-cuda/fattn-tile-f32.cu +++ b/ggml-cuda/fattn-tile-f32.cu @@ -283,8 +283,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * } void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * Q = dst->src[0]; if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; From 03d8900ebe062355e26a562379daee5f17ea099f Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Wed, 22 May 2024 07:08:18 -0400 Subject: [PATCH 08/15] llama : add missing model type names (#7445) --- llama.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llama.cpp b/llama.cpp index abff8c1c03e7a..d8c6f29a536aa 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3771,14 +3771,17 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { static const char * llama_model_type_name(e_model type) { switch (type) { + case MODEL_17M: return "17M"; case MODEL_22M: return "22M"; case MODEL_33M: return "33M"; case MODEL_109M: return "109M"; case MODEL_137M: return "137M"; + case MODEL_335M: return "335M"; case MODEL_0_5B: return "0.5B"; case MODEL_1B: return "1B"; case MODEL_2B: return "2B"; case MODEL_3B: return "3B"; + case MODEL_4B: return "4B"; case MODEL_7B: return "7B"; case MODEL_8B: return "8B"; case MODEL_12B: return "12B"; From fcda1128bc5f8eb7e1811708fe9d9867b9aec815 Mon Sep 17 00:00:00 2001 From: "k.h.lai" Date: Wed, 22 May 2024 20:53:21 +0800 Subject: [PATCH 09/15] vulkan: add workaround for iterator boundary check to fix clang-cl debug build (#7426) --- CMakeLists.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9cc60039a8416..c09d834fb010d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -505,6 +505,12 @@ if (LLAMA_VULKAN) add_compile_definitions(GGML_USE_VULKAN) + # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build + # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector + if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0) + endif() + if (LLAMA_VULKAN_CHECK_RESULTS) add_compile_definitions(GGML_VULKAN_CHECK_RESULTS) endif() From b18532a4efeca8796fea8e36195c81cbfd596a4a Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 22 May 2024 16:10:46 +0200 Subject: [PATCH 10/15] phi3 : duplicate rope factors in each layer (#7447) * phi3 : duplicate rope factors in each layer phi3 : set phi-3 model type as 14B model loader : simplify the process for duplicating model tensors llama-bench : remove default pg test * replace bool parameters in llama_model_loader with named flags --- examples/llama-bench/llama-bench.cpp | 2 +- llama.cpp | 178 ++++++++++++--------------- 2 files changed, 83 insertions(+), 97 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 8b965e1990ba5..6bb1f70c3c8dc 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -195,7 +195,7 @@ static const cmd_params cmd_params_defaults = { /* model */ {"models/7B/ggml-model-q4_0.gguf"}, /* n_prompt */ {512}, /* n_gen */ {128}, - /* n_pg */ {{512, 128}}, + /* n_pg */ {}, /* n_batch */ {2048}, /* n_ubatch */ {512}, /* type_k */ {GGML_TYPE_F16}, diff --git a/llama.cpp b/llama.cpp index d8c6f29a536aa..34137c7ade6b2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1940,6 +1940,10 @@ struct llama_layer { // mamba bias struct ggml_tensor * ssm_conv1d_b; struct ggml_tensor * ssm_dt_b; + + // long rope factors + struct ggml_tensor * rope_long = nullptr; + struct ggml_tensor * rope_short = nullptr; }; struct llama_kv_cell { @@ -2111,10 +2115,6 @@ struct llama_model { struct ggml_tensor * output; struct ggml_tensor * output_b; - // long rope factors - struct ggml_tensor * rope_long = nullptr; - struct ggml_tensor * rope_short = nullptr; - std::vector layers; llama_split_mode split_mode; @@ -3425,11 +3425,15 @@ struct llama_model_loader { return get_tensor_meta(get_tensor_name(i)); } - struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur) { + struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, bool duplicated) { struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur); ggml_set_name(tensor, ggml_get_name(cur)); - n_created++; + if (duplicated) { + size_data += ggml_nbytes(cur); + } else { + n_created++; + } return tensor; } @@ -3464,14 +3468,17 @@ struct llama_model_loader { return cur; } - struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, bool required = true) { - const struct ggml_tensor * cur = check_tensor_dims(name, ne, required); + static const int TENSOR_NOT_REQUIRED = 1; + static const int TENSOR_DUPLICATED = 2; + + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, int flags = 0) { + const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); if (cur == NULL) { return NULL; } - return create_tensor_for(ctx, cur); + return create_tensor_for(ctx, cur, flags & TENSOR_DUPLICATED); } struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::vector & ne, size_t offset, bool required = true) { @@ -4139,6 +4146,7 @@ static void llm_load_hparams( switch (hparams.n_layer) { case 24: model.type = e_model::MODEL_1B; break; case 32: model.type = e_model::MODEL_3B; break; + case 40: model.type = e_model::MODEL_14B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -4965,12 +4973,10 @@ static bool llm_load_tensors( { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); if (model.arch != LLM_ARCH_MINICPM){ - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } } @@ -4989,10 +4995,10 @@ static bool llm_load_tensors( layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); // optional bias tensors - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, false); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, false); + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); @@ -5003,7 +5009,7 @@ static bool llm_load_tensors( } else { layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); if (layer.ffn_gate_exps) { layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); @@ -5045,12 +5051,10 @@ static bool llm_load_tensors( // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -5073,7 +5077,7 @@ static bool llm_load_tensors( layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); if (layer.ffn_gate_exps) { layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); @@ -5175,11 +5179,9 @@ static bool llm_load_tensors( model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // needs to be on GPU - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU } } @@ -5192,8 +5194,8 @@ static bool llm_load_tensors( layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, false); - layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, false); + layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); @@ -5211,12 +5213,10 @@ static bool llm_load_tensors( { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { // needs to be on GPU - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -5314,14 +5314,14 @@ static bool llm_load_tensors( layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, false); - layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, false); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, false); - layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, false); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); @@ -5383,18 +5383,16 @@ static bool llm_load_tensors( case LLM_ARCH_MPT: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, false); + model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED); // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, false); + model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // needs to be on GPU - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU } } @@ -5405,31 +5403,31 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, false); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, false); + layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, false); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, false); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, false); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, false); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, false); - layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, false); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, false); - layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, false); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); // AWQ ScaleActivation layer - layer.ffn_act = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, false); + layer.ffn_act = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); } } break; case LLM_ARCH_STABLELM: @@ -5458,17 +5456,17 @@ static bool llm_load_tensors( layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); // optional bias tensors, present in Stable LM 2 1.6B - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, false); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false); + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); // optional q and k layernorms, present in StableLM 2 12B - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}, false); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}, false); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED); // optional FFN norm, not present in StableLM 2 12B which uses parallel residual - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, false); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, false); + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); @@ -5511,12 +5509,10 @@ static bool llm_load_tensors( // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -5614,8 +5610,8 @@ static bool llm_load_tensors( layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, false); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, false); + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); if (layer.wqkv == nullptr) { layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); @@ -5642,9 +5638,6 @@ static bool llm_load_tensors( { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }); - model.rope_long = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head/2 }, false); - model.rope_short = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, false); - // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }); @@ -5659,13 +5652,16 @@ static bool llm_load_tensors( layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, false); + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }); + + layer.rope_long = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); } } break; case LLM_ARCH_PLAMO: @@ -5834,9 +5830,7 @@ static bool llm_load_tensors( // output model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // same as tok_embd, duplicated to allow offloading - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading const int64_t n_ff = hparams.n_ff; const int64_t n_embd_head_k = hparams.n_embd_head_k; @@ -5871,12 +5865,10 @@ static bool llm_load_tensors( model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -5927,12 +5919,10 @@ static bool llm_load_tensors( { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed, duplicated to allow offloading if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -5993,9 +5983,7 @@ static bool llm_load_tensors( { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); // init output from the input tok embed - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } for (int i = 0; i < n_layer; ++i) { @@ -6027,12 +6015,10 @@ static bool llm_load_tensors( // output { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -6875,9 +6861,9 @@ struct llm_build_context { cb(lctx.inp_K_shift, "K_shift", -1); ggml_set_input(lctx.inp_K_shift); - struct ggml_tensor * rope_factors = build_rope_factors(); for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * rope_factors = build_rope_factors(il); struct ggml_tensor * tmp = // we rotate only the first n_rot dimensions ggml_rope_ext_inplace(ctx0, @@ -6991,15 +6977,15 @@ struct llm_build_context { return lctx.inp_pos; } - struct ggml_tensor * build_rope_factors() { + struct ggml_tensor * build_rope_factors(int il) { // choose long/short freq factors based on the context size const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max; if (n_ctx_pre_seq > hparams.n_yarn_orig_ctx) { - return model.rope_long; + return model.layers[il].rope_long; } - return model.rope_short; + return model.layers[il].rope_short; } struct ggml_tensor * build_inp_out_ids() { @@ -9120,14 +9106,14 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // rope freq factors for 128k context - struct ggml_tensor * rope_factors = build_rope_factors(); - for (int il = 0; il < n_layer; ++il) { auto residual = inpL; // self-attention { + // rope freq factors for 128k context + struct ggml_tensor * rope_factors = build_rope_factors(il); + struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, From 38c03478a37e460ecd3a21155b338a83bfed7f90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 22 May 2024 17:58:25 +0200 Subject: [PATCH 11/15] CUDA: fix FA out-of-bounds writes (#7465) --- ggml-cuda/fattn-tile-f16.cu | 4 ++++ ggml-cuda/fattn-tile-f32.cu | 4 ++++ ggml-cuda/fattn-vec-f16.cu | 6 +++++- ggml-cuda/fattn-vec-f32.cu | 6 +++++- 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 4a07ac6adad71..586d469c049d1 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -238,6 +238,10 @@ static __global__ void flash_attn_tile_ext_f16( for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); kqsum_j = warp_reduce_sum(kqsum_j); diff --git a/ggml-cuda/fattn-tile-f32.cu b/ggml-cuda/fattn-tile-f32.cu index b8b2f69e19edb..b6ef8eb48d992 100644 --- a/ggml-cuda/fattn-tile-f32.cu +++ b/ggml-cuda/fattn-tile-f32.cu @@ -237,6 +237,10 @@ static __global__ void flash_attn_tile_ext_f32( for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + float kqsum_j = kqsum[j_VKQ_0/nwarps]; kqsum_j = warp_reduce_sum(kqsum_j); diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu index 54e1ac5d16050..7352dcabf6291 100644 --- a/ggml-cuda/fattn-vec-f16.cu +++ b/ggml-cuda/fattn-vec-f16.cu @@ -212,6 +212,10 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + if (ic0 + j_VKQ >= ne01) { + break; + } + kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); @@ -223,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16( dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; } - if (parallel_blocks != 1 && tid < ncols) { + if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } #else diff --git a/ggml-cuda/fattn-vec-f32.cu b/ggml-cuda/fattn-vec-f32.cu index 5bcabd0928451..11476a6c0fbbc 100644 --- a/ggml-cuda/fattn-vec-f32.cu +++ b/ggml-cuda/fattn-vec-f32.cu @@ -200,6 +200,10 @@ static __global__ void flash_attn_vec_ext_f32( #pragma unroll for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + if (ic0 + j_VKQ >= ne01) { + break; + } + kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); @@ -211,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32( dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; } - if (parallel_blocks != 1 && tid < ncols) { + if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } } From 6ff13987ad1a9519bee13dd98b6a21cd98979aab Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 22 May 2024 20:04:20 +0300 Subject: [PATCH 12/15] common : normalize naming style (#7462) * common : normalize naming style ggml-ci * common : match declaration / definition order * zig : try to fix build --- build.zig | 12 +- common/common.cpp | 1488 +++++++++++----------- common/common.h | 88 +- common/sampling.cpp | 83 +- common/sampling.h | 5 + common/train.cpp | 2 +- examples/batched/batched.cpp | 2 +- examples/embedding/embedding.cpp | 4 +- examples/eval-callback/eval-callback.cpp | 4 +- examples/imatrix/imatrix.cpp | 4 +- examples/infill/infill.cpp | 16 +- examples/llama-bench/llama-bench.cpp | 2 +- examples/llava/llava-cli.cpp | 2 +- examples/lookahead/lookahead.cpp | 2 +- examples/lookup/lookup.cpp | 2 +- examples/main/main.cpp | 16 +- examples/parallel/parallel.cpp | 2 +- examples/perplexity/perplexity.cpp | 14 +- examples/quantize/quantize.cpp | 2 +- examples/retrieval/retrieval.cpp | 4 +- examples/server/server.cpp | 10 +- 21 files changed, 897 insertions(+), 867 deletions(-) diff --git a/build.zig b/build.zig index 96783574fe740..267c976b14d1a 100644 --- a/build.zig +++ b/build.zig @@ -129,14 +129,14 @@ pub fn build(b: *std.build.Builder) !void { const clip = make.obj("clip", "examples/llava/clip.cpp"); const llava = make.obj("llava", "examples/llava/llava.cpp"); - _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser }); - _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo }); - _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo }); - _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo }); - _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train }); + _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, sampling, json_schema_to_grammar, buildinfo, console, grammar_parser }); + _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, sampling, json_schema_to_grammar, buildinfo }); + _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, sampling, json_schema_to_grammar, buildinfo }); + _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, sampling, json_schema_to_grammar, buildinfo }); + _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, sampling, json_schema_to_grammar, buildinfo, train }); _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train }); - const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, grammar_parser, clip, llava }); + const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, sampling, json_schema_to_grammar, buildinfo, grammar_parser, clip, llava }); if (server.target.isWindows()) { server.linkSystemLibrary("ws2_32"); } diff --git a/common/common.cpp b/common/common.cpp index ae11650b446a4..7500e08ff1be4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -73,7 +73,11 @@ using json = nlohmann::ordered_json; -int32_t get_num_physical_cores() { +// +// CPU utils +// + +int32_t cpu_get_num_physical_cores() { #ifdef __linux__ // enumerate the set of thread siblings, num entries is num cores std::unordered_set siblings; @@ -142,9 +146,9 @@ static bool is_running_on_efficiency_core(void) { return core_type == intel_atom; } -static int count_math_cpus(int cpu_count) { +static int cpu_count_math_cpus(int n_cpu) { int result = 0; - for (int cpu = 0; cpu < cpu_count; ++cpu) { + for (int cpu = 0; cpu < n_cpu; ++cpu) { if (pin_cpu(cpu)) { return -1; } @@ -162,16 +166,16 @@ static int count_math_cpus(int cpu_count) { /** * Returns number of CPUs on system that are useful for math. */ -int get_math_cpu_count() { +int32_t cpu_get_num_math() { #if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) - int cpu_count = sysconf(_SC_NPROCESSORS_ONLN); - if (cpu_count < 1) { - return get_num_physical_cores(); + int n_cpu = sysconf(_SC_NPROCESSORS_ONLN); + if (n_cpu < 1) { + return cpu_get_num_physical_cores(); } if (is_hybrid_cpu()) { cpu_set_t affinity; if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) { - int result = count_math_cpus(cpu_count); + int result = cpu_count_math_cpus(n_cpu); pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity); if (result > 0) { return result; @@ -179,108 +183,103 @@ int get_math_cpu_count() { } } #endif - return get_num_physical_cores(); + return cpu_get_num_physical_cores(); } -void process_escapes(std::string & input) { - std::size_t input_len = input.length(); - std::size_t output_idx = 0; +// +// CLI argument parsing +// - for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) { - if (input[input_idx] == '\\' && input_idx + 1 < input_len) { - switch (input[++input_idx]) { - case 'n': input[output_idx++] = '\n'; break; - case 'r': input[output_idx++] = '\r'; break; - case 't': input[output_idx++] = '\t'; break; - case '\'': input[output_idx++] = '\''; break; - case '\"': input[output_idx++] = '\"'; break; - case '\\': input[output_idx++] = '\\'; break; - case 'x': - // Handle \x12, etc - if (input_idx + 2 < input_len) { - const char x[3] = { input[input_idx + 1], input[input_idx + 2], 0 }; - char *err_p = nullptr; - const long val = std::strtol(x, &err_p, 16); - if (err_p == x + 2) { - input_idx += 2; - input[output_idx++] = char(val); - break; - } - } - // fall through - default: input[output_idx++] = '\\'; - input[output_idx++] = input[input_idx]; break; +void gpt_params_handle_model_default(gpt_params & params) { + if (!params.hf_repo.empty()) { + // short-hand to avoid specifying --hf-file -> default it to --model + if (params.hf_file.empty()) { + if (params.model.empty()) { + throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n"); } - } else { - input[output_idx++] = input[input_idx]; + params.hf_file = params.model; + } else if (params.model.empty()) { + std::string cache_directory = fs_get_cache_directory(); + const bool success = fs_create_directory_with_parents(cache_directory); + if (!success) { + throw std::runtime_error("failed to create cache directory: " + cache_directory); + } + params.model = cache_directory + string_split(params.hf_file, '/').back(); + } + } else if (!params.model_url.empty()) { + if (params.model.empty()) { + auto f = string_split(params.model_url, '#').front(); + f = string_split(f, '?').front(); + f = string_split(f, '/').back(); + params.model = "models/" + f; } + } else if (params.model.empty()) { + params.model = DEFAULT_MODEL_PATH; } +} - input.resize(output_idx); +bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { + bool invalid_param = false; + std::string arg; + const std::string arg_prefix = "--"; + llama_sampling_params & sparams = params.sparams; + + for (int i = 1; i < argc; i++) { + arg = argv[i]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) { + throw std::invalid_argument("error: unknown argument: " + arg); + } + if (invalid_param) { + throw std::invalid_argument("error: invalid parameter for argument: " + arg); + } + } + + if (params.prompt_cache_all && + (params.interactive || params.interactive_first || + params.instruct)) { + + throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); + } + + gpt_params_handle_model_default(params); + + if (params.escape) { + string_process_escapes(params.prompt); + string_process_escapes(params.input_prefix); + string_process_escapes(params.input_suffix); + string_process_escapes(sparams.cfg_negative_prompt); + for (auto & antiprompt : params.antiprompt) { + string_process_escapes(antiprompt); + } + } + + if (!params.kv_overrides.empty()) { + params.kv_overrides.emplace_back(); + params.kv_overrides.back().key[0] = 0; + } + + return true; } bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { bool result = true; try { if (!gpt_params_parse_ex(argc, argv, params)) { - gpt_print_usage(argc, argv, gpt_params()); + gpt_params_print_usage(argc, argv, gpt_params()); exit(0); } } catch (const std::invalid_argument & ex) { fprintf(stderr, "%s\n", ex.what()); - gpt_print_usage(argc, argv, gpt_params()); + gpt_params_print_usage(argc, argv, gpt_params()); exit(1); } return result; } -bool parse_kv_override(const char * data, std::vector & overrides) { - const char * sep = strchr(data, '='); - if (sep == nullptr || sep - data >= 128) { - fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data); - return false; - } - llama_model_kv_override kvo; - std::strncpy(kvo.key, data, sep - data); - kvo.key[sep - data] = 0; - sep++; - if (strncmp(sep, "int:", 4) == 0) { - sep += 4; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.val_i64 = std::atol(sep); - } else if (strncmp(sep, "float:", 6) == 0) { - sep += 6; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; - kvo.val_f64 = std::atof(sep); - } else if (strncmp(sep, "bool:", 5) == 0) { - sep += 5; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; - if (std::strcmp(sep, "true") == 0) { - kvo.val_bool = true; - } else if (std::strcmp(sep, "false") == 0) { - kvo.val_bool = false; - } else { - fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data); - return false; - } - } else if (strncmp(sep, "str:", 4) == 0) { - sep += 4; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; - if (strlen(sep) > 127) { - fprintf(stderr, "%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data); - return false; - } - strncpy(kvo.val_str, sep, 127); - kvo.val_str[127] = '\0'; - } else { - fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data); - return false; - } - overrides.emplace_back(std::move(kvo)); - return true; -} - bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { llama_sampling_params & sparams = params.sparams; @@ -546,7 +545,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } const auto sampler_names = string_split(argv[i], ';'); - sparams.samplers_sequence = sampler_types_from_names(sampler_names, true); + sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true); return true; } if (arg == "--sampling-seq") { @@ -554,7 +553,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } - sparams.samplers_sequence = sampler_types_from_chars(argv[i]); + sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]); return true; } if (arg == "--top-p") { @@ -1240,7 +1239,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-h" || arg == "--help") { - gpt_print_usage(argc, argv, gpt_params()); + gpt_params_print_usage(argc, argv, gpt_params()); exit(0); } if (arg == "--version") { @@ -1311,7 +1310,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } - if (!parse_kv_override(argv[i], params.kv_overrides)) { + if (!string_parse_kv_override(argv[i], params.kv_overrides)) { fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); invalid_param = true; return true; @@ -1345,88 +1344,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return false; } -void gpt_params_handle_model_default(gpt_params & params) { - if (!params.hf_repo.empty()) { - // short-hand to avoid specifying --hf-file -> default it to --model - if (params.hf_file.empty()) { - if (params.model.empty()) { - throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n"); - } - params.hf_file = params.model; - } else if (params.model.empty()) { - std::string cache_directory = get_cache_directory(); - const bool success = create_directory_with_parents(cache_directory); - if (!success) { - throw std::runtime_error("failed to create cache directory: " + cache_directory); - } - params.model = cache_directory + string_split(params.hf_file, '/').back(); - } - } else if (!params.model_url.empty()) { - if (params.model.empty()) { - auto f = string_split(params.model_url, '#').front(); - f = string_split(f, '?').front(); - f = string_split(f, '/').back(); - params.model = "models/" + f; - } - } else if (params.model.empty()) { - params.model = DEFAULT_MODEL_PATH; - } -} - -bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { - bool invalid_param = false; - std::string arg; - const std::string arg_prefix = "--"; - llama_sampling_params & sparams = params.sparams; - - for (int i = 1; i < argc; i++) { - arg = argv[i]; - if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { - std::replace(arg.begin(), arg.end(), '_', '-'); - } - if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) { - throw std::invalid_argument("error: unknown argument: " + arg); - } - if (invalid_param) { - throw std::invalid_argument("error: invalid parameter for argument: " + arg); - } - } - - if (params.prompt_cache_all && - (params.interactive || params.interactive_first || - params.instruct)) { - - throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); - } - - gpt_params_handle_model_default(params); - - if (params.escape) { - process_escapes(params.prompt); - process_escapes(params.input_prefix); - process_escapes(params.input_suffix); - process_escapes(sparams.cfg_negative_prompt); - for (auto & antiprompt : params.antiprompt) { - process_escapes(antiprompt); - } - } - - if (!params.kv_overrides.empty()) { - params.kv_overrides.emplace_back(); - params.kv_overrides.back().key[0] = 0; - } - - return true; -} - -void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { +void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { const llama_sampling_params & sparams = params.sparams; std::string sampler_type_chars; std::string sampler_type_names; for (const auto sampler_type : sparams.samplers_sequence) { sampler_type_chars += static_cast(sampler_type); - sampler_type_names += sampler_type_to_name_string(sampler_type) + ";"; + sampler_type_names += llama_sampling_type_to_str(sampler_type) + ";"; } sampler_type_names.pop_back(); @@ -1623,7 +1548,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { #endif // LOG_DISABLE_LOGS } -std::string get_system_info(const gpt_params & params) { +std::string gpt_params_get_system_info(const gpt_params & params) { std::ostringstream os; os << "system_info: n_threads = " << params.n_threads; @@ -1635,7 +1560,52 @@ std::string get_system_info(const gpt_params & params) { return os.str(); } -std::string gpt_random_prompt(std::mt19937 & rng) { +// +// String utils +// + +std::vector string_split(std::string input, char separator) { + std::vector parts; + size_t separator_pos = input.find(separator); + while (separator_pos != std::string::npos) { + std::string part = input.substr(0, separator_pos); + parts.emplace_back(part); + input = input.substr(separator_pos + 1); + separator_pos = input.find(separator); + } + parts.emplace_back(input); + return parts; +} + +std::string string_strip(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + while (start < end && std::isspace(str[start])) { + start++; + } + while (end > start && std::isspace(str[end - 1])) { + end--; + } + return str.substr(start, end - start); +} + +std::string string_get_sortable_timestamp() { + using clock = std::chrono::system_clock; + + const clock::time_point current_time = clock::now(); + const time_t as_time_t = clock::to_time_t(current_time); + char timestamp_no_ns[100]; + std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t)); + + const int64_t ns = std::chrono::duration_cast( + current_time.time_since_epoch() % 1000000000).count(); + char timestamp_ns[11]; + snprintf(timestamp_ns, 11, "%09" PRId64, ns); + + return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns); +} + +std::string string_random_prompt(std::mt19937 & rng) { const int r = rng() % 10; switch (r) { case 0: return "So"; @@ -1653,12 +1623,99 @@ std::string gpt_random_prompt(std::mt19937 & rng) { GGML_UNREACHABLE(); } -// Validate if a filename is safe to use -// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function -bool validate_file_name(const std::string & filename) { - if (!filename.length()) { - // Empty filename invalid - return false; +void string_process_escapes(std::string & input) { + std::size_t input_len = input.length(); + std::size_t output_idx = 0; + + for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) { + if (input[input_idx] == '\\' && input_idx + 1 < input_len) { + switch (input[++input_idx]) { + case 'n': input[output_idx++] = '\n'; break; + case 'r': input[output_idx++] = '\r'; break; + case 't': input[output_idx++] = '\t'; break; + case '\'': input[output_idx++] = '\''; break; + case '\"': input[output_idx++] = '\"'; break; + case '\\': input[output_idx++] = '\\'; break; + case 'x': + // Handle \x12, etc + if (input_idx + 2 < input_len) { + const char x[3] = { input[input_idx + 1], input[input_idx + 2], 0 }; + char *err_p = nullptr; + const long val = std::strtol(x, &err_p, 16); + if (err_p == x + 2) { + input_idx += 2; + input[output_idx++] = char(val); + break; + } + } + // fall through + default: input[output_idx++] = '\\'; + input[output_idx++] = input[input_idx]; break; + } + } else { + input[output_idx++] = input[input_idx]; + } + } + + input.resize(output_idx); +} + +bool string_parse_kv_override(const char * data, std::vector & overrides) { + const char * sep = strchr(data, '='); + if (sep == nullptr || sep - data >= 128) { + fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data); + return false; + } + llama_model_kv_override kvo; + std::strncpy(kvo.key, data, sep - data); + kvo.key[sep - data] = 0; + sep++; + if (strncmp(sep, "int:", 4) == 0) { + sep += 4; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; + kvo.val_i64 = std::atol(sep); + } else if (strncmp(sep, "float:", 6) == 0) { + sep += 6; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; + kvo.val_f64 = std::atof(sep); + } else if (strncmp(sep, "bool:", 5) == 0) { + sep += 5; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; + if (std::strcmp(sep, "true") == 0) { + kvo.val_bool = true; + } else if (std::strcmp(sep, "false") == 0) { + kvo.val_bool = false; + } else { + fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data); + return false; + } + } else if (strncmp(sep, "str:", 4) == 0) { + sep += 4; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; + if (strlen(sep) > 127) { + fprintf(stderr, "%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data); + return false; + } + strncpy(kvo.val_str, sep, 127); + kvo.val_str[127] = '\0'; + } else { + fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data); + return false; + } + overrides.emplace_back(std::move(kvo)); + return true; +} + +// +// Filesystem utils +// + +// Validate if a filename is safe to use +// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function +bool fs_validate_filename(const std::string & filename) { + if (!filename.length()) { + // Empty filename invalid + return false; } if (filename.length() > 255) { // Limit at common largest possible filename on Linux filesystems @@ -1724,171 +1781,245 @@ bool validate_file_name(const std::string & filename) { return true; } -// -// String utils -// +// returns true if successful, false otherwise +bool fs_create_directory_with_parents(const std::string & path) { +#ifdef _WIN32 + std::wstring_convert> converter; + std::wstring wpath = converter.from_bytes(path); -std::vector string_split(std::string input, char separator) { - std::vector parts; - size_t separator_pos = input.find(separator); - while (separator_pos != std::string::npos) { - std::string part = input.substr(0, separator_pos); - parts.emplace_back(part); - input = input.substr(separator_pos + 1); - separator_pos = input.find(separator); + // if the path already exists, check whether it's a directory + const DWORD attributes = GetFileAttributesW(wpath.c_str()); + if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return true; } - parts.emplace_back(input); - return parts; -} -std::string string_strip(const std::string & str) { - size_t start = 0; - size_t end = str.size(); - while (start < end && std::isspace(str[start])) { - start++; - } - while (end > start && std::isspace(str[end - 1])) { - end--; - } - return str.substr(start, end - start); -} + size_t pos_slash = 0; -std::vector sampler_types_from_names(const std::vector & names, bool allow_alt_names) { - std::unordered_map sampler_canonical_name_map { - {"top_k", llama_sampler_type::TOP_K}, - {"top_p", llama_sampler_type::TOP_P}, - {"typical_p", llama_sampler_type::TYPICAL_P}, - {"min_p", llama_sampler_type::MIN_P}, - {"tfs_z", llama_sampler_type::TFS_Z}, - {"temperature", llama_sampler_type::TEMPERATURE} - }; + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { + const std::wstring subpath = wpath.substr(0, pos_slash); + const wchar_t * test = subpath.c_str(); - // since samplers names are written multiple ways - // make it ready for both system names and input names - std::unordered_map sampler_alt_name_map { - {"top-k", llama_sampler_type::TOP_K}, - {"top-p", llama_sampler_type::TOP_P}, - {"nucleus", llama_sampler_type::TOP_P}, - {"typical-p", llama_sampler_type::TYPICAL_P}, - {"typical", llama_sampler_type::TYPICAL_P}, - {"min-p", llama_sampler_type::MIN_P}, - {"tfs-z", llama_sampler_type::TFS_Z}, - {"tfs", llama_sampler_type::TFS_Z}, - {"temp", llama_sampler_type::TEMPERATURE} - }; + const bool success = CreateDirectoryW(test, NULL); + if (!success) { + const DWORD error = GetLastError(); - std::vector sampler_types; - sampler_types.reserve(names.size()); - for (const auto & name : names) - { - auto sampler_item = sampler_canonical_name_map.find(name); - if (sampler_item != sampler_canonical_name_map.end()) - { - sampler_types.push_back(sampler_item->second); - } - else - { - if (allow_alt_names) - { - sampler_item = sampler_alt_name_map.find(name); - if (sampler_item != sampler_alt_name_map.end()) - { - sampler_types.push_back(sampler_item->second); + // if the path already exists, ensure that it's a directory + if (error == ERROR_ALREADY_EXISTS) { + const DWORD attributes = GetFileAttributesW(subpath.c_str()); + if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return false; } + } else { + return false; } } + + pos_slash += 1; } - return sampler_types; -} -std::vector sampler_types_from_chars(const std::string & names_string) { - std::unordered_map sampler_name_map { - {'k', llama_sampler_type::TOP_K}, - {'p', llama_sampler_type::TOP_P}, - {'y', llama_sampler_type::TYPICAL_P}, - {'m', llama_sampler_type::MIN_P}, - {'f', llama_sampler_type::TFS_Z}, - {'t', llama_sampler_type::TEMPERATURE} - }; + return true; +#else + // if the path already exists, check whether it's a directory + struct stat info; + if (stat(path.c_str(), &info) == 0) { + return S_ISDIR(info.st_mode); + } + + size_t pos_slash = 1; // skip leading slashes for directory creation + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) { + const std::string subpath = path.substr(0, pos_slash); + struct stat info; - std::vector sampler_types; - sampler_types.reserve(names_string.size()); - for (const auto & c : names_string) { - const auto sampler_item = sampler_name_map.find(c); - if (sampler_item != sampler_name_map.end()) { - sampler_types.push_back(sampler_item->second); + // if the path already exists, ensure that it's a directory + if (stat(subpath.c_str(), &info) == 0) { + if (!S_ISDIR(info.st_mode)) { + return false; + } + } else { + // create parent directories + const int ret = mkdir(subpath.c_str(), 0755); + if (ret != 0) { + return false; + } } + + pos_slash += 1; } - return sampler_types; + + return true; +#endif // _WIN32 } -std::string sampler_type_to_name_string(llama_sampler_type sampler_type) { - switch (sampler_type) { - case llama_sampler_type::TOP_K: return "top_k"; - case llama_sampler_type::TFS_Z: return "tfs_z"; - case llama_sampler_type::TYPICAL_P: return "typical_p"; - case llama_sampler_type::TOP_P: return "top_p"; - case llama_sampler_type::MIN_P: return "min_p"; - case llama_sampler_type::TEMPERATURE: return "temperature"; - default : return ""; +std::string fs_get_cache_directory() { + std::string cache_directory = ""; + if (getenv("LLAMA_CACHE")) { + cache_directory = std::getenv("LLAMA_CACHE"); + if (cache_directory.back() != DIRECTORY_SEPARATOR) { + cache_directory += DIRECTORY_SEPARATOR; + } + } else { +#ifdef __linux__ + if (std::getenv("XDG_CACHE_HOME")) { + cache_directory = std::getenv("XDG_CACHE_HOME"); + } else { + cache_directory = std::getenv("HOME") + std::string("/.cache/"); + } +#elif defined(__APPLE__) + cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); +#elif defined(_WIN32) + cache_directory = std::getenv("APPDATA"); +#endif // __linux__ + cache_directory += "llama.cpp"; + cache_directory += DIRECTORY_SEPARATOR; } + return cache_directory; } + // // Model utils // -struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) { - auto mparams = llama_model_default_params(); +std::tuple llama_init_from_gpt_params(gpt_params & params) { + auto mparams = llama_model_params_from_gpt_params(params); - if (params.n_gpu_layers != -1) { - mparams.n_gpu_layers = params.n_gpu_layers; - } - mparams.rpc_servers = params.rpc_servers.c_str(); - mparams.main_gpu = params.main_gpu; - mparams.split_mode = params.split_mode; - mparams.tensor_split = params.tensor_split; - mparams.use_mmap = params.use_mmap; - mparams.use_mlock = params.use_mlock; - mparams.check_tensors = params.check_tensors; - if (params.kv_overrides.empty()) { - mparams.kv_overrides = NULL; + llama_model * model = nullptr; + + if (!params.hf_repo.empty() && !params.hf_file.empty()) { + model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams); + } else if (!params.model_url.empty()) { + model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams); } else { - GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key"); - mparams.kv_overrides = params.kv_overrides.data(); + model = llama_load_model_from_file(params.model.c_str(), mparams); } - return mparams; -} - -static ggml_type kv_cache_type_from_str(const std::string & s) { - if (s == "f32") { - return GGML_TYPE_F32; - } - if (s == "f16") { - return GGML_TYPE_F16; - } - if (s == "q8_0") { - return GGML_TYPE_Q8_0; - } - if (s == "q4_0") { - return GGML_TYPE_Q4_0; - } - if (s == "q4_1") { - return GGML_TYPE_Q4_1; - } - if (s == "iq4_nl") { - return GGML_TYPE_IQ4_NL; - } - if (s == "q5_0") { - return GGML_TYPE_Q5_0; - } - if (s == "q5_1") { - return GGML_TYPE_Q5_1; + if (model == NULL) { + fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); + return std::make_tuple(nullptr, nullptr); } - throw std::runtime_error("Invalid cache type: " + s); -} + auto cparams = llama_context_params_from_gpt_params(params); + + llama_context * lctx = llama_new_context_with_model(model, cparams); + if (lctx == NULL) { + fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + + if (!params.control_vectors.empty()) { + if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; + if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); + + const auto cvec = llama_control_vector_load(params.control_vectors); + if (cvec.n_embd == -1) { + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + + int err = llama_control_vector_apply(lctx, + cvec.data.data(), + cvec.data.size(), + cvec.n_embd, + params.control_vector_layer_start, + params.control_vector_layer_end); + if (err) { + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + } + + for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { + const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]); + float lora_scale = std::get<1>(params.lora_adapter[i]); + int err = llama_model_apply_lora_from_file(model, + lora_adapter.c_str(), + lora_scale, + ((i > 0) || params.lora_base.empty()) + ? NULL + : params.lora_base.c_str(), + params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + } + + if (params.ignore_eos) { + params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + } + + if (params.warmup) { + LOG("warming up the model with an empty run\n"); + + std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); + llama_kv_cache_clear(lctx); + llama_synchronize(lctx); + llama_reset_timings(lctx); + } + + return std::make_tuple(model, lctx); +} + +struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) { + auto mparams = llama_model_default_params(); + + if (params.n_gpu_layers != -1) { + mparams.n_gpu_layers = params.n_gpu_layers; + } + mparams.rpc_servers = params.rpc_servers.c_str(); + mparams.main_gpu = params.main_gpu; + mparams.split_mode = params.split_mode; + mparams.tensor_split = params.tensor_split; + mparams.use_mmap = params.use_mmap; + mparams.use_mlock = params.use_mlock; + mparams.check_tensors = params.check_tensors; + if (params.kv_overrides.empty()) { + mparams.kv_overrides = NULL; + } else { + GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key"); + mparams.kv_overrides = params.kv_overrides.data(); + } + + return mparams; +} + +static ggml_type kv_cache_type_from_str(const std::string & s) { + if (s == "f32") { + return GGML_TYPE_F32; + } + if (s == "f16") { + return GGML_TYPE_F16; + } + if (s == "q8_0") { + return GGML_TYPE_Q8_0; + } + if (s == "q4_0") { + return GGML_TYPE_Q4_0; + } + if (s == "q4_1") { + return GGML_TYPE_Q4_1; + } + if (s == "iq4_nl") { + return GGML_TYPE_IQ4_NL; + } + if (s == "q5_0") { + return GGML_TYPE_Q5_0; + } + if (s == "q5_1") { + return GGML_TYPE_Q5_1; + } + + throw std::runtime_error("Invalid cache type: " + s); +} struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { auto cparams = llama_context_default_params(); @@ -1923,27 +2054,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param return cparams; } -void llama_batch_clear(struct llama_batch & batch) { - batch.n_tokens = 0; -} - -void llama_batch_add( - struct llama_batch & batch, - llama_token id, - llama_pos pos, - const std::vector & seq_ids, - bool logits) { - batch.token [batch.n_tokens] = id; - batch.pos [batch.n_tokens] = pos; - batch.n_seq_id[batch.n_tokens] = seq_ids.size(); - for (size_t i = 0; i < seq_ids.size(); ++i) { - batch.seq_id[batch.n_tokens][i] = seq_ids[i]; - } - batch.logits [batch.n_tokens] = logits; - - batch.n_tokens++; -} - #ifdef LLAMA_USE_CURL static bool starts_with(const std::string & str, const std::string & prefix) { @@ -2274,90 +2384,29 @@ struct llama_model * llama_load_model_from_hf( #endif // LLAMA_USE_CURL -std::tuple llama_init_from_gpt_params(gpt_params & params) { - auto mparams = llama_model_params_from_gpt_params(params); - - llama_model * model = nullptr; - - if (!params.hf_repo.empty() && !params.hf_file.empty()) { - model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams); - } else if (!params.model_url.empty()) { - model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams); - } else { - model = llama_load_model_from_file(params.model.c_str(), mparams); - } - - if (model == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return std::make_tuple(nullptr, nullptr); - } - - auto cparams = llama_context_params_from_gpt_params(params); - - llama_context * lctx = llama_new_context_with_model(model, cparams); - if (lctx == NULL) { - fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); - llama_free_model(model); - return std::make_tuple(nullptr, nullptr); - } - - if (!params.control_vectors.empty()) { - if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; - if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); - - const auto cvec = llama_control_vector_load(params.control_vectors); - if (cvec.n_embd == -1) { - llama_free(lctx); - llama_free_model(model); - return std::make_tuple(nullptr, nullptr); - } - - int err = llama_control_vector_apply(lctx, - cvec.data.data(), - cvec.data.size(), - cvec.n_embd, - params.control_vector_layer_start, - params.control_vector_layer_end); - if (err) { - llama_free(lctx); - llama_free_model(model); - return std::make_tuple(nullptr, nullptr); - } - } - - for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { - const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]); - float lora_scale = std::get<1>(params.lora_adapter[i]); - int err = llama_model_apply_lora_from_file(model, - lora_adapter.c_str(), - lora_scale, - ((i > 0) || params.lora_base.empty()) - ? NULL - : params.lora_base.c_str(), - params.n_threads); - if (err != 0) { - fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); - llama_free(lctx); - llama_free_model(model); - return std::make_tuple(nullptr, nullptr); - } - } - - if (params.ignore_eos) { - params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; - } +// +// Batch utils +// - if (params.warmup) { - LOG("warming up the model with an empty run\n"); +void llama_batch_clear(struct llama_batch & batch) { + batch.n_tokens = 0; +} - std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; - llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); - llama_kv_cache_clear(lctx); - llama_synchronize(lctx); - llama_reset_timings(lctx); +void llama_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + const std::vector & seq_ids, + bool logits) { + batch.token [batch.n_tokens] = id; + batch.pos [batch.n_tokens] = pos; + batch.n_seq_id[batch.n_tokens] = seq_ids.size(); + for (size_t i = 0; i < seq_ids.size(); ++i) { + batch.seq_id[batch.n_tokens][i] = seq_ids[i]; } + batch.logits [batch.n_tokens] = logits; - return std::make_tuple(model, lctx); + batch.n_tokens++; } // @@ -2412,379 +2461,44 @@ std::string llama_detokenize_spm(llama_context * ctx, const std::vector & tokens) { - std::string piece; - std::string result; - - for (size_t i = 0; i < tokens.size(); ++i) { - piece = llama_token_to_piece(ctx, tokens[i]); - - result += piece; - } - - // NOTE: the original tokenizer decodes bytes after collecting the pieces. - return result; -} - -bool llama_should_add_bos_token(const llama_model * model) { - const int add_bos = llama_add_bos_token(model); - - return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); -} - -// -// YAML utils -// - -// returns true if successful, false otherwise -bool create_directory_with_parents(const std::string & path) { -#ifdef _WIN32 - std::wstring_convert> converter; - std::wstring wpath = converter.from_bytes(path); - - // if the path already exists, check whether it's a directory - const DWORD attributes = GetFileAttributesW(wpath.c_str()); - if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { - return true; - } - - size_t pos_slash = 0; - - // process path from front to back, procedurally creating directories - while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { - const std::wstring subpath = wpath.substr(0, pos_slash); - const wchar_t * test = subpath.c_str(); - - const bool success = CreateDirectoryW(test, NULL); - if (!success) { - const DWORD error = GetLastError(); - - // if the path already exists, ensure that it's a directory - if (error == ERROR_ALREADY_EXISTS) { - const DWORD attributes = GetFileAttributesW(subpath.c_str()); - if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) { - return false; - } - } else { - return false; - } - } - - pos_slash += 1; - } - - return true; -#else - // if the path already exists, check whether it's a directory - struct stat info; - if (stat(path.c_str(), &info) == 0) { - return S_ISDIR(info.st_mode); - } - - size_t pos_slash = 1; // skip leading slashes for directory creation - - // process path from front to back, procedurally creating directories - while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) { - const std::string subpath = path.substr(0, pos_slash); - struct stat info; - - // if the path already exists, ensure that it's a directory - if (stat(subpath.c_str(), &info) == 0) { - if (!S_ISDIR(info.st_mode)) { - return false; - } - } else { - // create parent directories - const int ret = mkdir(subpath.c_str(), 0755); - if (ret != 0) { - return false; - } - } - - pos_slash += 1; - } - - return true; -#endif // _WIN32 -} - -std::string get_cache_directory() { - std::string cache_directory = ""; - if (getenv("LLAMA_CACHE")) { - cache_directory = std::getenv("LLAMA_CACHE"); - if (cache_directory.back() != DIRECTORY_SEPARATOR) { - cache_directory += DIRECTORY_SEPARATOR; - } - } else { -#ifdef __linux__ - if (std::getenv("XDG_CACHE_HOME")) { - cache_directory = std::getenv("XDG_CACHE_HOME"); - } else { - cache_directory = std::getenv("HOME") + std::string("/.cache/"); - } -#elif defined(__APPLE__) - cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); -#elif defined(_WIN32) - cache_directory = std::getenv("APPDATA"); -#endif // __linux__ - cache_directory += "llama.cpp"; - cache_directory += DIRECTORY_SEPARATOR; - } - return cache_directory; -} - -void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector & data) { - if (data.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - fprintf(stream, "%s: [", prop_name); - for (size_t i = 0; i < data.size() - 1; ++i) { - fprintf(stream, "%e, ", data[i]); - } - fprintf(stream, "%e]\n", data.back()); -} - -void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector & data) { - if (data.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - fprintf(stream, "%s: [", prop_name); - for (size_t i = 0; i < data.size() - 1; ++i) { - fprintf(stream, "%d, ", data[i]); - } - fprintf(stream, "%d]\n", data.back()); -} - -void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data) { - std::string data_str(data == NULL ? "" : data); - - if (data_str.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - size_t pos_start = 0; - size_t pos_found = 0; - - if (std::isspace(data_str[0]) || std::isspace(data_str.back())) { - data_str = std::regex_replace(data_str, std::regex("\n"), "\\n"); - data_str = std::regex_replace(data_str, std::regex("\""), "\\\""); - data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)"); - data_str = "\"" + data_str + "\""; - fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); - return; - } - - if (data_str.find('\n') == std::string::npos) { - fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); - return; - } - - fprintf(stream, "%s: |\n", prop_name); - while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) { - fprintf(stream, " %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str()); - pos_start = pos_found + 1; - } -} - -std::string get_sortable_timestamp() { - using clock = std::chrono::system_clock; - - const clock::time_point current_time = clock::now(); - const time_t as_time_t = clock::to_time_t(current_time); - char timestamp_no_ns[100]; - std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t)); - - const int64_t ns = std::chrono::duration_cast( - current_time.time_since_epoch() % 1000000000).count(); - char timestamp_ns[11]; - snprintf(timestamp_ns, 11, "%09" PRId64, ns); - - return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns); -} - -void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx, - const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { - const llama_sampling_params & sparams = params.sparams; - - fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); - fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); - fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false"); - fprintf(stream, "cpu_has_avx_vnni: %s\n", ggml_cpu_has_avx_vnni() ? "true" : "false"); - fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false"); - fprintf(stream, "cpu_has_cuda: %s\n", ggml_cpu_has_cuda() ? "true" : "false"); - fprintf(stream, "cpu_has_vulkan: %s\n", ggml_cpu_has_vulkan() ? "true" : "false"); - fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false"); - fprintf(stream, "cpu_has_kompute: %s\n", ggml_cpu_has_kompute() ? "true" : "false"); - fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false"); - fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false"); - fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false"); - fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false"); - fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false"); - fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); - fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false"); - fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false"); - fprintf(stream, "cpu_has_matmul_int8: %s\n", ggml_cpu_has_matmul_int8() ? "true" : "false"); - -#ifdef NDEBUG - fprintf(stream, "debug: false\n"); -#else - fprintf(stream, "debug: true\n"); -#endif // NDEBUG - - fprintf(stream, "model_desc: %s\n", model_desc); - fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx))); - -#ifdef __OPTIMIZE__ - fprintf(stream, "optimize: true\n"); -#else - fprintf(stream, "optimize: false\n"); -#endif // __OPTIMIZE__ - - fprintf(stream, "time: %s\n", timestamp.c_str()); - - fprintf(stream, "\n"); - fprintf(stream, "###############\n"); - fprintf(stream, "# User Inputs #\n"); - fprintf(stream, "###############\n"); - fprintf(stream, "\n"); - - fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); - fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); - dump_string_yaml_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str()); - fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale); - fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); - fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); - fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); - fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); - fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); - fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); - dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str()); - fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); - fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); - fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - - const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx))); - const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; - fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); - - dump_string_yaml_multiline(stream, "in_prefix", params.input_prefix.c_str()); - fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false"); - dump_string_yaml_multiline(stream, "in_suffix", params.input_prefix.c_str()); - fprintf(stream, "instruct: %s # default: false\n", params.instruct ? "true" : "false"); - fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false"); - fprintf(stream, "interactive_specials: %s # default: false\n", params.interactive_specials ? "true" : "false"); - fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false"); - fprintf(stream, "keep: %d # default: 0\n", params.n_keep); - fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); - - fprintf(stream, "logit_bias:\n"); - for (std::pair lb : sparams.logit_bias) { - if (ignore_eos && lb.first == logit_bias_eos->first) { - continue; - } - fprintf(stream, " %d: %f", lb.first, lb.second); - } - - fprintf(stream, "lora:\n"); - for (std::tuple la : params.lora_adapter) { - if (std::get<1>(la) != 1.0f) { - continue; - } - fprintf(stream, " - %s\n", std::get<0>(la).c_str()); - } - fprintf(stream, "lora_scaled:\n"); - for (std::tuple la : params.lora_adapter) { - if (std::get<1>(la) == 1.0f) { - continue; - } - fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la)); - } - fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); - fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); - fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep); - fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); - fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); - fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); - fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); - fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH); - fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); - fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); - fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); - fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); - fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); - fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); - fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false"); - fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); - fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); - fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present); - dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str()); - fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); - fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); - fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); - dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens); - fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false"); - fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat); - - fprintf(stream, "reverse_prompt:\n"); - for (std::string ap : params.antiprompt) { - size_t pos = 0; - while ((pos = ap.find('\n', pos)) != std::string::npos) { - ap.replace(pos, 1, "\\n"); - pos += 1; + piece = llama_token_to_piece(ctx, tokens[i]); + + // remove the leading space of the first non-BOS token + if (((tokens[0] == bos_id && i == 1) || (tokens[0] != bos_id && i == 0)) && piece[0] == ' ') { + piece = piece.substr(1); } - fprintf(stream, " - %s\n", ap.c_str()); + result += piece; } - fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base); - fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); - fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed); - fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); - fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); - fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); - fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); + return result; +} - const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); - dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector); +std::string llama_detokenize_bpe(llama_context * ctx, const std::vector & tokens) { + std::string piece; + std::string result; - fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); - fprintf(stream, "threads: %d # default: %u\n", params.n_threads, std::thread::hardware_concurrency()); - fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); - fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); - fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); - fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); - fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); - fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); + for (size_t i = 0; i < tokens.size(); ++i) { + piece = llama_token_to_piece(ctx, tokens[i]); + + result += piece; + } + + // NOTE: the original tokenizer decodes bytes after collecting the pieces. + return result; +} + +bool llama_should_add_bos_token(const llama_model * model) { + const int add_bos = llama_add_bos_token(model); + + return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); } // // KV cache utils // -void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) { +void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) { static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+"; printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d", @@ -2807,7 +2521,7 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) { printf("\n=== Done dumping\n"); } -void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) { +void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) { static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n", @@ -2855,6 +2569,10 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) { printf("\n=== Done dumping\n"); } +// +// Embedding utils +// + void llama_embd_normalize(const float * inp, float * out, int n) { double sum = 0.0; for (int i = 0; i < n; i++) { @@ -3039,3 +2757,225 @@ llama_control_vector_data llama_control_vector_load(const std::vector & data) { + if (data.empty()) { + fprintf(stream, "%s:\n", prop_name); + return; + } + + fprintf(stream, "%s: [", prop_name); + for (size_t i = 0; i < data.size() - 1; ++i) { + fprintf(stream, "%e, ", data[i]); + } + fprintf(stream, "%e]\n", data.back()); +} + +void yaml_dump_vector_int(FILE * stream, const char * prop_name, const std::vector & data) { + if (data.empty()) { + fprintf(stream, "%s:\n", prop_name); + return; + } + + fprintf(stream, "%s: [", prop_name); + for (size_t i = 0; i < data.size() - 1; ++i) { + fprintf(stream, "%d, ", data[i]); + } + fprintf(stream, "%d]\n", data.back()); +} + +void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data) { + std::string data_str(data == NULL ? "" : data); + + if (data_str.empty()) { + fprintf(stream, "%s:\n", prop_name); + return; + } + + size_t pos_start = 0; + size_t pos_found = 0; + + if (std::isspace(data_str[0]) || std::isspace(data_str.back())) { + data_str = std::regex_replace(data_str, std::regex("\n"), "\\n"); + data_str = std::regex_replace(data_str, std::regex("\""), "\\\""); + data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)"); + data_str = "\"" + data_str + "\""; + fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); + return; + } + + if (data_str.find('\n') == std::string::npos) { + fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); + return; + } + + fprintf(stream, "%s: |\n", prop_name); + while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) { + fprintf(stream, " %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str()); + pos_start = pos_found + 1; + } +} + +void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx, + const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { + const llama_sampling_params & sparams = params.sparams; + + fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); + fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); + fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); + fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false"); + fprintf(stream, "cpu_has_avx_vnni: %s\n", ggml_cpu_has_avx_vnni() ? "true" : "false"); + fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false"); + fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false"); + fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false"); + fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false"); + fprintf(stream, "cpu_has_cuda: %s\n", ggml_cpu_has_cuda() ? "true" : "false"); + fprintf(stream, "cpu_has_vulkan: %s\n", ggml_cpu_has_vulkan() ? "true" : "false"); + fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false"); + fprintf(stream, "cpu_has_kompute: %s\n", ggml_cpu_has_kompute() ? "true" : "false"); + fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false"); + fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false"); + fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false"); + fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false"); + fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false"); + fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false"); + fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); + fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false"); + fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false"); + fprintf(stream, "cpu_has_matmul_int8: %s\n", ggml_cpu_has_matmul_int8() ? "true" : "false"); + +#ifdef NDEBUG + fprintf(stream, "debug: false\n"); +#else + fprintf(stream, "debug: true\n"); +#endif // NDEBUG + + fprintf(stream, "model_desc: %s\n", model_desc); + fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx))); + +#ifdef __OPTIMIZE__ + fprintf(stream, "optimize: true\n"); +#else + fprintf(stream, "optimize: false\n"); +#endif // __OPTIMIZE__ + + fprintf(stream, "time: %s\n", timestamp.c_str()); + + fprintf(stream, "\n"); + fprintf(stream, "###############\n"); + fprintf(stream, "# User Inputs #\n"); + fprintf(stream, "###############\n"); + fprintf(stream, "\n"); + + fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); + fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); + yaml_dump_string_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str()); + fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale); + fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); + fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); + fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); + fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); + fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); + fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); + yaml_dump_string_multiline(stream, "grammar", sparams.grammar.c_str()); + fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); + fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); + fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); + + const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx))); + const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; + fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); + + yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str()); + fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false"); + yaml_dump_string_multiline(stream, "in_suffix", params.input_prefix.c_str()); + fprintf(stream, "instruct: %s # default: false\n", params.instruct ? "true" : "false"); + fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false"); + fprintf(stream, "interactive_specials: %s # default: false\n", params.interactive_specials ? "true" : "false"); + fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false"); + fprintf(stream, "keep: %d # default: 0\n", params.n_keep); + fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); + + fprintf(stream, "logit_bias:\n"); + for (std::pair lb : sparams.logit_bias) { + if (ignore_eos && lb.first == logit_bias_eos->first) { + continue; + } + fprintf(stream, " %d: %f", lb.first, lb.second); + } + + fprintf(stream, "lora:\n"); + for (std::tuple la : params.lora_adapter) { + if (std::get<1>(la) != 1.0f) { + continue; + } + fprintf(stream, " - %s\n", std::get<0>(la).c_str()); + } + fprintf(stream, "lora_scaled:\n"); + for (std::tuple la : params.lora_adapter) { + if (std::get<1>(la) == 1.0f) { + continue; + } + fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la)); + } + fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); + fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); + fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep); + fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); + fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); + fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); + fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); + fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH); + fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); + fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); + fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); + fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); + fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); + fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); + fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false"); + fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); + fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); + fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present); + yaml_dump_string_multiline(stream, "prompt", params.prompt.c_str()); + fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); + fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); + fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); + yaml_dump_vector_int(stream, "prompt_tokens", prompt_tokens); + fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false"); + fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat); + + fprintf(stream, "reverse_prompt:\n"); + for (std::string ap : params.antiprompt) { + size_t pos = 0; + while ((pos = ap.find('\n', pos)) != std::string::npos) { + ap.replace(pos, 1, "\\n"); + pos += 1; + } + + fprintf(stream, " - %s\n", ap.c_str()); + } + + fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base); + fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); + fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed); + fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); + fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); + fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); + fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); + + const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); + yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector); + + fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); + fprintf(stream, "threads: %d # default: %u\n", params.n_threads, std::thread::hardware_concurrency()); + fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); + fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); + fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); + fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); + fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); + fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); +} diff --git a/common/common.h b/common/common.h index a8e5e50e6b810..f68f3c2979b94 100644 --- a/common/common.h +++ b/common/common.h @@ -27,7 +27,7 @@ #define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0) #define print_build_info() do { \ - fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ + fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ } while(0) @@ -35,14 +35,18 @@ // build info extern int LLAMA_BUILD_NUMBER; -extern char const *LLAMA_COMMIT; -extern char const *LLAMA_COMPILER; -extern char const *LLAMA_BUILD_TARGET; +extern char const * LLAMA_COMMIT; +extern char const * LLAMA_COMPILER; +extern char const * LLAMA_BUILD_TARGET; struct llama_control_vector_load_info; -int get_math_cpu_count(); -int32_t get_num_physical_cores(); +// +// CPU utils +// + +int32_t cpu_get_num_physical_cores(); +int32_t cpu_get_num_math(); // // CLI argument parsing @@ -51,7 +55,7 @@ int32_t get_num_physical_cores(); struct gpt_params { uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed - int32_t n_threads = get_math_cpu_count(); + int32_t n_threads = cpu_get_num_math(); int32_t n_threads_draft = -1; int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) int32_t n_threads_batch_draft = -1; @@ -179,33 +183,34 @@ struct gpt_params { void gpt_params_handle_model_default(gpt_params & params); -bool parse_kv_override(const char * data, std::vector & overrides); - -bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params); - -bool gpt_params_parse(int argc, char ** argv, gpt_params & params); +bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params); +bool gpt_params_parse (int argc, char ** argv, gpt_params & params); +bool gpt_params_find_arg (int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param); +void gpt_params_print_usage(int argc, char ** argv, const gpt_params & params); -void gpt_print_usage(int argc, char ** argv, const gpt_params & params); +std::string gpt_params_get_system_info(const gpt_params & params); -bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param); - -std::string get_system_info(const gpt_params & params); +// +// String utils +// -std::string gpt_random_prompt(std::mt19937 & rng); +std::vector string_split(std::string input, char separator); -void process_escapes(std::string& input); +std::string string_strip(const std::string & str); +std::string string_get_sortable_timestamp(); +std::string string_random_prompt(std::mt19937 & rng); -bool validate_file_name(const std::string & filename); +bool string_parse_kv_override(const char * data, std::vector & overrides); +void string_process_escapes(std::string & input); // -// String utils +// Filesystem utils // -std::vector sampler_types_from_names(const std::vector & names, bool allow_alt_names); -std::vector sampler_types_from_chars(const std::string & names_string); -std::vector string_split(std::string input, char separator); -std::string string_strip(const std::string & str); -std::string sampler_type_to_name_string(llama_sampler_type sampler_type); +bool fs_validate_filename(const std::string & filename); +bool fs_create_directory_with_parents(const std::string & path); + +std::string fs_get_cache_directory(); // // Model utils @@ -276,30 +281,15 @@ std::string llama_detokenize_bpe( // defaults to true when model type is SPM, otherwise false. bool llama_should_add_bos_token(const llama_model * model); -// -// YAML utils -// - -bool create_directory_with_parents(const std::string & path); -std::string get_cache_directory(); -void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector & data); -void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector & data); -void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data); -std::string get_sortable_timestamp(); - -void dump_non_result_info_yaml( - FILE * stream, const gpt_params & params, const llama_context * lctx, - const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); - // // KV cache utils // // Dump the KV cache view with the number of sequences per cell. -void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); +void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80); // Dump the KV cache view showing individual sequences in each cell (long output). -void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); +void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40); // // Embedding utils @@ -333,6 +323,20 @@ llama_control_vector_data llama_control_vector_load(const std::vector & data); +void yaml_dump_vector_int (FILE * stream, const char * prop_name, const std::vector & data); +void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data); + +void yaml_dump_non_result_info( + FILE * stream, const gpt_params & params, const llama_context * lctx, + const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); + diff --git a/common/sampling.cpp b/common/sampling.cpp index 7fc2e2158d5c4..f1f80351637f0 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -125,7 +125,7 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) { std::string result = "CFG -> Penalties "; if (params.mirostat == 0) { for (auto sampler_type : params.samplers_sequence) { - const auto sampler_type_name = sampler_type_to_name_string(sampler_type); + const auto sampler_type_name = llama_sampling_type_to_str(sampler_type); if (!sampler_type_name.empty()) { result += "-> " + sampler_type_name + " "; } @@ -137,6 +137,87 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) { return result; } +std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) { + switch (sampler_type) { + case llama_sampler_type::TOP_K: return "top_k"; + case llama_sampler_type::TFS_Z: return "tfs_z"; + case llama_sampler_type::TYPICAL_P: return "typical_p"; + case llama_sampler_type::TOP_P: return "top_p"; + case llama_sampler_type::MIN_P: return "min_p"; + case llama_sampler_type::TEMPERATURE: return "temperature"; + default : return ""; + } +} + +std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names) { + std::unordered_map sampler_canonical_name_map { + {"top_k", llama_sampler_type::TOP_K}, + {"top_p", llama_sampler_type::TOP_P}, + {"typical_p", llama_sampler_type::TYPICAL_P}, + {"min_p", llama_sampler_type::MIN_P}, + {"tfs_z", llama_sampler_type::TFS_Z}, + {"temperature", llama_sampler_type::TEMPERATURE} + }; + + // since samplers names are written multiple ways + // make it ready for both system names and input names + std::unordered_map sampler_alt_name_map { + {"top-k", llama_sampler_type::TOP_K}, + {"top-p", llama_sampler_type::TOP_P}, + {"nucleus", llama_sampler_type::TOP_P}, + {"typical-p", llama_sampler_type::TYPICAL_P}, + {"typical", llama_sampler_type::TYPICAL_P}, + {"min-p", llama_sampler_type::MIN_P}, + {"tfs-z", llama_sampler_type::TFS_Z}, + {"tfs", llama_sampler_type::TFS_Z}, + {"temp", llama_sampler_type::TEMPERATURE} + }; + + std::vector sampler_types; + sampler_types.reserve(names.size()); + for (const auto & name : names) + { + auto sampler_item = sampler_canonical_name_map.find(name); + if (sampler_item != sampler_canonical_name_map.end()) + { + sampler_types.push_back(sampler_item->second); + } + else + { + if (allow_alt_names) + { + sampler_item = sampler_alt_name_map.find(name); + if (sampler_item != sampler_alt_name_map.end()) + { + sampler_types.push_back(sampler_item->second); + } + } + } + } + return sampler_types; +} + +std::vector llama_sampling_types_from_chars(const std::string & names_string) { + std::unordered_map sampler_name_map { + {'k', llama_sampler_type::TOP_K}, + {'p', llama_sampler_type::TOP_P}, + {'y', llama_sampler_type::TYPICAL_P}, + {'m', llama_sampler_type::MIN_P}, + {'f', llama_sampler_type::TFS_Z}, + {'t', llama_sampler_type::TEMPERATURE} + }; + + std::vector sampler_types; + sampler_types.reserve(names_string.size()); + for (const auto & c : names_string) { + const auto sampler_item = sampler_name_map.find(c); + if (sampler_item != sampler_name_map.end()) { + sampler_types.push_back(sampler_item->second); + } + } + return sampler_types; +} + // no reasons to expose this function in header static void sampler_queue( struct llama_context * ctx_main, diff --git a/common/sampling.h b/common/sampling.h index 655732ad17206..eeaa53b8bcd00 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -116,6 +116,11 @@ std::string llama_sampling_print(const llama_sampling_params & params); // Print sampling order into a string std::string llama_sampling_order_print(const llama_sampling_params & params); +std::string llama_sampling_type_to_str(llama_sampler_type sampler_type); + +std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); +std::vector llama_sampling_types_from_chars(const std::string & names_string); + // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function // Note: When using multiple sequences, it is the caller's responsibility to call diff --git a/common/train.cpp b/common/train.cpp index 0dbfd24df2314..2d41a1d29a83c 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -1380,7 +1380,7 @@ bool consume_common_train_arg( void finish_processing_train_args(struct train_params_common * params) { if (params->escape) { - process_escapes(params->sample_start); + string_process_escapes(params->sample_start); } } diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index be30d20bf8194..591bc6e57645c 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -48,7 +48,7 @@ int main(int argc, char ** argv) { params.prompt = "Hello my name is"; } - process_escapes(params.prompt); + string_process_escapes(params.prompt); // init LLM diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 0c921ed69badb..004399b5f7eb8 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -80,7 +80,7 @@ int main(int argc, char ** argv) { std::mt19937 rng(params.seed); if (params.random_prompt) { - params.prompt = gpt_random_prompt(rng); + params.prompt = string_random_prompt(rng); } llama_backend_init(); @@ -107,7 +107,7 @@ int main(int argc, char ** argv) { // print system information { fprintf(stderr, "\n"); - fprintf(stderr, "%s\n", get_system_info(params).c_str()); + fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str()); } // split the prompt into lines diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index e670d3769c7e8..51d67d6d97ae6 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -152,7 +152,7 @@ int main(int argc, char ** argv) { std::mt19937 rng(params.seed); if (params.random_prompt) { - params.prompt = gpt_random_prompt(rng); + params.prompt = string_random_prompt(rng); } llama_backend_init(); @@ -176,7 +176,7 @@ int main(int argc, char ** argv) { // print system information { fprintf(stderr, "\n"); - fprintf(stderr, "%s\n", get_system_info(params).c_str()); + fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str()); } bool OK = run(ctx, params); diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 82b19fc4f3bae..25a2351cc64d3 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -598,7 +598,7 @@ int main(int argc, char ** argv) { std::mt19937 rng(params.seed); if (params.random_prompt) { - params.prompt = gpt_random_prompt(rng); + params.prompt = string_random_prompt(rng); } sparams.dataset = params.prompt_file; @@ -667,7 +667,7 @@ int main(int argc, char ** argv) { // print system information { fprintf(stderr, "\n"); - fprintf(stderr, "%s\n", get_system_info(params).c_str()); + fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str()); } bool OK = compute_imatrix(ctx, params, compute_ppl, from_chunk); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index afac145f63934..539f781847893 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -50,9 +50,9 @@ static void write_logfile( return; } - const std::string timestamp = get_sortable_timestamp(); + const std::string timestamp = string_get_sortable_timestamp(); - const bool success = create_directory_with_parents(params.logdir); + const bool success = fs_create_directory_with_parents(params.logdir); if (!success) { fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n", __func__, params.logdir.c_str()); @@ -70,7 +70,7 @@ static void write_logfile( fprintf(logfile, "binary: infill\n"); char model_desc[128]; llama_model_desc(model, model_desc, sizeof(model_desc)); - dump_non_result_info_yaml(logfile, params, ctx, timestamp, input_tokens, model_desc); + yaml_dump_non_result_info(logfile, params, ctx, timestamp, input_tokens, model_desc); fprintf(logfile, "\n"); fprintf(logfile, "######################\n"); @@ -78,8 +78,8 @@ static void write_logfile( fprintf(logfile, "######################\n"); fprintf(logfile, "\n"); - dump_string_yaml_multiline(logfile, "output", output.c_str()); - dump_vector_int_yaml(logfile, "output_tokens", output_tokens); + yaml_dump_string_multiline(logfile, "output", output.c_str()); + yaml_dump_vector_int(logfile, "output_tokens", output_tokens); llama_dump_timing_info_yaml(logfile, ctx); fclose(logfile); @@ -236,7 +236,7 @@ int main(int argc, char ** argv) { // print system information { LOG_TEE("\n"); - LOG_TEE("%s\n", get_system_info(params).c_str()); + LOG_TEE("%s\n", gpt_params_get_system_info(params).c_str()); } const bool add_bos = llama_should_add_bos_token(model); GGML_ASSERT(llama_add_eos_token(model) != 1); @@ -621,8 +621,8 @@ int main(int argc, char ** argv) { if (params.escape) { //process escape sequences, for the initial prompt this is done in common.cpp when we load the params, but for the interactive mode we need to do it here - process_escapes(params.input_prefix); - process_escapes(params.input_suffix); + string_process_escapes(params.input_prefix); + string_process_escapes(params.input_suffix); } suff_rm_leading_spc = params.escape; if (suff_rm_leading_spc && params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 6bb1f70c3c8dc..2afdb3abdc278 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -200,7 +200,7 @@ static const cmd_params cmd_params_defaults = { /* n_ubatch */ {512}, /* type_k */ {GGML_TYPE_F16}, /* type_v */ {GGML_TYPE_F16}, - /* n_threads */ {get_math_cpu_count()}, + /* n_threads */ {cpu_get_num_math()}, /* n_gpu_layers */ {99}, /* split_mode */ {LLAMA_SPLIT_MODE_LAYER}, /* main_gpu */ {0}, diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index a6d67e5d72cd2..c974900f21e20 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -290,7 +290,7 @@ int main(int argc, char ** argv) { #endif // LOG_DISABLE_LOGS if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { - gpt_print_usage(argc, argv, params); + gpt_params_print_usage(argc, argv, params); show_additional_info(argc, argv); return 1; } diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 9c3540b2008c2..54f060a85b263 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -174,7 +174,7 @@ int main(int argc, char ** argv) { // debug if (dump_kv_cache) { llama_kv_cache_view_update(ctx, &kvc_view); - dump_kv_cache_view_seqs(kvc_view, 40); + llama_kv_cache_dump_view_seqs(kvc_view, 40); } // build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/ diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index eebbd00a58e66..83dbee91a8362 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -121,7 +121,7 @@ int main(int argc, char ** argv){ // debug if (dump_kv_cache) { llama_kv_cache_view_update(ctx, &kvc_view); - dump_kv_cache_view_seqs(kvc_view, 40); + llama_kv_cache_dump_view_seqs(kvc_view, 40); } // print current draft sequence diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 832b51ee086be..791dc61a72dda 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -60,9 +60,9 @@ static void write_logfile( return; } - const std::string timestamp = get_sortable_timestamp(); + const std::string timestamp = string_get_sortable_timestamp(); - const bool success = create_directory_with_parents(params.logdir); + const bool success = fs_create_directory_with_parents(params.logdir); if (!success) { fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n", __func__, params.logdir.c_str()); @@ -80,7 +80,7 @@ static void write_logfile( fprintf(logfile, "binary: main\n"); char model_desc[128]; llama_model_desc(model, model_desc, sizeof(model_desc)); - dump_non_result_info_yaml(logfile, params, ctx, timestamp, input_tokens, model_desc); + yaml_dump_non_result_info(logfile, params, ctx, timestamp, input_tokens, model_desc); fprintf(logfile, "\n"); fprintf(logfile, "######################\n"); @@ -88,8 +88,8 @@ static void write_logfile( fprintf(logfile, "######################\n"); fprintf(logfile, "\n"); - dump_string_yaml_multiline(logfile, "output", output.c_str()); - dump_vector_int_yaml(logfile, "output_tokens", output_tokens); + yaml_dump_string_multiline(logfile, "output", output.c_str()); + yaml_dump_vector_int(logfile, "output_tokens", output_tokens); llama_dump_timing_info_yaml(logfile, ctx); fclose(logfile); @@ -181,7 +181,7 @@ int main(int argc, char ** argv) { std::mt19937 rng(params.seed); if (params.random_prompt) { - params.prompt = gpt_random_prompt(rng); + params.prompt = string_random_prompt(rng); } LOG("%s: llama backend init\n", __func__); @@ -219,7 +219,7 @@ int main(int argc, char ** argv) { // print system information { LOG_TEE("\n"); - LOG_TEE("%s\n", get_system_info(params).c_str()); + LOG_TEE("%s\n", gpt_params_get_system_info(params).c_str()); } std::string path_session = params.path_prompt_cache; @@ -879,7 +879,7 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), cml_pfx.begin(), cml_pfx.end()); } if (params.escape) { - process_escapes(buffer); + string_process_escapes(buffer); } const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7c5595d6edb2d..c731abb726dc2 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -210,7 +210,7 @@ int main(int argc, char ** argv) { while (true) { if (dump_kv_cache) { llama_kv_cache_view_update(ctx, &kvc_view); - dump_kv_cache_view_seqs(kvc_view, 40); + llama_kv_cache_dump_view_seqs(kvc_view, 40); } llama_batch_clear(batch); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index bae014e6f4c16..30e5e282ef5cf 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -44,9 +44,9 @@ static void write_logfile( return; } - const std::string timestamp = get_sortable_timestamp(); + const std::string timestamp = string_get_sortable_timestamp(); - const bool success = create_directory_with_parents(params.logdir); + const bool success = fs_create_directory_with_parents(params.logdir); if (!success) { fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n", __func__, params.logdir.c_str()); @@ -64,7 +64,7 @@ static void write_logfile( fprintf(logfile, "binary: main\n"); char model_desc[128]; llama_model_desc(model, model_desc, sizeof(model_desc)); - dump_non_result_info_yaml(logfile, params, ctx, timestamp, results.tokens, model_desc); + yaml_dump_non_result_info(logfile, params, ctx, timestamp, results.tokens, model_desc); fprintf(logfile, "\n"); fprintf(logfile, "######################\n"); @@ -72,9 +72,9 @@ static void write_logfile( fprintf(logfile, "######################\n"); fprintf(logfile, "\n"); - dump_vector_float_yaml(logfile, "logits", results.logits); + yaml_dump_vector_float(logfile, "logits", results.logits); fprintf(logfile, "ppl_value: %f\n", results.ppl_value); - dump_vector_float_yaml(logfile, "probs", results.probs); + yaml_dump_vector_float(logfile, "probs", results.probs); llama_dump_timing_info_yaml(logfile, ctx); fclose(logfile); @@ -2007,7 +2007,7 @@ int main(int argc, char ** argv) { std::mt19937 rng(params.seed); if (params.random_prompt) { - params.prompt = gpt_random_prompt(rng); + params.prompt = string_random_prompt(rng); } llama_backend_init(); @@ -2035,7 +2035,7 @@ int main(int argc, char ** argv) { // print system information { fprintf(stderr, "\n"); - fprintf(stderr, "%s\n", get_system_info(params).c_str()); + fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str()); } struct results_perplexity results; diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index cbb452334de0d..28584e14b788c 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -259,7 +259,7 @@ int main(int argc, char ** argv) { usage(argv[0]); } } else if (strcmp(argv[arg_idx], "--override-kv") == 0) { - if (arg_idx == argc-1 || !parse_kv_override(argv[++arg_idx], kv_overrides)) { + if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) { usage(argv[0]); } } else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) { diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 5ba71e76a93b4..4e7530706d4a9 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -11,7 +11,7 @@ struct retrieval_params { }; static void retrieval_params_print_usage(int argc, char ** argv, gpt_params & gpt_params, retrieval_params & params) { - gpt_print_usage(argc, argv, gpt_params); + gpt_params_print_usage(argc, argv, gpt_params); printf("retrieval options:\n"); printf(" --context-file FNAME file containing context to embed.\n"); printf(" specify multiple files by providing --context-file option multiple times.\n"); @@ -226,7 +226,7 @@ int main(int argc, char ** argv) { // print system information { fprintf(stderr, "\n"); - fprintf(stderr, "%s\n", get_system_info(params).c_str()); + fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str()); } // max batch size diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6af5cb96e6d13..e9904263d53c7 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1019,7 +1019,7 @@ struct server_context { sampler_names.emplace_back(sampler_name); } } - slot.sparams.samplers_sequence = sampler_types_from_names(sampler_names, false); + slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false); } else { slot.sparams.samplers_sequence = default_sparams.samplers_sequence; } @@ -1256,7 +1256,7 @@ struct server_context { std::vector samplers_sequence; samplers_sequence.reserve(slot.sparams.samplers_sequence.size()); for (const auto & sampler_type : slot.sparams.samplers_sequence) { - samplers_sequence.emplace_back(sampler_type_to_name_string(sampler_type)); + samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type)); } return json { @@ -2852,7 +2852,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, invalid_param = true; break; } - if (!parse_kv_override(argv[i], params.kv_overrides)) { + if (!string_parse_kv_override(argv[i], params.kv_overrides)) { fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); invalid_param = true; break; @@ -3310,7 +3310,7 @@ int main(int argc, char ** argv) { const auto handle_slots_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); - if (!validate_file_name(filename)) { + if (!fs_validate_filename(filename)) { res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return; } @@ -3340,7 +3340,7 @@ int main(int argc, char ** argv) { const auto handle_slots_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); - if (!validate_file_name(filename)) { + if (!fs_validate_filename(filename)) { res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return; } From 197ff91462dd05bb9a3be03578114abf0c355536 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 22 May 2024 20:05:38 +0300 Subject: [PATCH 13/15] build : remove zig (#7471) --- .github/workflows/zig-build.yml | 29 ------ build.zig | 172 -------------------------------- 2 files changed, 201 deletions(-) delete mode 100644 .github/workflows/zig-build.yml delete mode 100644 build.zig diff --git a/.github/workflows/zig-build.yml b/.github/workflows/zig-build.yml deleted file mode 100644 index 747c35cc07a96..0000000000000 --- a/.github/workflows/zig-build.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: Zig CI - -on: - pull_request: - push: - branches: - - master - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -jobs: - build: - strategy: - fail-fast: false - matrix: - runs-on: [ubuntu-latest, macos-latest, windows-latest] - runs-on: ${{ matrix.runs-on }} - steps: - - uses: actions/checkout@v4 - with: - submodules: recursive - fetch-depth: 0 - - uses: goto-bus-stop/setup-zig@v2 - with: - version: 0.11.0 - - name: Build Summary - run: zig build --summary all -freference-trace diff --git a/build.zig b/build.zig deleted file mode 100644 index 267c976b14d1a..0000000000000 --- a/build.zig +++ /dev/null @@ -1,172 +0,0 @@ -// Compatible with Zig Version 0.11.0 -const std = @import("std"); -const ArrayList = std.ArrayList; -const Compile = std.Build.Step.Compile; -const ConfigHeader = std.Build.Step.ConfigHeader; -const Mode = std.builtin.Mode; -const CrossTarget = std.zig.CrossTarget; - -const Maker = struct { - builder: *std.build.Builder, - target: CrossTarget, - optimize: Mode, - enable_lto: bool, - - include_dirs: ArrayList([]const u8), - cflags: ArrayList([]const u8), - cxxflags: ArrayList([]const u8), - objs: ArrayList(*Compile), - - fn addInclude(m: *Maker, dir: []const u8) !void { - try m.include_dirs.append(dir); - } - fn addProjectInclude(m: *Maker, path: []const []const u8) !void { - try m.addInclude(try m.builder.build_root.join(m.builder.allocator, path)); - } - fn addCFlag(m: *Maker, flag: []const u8) !void { - try m.cflags.append(flag); - } - fn addCxxFlag(m: *Maker, flag: []const u8) !void { - try m.cxxflags.append(flag); - } - fn addFlag(m: *Maker, flag: []const u8) !void { - try m.addCFlag(flag); - try m.addCxxFlag(flag); - } - - fn init(builder: *std.build.Builder) !Maker { - const target = builder.standardTargetOptions(.{}); - const zig_version = @import("builtin").zig_version_string; - const commit_hash = try std.ChildProcess.exec( - .{ .allocator = builder.allocator, .argv = &.{ "git", "rev-parse", "HEAD" } }, - ); - try std.fs.cwd().writeFile("common/build-info.cpp", builder.fmt( - \\int LLAMA_BUILD_NUMBER = {}; - \\char const *LLAMA_COMMIT = "{s}"; - \\char const *LLAMA_COMPILER = "Zig {s}"; - \\char const *LLAMA_BUILD_TARGET = "{s}"; - \\ - , .{ 0, commit_hash.stdout[0 .. commit_hash.stdout.len - 1], zig_version, try target.allocDescription(builder.allocator) })); - var m = Maker{ - .builder = builder, - .target = target, - .optimize = builder.standardOptimizeOption(.{}), - .enable_lto = false, - .include_dirs = ArrayList([]const u8).init(builder.allocator), - .cflags = ArrayList([]const u8).init(builder.allocator), - .cxxflags = ArrayList([]const u8).init(builder.allocator), - .objs = ArrayList(*Compile).init(builder.allocator), - }; - - try m.addCFlag("-std=c11"); - try m.addCxxFlag("-std=c++11"); - try m.addProjectInclude(&.{}); - try m.addProjectInclude(&.{"common"}); - return m; - } - - fn obj(m: *const Maker, name: []const u8, src: []const u8) *Compile { - const o = m.builder.addObject(.{ .name = name, .target = m.target, .optimize = m.optimize }); - if (o.target.getAbi() != .msvc) - o.defineCMacro("_GNU_SOURCE", null); - - if (std.mem.endsWith(u8, src, ".c")) { - o.addCSourceFiles(&.{src}, m.cflags.items); - o.linkLibC(); - } else { - o.addCSourceFiles(&.{src}, m.cxxflags.items); - if (o.target.getAbi() == .msvc) { - o.linkLibC(); // need winsdk + crt - } else { - // linkLibCpp already add (libc++ + libunwind + libc) - o.linkLibCpp(); - } - } - for (m.include_dirs.items) |i| o.addIncludePath(.{ .path = i }); - o.want_lto = m.enable_lto; - return o; - } - - fn exe(m: *const Maker, name: []const u8, src: []const u8, deps: []const *Compile) *Compile { - const e = m.builder.addExecutable(.{ .name = name, .target = m.target, .optimize = m.optimize }); - e.addCSourceFiles(&.{src}, m.cxxflags.items); - for (deps) |d| e.addObject(d); - for (m.objs.items) |o| e.addObject(o); - for (m.include_dirs.items) |i| e.addIncludePath(.{ .path = i }); - - // https://github.com/ziglang/zig/issues/15448 - if (e.target.getAbi() == .msvc) { - e.linkLibC(); // need winsdk + crt - } else { - // linkLibCpp already add (libc++ + libunwind + libc) - e.linkLibCpp(); - } - m.builder.installArtifact(e); - e.want_lto = m.enable_lto; - return e; - } -}; - -pub fn build(b: *std.build.Builder) !void { - var make = try Maker.init(b); - make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false; - - const ggml = make.obj("ggml", "ggml.c"); - const sgemm = make.obj("sgemm", "sgemm.cpp"); - const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c"); - const ggml_backend = make.obj("ggml-backend", "ggml-backend.c"); - const ggml_quants = make.obj("ggml-quants", "ggml-quants.c"); - const unicode = make.obj("unicode", "unicode.cpp"); - const unicode_data = make.obj("unicode-data", "unicode-data.cpp"); - const llama = make.obj("llama", "llama.cpp"); - const buildinfo = make.obj("common", "common/build-info.cpp"); - const common = make.obj("common", "common/common.cpp"); - const console = make.obj("console", "common/console.cpp"); - const sampling = make.obj("sampling", "common/sampling.cpp"); - const grammar_parser = make.obj("grammar-parser", "common/grammar-parser.cpp"); - const json_schema_to_grammar = make.obj("json-schema-to-grammar", "common/json-schema-to-grammar.cpp"); - const train = make.obj("train", "common/train.cpp"); - const clip = make.obj("clip", "examples/llava/clip.cpp"); - const llava = make.obj("llava", "examples/llava/llava.cpp"); - - _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, sampling, json_schema_to_grammar, buildinfo, console, grammar_parser }); - _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, sampling, json_schema_to_grammar, buildinfo }); - _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, sampling, json_schema_to_grammar, buildinfo }); - _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, sampling, json_schema_to_grammar, buildinfo }); - _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, sampling, json_schema_to_grammar, buildinfo, train }); - _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train }); - - const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, sampling, json_schema_to_grammar, buildinfo, grammar_parser, clip, llava }); - if (server.target.isWindows()) { - server.linkSystemLibrary("ws2_32"); - } - - const server_assets = [_][]const u8{ "index.html", "index.js", "completion.js", "json-schema-to-grammar.mjs" }; - for (server_assets) |asset| { - const input_path = b.fmt("examples/server/public/{s}", .{asset}); - const output_path = b.fmt("examples/server/{s}.hpp", .{asset}); - - // Portable equivalent of `b.addSystemCommand(&.{ "xxd", "-n", asset, "-i", input_path, output_path }) })`: - - const input = try std.fs.cwd().readFileAlloc(b.allocator, input_path, std.math.maxInt(usize)); - defer b.allocator.free(input); - - var buf = std.ArrayList(u8).init(b.allocator); - defer buf.deinit(); - - for (input) |byte| { - try std.fmt.format(buf.writer(), "0x{X:0>2}, ", .{byte}); - } - - var name = try std.mem.replaceOwned(u8, b.allocator, asset, "-", "_"); - defer b.allocator.free(name); - std.mem.replaceScalar(u8, name, '.', '_'); - - try std.fs.cwd().writeFile(output_path, b.fmt( - "unsigned char {s}[] = {{{s}}};\nunsigned int {s}_len = {d};\n", - .{ name, buf.items, name, input.len }, - )); - - std.debug.print("Dumped hex of \"{s}\" ({s}) to {s}\n", .{ input_path, name, output_path }); - } -} From 1e374365d170b7f692fd7753c145e21bc14486c8 Mon Sep 17 00:00:00 2001 From: HanishKVC Date: Wed, 22 May 2024 23:23:21 +0530 Subject: [PATCH 14/15] SimpleChat: a simple and dumb web front end for testing /chat/completions and /completions end points and try chat (#7350) * SimpleChat: Add a skeletal html page Contains a div placeholder for showing chat messages till now a text-input for allowing user to enter next chat message/query to the model. a submit button to allow sending of the user entered message and chat till now to the model. * SimpleChat: A js skeleton with SimpleChat class Allows maintaining an array of chat message. Allows adding chat message (from any of the roles be it system, user, assistant, ...) Allows showing chat messages till now, in a given div element. * SimpleChat: request_json, globals, startme * SimpleChatJS: Roles Class, submitClick Define Role class with static members corresponding to the roles. Update startme to * Get hold of the ui elements. * Attach a click handler to submit button, which adds the user input to xchats array and shows the chat messages till now in chat div element. Trap DOMContentLoaded to trigger startme * SimpleChat:HTML: Bring in the js file * SimpleChat: Rather value wrt input text element * SimpleChat: Also add completions related prompt * SimpleChat: Use common helper logic wrt json data * SimpleChat: Move handling of submit request into its own func * SimpleChat: Try handshake with llm over its web service endpoint * SimpleChat:JS: Extract model response and show to user * SimpleChat:JS: Messages/Prompt, indicate working to end user * SimpleChat: Try keep input element in view * SimpleChat: Diff user/assistant msgs, Make input wider Also show a default message to user Also add some metas * SimpleChat: Move into its own sub directory to avoid confusion * SimpleChat:sh: Add simple shell script to run python3 http.server So one needs to run the llm server locally then run this script and access it using a local browser * SimpleChat:JS: Try trap enter key press wrt input text field So user can either press submit button or press enter key * SimpleChat: Allow user to select chat or completion mode * SimpleChat: Dont submit if already submitted and waiting Also make chat the default selection wrt mode * SimpleChat:JS: Handle difference in response Try read the assistance response from appropriate field in the response got. Also examples/server seems to return the response in a slightly different field, so try account for that also. * SimpleChat:JS: Force completion mode be single message by default * SimpleChat: Add a simple readme file * SimpleChat:HTML: Cleanup/structure UI a bit, Add input for system * SimpleChat:Allow system prompt to be set, if provided before user * SimpleChat: Ignore empty user input, without trimming * SimpleChat:Alert user if they provide sysprompt late or change it * SimpleChat: Move handling systemprompt into its own func * SimpleChat:HTML: Add a style for system role message * SimpleChat: Update the readme file * SimpleChat:CSS: Move style info into its own css file To keep it simple, clean and seperate so that things are not unnecessarily cluttered. * SimpleChat:CSS: Allow for chat div to be scrollable * SimpleChat:JS: Try ensure the last entry in chat is visible Needed because now only the chat div is scrollable and not the full page. In last commit the chat div size was fixed to 75% vertical height, so the full page no longer scrolls, so the old bring user-input element to view wont work, instead now the last element in the chat div should be brought into view. * SimpleChat:JS: bottom of element visible, Set focus to user input As the generated text could be multiple lines and occupy more space that the full scrollable div's vertical space, make the bottom of the last element (which can be such a generated text) in the div visible by scrolling. Ensure that the user input box has focus * SimpleChat: Update notes a bit. Try keep browser happy Avoid browser quirk mode with DOCTYPE. Help with accessibility a bit by specifying the language explicitly. Specify the char encoding explicitly, inturn utf-8 is a safe bet, even with intermixing of languages if reqd in future. Add a cache-control http-equiv meta tag, which in all probability will be ignored. Defer js loading and execution, just for fun and future, not that critical here as it stands now. * SimpleChat:HTML:Group user input+btn together; Note about multichat * SimpleChat:JS: Allow for changing system prompt anytime for future * SimpleChat:Readme: Note about handle_systemprompt begin/anytime * SimpleChat:HTML: Add viewport meta for better mobile friendliness Without this the page content may look too small. * SimpleChat:HtmlCss: Cleanup UI flow set margin wrt vmin rather than vw or vh so portrait/landscape ok. Use flex and flex-grow to put things on the same line as well as distribute available space as needed. Given two main elements/line so it remains simple. In each line have one element with grows and one sits with a basic comfortably fixed size. * SimpleChat: textarea for multiline user chat, inturn shift+enter 4 enter * SimpleChat: Make vertical layout better responsive (flex based) Also needed to make things cleaner and properly usable whether landscape or portrait, after changing to multiline textarea rather than single line user input. Avoid hardcoding the chat-till-now display area height, instead make it a flex-growable within a flex column of ui elements within a fixed vertical area. * SimpleChat: Rename simplechat.html to index.html, update readme Instead of providing a seperate shell script, update the readme wrt how to run/use this web front end. * SimpleChat: Screen fixed view and scrolling, Printing full * SimpleChat:JS:CI: Avoid space at end of jsdoc param line * SimpleChat:JS: MultiChat initial skeleton Will help maintain multiple independent chats in future * SimpleChat:JS: Move system prompt begin/anytime into SimpleChat * SimpleChat:JS:Keep MultiChatUI simple for now Worry about different chats with different servers for later. * SimpleChat:JS: Move handle submit into MultiChat, build on same Create an instance of MultiChatUI and inturn a instance of chat session, which is what the UI will inturn work on. * SimpleChat:JS: Move to dictionary of SimpleChat, instead of array * SimpleChat: Move ui elements into MultiChatUI, Update el IDs Move ui elements into MultiChatUI, so that current handleUserSubmit doesnt need to take the element arguments. Also in future, when user is allowed to switch between different chat sessions, the UI can be updated as needed by using the elements in UI already known to MultiChatUI instance. Rename the element ids' so that they follow a common convention, as well as one can identify what the element represents in a more consistant manner. * SimpleChat:MCUI:Show available chat sessions, try switch btw them Previous commits brought in / consolidated existing logic into MultiChatUI class. Now start adding logic towards multichat support * show buttons indicating available chat sessions * on sessin button click, try switch to that session * SimpleChat:MCUI: Store and use current chat session id Also allow to switch chat session optionally, wrt some of the related helpers. setup for two chat sessions by default. * SimpleChat:MCUI: Delay enabling user-input to avoid race Re-enable user-input, only after response to a user query has been updated to the chat-div. This ensures that if user tries to switch chat session, it wont be allowed till chat-request-response flow is done. * SimpleChat: Take care of system prompt Helper to get the latest system prompt and inturn use same to set the system prompt ui, when switching. Ensure that system prompt is set if and when enter key is pressed. * SimpleChat:GetSystemLatest, fix a oversight. * SimpleChat:MCUI: Allow selected chat-session btn to be highlighted Also have a general helper for setting class of children. * SimpleChat:Cleanup corners Show system prompt in chat space, when it is set by pressing enter, as a feedback to user. Alert user, if they try to switch chat session in the middle of waiting for a response from the ai model. * SimpleChat:MCUI: Ensure req-resp failure doesnt lock up things * SimpleChat:MCUI: Support for new chat sessions Also a general create button helper. * SimpleChat:MCUI: CreateSessionBtn helper, use wrt NewChat Also fix a oversight wrt using stale data wrt the list of chat sessions. * SimpleChat:MCUI: NewChat btn first before existing chat sessions * SimpleChat:MCUI:CornerCases:Skip new chat, show only if current Skip NewChat if user cancels or if one waiting for response from the ai model. Dont show a chat with newly got ai model response, if current chat session has changed, some how. Chat session shouldnt be allowed to change, if there is a pending response, but still as a additional sanity check. * SimpleChat: Update readme, title, show usage if no chat to show * SimpleChat: Cleanup the log/dialog messages a bit --- examples/server/public_simplechat/index.html | 52 ++ examples/server/public_simplechat/readme.md | 81 +++ .../server/public_simplechat/simplechat.css | 61 +++ .../server/public_simplechat/simplechat.js | 478 ++++++++++++++++++ 4 files changed, 672 insertions(+) create mode 100644 examples/server/public_simplechat/index.html create mode 100644 examples/server/public_simplechat/readme.md create mode 100644 examples/server/public_simplechat/simplechat.css create mode 100644 examples/server/public_simplechat/simplechat.js diff --git a/examples/server/public_simplechat/index.html b/examples/server/public_simplechat/index.html new file mode 100644 index 0000000000000..1eb390b85a69c --- /dev/null +++ b/examples/server/public_simplechat/index.html @@ -0,0 +1,52 @@ + + + + SimpleChat (LlamaCPP, ...) + + + + + + + + + + +
+ +
+

SimpleChat

+
+ + +
+
+ +
+ +
+
+ + +
+ +
+
+

Enter the system prompt above, before entering/submitting any user query.

+

Enter your text to the ai assistant below.

+

Use shift+enter for inserting enter.

+

Refresh the page to start over fresh.

+
+ +
+
+ + +
+ +
+ + diff --git a/examples/server/public_simplechat/readme.md b/examples/server/public_simplechat/readme.md new file mode 100644 index 0000000000000..5ac8258f21aca --- /dev/null +++ b/examples/server/public_simplechat/readme.md @@ -0,0 +1,81 @@ + +# SimpleChat + +by Humans for All. + + +## overview + +This simple web frontend, allows triggering/testing the server's /completions or /chat/completions endpoints +in a simple way with minimal code from a common code base. Inturn additionally it tries to allow single or +multiple independent back and forth chatting to an extent, with the ai llm model at a basic level, with their +own system prompts. + +The UI follows a responsive web design so that the layout can adapt to available display space in a usable +enough manner, in general. + +NOTE: Given that the idea is for basic minimal testing, it doesnt bother with any model context length and +culling of old messages from the chat. + +NOTE: It doesnt set any parameters other than temperature for now. However if someone wants they can update +the js file as needed. + + +## usage + +One could run this web frontend directly using server itself or if anyone is thinking of adding a built in web +frontend to configure the server over http(s) or so, then run this web frontend using something like python's +http module. + +### running using examples/server + +bin/server -m path/model.gguf --path ../examples/server/public_simplechat [--port PORT] + +### running using python3's server module + +first run examples/server +* bin/server -m path/model.gguf + +next run this web front end in examples/server/public_simplechat +* cd ../examples/server/public_simplechat +* python3 -m http.server PORT + +### using the front end + +Open this simple web front end from your local browser +* http://127.0.0.1:PORT/index.html + +Once inside +* Select between chat and completion mode. By default it is set to chat mode. +* If you want to provide a system prompt, then ideally enter it first, before entering any user query. + * if chat.add_system_begin is used + * you cant change the system prompt, after it is has been submitted once along with user query. + * you cant set a system prompt, after you have submitted any user query + * if chat.add_system_anytime is used + * one can change the system prompt any time during chat, by changing the contents of system prompt. + * inturn the updated/changed system prompt will be inserted into the chat session. + * this allows for the subsequent user chatting to be driven by the new system prompt set above. +* Enter your query and either press enter or click on the submit button. + If you want to insert enter (\n) as part of your chat/query to ai model, use shift+enter. +* Wait for the logic to communicate with the server and get the response. + * the user is not allowed to enter any fresh query during this time. + * the user input box will be disabled and a working message will be shown in it. +* just refresh the page, to reset wrt the chat history and or system prompt and start afresh. +* Using NewChat one can start independent chat sessions. + * two independent chat sessions are setup by default. + + +## Devel note + +Sometimes the browser may be stuborn with caching of the file, so your updates to html/css/js +may not be visible. Also remember that just refreshing/reloading page in browser or for that +matter clearing site data, dont directly override site caching in all cases. Worst case you may +have to change port. Or in dev tools of browser, you may be able to disable caching fully. + +Concept of multiple chat sessions with different servers, as well as saving and restoring of +those across browser usage sessions, can be woven around the SimpleChat/MultiChatUI class and +its instances relatively easily, however given the current goal of keeping this simple, it has +not been added, for now. + +By switching between chat.add_system_begin/anytime, one can control whether one can change +the system prompt, anytime during the conversation or only at the beginning. diff --git a/examples/server/public_simplechat/simplechat.css b/examples/server/public_simplechat/simplechat.css new file mode 100644 index 0000000000000..d45f50a957e4c --- /dev/null +++ b/examples/server/public_simplechat/simplechat.css @@ -0,0 +1,61 @@ +/** + * the styling of the simplechat web frontend + * by Humans for All + */ + +#fullbody { + height: 98vh; +} + +.heading { + background-color: lightgray; +} + +.session-selected { + background-color: lightblue; +} + +.role-system { + background-color: lightblue; +} +.role-user { + background-color: lightgray; +} + +.flex-grow { + flex-grow: 1; +} +.float-right { + float: right; +} + +#chat-div { + overflow: scroll; + flex-grow: 1; + flex-shrink: 1; + min-height: 40vh; +} +button { + min-width: 8vw; +} + +.sameline { + display: flex; + flex-direction: row; +} +.samecolumn { + display: flex; + flex-direction: column; +} + +* { + margin: 0.6vmin; +} + +@media print { + + #fullbody { + height: auto; + } + +} diff --git a/examples/server/public_simplechat/simplechat.js b/examples/server/public_simplechat/simplechat.js new file mode 100644 index 0000000000000..3fc4dbc2026fa --- /dev/null +++ b/examples/server/public_simplechat/simplechat.js @@ -0,0 +1,478 @@ +// @ts-check +// A simple completions and chat/completions test related web front end logic +// by Humans for All + +class Roles { + static System = "system"; + static User = "user"; + static Assistant = "assistant"; +} + +class ApiEP { + static Chat = "chat"; + static Completion = "completion"; +} + +let gUsageMsg = ` +

Enter the system prompt above, before entering/submitting any user query.

+

Enter your text to the ai assistant below.

+

Use shift+enter for inserting enter.

+

Refresh the page to start over fresh.

+`; + +class SimpleChat { + + constructor() { + /** + * Maintain in a form suitable for common LLM web service chat/completions' messages entry + * @type {{role: string, content: string}[]} + */ + this.xchat = []; + this.iLastSys = -1; + } + + /** + * Add an entry into xchat + * @param {string} role + * @param {string|undefined|null} content + */ + add(role, content) { + if ((content == undefined) || (content == null) || (content == "")) { + return false; + } + this.xchat.push( {role: role, content: content} ); + if (role == Roles.System) { + this.iLastSys = this.xchat.length - 1; + } + return true; + } + + /** + * Show the contents in the specified div + * @param {HTMLDivElement} div + * @param {boolean} bClear + */ + show(div, bClear=true) { + if (bClear) { + div.replaceChildren(); + } + let last = undefined; + for(const x of this.xchat) { + let entry = document.createElement("p"); + entry.className = `role-${x.role}`; + entry.innerText = `${x.role}: ${x.content}`; + div.appendChild(entry); + last = entry; + } + if (last !== undefined) { + last.scrollIntoView(false); + } else { + if (bClear) { + div.innerHTML = gUsageMsg; + } + } + } + + /** + * Add needed fields wrt json object to be sent wrt LLM web services completions endpoint + * Convert the json into string. + * @param {Object} obj + */ + request_jsonstr(obj) { + obj["temperature"] = 0.7; + return JSON.stringify(obj); + } + + /** + * Return a string form of json object suitable for chat/completions + */ + request_messages_jsonstr() { + let req = { + messages: this.xchat, + } + return this.request_jsonstr(req); + } + + /** + * Return a string form of json object suitable for /completions + */ + request_prompt_jsonstr() { + let prompt = ""; + for(const chat of this.xchat) { + prompt += `${chat.role}: ${chat.content}\n`; + } + let req = { + prompt: prompt, + } + return this.request_jsonstr(req); + } + + /** + * Allow setting of system prompt, but only at begining. + * @param {string} sysPrompt + * @param {string} msgTag + */ + add_system_begin(sysPrompt, msgTag) { + if (this.xchat.length == 0) { + if (sysPrompt.length > 0) { + return this.add(Roles.System, sysPrompt); + } + } else { + if (sysPrompt.length > 0) { + if (this.xchat[0].role !== Roles.System) { + console.error(`ERRR:SimpleChat:SC:${msgTag}:You need to specify system prompt before any user query, ignoring...`); + } else { + if (this.xchat[0].content !== sysPrompt) { + console.error(`ERRR:SimpleChat:SC:${msgTag}:You cant change system prompt, mid way through, ignoring...`); + } + } + } + } + return false; + } + + /** + * Allow setting of system prompt, at any time. + * @param {string} sysPrompt + * @param {string} msgTag + */ + add_system_anytime(sysPrompt, msgTag) { + if (sysPrompt.length <= 0) { + return false; + } + + if (this.iLastSys < 0) { + return this.add(Roles.System, sysPrompt); + } + + let lastSys = this.xchat[this.iLastSys].content; + if (lastSys !== sysPrompt) { + return this.add(Roles.System, sysPrompt); + } + return false; + } + + /** + * Retrieve the latest system prompt. + */ + get_system_latest() { + if (this.iLastSys == -1) { + return ""; + } + let sysPrompt = this.xchat[this.iLastSys].content; + return sysPrompt; + } + +} + + +let gBaseURL = "http://127.0.0.1:8080"; +let gChatURL = { + 'chat': `${gBaseURL}/chat/completions`, + 'completion': `${gBaseURL}/completions`, +} +const gbCompletionFreshChatAlways = true; + + +/** + * Set the class of the children, based on whether it is the idSelected or not. + * @param {HTMLDivElement} elBase + * @param {string} idSelected + * @param {string} classSelected + * @param {string} classUnSelected + */ +function el_children_config_class(elBase, idSelected, classSelected, classUnSelected="") { + for(let child of elBase.children) { + if (child.id == idSelected) { + child.className = classSelected; + } else { + child.className = classUnSelected; + } + } +} + +/** + * Create button and set it up. + * @param {string} id + * @param {(this: HTMLButtonElement, ev: MouseEvent) => any} callback + * @param {string | undefined} name + * @param {string | undefined} innerText + */ +function el_create_button(id, callback, name=undefined, innerText=undefined) { + if (!name) { + name = id; + } + if (!innerText) { + innerText = id; + } + let btn = document.createElement("button"); + btn.id = id; + btn.name = name; + btn.innerText = innerText; + btn.addEventListener("click", callback); + return btn; +} + + +class MultiChatUI { + + constructor() { + /** @type {Object} */ + this.simpleChats = {}; + /** @type {string} */ + this.curChatId = ""; + + // the ui elements + this.elInSystem = /** @type{HTMLInputElement} */(document.getElementById("system-in")); + this.elDivChat = /** @type{HTMLDivElement} */(document.getElementById("chat-div")); + this.elBtnUser = /** @type{HTMLButtonElement} */(document.getElementById("user-btn")); + this.elInUser = /** @type{HTMLInputElement} */(document.getElementById("user-in")); + this.elSelectApiEP = /** @type{HTMLSelectElement} */(document.getElementById("api-ep")); + this.elDivSessions = /** @type{HTMLDivElement} */(document.getElementById("sessions-div")); + + this.validate_element(this.elInSystem, "system-in"); + this.validate_element(this.elDivChat, "chat-div"); + this.validate_element(this.elInUser, "user-in"); + this.validate_element(this.elSelectApiEP, "api-ep"); + this.validate_element(this.elDivChat, "sessions-div"); + } + + /** + * Check if the element got + * @param {HTMLElement | null} el + * @param {string} msgTag + */ + validate_element(el, msgTag) { + if (el == null) { + throw Error(`ERRR:SimpleChat:MCUI:${msgTag} element missing in html...`); + } else { + console.debug(`INFO:SimpleChat:MCUI:${msgTag} Id[${el.id}] Name[${el["name"]}]`); + } + } + + /** + * Reset user input ui. + * * clear user input + * * enable user input + * * set focus to user input + */ + ui_reset_userinput() { + this.elInUser.value = ""; + this.elInUser.disabled = false; + this.elInUser.focus(); + } + + /** + * Setup the needed callbacks wrt UI, curChatId to defaultChatId and + * optionally switch to specified defaultChatId. + * @param {string} defaultChatId + * @param {boolean} bSwitchSession + */ + setup_ui(defaultChatId, bSwitchSession=false) { + + this.curChatId = defaultChatId; + if (bSwitchSession) { + this.handle_session_switch(this.curChatId); + } + + this.elBtnUser.addEventListener("click", (ev)=>{ + if (this.elInUser.disabled) { + return; + } + this.handle_user_submit(this.curChatId, this.elSelectApiEP.value).catch((/** @type{Error} */reason)=>{ + let msg = `ERRR:SimpleChat\nMCUI:HandleUserSubmit:${this.curChatId}\n${reason.name}:${reason.message}`; + console.debug(msg.replace("\n", ":")); + alert(msg); + this.ui_reset_userinput(); + }); + }); + + this.elInUser.addEventListener("keyup", (ev)=> { + // allow user to insert enter into their message using shift+enter. + // while just pressing enter key will lead to submitting. + if ((ev.key === "Enter") && (!ev.shiftKey)) { + this.elBtnUser.click(); + ev.preventDefault(); + } + }); + + this.elInSystem.addEventListener("keyup", (ev)=> { + // allow user to insert enter into the system prompt using shift+enter. + // while just pressing enter key will lead to setting the system prompt. + if ((ev.key === "Enter") && (!ev.shiftKey)) { + let chat = this.simpleChats[this.curChatId]; + chat.add_system_anytime(this.elInSystem.value, this.curChatId); + chat.show(this.elDivChat); + ev.preventDefault(); + } + }); + + } + + /** + * Setup a new chat session and optionally switch to it. + * @param {string} chatId + * @param {boolean} bSwitchSession + */ + new_chat_session(chatId, bSwitchSession=false) { + this.simpleChats[chatId] = new SimpleChat(); + if (bSwitchSession) { + this.handle_session_switch(chatId); + } + } + + /** + * Handle user query submit request, wrt specified chat session. + * @param {string} chatId + * @param {string} apiEP + */ + async handle_user_submit(chatId, apiEP) { + + let chat = this.simpleChats[chatId]; + + chat.add_system_anytime(this.elInSystem.value, chatId); + + let content = this.elInUser.value; + if (!chat.add(Roles.User, content)) { + console.debug(`WARN:SimpleChat:MCUI:${chatId}:HandleUserSubmit:Ignoring empty user input...`); + return; + } + chat.show(this.elDivChat); + + let theBody; + let theUrl = gChatURL[apiEP] + if (apiEP == ApiEP.Chat) { + theBody = chat.request_messages_jsonstr(); + } else { + theBody = chat.request_prompt_jsonstr(); + } + + this.elInUser.value = "working..."; + this.elInUser.disabled = true; + console.debug(`DBUG:SimpleChat:MCUI:${chatId}:HandleUserSubmit:${theUrl}:ReqBody:${theBody}`); + let resp = await fetch(theUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: theBody, + }); + + let respBody = await resp.json(); + console.debug(`DBUG:SimpleChat:MCUI:${chatId}:HandleUserSubmit:RespBody:${JSON.stringify(respBody)}`); + let assistantMsg; + if (apiEP == ApiEP.Chat) { + assistantMsg = respBody["choices"][0]["message"]["content"]; + } else { + try { + assistantMsg = respBody["choices"][0]["text"]; + } catch { + assistantMsg = respBody["content"]; + } + } + chat.add(Roles.Assistant, assistantMsg); + if (chatId == this.curChatId) { + chat.show(this.elDivChat); + } else { + console.debug(`DBUG:SimpleChat:MCUI:HandleUserSubmit:ChatId has changed:[${chatId}] [${this.curChatId}]`); + } + // Purposefully clear at end rather than begin of this function + // so that one can switch from chat to completion mode and sequece + // in a completion mode with multiple user-assistant chat data + // from before to be sent/occur once. + if ((apiEP == ApiEP.Completion) && (gbCompletionFreshChatAlways)) { + chat.xchat.length = 0; + } + this.ui_reset_userinput(); + } + + /** + * Show buttons for NewChat and available chat sessions, in the passed elDiv. + * If elDiv is undefined/null, then use this.elDivSessions. + * Take care of highlighting the selected chat-session's btn. + * @param {HTMLDivElement | undefined} elDiv + */ + show_sessions(elDiv=undefined) { + if (!elDiv) { + elDiv = this.elDivSessions; + } + elDiv.replaceChildren(); + // Btn for creating new chat session + let btnNew = el_create_button("New CHAT", (ev)=> { + if (this.elInUser.disabled) { + console.error(`ERRR:SimpleChat:MCUI:NewChat:Current session [${this.curChatId}] awaiting response, ignoring request...`); + alert("ERRR:SimpleChat\nMCUI:NewChat\nWait for response to pending query, before starting new chat session"); + return; + } + let chatId = `Chat${Object.keys(this.simpleChats).length}`; + let chatIdGot = prompt("INFO:SimpleChat\nMCUI:NewChat\nEnter id for new chat session", chatId); + if (!chatIdGot) { + console.error("ERRR:SimpleChat:MCUI:NewChat:Skipping based on user request..."); + return; + } + this.new_chat_session(chatIdGot, true); + this.create_session_btn(elDiv, chatIdGot); + el_children_config_class(elDiv, chatIdGot, "session-selected", ""); + }); + elDiv.appendChild(btnNew); + // Btns for existing chat sessions + let chatIds = Object.keys(this.simpleChats); + for(let cid of chatIds) { + let btn = this.create_session_btn(elDiv, cid); + if (cid == this.curChatId) { + btn.className = "session-selected"; + } + } + } + + create_session_btn(elDiv, cid) { + let btn = el_create_button(cid, (ev)=>{ + let target = /** @type{HTMLButtonElement} */(ev.target); + console.debug(`DBUG:SimpleChat:MCUI:SessionClick:${target.id}`); + if (this.elInUser.disabled) { + console.error(`ERRR:SimpleChat:MCUI:SessionClick:${target.id}:Current session [${this.curChatId}] awaiting response, ignoring switch...`); + alert("ERRR:SimpleChat\nMCUI:SessionClick\nWait for response to pending query, before switching"); + return; + } + this.handle_session_switch(target.id); + el_children_config_class(elDiv, target.id, "session-selected", ""); + }); + elDiv.appendChild(btn); + return btn; + } + + /** + * Switch ui to the specified chatId and set curChatId to same. + * @param {string} chatId + */ + async handle_session_switch(chatId) { + let chat = this.simpleChats[chatId]; + if (chat == undefined) { + console.error(`ERRR:SimpleChat:MCUI:HandleSessionSwitch:${chatId} missing...`); + return; + } + this.elInSystem.value = chat.get_system_latest(); + this.elInUser.value = ""; + chat.show(this.elDivChat); + this.elInUser.focus(); + this.curChatId = chatId; + console.log(`INFO:SimpleChat:MCUI:HandleSessionSwitch:${chatId} entered...`); + } + +} + + +let gMuitChat; +const gChatIds = [ "Default", "Other" ]; + +function startme() { + console.log("INFO:SimpleChat:StartMe:Starting..."); + gMuitChat = new MultiChatUI(); + for (let cid of gChatIds) { + gMuitChat.new_chat_session(cid); + } + gMuitChat.setup_ui(gChatIds[0]); + gMuitChat.show_sessions(); +} + +document.addEventListener("DOMContentLoaded", startme); From cd93a28cb1446319af5e2f4b416174c3a8e43546 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 23 May 2024 00:31:20 +0200 Subject: [PATCH 15/15] CUDA: fix FA out-of-bounds reads (#7479) --- ggml-cuda/fattn-tile-f16.cu | 2 +- ggml-cuda/fattn-tile-f32.cu | 2 +- ggml-cuda/fattn-vec-f16.cu | 6 +++--- ggml-cuda/fattn-vec-f32.cu | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 586d469c049d1..cdb5eaff79535 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -83,7 +83,7 @@ static __global__ void flash_attn_tile_ext_f16( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i]; + const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); } } diff --git a/ggml-cuda/fattn-tile-f32.cu b/ggml-cuda/fattn-tile-f32.cu index b6ef8eb48d992..5a3de2918c7a3 100644 --- a/ggml-cuda/fattn-tile-f32.cu +++ b/ggml-cuda/fattn-tile-f32.cu @@ -79,7 +79,7 @@ static __global__ void flash_attn_tile_ext_f32( #pragma unroll for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { - float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x]; + float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f); Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale; Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale; } diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu index 7352dcabf6291..808e8f36246a7 100644 --- a/ggml-cuda/fattn-vec-f16.cu +++ b/ggml-cuda/fattn-vec-f16.cu @@ -94,7 +94,7 @@ static __global__ void flash_attn_vec_ext_f16( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i]; + const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); } } @@ -212,7 +212,7 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ic0 + j_VKQ >= ne01) { + if (ncols > 2 && ic0 + j_VKQ >= ne01) { break; } @@ -227,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16( dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; } - if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) { + if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } #else diff --git a/ggml-cuda/fattn-vec-f32.cu b/ggml-cuda/fattn-vec-f32.cu index 11476a6c0fbbc..b4652301b87e0 100644 --- a/ggml-cuda/fattn-vec-f32.cu +++ b/ggml-cuda/fattn-vec-f32.cu @@ -91,7 +91,7 @@ static __global__ void flash_attn_vec_ext_f32( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - Q_h2[j][i0/WARP_SIZE] = Q_f2[j*(nb01/sizeof(float2)) + i]; + Q_h2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); Q_h2[j][i0/WARP_SIZE].x *= scale; Q_h2[j][i0/WARP_SIZE].y *= scale; } @@ -200,7 +200,7 @@ static __global__ void flash_attn_vec_ext_f32( #pragma unroll for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ic0 + j_VKQ >= ne01) { + if (ncols > 2 && ic0 + j_VKQ >= ne01) { break; } @@ -215,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32( dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; } - if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) { + if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } }