From dbe777d3feadcf0b2930992881318b035e2084a9 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Sun, 17 Nov 2024 11:45:30 +0800 Subject: [PATCH] feat: sync llama.cpp --- android/src/main/CMakeLists.txt | 8 +- cpp/common.cpp | 216 +- cpp/common.h | 17 +- cpp/ggml-aarch64.c | 3440 +-- cpp/ggml-aarch64.h | 20 - cpp/ggml-backend-reg.cpp | 204 + cpp/ggml-backend.cpp | 956 +- cpp/ggml-backend.h | 33 +- cpp/ggml-cpu-aarch64.c | 3560 +++ cpp/ggml-cpu-aarch64.h | 30 + cpp/ggml-cpu-impl.h | 243 - cpp/ggml-cpu-quants.c | 10822 +++++++ cpp/ggml-cpu-quants.h | 63 + cpp/ggml-cpu.c | 13975 +++++++++ cpp/ggml-cpu.cpp | 663 + cpp/ggml-cpu.h | 177 + cpp/ggml-impl.h | 357 +- cpp/ggml-metal.h | 16 +- cpp/ggml-metal.m | 756 +- cpp/ggml-quants.c | 10877 +------ cpp/ggml-quants.h | 203 +- cpp/ggml-threading.cpp | 12 + cpp/ggml-threading.h | 12 + cpp/ggml.c | 23997 +++------------- cpp/ggml.h | 199 +- cpp/llama-sampling.cpp | 7 +- cpp/llama.cpp | 230 +- cpp/llama.h | 7 +- cpp/sgemm.cpp | 608 + example/ios/.xcode.env.local | 2 +- llama.cpp | 2 +- scripts/bootstrap.sh | 45 +- scripts/common.h.patch | 8 +- ...d.cpp.patch => ggml-backend-reg.cpp.patch} | 24 +- scripts/ggml-cpu-aarch64.c.patch | 11 + scripts/ggml-metal.m.patch | 8 +- scripts/ggml-quants.c.patch | 11 + scripts/ggml.c.patch | 6 +- scripts/sgemm.cpp.patch | 12 + 39 files changed, 36208 insertions(+), 35629 deletions(-) create mode 100644 cpp/ggml-backend-reg.cpp create mode 100644 cpp/ggml-cpu-aarch64.c create mode 100644 cpp/ggml-cpu-aarch64.h create mode 100644 cpp/ggml-cpu-quants.c create mode 100644 cpp/ggml-cpu-quants.h create mode 100644 cpp/ggml-cpu.c create mode 100644 cpp/ggml-cpu.cpp create mode 100644 cpp/ggml-cpu.h create mode 100644 cpp/ggml-threading.cpp create mode 100644 cpp/ggml-threading.h rename scripts/{ggml-backend.cpp.patch => ggml-backend-reg.cpp.patch} (50%) create mode 100644 scripts/ggml-cpu-aarch64.c.patch create mode 100644 scripts/ggml-quants.c.patch create mode 100644 scripts/sgemm.cpp.patch diff --git a/android/src/main/CMakeLists.txt b/android/src/main/CMakeLists.txt index b539a678..40fbb68b 100644 --- a/android/src/main/CMakeLists.txt +++ b/android/src/main/CMakeLists.txt @@ -17,7 +17,13 @@ set( ${RNLLAMA_LIB_DIR}/ggml-aarch64.c ${RNLLAMA_LIB_DIR}/ggml-alloc.c ${RNLLAMA_LIB_DIR}/ggml-backend.cpp + ${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp ${RNLLAMA_LIB_DIR}/ggml.c + ${RNLLAMA_LIB_DIR}/ggml-cpu.c + ${RNLLAMA_LIB_DIR}/ggml-cpu.cpp + ${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.c + ${RNLLAMA_LIB_DIR}/ggml-cpu-quants.c + ${RNLLAMA_LIB_DIR}/ggml-threading.cpp ${RNLLAMA_LIB_DIR}/ggml-quants.c ${RNLLAMA_LIB_DIR}/common.cpp ${RNLLAMA_LIB_DIR}/json.hpp @@ -77,7 +83,7 @@ if (${ANDROID_ABI} STREQUAL "arm64-v8a") # https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md#cross-compile-using-android-ndk # llama.cpp will deal with the cpu features - # build_library("rnllama_v8_7" "-march=armv8.7-a") + # build_library("rnllama_v8_7" "-march=armv8.7-a") # TODO: Add support runtime check for cpu features # At the moment runtime check is failing. diff --git a/cpp/common.cpp b/cpp/common.cpp index dfaa0378..39f23482 100644 --- a/cpp/common.cpp +++ b/cpp/common.cpp @@ -1014,6 +1014,9 @@ static lm_ggml_type kv_cache_type_from_str(const std::string & s) { if (s == "f16") { return LM_GGML_TYPE_F16; } + if (s == "bf16") { + return LM_GGML_TYPE_BF16; + } if (s == "q8_0") { return LM_GGML_TYPE_Q8_0; } @@ -1898,216 +1901,3 @@ common_control_vector_data common_control_vector_load(const std::vector & data) { - if (data.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - fprintf(stream, "%s: [", prop_name); - for (size_t i = 0; i < data.size() - 1; ++i) { - fprintf(stream, "%e, ", data[i]); - } - fprintf(stream, "%e]\n", data.back()); -} - -void yaml_dump_vector_int(FILE * stream, const char * prop_name, const std::vector & data) { - if (data.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - fprintf(stream, "%s: [", prop_name); - for (size_t i = 0; i < data.size() - 1; ++i) { - fprintf(stream, "%d, ", data[i]); - } - fprintf(stream, "%d]\n", data.back()); -} - -void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data) { - std::string data_str(data == NULL ? "" : data); - - if (data_str.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - size_t pos_start = 0; - size_t pos_found = 0; - - if (std::isspace(data_str[0]) || std::isspace(data_str.back())) { - data_str = std::regex_replace(data_str, std::regex("\n"), "\\n"); - data_str = std::regex_replace(data_str, std::regex("\""), "\\\""); - data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)"); - data_str = "\"" + data_str + "\""; - fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); - return; - } - - if (data_str.find('\n') == std::string::npos) { - fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); - return; - } - - fprintf(stream, "%s: |\n", prop_name); - while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) { - fprintf(stream, " %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str()); - pos_start = pos_found + 1; - } -} - -void yaml_dump_non_result_info(FILE * stream, const common_params & params, const llama_context * lctx, - const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { - const auto & sparams = params.sparams; - - fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); - fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); - fprintf(stream, "cpu_has_arm_fma: %s\n", lm_ggml_cpu_has_arm_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_avx: %s\n", lm_ggml_cpu_has_avx() ? "true" : "false"); - fprintf(stream, "cpu_has_avx_vnni: %s\n", lm_ggml_cpu_has_avx_vnni() ? "true" : "false"); - fprintf(stream, "cpu_has_avx2: %s\n", lm_ggml_cpu_has_avx2() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512: %s\n", lm_ggml_cpu_has_avx512() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512_vbmi: %s\n", lm_ggml_cpu_has_avx512_vbmi() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512_vnni: %s\n", lm_ggml_cpu_has_avx512_vnni() ? "true" : "false"); - fprintf(stream, "cpu_has_cuda: %s\n", lm_ggml_cpu_has_cuda() ? "true" : "false"); - fprintf(stream, "cpu_has_vulkan: %s\n", lm_ggml_cpu_has_vulkan() ? "true" : "false"); - fprintf(stream, "cpu_has_kompute: %s\n", lm_ggml_cpu_has_kompute() ? "true" : "false"); - fprintf(stream, "cpu_has_fma: %s\n", lm_ggml_cpu_has_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_gpublas: %s\n", lm_ggml_cpu_has_gpublas() ? "true" : "false"); - fprintf(stream, "cpu_has_neon: %s\n", lm_ggml_cpu_has_neon() ? "true" : "false"); - fprintf(stream, "cpu_has_sve: %s\n", lm_ggml_cpu_has_sve() ? "true" : "false"); - fprintf(stream, "cpu_has_f16c: %s\n", lm_ggml_cpu_has_f16c() ? "true" : "false"); - fprintf(stream, "cpu_has_fp16_va: %s\n", lm_ggml_cpu_has_fp16_va() ? "true" : "false"); - fprintf(stream, "cpu_has_riscv_v: %s\n", lm_ggml_cpu_has_riscv_v() ? "true" : "false"); - fprintf(stream, "cpu_has_wasm_simd: %s\n", lm_ggml_cpu_has_wasm_simd() ? "true" : "false"); - fprintf(stream, "cpu_has_blas: %s\n", lm_ggml_cpu_has_blas() ? "true" : "false"); - fprintf(stream, "cpu_has_sse3: %s\n", lm_ggml_cpu_has_sse3() ? "true" : "false"); - fprintf(stream, "cpu_has_vsx: %s\n", lm_ggml_cpu_has_vsx() ? "true" : "false"); - fprintf(stream, "cpu_has_matmul_int8: %s\n", lm_ggml_cpu_has_matmul_int8() ? "true" : "false"); - -#ifdef NDEBUG - fprintf(stream, "debug: false\n"); -#else - fprintf(stream, "debug: true\n"); -#endif // NDEBUG - - fprintf(stream, "model_desc: %s\n", model_desc); - fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx))); - -#ifdef __OPTIMIZE__ - fprintf(stream, "optimize: true\n"); -#else - fprintf(stream, "optimize: false\n"); -#endif // __OPTIMIZE__ - - fprintf(stream, "time: %s\n", timestamp.c_str()); - - fprintf(stream, "\n"); - fprintf(stream, "###############\n"); - fprintf(stream, "# User Inputs #\n"); - fprintf(stream, "###############\n"); - fprintf(stream, "\n"); - - fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); - fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); - fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); - fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); - fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); - fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length); - fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base); - fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier); - fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n); - fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); - fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); - fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); - yaml_dump_string_multiline(stream, "grammar", sparams.grammar.c_str()); - fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); - fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); - fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false"); - - yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str()); - fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false"); - yaml_dump_string_multiline(stream, "in_suffix", params.input_prefix.c_str()); - fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false"); - fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false"); - fprintf(stream, "keep: %d # default: 0\n", params.n_keep); - fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); - - fprintf(stream, "logit_bias:\n"); - for (const auto & logit_bias : sparams.logit_bias) { - fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias); - } - - fprintf(stream, "lora:\n"); - for (auto & la : params.lora_adapters) { - if (la.scale == 1.0f) { - fprintf(stream, " - %s\n", la.path.c_str()); - } - } - fprintf(stream, "lora_scaled:\n"); - for (auto & la : params.lora_adapters) { - if (la.scale != 1.0f) { - fprintf(stream, " - %s: %f\n", la.path.c_str(), la.scale); - } - } - fprintf(stream, "lora_init_without_apply: %s # default: false\n", params.lora_init_without_apply ? "true" : "false"); - fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); - fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep); - fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); - fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); - fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); - fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); - fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH); - fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); - fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); - fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); - fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); - fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); - fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); - fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false"); - fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); - fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); - fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present); - yaml_dump_string_multiline(stream, "prompt", params.prompt.c_str()); - fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); - fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); - fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); - yaml_dump_vector_int(stream, "prompt_tokens", prompt_tokens); - fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat); - - fprintf(stream, "reverse_prompt:\n"); - for (std::string ap : params.antiprompt) { - size_t pos = 0; - while ((pos = ap.find('\n', pos)) != std::string::npos) { - ap.replace(pos, 1, "\\n"); - pos += 1; - } - - fprintf(stream, " - %s\n", ap.c_str()); - } - - fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base); - fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); - fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); - fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); - fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); - fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); - - const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); - yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector); - - fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency()); - fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); - fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); - fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); - fprintf(stream, "xtc_probability: %f # default: 0.0\n", sparams.xtc_probability); - fprintf(stream, "xtc_threshold: %f # default: 0.1\n", sparams.xtc_threshold); - fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p); - fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); - fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); -} diff --git a/cpp/common.h b/cpp/common.h index c60e5bef..4070df28 100644 --- a/cpp/common.h +++ b/cpp/common.h @@ -167,7 +167,7 @@ struct common_sampler_params { struct common_params { bool vocab_only = false; int32_t n_predict = -1; // new tokens to predict - int32_t n_ctx = 0; // context size + int32_t n_ctx = 4096; // context size int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt @@ -190,7 +190,7 @@ struct common_params { float yarn_beta_fast = 32.0f; // YaRN low correction dim float yarn_beta_slow = 1.0f; // YaRN high correction dim int32_t yarn_orig_ctx = 0; // YaRN original context length - float defrag_thold = -1.0f; // KV cache defragmentation threshold + float defrag_thold = 0.1f; // KV cache defragmentation threshold struct cpu_params cpuparams; struct cpu_params cpuparams_batch; @@ -221,7 +221,6 @@ struct common_params { std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT std::string input_prefix = ""; // string to prefix user inputs with // NOLINT std::string input_suffix = ""; // string to suffix user inputs with // NOLINT - std::string logdir = ""; // directory in which to save YAML log files // NOLINT std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT @@ -599,15 +598,3 @@ common_control_vector_data common_control_vector_load(const std::vector & data); -void yaml_dump_vector_int (FILE * stream, const char * prop_name, const std::vector & data); -void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data); - -void yaml_dump_non_result_info( - FILE * stream, const common_params & params, const llama_context * lctx, - const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); diff --git a/cpp/ggml-aarch64.c b/cpp/ggml-aarch64.c index d5f94193..2e396a91 100644 --- a/cpp/ggml-aarch64.c +++ b/cpp/ggml-aarch64.c @@ -1,204 +1,49 @@ -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// SPDX-License-Identifier: MIT -// - -#define LM_GGML_COMMON_IMPL_C +#define LM_GGML_COMMON_DECL_C #include "ggml-common.h" -#include "ggml-quants.h" +#include "ggml-aarch64.h" #include "ggml-impl.h" -#include "ggml-cpu-impl.h" - -#include -#include +#include "ggml-quants.h" #include -#include -#include // for qsort -#include // for LM_GGML_ASSERT - -#include "ggml-aarch64.h" - -#if defined(__GNUC__) -#pragma GCC diagnostic ignored "-Woverlength-strings" -#elif defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data -#endif #define UNUSED LM_GGML_UNUSED -// Functions to create the interleaved data layout formats - -// interleave 4 block_q4_0s in blocks of blck_size_interleave -// returns an interleaved block_q4_0x4 -// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks -// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave -// -// - in : an array of block_q4_0 pointers -// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of -// blck_size_interleave bytes -// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes -// from bias offset form to pure sign form (this saves subtract -// operations durin unpacking) -// -#if defined(__AVX__) -#if defined(__F16C__) -#if defined(__AVX512F__) -#define LM_GGML_F32Cx8x2_LOAD(x, y) _mm512_cvtph_ps(_mm256_set_m128i(_mm_loadu_si128((const __m128i *)(y)), _mm_loadu_si128((const __m128i *)(x)))) -#define LM_GGML_F32Cx16_REPEAT_LOAD(x) _mm512_cvtph_ps(_mm256_set_m128i(x, x)) -#endif -// the _mm256_cvt intrinsics require F16C -#define LM_GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) -#define LM_GGML_F32Cx8_REPEAT_LOAD(x, loadMask) _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68)) -#define LM_GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask)) -#else -#if defined(__AVX512F__) -static inline __m512 __avx512_f32cx8x2_load(lm_ggml_fp16_t *x, lm_ggml_fp16_t *y) { - float tmp[16]; - - for (int i = 0; i < 8; i++) { - tmp[i] = LM_GGML_FP16_TO_FP32(x[i]); - } - - for (int i = 0; i < 8; i++) { - tmp[i + 8] = LM_GGML_FP16_TO_FP32(y[i]); - } - - return _mm512_loadu_ps(tmp); -} -static inline __m512 __avx512_repeat_f32cx16_load(__m128i x) { - float tmp[16]; - uint16_t tmphalf[8]; - _mm_storeu_si128((__m128i*)tmphalf, x); - - for (int i = 0; i < 4; i++) { - tmp[i] = LM_GGML_FP16_TO_FP32(tmphalf[i]); - tmp[i + 4] = LM_GGML_FP16_TO_FP32(tmphalf[i]); - tmp[i + 8] = LM_GGML_FP16_TO_FP32(tmphalf[i]); - tmp[i + 12] = LM_GGML_FP16_TO_FP32(tmphalf[i]); - } - - return _mm512_loadu_ps(tmp); -} -#endif -static inline __m256 __avx_f32cx8_load(lm_ggml_fp16_t *x) { - float tmp[8]; - - for (int i = 0; i < 8; i++) { - tmp[i] = LM_GGML_FP16_TO_FP32(x[i]); - } - - return _mm256_loadu_ps(tmp); -} -static inline __m256 __avx_repeat_f32cx8_load(lm_ggml_fp16_t *x) { - float tmp[8]; - - for (int i = 0; i < 4; i++) { - tmp[i] = LM_GGML_FP16_TO_FP32(x[i]); - tmp[i + 4] = LM_GGML_FP16_TO_FP32(x[i]); - } - - return _mm256_loadu_ps(tmp); -} -static inline __m256 __avx_rearranged_f32cx8_load(lm_ggml_fp16_t *x, __m128i arrangeMask) { - uint16_t tmphalf[8]; - float tmp[8]; - - _mm_storeu_si128((__m128i*)tmphalf, _mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask)); - for (int i = 0; i < 8; i++) { - tmp[i] = LM_GGML_FP16_TO_FP32(tmphalf[i]); - } - - return _mm256_loadu_ps(tmp); -} - -#define LM_GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) -#define LM_GGML_F32Cx8_REPEAT_LOAD(x, loadMask) __avx_repeat_f32cx8_load(x) -#define LM_GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) __avx_rearranged_f32cx8_load(x, arrangeMask) -#if defined(__AVX512F__) -#define LM_GGML_F32Cx8x2_LOAD(x, y) __avx512_f32cx8x2_load(x, y) -#define LM_GGML_F32Cx16_REPEAT_LOAD(x) __avx512_repeat_f32cx16_load(x) -#endif -#endif -#endif - - -#if defined(__AVX2__) || defined(__AVX512F__) -#if defined(__AVX512F__) -// add int16_t pairwise and return as 512 bit int vector -static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) { - const __m512i ones = _mm512_set1_epi16(1); - return _mm512_madd_epi16(ones, x); -} - -static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) { -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) - const __m512i zero = _mm512_setzero_si512(); - return _mm512_dpbusd_epi32(zero, ax, sy); -#else - // Perform multiplication and create 16-bit values - const __m512i dot = _mm512_maddubs_epi16(ax, sy); - return sum_i16_pairs_int_32x16(dot); -#endif -} - -// multiply int8_t, add results pairwise twice and return as 512 bit int vector -static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) { - const __m512i zero = _mm512_setzero_si512(); - // Get absolute values of x vectors - const __m512i ax = _mm512_abs_epi8(x); - // Sign the values of the y vectors - __mmask64 blt0 = _mm512_movepi8_mask(x); - const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y); - return mul_sum_us8_pairs_int32x16(ax, sy); -} -#endif - -// add int16_t pairwise and return as 256 bit int vector -static inline __m256i sum_i16_pairs_int32x8(const __m256i x) { - const __m256i ones = _mm256_set1_epi16(1); - return _mm256_madd_epi16(ones, x); -} - -static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) { -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) - const __m256i zero = _mm256_setzero_si256(); - return _mm256_dpbusd_epi32(zero, ax, sy); -#else - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - return sum_i16_pairs_int32x8(dot); -#endif -} - -// Integer variant of the function defined in ggml-quants.c -// multiply int8_t, add results pairwise twice and return as 256 bit int vector -static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) { -#if __AVXVNNIINT8__ - const __m256i zero = _mm256_setzero_si256(); - return _mm256_dpbssd_epi32(zero, x, y); -#else - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(x, x); - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(y, x); - return mul_sum_us8_pairs_int32x8(ax, sy); -#endif -} -#endif - -static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { +static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) { block_q4_0x4 out; for (int i = 0; i < 4; i++) { out.d[i] = in[i].d; } - for (int i = 0; i < QK4_0 * 2; i++) { - int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave; - src_offset += (i % blck_size_interleave); + const int end = QK4_0 * 2 / blck_size_interleave; + + if (blck_size_interleave == 8) { + const uint64_t xor_mask = 0x8888888888888888ULL; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + // Using memcpy to avoid unaligned memory accesses + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + } else if (blck_size_interleave == 4) { + const uint32_t xor_mask = 0x88888888; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; - out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; + uint32_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t)); + } + } else { + LM_GGML_ASSERT(false); } return out; @@ -208,343 +53,28 @@ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_in // returns an interleaved block_q4_0x8 // in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks // first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave -static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { +static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) { block_q4_0x8 out; for (int i = 0; i < 8; i++) { out.d[i] = in[i].d; } - for (int i = 0; i < QK4_0 * 4; i++) { - int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave; - int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave; - src_offset += (i % blck_size_interleave); - - out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; - } - - return out; -} - -void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; - -#if defined(__ARM_NEON) - float32x4_t srcv[4][8]; - float id[4]; - - for (int i = 0; i < nb; i++) { - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int row_iter = 0; row_iter < 4; row_iter++) { - for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); - - for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); - for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); - for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); - - const float amax = vmaxvq_f32(amaxv[0]); + const int end = QK4_0 * 4 / blck_size_interleave; + const uint64_t xor_mask = 0x8888888888888888ULL; - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; - y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < 8; j++) { - float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]); - int32x4_t vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[1][j], id[1]); - vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[2][j], id[2]); - vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[3][j], id[3]); - vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3); - } + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); } -#else - // scalar - const int blck_size_interleave = 4; - float srcv[4][QK8_0]; - float id[4]; - - for (int i = 0; i < nb; i++) { - for (int row_iter = 0; row_iter < 4; row_iter++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; - amax = MAX(amax, fabsf(srcv[row_iter][j])); - } - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < QK8_0 * 4; j++) { - int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; - src_offset += (j % blck_size_interleave); - - float x0 = srcv[src_id][src_offset] * id[src_id]; - y[i].qs[j] = roundf(x0); - } - } -#endif -} - -void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; - -#if defined(__ARM_NEON) - float32x4_t srcv[4][8]; - float id[4]; - - for (int i = 0; i < nb; i++) { - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int row_iter = 0; row_iter < 4; row_iter++) { - for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); - - for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); - for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); - for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < 4; j++) { - float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]); - int32x4_t vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[1][2 * j], id[1]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[2][2 * j], id[2]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[3][2 * j], id[3]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3); - } - } -#elif defined(__AVX2__) || defined(__AVX__) - float id[4]; - __m256 srcv[4][4]; - __m256 idvec[4]; - - for (int i = 0; i < nb; i++) { - for (int row_iter = 0; row_iter < 4; row_iter++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 32 ); - __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 8 ); - __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 16 ); - __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 24 ); - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - // Divided by 127.f to mirror results in quantize_row_q8_0 - const float d = maxScalar / 127.f; - id[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; //d ? 1.0f / d : 0.0f; - - // Store the scale for the individual block - y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); - - // Store the values in blocks of eight values - Aim is to use these later for block interleaving - srcv[row_iter][0] = v0; - srcv[row_iter][1] = v1; - srcv[row_iter][2] = v2; - srcv[row_iter][3] = v3; - idvec[row_iter] = _mm256_set1_ps(id[row_iter]); - } - - // The loop iterates four times - The aim is to get 4 corresponding chunks of eight bytes from the original weight blocks that are interleaved - for (int j = 0; j < 4; j++) { - // Apply the multiplier - __m256 v0 = _mm256_mul_ps(srcv[0][j], idvec[0]); - __m256 v1 = _mm256_mul_ps(srcv[1][j], idvec[1]); - __m256 v2 = _mm256_mul_ps(srcv[2][j], idvec[2]); - __m256 v3 = _mm256_mul_ps(srcv[3][j], idvec[3]); - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - -#if defined(__AVX2__) - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); - i2 = _mm256_packs_epi32( i2, i3 ); - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); - - // Permute and store the quantized weights in the required order after the pack instruction - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0); -#else - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j), ni0); - _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j + 16), ni4); -#endif - } - } -#else - // scalar - const int blck_size_interleave = 8; - float srcv[4][QK8_0]; - float id[4]; - - for (int i = 0; i < nb; i++) { - for (int row_iter = 0; row_iter < 4; row_iter++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; - amax = MAX(amax, fabsf(srcv[row_iter][j])); - } - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < QK8_0 * 4; j++) { - int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; - src_offset += (j % blck_size_interleave); - - float x0 = srcv[src_id][src_offset] * id[src_id]; - y[i].qs[j] = roundf(x0); - } - } -#endif -} - -void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { - assert(nrow == 4); - UNUSED(nrow); - if (blck_size_interleave == 4) { - quantize_q8_0_4x4(x, vy, n_per_row); - } else if (blck_size_interleave == 8) { - quantize_q8_0_4x8(x, vy, n_per_row); - } else { - assert(false); - } + return out; } static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) { @@ -570,11 +100,11 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds } if (nrows_interleaved == 8) { - *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88); + *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave); out_ptr = (block_q4_0x8 *) out_ptr + 1; } else if (nrows_interleaved == 4) { - *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88); + *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave); out_ptr = (block_q4_0x4 *) out_ptr + 1; } } @@ -597,2881 +127,3 @@ size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_ UNUSED(quant_weights); return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); } - -void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) - if (lm_ggml_cpu_has_neon()) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - - __asm__ __volatile__( - "movi v31.16b, #0x4\n" - "movi v30.16b, #0xf0\n" - "add %x[b_ptr], %x[b_ptr], #0x8\n" - "1:" // Column loop - "add x22, %x[a_ptr], #0x2\n" - "movi v29.16b, #0x0\n" - "mov x21, %x[nb]\n" - "2:" // Block loop - "ldr q28, [%x[b_ptr], #0x0]\n" - "ldr q27, [x22, #0x0]\n" - "movi v26.4s, #0x0\n" - "sub x20, x22, #0x2\n" - "ldr q25, [x22, #0x10]\n" - "ldr q24, [%x[b_ptr], #0x10]\n" - "sub x21, x21, #0x1\n" - "add x22, x22, #0x22\n" - "ldr q23, [%x[b_ptr], #0x20]\n" - "ldr q22, [%x[b_ptr], #0x30]\n" - "ld1r { v21.8h }, [x20]\n" - "ldr q20, [%x[b_ptr], #-0x8]\n" - "sshl v16.16b, v28.16b, v31.16b\n" - "and v28.16b, v28.16b, v30.16b\n" - "sshl v19.16b, v24.16b, v31.16b\n" - "and v24.16b, v24.16b, v30.16b\n" - "add %x[b_ptr], %x[b_ptr], #0x48\n" - "sshl v18.16b, v23.16b, v31.16b\n" - "and v23.16b, v23.16b, v30.16b\n" - ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n" - "sshl v17.16b, v22.16b, v31.16b\n" - "and v22.16b, v22.16b, v30.16b\n" - "fcvtl v21.4s, v21.4h\n" - "fcvtl v16.4s, v20.4h\n" - ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n" - "fmul v16.4s, v16.4s, v21.4s\n" - ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n" - ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n" - ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n" - ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n" - ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n" - ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "fmla v29.4s, v26.4s, v16.4s\n" - "cbnz x21, 2b\n" - "sub %x[nc], %x[nc], #0x4\n" - "str q29, [%x[res_ptr], #0x0]\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "cbnz %x[nc], 1b\n" - : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) - : [a_ptr] "r" (a_ptr), [nb] "r" (nb) - : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22" - ); - return; - } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -} - -void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - - __asm__ __volatile__( - "movi v2.16b, #0x4\n" - "movi v1.16b, #0xf0\n" - "add %x[b_ptr], %x[b_ptr], #0x8\n" - "1:" // Column loop - "add x23, %x[a_ptr], #0x2\n" - "movi v0.16b, #0x0\n" - "mov x22, %x[nb]\n" - "2:" // Block loop - "ldr q31, [%x[b_ptr], #0x0]\n" - "ldr q30, [%x[b_ptr], #0x10]\n" - "mov x21, x23\n" - "movi v29.4s, #0x0\n" - "ldr q28, [%x[b_ptr], #0x20]\n" - "ldr q27, [%x[b_ptr], #0x30]\n" - "movi v26.4s, #0x0\n" - "sub x20, x23, #0x2\n" - "ld1r { v25.8h }, [x20]\n" - "ldr q24, [%x[b_ptr], #-0x8]\n" - "sub x22, x22, #0x1\n" - "add x23, x23, #0x22\n" - "ld1r { v23.2d }, [x21], #0x8\n" - "sshl v22.16b, v31.16b, v2.16b\n" - "sshl v16.16b, v30.16b, v2.16b\n" - "add %x[b_ptr], %x[b_ptr], #0x48\n" - "ld1r { v21.2d }, [x21], #0x8\n" - "sshl v20.16b, v28.16b, v2.16b\n" - "sshl v19.16b, v27.16b, v2.16b\n" - "ld1r { v18.2d }, [x21], #0x8\n" - "ld1r { v17.2d }, [x21], #0x8\n" - "and v31.16b, v31.16b, v1.16b\n" - "and v30.16b, v30.16b, v1.16b\n" - ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n" - ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n" - "and v28.16b, v28.16b, v1.16b\n" - "and v27.16b, v27.16b, v1.16b\n" - "fcvtl v25.4s, v25.4h\n" - "fcvtl v16.4s, v24.4h\n" - ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n" - ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n" - "fmul v16.4s, v16.4s, v25.4s\n" - ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n" - ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n" - ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n" - ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n" - "addp v29.4s, v29.4s, v26.4s\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "fmla v0.4s, v29.4s, v16.4s\n" - "cbnz x22, 2b\n" - "sub %x[nc], %x[nc], #0x4\n" - "str q0, [%x[res_ptr], #0x0]\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "cbnz %x[nc], 1b\n" - : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) - : [a_ptr] "r" (a_ptr), [nb] "r" (nb) - : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23" - ); - return; - } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -} - -void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) -#if defined(__ARM_FEATURE_SVE) - if (lm_ggml_cpu_has_sve() && lm_ggml_cpu_get_sve_cnt() == QK8_0) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - - __asm__ __volatile__( - "ptrue p0.b\n" - "add %x[b_ptr], %x[b_ptr], #0x10\n" - "1:" // Column loop - "add x22, %x[a_ptr], #0x2\n" - "mov z31.b, #0x0\n" - "mov x21, %x[nb]\n" - "2:" // Block loop - "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n" - "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n" - "mov z28.s, #0x0\n" - "mov z27.s, #0x0\n" - "ld1rd { z26.d }, p0/Z, [x22]\n" - "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n" - "sub x20, x22, #0x2\n" - "sub x21, x21, #0x1\n" - "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n" - "ld1rd { z23.d }, p0/Z, [x22, #8]\n" - "lsl z22.b, z30.b, #0x4\n" - "lsl z16.b, z29.b, #0x4\n" - "and z30.b, z30.b, #0xf0\n" - "and z29.b, z29.b, #0xf0\n" - "ld1rd { z21.d }, p0/Z, [x22, #16]\n" - "ld1rd { z20.d }, p0/Z, [x22, #24]\n" - "lsl z19.b, z25.b, #0x4\n" - "and z25.b, z25.b, #0xf0\n" - "ld1rh { z17.h }, p0/Z, [x20]\n" - "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n" - "sdot z28.s, z22.b, z26.b\n" - "sdot z27.s, z16.b, z26.b\n" - "lsl z16.b, z24.b, #0x4\n" - "add x22, x22, #0x22\n" - "and z24.b, z24.b, #0xf0\n" - "add %x[b_ptr], %x[b_ptr], #0x90\n" - "fcvt z17.s, p0/m, z17.h\n" - "fcvt z18.s, p0/m, z18.h\n" - "sdot z28.s, z19.b, z23.b\n" - "sdot z27.s, z16.b, z23.b\n" - "fmul z18.s, z18.s, z17.s\n" - "sdot z28.s, z30.b, z21.b\n" - "sdot z27.s, z29.b, z21.b\n" - "sdot z28.s, z25.b, z20.b\n" - "sdot z27.s, z24.b, z20.b\n" - "uzp1 z17.s, z28.s, z27.s\n" - "uzp2 z16.s, z28.s, z27.s\n" - "add z17.s, z17.s, z16.s\n" - "asr z17.s, z17.s, #0x4\n" - "scvtf z17.s, p0/m, z17.s\n" - "fmla z31.s, p0/M, z17.s, z18.s\n" - "cbnz x21, 2b\n" - "sub %x[nc], %x[nc], #0x8\n" - "st1w { z31.s }, p0, [%x[res_ptr]]\n" - "add %x[res_ptr], %x[res_ptr], #0x20\n" - "cbnz %x[nc], 1b\n" - : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) - : [a_ptr] "r" (a_ptr), [nb] "r" (nb) - : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" - ); - return; - } -#endif // #if defined(__ARM_FEATURE_SVE) -#elif defined(__AVX2__) - // Lookup table to convert signed nibbles to signed bytes - __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); - signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); - __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); - __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); - - // Permute mask used for easier vector processing at later stages - const __m256i m4b = _mm256_set1_epi8(0x0F); - - int64_t b_nb = n / QK4_0; - - const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; - const block_q8_0 * a_ptr_start = (const block_q8_0 *)vy; - - // Process Q8_0 blocks one by one - for (int64_t y = 0; y < nr; y++) { - - // Pointers to LHS blocks of block_q8_0 format - const block_q8_0 * a_ptr = a_ptr_start + (y * nb); - - // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation - for (int64_t x = 0; x < nc / 8; x++) { - - // Pointers to RHS blocks - const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); - - // Master FP accumulator - __m256 acc_row = _mm256_setzero_ps(); - - for (int64_t b = 0; b < nb; b++) { - // Load 8 blocks of Q4_0 interleaved as 8 bytes (B0 - B7) - const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); - const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 1); - const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 2); - const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 3); - - // 4-bit -> 8-bit - Sign is maintained - const __m256i rhs_vec_0123_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_0, m4b)); // B0(0-7) B1(0-7) B2(0-7) B3(0-7) - const __m256i rhs_vec_4567_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_0, m4b)); // B4(0-7) B5(0-7) B6(0-7) B7(0-7) - const __m256i rhs_vec_0123_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15) - const __m256i rhs_vec_4567_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15) - - const __m256i rhs_vec_0123_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b)); // B0(16-23) B1(16-23) B2(16-23) B3(16-23) - const __m256i rhs_vec_4567_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b)); // B4(16-23) B5(16-23) B6(16-23) B7(16-23) - const __m256i rhs_vec_0123_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b)); // B0(24-31) B1(24-31) B2(24-31) B3(24-31) - const __m256i rhs_vec_4567_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b)); // B4(24-31) B5(24-31) B6(24-31) B7(24-31) - - // Load the scale values for the 8 blocks interleaved in block_q4_0x8 - const __m256 col_scale_f32 = LM_GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); - - // Load and convert to FP32 scale from block_q8_0 - const __m256 row_scale_f32 = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(a_ptr[b].d)); - - // Load the block values in block_q8_0 in batches of 16 bytes and replicate the same across 256 bit vector - __m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)a_ptr[b].qs)); - __m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16))); - - lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0); // A0 (0-15) A0(0-15) - lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0); // A0 (16-31) A0(16-31)) - - __m256i iacc = _mm256_setzero_si256(); - - // Dot product done within 32 bit lanes and accumulated in the same vector - // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3) - // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7) - // ........................................................................... - // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31) - - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85))); - - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255))); - - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85))); - - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255))); - - // Accumulated values multipled with appropriate scales - acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row); - } - - // Accumulated output values permuted so as to be stored in appropriate order post accumulation - acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask); - _mm256_storeu_ps(s + (y * nr + x * 8), acc_row); - } - } - return; -#elif defined(__riscv_v_intrinsic) - if (__riscv_vlenb() >= QK4_0) { - const size_t vl = QK4_0; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - - vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - for (int l = 0; l < nb; l++) { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4)); - - const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); - const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); - const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); - const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); - const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); - const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); - const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); - - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - // vector version needs Zvfhmin extension - const float a_scale = LM_GGML_FP16_TO_FP32(a_ptr[l].d); - const float b_scales[8] = { - LM_GGML_FP16_TO_FP32(b_ptr[l].d[0]), - LM_GGML_FP16_TO_FP32(b_ptr[l].d[1]), - LM_GGML_FP16_TO_FP32(b_ptr[l].d[2]), - LM_GGML_FP16_TO_FP32(b_ptr[l].d[3]), - LM_GGML_FP16_TO_FP32(b_ptr[l].d[4]), - LM_GGML_FP16_TO_FP32(b_ptr[l].d[5]), - LM_GGML_FP16_TO_FP32(b_ptr[l].d[6]), - LM_GGML_FP16_TO_FP32(b_ptr[l].d[7]) - }; - const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4); - sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4); - } - __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4); - } - return; - } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) - { - float sumf[8]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } - } -} - -void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) - if (lm_ggml_cpu_has_neon()) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); - - __asm__ __volatile__( - "mov x10, %x[nr]\n" - "mov x9, #0x88\n" - "cmp x10, #0x10\n" - "mul x9, %x[nb], x9\n" - "blt 4f\n" - "1:" // Row loop - "add x28, %x[b_ptr], #0x8\n" - "mov x27, %x[nc]\n" - "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x25, %x[a_ptr], #0x8\n" - "movi v15.16b, #0x0\n" - "movi v19.16b, #0x0\n" - "mov x24, %x[nb]\n" - "add x23, x25, x9\n" - "movi v18.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "add x22, x23, x9\n" - "movi v11.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "add x21, x22, x9\n" - "movi v23.16b, #0x0\n" - "movi v16.16b, #0x0\n" - "movi v25.16b, #0x0\n" - "movi v7.16b, #0x0\n" - "movi v0.16b, #0x0\n" - "movi v4.16b, #0x0\n" - "movi v5.16b, #0x0\n" - "movi v21.16b, #0x0\n" - "movi v8.16b, #0x0\n" - "movi v1.16b, #0x0\n" - "3:" // Block loop - "ldr q3, [x28, #0x0]\n" - "ldr q31, [x25, #0x0]\n" - "movi v28.16b, #0x4\n" - "movi v10.4s, #0x0\n" - "ldr q22, [x28, #0x10]\n" - "ldr q6, [x25, #0x10]\n" - "movi v29.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "ldr q27, [x28, #0x20]\n" - "ldr q30, [x28, #0x30]\n" - "movi v20.4s, #0x0\n" - "movi v24.16b, #0xf0\n" - "ldr d2, [x25, #-0x8]\n" - "ldr d26, [x23, #-0x8]\n" - "sshl v12.16b, v3.16b, v28.16b\n" - "sub x20, x28, #0x8\n" - "ldr d17, [x20, #0x0]\n" - "and v3.16b, v3.16b, v24.16b\n" - "subs x24, x24, #0x1\n" - "add x28, x28, #0x48\n" - ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" - ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" - ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" - ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" - "sshl v31.16b, v22.16b, v28.16b\n" - "and v22.16b, v22.16b, v24.16b\n" - "fcvtl v17.4s, v17.4h\n" - "fcvtl v2.4s, v2.4h\n" - "fcvtl v26.4s, v26.4h\n" - ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" - ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" - ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" - ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" - "sshl v6.16b, v27.16b, v28.16b\n" - "sshl v28.16b, v30.16b, v28.16b\n" - "and v27.16b, v27.16b, v24.16b\n" - "and v30.16b, v30.16b, v24.16b\n" - "ldr q24, [x25, #0x20]\n" - ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x30]\n" - ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" - ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" - ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" - ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x40]\n" - ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x50]\n" - ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" - ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" - ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" - ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x60]\n" - ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x70]\n" - "add x25, x25, #0x88\n" - ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" - ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" - ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" - ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" - "fmul v24.4s, v17.4s, v2.s[0]\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v15.4s, v10.4s, v24.4s\n" - "ldr q24, [x23, #0x0]\n" - "fmul v10.4s, v17.4s, v2.s[1]\n" - "fmla v19.4s, v29.4s, v10.4s\n" - "ldr q10, [x23, #0x10]\n" - "fmul v29.4s, v17.4s, v2.s[2]\n" - "fmul v2.4s, v17.4s, v2.s[3]\n" - "fmla v18.4s, v9.4s, v29.4s\n" - "movi v9.4s, #0x0\n" - "movi v29.4s, #0x0\n" - ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" - "fmla v14.4s, v20.4s, v2.4s\n" - "movi v20.4s, #0x0\n" - "movi v2.4s, #0x0\n" - ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x20]\n" - ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" - ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" - ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" - ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x30]\n" - ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x40]\n" - ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" - ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" - ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" - ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x50]\n" - ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x60]\n" - ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" - ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" - ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" - ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x70]\n" - "add x23, x23, #0x88\n" - ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x0]\n" - ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" - ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" - ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" - ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" - "fmul v10.4s, v17.4s, v26.s[0]\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v11.4s, v9.4s, v10.4s\n" - "ldr q9, [x22, #0x10]\n" - "fmul v10.4s, v17.4s, v26.s[1]\n" - "fmla v13.4s, v29.4s, v10.4s\n" - "ldr d29, [x22, #-0x8]\n" - "fmul v10.4s, v17.4s, v26.s[2]\n" - "fmul v26.4s, v17.4s, v26.s[3]\n" - "fcvtl v29.4s, v29.4h\n" - "fmla v23.4s, v20.4s, v10.4s\n" - "movi v20.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "fmla v16.4s, v2.4s, v26.4s\n" - "movi v26.4s, #0x0\n" - "movi v2.4s, #0x0\n" - ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" - ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x20]\n" - ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" - ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" - ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x30]\n" - ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x40]\n" - ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n" - ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" - ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n" - ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x50]\n" - ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x60]\n" - ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n" - ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" - ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n" - ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x70]\n" - "add x22, x22, #0x88\n" - ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x21, #0x0]\n" - ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n" - ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n" - ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n" - "fmul v9.4s, v17.4s, v29.s[0]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v25.4s, v20.4s, v9.4s\n" - "ldr q9, [x21, #0x10]\n" - "fmul v20.4s, v17.4s, v29.s[1]\n" - "fmla v7.4s, v10.4s, v20.4s\n" - "ldr d20, [x21, #-0x8]\n" - "fmul v10.4s, v17.4s, v29.s[2]\n" - "fmul v29.4s, v17.4s, v29.s[3]\n" - "fcvtl v20.4s, v20.4h\n" - "fmla v0.4s, v26.4s, v10.4s\n" - "movi v26.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "fmla v4.4s, v2.4s, v29.4s\n" - "movi v2.4s, #0x0\n" - "movi v29.4s, #0x0\n" - ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" - ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n" - "ldr q12, [x21, #0x20]\n" - "fmul v24.4s, v17.4s, v20.s[0]\n" - ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" - ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n" - ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n" - "ldr q9, [x21, #0x30]\n" - "fmul v31.4s, v17.4s, v20.s[1]\n" - ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n" - ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n" - ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n" - ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n" - "ldr q12, [x21, #0x40]\n" - "fmul v6.4s, v17.4s, v20.s[2]\n" - "fmul v20.4s, v17.4s, v20.s[3]\n" - ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n" - ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" - ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n" - ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n" - "ldr q9, [x21, #0x50]\n" - ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n" - ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n" - ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n" - ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n" - "ldr q12, [x21, #0x60]\n" - ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n" - ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" - ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n" - ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n" - "ldr q17, [x21, #0x70]\n" - "add x21, x21, #0x88\n" - ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n" - ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n" - ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n" - ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n" - ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n" - ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n" - ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n" - ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "fmla v5.4s, v26.4s, v24.4s\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "fmla v21.4s, v10.4s, v31.4s\n" - "fmla v8.4s, v2.4s, v6.4s\n" - "fmla v1.4s, v29.4s, v20.4s\n" - "bgt 3b\n" - "mov x20, %x[res_ptr]\n" - "subs x27, x27, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "str q15, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q19, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q18, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q14, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q11, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q13, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q16, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q7, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q0, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q4, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q5, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q21, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q8, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q1, [x20, #0x0]\n" - "bne 2b\n" - "mov x20, #0x4\n" - "sub x10, x10, #0x10\n" - "cmp x10, #0x10\n" - "mov %x[res_ptr], x26\n" - "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x10, 9f\n" - "5:" // Row tail: Row loop - "add x24, %x[b_ptr], #0x8\n" - "mov x23, %x[nc]\n" - "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" - "6:" // Row tail: Column loop - "movi v15.16b, #0x0\n" - "movi v19.16b, #0x0\n" - "add x25, %x[a_ptr], #0x8\n" - "mov x21, %x[nb]\n" - "movi v18.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "7:" // Row tail: Block loop - "ldr q7, [x24, #0x0]\n" - "ldr q5, [x25, #0x0]\n" - "movi v9.16b, #0x4\n" - "movi v4.4s, #0x0\n" - "ldr q3, [x24, #0x10]\n" - "ldr q2, [x25, #0x10]\n" - "movi v1.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "ldr q13, [x24, #0x20]\n" - "ldr q31, [x25, #0x20]\n" - "movi v30.4s, #0x0\n" - "movi v29.16b, #0xf0\n" - "ldr q28, [x24, #0x30]\n" - "ldr q27, [x25, #0x30]\n" - "sshl v20.16b, v7.16b, v9.16b\n" - "sub x20, x24, #0x8\n" - "ldr q26, [x25, #0x40]\n" - "ldr q25, [x25, #0x50]\n" - "sshl v17.16b, v3.16b, v9.16b\n" - "and v7.16b, v7.16b, v29.16b\n" - "ldr q24, [x25, #0x60]\n" - "ldr q16, [x25, #0x70]\n" - "sshl v22.16b, v13.16b, v9.16b\n" - "and v3.16b, v3.16b, v29.16b\n" - "ldr d21, [x20, #0x0]\n" - "ldr d12, [x25, #-0x8]\n" - ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n" - ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n" - ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n" - ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n" - "sshl v9.16b, v28.16b, v9.16b\n" - "subs x21, x21, #0x1\n" - "and v13.16b, v13.16b, v29.16b\n" - "and v28.16b, v28.16b, v29.16b\n" - "add x25, x25, #0x88\n" - "add x24, x24, #0x48\n" - "fcvtl v21.4s, v21.4h\n" - "fcvtl v12.4s, v12.4h\n" - ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n" - ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n" - ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n" - ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n" - "fmul v11.4s, v21.4s, v12.s[0]\n" - "fmul v23.4s, v21.4s, v12.s[1]\n" - "fmul v17.4s, v21.4s, v12.s[2]\n" - ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n" - "fmul v6.4s, v21.4s, v12.s[3]\n" - ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n" - ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n" - ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n" - ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n" - ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n" - ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n" - ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n" - ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n" - ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n" - ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n" - ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n" - ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n" - ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n" - ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n" - ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n" - ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n" - ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n" - ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n" - ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n" - ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n" - ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n" - ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n" - ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n" - "scvtf v4.4s, v4.4s, #0x4\n" - "scvtf v1.4s, v1.4s, #0x4\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "fmla v15.4s, v4.4s, v11.4s\n" - "scvtf v30.4s, v30.4s, #0x4\n" - "fmla v19.4s, v1.4s, v23.4s\n" - "fmla v18.4s, v0.4s, v17.4s\n" - "fmla v14.4s, v30.4s, v6.4s\n" - "bgt 7b\n" - "mov x20, %x[res_ptr]\n" - "cmp x10, #0x1\n" - "str q15, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x2\n" - "str q19, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x3\n" - "str q18, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "str q14, [x20, #0x0]\n" - "8:" // Row tail: Accumulator store skip - "subs x23, x23, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "bne 6b\n" - "subs x10, x10, #0x4\n" - "add %x[a_ptr], %x[a_ptr], x9\n" - "mov %x[res_ptr], x22\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) - : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" - ); - return; - } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) - { - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } -} - -void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); - - __asm__ __volatile__( - "mov x10, %x[nr]\n" - "mov x9, #0x88\n" - "cmp x10, #0x10\n" - "mul x9, %x[nb], x9\n" - "blt 4f\n" - "1:" // Row loop - "add x28, %x[b_ptr], #0x8\n" - "mov x27, %x[nc]\n" - "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x25, %x[a_ptr], #0x8\n" - "movi v2.16b, #0x0\n" - "movi v10.16b, #0x0\n" - "mov x24, %x[nb]\n" - "add x23, x25, x9\n" - "movi v12.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "add x22, x23, x9\n" - "movi v11.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "add x21, x22, x9\n" - "movi v22.16b, #0x0\n" - "movi v23.16b, #0x0\n" - "movi v25.16b, #0x0\n" - "movi v5.16b, #0x0\n" - "movi v7.16b, #0x0\n" - "movi v4.16b, #0x0\n" - "movi v6.16b, #0x0\n" - "movi v30.16b, #0x0\n" - "movi v24.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "3:" // Block loop - "ldr q21, [x28, #0x0]\n" - "ldr q16, [x28, #0x10]\n" - "movi v1.16b, #0x4\n" - "movi v19.4s, #0x0\n" - "ldr q27, [x25, #0x0]\n" - "ldr q15, [x25, #0x10]\n" - "movi v26.4s, #0x0\n" - "movi v18.4s, #0x0\n" - "ldr q29, [x28, #0x20]\n" - "ldr q3, [x28, #0x30]\n" - "movi v17.4s, #0x0\n" - "movi v0.16b, #0xf0\n" - "ldr d20, [x25, #-0x8]\n" - "ldr d9, [x23, #-0x8]\n" - "sshl v8.16b, v21.16b, v1.16b\n" - "sshl v31.16b, v16.16b, v1.16b\n" - "and v21.16b, v21.16b, v0.16b\n" - "and v16.16b, v16.16b, v0.16b\n" - "sub x20, x28, #0x8\n" - "subs x24, x24, #0x1\n" - "add x28, x28, #0x48\n" - ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n" - ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n" - "ldr q27, [x25, #0x20]\n" - ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n" - ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n" - "sshl v15.16b, v29.16b, v1.16b\n" - "sshl v1.16b, v3.16b, v1.16b\n" - "and v29.16b, v29.16b, v0.16b\n" - "and v3.16b, v3.16b, v0.16b\n" - "ldr q0, [x25, #0x30]\n" - "fcvtl v20.4s, v20.4h\n" - ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n" - "fcvtl v9.4s, v9.4h\n" - ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n" - "ldr q27, [x25, #0x40]\n" - ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n" - ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n" - "ldr q0, [x25, #0x50]\n" - ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n" - ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n" - "ldr q27, [x25, #0x60]\n" - ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n" - ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n" - "ldr q0, [x25, #0x70]\n" - "add x25, x25, #0x88\n" - ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n" - ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" - "ldr d27, [x20, #0x0]\n" - ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n" - ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n" - "fcvtl v27.4s, v27.4h\n" - "uzp1 v0.2d, v19.2d, v26.2d\n" - "uzp2 v26.2d, v19.2d, v26.2d\n" - "fmul v19.4s, v27.4s, v20.s[0]\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "fmla v2.4s, v0.4s, v19.4s\n" - "ldr q19, [x23, #0x0]\n" - "uzp1 v0.2d, v18.2d, v17.2d\n" - "uzp2 v18.2d, v18.2d, v17.2d\n" - "fmul v17.4s, v27.4s, v20.s[1]\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "fmla v10.4s, v26.4s, v17.4s\n" - "ldr q17, [x23, #0x10]\n" - "fmul v26.4s, v27.4s, v20.s[2]\n" - "fmul v20.4s, v27.4s, v20.s[3]\n" - "fmla v12.4s, v0.4s, v26.4s\n" - "ldr d0, [x22, #-0x8]\n" - "ldr d26, [x21, #-0x8]\n" - "fcvtl v0.4s, v0.4h\n" - "fmla v28.4s, v18.4s, v20.4s\n" - "movi v20.4s, #0x0\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" - ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" - "ldr q19, [x23, #0x20]\n" - "fcvtl v26.4s, v26.4h\n" - ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" - ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" - "ldr q19, [x23, #0x40]\n" - ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" - ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" - "ldr q19, [x23, #0x60]\n" - ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n" - ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" - "uzp1 v19.2d, v20.2d, v18.2d\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "uzp2 v20.2d, v20.2d, v18.2d\n" - "fmul v18.4s, v27.4s, v9.s[0]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v11.4s, v19.4s, v18.4s\n" - "ldr q18, [x22, #0x0]\n" - "fmul v19.4s, v27.4s, v9.s[1]\n" - "fmla v13.4s, v20.4s, v19.4s\n" - "movi v19.4s, #0x0\n" - "movi v20.4s, #0x0\n" - ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n" - ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n" - "ldr q17, [x23, #0x30]\n" - ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n" - ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n" - "ldr q17, [x23, #0x50]\n" - ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n" - ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n" - "ldr q17, [x23, #0x70]\n" - "add x23, x23, #0x88\n" - ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n" - ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n" - "uzp1 v17.2d, v19.2d, v20.2d\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "uzp2 v20.2d, v19.2d, v20.2d\n" - "fmul v19.4s, v27.4s, v9.s[2]\n" - "fmul v9.4s, v27.4s, v9.s[3]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v22.4s, v17.4s, v19.4s\n" - "ldr q17, [x22, #0x10]\n" - "movi v19.4s, #0x0\n" - ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n" - "fmla v23.4s, v20.4s, v9.4s\n" - "movi v20.4s, #0x0\n" - "movi v9.4s, #0x0\n" - ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n" - "ldr q18, [x22, #0x20]\n" - ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" - ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n" - ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n" - "ldr q18, [x22, #0x40]\n" - ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n" - ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n" - "ldr q18, [x22, #0x60]\n" - ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n" - ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n" - "ldr q17, [x22, #0x30]\n" - ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" - ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n" - "ldr q17, [x22, #0x50]\n" - ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n" - ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n" - "ldr q17, [x22, #0x70]\n" - "add x22, x22, #0x88\n" - ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n" - ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n" - "uzp1 v17.2d, v19.2d, v20.2d\n" - "uzp2 v20.2d, v19.2d, v20.2d\n" - "fmul v19.4s, v27.4s, v0.s[0]\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v25.4s, v17.4s, v19.4s\n" - "ldr q19, [x21, #0x0]\n" - "fmul v17.4s, v27.4s, v0.s[1]\n" - "fmla v5.4s, v20.4s, v17.4s\n" - "ldr q17, [x21, #0x10]\n" - "uzp1 v20.2d, v9.2d, v18.2d\n" - "uzp2 v9.2d, v9.2d, v18.2d\n" - "fmul v18.4s, v27.4s, v0.s[2]\n" - "fmul v0.4s, v27.4s, v0.s[3]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "fmla v7.4s, v20.4s, v18.4s\n" - "movi v20.4s, #0x0\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" - ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" - "ldr q19, [x21, #0x20]\n" - "fmla v4.4s, v9.4s, v0.4s\n" - "movi v9.4s, #0x0\n" - "movi v0.4s, #0x0\n" - ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" - "fmul v8.4s, v27.4s, v26.s[0]\n" - ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n" - "ldr q17, [x21, #0x30]\n" - ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" - "fmul v31.4s, v27.4s, v26.s[1]\n" - ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" - "ldr q19, [x21, #0x40]\n" - ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" - "fmul v15.4s, v27.4s, v26.s[2]\n" - "fmul v27.4s, v27.4s, v26.s[3]\n" - ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" - "ldr q1, [x21, #0x50]\n" - ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" - ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" - "ldr q26, [x21, #0x60]\n" - ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n" - ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n" - "ldr q21, [x21, #0x70]\n" - "add x21, x21, #0x88\n" - ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n" - ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n" - ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n" - ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n" - "uzp1 v29.2d, v20.2d, v18.2d\n" - "uzp2 v21.2d, v20.2d, v18.2d\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "uzp1 v18.2d, v9.2d, v0.2d\n" - "uzp2 v16.2d, v9.2d, v0.2d\n" - "scvtf v21.4s, v21.4s, #0x4\n" - "fmla v6.4s, v29.4s, v8.4s\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "fmla v30.4s, v21.4s, v31.4s\n" - "fmla v24.4s, v18.4s, v15.4s\n" - "fmla v14.4s, v16.4s, v27.4s\n" - "bgt 3b\n" - "mov x20, %x[res_ptr]\n" - "subs x27, x27, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "str q2, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q10, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q12, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q28, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q11, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q13, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q22, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q5, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q7, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q4, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q6, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q30, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q24, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q14, [x20, #0x0]\n" - "bne 2b\n" - "mov x20, #0x4\n" - "sub x10, x10, #0x10\n" - "cmp x10, #0x10\n" - "mov %x[res_ptr], x26\n" - "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x10, 9f\n" - "5:" // Row tail: Row loop - "add x24, %x[b_ptr], #0x8\n" - "mov x23, %x[nc]\n" - "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" - "6:" // Row tail: Column loop - "movi v2.16b, #0x0\n" - "movi v10.16b, #0x0\n" - "add x25, %x[a_ptr], #0x8\n" - "mov x21, %x[nb]\n" - "movi v12.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "7:" // Row tail: Block loop - "ldr q6, [x24, #0x0]\n" - "ldr q5, [x24, #0x10]\n" - "movi v17.16b, #0x4\n" - "movi v8.4s, #0x0\n" - "ldr q4, [x25, #0x0]\n" - "ldr q13, [x25, #0x10]\n" - "movi v27.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "ldr q31, [x24, #0x20]\n" - "ldr q14, [x24, #0x30]\n" - "movi v29.4s, #0x0\n" - "movi v22.16b, #0xf0\n" - "ldr q11, [x25, #0x20]\n" - "ldr q23, [x25, #0x30]\n" - "sshl v21.16b, v6.16b, v17.16b\n" - "sshl v16.16b, v5.16b, v17.16b\n" - "ldr q20, [x25, #0x40]\n" - "ldr q26, [x25, #0x50]\n" - "and v6.16b, v6.16b, v22.16b\n" - "and v5.16b, v5.16b, v22.16b\n" - "ldr q25, [x25, #0x60]\n" - "ldr q3, [x25, #0x70]\n" - "sshl v19.16b, v31.16b, v17.16b\n" - "sshl v18.16b, v14.16b, v17.16b\n" - "ldr d17, [x25, #-0x8]\n" - ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n" - ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n" - "and v31.16b, v31.16b, v22.16b\n" - ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n" - ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n" - "and v14.16b, v14.16b, v22.16b\n" - "sub x20, x24, #0x8\n" - "ldr d16, [x20, #0x0]\n" - "subs x21, x21, #0x1\n" - "add x25, x25, #0x88\n" - "fcvtl v17.4s, v17.4h\n" - "add x24, x24, #0x48\n" - ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n" - ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n" - ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n" - ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n" - "fcvtl v16.4s, v16.4h\n" - ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n" - ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n" - "fmul v23.4s, v16.4s, v17.s[0]\n" - "fmul v21.4s, v16.4s, v17.s[1]\n" - "fmul v1.4s, v16.4s, v17.s[2]\n" - "fmul v20.4s, v16.4s, v17.s[3]\n" - ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n" - ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n" - ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n" - ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n" - ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n" - ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n" - "uzp1 v19.2d, v8.2d, v27.2d\n" - "uzp2 v18.2d, v8.2d, v27.2d\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "uzp1 v17.2d, v0.2d, v29.2d\n" - "uzp2 v16.2d, v0.2d, v29.2d\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "fmla v2.4s, v19.4s, v23.4s\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "fmla v10.4s, v18.4s, v21.4s\n" - "fmla v12.4s, v17.4s, v1.4s\n" - "fmla v28.4s, v16.4s, v20.4s\n" - "bgt 7b\n" - "mov x20, %x[res_ptr]\n" - "cmp x10, #0x1\n" - "str q2, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x2\n" - "str q10, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x3\n" - "str q12, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "str q28, [x20, #0x0]\n" - "8:" // Row tail: Accumulator store skip - "subs x23, x23, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "bne 6b\n" - "subs x10, x10, #0x4\n" - "add %x[a_ptr], %x[a_ptr], x9\n" - "mov %x[res_ptr], x22\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) - : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" - ); - return; - } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } -} - -void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) -#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) - if (lm_ggml_cpu_has_sve() && lm_ggml_cpu_has_matmul_int8() && lm_ggml_cpu_get_sve_cnt() == QK8_0) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); - - __asm__ __volatile__( - "mov x20, #0x4\n" - "mov x13, %x[nr]\n" - "mov z28.s, #-0x4\n" - "mov x12, #0x88\n" - "ptrue p1.b\n" - "whilelt p0.s, XZR, x20\n" - "cmp x13, #0x10\n" - "mul x12, %x[nb], x12\n" - "blt 4f\n" - "1:" // Row loop - "add x11, %x[b_ptr], #0x10\n" - "mov x10, %x[nc]\n" - "add x9, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x28, %x[a_ptr], #0x8\n" - "mov z24.b, #0x0\n" - "mov z15.b, #0x0\n" - "mov x27, %x[nb]\n" - "add x26, x28, x12\n" - "mov z12.b, #0x0\n" - "mov z0.b, #0x0\n" - "add x25, x26, x12\n" - "mov z13.b, #0x0\n" - "mov z1.b, #0x0\n" - "add x24, x25, x12\n" - "mov z20.b, #0x0\n" - "mov z25.b, #0x0\n" - "mov z11.b, #0x0\n" - "mov z16.b, #0x0\n" - "mov z19.b, #0x0\n" - "mov z26.b, #0x0\n" - "mov z8.b, #0x0\n" - "mov z29.b, #0x0\n" - "mov z27.b, #0x0\n" - "mov z10.b, #0x0\n" - "3:" // Block loop - "ld1b { z30.b }, p1/Z, [x11]\n" - "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n" - "mov z18.s, #0x0\n" - "mov z7.s, #0x0\n" - "ld1rqb { z3.b }, p1/Z, [x28]\n" - "ld1rqb { z5.b }, p1/Z, [x28, #16]\n" - "mov z9.s, #0x0\n" - "mov z22.s, #0x0\n" - "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n" - "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n" - "sub x20, x11, #0x10\n" - "sub x23, x28, #0x8\n" - "lsl z31.b, z30.b, #0x4\n" - "lsl z6.b, z21.b, #0x4\n" - "ld1h { z23.s }, p1/Z, [x20]\n" - "sub x22, x26, #0x8\n" - "and z30.b, z30.b, #0xf0\n" - "and z21.b, z21.b, #0xf0\n" - "sub x21, x25, #0x8\n" - "sub x20, x24, #0x8\n" - "lsl z14.b, z4.b, #0x4\n" - "lsl z2.b, z17.b, #0x4\n" - "subs x27, x27, #0x1\n" - "add x11, x11, #0x90\n" - ".inst 0x451f9872 // smmla z18.s, z3.b, z31.b\n" - ".inst 0x45069867 // smmla z7.s, z3.b, z6.b\n" - "ld1rqb { z3.b }, p1/Z, [x28, #32]\n" - "and z4.b, z4.b, #0xf0\n" - ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" - ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" - "ld1rqb { z5.b }, p1/Z, [x28, #48]\n" - "and z17.b, z17.b, #0xf0\n" - "fcvt z23.s, p1/m, z23.h\n" - ".inst 0x450e9872 // smmla z18.s, z3.b, z14.b\n" - ".inst 0x45029867 // smmla z7.s, z3.b, z2.b\n" - "ld1rqb { z3.b }, p1/Z, [x28, #64]\n" - ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" - ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" - "ld1rqb { z5.b }, p1/Z, [x28, #80]\n" - "fscale z23.s, p1/m, z23.s, z28.s\n" - ".inst 0x451e9872 // smmla z18.s, z3.b, z30.b\n" - ".inst 0x45159867 // smmla z7.s, z3.b, z21.b\n" - "ld1rqb { z3.b }, p1/Z, [x28, #96]\n" - ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" - ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" - "ld1rqb { z5.b }, p1/Z, [x28, #112]\n" - "add x28, x28, #0x88\n" - ".inst 0x45049872 // smmla z18.s, z3.b, z4.b\n" - ".inst 0x45119867 // smmla z7.s, z3.b, z17.b\n" - "ld1h { z3.s }, p0/Z, [x23]\n" - ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" - ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" - "fcvt z3.s, p1/m, z3.h\n" - "uzp1 z5.d, z18.d, z7.d\n" - "uzp2 z18.d, z18.d, z7.d\n" - "mov z3.q, z3.q[0]\n" - "uzp1 z7.d, z9.d, z22.d\n" - "uzp2 z22.d, z9.d, z22.d\n" - "fmul z9.s, z23.s, z3.s[0]\n" - "scvtf z5.s, p1/m, z5.s\n" - "scvtf z18.s, p1/m, z18.s\n" - "scvtf z7.s, p1/m, z7.s\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z24.s, p1/M, z5.s, z9.s\n" - "ld1rqb { z5.b }, p1/Z, [x26]\n" - "fmul z9.s, z23.s, z3.s[1]\n" - "fmla z15.s, p1/M, z18.s, z9.s\n" - "ld1rqb { z18.b }, p1/Z, [x26, #16]\n" - "fmul z9.s, z23.s, z3.s[2]\n" - "fmul z3.s, z23.s, z3.s[3]\n" - "fmla z12.s, p1/M, z7.s, z9.s\n" - "mov z9.s, #0x0\n" - "ld1h { z7.s }, p0/Z, [x22]\n" - ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" - "fmla z0.s, p1/M, z22.s, z3.s\n" - "mov z22.s, #0x0\n" - "ld1h { z3.s }, p0/Z, [x21]\n" - ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" - "ld1rqb { z5.b }, p1/Z, [x26, #32]\n" - "fcvt z7.s, p1/m, z7.h\n" - "fcvt z3.s, p1/m, z3.h\n" - ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" - ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" - "ld1rqb { z5.b }, p1/Z, [x26, #64]\n" - "mov z7.q, z7.q[0]\n" - "mov z3.q, z3.q[0]\n" - ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" - ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" - "ld1rqb { z5.b }, p1/Z, [x26, #96]\n" - ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" - ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" - "uzp1 z5.d, z9.d, z22.d\n" - "scvtf z5.s, p1/m, z5.s\n" - "uzp2 z22.d, z9.d, z22.d\n" - "fmul z9.s, z23.s, z7.s[0]\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z13.s, p1/M, z5.s, z9.s\n" - "ld1rqb { z9.b }, p1/Z, [x25]\n" - "fmul z5.s, z23.s, z7.s[1]\n" - "fmla z1.s, p1/M, z22.s, z5.s\n" - "mov z5.s, #0x0\n" - "mov z22.s, #0x0\n" - ".inst 0x451f9a45 // smmla z5.s, z18.b, z31.b\n" - ".inst 0x45069a56 // smmla z22.s, z18.b, z6.b\n" - "ld1rqb { z18.b }, p1/Z, [x26, #48]\n" - ".inst 0x450e9a45 // smmla z5.s, z18.b, z14.b\n" - ".inst 0x45029a56 // smmla z22.s, z18.b, z2.b\n" - "ld1rqb { z18.b }, p1/Z, [x26, #80]\n" - ".inst 0x451e9a45 // smmla z5.s, z18.b, z30.b\n" - ".inst 0x45159a56 // smmla z22.s, z18.b, z21.b\n" - "ld1rqb { z18.b }, p1/Z, [x26, #112]\n" - "add x26, x26, #0x88\n" - ".inst 0x45049a45 // smmla z5.s, z18.b, z4.b\n" - ".inst 0x45119a56 // smmla z22.s, z18.b, z17.b\n" - "uzp1 z18.d, z5.d, z22.d\n" - "scvtf z18.s, p1/m, z18.s\n" - "uzp2 z22.d, z5.d, z22.d\n" - "fmul z5.s, z23.s, z7.s[2]\n" - "fmul z7.s, z23.s, z7.s[3]\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z20.s, p1/M, z18.s, z5.s\n" - "ld1rqb { z18.b }, p1/Z, [x25, #16]\n" - "ld1h { z5.s }, p0/Z, [x20]\n" - "fcvt z5.s, p1/m, z5.h\n" - "fmla z25.s, p1/M, z22.s, z7.s\n" - "mov z22.s, #0x0\n" - "mov z7.s, #0x0\n" - ".inst 0x451f9936 // smmla z22.s, z9.b, z31.b\n" - ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" - "ld1rqb { z9.b }, p1/Z, [x25, #32]\n" - "mov z5.q, z5.q[0]\n" - ".inst 0x450e9936 // smmla z22.s, z9.b, z14.b\n" - ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" - "ld1rqb { z9.b }, p1/Z, [x25, #64]\n" - ".inst 0x451e9936 // smmla z22.s, z9.b, z30.b\n" - ".inst 0x45159927 // smmla z7.s, z9.b, z21.b\n" - "ld1rqb { z9.b }, p1/Z, [x25, #96]\n" - ".inst 0x45049936 // smmla z22.s, z9.b, z4.b\n" - ".inst 0x45119927 // smmla z7.s, z9.b, z17.b\n" - "uzp1 z9.d, z22.d, z7.d\n" - "scvtf z9.s, p1/m, z9.s\n" - "uzp2 z22.d, z22.d, z7.d\n" - "fmul z7.s, z23.s, z3.s[0]\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z11.s, p1/M, z9.s, z7.s\n" - "ld1rqb { z9.b }, p1/Z, [x24]\n" - "fmul z7.s, z23.s, z3.s[1]\n" - "fmla z16.s, p1/M, z22.s, z7.s\n" - "mov z22.s, #0x0\n" - "mov z7.s, #0x0\n" - ".inst 0x451f9a56 // smmla z22.s, z18.b, z31.b\n" - ".inst 0x45069a47 // smmla z7.s, z18.b, z6.b\n" - "ld1rqb { z18.b }, p1/Z, [x25, #48]\n" - ".inst 0x450e9a56 // smmla z22.s, z18.b, z14.b\n" - ".inst 0x45029a47 // smmla z7.s, z18.b, z2.b\n" - "ld1rqb { z18.b }, p1/Z, [x25, #80]\n" - ".inst 0x451e9a56 // smmla z22.s, z18.b, z30.b\n" - ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" - "ld1rqb { z18.b }, p1/Z, [x25, #112]\n" - "add x25, x25, #0x88\n" - ".inst 0x45049a56 // smmla z22.s, z18.b, z4.b\n" - ".inst 0x45119a47 // smmla z7.s, z18.b, z17.b\n" - "uzp1 z18.d, z22.d, z7.d\n" - "scvtf z18.s, p1/m, z18.s\n" - "uzp2 z7.d, z22.d, z7.d\n" - "fmul z22.s, z23.s, z3.s[2]\n" - "fmul z3.s, z23.s, z3.s[3]\n" - "scvtf z7.s, p1/m, z7.s\n" - "fmla z19.s, p1/M, z18.s, z22.s\n" - "ld1rqb { z18.b }, p1/Z, [x24, #16]\n" - "fmul z22.s, z23.s, z5.s[0]\n" - "fmla z26.s, p1/M, z7.s, z3.s\n" - "mov z3.s, #0x0\n" - "mov z7.s, #0x0\n" - ".inst 0x451f9923 // smmla z3.s, z9.b, z31.b\n" - ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" - "ld1rqb { z9.b }, p1/Z, [x24, #32]\n" - ".inst 0x450e9923 // smmla z3.s, z9.b, z14.b\n" - ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" - "mov z9.s, #0x0\n" - ".inst 0x451f9a49 // smmla z9.s, z18.b, z31.b\n" - "mov z31.s, #0x0\n" - ".inst 0x45069a5f // smmla z31.s, z18.b, z6.b\n" - "ld1rqb { z6.b }, p1/Z, [x24, #48]\n" - "ld1rqb { z18.b }, p1/Z, [x24, #64]\n" - ".inst 0x450e98c9 // smmla z9.s, z6.b, z14.b\n" - "fmul z14.s, z23.s, z5.s[1]\n" - ".inst 0x450298df // smmla z31.s, z6.b, z2.b\n" - "ld1rqb { z6.b }, p1/Z, [x24, #80]\n" - "fmul z2.s, z23.s, z5.s[2]\n" - "fmul z23.s, z23.s, z5.s[3]\n" - ".inst 0x451e9a43 // smmla z3.s, z18.b, z30.b\n" - ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" - "ld1rqb { z5.b }, p1/Z, [x24, #96]\n" - ".inst 0x451e98c9 // smmla z9.s, z6.b, z30.b\n" - ".inst 0x451598df // smmla z31.s, z6.b, z21.b\n" - "ld1rqb { z18.b }, p1/Z, [x24, #112]\n" - "add x24, x24, #0x88\n" - ".inst 0x450498a3 // smmla z3.s, z5.b, z4.b\n" - ".inst 0x451198a7 // smmla z7.s, z5.b, z17.b\n" - ".inst 0x45049a49 // smmla z9.s, z18.b, z4.b\n" - ".inst 0x45119a5f // smmla z31.s, z18.b, z17.b\n" - "uzp1 z18.d, z3.d, z7.d\n" - "uzp2 z5.d, z3.d, z7.d\n" - "scvtf z18.s, p1/m, z18.s\n" - "uzp1 z6.d, z9.d, z31.d\n" - "uzp2 z9.d, z9.d, z31.d\n" - "scvtf z5.s, p1/m, z5.s\n" - "fmla z8.s, p1/M, z18.s, z22.s\n" - "scvtf z6.s, p1/m, z6.s\n" - "scvtf z9.s, p1/m, z9.s\n" - "fmla z29.s, p1/M, z5.s, z14.s\n" - "fmla z27.s, p1/M, z6.s, z2.s\n" - "fmla z10.s, p1/M, z9.s, z23.s\n" - "bgt 3b\n" - "mov x20, %x[res_ptr]\n" - "subs x10, x10, #0x8\n" - "add %x[res_ptr], %x[res_ptr], #0x20\n" - "st1w { z24.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z15.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z12.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z0.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z13.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z1.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z20.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z25.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z11.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z16.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z19.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z26.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z8.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z29.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z27.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z10.s }, p1, [x20]\n" - "bne 2b\n" - "mov x20, #0x4\n" - "sub x13, x13, #0x10\n" - "cmp x13, #0x10\n" - "mov %x[res_ptr], x9\n" - "madd %x[a_ptr], x20, x12, %x[a_ptr]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x13, 9f\n" - "5:" // Row tail: Row loop - "add x25, %x[b_ptr], #0x10\n" - "mov x24, %x[nc]\n" - "add x23, %x[res_ptr], %x[res_stride], LSL #2\n" - "6:" // Row tail: Column loop - "mov z24.b, #0x0\n" - "mov z15.b, #0x0\n" - "add x28, %x[a_ptr], #0x8\n" - "mov x22, %x[nb]\n" - "mov z12.b, #0x0\n" - "mov z0.b, #0x0\n" - "7:" // Row tail: Block loop - "ld1b { z3.b }, p1/Z, [x25]\n" - "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n" - "mov z2.s, #0x0\n" - "mov z25.s, #0x0\n" - "ld1rqb { z26.b }, p1/Z, [x28]\n" - "ld1rqb { z21.b }, p1/Z, [x28, #16]\n" - "mov z27.s, #0x0\n" - "mov z19.s, #0x0\n" - "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n" - "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n" - "sub x21, x25, #0x10\n" - "sub x20, x28, #0x8\n" - "lsl z20.b, z3.b, #0x4\n" - "lsl z4.b, z6.b, #0x4\n" - "ld1rqb { z10.b }, p1/Z, [x28, #32]\n" - "ld1rqb { z23.b }, p1/Z, [x28, #48]\n" - "and z3.b, z3.b, #0xf0\n" - "and z6.b, z6.b, #0xf0\n" - "ld1rqb { z11.b }, p1/Z, [x28, #64]\n" - "ld1rqb { z7.b }, p1/Z, [x28, #80]\n" - "lsl z8.b, z29.b, #0x4\n" - "lsl z14.b, z16.b, #0x4\n" - "ld1rqb { z18.b }, p1/Z, [x28, #96]\n" - "ld1rqb { z30.b }, p1/Z, [x28, #112]\n" - ".inst 0x45149b42 // smmla z2.s, z26.b, z20.b\n" - ".inst 0x45049b59 // smmla z25.s, z26.b, z4.b\n" - "and z29.b, z29.b, #0xf0\n" - "ld1h { z17.s }, p1/Z, [x21]\n" - ".inst 0x45149abb // smmla z27.s, z21.b, z20.b\n" - ".inst 0x45049ab3 // smmla z19.s, z21.b, z4.b\n" - "and z16.b, z16.b, #0xf0\n" - "ld1h { z4.s }, p0/Z, [x20]\n" - "subs x22, x22, #0x1\n" - "add x28, x28, #0x88\n" - "fcvt z17.s, p1/m, z17.h\n" - "add x25, x25, #0x90\n" - ".inst 0x45089942 // smmla z2.s, z10.b, z8.b\n" - ".inst 0x450e9959 // smmla z25.s, z10.b, z14.b\n" - "fcvt z4.s, p1/m, z4.h\n" - ".inst 0x45089afb // smmla z27.s, z23.b, z8.b\n" - ".inst 0x450e9af3 // smmla z19.s, z23.b, z14.b\n" - "fscale z17.s, p1/m, z17.s, z28.s\n" - "mov z4.q, z4.q[0]\n" - ".inst 0x45039962 // smmla z2.s, z11.b, z3.b\n" - ".inst 0x45069979 // smmla z25.s, z11.b, z6.b\n" - "fmul z23.s, z17.s, z4.s[0]\n" - "fmul z9.s, z17.s, z4.s[1]\n" - "fmul z21.s, z17.s, z4.s[2]\n" - "fmul z4.s, z17.s, z4.s[3]\n" - ".inst 0x450398fb // smmla z27.s, z7.b, z3.b\n" - ".inst 0x450698f3 // smmla z19.s, z7.b, z6.b\n" - ".inst 0x451d9a42 // smmla z2.s, z18.b, z29.b\n" - ".inst 0x45109a59 // smmla z25.s, z18.b, z16.b\n" - ".inst 0x451d9bdb // smmla z27.s, z30.b, z29.b\n" - ".inst 0x45109bd3 // smmla z19.s, z30.b, z16.b\n" - "uzp1 z31.d, z2.d, z25.d\n" - "uzp2 z13.d, z2.d, z25.d\n" - "scvtf z31.s, p1/m, z31.s\n" - "uzp1 z17.d, z27.d, z19.d\n" - "uzp2 z18.d, z27.d, z19.d\n" - "scvtf z13.s, p1/m, z13.s\n" - "fmla z24.s, p1/M, z31.s, z23.s\n" - "scvtf z17.s, p1/m, z17.s\n" - "scvtf z18.s, p1/m, z18.s\n" - "fmla z15.s, p1/M, z13.s, z9.s\n" - "fmla z12.s, p1/M, z17.s, z21.s\n" - "fmla z0.s, p1/M, z18.s, z4.s\n" - "bgt 7b\n" - "mov x20, %x[res_ptr]\n" - "cmp x13, #0x1\n" - "st1w { z24.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x13, #0x2\n" - "st1w { z15.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x13, #0x3\n" - "st1w { z12.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "st1w { z0.s }, p1, [x20]\n" - "8:" // Row tail: Accumulator store skip - "subs x24, x24, #0x8\n" - "add %x[res_ptr], %x[res_ptr], #0x20\n" - "bne 6b\n" - "subs x13, x13, #0x4\n" - "add %x[a_ptr], %x[a_ptr], x12\n" - "mov %x[res_ptr], x23\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) - : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) - : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" - ); - return; - } -#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) -#elif defined(__AVX2__) || defined(__AVX512F__) - { - const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; - const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy; - int64_t b_nb = n / QK4_0; - int64_t y = 0; - // Mask to mask out nibbles from packed bytes - const __m256i m4b = _mm256_set1_epi8(0x0F); - const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3); - // Lookup table to convert signed nibbles to signed bytes - __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); - signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); - // Permute mask used for easier vector processing at later stages - __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4); - int64_t xstart = 0; - int anr = nr - nr%16; // Used to align nr with boundary of 16 - #ifdef __AVX512F__ - int anc = nc - nc%16; // Used to align nc with boundary of 16 - // Mask to mask out nibbles from packed bytes expanded to 512 bit length - const __m512i m4bexpanded = _mm512_set1_epi8(0x0F); - // Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length - __m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1); - - // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation - for (; y < anr / 4; y += 4) { - - const block_q8_0x4 * a_ptrs[4]; - - a_ptrs[0] = a_ptr_start + (y * nb); - for (int i = 0; i < 3; ++i) { - a_ptrs[i + 1] = a_ptrs[i] + nb; - } - - // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation - for (int64_t x = 0; x < anc / 8; x += 2) { - - const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); - const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); - - // Master FP accumulators - __m512 acc_rows[16]; - for (int i = 0; i < 16; i++) { - acc_rows[i] = _mm512_setzero_ps(); - } - - for (int64_t b = 0; b < nb; b++) { - // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF - const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs)); - const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32)); - const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64)); - const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96)); - - const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs)); - const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32)); - const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64)); - const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96)); - - // Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values - const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); - const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); - - const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); - const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); - const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); - const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); - - const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); - const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); - const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); - const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); - - // 4-bit -> 8-bit - Sign is maintained - const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7) - const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7) - - const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15) - const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15) - - const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23) - const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23) - - const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31) - const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) - - // Shuffle pattern one - right side input - const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) - const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) - - const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) - const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) - - const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) - const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) - - const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) - const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) - - // Shuffle pattern two - right side input - - const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) - const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) - - const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) - const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) - - const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) - const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) - - const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) - const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) - - // Scale values - Load the weight scale values of two block_q4_0x8 - const __m512 col_scale_f32 = LM_GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); - - // Process LHS in pairs of rows - for (int rp = 0; rp < 4; rp++) { - - // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 - // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector - __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); - __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0); - __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17); - __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); - __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0); - __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17); - __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); - __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0); - __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17); - __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); - __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0); - __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17); - - __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1); - __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1); - __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1); - __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1); - __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1); - __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1); - __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1); - __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1); - - // Shuffle pattern one - left side input - - const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) - const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) - - const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) - const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) - - const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) - const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) - - const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) - const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) - - // Shuffle pattern two - left side input - - const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) - const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) - - const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) - const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) - - const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) - const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) - - const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) - const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) - - // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane - // Resembles MMLAs into 2x2 matrices in ARM Version - __m512i iacc_mat_00_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_01_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_10_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_11_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_00_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_01_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2)); - __m512i iacc_mat_10_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_11_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2)); - - // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block - __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); - __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); - __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); - __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); - - - // Straighten out to make 4 row vectors - __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78)); - __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01); - __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78)); - __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11); - - // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes - const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68); - const __m512 row_scale_f32 = LM_GGML_F32Cx16_REPEAT_LOAD(row_scale_f16); - - // Multiply with appropiate scales and accumulate - acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); - acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); - acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); - acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); - } - } - - // Store the accumulated values - for (int i = 0; i < 16; i++) { - _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); - } - } - } - // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation - for (; y < nr / 4; y ++) { - - const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb); - - // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation - for (int64_t x = 0; x < anc / 8; x += 2) { - - const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); - const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); - - // Master FP accumulators - __m512 acc_rows[4]; - for (int i = 0; i < 4; i++) { - acc_rows[i] = _mm512_setzero_ps(); - } - - for (int64_t b = 0; b < nb; b++) { - // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF - const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs)); - const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32)); - const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64)); - const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96)); - - const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs)); - const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32)); - const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64)); - const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96)); - - // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess - const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); - const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); - - const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); - const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); - const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); - const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); - - const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); - const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); - const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); - const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); - - // 4-bit -> 8-bit - Sign is maintained - const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7) - const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7) - - const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15) - const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15) - - const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23) - const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23) - - const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31) - const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) - - // Shuffle pattern one - right side input - const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) - const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) - - const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) - const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) - - const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) - const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) - - const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) - const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) - - // Shuffle pattern two - right side input - - const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) - const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) - - const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) - const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) - - const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) - const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) - - const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) - const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) - - - // Scale values - Load the weight scale values of two block_q4_0x8 - const __m512 col_scale_f32 = LM_GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); - - // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 - // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector - __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); - __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0); - __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17); - __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); - __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0); - __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17); - __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); - __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0); - __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17); - __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); - __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0); - __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17); - - __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1); - __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1); - __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1); - __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1); - __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1); - __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1); - __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1); - __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1); - - // Shuffle pattern one - left side input - - const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) - const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) - - const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) - const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) - - const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) - const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) - - const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) - const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) - - // Shuffle pattern two - left side input - - const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) - const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) - - const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) - const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) - - const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) - const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) - - const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) - const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) - - // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane - // Resembles MMLAs into 2x2 matrices in ARM Version - __m512i iacc_mat_00_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_01_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_10_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_11_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_00_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_01_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2)); - __m512i iacc_mat_10_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_11_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2)); - - // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block - __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); - __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); - __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); - __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); - - - // Straighten out to make 4 row vectors - __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78)); - __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01); - __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78)); - __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11); - - // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes - const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68); - const __m512 row_scale_f32 = LM_GGML_F32Cx16_REPEAT_LOAD(row_scale_f16); - - // Multiply with appropiate scales and accumulate - acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); - acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); - acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); - acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); - } - - // Store the accumulated values - for (int i = 0; i < 4; i++) { - _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); - } - } - } - if (anc != nc) { - xstart = anc/8; - y = 0; - } - #endif // __AVX512F__ - - // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation - - for (; y < anr / 4; y += 4) { - const block_q8_0x4 * a_ptrs[4]; - - a_ptrs[0] = a_ptr_start + (y * nb); - for (int i = 0; i < 3; ++i) { - a_ptrs[i + 1] = a_ptrs[i] + nb; - } - - // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation - for (int64_t x = xstart; x < nc / 8; x++) { - - const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); - - // Master FP accumulators - __m256 acc_rows[16]; - for (int i = 0; i < 16; i++) { - acc_rows[i] = _mm256_setzero_ps(); - } - - for (int64_t b = 0; b < nb; b++) { - // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 - const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); - const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); - const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); - const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); - - // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values - const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); - const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); - - // 4-bit -> 8-bit - Sign is maintained - const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) - const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) - - const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) - const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) - - const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) - const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) - - const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) - const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) - - // Shuffle pattern one - right side input - const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) - const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) - - const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) - const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) - - const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) - const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) - - const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) - const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) - - // Shuffle pattern two - right side input - - const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) - const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) - - const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) - const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) - - const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) - const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) - - const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) - const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) - - // Scale values - Load the wight scale values of block_q4_0x8 - const __m256 col_scale_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].d); - - // Process LHS in groups of four - for (int rp = 0; rp < 4; rp++) { - // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 - // Loaded as set of 128 bit vectors and repeated into a 256 bit vector - __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); - __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); - __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); - __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); - __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); - __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); - __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); - __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); - __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); - __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); - __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); - __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); - - // Shuffle pattern one - left side input - const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) - const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) - - const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) - const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) - - const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) - const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) - - const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) - const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) - - // Shuffle pattern two - left side input - const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) - const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) - - const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) - const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) - - const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) - const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) - - const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) - const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) - - // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane - // Resembles MMLAs into 2x2 matrices in ARM Version - __m256i iacc_mat_00_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_01_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_10_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_11_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_00_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_01_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); - __m256i iacc_mat_10_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_11_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); - - // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block - __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); - __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); - __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); - __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); - - // Straighten out to make 4 row vectors - __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); - __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); - __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); - __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); - - // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes - const __m256 row_scale_f32 = LM_GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask); - - // Multiply with appropiate scales and accumulate - acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); - acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); - acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); - acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); - } - } - - // Store the accumulated values - for (int i = 0; i < 16; i++) { - _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); - } - } - } - - // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation - for (; y < nr / 4; y ++) { - - const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb); - - // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 - for (int64_t x = xstart; x < nc / 8; x++) { - - const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); - - // Master FP accumulators - __m256 acc_rows[4]; - for (int i = 0; i < 4; i++) { - acc_rows[i] = _mm256_setzero_ps(); - } - - for (int64_t b = 0; b < nb; b++) { - // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 - const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); - const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); - const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); - const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); - - // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess - const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); - const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); - - // 4-bit -> 8-bit - Sign is maintained - const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) - const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) - - const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) - const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) - - const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) - const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) - - const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) - const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) - - // Shuffle pattern one - right side input - const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) - const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) - - const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) - const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) - - const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) - const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) - - const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) - const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) - - // Shuffle pattern two - right side input - - const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) - const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) - - const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) - const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) - - const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) - const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) - - const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) - const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) - - // Scale values - Load the wight scale values of block_q4_0x8 - const __m256 col_scale_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].d); - - // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 - // Loaded as set of 128 bit vectors and repeated into a 256 bit vector - __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); - __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); - __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); - __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); - __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); - __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); - __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); - __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); - __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); - __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); - __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); - __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); - - // Shuffle pattern one - left side input - - const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) - const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) - - const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) - const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) - - const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) - const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) - - const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) - const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) - - // Shuffle pattern two - left side input - - const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) - const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) - - const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) - const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) - - const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) - const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) - - const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) - const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) - - // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane - // Resembles MMLAs into 2x2 matrices in ARM Version - __m256i iacc_mat_00_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_01_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_10_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_11_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_00_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_01_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); - __m256i iacc_mat_10_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_11_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); - - // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block - __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); - __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); - __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); - __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); - - - // Straighten out to make 4 row vectors - __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); - __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); - __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); - __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); - - // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes - const __m256 row_scale_f32 = LM_GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask); - - // Multiply with appropiate scales and accumulate - acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); - acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); - acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); - acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); - } - - // Store the accumulated values - for (int i = 0; i < 4; i++) { - _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); - } - } - } - return; - } -#elif defined(__riscv_v_intrinsic) - if (__riscv_vlenb() >= QK4_0) { - const size_t vl = QK4_0; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - for (int l = 0; l < nb; l++) { - const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); - const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); - const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); - const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); - const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); - const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); - const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); - - // vector version needs Zvfhmin extension - const float a_scales[4] = { - LM_GGML_FP16_TO_FP32(a_ptr[l].d[0]), - LM_GGML_FP16_TO_FP32(a_ptr[l].d[1]), - LM_GGML_FP16_TO_FP32(a_ptr[l].d[2]), - LM_GGML_FP16_TO_FP32(a_ptr[l].d[3]) - }; - const float b_scales[8] = { - LM_GGML_FP16_TO_FP32(b_ptr[l].d[0]), - LM_GGML_FP16_TO_FP32(b_ptr[l].d[1]), - LM_GGML_FP16_TO_FP32(b_ptr[l].d[2]), - LM_GGML_FP16_TO_FP32(b_ptr[l].d[3]), - LM_GGML_FP16_TO_FP32(b_ptr[l].d[4]), - LM_GGML_FP16_TO_FP32(b_ptr[l].d[5]), - LM_GGML_FP16_TO_FP32(b_ptr[l].d[6]), - LM_GGML_FP16_TO_FP32(b_ptr[l].d[7]) - }; - const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); - - const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32]; - const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64]; - const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l0; - { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - - sumi_l0 = sumi_hi_m; - } - - { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); - sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4); - } - - const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40]; - const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72]; - const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l1; - { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - - sumi_l1 = sumi_hi_m; - } - - { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4); - sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4); - } - - const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48]; - const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80]; - const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l2; - { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - - sumi_l2 = sumi_hi_m; - } - - { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4); - sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4); - } - - const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24]; - const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56]; - const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88]; - const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l3; - { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - - sumi_l3 = sumi_hi_m; - } - - { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4); - sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4); - } - } - __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4); - __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4); - __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4); - __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4); - } - } - - return; - } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) - float sumf[4][8]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } -} diff --git a/cpp/ggml-aarch64.h b/cpp/ggml-aarch64.h index 90a23d77..3b8c8b37 100644 --- a/cpp/ggml-aarch64.h +++ b/cpp/ggml-aarch64.h @@ -1,9 +1,5 @@ -// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. #pragma once -#define LM_GGML_COMMON_DECL_C -#include "ggml-common.h" - #include "ggml.h" // GGML internal header @@ -12,27 +8,11 @@ extern "C" { #endif -// Quantization -void quantize_q8_0_4x4(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_q8_0_4x8(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); - -void quantize_mat_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave); - // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") size_t quantize_q4_0_4x4(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_0_4x8(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_0_8x8(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -// GEMV -void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); -void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); -void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); - -// GEMM -void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); -void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); -void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); - #ifdef __cplusplus } #endif diff --git a/cpp/ggml-backend-reg.cpp b/cpp/ggml-backend-reg.cpp new file mode 100644 index 00000000..4009d5fc --- /dev/null +++ b/cpp/ggml-backend-reg.cpp @@ -0,0 +1,204 @@ +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-cpu.h" +#include "ggml-impl.h" +#include +#include + +// Backend registry + +#ifdef LM_GGML_USE_CUDA +#include "ggml-cuda.h" +#endif + +#ifdef LM_GGML_USE_METAL +#include + +#if !TARGET_OS_SIMULATOR +#include "ggml-metal.h" +#endif + +#endif + +#ifdef LM_GGML_USE_SYCL +#include "ggml-sycl.h" +#endif + +#ifdef LM_GGML_USE_VULKAN +#include "ggml-vulkan.h" +#endif + +#ifdef LM_GGML_USE_BLAS +#include "ggml-blas.h" +#endif + +#ifdef LM_GGML_USE_RPC +#include "ggml-rpc.h" +#endif + +#ifdef LM_GGML_USE_AMX +# include "ggml-amx.h" +#endif + +#ifdef LM_GGML_USE_CANN +#include "ggml-cann.h" +#endif + +#ifdef LM_GGML_USE_KOMPUTE +#include "ggml-kompute.h" +#endif + +struct lm_ggml_backend_registry { + std::vector backends; + std::vector devices; + + lm_ggml_backend_registry() { +#ifdef LM_GGML_USE_CUDA + register_backend(lm_ggml_backend_cuda_reg()); +#endif +#ifdef LM_GGML_USE_METAL + +#if !TARGET_OS_SIMULATOR + register_backend(lm_ggml_backend_metal_reg()); +#endif + +#endif +#ifdef LM_GGML_USE_SYCL + register_backend(lm_ggml_backend_sycl_reg()); +#endif +#ifdef LM_GGML_USE_VULKAN + register_backend(lm_ggml_backend_vk_reg()); +#endif +#ifdef LM_GGML_USE_CANN + register_backend(lm_ggml_backend_cann_reg()); +#endif +#ifdef LM_GGML_USE_BLAS + register_backend(lm_ggml_backend_blas_reg()); +#endif +#ifdef LM_GGML_USE_RPC + register_backend(lm_ggml_backend_rpc_reg()); +#endif +#ifdef LM_GGML_USE_AMX + register_backend(lm_ggml_backend_amx_reg()); +#endif +#ifdef LM_GGML_USE_KOMPUTE + register_backend(lm_ggml_backend_kompute_reg()); +#endif + + register_backend(lm_ggml_backend_cpu_reg()); + } + + void register_backend(lm_ggml_backend_reg_t reg) { + if (!reg) { + return; + } + +#ifndef NDEBUG + LM_GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n", + __func__, lm_ggml_backend_reg_name(reg), lm_ggml_backend_reg_dev_count(reg)); +#endif + backends.push_back(reg); + for (size_t i = 0; i < lm_ggml_backend_reg_dev_count(reg); i++) { + register_device(lm_ggml_backend_reg_dev_get(reg, i)); + } + } + + void register_device(lm_ggml_backend_dev_t device) { +#ifndef NDEBUG + LM_GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, lm_ggml_backend_dev_name(device), lm_ggml_backend_dev_description(device)); +#endif + devices.push_back(device); + } +}; + +static lm_ggml_backend_registry & get_reg() { + static lm_ggml_backend_registry reg; + return reg; +} + +// Internal API +void lm_ggml_backend_register(lm_ggml_backend_reg_t reg) { + get_reg().register_backend(reg); +} + +void lm_ggml_backend_device_register(lm_ggml_backend_dev_t device) { + get_reg().register_device(device); +} + +// Backend (reg) enumeration +size_t lm_ggml_backend_reg_count() { + return get_reg().backends.size(); +} + +lm_ggml_backend_reg_t lm_ggml_backend_reg_get(size_t index) { + LM_GGML_ASSERT(index < lm_ggml_backend_reg_count()); + return get_reg().backends[index]; +} + +lm_ggml_backend_reg_t lm_ggml_backend_reg_by_name(const char * name) { + for (size_t i = 0; i < lm_ggml_backend_reg_count(); i++) { + lm_ggml_backend_reg_t reg = lm_ggml_backend_reg_get(i); + if (std::strcmp(lm_ggml_backend_reg_name(reg), name) == 0) { + return reg; + } + } + return NULL; +} + +// Device enumeration +size_t lm_ggml_backend_dev_count() { + return get_reg().devices.size(); +} + +lm_ggml_backend_dev_t lm_ggml_backend_dev_get(size_t index) { + LM_GGML_ASSERT(index < lm_ggml_backend_dev_count()); + return get_reg().devices[index]; +} + +lm_ggml_backend_dev_t lm_ggml_backend_dev_by_name(const char * name) { + for (size_t i = 0; i < lm_ggml_backend_dev_count(); i++) { + lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i); + if (strcmp(lm_ggml_backend_dev_name(dev), name) == 0) { + return dev; + } + } + return NULL; +} + +lm_ggml_backend_dev_t lm_ggml_backend_dev_by_type(enum lm_ggml_backend_dev_type type) { + for (size_t i = 0; i < lm_ggml_backend_dev_count(); i++) { + lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i); + if (lm_ggml_backend_dev_type(dev) == type) { + return dev; + } + } + return NULL; +} + +// Convenience functions +lm_ggml_backend_t lm_ggml_backend_init_by_name(const char * name, const char * params) { + lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_by_name(name); + if (!dev) { + return NULL; + } + return lm_ggml_backend_dev_init(dev, params); +} + +lm_ggml_backend_t lm_ggml_backend_init_by_type(enum lm_ggml_backend_dev_type type, const char * params) { + lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_by_type(type); + if (!dev) { + return NULL; + } + return lm_ggml_backend_dev_init(dev, params); +} + +lm_ggml_backend_t lm_ggml_backend_init_best(void) { + lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_GPU); + if (!dev) { + dev = lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU); + } + if (!dev) { + return NULL; + } + return lm_ggml_backend_dev_init(dev, NULL); +} diff --git a/cpp/ggml-backend.cpp b/cpp/ggml-backend.cpp index d8cb76d0..3f45df78 100644 --- a/cpp/ggml-backend.cpp +++ b/cpp/ggml-backend.cpp @@ -8,6 +8,7 @@ #include #endif +#include "ggml-backend.h" #include "ggml-backend-impl.h" #include "ggml-alloc.h" #include "ggml-impl.h" @@ -524,808 +525,6 @@ void * lm_ggml_backend_reg_get_proc_address(lm_ggml_backend_reg_t reg, const cha return reg->iface.get_proc_address(reg, name); } -// Backend registry - -#ifdef LM_GGML_USE_CUDA -#include "ggml-cuda.h" -#endif - -#ifdef LM_GGML_USE_METAL -#include "ggml-metal.h" -#endif - -#ifdef LM_GGML_USE_SYCL -#include "ggml-sycl.h" -#endif - -#ifdef LM_GGML_USE_VULKAN -#include "ggml-vulkan.h" -#endif - -#ifdef LM_GGML_USE_BLAS -#include "ggml-blas.h" -#endif - -#ifdef LM_GGML_USE_RPC -#include "ggml-rpc.h" -#endif - -#ifndef __AMX_INT8__ -#undef LM_GGML_USE_AMX -#endif - -#ifdef LM_GGML_USE_AMX -# include "ggml-amx.h" -#endif - -#ifdef LM_GGML_USE_CANN -#include "ggml-cann.h" -#endif - -#ifdef LM_GGML_USE_KOMPUTE -#include "ggml-kompute.h" -#endif - -struct lm_ggml_backend_registry { - std::vector backends; - std::vector devices; - - lm_ggml_backend_registry() { -#ifdef LM_GGML_USE_CUDA - register_backend(lm_ggml_backend_cuda_reg()); -#endif -#ifdef LM_GGML_USE_METAL -#include -#if !TARGET_OS_SIMULATOR - register_backend(lm_ggml_backend_metal_reg()); -#endif -#endif -#ifdef LM_GGML_USE_SYCL - register_backend(lm_ggml_backend_sycl_reg()); -#endif -#ifdef LM_GGML_USE_VULKAN - register_backend(lm_ggml_backend_vk_reg()); -#endif -#ifdef LM_GGML_USE_CANN - register_backend(lm_ggml_backend_cann_reg()); -#endif -#ifdef LM_GGML_USE_BLAS - register_backend(lm_ggml_backend_blas_reg()); -#endif -#ifdef LM_GGML_USE_RPC - register_backend(lm_ggml_backend_rpc_reg()); -#endif -#ifdef LM_GGML_USE_AMX - register_backend(lm_ggml_backend_amx_reg()); -#endif -#ifdef LM_GGML_USE_KOMPUTE - register_backend(lm_ggml_backend_kompute_reg()); -#endif - - register_backend(lm_ggml_backend_cpu_reg()); - } - - void register_backend(lm_ggml_backend_reg_t reg) { -#ifndef NDEBUG - LM_GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n", - __func__, lm_ggml_backend_reg_name(reg), lm_ggml_backend_reg_dev_count(reg)); -#endif - backends.push_back(reg); - for (size_t i = 0; i < lm_ggml_backend_reg_dev_count(reg); i++) { - register_device(lm_ggml_backend_reg_dev_get(reg, i)); - } - } - - void register_device(lm_ggml_backend_dev_t device) { -#ifndef NDEBUG - LM_GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, lm_ggml_backend_dev_name(device), lm_ggml_backend_dev_description(device)); -#endif - devices.push_back(device); - } -}; - -static lm_ggml_backend_registry & get_reg() { - static lm_ggml_backend_registry reg; - return reg; -} - -// Internal API -void lm_ggml_backend_register(lm_ggml_backend_reg_t reg) { - get_reg().register_backend(reg); -} - -void lm_ggml_backend_device_register(lm_ggml_backend_dev_t device) { - get_reg().register_device(device); -} - -// Backend (reg) enumeration -size_t lm_ggml_backend_reg_count() { - return get_reg().backends.size(); -} - -lm_ggml_backend_reg_t lm_ggml_backend_reg_get(size_t index) { - LM_GGML_ASSERT(index < lm_ggml_backend_reg_count()); - return get_reg().backends[index]; -} - -lm_ggml_backend_reg_t lm_ggml_backend_reg_by_name(const char * name) { - for (size_t i = 0; i < lm_ggml_backend_reg_count(); i++) { - lm_ggml_backend_reg_t reg = lm_ggml_backend_reg_get(i); - if (strcmp(lm_ggml_backend_reg_name(reg), name) == 0) { - return reg; - } - } - return NULL; -} - -// Device enumeration -size_t lm_ggml_backend_dev_count() { - return get_reg().devices.size(); -} - -lm_ggml_backend_dev_t lm_ggml_backend_dev_get(size_t index) { - LM_GGML_ASSERT(index < lm_ggml_backend_dev_count()); - return get_reg().devices[index]; -} - -lm_ggml_backend_dev_t lm_ggml_backend_dev_by_name(const char * name) { - for (size_t i = 0; i < lm_ggml_backend_dev_count(); i++) { - lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i); - if (strcmp(lm_ggml_backend_dev_name(dev), name) == 0) { - return dev; - } - } - return NULL; -} - -lm_ggml_backend_dev_t lm_ggml_backend_dev_by_type(enum lm_ggml_backend_dev_type type) { - for (size_t i = 0; i < lm_ggml_backend_dev_count(); i++) { - lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i); - if (lm_ggml_backend_dev_type(dev) == type) { - return dev; - } - } - return NULL; -} - -// Convenience functions -lm_ggml_backend_t lm_ggml_backend_init_by_name(const char * name, const char * params) { - lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_by_name(name); - if (!dev) { - return NULL; - } - return lm_ggml_backend_dev_init(dev, params); -} - -lm_ggml_backend_t lm_ggml_backend_init_by_type(enum lm_ggml_backend_dev_type type, const char * params) { - lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_by_type(type); - if (!dev) { - return NULL; - } - return lm_ggml_backend_dev_init(dev, params); -} - -lm_ggml_backend_t lm_ggml_backend_init_best(void) { - lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_GPU); - if (!dev) { - dev = lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU); - } - if (!dev) { - return NULL; - } - return lm_ggml_backend_dev_init(dev, NULL); -} - -// CPU backend - buffer - -static void * lm_ggml_backend_cpu_buffer_get_base(lm_ggml_backend_buffer_t buffer) { - uintptr_t data = (uintptr_t)buffer->context; - - // align the buffer - if (data % TENSOR_ALIGNMENT != 0) { - data = LM_GGML_PAD(data, TENSOR_ALIGNMENT); - } - - return (void *)data; -} - -static void lm_ggml_backend_cpu_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) { - lm_ggml_aligned_free(buffer->context, buffer->size); -} - -static void lm_ggml_backend_cpu_buffer_memset_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { - memset((char *)tensor->data + offset, value, size); - - LM_GGML_UNUSED(buffer); -} - -static void lm_ggml_backend_cpu_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - memcpy((char *)tensor->data + offset, data, size); - - LM_GGML_UNUSED(buffer); -} - -static void lm_ggml_backend_cpu_buffer_get_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) { - memcpy(data, (const char *)tensor->data + offset, size); - - LM_GGML_UNUSED(buffer); -} - -static bool lm_ggml_backend_cpu_buffer_cpy_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst) { - if (lm_ggml_backend_buffer_is_host(src->buffer)) { - memcpy(dst->data, src->data, lm_ggml_nbytes(src)); - return true; - } - return false; - - LM_GGML_UNUSED(buffer); -} - -static void lm_ggml_backend_cpu_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) { - memset(buffer->context, value, buffer->size); -} - -static const struct lm_ggml_backend_buffer_i lm_ggml_backend_cpu_buffer_i = { - /* .free_buffer = */ lm_ggml_backend_cpu_buffer_free_buffer, - /* .get_base = */ lm_ggml_backend_cpu_buffer_get_base, - /* .init_tensor = */ NULL, // no initialization required - /* .memset_tensor = */ lm_ggml_backend_cpu_buffer_memset_tensor, - /* .set_tensor = */ lm_ggml_backend_cpu_buffer_set_tensor, - /* .get_tensor = */ lm_ggml_backend_cpu_buffer_get_tensor, - /* .cpy_tensor = */ lm_ggml_backend_cpu_buffer_cpy_tensor, - /* .clear = */ lm_ggml_backend_cpu_buffer_clear, - /* .reset = */ NULL, -}; - -static const struct lm_ggml_backend_buffer_i lm_ggml_backend_cpu_buffer_from_ptr_i = { - /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed - /* .get_base = */ lm_ggml_backend_cpu_buffer_get_base, - /* .init_tensor = */ NULL, // no initialization required - /* .memset_tensor = */ lm_ggml_backend_cpu_buffer_memset_tensor, - /* .set_tensor = */ lm_ggml_backend_cpu_buffer_set_tensor, - /* .get_tensor = */ lm_ggml_backend_cpu_buffer_get_tensor, - /* .cpy_tensor = */ lm_ggml_backend_cpu_buffer_cpy_tensor, - /* .clear = */ lm_ggml_backend_cpu_buffer_clear, - /* .reset = */ NULL, -}; - -// CPU backend - buffer type - -static const char * lm_ggml_backend_cpu_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { - return "CPU"; - - LM_GGML_UNUSED(buft); -} - -static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { - void * data = lm_ggml_aligned_malloc(size); - - if (data == NULL) { - LM_GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size); - return NULL; - } - - return lm_ggml_backend_buffer_init(buft, lm_ggml_backend_cpu_buffer_i, data, size); -} - -static size_t lm_ggml_backend_cpu_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) { - return TENSOR_ALIGNMENT; - - LM_GGML_UNUSED(buft); -} - -static bool lm_ggml_backend_cpu_buffer_type_is_host(lm_ggml_backend_buffer_type_t buft) { - return true; - - LM_GGML_UNUSED(buft); -} - -lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_buffer_type(void) { - static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type = { - /* .iface = */ { - /* .get_name = */ lm_ggml_backend_cpu_buffer_type_get_name, - /* .alloc_buffer = */ lm_ggml_backend_cpu_buffer_type_alloc_buffer, - /* .get_alignment = */ lm_ggml_backend_cpu_buffer_type_get_alignment, - /* .get_max_size = */ NULL, // defaults to SIZE_MAX - /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes - /* .is_host = */ lm_ggml_backend_cpu_buffer_type_is_host, - }, - /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0), - /* .context = */ NULL, - }; - - return &lm_ggml_backend_cpu_buffer_type; -} - -static const char * lm_ggml_backend_cpu_buffer_from_ptr_type_get_name(lm_ggml_backend_buffer_type_t buft) { - return "CPU_Mapped"; - - LM_GGML_UNUSED(buft); -} - -static lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_buffer_from_ptr_type(void) { - static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type = { - /* .iface = */ { - /* .get_name = */ lm_ggml_backend_cpu_buffer_from_ptr_type_get_name, - /* .alloc_buffer = */ lm_ggml_backend_cpu_buffer_type_alloc_buffer, - /* .get_alignment = */ lm_ggml_backend_cpu_buffer_type_get_alignment, - /* .get_max_size = */ NULL, // defaults to SIZE_MAX - /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes - /* .is_host = */ lm_ggml_backend_cpu_buffer_type_is_host, - }, - /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0), - /* .context = */ NULL, - }; - - return &lm_ggml_backend_cpu_buffer_type; -} - -#ifdef LM_GGML_USE_CPU_HBM - -// buffer type HBM - -#include - -static const char * lm_ggml_backend_cpu_hbm_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { - return "CPU_HBM"; - - LM_GGML_UNUSED(buft); -} - -static void lm_ggml_backend_cpu_hbm_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) { - hbw_free(buffer->context); -} - -static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_hbm_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { - void * ptr; - int result = hbw_posix_memalign(&ptr, lm_ggml_backend_cpu_buffer_type_get_alignment(buft), size); - if (result != 0) { - LM_GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size); - return NULL; - } - - lm_ggml_backend_buffer_t buffer = lm_ggml_backend_cpu_buffer_from_ptr(ptr, size); - buffer->buft = buft; - buffer->iface.free_buffer = lm_ggml_backend_cpu_hbm_buffer_free_buffer; - - return buffer; -} - -lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_hbm_buffer_type(void) { - static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type_hbm = { - /* .iface = */ { - /* .get_name = */ lm_ggml_backend_cpu_hbm_buffer_type_get_name, - /* .alloc_buffer = */ lm_ggml_backend_cpu_hbm_buffer_type_alloc_buffer, - /* .get_alignment = */ lm_ggml_backend_cpu_buffer_type_get_alignment, - /* .get_max_size = */ NULL, // defaults to SIZE_MAX - /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes - /* .is_host = */ lm_ggml_backend_cpu_buffer_type_is_host, - }, - /* .context = */ NULL, - }; - - return &lm_ggml_backend_cpu_buffer_type_hbm; -} -#endif - -static lm_ggml_backend_buffer_type_t * lm_ggml_backend_cpu_get_extra_bufts(lm_ggml_backend_dev_t device) { - static lm_ggml_backend_buffer_type_t bufts[] = { -#ifdef LM_GGML_USE_CPU_HBM - lm_ggml_backend_cpu_hbm_buffer_type(), -#endif - NULL - }; - - return bufts; - - LM_GGML_UNUSED(device); -} - -// CPU backend - backend (stream) - -struct lm_ggml_backend_cpu_context { - int n_threads; - lm_ggml_threadpool_t threadpool; - - uint8_t * work_data; - size_t work_size; - - lm_ggml_abort_callback abort_callback; - void * abort_callback_data; -}; - -static const char * lm_ggml_backend_cpu_get_name(lm_ggml_backend_t backend) { - return "CPU"; - - LM_GGML_UNUSED(backend); -} - -static void lm_ggml_backend_cpu_free(lm_ggml_backend_t backend) { - struct lm_ggml_backend_cpu_context * cpu_ctx = (struct lm_ggml_backend_cpu_context *)backend->context; - delete[] cpu_ctx->work_data; - delete cpu_ctx; - delete backend; -} - -struct lm_ggml_backend_plan_cpu { - struct lm_ggml_cplan cplan; - struct lm_ggml_cgraph cgraph; -}; - -static lm_ggml_backend_graph_plan_t lm_ggml_backend_cpu_graph_plan_create(lm_ggml_backend_t backend, const struct lm_ggml_cgraph * cgraph) { - struct lm_ggml_backend_cpu_context * cpu_ctx = (struct lm_ggml_backend_cpu_context *)backend->context; - - struct lm_ggml_backend_plan_cpu * cpu_plan = new lm_ggml_backend_plan_cpu; - - cpu_plan->cplan = lm_ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool); - cpu_plan->cgraph = *cgraph; // FIXME: deep copy - - if (cpu_plan->cplan.work_size > 0) { - cpu_plan->cplan.work_data = new uint8_t[cpu_plan->cplan.work_size]; - if (cpu_plan->cplan.work_data == NULL) { - delete cpu_plan; - return NULL; - } - } - - cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback; - cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data; - - return cpu_plan; -} - -static void lm_ggml_backend_cpu_graph_plan_free(lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan) { - struct lm_ggml_backend_plan_cpu * cpu_plan = (struct lm_ggml_backend_plan_cpu *)plan; - - delete[] cpu_plan->cplan.work_data; - delete cpu_plan; - - LM_GGML_UNUSED(backend); -} - -static enum lm_ggml_status lm_ggml_backend_cpu_graph_plan_compute(lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan) { - struct lm_ggml_backend_plan_cpu * cpu_plan = (struct lm_ggml_backend_plan_cpu *)plan; - - return lm_ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); - - LM_GGML_UNUSED(backend); -} - -static enum lm_ggml_status lm_ggml_backend_cpu_graph_compute(lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph) { - struct lm_ggml_backend_cpu_context * cpu_ctx = (struct lm_ggml_backend_cpu_context *)backend->context; - - struct lm_ggml_cplan cplan = lm_ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool); - - if (cpu_ctx->work_size < cplan.work_size) { - delete[] cpu_ctx->work_data; - cpu_ctx->work_data = new uint8_t[cplan.work_size]; - if (cpu_ctx->work_data == NULL) { - cpu_ctx->work_size = 0; - return LM_GGML_STATUS_ALLOC_FAILED; - } - cpu_ctx->work_size = cplan.work_size; - } - cplan.work_data = (uint8_t *)cpu_ctx->work_data; - - cplan.abort_callback = cpu_ctx->abort_callback; - cplan.abort_callback_data = cpu_ctx->abort_callback_data; - - return lm_ggml_graph_compute(cgraph, &cplan); -} - -static const struct lm_ggml_backend_i lm_ggml_backend_cpu_i = { - /* .get_name = */ lm_ggml_backend_cpu_get_name, - /* .free = */ lm_ggml_backend_cpu_free, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, - /* .graph_plan_create = */ lm_ggml_backend_cpu_graph_plan_create, - /* .graph_plan_free = */ lm_ggml_backend_cpu_graph_plan_free, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ lm_ggml_backend_cpu_graph_plan_compute, - /* .graph_compute = */ lm_ggml_backend_cpu_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, -}; - -static lm_ggml_guid_t lm_ggml_backend_cpu_guid(void) { - static lm_ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 }; - return &guid; -} - -lm_ggml_backend_t lm_ggml_backend_cpu_init(void) { - struct lm_ggml_backend_cpu_context * ctx = new lm_ggml_backend_cpu_context; - if (ctx == NULL) { - return NULL; - } - - ctx->n_threads = LM_GGML_DEFAULT_N_THREADS; - ctx->threadpool = NULL; - ctx->work_data = NULL; - ctx->work_size = 0; - ctx->abort_callback = NULL; - ctx->abort_callback_data = NULL; - - lm_ggml_backend_t cpu_backend = new lm_ggml_backend { - /* .guid = */ lm_ggml_backend_cpu_guid(), - /* .interface = */ lm_ggml_backend_cpu_i, - /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0), - /* .context = */ ctx, - }; - - if (cpu_backend == NULL) { - delete ctx; - return NULL; - } - - return cpu_backend; -} - -bool lm_ggml_backend_is_cpu(lm_ggml_backend_t backend) { - return backend != NULL && lm_ggml_guid_matches(backend->guid, lm_ggml_backend_cpu_guid()); -} - -void lm_ggml_backend_cpu_set_n_threads(lm_ggml_backend_t backend_cpu, int n_threads) { - LM_GGML_ASSERT(lm_ggml_backend_is_cpu(backend_cpu)); - - struct lm_ggml_backend_cpu_context * ctx = (struct lm_ggml_backend_cpu_context *)backend_cpu->context; - ctx->n_threads = n_threads; -} - -void lm_ggml_backend_cpu_set_threadpool(lm_ggml_backend_t backend_cpu, lm_ggml_threadpool_t threadpool) { - LM_GGML_ASSERT(lm_ggml_backend_is_cpu(backend_cpu)); - - struct lm_ggml_backend_cpu_context * ctx = (struct lm_ggml_backend_cpu_context *)backend_cpu->context; - - if (ctx->threadpool && ctx->threadpool != threadpool) { - // already had a different threadpool, pause/suspend it before switching - lm_ggml_threadpool_pause(ctx->threadpool); - } - ctx->threadpool = threadpool; -} - -void lm_ggml_backend_cpu_set_abort_callback(lm_ggml_backend_t backend_cpu, lm_ggml_abort_callback abort_callback, void * abort_callback_data) { - LM_GGML_ASSERT(lm_ggml_backend_is_cpu(backend_cpu)); - - struct lm_ggml_backend_cpu_context * ctx = (struct lm_ggml_backend_cpu_context *)backend_cpu->context; - ctx->abort_callback = abort_callback; - ctx->abort_callback_data = abort_callback_data; -} - -lm_ggml_backend_buffer_t lm_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { - LM_GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned"); - return lm_ggml_backend_buffer_init(lm_ggml_backend_cpu_buffer_from_ptr_type(), lm_ggml_backend_cpu_buffer_from_ptr_i, ptr, size); -} - -// CPU backend - device - -struct lm_ggml_backend_cpu_device_context { - std::string description = "CPU"; - - lm_ggml_backend_cpu_device_context() { -#ifdef __APPLE__ - size_t len = 0; - if (!sysctlbyname("machdep.cpu.brand_string", NULL, &len, NULL, 0)) { - description.resize(len); - sysctlbyname("machdep.cpu.brand_string", &description[0], &len, NULL, 0); // NOLINT - } -#elif defined(__linux__) - FILE * f = fopen("/proc/cpuinfo", "r"); - if (f) { - char buf[1024]; - while (fgets(buf, sizeof(buf), f)) { - if (strncmp(buf, "model name", 10) == 0) { - char * p = strchr(buf, ':'); - if (p) { - p++; - while (std::isspace(*p)) { - p++; - } - while (std::isspace(p[strlen(p) - 1])) { - p[strlen(p) - 1] = '\0'; - } - description = p; - break; - } - } - } - fclose(f); - } -#elif defined(_WIN32) - HKEY hKey; - if (RegOpenKeyEx(HKEY_LOCAL_MACHINE, - TEXT("HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"), - 0, - KEY_READ, - &hKey) == ERROR_SUCCESS) { - DWORD cpu_brand_size = 0; - if (RegQueryValueExA(hKey, - TEXT("ProcessorNameString"), - NULL, - NULL, - NULL, - &cpu_brand_size) == ERROR_SUCCESS) { - description.resize(cpu_brand_size); - if (RegQueryValueExA(hKey, - TEXT("ProcessorNameString"), - NULL, - NULL, - (LPBYTE)&description[0], // NOLINT - &cpu_brand_size) == ERROR_SUCCESS) { - if (description.find('\0') != std::string::npos) { - description.resize(description.find('\0')); - } - } - } - RegCloseKey(hKey); - } -#endif - } -}; - -static const char * lm_ggml_backend_cpu_device_get_name(lm_ggml_backend_dev_t dev) { - return "CPU"; - - LM_GGML_UNUSED(dev); -} - -static const char * lm_ggml_backend_cpu_device_get_description(lm_ggml_backend_dev_t dev) { - struct lm_ggml_backend_cpu_device_context * ctx = (struct lm_ggml_backend_cpu_device_context *)dev->context; - - return ctx->description.c_str(); -} - -static void lm_ggml_backend_cpu_device_get_memory(lm_ggml_backend_dev_t dev, size_t * free, size_t * total) { - // TODO - *free = 0; - *total = 0; - - LM_GGML_UNUSED(dev); -} - -static enum lm_ggml_backend_dev_type lm_ggml_backend_cpu_device_get_type(lm_ggml_backend_dev_t dev) { - return LM_GGML_BACKEND_DEVICE_TYPE_CPU; - - LM_GGML_UNUSED(dev); -} - -static void lm_ggml_backend_cpu_device_get_props(lm_ggml_backend_dev_t dev, struct lm_ggml_backend_dev_props * props) { - props->name = lm_ggml_backend_cpu_device_get_name(dev); - props->description = lm_ggml_backend_cpu_device_get_description(dev); - props->type = lm_ggml_backend_cpu_device_get_type(dev); - lm_ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total); - props->caps = { - /* .async = */ false, - /* .host_buffer = */ false, - /* .buffer_from_host_ptr = */ true, - /* .events = */ false, - }; -} - -static lm_ggml_backend_t lm_ggml_backend_cpu_device_init_backend(lm_ggml_backend_dev_t dev, const char * params) { - return lm_ggml_backend_cpu_init(); - - LM_GGML_UNUSED(dev); - LM_GGML_UNUSED(params); -} - -static lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_device_get_buffer_type(lm_ggml_backend_dev_t dev) { - return lm_ggml_backend_cpu_buffer_type(); - - LM_GGML_UNUSED(dev); -} - -static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_device_buffer_from_host_ptr(lm_ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { - return lm_ggml_backend_cpu_buffer_from_ptr(ptr, size); - - LM_GGML_UNUSED(dev); - LM_GGML_UNUSED(max_tensor_size); -} - -static bool lm_ggml_backend_cpu_device_supports_op(lm_ggml_backend_dev_t dev, const struct lm_ggml_tensor * op) { - switch (op->op) { - case LM_GGML_OP_CPY: - return - op->type != LM_GGML_TYPE_IQ2_XXS && - op->type != LM_GGML_TYPE_IQ2_XS && - op->type != LM_GGML_TYPE_IQ1_S && - op->type != LM_GGML_TYPE_IQ1_M; // missing type_traits.from_float - case LM_GGML_OP_MUL_MAT: - return op->src[1]->type == LM_GGML_TYPE_F32 || op->src[1]->type == lm_ggml_get_type_traits(op->src[0]->type)->vec_dot_type; - case LM_GGML_OP_ROPE_BACK: - return op->src[2] == NULL && (op->op_params[2] & 4) == 0; - case LM_GGML_OP_IM2COL_BACK: - return op->src[0]->type == LM_GGML_TYPE_F32 && op->src[1]->type == LM_GGML_TYPE_F32; - case LM_GGML_OP_OUT_PROD: - return (op->src[0]->type == LM_GGML_TYPE_F32 || lm_ggml_is_quantized(op->src[0]->type)) && op->src[1]->type == LM_GGML_TYPE_F32; - default: - return true; - } - - LM_GGML_UNUSED(dev); -} - -static bool lm_ggml_backend_cpu_device_supports_buft(lm_ggml_backend_dev_t dev, lm_ggml_backend_buffer_type_t buft) { - return lm_ggml_backend_buft_is_host(buft); - - LM_GGML_UNUSED(dev); -} - -static const struct lm_ggml_backend_device_i lm_ggml_backend_cpu_device_i = { - /* .get_name = */ lm_ggml_backend_cpu_device_get_name, - /* .get_description = */ lm_ggml_backend_cpu_device_get_description, - /* .get_memory = */ lm_ggml_backend_cpu_device_get_memory, - /* .get_type = */ lm_ggml_backend_cpu_device_get_type, - /* .get_props = */ lm_ggml_backend_cpu_device_get_props, - /* .init_backend = */ lm_ggml_backend_cpu_device_init_backend, - /* .get_buffer_type = */ lm_ggml_backend_cpu_device_get_buffer_type, - /* .get_host_buffer_type = */ NULL, - /* .buffer_from_host_ptr = */ lm_ggml_backend_cpu_device_buffer_from_host_ptr, - /* .supports_op = */ lm_ggml_backend_cpu_device_supports_op, - /* .supports_buft = */ lm_ggml_backend_cpu_device_supports_buft, - /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, -}; - -// CPU backend - backend (reg) - -static const char * lm_ggml_backend_cpu_reg_get_name(lm_ggml_backend_reg_t reg) { - return "CPU"; - - LM_GGML_UNUSED(reg); -} - -static size_t lm_ggml_backend_cpu_reg_get_device_count(lm_ggml_backend_reg_t reg) { - return 1; - - LM_GGML_UNUSED(reg); -} - -static lm_ggml_backend_dev_t lm_ggml_backend_cpu_reg_get_device(lm_ggml_backend_reg_t reg, size_t index) { - LM_GGML_ASSERT(index == 0); - - static lm_ggml_backend_cpu_device_context ctx; - static lm_ggml_backend_device lm_ggml_backend_cpu_device = { - /* .iface = */ lm_ggml_backend_cpu_device_i, - /* .reg = */ reg, - /* .context = */ &ctx, - }; - - return &lm_ggml_backend_cpu_device; -} - -static void * lm_ggml_backend_cpu_get_proc_address(lm_ggml_backend_reg_t reg, const char * name) { - if (strcmp(name, "lm_ggml_backend_set_n_threads") == 0) { - return (void *)lm_ggml_backend_cpu_set_n_threads; - } - if (strcmp(name, "lm_ggml_backend_dev_get_extra_bufts") == 0) { - return (void *)lm_ggml_backend_cpu_get_extra_bufts; - } - - return NULL; - - LM_GGML_UNUSED(reg); -} - -static const struct lm_ggml_backend_reg_i lm_ggml_backend_cpu_reg_i = { - /* .get_name = */ lm_ggml_backend_cpu_reg_get_name, - /* .get_device_count = */ lm_ggml_backend_cpu_reg_get_device_count, - /* .get_device = */ lm_ggml_backend_cpu_reg_get_device, - /* .get_proc_address = */ lm_ggml_backend_cpu_get_proc_address, -}; - -lm_ggml_backend_reg_t lm_ggml_backend_cpu_reg(void) { - static struct lm_ggml_backend_reg lm_ggml_backend_cpu_reg = { - /* .iface = */ lm_ggml_backend_cpu_reg_i, - /* .context = */ NULL, - }; - - return &lm_ggml_backend_cpu_reg; -} - // multi-buffer buffer struct lm_ggml_backend_multi_buffer_context { @@ -2250,7 +1449,7 @@ lm_ggml_backend_sched_t lm_ggml_backend_sched_new( bool parallel) { LM_GGML_ASSERT(n_backends > 0); LM_GGML_ASSERT(n_backends <= LM_GGML_SCHED_MAX_BACKENDS); - LM_GGML_ASSERT(lm_ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU + LM_GGML_ASSERT(lm_ggml_backend_dev_type(lm_ggml_backend_get_device(backends[n_backends - 1])) == LM_GGML_BACKEND_DEVICE_TYPE_CPU); struct lm_ggml_backend_sched * sched = (lm_ggml_backend_sched *) calloc(1, sizeof(struct lm_ggml_backend_sched)); @@ -2645,3 +1844,154 @@ bool lm_ggml_backend_compare_graph_backend(lm_ggml_backend_t backend1, lm_ggml_b return true; } + +// CPU backend - buffer + +static void * lm_ggml_backend_cpu_buffer_get_base(lm_ggml_backend_buffer_t buffer) { + uintptr_t data = (uintptr_t)buffer->context; + + // align the buffer + if (data % TENSOR_ALIGNMENT != 0) { + data = LM_GGML_PAD(data, TENSOR_ALIGNMENT); + } + + return (void *)data; +} + +static void lm_ggml_backend_cpu_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) { + lm_ggml_aligned_free(buffer->context, buffer->size); +} + +static void lm_ggml_backend_cpu_buffer_memset_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + memset((char *)tensor->data + offset, value, size); + + LM_GGML_UNUSED(buffer); +} + +static void lm_ggml_backend_cpu_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + memcpy((char *)tensor->data + offset, data, size); + + LM_GGML_UNUSED(buffer); +} + +static void lm_ggml_backend_cpu_buffer_get_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) { + memcpy(data, (const char *)tensor->data + offset, size); + + LM_GGML_UNUSED(buffer); +} + +static bool lm_ggml_backend_cpu_buffer_cpy_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst) { + if (lm_ggml_backend_buffer_is_host(src->buffer)) { + memcpy(dst->data, src->data, lm_ggml_nbytes(src)); + return true; + } + return false; + + LM_GGML_UNUSED(buffer); +} + +static void lm_ggml_backend_cpu_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) { + memset(buffer->context, value, buffer->size); +} + +static const struct lm_ggml_backend_buffer_i lm_ggml_backend_cpu_buffer_i = { + /* .free_buffer = */ lm_ggml_backend_cpu_buffer_free_buffer, + /* .get_base = */ lm_ggml_backend_cpu_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .memset_tensor = */ lm_ggml_backend_cpu_buffer_memset_tensor, + /* .set_tensor = */ lm_ggml_backend_cpu_buffer_set_tensor, + /* .get_tensor = */ lm_ggml_backend_cpu_buffer_get_tensor, + /* .cpy_tensor = */ lm_ggml_backend_cpu_buffer_cpy_tensor, + /* .clear = */ lm_ggml_backend_cpu_buffer_clear, + /* .reset = */ NULL, +}; + +static const struct lm_ggml_backend_buffer_i lm_ggml_backend_cpu_buffer_from_ptr_i = { + /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed + /* .get_base = */ lm_ggml_backend_cpu_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .memset_tensor = */ lm_ggml_backend_cpu_buffer_memset_tensor, + /* .set_tensor = */ lm_ggml_backend_cpu_buffer_set_tensor, + /* .get_tensor = */ lm_ggml_backend_cpu_buffer_get_tensor, + /* .cpy_tensor = */ lm_ggml_backend_cpu_buffer_cpy_tensor, + /* .clear = */ lm_ggml_backend_cpu_buffer_clear, + /* .reset = */ NULL, +}; + +// CPU backend buffer type + +// this buffer type is defined here to make it available to all backends + +static const char * lm_ggml_backend_cpu_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { + return "CPU"; + + LM_GGML_UNUSED(buft); +} + +static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { + void * data = lm_ggml_aligned_malloc(size); + + if (data == NULL) { + LM_GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size); + return NULL; + } + + return lm_ggml_backend_buffer_init(buft, lm_ggml_backend_cpu_buffer_i, data, size); +} + +static size_t lm_ggml_backend_cpu_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + LM_GGML_UNUSED(buft); +} + +static bool lm_ggml_backend_cpu_buffer_type_is_host(lm_ggml_backend_buffer_type_t buft) { + return true; + + LM_GGML_UNUSED(buft); +} + +lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_buffer_type(void) { + static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type = { + /* .iface = */ { + /* .get_name = */ lm_ggml_backend_cpu_buffer_type_get_name, + /* .alloc_buffer = */ lm_ggml_backend_cpu_buffer_type_alloc_buffer, + /* .get_alignment = */ lm_ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes + /* .is_host = */ lm_ggml_backend_cpu_buffer_type_is_host, + }, + /* .device = */ NULL, // FIXME lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0), + /* .context = */ NULL, + }; + + return &lm_ggml_backend_cpu_buffer_type; +} + +static const char * lm_ggml_backend_cpu_buffer_from_ptr_type_get_name(lm_ggml_backend_buffer_type_t buft) { + return "CPU_Mapped"; + + LM_GGML_UNUSED(buft); +} + +static lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_buffer_from_ptr_type(void) { + static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type = { + /* .iface = */ { + /* .get_name = */ lm_ggml_backend_cpu_buffer_from_ptr_type_get_name, + /* .alloc_buffer = */ lm_ggml_backend_cpu_buffer_type_alloc_buffer, + /* .get_alignment = */ lm_ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes + /* .is_host = */ lm_ggml_backend_cpu_buffer_type_is_host, + }, + /* .device = */ NULL, // FIXME lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0), + /* .context = */ NULL, + }; + + return &lm_ggml_backend_cpu_buffer_type; +} + +lm_ggml_backend_buffer_t lm_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { + LM_GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned"); + return lm_ggml_backend_buffer_init(lm_ggml_backend_cpu_buffer_from_ptr_type(), lm_ggml_backend_cpu_buffer_from_ptr_i, ptr, size); +} diff --git a/cpp/ggml-backend.h b/cpp/ggml-backend.h index e85bdca8..6a5851c7 100644 --- a/cpp/ggml-backend.h +++ b/cpp/ggml-backend.h @@ -3,6 +3,20 @@ #include "ggml.h" #include "ggml-alloc.h" +#ifdef LM_GGML_BACKEND_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef LM_GGML_BACKEND_BUILD +# define LM_GGML_BACKEND_API __declspec(dllexport) extern +# else +# define LM_GGML_BACKEND_API __declspec(dllimport) extern +# endif +# else +# define LM_GGML_BACKEND_API __attribute__ ((visibility ("default"))) extern +# endif +#else +# define LM_GGML_BACKEND_API extern +#endif + #ifdef __cplusplus extern "C" { #endif @@ -305,27 +319,10 @@ extern "C" { LM_GGML_API void lm_ggml_backend_tensor_alloc(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, void * addr); LM_GGML_API void lm_ggml_backend_view_init(struct lm_ggml_tensor * tensor); - // - // CPU backend - // - - LM_GGML_API lm_ggml_backend_t lm_ggml_backend_cpu_init(void); - - LM_GGML_API bool lm_ggml_backend_is_cpu (lm_ggml_backend_t backend); - LM_GGML_API void lm_ggml_backend_cpu_set_n_threads (lm_ggml_backend_t backend_cpu, int n_threads); - LM_GGML_API void lm_ggml_backend_cpu_set_threadpool (lm_ggml_backend_t backend_cpu, lm_ggml_threadpool_t threadpool); - LM_GGML_API void lm_ggml_backend_cpu_set_abort_callback(lm_ggml_backend_t backend_cpu, lm_ggml_abort_callback abort_callback, void * abort_callback_data); - - // Create a backend buffer from an existing pointer + // CPU buffer types are always available LM_GGML_API lm_ggml_backend_buffer_t lm_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_buffer_type(void); - LM_GGML_API lm_ggml_backend_reg_t lm_ggml_backend_cpu_reg(void); - -#ifdef LM_GGML_USE_CPU_HBM - LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_hbm_buffer_type(void); -#endif - #ifdef __cplusplus } #endif diff --git a/cpp/ggml-cpu-aarch64.c b/cpp/ggml-cpu-aarch64.c new file mode 100644 index 00000000..40e033b9 --- /dev/null +++ b/cpp/ggml-cpu-aarch64.c @@ -0,0 +1,3560 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT +// + +#define LM_GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#include "ggml-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" +#include "ggml-cpu-impl.h" + +#include +#include +#include +#include +#include // for qsort +#include // for LM_GGML_ASSERT + +#include "ggml-cpu-aarch64.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#elif defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +#define UNUSED LM_GGML_UNUSED + +// Functions to create the interleaved data layout formats + +// interleave 4 block_q4_0s in blocks of blck_size_interleave +// returns an interleaved block_q4_0x4 +// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks +// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave +// +// - in : an array of block_q4_0 pointers +// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of +// blck_size_interleave bytes +// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes +// from bias offset form to pure sign form (this saves subtract +// operations durin unpacking) +// +#if defined(__AVX__) +#if defined(__F16C__) +#if defined(__AVX512F__) +#define LM_GGML_F32Cx8x2_LOAD(x, y) _mm512_cvtph_ps(_mm256_set_m128i(_mm_loadu_si128((const __m128i *)(y)), _mm_loadu_si128((const __m128i *)(x)))) +#define LM_GGML_F32Cx16_REPEAT_LOAD(x) _mm512_cvtph_ps(_mm256_set_m128i(x, x)) +#endif +// the _mm256_cvt intrinsics require F16C +#define LM_GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) +#define LM_GGML_F32Cx8_REPEAT_LOAD(x, loadMask) _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68)) +#define LM_GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask)) +#else +#if defined(__AVX512F__) +static inline __m512 __avx512_f32cx8x2_load(lm_ggml_fp16_t *x, lm_ggml_fp16_t *y) { + float tmp[16]; + + for (int i = 0; i < 8; i++) { + tmp[i] = LM_GGML_FP16_TO_FP32(x[i]); + } + + for (int i = 0; i < 8; i++) { + tmp[i + 8] = LM_GGML_FP16_TO_FP32(y[i]); + } + + return _mm512_loadu_ps(tmp); +} +static inline __m512 __avx512_repeat_f32cx16_load(__m128i x) { + float tmp[16]; + uint16_t tmphalf[8]; + _mm_storeu_si128((__m128i*)tmphalf, x); + + for (int i = 0; i < 4; i++) { + tmp[i] = LM_GGML_FP16_TO_FP32(tmphalf[i]); + tmp[i + 4] = LM_GGML_FP16_TO_FP32(tmphalf[i]); + tmp[i + 8] = LM_GGML_FP16_TO_FP32(tmphalf[i]); + tmp[i + 12] = LM_GGML_FP16_TO_FP32(tmphalf[i]); + } + + return _mm512_loadu_ps(tmp); +} +#endif +static inline __m256 __avx_f32cx8_load(lm_ggml_fp16_t *x) { + float tmp[8]; + + for (int i = 0; i < 8; i++) { + tmp[i] = LM_GGML_FP16_TO_FP32(x[i]); + } + + return _mm256_loadu_ps(tmp); +} +static inline __m256 __avx_repeat_f32cx8_load(lm_ggml_fp16_t *x) { + float tmp[8]; + + for (int i = 0; i < 4; i++) { + tmp[i] = LM_GGML_FP16_TO_FP32(x[i]); + tmp[i + 4] = LM_GGML_FP16_TO_FP32(x[i]); + } + + return _mm256_loadu_ps(tmp); +} +static inline __m256 __avx_rearranged_f32cx8_load(lm_ggml_fp16_t *x, __m128i arrangeMask) { + uint16_t tmphalf[8]; + float tmp[8]; + + _mm_storeu_si128((__m128i*)tmphalf, _mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask)); + for (int i = 0; i < 8; i++) { + tmp[i] = LM_GGML_FP16_TO_FP32(tmphalf[i]); + } + + return _mm256_loadu_ps(tmp); +} + +#define LM_GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) +#define LM_GGML_F32Cx8_REPEAT_LOAD(x, loadMask) __avx_repeat_f32cx8_load(x) +#define LM_GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) __avx_rearranged_f32cx8_load(x, arrangeMask) +#if defined(__AVX512F__) +#define LM_GGML_F32Cx8x2_LOAD(x, y) __avx512_f32cx8x2_load(x, y) +#define LM_GGML_F32Cx16_REPEAT_LOAD(x) __avx512_repeat_f32cx16_load(x) +#endif +#endif +#endif + + +#if defined(__AVX2__) || defined(__AVX512F__) +#if defined(__AVX512F__) +// add int16_t pairwise and return as 512 bit int vector +static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) { + const __m512i ones = _mm512_set1_epi16(1); + return _mm512_madd_epi16(ones, x); +} + +static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) { +#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) + const __m512i zero = _mm512_setzero_si512(); + return _mm512_dpbusd_epi32(zero, ax, sy); +#else + // Perform multiplication and create 16-bit values + const __m512i dot = _mm512_maddubs_epi16(ax, sy); + return sum_i16_pairs_int_32x16(dot); +#endif +} + +// multiply int8_t, add results pairwise twice and return as 512 bit int vector +static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) { + const __m512i zero = _mm512_setzero_si512(); + // Get absolute values of x vectors + const __m512i ax = _mm512_abs_epi8(x); + // Sign the values of the y vectors + __mmask64 blt0 = _mm512_movepi8_mask(x); + const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y); + return mul_sum_us8_pairs_int32x16(ax, sy); +} +#endif + +// add int16_t pairwise and return as 256 bit int vector +static inline __m256i sum_i16_pairs_int32x8(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + return _mm256_madd_epi16(ones, x); +} + +static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) { +#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) + const __m256i zero = _mm256_setzero_si256(); + return _mm256_dpbusd_epi32(zero, ax, sy); +#else + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_int32x8(dot); +#endif +} + +// Integer variant of the function defined in ggml-quants.c +// multiply int8_t, add results pairwise twice and return as 256 bit int vector +static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + return _mm256_dpbssd_epi32(zero, x, y); +#else + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + return mul_sum_us8_pairs_int32x8(ax, sy); +#endif +} +#endif + +static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; + +#if defined(__ARM_NEON) + float32x4_t srcv[4][8]; + float id[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 8; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3); + } + } +#else + // scalar + const int blck_size_interleave = 4; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +#endif +} + +static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; + +#if defined(__ARM_NEON) + float32x4_t srcv[4][8]; + float id[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 4; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][2 * j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][2 * j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][2 * j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3); + } + } +#elif defined(__AVX2__) || defined(__AVX__) + float id[4]; + __m256 srcv[4][4]; + __m256 idvec[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 32 ); + __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 8 ); + __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 16 ); + __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 24 ); + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Divided by 127.f to mirror results in quantize_row_q8_0 + const float d = maxScalar / 127.f; + id[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; //d ? 1.0f / d : 0.0f; + + // Store the scale for the individual block + y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); + + // Store the values in blocks of eight values - Aim is to use these later for block interleaving + srcv[row_iter][0] = v0; + srcv[row_iter][1] = v1; + srcv[row_iter][2] = v2; + srcv[row_iter][3] = v3; + idvec[row_iter] = _mm256_set1_ps(id[row_iter]); + } + + // The loop iterates four times - The aim is to get 4 corresponding chunks of eight bytes from the original weight blocks that are interleaved + for (int j = 0; j < 4; j++) { + // Apply the multiplier + __m256 v0 = _mm256_mul_ps(srcv[0][j], idvec[0]); + __m256 v1 = _mm256_mul_ps(srcv[1][j], idvec[1]); + __m256 v2 = _mm256_mul_ps(srcv[2][j], idvec[2]); + __m256 v3 = _mm256_mul_ps(srcv[3][j], idvec[3]); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); + i2 = _mm256_packs_epi32( i2, i3 ); + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); + + // Permute and store the quantized weights in the required order after the pack instruction + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j + 16), ni4); +#endif + } + } +#else + // scalar + const int blck_size_interleave = 8; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +#endif +} + +void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { + assert(nrow == 4); + UNUSED(nrow); + if (blck_size_interleave == 4) { + quantize_q8_0_4x4(x, vy, n_per_row); + } else if (blck_size_interleave == 8) { + quantize_q8_0_4x8(x, vy, n_per_row); + } else { + assert(false); + } +} + +void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + if (lm_ggml_cpu_has_neon()) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "movi v31.16b, #0x4\n" + "movi v30.16b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x8\n" + "1:" // Column loop + "add x22, %x[a_ptr], #0x2\n" + "movi v29.16b, #0x0\n" + "mov x21, %x[nb]\n" + "2:" // Block loop + "ldr q28, [%x[b_ptr], #0x0]\n" + "ldr q27, [x22, #0x0]\n" + "movi v26.4s, #0x0\n" + "sub x20, x22, #0x2\n" + "ldr q25, [x22, #0x10]\n" + "ldr q24, [%x[b_ptr], #0x10]\n" + "sub x21, x21, #0x1\n" + "add x22, x22, #0x22\n" + "ldr q23, [%x[b_ptr], #0x20]\n" + "ldr q22, [%x[b_ptr], #0x30]\n" + "ld1r { v21.8h }, [x20]\n" + "ldr q20, [%x[b_ptr], #-0x8]\n" + "sshl v16.16b, v28.16b, v31.16b\n" + "and v28.16b, v28.16b, v30.16b\n" + "sshl v19.16b, v24.16b, v31.16b\n" + "and v24.16b, v24.16b, v30.16b\n" + "add %x[b_ptr], %x[b_ptr], #0x48\n" + "sshl v18.16b, v23.16b, v31.16b\n" + "and v23.16b, v23.16b, v30.16b\n" + ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n" + "sshl v17.16b, v22.16b, v31.16b\n" + "and v22.16b, v22.16b, v30.16b\n" + "fcvtl v21.4s, v21.4h\n" + "fcvtl v16.4s, v20.4h\n" + ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n" + "fmul v16.4s, v16.4s, v21.4s\n" + ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n" + ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n" + ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n" + ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n" + ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n" + ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "fmla v29.4s, v26.4s, v16.4s\n" + "cbnz x21, 2b\n" + "sub %x[nc], %x[nc], #0x4\n" + "str q29, [%x[res_ptr], #0x0]\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22" + ); + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "movi v2.16b, #0x4\n" + "movi v1.16b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x8\n" + "1:" // Column loop + "add x23, %x[a_ptr], #0x2\n" + "movi v0.16b, #0x0\n" + "mov x22, %x[nb]\n" + "2:" // Block loop + "ldr q31, [%x[b_ptr], #0x0]\n" + "ldr q30, [%x[b_ptr], #0x10]\n" + "mov x21, x23\n" + "movi v29.4s, #0x0\n" + "ldr q28, [%x[b_ptr], #0x20]\n" + "ldr q27, [%x[b_ptr], #0x30]\n" + "movi v26.4s, #0x0\n" + "sub x20, x23, #0x2\n" + "ld1r { v25.8h }, [x20]\n" + "ldr q24, [%x[b_ptr], #-0x8]\n" + "sub x22, x22, #0x1\n" + "add x23, x23, #0x22\n" + "ld1r { v23.2d }, [x21], #0x8\n" + "sshl v22.16b, v31.16b, v2.16b\n" + "sshl v16.16b, v30.16b, v2.16b\n" + "add %x[b_ptr], %x[b_ptr], #0x48\n" + "ld1r { v21.2d }, [x21], #0x8\n" + "sshl v20.16b, v28.16b, v2.16b\n" + "sshl v19.16b, v27.16b, v2.16b\n" + "ld1r { v18.2d }, [x21], #0x8\n" + "ld1r { v17.2d }, [x21], #0x8\n" + "and v31.16b, v31.16b, v1.16b\n" + "and v30.16b, v30.16b, v1.16b\n" + ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n" + ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n" + "and v28.16b, v28.16b, v1.16b\n" + "and v27.16b, v27.16b, v1.16b\n" + "fcvtl v25.4s, v25.4h\n" + "fcvtl v16.4s, v24.4h\n" + ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n" + ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n" + "fmul v16.4s, v16.4s, v25.4s\n" + ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n" + ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n" + ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n" + ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n" + "addp v29.4s, v29.4s, v26.4s\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "fmla v0.4s, v29.4s, v16.4s\n" + "cbnz x22, 2b\n" + "sub %x[nc], %x[nc], #0x4\n" + "str q0, [%x[res_ptr], #0x0]\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23" + ); + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) +#if defined(__ARM_FEATURE_SVE) + if (lm_ggml_cpu_has_sve() && lm_ggml_cpu_get_sve_cnt() == QK8_0) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "ptrue p0.b\n" + "add %x[b_ptr], %x[b_ptr], #0x10\n" + "1:" // Column loop + "add x22, %x[a_ptr], #0x2\n" + "mov z31.b, #0x0\n" + "mov x21, %x[nb]\n" + "2:" // Block loop + "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n" + "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n" + "mov z28.s, #0x0\n" + "mov z27.s, #0x0\n" + "ld1rd { z26.d }, p0/Z, [x22]\n" + "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n" + "sub x20, x22, #0x2\n" + "sub x21, x21, #0x1\n" + "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n" + "ld1rd { z23.d }, p0/Z, [x22, #8]\n" + "lsl z22.b, z30.b, #0x4\n" + "lsl z16.b, z29.b, #0x4\n" + "and z30.b, z30.b, #0xf0\n" + "and z29.b, z29.b, #0xf0\n" + "ld1rd { z21.d }, p0/Z, [x22, #16]\n" + "ld1rd { z20.d }, p0/Z, [x22, #24]\n" + "lsl z19.b, z25.b, #0x4\n" + "and z25.b, z25.b, #0xf0\n" + "ld1rh { z17.h }, p0/Z, [x20]\n" + "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n" + "sdot z28.s, z22.b, z26.b\n" + "sdot z27.s, z16.b, z26.b\n" + "lsl z16.b, z24.b, #0x4\n" + "add x22, x22, #0x22\n" + "and z24.b, z24.b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x90\n" + "fcvt z17.s, p0/m, z17.h\n" + "fcvt z18.s, p0/m, z18.h\n" + "sdot z28.s, z19.b, z23.b\n" + "sdot z27.s, z16.b, z23.b\n" + "fmul z18.s, z18.s, z17.s\n" + "sdot z28.s, z30.b, z21.b\n" + "sdot z27.s, z29.b, z21.b\n" + "sdot z28.s, z25.b, z20.b\n" + "sdot z27.s, z24.b, z20.b\n" + "uzp1 z17.s, z28.s, z27.s\n" + "uzp2 z16.s, z28.s, z27.s\n" + "add z17.s, z17.s, z16.s\n" + "asr z17.s, z17.s, #0x4\n" + "scvtf z17.s, p0/m, z17.s\n" + "fmla z31.s, p0/M, z17.s, z18.s\n" + "cbnz x21, 2b\n" + "sub %x[nc], %x[nc], #0x8\n" + "st1w { z31.s }, p0, [%x[res_ptr]]\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); + return; + } +#endif // #if defined(__ARM_FEATURE_SVE) +#elif defined(__AVX2__) + // Lookup table to convert signed nibbles to signed bytes + __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); + __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + + // Permute mask used for easier vector processing at later stages + const __m256i m4b = _mm256_set1_epi8(0x0F); + + int64_t b_nb = n / QK4_0; + + const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; + const block_q8_0 * a_ptr_start = (const block_q8_0 *)vy; + + // Process Q8_0 blocks one by one + for (int64_t y = 0; y < nr; y++) { + + // Pointers to LHS blocks of block_q8_0 format + const block_q8_0 * a_ptr = a_ptr_start + (y * nb); + + // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < nc / 8; x++) { + + // Pointers to RHS blocks + const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulator + __m256 acc_row = _mm256_setzero_ps(); + + for (int64_t b = 0; b < nb; b++) { + // Load 8 blocks of Q4_0 interleaved as 8 bytes (B0 - B7) + const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); + const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 1); + const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 2); + const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 3); + + // 4-bit -> 8-bit - Sign is maintained + const __m256i rhs_vec_0123_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_0, m4b)); // B0(0-7) B1(0-7) B2(0-7) B3(0-7) + const __m256i rhs_vec_4567_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_0, m4b)); // B4(0-7) B5(0-7) B6(0-7) B7(0-7) + const __m256i rhs_vec_0123_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15) + const __m256i rhs_vec_4567_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15) + + const __m256i rhs_vec_0123_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b)); // B0(16-23) B1(16-23) B2(16-23) B3(16-23) + const __m256i rhs_vec_4567_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b)); // B4(16-23) B5(16-23) B6(16-23) B7(16-23) + const __m256i rhs_vec_0123_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b)); // B0(24-31) B1(24-31) B2(24-31) B3(24-31) + const __m256i rhs_vec_4567_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b)); // B4(24-31) B5(24-31) B6(24-31) B7(24-31) + + // Load the scale values for the 8 blocks interleaved in block_q4_0x8 + const __m256 col_scale_f32 = LM_GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); + + // Load and convert to FP32 scale from block_q8_0 + const __m256 row_scale_f32 = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(a_ptr[b].d)); + + // Load the block values in block_q8_0 in batches of 16 bytes and replicate the same across 256 bit vector + __m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)a_ptr[b].qs)); + __m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16))); + + lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0); // A0 (0-15) A0(0-15) + lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0); // A0 (16-31) A0(16-31)) + + __m256i iacc = _mm256_setzero_si256(); + + // Dot product done within 32 bit lanes and accumulated in the same vector + // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3) + // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7) + // ........................................................................... + // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31) + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85))); + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255))); + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85))); + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255))); + + // Accumulated values multipled with appropriate scales + acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row); + } + + // Accumulated output values permuted so as to be stored in appropriate order post accumulation + acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask); + _mm256_storeu_ps(s + (y * nr + x * 8), acc_row); + } + } + return; +#elif defined(__riscv_v_intrinsic) + if (__riscv_vlenb() >= QK4_0) { + const size_t vl = QK4_0; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + + vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + for (int l = 0; l < nb; l++) { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4)); + + const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); + const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); + const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); + const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); + const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); + const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); + + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + // vector version needs Zvfhmin extension + const float a_scale = LM_GGML_FP16_TO_FP32(a_ptr[l].d); + const float b_scales[8] = { + LM_GGML_FP16_TO_FP32(b_ptr[l].d[0]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[1]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[2]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[3]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[4]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[5]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[6]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[7]) + }; + const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4); + sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4); + } + __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4); + } + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) + { + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + if (lm_ggml_cpu_has_neon()) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v23.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v0.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "movi v1.16b, #0x0\n" + "3:" // Block loop + "ldr q3, [x28, #0x0]\n" + "ldr q31, [x25, #0x0]\n" + "movi v28.16b, #0x4\n" + "movi v10.4s, #0x0\n" + "ldr q22, [x28, #0x10]\n" + "ldr q6, [x25, #0x10]\n" + "movi v29.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "ldr q27, [x28, #0x20]\n" + "ldr q30, [x28, #0x30]\n" + "movi v20.4s, #0x0\n" + "movi v24.16b, #0xf0\n" + "ldr d2, [x25, #-0x8]\n" + "ldr d26, [x23, #-0x8]\n" + "sshl v12.16b, v3.16b, v28.16b\n" + "sub x20, x28, #0x8\n" + "ldr d17, [x20, #0x0]\n" + "and v3.16b, v3.16b, v24.16b\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" + ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" + ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" + ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" + "sshl v31.16b, v22.16b, v28.16b\n" + "and v22.16b, v22.16b, v24.16b\n" + "fcvtl v17.4s, v17.4h\n" + "fcvtl v2.4s, v2.4h\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" + ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" + ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" + ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" + "sshl v6.16b, v27.16b, v28.16b\n" + "sshl v28.16b, v30.16b, v28.16b\n" + "and v27.16b, v27.16b, v24.16b\n" + "and v30.16b, v30.16b, v24.16b\n" + "ldr q24, [x25, #0x20]\n" + ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x30]\n" + ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" + ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" + ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" + ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x40]\n" + ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x50]\n" + ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" + ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" + ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" + ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x60]\n" + ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" + ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" + ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" + ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" + "fmul v24.4s, v17.4s, v2.s[0]\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v15.4s, v10.4s, v24.4s\n" + "ldr q24, [x23, #0x0]\n" + "fmul v10.4s, v17.4s, v2.s[1]\n" + "fmla v19.4s, v29.4s, v10.4s\n" + "ldr q10, [x23, #0x10]\n" + "fmul v29.4s, v17.4s, v2.s[2]\n" + "fmul v2.4s, v17.4s, v2.s[3]\n" + "fmla v18.4s, v9.4s, v29.4s\n" + "movi v9.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" + "fmla v14.4s, v20.4s, v2.4s\n" + "movi v20.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x20]\n" + ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" + ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" + ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" + ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x30]\n" + ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x40]\n" + ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" + ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" + ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" + ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x50]\n" + ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x60]\n" + ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" + ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" + ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" + ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x0]\n" + ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" + ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" + ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" + ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" + "fmul v10.4s, v17.4s, v26.s[0]\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v11.4s, v9.4s, v10.4s\n" + "ldr q9, [x22, #0x10]\n" + "fmul v10.4s, v17.4s, v26.s[1]\n" + "fmla v13.4s, v29.4s, v10.4s\n" + "ldr d29, [x22, #-0x8]\n" + "fmul v10.4s, v17.4s, v26.s[2]\n" + "fmul v26.4s, v17.4s, v26.s[3]\n" + "fcvtl v29.4s, v29.4h\n" + "fmla v23.4s, v20.4s, v10.4s\n" + "movi v20.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v16.4s, v2.4s, v26.4s\n" + "movi v26.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x20]\n" + ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x30]\n" + ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x40]\n" + ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x50]\n" + ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x60]\n" + ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x21, #0x0]\n" + ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n" + ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n" + ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n" + "fmul v9.4s, v17.4s, v29.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v25.4s, v20.4s, v9.4s\n" + "ldr q9, [x21, #0x10]\n" + "fmul v20.4s, v17.4s, v29.s[1]\n" + "fmla v7.4s, v10.4s, v20.4s\n" + "ldr d20, [x21, #-0x8]\n" + "fmul v10.4s, v17.4s, v29.s[2]\n" + "fmul v29.4s, v17.4s, v29.s[3]\n" + "fcvtl v20.4s, v20.4h\n" + "fmla v0.4s, v26.4s, v10.4s\n" + "movi v26.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v4.4s, v2.4s, v29.4s\n" + "movi v2.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n" + "ldr q12, [x21, #0x20]\n" + "fmul v24.4s, v17.4s, v20.s[0]\n" + ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x30]\n" + "fmul v31.4s, v17.4s, v20.s[1]\n" + ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n" + ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n" + ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n" + ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x40]\n" + "fmul v6.4s, v17.4s, v20.s[2]\n" + "fmul v20.4s, v17.4s, v20.s[3]\n" + ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x50]\n" + ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n" + ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n" + ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n" + ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x60]\n" + ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n" + "ldr q17, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n" + ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n" + ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n" + ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n" + ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n" + ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n" + ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n" + ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "fmla v5.4s, v26.4s, v24.4s\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "fmla v21.4s, v10.4s, v31.4s\n" + "fmla v8.4s, v2.4s, v6.4s\n" + "fmla v1.4s, v29.4s, v20.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q16, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q0, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q21, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q1, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q7, [x24, #0x0]\n" + "ldr q5, [x25, #0x0]\n" + "movi v9.16b, #0x4\n" + "movi v4.4s, #0x0\n" + "ldr q3, [x24, #0x10]\n" + "ldr q2, [x25, #0x10]\n" + "movi v1.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q13, [x24, #0x20]\n" + "ldr q31, [x25, #0x20]\n" + "movi v30.4s, #0x0\n" + "movi v29.16b, #0xf0\n" + "ldr q28, [x24, #0x30]\n" + "ldr q27, [x25, #0x30]\n" + "sshl v20.16b, v7.16b, v9.16b\n" + "sub x20, x24, #0x8\n" + "ldr q26, [x25, #0x40]\n" + "ldr q25, [x25, #0x50]\n" + "sshl v17.16b, v3.16b, v9.16b\n" + "and v7.16b, v7.16b, v29.16b\n" + "ldr q24, [x25, #0x60]\n" + "ldr q16, [x25, #0x70]\n" + "sshl v22.16b, v13.16b, v9.16b\n" + "and v3.16b, v3.16b, v29.16b\n" + "ldr d21, [x20, #0x0]\n" + "ldr d12, [x25, #-0x8]\n" + ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n" + ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n" + ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n" + ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n" + "sshl v9.16b, v28.16b, v9.16b\n" + "subs x21, x21, #0x1\n" + "and v13.16b, v13.16b, v29.16b\n" + "and v28.16b, v28.16b, v29.16b\n" + "add x25, x25, #0x88\n" + "add x24, x24, #0x48\n" + "fcvtl v21.4s, v21.4h\n" + "fcvtl v12.4s, v12.4h\n" + ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n" + ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n" + ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n" + ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n" + "fmul v11.4s, v21.4s, v12.s[0]\n" + "fmul v23.4s, v21.4s, v12.s[1]\n" + "fmul v17.4s, v21.4s, v12.s[2]\n" + ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n" + "fmul v6.4s, v21.4s, v12.s[3]\n" + ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n" + ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n" + ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n" + ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n" + ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n" + ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n" + ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n" + ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n" + ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n" + ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n" + ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n" + ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n" + ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n" + ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n" + ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n" + ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n" + ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n" + ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n" + ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n" + ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n" + ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n" + ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n" + ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n" + "scvtf v4.4s, v4.4s, #0x4\n" + "scvtf v1.4s, v1.4s, #0x4\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "fmla v15.4s, v4.4s, v11.4s\n" + "scvtf v30.4s, v30.4s, #0x4\n" + "fmla v19.4s, v1.4s, v23.4s\n" + "fmla v18.4s, v0.4s, v17.4s\n" + "fmla v14.4s, v30.4s, v6.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q14, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v6.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "3:" // Block loop + "ldr q21, [x28, #0x0]\n" + "ldr q16, [x28, #0x10]\n" + "movi v1.16b, #0x4\n" + "movi v19.4s, #0x0\n" + "ldr q27, [x25, #0x0]\n" + "ldr q15, [x25, #0x10]\n" + "movi v26.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "ldr q29, [x28, #0x20]\n" + "ldr q3, [x28, #0x30]\n" + "movi v17.4s, #0x0\n" + "movi v0.16b, #0xf0\n" + "ldr d20, [x25, #-0x8]\n" + "ldr d9, [x23, #-0x8]\n" + "sshl v8.16b, v21.16b, v1.16b\n" + "sshl v31.16b, v16.16b, v1.16b\n" + "and v21.16b, v21.16b, v0.16b\n" + "and v16.16b, v16.16b, v0.16b\n" + "sub x20, x28, #0x8\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n" + ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n" + "ldr q27, [x25, #0x20]\n" + ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n" + ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n" + "sshl v15.16b, v29.16b, v1.16b\n" + "sshl v1.16b, v3.16b, v1.16b\n" + "and v29.16b, v29.16b, v0.16b\n" + "and v3.16b, v3.16b, v0.16b\n" + "ldr q0, [x25, #0x30]\n" + "fcvtl v20.4s, v20.4h\n" + ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n" + "fcvtl v9.4s, v9.4h\n" + ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n" + "ldr q27, [x25, #0x40]\n" + ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n" + ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n" + "ldr q0, [x25, #0x50]\n" + ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n" + ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n" + "ldr q27, [x25, #0x60]\n" + ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n" + ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n" + "ldr q0, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n" + ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" + "ldr d27, [x20, #0x0]\n" + ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n" + ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n" + "fcvtl v27.4s, v27.4h\n" + "uzp1 v0.2d, v19.2d, v26.2d\n" + "uzp2 v26.2d, v19.2d, v26.2d\n" + "fmul v19.4s, v27.4s, v20.s[0]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "fmla v2.4s, v0.4s, v19.4s\n" + "ldr q19, [x23, #0x0]\n" + "uzp1 v0.2d, v18.2d, v17.2d\n" + "uzp2 v18.2d, v18.2d, v17.2d\n" + "fmul v17.4s, v27.4s, v20.s[1]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v10.4s, v26.4s, v17.4s\n" + "ldr q17, [x23, #0x10]\n" + "fmul v26.4s, v27.4s, v20.s[2]\n" + "fmul v20.4s, v27.4s, v20.s[3]\n" + "fmla v12.4s, v0.4s, v26.4s\n" + "ldr d0, [x22, #-0x8]\n" + "ldr d26, [x21, #-0x8]\n" + "fcvtl v0.4s, v0.4h\n" + "fmla v28.4s, v18.4s, v20.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x23, #0x20]\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x23, #0x40]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q19, [x23, #0x60]\n" + ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n" + ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" + "uzp1 v19.2d, v20.2d, v18.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp2 v20.2d, v20.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v9.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v11.4s, v19.4s, v18.4s\n" + "ldr q18, [x22, #0x0]\n" + "fmul v19.4s, v27.4s, v9.s[1]\n" + "fmla v13.4s, v20.4s, v19.4s\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n" + ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n" + "ldr q17, [x23, #0x30]\n" + ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n" + "ldr q17, [x23, #0x50]\n" + ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n" + "ldr q17, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v9.s[2]\n" + "fmul v9.4s, v27.4s, v9.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v22.4s, v17.4s, v19.4s\n" + "ldr q17, [x22, #0x10]\n" + "movi v19.4s, #0x0\n" + ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n" + "fmla v23.4s, v20.4s, v9.4s\n" + "movi v20.4s, #0x0\n" + "movi v9.4s, #0x0\n" + ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n" + "ldr q18, [x22, #0x20]\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n" + ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n" + "ldr q18, [x22, #0x40]\n" + ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n" + ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n" + "ldr q18, [x22, #0x60]\n" + ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n" + ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n" + "ldr q17, [x22, #0x30]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n" + "ldr q17, [x22, #0x50]\n" + ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n" + "ldr q17, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v0.s[0]\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v25.4s, v17.4s, v19.4s\n" + "ldr q19, [x21, #0x0]\n" + "fmul v17.4s, v27.4s, v0.s[1]\n" + "fmla v5.4s, v20.4s, v17.4s\n" + "ldr q17, [x21, #0x10]\n" + "uzp1 v20.2d, v9.2d, v18.2d\n" + "uzp2 v9.2d, v9.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v0.s[2]\n" + "fmul v0.4s, v27.4s, v0.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "fmla v7.4s, v20.4s, v18.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x21, #0x20]\n" + "fmla v4.4s, v9.4s, v0.4s\n" + "movi v9.4s, #0x0\n" + "movi v0.4s, #0x0\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + "fmul v8.4s, v27.4s, v26.s[0]\n" + ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n" + "ldr q17, [x21, #0x30]\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + "fmul v31.4s, v27.4s, v26.s[1]\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x21, #0x40]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + "fmul v15.4s, v27.4s, v26.s[2]\n" + "fmul v27.4s, v27.4s, v26.s[3]\n" + ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" + "ldr q1, [x21, #0x50]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q26, [x21, #0x60]\n" + ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n" + ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n" + "ldr q21, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n" + ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n" + ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n" + ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n" + "uzp1 v29.2d, v20.2d, v18.2d\n" + "uzp2 v21.2d, v20.2d, v18.2d\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "uzp1 v18.2d, v9.2d, v0.2d\n" + "uzp2 v16.2d, v9.2d, v0.2d\n" + "scvtf v21.4s, v21.4s, #0x4\n" + "fmla v6.4s, v29.4s, v8.4s\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v30.4s, v21.4s, v31.4s\n" + "fmla v24.4s, v18.4s, v15.4s\n" + "fmla v14.4s, v16.4s, v27.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q6, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q24, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q6, [x24, #0x0]\n" + "ldr q5, [x24, #0x10]\n" + "movi v17.16b, #0x4\n" + "movi v8.4s, #0x0\n" + "ldr q4, [x25, #0x0]\n" + "ldr q13, [x25, #0x10]\n" + "movi v27.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q31, [x24, #0x20]\n" + "ldr q14, [x24, #0x30]\n" + "movi v29.4s, #0x0\n" + "movi v22.16b, #0xf0\n" + "ldr q11, [x25, #0x20]\n" + "ldr q23, [x25, #0x30]\n" + "sshl v21.16b, v6.16b, v17.16b\n" + "sshl v16.16b, v5.16b, v17.16b\n" + "ldr q20, [x25, #0x40]\n" + "ldr q26, [x25, #0x50]\n" + "and v6.16b, v6.16b, v22.16b\n" + "and v5.16b, v5.16b, v22.16b\n" + "ldr q25, [x25, #0x60]\n" + "ldr q3, [x25, #0x70]\n" + "sshl v19.16b, v31.16b, v17.16b\n" + "sshl v18.16b, v14.16b, v17.16b\n" + "ldr d17, [x25, #-0x8]\n" + ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n" + ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n" + "and v31.16b, v31.16b, v22.16b\n" + ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n" + ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n" + "and v14.16b, v14.16b, v22.16b\n" + "sub x20, x24, #0x8\n" + "ldr d16, [x20, #0x0]\n" + "subs x21, x21, #0x1\n" + "add x25, x25, #0x88\n" + "fcvtl v17.4s, v17.4h\n" + "add x24, x24, #0x48\n" + ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n" + ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n" + ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n" + ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n" + "fcvtl v16.4s, v16.4h\n" + ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n" + ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n" + "fmul v23.4s, v16.4s, v17.s[0]\n" + "fmul v21.4s, v16.4s, v17.s[1]\n" + "fmul v1.4s, v16.4s, v17.s[2]\n" + "fmul v20.4s, v16.4s, v17.s[3]\n" + ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n" + ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n" + ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n" + ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n" + ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n" + ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n" + "uzp1 v19.2d, v8.2d, v27.2d\n" + "uzp2 v18.2d, v8.2d, v27.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp1 v17.2d, v0.2d, v29.2d\n" + "uzp2 v16.2d, v0.2d, v29.2d\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v2.4s, v19.4s, v23.4s\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v10.4s, v18.4s, v21.4s\n" + "fmla v12.4s, v17.4s, v1.4s\n" + "fmla v28.4s, v16.4s, v20.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q28, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (lm_ggml_cpu_has_sve() && lm_ggml_cpu_has_matmul_int8() && lm_ggml_cpu_get_sve_cnt() == QK8_0) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x20, #0x4\n" + "mov x13, %x[nr]\n" + "mov z28.s, #-0x4\n" + "mov x12, #0x88\n" + "ptrue p1.b\n" + "whilelt p0.s, XZR, x20\n" + "cmp x13, #0x10\n" + "mul x12, %x[nb], x12\n" + "blt 4f\n" + "1:" // Row loop + "add x11, %x[b_ptr], #0x10\n" + "mov x10, %x[nc]\n" + "add x9, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x28, %x[a_ptr], #0x8\n" + "mov z24.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov x27, %x[nb]\n" + "add x26, x28, x12\n" + "mov z12.b, #0x0\n" + "mov z0.b, #0x0\n" + "add x25, x26, x12\n" + "mov z13.b, #0x0\n" + "mov z1.b, #0x0\n" + "add x24, x25, x12\n" + "mov z20.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z8.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z10.b, #0x0\n" + "3:" // Block loop + "ld1b { z30.b }, p1/Z, [x11]\n" + "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n" + "mov z18.s, #0x0\n" + "mov z7.s, #0x0\n" + "ld1rqb { z3.b }, p1/Z, [x28]\n" + "ld1rqb { z5.b }, p1/Z, [x28, #16]\n" + "mov z9.s, #0x0\n" + "mov z22.s, #0x0\n" + "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n" + "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n" + "sub x20, x11, #0x10\n" + "sub x23, x28, #0x8\n" + "lsl z31.b, z30.b, #0x4\n" + "lsl z6.b, z21.b, #0x4\n" + "ld1h { z23.s }, p1/Z, [x20]\n" + "sub x22, x26, #0x8\n" + "and z30.b, z30.b, #0xf0\n" + "and z21.b, z21.b, #0xf0\n" + "sub x21, x25, #0x8\n" + "sub x20, x24, #0x8\n" + "lsl z14.b, z4.b, #0x4\n" + "lsl z2.b, z17.b, #0x4\n" + "subs x27, x27, #0x1\n" + "add x11, x11, #0x90\n" + ".inst 0x451f9872 // smmla z18.s, z3.b, z31.b\n" + ".inst 0x45069867 // smmla z7.s, z3.b, z6.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #32]\n" + "and z4.b, z4.b, #0xf0\n" + ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" + ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #48]\n" + "and z17.b, z17.b, #0xf0\n" + "fcvt z23.s, p1/m, z23.h\n" + ".inst 0x450e9872 // smmla z18.s, z3.b, z14.b\n" + ".inst 0x45029867 // smmla z7.s, z3.b, z2.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #64]\n" + ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" + ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #80]\n" + "fscale z23.s, p1/m, z23.s, z28.s\n" + ".inst 0x451e9872 // smmla z18.s, z3.b, z30.b\n" + ".inst 0x45159867 // smmla z7.s, z3.b, z21.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #96]\n" + ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" + ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #112]\n" + "add x28, x28, #0x88\n" + ".inst 0x45049872 // smmla z18.s, z3.b, z4.b\n" + ".inst 0x45119867 // smmla z7.s, z3.b, z17.b\n" + "ld1h { z3.s }, p0/Z, [x23]\n" + ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" + ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" + "fcvt z3.s, p1/m, z3.h\n" + "uzp1 z5.d, z18.d, z7.d\n" + "uzp2 z18.d, z18.d, z7.d\n" + "mov z3.q, z3.q[0]\n" + "uzp1 z7.d, z9.d, z22.d\n" + "uzp2 z22.d, z9.d, z22.d\n" + "fmul z9.s, z23.s, z3.s[0]\n" + "scvtf z5.s, p1/m, z5.s\n" + "scvtf z18.s, p1/m, z18.s\n" + "scvtf z7.s, p1/m, z7.s\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z24.s, p1/M, z5.s, z9.s\n" + "ld1rqb { z5.b }, p1/Z, [x26]\n" + "fmul z9.s, z23.s, z3.s[1]\n" + "fmla z15.s, p1/M, z18.s, z9.s\n" + "ld1rqb { z18.b }, p1/Z, [x26, #16]\n" + "fmul z9.s, z23.s, z3.s[2]\n" + "fmul z3.s, z23.s, z3.s[3]\n" + "fmla z12.s, p1/M, z7.s, z9.s\n" + "mov z9.s, #0x0\n" + "ld1h { z7.s }, p0/Z, [x22]\n" + ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" + "fmla z0.s, p1/M, z22.s, z3.s\n" + "mov z22.s, #0x0\n" + "ld1h { z3.s }, p0/Z, [x21]\n" + ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #32]\n" + "fcvt z7.s, p1/m, z7.h\n" + "fcvt z3.s, p1/m, z3.h\n" + ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" + ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #64]\n" + "mov z7.q, z7.q[0]\n" + "mov z3.q, z3.q[0]\n" + ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" + ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #96]\n" + ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" + ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" + "uzp1 z5.d, z9.d, z22.d\n" + "scvtf z5.s, p1/m, z5.s\n" + "uzp2 z22.d, z9.d, z22.d\n" + "fmul z9.s, z23.s, z7.s[0]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z13.s, p1/M, z5.s, z9.s\n" + "ld1rqb { z9.b }, p1/Z, [x25]\n" + "fmul z5.s, z23.s, z7.s[1]\n" + "fmla z1.s, p1/M, z22.s, z5.s\n" + "mov z5.s, #0x0\n" + "mov z22.s, #0x0\n" + ".inst 0x451f9a45 // smmla z5.s, z18.b, z31.b\n" + ".inst 0x45069a56 // smmla z22.s, z18.b, z6.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #48]\n" + ".inst 0x450e9a45 // smmla z5.s, z18.b, z14.b\n" + ".inst 0x45029a56 // smmla z22.s, z18.b, z2.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #80]\n" + ".inst 0x451e9a45 // smmla z5.s, z18.b, z30.b\n" + ".inst 0x45159a56 // smmla z22.s, z18.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #112]\n" + "add x26, x26, #0x88\n" + ".inst 0x45049a45 // smmla z5.s, z18.b, z4.b\n" + ".inst 0x45119a56 // smmla z22.s, z18.b, z17.b\n" + "uzp1 z18.d, z5.d, z22.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp2 z22.d, z5.d, z22.d\n" + "fmul z5.s, z23.s, z7.s[2]\n" + "fmul z7.s, z23.s, z7.s[3]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z20.s, p1/M, z18.s, z5.s\n" + "ld1rqb { z18.b }, p1/Z, [x25, #16]\n" + "ld1h { z5.s }, p0/Z, [x20]\n" + "fcvt z5.s, p1/m, z5.h\n" + "fmla z25.s, p1/M, z22.s, z7.s\n" + "mov z22.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9936 // smmla z22.s, z9.b, z31.b\n" + ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #32]\n" + "mov z5.q, z5.q[0]\n" + ".inst 0x450e9936 // smmla z22.s, z9.b, z14.b\n" + ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #64]\n" + ".inst 0x451e9936 // smmla z22.s, z9.b, z30.b\n" + ".inst 0x45159927 // smmla z7.s, z9.b, z21.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #96]\n" + ".inst 0x45049936 // smmla z22.s, z9.b, z4.b\n" + ".inst 0x45119927 // smmla z7.s, z9.b, z17.b\n" + "uzp1 z9.d, z22.d, z7.d\n" + "scvtf z9.s, p1/m, z9.s\n" + "uzp2 z22.d, z22.d, z7.d\n" + "fmul z7.s, z23.s, z3.s[0]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z11.s, p1/M, z9.s, z7.s\n" + "ld1rqb { z9.b }, p1/Z, [x24]\n" + "fmul z7.s, z23.s, z3.s[1]\n" + "fmla z16.s, p1/M, z22.s, z7.s\n" + "mov z22.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9a56 // smmla z22.s, z18.b, z31.b\n" + ".inst 0x45069a47 // smmla z7.s, z18.b, z6.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #48]\n" + ".inst 0x450e9a56 // smmla z22.s, z18.b, z14.b\n" + ".inst 0x45029a47 // smmla z7.s, z18.b, z2.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #80]\n" + ".inst 0x451e9a56 // smmla z22.s, z18.b, z30.b\n" + ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #112]\n" + "add x25, x25, #0x88\n" + ".inst 0x45049a56 // smmla z22.s, z18.b, z4.b\n" + ".inst 0x45119a47 // smmla z7.s, z18.b, z17.b\n" + "uzp1 z18.d, z22.d, z7.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp2 z7.d, z22.d, z7.d\n" + "fmul z22.s, z23.s, z3.s[2]\n" + "fmul z3.s, z23.s, z3.s[3]\n" + "scvtf z7.s, p1/m, z7.s\n" + "fmla z19.s, p1/M, z18.s, z22.s\n" + "ld1rqb { z18.b }, p1/Z, [x24, #16]\n" + "fmul z22.s, z23.s, z5.s[0]\n" + "fmla z26.s, p1/M, z7.s, z3.s\n" + "mov z3.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9923 // smmla z3.s, z9.b, z31.b\n" + ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" + "ld1rqb { z9.b }, p1/Z, [x24, #32]\n" + ".inst 0x450e9923 // smmla z3.s, z9.b, z14.b\n" + ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" + "mov z9.s, #0x0\n" + ".inst 0x451f9a49 // smmla z9.s, z18.b, z31.b\n" + "mov z31.s, #0x0\n" + ".inst 0x45069a5f // smmla z31.s, z18.b, z6.b\n" + "ld1rqb { z6.b }, p1/Z, [x24, #48]\n" + "ld1rqb { z18.b }, p1/Z, [x24, #64]\n" + ".inst 0x450e98c9 // smmla z9.s, z6.b, z14.b\n" + "fmul z14.s, z23.s, z5.s[1]\n" + ".inst 0x450298df // smmla z31.s, z6.b, z2.b\n" + "ld1rqb { z6.b }, p1/Z, [x24, #80]\n" + "fmul z2.s, z23.s, z5.s[2]\n" + "fmul z23.s, z23.s, z5.s[3]\n" + ".inst 0x451e9a43 // smmla z3.s, z18.b, z30.b\n" + ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x24, #96]\n" + ".inst 0x451e98c9 // smmla z9.s, z6.b, z30.b\n" + ".inst 0x451598df // smmla z31.s, z6.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x24, #112]\n" + "add x24, x24, #0x88\n" + ".inst 0x450498a3 // smmla z3.s, z5.b, z4.b\n" + ".inst 0x451198a7 // smmla z7.s, z5.b, z17.b\n" + ".inst 0x45049a49 // smmla z9.s, z18.b, z4.b\n" + ".inst 0x45119a5f // smmla z31.s, z18.b, z17.b\n" + "uzp1 z18.d, z3.d, z7.d\n" + "uzp2 z5.d, z3.d, z7.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp1 z6.d, z9.d, z31.d\n" + "uzp2 z9.d, z9.d, z31.d\n" + "scvtf z5.s, p1/m, z5.s\n" + "fmla z8.s, p1/M, z18.s, z22.s\n" + "scvtf z6.s, p1/m, z6.s\n" + "scvtf z9.s, p1/m, z9.s\n" + "fmla z29.s, p1/M, z5.s, z14.s\n" + "fmla z27.s, p1/M, z6.s, z2.s\n" + "fmla z10.s, p1/M, z9.s, z23.s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x10, x10, #0x8\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "st1w { z24.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z15.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z12.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z0.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z13.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z1.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z20.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z25.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z11.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z16.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z19.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z26.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z8.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z29.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z27.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z10.s }, p1, [x20]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x13, x13, #0x10\n" + "cmp x13, #0x10\n" + "mov %x[res_ptr], x9\n" + "madd %x[a_ptr], x20, x12, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x13, 9f\n" + "5:" // Row tail: Row loop + "add x25, %x[b_ptr], #0x10\n" + "mov x24, %x[nc]\n" + "add x23, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "mov z24.b, #0x0\n" + "mov z15.b, #0x0\n" + "add x28, %x[a_ptr], #0x8\n" + "mov x22, %x[nb]\n" + "mov z12.b, #0x0\n" + "mov z0.b, #0x0\n" + "7:" // Row tail: Block loop + "ld1b { z3.b }, p1/Z, [x25]\n" + "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n" + "mov z2.s, #0x0\n" + "mov z25.s, #0x0\n" + "ld1rqb { z26.b }, p1/Z, [x28]\n" + "ld1rqb { z21.b }, p1/Z, [x28, #16]\n" + "mov z27.s, #0x0\n" + "mov z19.s, #0x0\n" + "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n" + "sub x21, x25, #0x10\n" + "sub x20, x28, #0x8\n" + "lsl z20.b, z3.b, #0x4\n" + "lsl z4.b, z6.b, #0x4\n" + "ld1rqb { z10.b }, p1/Z, [x28, #32]\n" + "ld1rqb { z23.b }, p1/Z, [x28, #48]\n" + "and z3.b, z3.b, #0xf0\n" + "and z6.b, z6.b, #0xf0\n" + "ld1rqb { z11.b }, p1/Z, [x28, #64]\n" + "ld1rqb { z7.b }, p1/Z, [x28, #80]\n" + "lsl z8.b, z29.b, #0x4\n" + "lsl z14.b, z16.b, #0x4\n" + "ld1rqb { z18.b }, p1/Z, [x28, #96]\n" + "ld1rqb { z30.b }, p1/Z, [x28, #112]\n" + ".inst 0x45149b42 // smmla z2.s, z26.b, z20.b\n" + ".inst 0x45049b59 // smmla z25.s, z26.b, z4.b\n" + "and z29.b, z29.b, #0xf0\n" + "ld1h { z17.s }, p1/Z, [x21]\n" + ".inst 0x45149abb // smmla z27.s, z21.b, z20.b\n" + ".inst 0x45049ab3 // smmla z19.s, z21.b, z4.b\n" + "and z16.b, z16.b, #0xf0\n" + "ld1h { z4.s }, p0/Z, [x20]\n" + "subs x22, x22, #0x1\n" + "add x28, x28, #0x88\n" + "fcvt z17.s, p1/m, z17.h\n" + "add x25, x25, #0x90\n" + ".inst 0x45089942 // smmla z2.s, z10.b, z8.b\n" + ".inst 0x450e9959 // smmla z25.s, z10.b, z14.b\n" + "fcvt z4.s, p1/m, z4.h\n" + ".inst 0x45089afb // smmla z27.s, z23.b, z8.b\n" + ".inst 0x450e9af3 // smmla z19.s, z23.b, z14.b\n" + "fscale z17.s, p1/m, z17.s, z28.s\n" + "mov z4.q, z4.q[0]\n" + ".inst 0x45039962 // smmla z2.s, z11.b, z3.b\n" + ".inst 0x45069979 // smmla z25.s, z11.b, z6.b\n" + "fmul z23.s, z17.s, z4.s[0]\n" + "fmul z9.s, z17.s, z4.s[1]\n" + "fmul z21.s, z17.s, z4.s[2]\n" + "fmul z4.s, z17.s, z4.s[3]\n" + ".inst 0x450398fb // smmla z27.s, z7.b, z3.b\n" + ".inst 0x450698f3 // smmla z19.s, z7.b, z6.b\n" + ".inst 0x451d9a42 // smmla z2.s, z18.b, z29.b\n" + ".inst 0x45109a59 // smmla z25.s, z18.b, z16.b\n" + ".inst 0x451d9bdb // smmla z27.s, z30.b, z29.b\n" + ".inst 0x45109bd3 // smmla z19.s, z30.b, z16.b\n" + "uzp1 z31.d, z2.d, z25.d\n" + "uzp2 z13.d, z2.d, z25.d\n" + "scvtf z31.s, p1/m, z31.s\n" + "uzp1 z17.d, z27.d, z19.d\n" + "uzp2 z18.d, z27.d, z19.d\n" + "scvtf z13.s, p1/m, z13.s\n" + "fmla z24.s, p1/M, z31.s, z23.s\n" + "scvtf z17.s, p1/m, z17.s\n" + "scvtf z18.s, p1/m, z18.s\n" + "fmla z15.s, p1/M, z13.s, z9.s\n" + "fmla z12.s, p1/M, z17.s, z21.s\n" + "fmla z0.s, p1/M, z18.s, z4.s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x13, #0x1\n" + "st1w { z24.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x13, #0x2\n" + "st1w { z15.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x13, #0x3\n" + "st1w { z12.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "st1w { z0.s }, p1, [x20]\n" + "8:" // Row tail: Accumulator store skip + "subs x24, x24, #0x8\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "bne 6b\n" + "subs x13, x13, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x12\n" + "mov %x[res_ptr], x23\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); + return; + } +#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) +#elif defined(__AVX2__) || defined(__AVX512F__) + { + const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; + const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy; + int64_t b_nb = n / QK4_0; + int64_t y = 0; + // Mask to mask out nibbles from packed bytes + const __m256i m4b = _mm256_set1_epi8(0x0F); + const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3); + // Lookup table to convert signed nibbles to signed bytes + __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + // Permute mask used for easier vector processing at later stages + __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4); + int64_t xstart = 0; + int anr = nr - nr%16; // Used to align nr with boundary of 16 + #ifdef __AVX512F__ + int anc = nc - nc%16; // Used to align nc with boundary of 16 + // Mask to mask out nibbles from packed bytes expanded to 512 bit length + const __m512i m4bexpanded = _mm512_set1_epi8(0x0F); + // Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length + __m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1); + + // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation + for (; y < anr / 4; y += 4) { + + const block_q8_0x4 * a_ptrs[4]; + + a_ptrs[0] = a_ptr_start + (y * nb); + for (int i = 0; i < 3; ++i) { + a_ptrs[i + 1] = a_ptrs[i] + nb; + } + + // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[16]; + for (int i = 0; i < 16; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + // 4-bit -> 8-bit - Sign is maintained + const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7) + const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7) + + const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15) + const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15) + + const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23) + const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23) + + const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31) + const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) + + // Shuffle pattern one - right side input + const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) + const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) + + const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) + const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) + + const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) + const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) + + const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) + const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) + + // Shuffle pattern two - right side input + + const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) + const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) + + const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) + const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) + + const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) + const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) + + const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) + const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) + + // Scale values - Load the weight scale values of two block_q4_0x8 + const __m512 col_scale_f32 = LM_GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + + // Process LHS in pairs of rows + for (int rp = 0; rp < 4; rp++) { + + // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector + __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); + __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0); + __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17); + __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); + __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0); + __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17); + __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); + __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0); + __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17); + __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); + __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0); + __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17); + + __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1); + __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1); + __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1); + __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1); + __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1); + __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1); + __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1); + __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1); + + // Shuffle pattern one - left side input + + const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + + const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + __m512i iacc_mat_00_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1)); + __m512i iacc_mat_01_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1)); + __m512i iacc_mat_10_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1)); + __m512i iacc_mat_11_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1)); + __m512i iacc_mat_00_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2)); + __m512i iacc_mat_01_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2)); + __m512i iacc_mat_10_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2)); + __m512i iacc_mat_11_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + + // Straighten out to make 4 row vectors + __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78)); + __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01); + __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78)); + __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68); + const __m512 row_scale_f32 = LM_GGML_F32Cx16_REPEAT_LOAD(row_scale_f16); + + // Multiply with appropiate scales and accumulate + acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); + acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); + acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); + acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + } + } + + // Store the accumulated values + for (int i = 0; i < 16; i++) { + _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation + for (; y < nr / 4; y ++) { + + const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb); + + // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + // 4-bit -> 8-bit - Sign is maintained + const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7) + const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7) + + const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15) + const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15) + + const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23) + const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23) + + const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31) + const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) + + // Shuffle pattern one - right side input + const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) + const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) + + const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) + const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) + + const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) + const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) + + const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) + const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) + + // Shuffle pattern two - right side input + + const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) + const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) + + const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) + const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) + + const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) + const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) + + const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) + const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) + + + // Scale values - Load the weight scale values of two block_q4_0x8 + const __m512 col_scale_f32 = LM_GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + + // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector + __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); + __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0); + __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17); + __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); + __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0); + __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17); + __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); + __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0); + __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17); + __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); + __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0); + __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17); + + __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1); + __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1); + __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1); + __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1); + __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1); + __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1); + __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1); + __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1); + + // Shuffle pattern one - left side input + + const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + + const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + __m512i iacc_mat_00_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1)); + __m512i iacc_mat_01_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1)); + __m512i iacc_mat_10_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1)); + __m512i iacc_mat_11_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1)); + __m512i iacc_mat_00_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2)); + __m512i iacc_mat_01_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2)); + __m512i iacc_mat_10_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2)); + __m512i iacc_mat_11_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + + // Straighten out to make 4 row vectors + __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78)); + __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01); + __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78)); + __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68); + const __m512 row_scale_f32 = LM_GGML_F32Cx16_REPEAT_LOAD(row_scale_f16); + + // Multiply with appropiate scales and accumulate + acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); + acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); + acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); + acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); + } + + // Store the accumulated values + for (int i = 0; i < 4; i++) { + _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + if (anc != nc) { + xstart = anc/8; + y = 0; + } + #endif // __AVX512F__ + + // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation + + for (; y < anr / 4; y += 4) { + const block_q8_0x4 * a_ptrs[4]; + + a_ptrs[0] = a_ptr_start + (y * nb); + for (int i = 0; i < 3; ++i) { + a_ptrs[i + 1] = a_ptrs[i] + nb; + } + + // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation + for (int64_t x = xstart; x < nc / 8; x++) { + + const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulators + __m256 acc_rows[16]; + for (int i = 0; i < 16; i++) { + acc_rows[i] = _mm256_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + // 4-bit -> 8-bit - Sign is maintained + const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) + const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) + + const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) + const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) + + const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) + const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) + + const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) + const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) + + // Shuffle pattern one - right side input + const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) + const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) + + const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) + const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) + + const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) + const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) + + const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) + const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) + + // Shuffle pattern two - right side input + + const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) + const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) + + const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) + const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) + + const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) + const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) + + const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) + const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) + + // Scale values - Load the wight scale values of block_q4_0x8 + const __m256 col_scale_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].d); + + // Process LHS in groups of four + for (int rp = 0; rp < 4; rp++) { + // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); + __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); + __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); + __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); + __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); + __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); + __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); + __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); + __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); + __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); + __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); + __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); + + // Shuffle pattern one - left side input + const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + __m256i iacc_mat_00_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_01_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_10_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_11_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_00_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_01_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); + __m256i iacc_mat_10_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_11_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + // Straighten out to make 4 row vectors + __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); + __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); + __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); + __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m256 row_scale_f32 = LM_GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask); + + // Multiply with appropiate scales and accumulate + acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); + acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); + acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); + acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + } + } + + // Store the accumulated values + for (int i = 0; i < 16; i++) { + _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + + // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation + for (; y < nr / 4; y ++) { + + const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb); + + // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + for (int64_t x = xstart; x < nc / 8; x++) { + + const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulators + __m256 acc_rows[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = _mm256_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + // 4-bit -> 8-bit - Sign is maintained + const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) + const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) + + const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) + const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) + + const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) + const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) + + const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) + const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) + + // Shuffle pattern one - right side input + const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) + const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) + + const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) + const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) + + const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) + const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) + + const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) + const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) + + // Shuffle pattern two - right side input + + const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) + const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) + + const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) + const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) + + const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) + const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) + + const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) + const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) + + // Scale values - Load the wight scale values of block_q4_0x8 + const __m256 col_scale_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].d); + + // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); + __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); + __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); + __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); + __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); + __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); + __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); + __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); + __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); + __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); + __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); + __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); + + // Shuffle pattern one - left side input + + const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + + const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + __m256i iacc_mat_00_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_01_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_10_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_11_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_00_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_01_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); + __m256i iacc_mat_10_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_11_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + + // Straighten out to make 4 row vectors + __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); + __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); + __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); + __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m256 row_scale_f32 = LM_GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask); + + // Multiply with appropiate scales and accumulate + acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); + acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); + acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); + acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); + } + + // Store the accumulated values + for (int i = 0; i < 4; i++) { + _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + return; + } +#elif defined(__riscv_v_intrinsic) + if (__riscv_vlenb() >= QK4_0) { + const size_t vl = QK4_0; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + for (int l = 0; l < nb; l++) { + const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); + const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); + const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); + const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); + const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); + const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); + + // vector version needs Zvfhmin extension + const float a_scales[4] = { + LM_GGML_FP16_TO_FP32(a_ptr[l].d[0]), + LM_GGML_FP16_TO_FP32(a_ptr[l].d[1]), + LM_GGML_FP16_TO_FP32(a_ptr[l].d[2]), + LM_GGML_FP16_TO_FP32(a_ptr[l].d[3]) + }; + const float b_scales[8] = { + LM_GGML_FP16_TO_FP32(b_ptr[l].d[0]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[1]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[2]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[3]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[4]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[5]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[6]), + LM_GGML_FP16_TO_FP32(b_ptr[l].d[7]) + }; + const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); + + const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32]; + const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64]; + const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l0; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l0 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); + sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40]; + const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72]; + const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l1; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l1 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4); + sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48]; + const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80]; + const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l2; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l2 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4); + sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24]; + const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56]; + const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88]; + const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l3; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l3 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4); + sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4); + } + } + __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4); + } + } + + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) + float sumf[4][8]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +// FIXME: this code is duplicated from ggml-aarch64.c +static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x4 out; + + for (int i = 0; i < 4; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_0 * 2 / blck_size_interleave; + + if (blck_size_interleave == 8) { + const uint64_t xor_mask = 0x8888888888888888ULL; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + // Using memcpy to avoid unaligned memory accesses + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + } else if (blck_size_interleave == 4) { + const uint32_t xor_mask = 0x88888888; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint32_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t)); + } + } else { + LM_GGML_ASSERT(false); + } + + return out; +} + +// interleave 8 block_q4_0s in blocks of blck_size_interleave +// returns an interleaved block_q4_0x8 +// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks +// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave +static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x8 out; + + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_0 * 4 / blck_size_interleave; + const uint64_t xor_mask = 0x8888888888888888ULL; + + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + + return out; +} + +static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) { + LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0); + LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8); + + block_q4_0x4 * dst = (block_q4_0x4 *)t->data; + const block_q4_0 * src = (const block_q4_0 *)data; + block_q4_0 dst_tmp[4]; + int nrow = t->ne[1]; // Number of rows + int nrows_interleaved = 4; + int nblocks = t->ne[0] / QK4_0; + + LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x4(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + LM_GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_0_8_bl(struct lm_ggml_tensor *t, int interleave_block, const void * restrict data, size_t data_size) { + LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0); + LM_GGML_ASSERT(interleave_block == 8); + + block_q4_0x8 * dst = (block_q4_0x8*)t->data; + const block_q4_0 * src = (const block_q4_0*) data; + block_q4_0 dst_tmp[8]; + int nrow = t->ne[1]; // Number of rows + int nrows_interleaved = 8; + int nblocks = t->ne[0] / QK4_0; + + LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + LM_GGML_UNUSED(data_size); +} + +// Prepare for optimized kernels if applicable +void lm_ggml_aarch64_repack_tensor(struct lm_ggml_tensor * cur, enum lm_ggml_type repack_type, const void * restrict data, size_t data_size) { + if (cur->type == repack_type) { + memcpy(cur->data, data, data_size); + return; + } + + LM_GGML_ASSERT(cur->type == LM_GGML_TYPE_Q4_0); + + switch (repack_type) { + case LM_GGML_TYPE_Q4_0_8_8: + repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size); + break; + case LM_GGML_TYPE_Q4_0_4_8: + repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size); + break; + case LM_GGML_TYPE_Q4_0_4_4: + repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size); + break; + default: + LM_GGML_ABORT("Unsupported type"); + } +} + +enum lm_ggml_type lm_ggml_aarch64_get_optimal_repack_type(const struct lm_ggml_tensor * cur) { + if (cur->type == LM_GGML_TYPE_Q4_0) { + // TODO: enable for AVX2 - currently disabled due to bad gemv performance + if (/* lm_ggml_cpu_has_avx2() || */ (lm_ggml_cpu_has_sve() && lm_ggml_cpu_has_matmul_int8() && lm_ggml_cpu_get_sve_cnt() == QK8_0)) { + return LM_GGML_TYPE_Q4_0_8_8; + } + if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) { + return LM_GGML_TYPE_Q4_0_4_8; + } + if (lm_ggml_cpu_has_neon()) { + return LM_GGML_TYPE_Q4_0_4_4; + } + } + + return cur->type; +} diff --git a/cpp/ggml-cpu-aarch64.h b/cpp/ggml-cpu-aarch64.h new file mode 100644 index 00000000..dd7b10fc --- /dev/null +++ b/cpp/ggml-cpu-aarch64.h @@ -0,0 +1,30 @@ +#pragma once + +#include "ggml.h" + +// GGML internal header + +#ifdef __cplusplus +extern "C" { +#endif + +// Quantization +void quantize_mat_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave); + +// GEMV +void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); +void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); +void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); + +// GEMM +void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); +void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); +void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); + +void lm_ggml_aarch64_repack_tensor(struct lm_ggml_tensor * cur, enum lm_ggml_type repack_type, const void * data, size_t data_size); +enum lm_ggml_type lm_ggml_aarch64_get_optimal_repack_type(const struct lm_ggml_tensor * cur); + +#ifdef __cplusplus +} +#endif + diff --git a/cpp/ggml-cpu-impl.h b/cpp/ggml-cpu-impl.h index 760deb2f..47f0cad5 100644 --- a/cpp/ggml-cpu-impl.h +++ b/cpp/ggml-cpu-impl.h @@ -27,80 +27,6 @@ extern "C" { #endif -/** - * Converts brain16 to float32. - * - * The bfloat16 floating point format has the following structure: - * - * ┌sign - * │ - * │ ┌exponent - * │ │ - * │ │ ┌mantissa - * │ │ │ - * │┌──┴───┐┌─┴───┐ - * 0b0000000000000000 brain16 - * - * Since bf16 has the same number of exponent bits as a 32bit float, - * encoding and decoding numbers becomes relatively straightforward. - * - * ┌sign - * │ - * │ ┌exponent - * │ │ - * │ │ ┌mantissa - * │ │ │ - * │┌──┴───┐┌─┴───────────────────┐ - * 0b00000000000000000000000000000000 IEEE binary32 - * - * For comparison, the standard fp16 format has fewer exponent bits. - * - * ┌sign - * │ - * │ ┌exponent - * │ │ - * │ │ ┌mantissa - * │ │ │ - * │┌─┴─┐┌─┴──────┐ - * 0b0000000000000000 IEEE binary16 - * - * @see IEEE 754-2008 - */ -static inline float lm_ggml_compute_bf16_to_fp32(lm_ggml_bf16_t h) { - union { - float f; - uint32_t i; - } u; - u.i = (uint32_t)h.bits << 16; - return u.f; -} - -/** - * Converts float32 to brain16. - * - * This is binary identical with Google Brain float conversion. - * Floats shall round to nearest even, and NANs shall be quiet. - * Subnormals aren't flushed to zero, except perhaps when used. - * This code should vectorize nicely if using modern compilers. - */ -static inline lm_ggml_bf16_t lm_ggml_compute_fp32_to_bf16(float s) { - lm_ggml_bf16_t h; - union { - float f; - uint32_t i; - } u; - u.f = s; - if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */ - h.bits = (u.i >> 16) | 64; /* force to quiet */ - return h; - } - h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16; - return h; -} - -#define LM_GGML_FP32_TO_BF16(x) lm_ggml_compute_fp32_to_bf16(x) -#define LM_GGML_BF16_TO_FP32(x) lm_ggml_compute_bf16_to_fp32(x) - // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) #ifndef __FMA__ @@ -388,28 +314,6 @@ inline static int32x4_t lm_ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t #endif // defined(__ARM_NEON) -#if defined(__ARM_NEON) && !defined(_MSC_VER) - -#define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) -#define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x) - -#define LM_GGML_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) - -static inline float lm_ggml_compute_fp16_to_fp32(lm_ggml_fp16_t h) { - lm_ggml_fp16_internal_t tmp; - memcpy(&tmp, &h, sizeof(lm_ggml_fp16_t)); - return (float)tmp; -} - -static inline lm_ggml_fp16_t lm_ggml_compute_fp32_to_fp16(float f) { - lm_ggml_fp16_t res; - lm_ggml_fp16_internal_t tmp = f; - memcpy(&res, &tmp, sizeof(lm_ggml_fp16_t)); - return res; -} - -#else - #ifdef __wasm_simd128__ #include #else @@ -462,153 +366,6 @@ static __m256 __lasx_xvreplfr2vr_s(float val) { } #endif -#ifdef __F16C__ - -#ifdef _MSC_VER -#define LM_GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) -#define LM_GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) -#else -#define LM_GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) -#define LM_GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) -#endif - -#elif defined(__POWER9_VECTOR__) - -#define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) -#define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x) -/* the inline asm below is about 12% faster than the lookup method */ -#define LM_GGML_FP16_TO_FP32(x) LM_GGML_COMPUTE_FP16_TO_FP32(x) -#define LM_GGML_FP32_TO_FP16(x) LM_GGML_COMPUTE_FP32_TO_FP16(x) - -static inline float lm_ggml_compute_fp16_to_fp32(lm_ggml_fp16_t h) { - register float f; - register double d; - __asm__( - "mtfprd %0,%2\n" - "xscvhpdp %0,%0\n" - "frsp %1,%0\n" : - /* temp */ "=d"(d), - /* out */ "=f"(f): - /* in */ "r"(h)); - return f; -} - -static inline lm_ggml_fp16_t lm_ggml_compute_fp32_to_fp16(float f) { - register double d; - register lm_ggml_fp16_t r; - __asm__( /* xscvdphp can work on double or single precision */ - "xscvdphp %0,%2\n" - "mffprd %1,%0\n" : - /* temp */ "=d"(d), - /* out */ "=r"(r): - /* in */ "f"(f)); - return r; -} - -#else - -// FP16 <-> FP32 -// ref: https://github.com/Maratyszcza/FP16 - -static inline float fp32_from_bits(uint32_t w) { - union { - uint32_t as_bits; - float as_value; - } fp32; - fp32.as_bits = w; - return fp32.as_value; -} - -static inline uint32_t fp32_to_bits(float f) { - union { - float as_value; - uint32_t as_bits; - } fp32; - fp32.as_value = f; - return fp32.as_bits; -} - -static inline float lm_ggml_compute_fp16_to_fp32(lm_ggml_fp16_t h) { - const uint32_t w = (uint32_t) h << 16; - const uint32_t sign = w & UINT32_C(0x80000000); - const uint32_t two_w = w + w; - - const uint32_t exp_offset = UINT32_C(0xE0) << 23; -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float exp_scale = 0x1.0p-112f; -#else - const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); -#endif - const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; - - const uint32_t magic_mask = UINT32_C(126) << 23; - const float magic_bias = 0.5f; - const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; - - const uint32_t denormalized_cutoff = UINT32_C(1) << 27; - const uint32_t result = sign | - (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); - return fp32_from_bits(result); -} - -static inline lm_ggml_fp16_t lm_ggml_compute_fp32_to_fp16(float f) { -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float scale_to_inf = 0x1.0p+112f; - const float scale_to_zero = 0x1.0p-110f; -#else - const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); - const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); -#endif - float base = (fabsf(f) * scale_to_inf) * scale_to_zero; - - const uint32_t w = fp32_to_bits(f); - const uint32_t shl1_w = w + w; - const uint32_t sign = w & UINT32_C(0x80000000); - uint32_t bias = shl1_w & UINT32_C(0xFF000000); - if (bias < UINT32_C(0x71000000)) { - bias = UINT32_C(0x71000000); - } - - base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; - const uint32_t bits = fp32_to_bits(base); - const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); - const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); - const uint32_t nonsign = exp_bits + mantissa_bits; - return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); -} - -#define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) -#define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x) - -#endif // __F16C__ - -#endif // defined(__ARM_NEON) && (!defined(__MSC_VER) - -#ifdef __ARM_FEATURE_SVE -#include -#endif // __ARM_FEATURE_SVE - -// precomputed f32 table for f16 (256 KB) -// defined in ggml.c, initialized in lm_ggml_init() -extern float lm_ggml_table_f32_f16[1 << 16]; - -// On ARM NEON, it's quicker to directly convert x -> x instead of calling into lm_ggml_lookup_fp16_to_fp32, -// so we define LM_GGML_FP16_TO_FP32 and LM_GGML_FP32_TO_FP16 elsewhere for NEON. -// This is also true for POWER9. -#if !defined(LM_GGML_FP16_TO_FP32) -inline static float lm_ggml_lookup_fp16_to_fp32(lm_ggml_fp16_t f) { - uint16_t s; - memcpy(&s, &f, sizeof(uint16_t)); - return lm_ggml_table_f32_f16[s]; -} - -#define LM_GGML_FP16_TO_FP32(x) lm_ggml_lookup_fp16_to_fp32(x) -#endif - -#if !defined(LM_GGML_FP32_TO_FP16) -#define LM_GGML_FP32_TO_FP16(x) LM_GGML_COMPUTE_FP32_TO_FP16(x) -#endif - #ifdef __cplusplus } #endif diff --git a/cpp/ggml-cpu-quants.c b/cpp/ggml-cpu-quants.c new file mode 100644 index 00000000..3a2cfe34 --- /dev/null +++ b/cpp/ggml-cpu-quants.c @@ -0,0 +1,10822 @@ +#define LM_GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#include "ggml-quants.h" +#include "ggml-cpu-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu-impl.h" +#include "ggml-cpu.h" + +#include +#include +#include +#include +#include // for qsort +#include // for LM_GGML_ASSERT + +#define GROUP_MAX_EPS 1e-15f +#define GROUP_MAX_EPS_IQ3_XXS 1e-8f +#define GROUP_MAX_EPS_IQ2_S 1e-8f +#define GROUP_MAX_EPS_IQ1_M 1e-7f +#define GROUP_MAX_EPS_IQ1_S 1e-12f + +#if defined(_MSC_VER) +// disable "possible loss of data" to avoid warnings for hundreds of casts +// we should just be careful :) +#pragma warning(disable: 4244 4267) +#endif + +#define UNUSED LM_GGML_UNUSED + +// some compilers don't provide _mm256_set_m128i, e.g. gcc 7 +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) +// multiply int8_t, add results pairwise twice +static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { + // Get absolute values of x vectors + const __m128i ax = _mm_sign_epi8(x, x); + // Sign the values of the y vectors + const __m128i sy = _mm_sign_epi8(y, x); + // Perform multiplication and create 16-bit values + const __m128i dot = _mm_maddubs_epi16(ax, sy); + const __m128i ones = _mm_set1_epi16(1); + return _mm_madd_epi16(ones, dot); +} + +#if __AVX__ || __AVX2__ || __AVX512F__ +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + const __m128i hi64 = _mm_unpackhi_epi64(a, a); + const __m128i sum64 = _mm_add_epi32(hi64, a); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +#if defined(__AVX2__) || defined(__AVX512F__) +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = _mm256_set_epi64x( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); + const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytes = _mm256_or_si256(bytes, bit_mask); + return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8( 0xF ); + return _mm256_and_si256(lowMask, bytes); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + const __m256i summed_pairs = _mm256_madd_epi16(ones, x); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { +#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_float(dot); +#endif +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + return mul_sum_us8_pairs_float(ax, sy); +#endif +} + +static inline __m128i packNibbles( __m256i bytes ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh +#if __AVX512F__ + const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 + bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh + return _mm256_cvtepi16_epi8(bytes); // abcd_efgh +#else + const __m256i lowByte = _mm256_set1_epi16( 0xFF ); + __m256i high = _mm256_andnot_si256( lowByte, bytes ); + __m256i low = _mm256_and_si256( lowByte, bytes ); + high = _mm256_srli_epi16( high, 4 ); + bytes = _mm256_or_si256( low, high ); + + // Compress uint16_t lanes into bytes + __m128i r0 = _mm256_castsi256_si128( bytes ); + __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); + return _mm_packus_epi16( r0, r1 ); +#endif +} +#elif defined(__AVX__) +static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m128i lowByte = _mm_set1_epi16( 0xFF ); + __m128i high = _mm_andnot_si128( lowByte, bytes1 ); + __m128i low = _mm_and_si128( lowByte, bytes1 ); + high = _mm_srli_epi16( high, 4 ); + bytes1 = _mm_or_si128( low, high ); + high = _mm_andnot_si128( lowByte, bytes2 ); + low = _mm_and_si128( lowByte, bytes2 ); + high = _mm_srli_epi16( high, 4 ); + bytes2 = _mm_or_si128( low, high ); + + return _mm_packus_epi16( bytes1, bytes2); +} + +static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) { + const __m128i ax = _mm_sign_epi8(x, x); + const __m128i sy = _mm_sign_epi8(y, x); + return _mm_maddubs_epi16(ax, sy); +} + +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); + __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); + __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); + const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytesl = _mm_or_si128(bytesl, bit_mask); + bytesh = _mm_or_si128(bytesh, bit_mask); + bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); + bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); + return MM256_SET_M128I(bytesh, bytesl); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + // Load 16 bytes from memory + __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); + __m128i tmph = _mm_srli_epi16(tmpl, 4); + const __m128i lowMask = _mm_set1_epi8(0xF); + tmpl = _mm_and_si128(lowMask, tmpl); + tmph = _mm_and_si128(lowMask, tmph); + return MM256_SET_M128I(tmph, tmpl); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { + const __m128i ones = _mm_set1_epi16(1); + const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); + const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); + const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + const __m128i axl = _mm256_castsi256_si128(ax); + const __m128i axh = _mm256_extractf128_si256(ax, 1); + const __m128i syl = _mm256_castsi256_si128(sy); + const __m128i syh = _mm256_extractf128_si256(sy, 1); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + const __m128i xl = _mm256_castsi256_si128(x); + const __m128i xh = _mm256_extractf128_si256(x, 1); + const __m128i yl = _mm256_castsi256_si128(y); + const __m128i yh = _mm256_extractf128_si256(y, 1); + // Get absolute values of x vectors + const __m128i axl = _mm_sign_epi8(xl, xl); + const __m128i axh = _mm_sign_epi8(xh, xh); + // Sign the values of the y vectors + const __m128i syl = _mm_sign_epi8(yl, xl); + const __m128i syh = _mm_sign_epi8(yh, xh); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors +static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1, + const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) { + const __m128i mone = _mm_set1_epi16(1); + + const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1); + const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone); + const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone); + const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone); + const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone); + const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1); + const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1); + return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1)); +} + +// quad fp16 delta calculation +static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) { + // LM_GGML_FP16_TO_FP32 is faster than Intel F16C + return _mm256_set_m128(_mm_set1_ps(LM_GGML_FP16_TO_FP32(x1) * LM_GGML_FP16_TO_FP32(y1)), + _mm_set1_ps(LM_GGML_FP16_TO_FP32(x0) * LM_GGML_FP16_TO_FP32(y0))); +} +#endif +#elif defined(__SSSE3__) +// horizontally add 4x4 floats +static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { + __m128 res_0 =_mm_hadd_ps(a, b); + __m128 res_1 =_mm_hadd_ps(c, d); + __m128 res =_mm_hadd_ps(res_0, res_1); + res =_mm_hadd_ps(res, res); + res =_mm_hadd_ps(res, res); + + return _mm_cvtss_f32(res); +} +#endif // __AVX__ || __AVX2__ || __AVX512F__ +#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) + +#if defined(__ARM_NEON) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__) +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) +#define B8(c,s ) B7(c,s, c), B7(c,s, s) + +// precomputed tables for expanding 8bits to 8 bytes: +static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 +static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 +#endif + +#if defined(__loongarch_asx) + +#ifdef __clang__ +#define VREGS_PREFIX "$vr" +#define XREGS_PREFIX "$xr" +#else // GCC +#define VREGS_PREFIX "$f" +#define XREGS_PREFIX "$f" +#endif +#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31" +// Convert __m128i to __m256i +static inline __m256i ____m256i(__m128i in) { + __m256i out = __lasx_xvldi(0); + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " XREGS_PREFIX"\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " VREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + : [out] "+f" (out) : [in] "f" (in) + ); + return out; +} +// Convert two __m128i to __m256i +static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) { + __m256i out; + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[hi], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[lo], " VREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".ifnc %[out], %[hi] \n\t" + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " XREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[hi], " VREGS_PREFIX "\\j \n\t" + " xvori.b $xr\\i, $xr\\j, 0 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".endif \n\t" + : [out] "=f" (out), [hi] "+f" (inhi) + : [lo] "f" (inlo) + ); + return out; +} +// Convert __m256i low part to __m128i +static inline __m128i lasx_extracti128_lo(__m256i in) { + __m128i out; + __asm__ volatile ( + ".ifnc %[out], %[in] \n\t" + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " XREGS_PREFIX "\\j \n\t" + " vori.b $vr\\i, $vr\\j, 0 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".endif \n\t" + : [out] "=f" (out) : [in] "f" (in) + ); + return out; +} +// Convert __m256i high part to __m128i +static inline __m128i lasx_extracti128_hi(__m256i in) { + __m128i out; + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " XREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + : [out] "=f" (out) : [in] "f" (in) + ); + return out; +} + +static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) { + v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7}; + return (__m256i)__ret; +} + +static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) { + v4i32 __ret = {d, c, b, a}; + return (__m128i)__ret; +} + +static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) { + v4i64 __ret = {d, c, b, a}; + return (__m256i)__ret; +} + +static __m256i lasx_insertf128( __m128i x, __m128i y) { + return lasx_set_q(x, y); +} + +static __m128i lsx_shuffle_b(__m128i a, __m128i b) { + __m128i mask_f, zero, tmp0, tmp2, mask; + int f = 0x8f; + mask_f = __lsx_vreplgr2vr_b(f); + zero = __lsx_vldi(0); + tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits + tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive + mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask + tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones + return __lsx_vshuf_b(a, zero, tmp2); +} + +static __m256i lasx_shuffle_b(__m256i a, __m256i b) { + __m256i mask_f, zero, tmp0, tmp2, mask; + int f = 0x8f; + mask_f = __lasx_xvreplgr2vr_b(f); + zero = __lasx_xvldi(0); + tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits + tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive + mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask + tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones + return __lasx_xvshuf_b(a, zero, tmp2); +} + +static __m256i lasx_extu8_16(__m128i a) { + __m128i zero = __lsx_vldi(0); + __m128i vlo = __lsx_vilvl_b(zero, a); + __m128i vhi = __lsx_vilvh_b(zero, a); + return lasx_set_q(vhi, vlo); +} + +static __m256i lasx_ext8_16(__m128i a) { + __m128i sign = __lsx_vslti_b(a, 0); + __m128i vlo = __lsx_vilvl_b(sign, a); + __m128i vhi = __lsx_vilvh_b(sign, a); + return lasx_set_q(vhi, vlo); +} + +static __m256i lasx_ext16_32(__m128i a) { + __m256i tmp1; + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7); + return tmp1; +} + +static __m128i lasx_extracti128( __m256i a, int pos) { + __m128i ret; + if( pos == 0) + { + ret = lasx_extracti128_lo(a); + } else { + ret = lasx_extracti128_hi(a); + } + return ret; +} + +static __m128 lasx_extractf128( __m256 a, int pos) { + __m128 ret; + if( pos == 0) + { + ret = (__m128)lasx_extracti128_lo((__m256i)a); + } else { + ret = (__m128)lasx_extracti128_hi((__m256i)a); + } + return ret; +} + +static __m128i lsx_hadd_h(__m128i a, __m128i b) { + __m128i tmp1 = __lsx_vpickev_h(b, a); + __m128i tmp2 = __lsx_vpickod_h(b, a); + return __lsx_vadd_h(tmp1, tmp2); +} + +static __m128i lsx_hadd_w(__m128i a, __m128i b) { + __m128i tmp1 = __lsx_vpickev_w(b, a); + __m128i tmp2 = __lsx_vpickod_w(b, a); + return __lsx_vadd_w(tmp1, tmp2); +} + +static __m128 lsx_hadd_s(__m128 a, __m128 b) { + __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a); + __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a); + + return __lsx_vfadd_s(tmp1, tmp2); +} + +static __m256i lasx_maddubs_h(__m256i a, __m256i b) { + __m256i tmp1, tmp2; + tmp1 = __lasx_xvmulwev_h_b(a, b); + tmp2 = __lasx_xvmulwod_h_b(a, b); + return __lasx_xvsadd_h(tmp1, tmp2); +} + +static __m256i lasx_madd_h(__m256i a, __m256i b) { + __m256i tmp1, tmp2; + tmp1 = __lasx_xvmulwev_w_h(a, b); + tmp2 = __lasx_xvmulwod_w_h(a, b); + return __lasx_xvadd_w(tmp1, tmp2); +} + +static __m256i lasx_packs_w(__m256i a, __m256i b) { + __m256i tmp, tmp1; + tmp = __lasx_xvsat_w(a, 15); + tmp1 = __lasx_xvsat_w(b, 15); + return __lasx_xvpickev_h(tmp1, tmp); +} + +static __m256i lasx_packs_h(__m256i a, __m256i b) { + __m256i tmp, tmp1; + tmp = __lasx_xvsat_h(a, 7); + tmp1 = __lasx_xvsat_h(b, 7); + return __lasx_xvpickev_b(tmp1, tmp); +} + +static __m128i lsx_packs_w(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_w(a, 15); + tmp1 = __lsx_vsat_w(b, 15); + return __lsx_vpickev_h(tmp1, tmp); +} + +static __m128i lsx_packs_h(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_h(a, 7); + tmp1 = __lsx_vsat_h(b, 7); + return __lsx_vpickev_b(tmp1, tmp); +} + +static __m128i lsx_packus_h(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_hu(a, 7); + tmp1 = __lsx_vsat_hu(b, 7); + return __lsx_vpickev_b(tmp1, tmp); +} + + +static __m128i lsx_maddubs_h(__m128i a, __m128i b) { + __m128i tmp1, tmp2; + tmp1 = __lsx_vmulwev_h_b(a, b); + tmp2 = __lsx_vmulwod_h_b(a, b); + return __lsx_vsadd_h(tmp1, tmp2); +} + +static __m128i lsx_madd_h(__m128i a, __m128i b) { + __m128i tmp1, tmp2; + tmp1 = __lsx_vmulwev_w_h(a, b); + tmp2 = __lsx_vmulwod_w_h(a, b); + return __lsx_vadd_w(tmp1, tmp2); +} + +// multiply int8_t, add results pairwise twice +static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { + // Get absolute values of x vectors + const __m128i ax = __lsx_vsigncov_b(x, x); + // Sign the values of the y vectors + const __m128i sy = __lsx_vsigncov_b(x, y); + // Perform multiplication and create 16-bit values + const __m128i dot = lsx_maddubs_h(ax, sy); + const __m128i ones = __lsx_vreplgr2vr_h(1); + return lsx_madd_h(ones, dot); +} + +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = lasx_extractf128(x, 1); + ft_union tmp; + res = __lsx_vfadd_s(res, lasx_extractf128(x, 0)); + res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res)); + res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0)); + tmp.i = __lsx_vpickve2gr_w(res, 0); + return tmp.f; +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + + __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11); + __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00); + + __m128i tmp1_128 = lasx_extracti128_lo(tmp1); + __m128i tmp2_128 = lasx_extracti128_lo(tmp2); + + __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128); + + __m128i ev = __lsx_vpickev_w(sum128, sum128); + __m128i od = __lsx_vpickod_w(sum128, sum128); + __m128i sum64 = __lsx_vadd_w(ev, od); + + int sum64_1, sum64_2; + sum64_1 = __lsx_vpickve2gr_w(sum64, 0); + sum64_2 = __lsx_vpickve2gr_w(sum64, 1); + + return sum64_1 + sum64_2; +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + __m128i ev = __lsx_vpickev_w(a, a); + __m128i od = __lsx_vpickod_w(a, a); + __m128i sum64 = __lsx_vadd_w(ev, od); + + int sum64_1, sum64_2; + sum64_1 = __lsx_vpickve2gr_w(sum64, 0); + sum64_2 = __lsx_vpickve2gr_w(sum64, 1); + + return sum64_1 + sum64_2; +} + +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = lasx_set_d( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + + __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask); + const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe); + bytes = __lasx_xvor_v(bytes, bit_mask); + return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1)); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { + const __m128i lo = __lsx_vld((const __m128i *)rsi, 0); + __m128i hi = __lsx_vsrli_h(lo, 4); + return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + __m256i v = __lasx_xvpackod_h(x, x); + __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v); + return __lasx_xvffint_s_w(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + // Perform multiplication and create 16-bit values + const __m256i dot = lasx_maddubs_h(ax, sy); + return sum_i16_pairs_float(dot); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + + // Get absolute values of x vectors + const __m256i ax = __lasx_xvsigncov_b(x, x); + // Sign the values of the y vectors + const __m256i sy = __lasx_xvsigncov_b(x, y); + + return mul_sum_us8_pairs_float(ax, sy); +} + +static inline __m128i packNibbles( __m256i bytes ) { + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF); + __m256i high = __lasx_xvandn_v(lowByte, bytes); + __m256i low = __lasx_xvand_v(lowByte, bytes); + high = __lasx_xvsrli_h(high, 4); + bytes = __lasx_xvor_v(low, high); + // Compress uint16_t lanes into bytes + __m128i *r0 = (__m128i *)&bytes; + __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11); + __m128i *r1 = (__m128i *)&tmp_h128; + + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2, tmp3; + + tmp = __lsx_vmax_h(zero, *r0); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(zero, *r1); + tmp3 = __lsx_vsat_hu(tmp, 7); + return __lsx_vpickev_b(tmp3, tmp2); +} +#endif //__loongarch_asx + +void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q4_0_ref(x, y, k); +} + +void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q4_1_ref(x, y, k); +} + +void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q5_0_ref(x, y, k); +} + +void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q5_1_ref(x, y, k); +} + +void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = LM_GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } + } +#elif defined(__wasm_simd128__) + for (int i = 0; i < nb; i++) { + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), + wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), + wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = LM_GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); + + y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); + y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); + y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); + y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); + } + } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + y[i].d = LM_GGML_FP32_TO_FP16(d); + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#elif defined(__riscv_v_intrinsic) + + size_t vl = __riscv_vsetvl_e32m4(QK8_0); + + for (int i = 0; i < nb; i++) { + // load elements + vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl); + + vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = LM_GGML_FP32_TO_FP16(d); + + vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + + // convert to integer + vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); + vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + + // store result + __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + } + +#elif defined(__POWER9_VECTOR__) + for (int i = 0; i < nb; i++) { + vector float srcv [8]; + vector float asrcv[8]; + vector float amaxv[8]; + vector signed int vi[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(vec_extract(amaxv[0], 0), + vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), + vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + const vector float vid = vec_splats(id); + + y[i].d = LM_GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const vector float v = vec_round(vec_mul(srcv[j], vid)); + vi[j] = vec_cts(v, 0); + } + vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); + vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); + } + +#elif defined(__loongarch_asx) + for (int i = 0; i < nb; i++) { + ft_union fi; + __m256 v0 = (__m256)__lasx_xvld( x , 0); + __m256 v1 = (__m256)__lasx_xvld( x , 32); + __m256 v2 = (__m256)__lasx_xvld( x , 64); + __m256 v3 = (__m256)__lasx_xvld( x , 96); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f ); + __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) ); + + __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) ); + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); + __m128 tmp = max4; + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 )); + fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 ); + const float max_scalar = fi.f; + + // Quantize these floats + const float d = max_scalar / 127.f; + y[i].d = LM_GGML_FP32_TO_FP16(d); + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id ); + + // Apply the multiplier + v0 = __lasx_xvfmul_s( v0, mul ); + v1 = __lasx_xvfmul_s( v1, mul ); + v2 = __lasx_xvfmul_s( v2, mul ); + v3 = __lasx_xvfmul_s( v3, mul ); + + // Round to nearest integer + __m256i i0 = __lasx_xvftintrne_w_s( v0 ); + __m256i i1 = __lasx_xvftintrne_w_s( v1 ); + __m256i i2 = __lasx_xvftintrne_w_s( v2 ); + __m256i i3 = __lasx_xvftintrne_w_s( v3 ); + + __m128i ni0 = lasx_extracti128( i0, 0 ); + __m128i ni1 = lasx_extracti128( i0, 1); + __m128i ni2 = lasx_extracti128( i1, 0); + __m128i ni3 = lasx_extracti128( i1, 1); + __m128i ni4 = lasx_extracti128( i2, 0); + __m128i ni5 = lasx_extracti128( i2, 1); + __m128i ni6 = lasx_extracti128( i3, 0); + __m128i ni7 = lasx_extracti128( i3, 1); + + // Convert int32 to int16 + ni0 = lsx_packs_w( ni0, ni1 ); + ni2 = lsx_packs_w( ni2, ni3 ); + ni4 = lsx_packs_w( ni4, ni5 ); + ni6 = lsx_packs_w( ni6, ni7 ); + // Convert int16 to int8 + ni0 = lsx_packs_h( ni0, ni2 ); + ni4 = lsx_packs_h( ni4, ni6 ); + + __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); + __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); + + } +#else + LM_GGML_UNUSED(nb); + // scalar + quantize_row_q8_0_ref(x, y, k); +#endif +} + +void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = LM_GGML_FP32_TO_FP16(d); + + int32x4_t accv = vdupq_n_s32(0); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + + accv = vaddq_s32(accv, vi); + } + + y[i].s = LM_GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } +#elif defined(__wasm_simd128__) + for (int i = 0; i < nb; i++) { + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), + wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), + wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = LM_GGML_FP32_TO_FP16(d); + + v128_t accv = wasm_i32x4_splat(0); + + for (int j = 0; j < 8; j++) { + const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); + + y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); + y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); + y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); + y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); + + accv = wasm_i32x4_add(accv, vi); + } + + y[i].s = LM_GGML_FP32_TO_FP16( + d * (wasm_i32x4_extract_lane(accv, 0) + + wasm_i32x4_extract_lane(accv, 1) + + wasm_i32x4_extract_lane(accv, 2) + + wasm_i32x4_extract_lane(accv, 3))); + } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float max_scalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = max_scalar / 127.f; + y[i].d = LM_GGML_FP32_TO_FP16(d); + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Compute the sum of the quants and set y[i].s + y[i].s = LM_GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Compute the sum of the quants and set y[i].s + const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); + const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); + y[i].s = LM_GGML_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1))); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#elif defined(__riscv_v_intrinsic) + + size_t vl = __riscv_vsetvl_e32m4(QK8_1); + + for (int i = 0; i < nb; i++) { + // load elements + vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl); + + vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = LM_GGML_FP32_TO_FP16(d); + + vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + + // convert to integer + vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); + vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + + // store result + __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + + // compute sum for y[i].s + vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl); + + // set y[i].s + int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); + y[i].s = LM_GGML_FP32_TO_FP16(sum*d); + } + +#elif defined(__POWER9_VECTOR__) + for (int i = 0; i < nb; i++) { + vector float srcv [8]; + vector float asrcv[8]; + vector float amaxv[8]; + vector signed int vi[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(vec_extract(amaxv[0], 0), + vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), + vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + const vector float vid = vec_splats(id); + + y[i].d = LM_GGML_FP32_TO_FP16(d); + + vector int accv = vec_splats(0); + + for (int j = 0; j < 8; j++) { + const vector float v = vec_round(vec_mul(srcv[j], vid)); + vi[j] = vec_cts(v, 0); + + accv = vec_add(accv, vi[j]); + } + vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); + vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); + + accv = vec_add(accv, vec_sld(accv, accv, 4)); + accv = vec_add(accv, vec_sld(accv, accv, 8)); + y[i].s = LM_GGML_FP32_TO_FP16(d * vec_extract(accv, 0)); + } + +#elif defined(__loongarch_asx) + for (int i = 0; i < nb; i++) { + ft_union ft; + __m256 v0 = (__m256)__lasx_xvld( x , 0 ); + __m256 v1 = (__m256)__lasx_xvld( x , 32 ); + __m256 v2 = (__m256)__lasx_xvld( x , 64 ); + __m256 v3 = (__m256)__lasx_xvld( x , 96 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f ); + __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) ); + + __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) ); + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); + __m128 tmp = max4; + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 )); + ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 ); + const float max_scalar = ft.f; + + // Quantize these floats + const float d = max_scalar / 127.f; + y[i].d = LM_GGML_FP32_TO_FP16(d); + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const __m256 mul = __lasx_xvreplfr2vr_s( id ); + + // Apply the multiplier + v0 = __lasx_xvfmul_s( v0, mul ); + v1 = __lasx_xvfmul_s( v1, mul ); + v2 = __lasx_xvfmul_s( v2, mul ); + v3 = __lasx_xvfmul_s( v3, mul ); + + // Round to nearest integer + __m256i i0 = __lasx_xvftintrne_w_s( v0 ); + __m256i i1 = __lasx_xvftintrne_w_s( v1 ); + __m256i i2 = __lasx_xvftintrne_w_s( v2 ); + __m256i i3 = __lasx_xvftintrne_w_s( v3 ); + + __m128i ni0 = lasx_extracti128(i0, 0); + __m128i ni1 = lasx_extracti128( i0, 1); + __m128i ni2 = lasx_extracti128( i1, 0); + __m128i ni3 = lasx_extracti128( i1, 1); + __m128i ni4 = lasx_extracti128( i2, 0 ); + __m128i ni5 = lasx_extracti128( i2, 1); + __m128i ni6 = lasx_extracti128( i3, 0); + __m128i ni7 = lasx_extracti128( i3, 1); + + // Compute the sum of the quants and set y[i].s + const __m128i s0 = __lsx_vadd_w(__lsx_vadd_w(ni0, ni1), __lsx_vadd_w(ni2, ni3)); + const __m128i s1 = __lsx_vadd_w(__lsx_vadd_w(ni4, ni5), __lsx_vadd_w(ni6, ni7)); + y[i].s = LM_GGML_FP32_TO_FP16(d * hsum_i32_4(__lsx_vadd_w(s0, s1))); + + // Convert int32 to int16 + ni0 = lsx_packs_w( ni0, ni1 ); + ni2 = lsx_packs_w( ni2, ni3 ); + ni4 = lsx_packs_w( ni4, ni5 ); + ni6 = lsx_packs_w( ni6, ni7 ); + // Convert int16 to int8 + ni0 = lsx_packs_h( ni0, ni2 ); + ni4 = lsx_packs_h( ni4, ni6 ); + + __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); + __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); + } +#else + LM_GGML_UNUSED(nb); + // scalar + quantize_row_q8_1_ref(x, y, k); +#endif +} + +// +// 2-6 bit quantization in super-blocks +// + +// +// ===================== Helper functions +// +static inline int nearest_int(float fval) { + assert(fabsf(fval) <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type, + const float * restrict qw) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (amax < GROUP_MAX_EPS) { // all zero + for (int i = 0; i < n; ++i) { + L[i] = 0; + } + return 0.f; + } + float iscale = -nmax / max; + if (rmse_type == 0) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + return 1/iscale; + } + bool return_early = false; + if (rmse_type < 0) { + rmse_type = -rmse_type; + return_early = true; + } + float sumlx = 0; + float suml2 = 0; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 0; i < n; ++i) { +#else + for (int i = 0; i < n; ++i) { +#endif + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l + nmax; + float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + float scale = suml2 ? sumlx/suml2 : 0.0f; + if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale; + float best = scale * sumlx; + for (int is = -9; is <= 9; ++is) { + if (is == 0) { + continue; + } + iscale = -(nmax + 0.1f*is) / max; + sumlx = suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + if (suml2 > 0 && sumlx*sumlx > best*suml2) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + scale = sumlx/suml2; best = scale*sumlx; + } + } + return scale; +} + +static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (amax < GROUP_MAX_EPS) { // all zero + for (int i = 0; i < n; ++i) { L[i] = 0; } + return 0.f; + } + float iscale = -nmax / max; + if (do_rmse) { + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l; + float w = x[i]*x[i]; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + for (int itry = 0; itry < 5; ++itry) { + int n_changed = 0; + for (int i = 0; i < n; ++i) { + float w = x[i]*x[i]; + float slx = sumlx - w*x[i]*L[i]; + if (slx > 0) { + float sl2 = suml2 - w*L[i]*L[i]; + int new_l = nearest_int(x[i] * sl2 / slx); + new_l = MAX(-nmax, MIN(nmax-1, new_l)); + if (new_l != L[i]) { + slx += w*x[i]*new_l; + sl2 += w*new_l*new_l; + if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) { + L[i] = new_l; sumlx = slx; suml2 = sl2; + ++n_changed; + } + } + } + } + if (!n_changed) { + break; + } + } + for (int i = 0; i < n; ++i) { + L[i] += nmax; + } + return sumlx / suml2; + } + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l + nmax; + } + return 1/iscale; +} + +static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, + int ntry, float alpha) { + float min = x[0]; + float max = x[0]; + for (int i = 1; i < n; ++i) { + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + } + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = 0; + return 0.f; + } + if (min > 0) min = 0; + float iscale = nmax/(max - min); + float scale = 1/iscale; + for (int itry = 0; itry < ntry; ++itry) { + float sumlx = 0; int suml2 = 0; + bool did_change = false; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + if (l != L[i]) { + L[i] = l; + did_change = true; + } + sumlx += (x[i] - min)*l; + suml2 += l*l; + } + scale = sumlx/suml2; + float sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] - scale*L[i]; + } + min = alpha*min + (1 - alpha)*sum/n; + if (min > 0) min = 0; + iscale = 1/scale; + if (!did_change) break; + } + *the_min = -min; + return scale; +} + +static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights, + uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux, + float rmin, float rdelta, int nstep, bool use_mad) { + float min = x[0]; + float max = x[0]; + float sum_w = weights[0]; + float sum_x = sum_w * x[0]; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 1; i < n; ++i) { +#else + for (int i = 1; i < n; ++i) { +#endif + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + float w = weights[i]; + sum_w += w; + sum_x += w * x[i]; + } + if (min > 0) min = 0; + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = -min; + return 0.f; + } + float iscale = nmax/(max - min); + float scale = 1/iscale; + float best_mad = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + L[i] = MAX(0, MIN(nmax, l)); + float diff = scale * L[i] + min - x[i]; + diff = use_mad ? fabsf(diff) : diff * diff; + float w = weights[i]; + best_mad += w * diff; + } + if (nstep < 1) { + *the_min = -min; + return scale; + } + for (int is = 0; is <= nstep; ++is) { + iscale = (rmin + rdelta*is + nmax)/(max - min); + float sum_l = 0, sum_l2 = 0, sum_xl = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + Laux[i] = l; + float w = weights[i]; + sum_l += w*l; + sum_l2 += w*l*l; + sum_xl += w*l*x[i]; + } + float D = sum_w * sum_l2 - sum_l * sum_l; + if (D > 0) { + float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; + float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; + if (this_min > 0) { + this_min = 0; + this_scale = sum_xl / sum_l2; + } + float mad = 0; + for (int i = 0; i < n; ++i) { + float diff = this_scale * Laux[i] + this_min - x[i]; + diff = use_mad ? fabsf(diff) : diff * diff; + float w = weights[i]; + mad += w * diff; + } + if (mad < best_mad) { + for (int i = 0; i < n; ++i) { + L[i] = Laux[i]; + } + best_mad = mad; + scale = this_scale; + min = this_min; + } + } + } + *the_min = -min; + return scale; +} + +static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { + if (j < 4) { + *d = q[j] & 63; *m = q[j + 4] & 63; + } else { + *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} + +//========================- 2-bit (de)-quantization + +void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) { + quantize_row_q2_K_ref(x, vy, k); +} + +//========================= 3-bit (de)-quantization + +void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { + quantize_row_q3_K_ref(x, vy, k); +} + +// ====================== 4-bit (de)-quantization + +void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_q4_K * restrict y = vy; + quantize_row_q4_K_ref(x, y, k); +} + +// ====================== 5-bit (de)-quantization + +void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_q5_K * restrict y = vy; + quantize_row_q5_K_ref(x, y, k); +} + +// ====================== 6-bit (de)-quantization + +void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_q6_K * restrict y = vy; + quantize_row_q6_K_ref(x, y, k); +} + +// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) + +void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_tq1_0 * restrict y = vy; + quantize_row_tq1_0_ref(x, y, k); +} + +void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_tq2_0 * restrict y = vy; + quantize_row_tq2_0_ref(x, y, k); +} + +static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + +//===================================== Q8_K ============================================== + +void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q8_K_ref(x, y, k); +} + +//===================================== Dot products ================================= + +// +// Helper functions +// +#if __AVX__ || __AVX2__ || __AVX512F__ + +// shuffles to pick the required scales in dot products +static inline __m256i get_scale_shuffle_q3k(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m256i get_scale_shuffle_k4(int i) { + static const uint8_t k_shuffle[256] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m128i get_scale_shuffle(int i) { + static const uint8_t k_shuffle[128] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, + 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, + 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 + }; + return _mm_loadu_si128((const __m128i*)k_shuffle + i); +} +#elif defined(__loongarch_asx) +// shuffles to pick the required scales in dot products +static inline __m256i get_scale_shuffle_q3k(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return __lasx_xvld((const __m256i*)k_shuffle + i, 0); +} +static inline __m256i get_scale_shuffle_k4(int i) { + static const uint8_t k_shuffle[256] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + return __lasx_xvld((const __m256i*)k_shuffle + i, 0); +} +static inline __m128i get_scale_shuffle(int i) { + static const uint8_t k_shuffle[128] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, + 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, + 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 + }; + return __lsx_vld((const __m128i*)k_shuffle + i, 0); +} +#endif + +void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q4_0 * restrict vx0 = vx; + const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx); + const block_q8_0 * restrict vy0 = vy; + const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by); + + float32x4_t sumv0 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const block_q4_0 * restrict b_x0 = &vx0[i]; + const block_q4_0 * restrict b_x1 = &vx1[i]; + const block_q8_0 * restrict b_y0 = &vy0[i]; + const block_q8_0 * restrict b_y1 = &vy1[i]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); + const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t x0_l = vsubq_s8(v0_0l, s8b); + const int8x16_t x0_h = vsubq_s8(v0_0h, s8b); + const int8x16_t x1_l = vsubq_s8(v0_1l, s8b); + const int8x16_t x1_h = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); + + float32_t _scale[4] = { LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y0->d), + LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y1->d), + LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y0->d), + LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y1->d)}; + + float32x4_t scale = vld1q_f32(_scale); + + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + + vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + return; + } +#endif + + int ib = 0; + float sumf = 0; + +#if defined(__ARM_FEATURE_SVE) + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); + + const int vector_length = lm_ggml_cpu_get_sve_cnt()*8; + + // VLA Implementation using switch case + switch (vector_length) { + case 128: + { + // predicate for activating higher lanes for 4 float32 elements + const svbool_t ph4 = svptrue_pat_b32(SV_VL4); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F)); + const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04)); + const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F)); + const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04)); + + // sub 8 + const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8); + const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8); + const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8); + const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8); + + // load y + const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16); + const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs); + const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16); + + // dot product + sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4, + svdot_s32(svdup_n_s32(0), qx0ls, qy0l), + svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4, + svdot_s32(svdup_n_s32(0), qx1ls, qy1l), + svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; + case 256: + { + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements + const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); + + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); + + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + + // dot product + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; + case 512: + { + // predicate for activating higher lanes for 32 int8 elements + const svbool_t ph32 = svptrue_pat_b8(SV_VL32); + + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes + const svbool_t pl16 = svnot_b_z(ph32, ph16); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs); + const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); + + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8); + + // load y + const svint8_t qy0 = svld1_s8(ph32, y0->qs); + const svint8_t qy1 = svld1_s8(ph32, y1->qs); + + // dot product + sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32, + svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32, + svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1)); + } break; + default: + assert(false && "Unsupported vector length"); + break; + } + +#elif defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + // dot product into int32x4_t + const int32x4_t p_0 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); + const int32x4_t p_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) ); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + qx = _mm256_sub_epi8( qx, off ); + + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps( d, q, acc ); + } + + sumf = hsum_float_8(acc); +#elif defined(__AVX__) + __m256 accum = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); + + const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8)); + const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8)); + const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8)); + const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8)); + + const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); + const __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1); + const __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1); + const __m256 p = sum_i16_pairs_float(p_2, p_1); + + const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d); + accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum); + } + + sumf = hsum_float_8(accum); +#elif defined(__SSSE3__) + // set constants + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); + + // Initialize accumulator with zeros + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); + + for (; ib + 1 < nb; ib += 2) { + _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) ); + + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs); + + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( LM_GGML_FP16_TO_FP32(x[ib + 1].d) * LM_GGML_FP16_TO_FP32(y[ib + 1].d) ); + + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); + + // Apply the scale + __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); + __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); + __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); + __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); + + // Acummulate + acc_0 = _mm_add_ps(p0_d, acc_0); + acc_1 = _mm_add_ps(p1_d, acc_1); + acc_2 = _mm_add_ps(p2_d, acc_2); + acc_3 = _mm_add_ps(p3_d, acc_3); + } + + sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#elif defined(__riscv_v_intrinsic) + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (; ib < nb; ++ib) { + // load elements + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); + + // mask and store lower part of x, and then upper part + vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + // subtract offset + vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); + vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += sumi*LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d); + } + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector signed char v8 = vec_splats((signed char)0x8); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed char q4x0 = vec_and(qxs, lowMask); + vector signed char q4x1 = vec_sr(qxs, v4); + + q4x0 = vec_sub(q4x0, v8); + q4x1 = vec_sub(q4x1, v8); + + vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + + vector signed int vsumi0 = v0; + + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi0 = vec_sum4s(qv1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = __lasx_xvreplfr2vr_s( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) ); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = __lasx_xvreplgr2vr_b( 8 ); + qx = __lasx_xvsub_b( qx, off ); + + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = __lasx_xvfmadd_s( d, q, acc ); + } + + sumf = hsum_float_8(acc); +#elif defined(__loongarch_sx) + // set constants + const __m128i low_mask = __lsx_vreplgr2vr_b(0xF); + const __m128i off = __lsx_vreplgr2vr_b(8); + + // Initialize accumulator with zeros + __m128 acc_0 = __lsx_vldi(0); + __m128 acc_1 = __lsx_vldi(0); + __m128 acc_2 = __lsx_vldi(0); + __m128 acc_3 = __lsx_vldi(0); + + for (; ib + 1 < nb; ib += 2) { + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = __lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) ); + + const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0); + + __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1); + __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0); + bx_0 = __lsx_vsub_b(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4)); + __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0); + bx_1 = __lsx_vsub_b(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + //_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = __lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[ib + 1].d) * LM_GGML_FP16_TO_FP32(y[ib + 1].d) ); + + const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0); + + __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3); + __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0); + bx_2 = __lsx_vsub_b(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4)); + __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0); + bx_3 = __lsx_vsub_b(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = __lsx_vffint_s_w(i32_0); + __m128 p1 = __lsx_vffint_s_w(i32_1); + __m128 p2 = __lsx_vffint_s_w(i32_2); + __m128 p3 = __lsx_vffint_s_w(i32_3); + + // Apply the scale + __m128 p0_d = __lsx_vfmul_s( d_0_1, p0 ); + __m128 p1_d = __lsx_vfmul_s( d_0_1, p1 ); + __m128 p2_d = __lsx_vfmul_s( d_2_3, p2 ); + __m128 p3_d = __lsx_vfmul_s( d_2_3, p3 ); + + // Acummulate + acc_0 = __lsx_vfadd_s(p0_d, acc_0); + acc_1 = __lsx_vfadd_s(p1_d, acc_1); + acc_2 = __lsx_vfadd_s(p2_d, acc_2); + acc_3 = __lsx_vfadd_s(p3_d, acc_3); + } + + sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F) - 8; + const int v1 = (x[ib].qs[j] >> 4) - 8; + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += sumi*LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d); + } + + *s = sumf; +} + +void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; + +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q4_1 * restrict vx0 = vx; + const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx); + const block_q8_1 * restrict vy0 = vy; + const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by); + + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t summs0 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const block_q4_1 * restrict b_x0 = &vx0[i]; + const block_q4_1 * restrict b_x1 = &vx1[i]; + const block_q8_1 * restrict b_y0 = &vy0[i]; + const block_q8_1 * restrict b_y1 = &vy1[i]; + + float32_t summs_t[4] = {LM_GGML_FP16_TO_FP32(b_x0->m) * LM_GGML_FP16_TO_FP32(b_y0->s), + LM_GGML_FP16_TO_FP32(b_x1->m) * LM_GGML_FP16_TO_FP32(b_y0->s), + LM_GGML_FP16_TO_FP32(b_x0->m) * LM_GGML_FP16_TO_FP32(b_y1->s), + LM_GGML_FP16_TO_FP32(b_x1->m) * LM_GGML_FP16_TO_FP32(b_y1->s)}; + summs0 = vaddq_f32(summs0, vld1q_f32(summs_t)); + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); + const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); + + // 4-bit -> 8-bit + const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); + + // mmla into int32x4_t + float32_t _scale[4] = {LM_GGML_FP16_TO_FP32(b_x0->d)*b_y0->d, + LM_GGML_FP16_TO_FP32(b_x0->d)*b_y1->d, + LM_GGML_FP16_TO_FP32(b_x1->d)*b_y0->d, + LM_GGML_FP16_TO_FP32(b_x1->d)*b_y1->d}; + float32x4_t scale = vld1q_f32(_scale); + + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + + float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + sumv2 = vaddq_f32(sumv2, summs0); + + vst1_f32(s, vget_low_f32 (sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + return; + } +#endif + + int ib = 0; + float sumf = 0; + + // TODO: add WASM SIMD +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs = 0; + + for (; ib + 1 < nb; ib += 2) { + const block_q4_1 * restrict x0 = &x[ib + 0]; + const block_q4_1 * restrict x1 = &x[ib + 1]; + const block_q8_1 * restrict y0 = &y[ib + 0]; + const block_q8_1 * restrict y1 = &y[ib + 1]; + + summs += LM_GGML_FP16_TO_FP32(x0->m) * LM_GGML_FP16_TO_FP32(y0->s) + LM_GGML_FP16_TO_FP32(x1->m) * LM_GGML_FP16_TO_FP32(y1->s); + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + // dot product into int32x4_t + const int32x4_t p_0 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); + const int32x4_t p_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; +#elif defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + // Main loop + for (; ib < nb; ++ib) { + const float d0 = LM_GGML_FP16_TO_FP32(x[ib].d); + const float d1 = LM_GGML_FP16_TO_FP32(y[ib].d); + + summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); + + const __m256 d0v = _mm256_set1_ps( d0 ); + const __m256 d1v = _mm256_set1_ps( d1 ); + + // Compute combined scales + const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i qx = bytes_from_nibbles_32(x[ib].qs); + const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs ); + + const __m256 xy = mul_sum_us8_pairs_float(qx, qy); + + // Accumulate d0*d1*x*y +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d0d1, xy, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); +#endif + } + + sumf = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (; ib < nb; ++ib) { + // load elements + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); + + // mask and store lower part of x, and then upper part + vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s); + } + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].m)); + vector float vys = {LM_GGML_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f}; + vsumf0 = vec_madd(vxmin, vys, vsumf0); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask); + vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4); + + vector signed int vsumi0 = v0; + + vsumi0 = vec_msum(q8y0, q4x0, vsumi0); + vsumi0 = vec_msum(q8y1, q4x1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + float summs = 0; + + // Main loop + for (; ib < nb; ++ib) { + const float d0 = LM_GGML_FP16_TO_FP32(x[ib].d); + const float d1 = LM_GGML_FP16_TO_FP32(y[ib].d); + + summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); + + const __m256 d0v = __lasx_xvreplfr2vr_s( d0 ); + const __m256 d1v = __lasx_xvreplfr2vr_s( d1 ); + + // Compute combined scales + const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i qx = bytes_from_nibbles_32(x[ib].qs); + const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0); + + const __m256 xy = mul_sum_us8_pairs_float(qx, qy); + + // Accumulate d0*d1*x*y + acc = __lasx_xvfmadd_s( d0d1, xy, acc ); + } + + sumf = hsum_float_8(acc) + summs; +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F); + const int v1 = (x[ib].qs[j] >> 4); + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + for (; ib + 1 < nb; ib += 2) { + const block_q5_0 * restrict x0 = &x[ib]; + const block_q5_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + // extract the 5th bit via lookup table ((!b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_1[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_1[(qh1 >> 24) ]; + + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__wasm_simd128__) + v128_t sumv = wasm_f32x4_splat(0.0f); + + uint32_t qh; + uint64_t tmp[4]; + + // TODO: check if unrolling this is better + for (; ib < nb; ++ib) { + const block_q5_0 * restrict x0 = &x[ib]; + const block_q8_0 * restrict y0 = &y[ib]; + + const v128_t m4b = wasm_i8x16_splat(0x0F); + + // extract the 5th bit + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_1[(qh >> 24) ]; + + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); + + const v128_t v0 = wasm_v128_load(x0->qs); + + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); + + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); + const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); + + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); + + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + + // dot product + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( + wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(LM_GGML_FP16_TO_FP32(x0->d) * LM_GGML_FP16_TO_FP32(y0->d)))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); + bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); + qx = _mm256_or_si256(qx, bxhi); + + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps(d, q, acc); + } + + sumf = hsum_float_8(acc); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8((char)0xF0); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); + + __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); + const __m256i bxhi = bytes_from_bits_32(x[ib].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_andnot_si128(bxhil, mask); + bxhih = _mm_andnot_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx_0); + __m128i bxh = _mm256_extractf128_si256(bx_0, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx_0 = MM256_SET_M128I(bxh, bxl); + + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0); + + /* Multiply q with scale and accumulate */ + acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); + } + + sumf = hsum_float_8(acc); +#elif defined(__riscv_v_intrinsic) + uint32_t qh; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + // These temporary registers are for masking and shift operations + vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); + vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); + + vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); + vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + + for (; ib < nb; ++ib) { + memcpy(&qh, x[ib].qh, sizeof(uint32_t)); + + // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); + vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); + vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); + + // ((qh & (1u << (j + 16))) >> (j + 12)); + vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); + vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); + + // narrowing + vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); + vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); + + vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); + vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); + + // load + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); + + vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); + vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); + + vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); + vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d)) * sumi; + } + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])}; + vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])}; + + vector signed char qh0 = (vector signed char)aux64x2_0; + vector signed char qh1 = (vector signed char)aux64x2_1; + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + + vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0); + vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl( 16, y[ib].qs); + + vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1)); + + qv0 = vec_add(qv0, qv1); + + vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); //FIXME + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); + bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0)); + qx = __lasx_xvor_v(qx, bxhi); + + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = __lasx_xvfmadd_s(d, q, acc); + } + + sumf = hsum_float_8(acc); +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); + const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d)) * sumi; + } + + *s = sumf; +} + +void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_1); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs0 = 0.0f; + float summs1 = 0.0f; + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + for (; ib + 1 < nb; ib += 2) { + const block_q5_1 * restrict x0 = &x[ib]; + const block_q5_1 * restrict x1 = &x[ib + 1]; + const block_q8_1 * restrict y0 = &y[ib]; + const block_q8_1 * restrict y1 = &y[ib + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + summs0 += LM_GGML_FP16_TO_FP32(x0->m) * LM_GGML_FP16_TO_FP32(y0->s); + summs1 += LM_GGML_FP16_TO_FP32(x1->m) * LM_GGML_FP16_TO_FP32(y1->s); + + // extract the 5th bit via lookup table ((b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_0[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_0[(qh1 >> 24) ]; + + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // add high bit + const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; +#elif defined(__wasm_simd128__) + v128_t sumv = wasm_f32x4_splat(0.0f); + + float summs = 0.0f; + + uint32_t qh; + uint64_t tmp[4]; + + // TODO: check if unrolling this is better + for (; ib < nb; ++ib) { + const block_q5_1 * restrict x0 = &x[ib]; + const block_q8_1 * restrict y0 = &y[ib]; + + summs += LM_GGML_FP16_TO_FP32(x0->m) * LM_GGML_FP16_TO_FP32(y0->s); + + const v128_t m4b = wasm_i8x16_splat(0x0F); + + // extract the 5th bit + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_0[(qh >> 24) ]; + + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); + + const v128_t v0 = wasm_v128_load(x0->qs); + + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); + + // add high bit + const v128_t v0lf = wasm_v128_or(v0l, qhl); + const v128_t v0hf = wasm_v128_or(v0h, qhh); + + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); + + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + + // dot product + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(LM_GGML_FP16_TO_FP32(x0->d) * LM_GGML_FP16_TO_FP32(y0->d)))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.0f; + + // Main loop + for (; ib < nb; ++ib) { + const __m256 dx = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d)); + + summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); + bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); + qx = _mm256_or_si256(qx, bxhi); + + const __m256 dy = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib].d)); + const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_us8_pairs_float(qx, qy); + + acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); + } + + sumf = hsum_float_8(acc) + summs; +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8(0x10); + + float summs = 0.0f; + + // Main loop + for (; ib < nb; ++ib) { + const __m256 dx = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d)); + + summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); + + __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); + const __m256i bxhi = bytes_from_bits_32(x[ib].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_and_si128(bxhil, mask); + bxhih = _mm_and_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx_0); + __m128i bxh = _mm256_extractf128_si256(bx_0, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx_0 = MM256_SET_M128I(bxh, bxl); + + const __m256 dy = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib].d)); + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0); + + acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); + } + + sumf = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + uint32_t qh; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + // temporary registers for shift operations + vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); + vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + + for (; ib < nb; ++ib) { + memcpy(&qh, x[ib].qh, sizeof(uint32_t)); + + // load qh + vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); + + // ((qh >> (j + 0)) << 4) & 0x10; + vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); + vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); + vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); + + // ((qh >> (j + 12)) ) & 0x10; + vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); + vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); + + // narrowing + vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); + vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); + + vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); + vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); + + // load + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); + + vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); + vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); + + vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s); + } + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].m)); + vector float vys = {LM_GGML_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f}; + vsumf0 = vec_madd(vxmin, vys, vsumf0); + + vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])}; + vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])}; + + vector signed char qh0 = (vector signed char)aux64x2_0; + vector signed char qh1 = (vector signed char)aux64x2_1; + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + + vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0); + vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl( 16, y[ib].qs); + + vector signed int vsumi0 = v0; + + vsumi0 = vec_msum(q8y0, q5x0, vsumi0); + vsumi0 = vec_msum(q8y1, q5x1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + float summs = 0.0f; + + // Main loop + for (; ib < nb; ++ib) { + const __m256 dx = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(x[ib].d)); + + summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); + bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10)); + qx = __lasx_xvor_v(qx, bxhi); + + const __m256 dy = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(y[ib].d)); + const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); + + const __m256 q = mul_sum_us8_pairs_float(qx, qy); + + acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc); + } + + sumf = hsum_float_8(acc) + summs; +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q8_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q8_0 * restrict vx0 = vx; + const block_q8_0 * restrict vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx); + const block_q8_0 * restrict vy0 = vy; + const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by); + + float32x4_t sumv0 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const block_q8_0 * restrict b_x0 = &vx0[i]; + const block_q8_0 * restrict b_y0 = &vy0[i]; + + const block_q8_0 * restrict b_x1 = &vx1[i]; + const block_q8_0 * restrict b_y1 = &vy1[i]; + + const int8x16_t x0_l = vld1q_s8(b_x0->qs); + const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16); + const int8x16_t x1_l = vld1q_s8(b_x1->qs); + const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16); + + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); + + float32_t _scale[4] = {LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y0->d), + LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y1->d), + LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y0->d), + LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y1->d)}; + float32x4_t scale = vld1q_f32(_scale); + + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + + vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + return; + } +#endif + + int ib = 0; + float sumf = 0; + +#if defined(__ARM_FEATURE_SVE) + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); + + const int vector_length = lm_ggml_cpu_get_sve_cnt()*8; + + //VLA Implemenation for SVE + switch (vector_length) { + case 128: + { + // predicate for activating lanes for 16 Int8 elements + const svbool_t ph16 = svptrue_pat_b8 (SV_VL16); + const svbool_t pl16 = svptrue_pat_b32(SV_VL4); + + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svint8_t qx0_0 = svld1_s8(ph16, x0->qs); + const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16); + const svint8_t qx1_0 = svld1_s8(ph16, x1->qs); + const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16); + + // load y + const svint8_t qy0_0 = svld1_s8(ph16, y0->qs); + const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16); + const svint8_t qy1_0 = svld1_s8(ph16, y1->qs); + const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16); + + sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16, + svdot_s32(svdup_n_s32(0), qx0_0, qy0_0), + svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16, + svdot_s32(svdup_n_s32(0), qx1_0, qy1_0), + svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1)); + } break; + case 256: + { + //printf("sve256"); + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); + const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); + + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx0, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx1, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; + case 512: + { + // predicate for activating high 256 bit + const svbool_t ph32 = svptrue_pat_b8(SV_VL32); + // predicate for activating low 256 bit + const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32); + + // predicate for activating high lanes for 8 float32 elements + const svbool_t ph8 = svptrue_pat_b32(SV_VL8); + // predicate for activating low lanes for 8 float32 elements + const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8); + + svfloat32_t sumv00 = svdup_n_f32(0.0f); + + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits + // and add them to make one 64 element vector + // load x + const svint8_t qx_32 = svld1_s8(ph32, x0->qs); + svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2); + + qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64); + + // load y + const svint8_t qy_32 = svld1_s8(ph32, y0->qs); + svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2); + + qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64); + + // scale creation + const float32_t deq1 = LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d); + const float32_t deq2 = LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d); + + // duplicate deq1 in first half of vector and deq2 in second half of vector + const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2); + + const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64)); + + sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp); + } + + sumf = svaddv_f32(svptrue_b32(), sumv00); + break; + } + default: + assert(false && "Unsupported vector length"); + break; + } +#elif defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + const int8x16_t x0_0 = vld1q_s8(x0->qs); + const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); + const int8x16_t x1_0 = vld1q_s8(x1->qs); + const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); + + // load y + const int8x16_t y0_0 = vld1q_s8(y0->qs); + const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); + const int8x16_t y1_0 = vld1q_s8(y1->qs); + const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + lm_ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), + lm_ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); + + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + lm_ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), + lm_ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (; ib < nb; ++ib) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); + __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + // Multiply q with scale and accumulate + acc = _mm256_fmadd_ps( d, q, acc ); + } + + sumf = hsum_float_8(acc); +#elif defined(__AVX__) + __m256 accum = _mm256_setzero_ps(); + + for (; ib + 1 < nb; ib += 2) { + const __m128i qx_1_0 = _mm_loadu_si128((const __m128i *)x[ib].qs); + const __m128i qx_1_1 = _mm_loadu_si128((const __m128i *)x[ib].qs + 1); + const __m128i qx_2_0 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + const __m128i qx_2_1 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs + 1); + const __m128i qy_1_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); + const __m128i qy_1_1 = _mm_loadu_si128((const __m128i *)y[ib].qs + 1); + const __m128i qy_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + const __m128i qy_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); + + const __m256 p = mul_sum_i8_quad_float(qx_1_0, qx_1_1, qx_2_0, qx_2_1, qy_1_0, qy_1_1, qy_2_0, qy_2_1); + const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d); + accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum); + } + + sumf = hsum_float_8(accum); +#elif defined(__riscv_v_intrinsic) + size_t vl = __riscv_vsetvl_e8m1(qk); + + for (; ib < nb; ++ib) { + // load elements + vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl); + vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); + + vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); + + sumf += sumi*(LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d)); + } +#elif defined(__POWER9_VECTOR__) + const vector signed int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char q8x0 = vec_xl( 0, x[ib].qs); + vector signed char q8x1 = vec_xl(16, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed short qv0 = vec_mule(q8x0, q8y0); + vector signed short qv1 = vec_mulo(q8x0, q8y0); + vector signed short qv2 = vec_mule(q8x1, q8y1); + vector signed short qv3 = vec_mulo(q8x1, q8y1); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi1 = vec_sum4s(qv1, vsumi1); + vsumi0 = vec_sum4s(qv2, vsumi0); + vsumi1 = vec_sum4s(qv3, vsumi1); + + vsumi0 = vec_add(vsumi0, vsumi1); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + // Main loop + for (; ib < nb; ++ib) { + // Compute combined scale for the block + const __m256 d = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); + __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0); + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + // Multiply q with scale and accumulate + acc = __lasx_xvfmadd_s( d, q, acc ); + } + + sumf = hsum_float_8(acc); +#endif + for (; ib < nb; ++ib) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[ib].qs[j]*y[ib].qs[j]; + } + + sumf += sumi*(LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d)); + } + + *s = sumf; +} + +void lm_ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + float sumf = 0.0f; + + uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + const uint8x16_t shift = vld1q_u8(k_shift); + + for (int i = 0; i < nb; ++i) { +#if defined(__ARM_FEATURE_DOTPROD) + int32x4_t sumi0 = vdupq_n_s32(0); + int32x4_t sumi1 = vdupq_n_s32(0); +#else + int16x8_t sumi0 = vdupq_n_s16(0); + int16x8_t sumi1 = vdupq_n_s16(0); +#endif + + // first 32 bytes of 5 elements + { + uint8x16_t qx0 = vld1q_u8(x[i].qs + 0); + uint8x16_t qx1 = vld1q_u8(x[i].qs + 16); + uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3)); + uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3)); + uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9)); + uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9)); + uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27)); + uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27)); + uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81)); + uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81)); + + // multiply by 3 and keep the 2 bits above 8 bits + int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); + int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6)); + int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6)); + int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6)); + int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + 0); + const int8x16_t qy1 = vld1q_s8(y[i].qs + 16); + const int8x16_t qy2 = vld1q_s8(y[i].qs + 32); + const int8x16_t qy3 = vld1q_s8(y[i].qs + 48); + const int8x16_t qy4 = vld1q_s8(y[i].qs + 64); + const int8x16_t qy5 = vld1q_s8(y[i].qs + 80); + const int8x16_t qy6 = vld1q_s8(y[i].qs + 96); + const int8x16_t qy7 = vld1q_s8(y[i].qs + 112); + const int8x16_t qy8 = vld1q_s8(y[i].qs + 128); + const int8x16_t qy9 = vld1q_s8(y[i].qs + 144); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); + sumi0 = vdotq_s32(sumi0, sqx6, qy6); + sumi1 = vdotq_s32(sumi1, sqx7, qy7); + sumi0 = vdotq_s32(sumi0, sqx8, qy8); + sumi1 = vdotq_s32(sumi1, sqx9, qy9); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9)); +#endif + } + + // last 16 bytes of 5-element, along with the 4 bytes of 4 elements + { + uint8x16_t qx0 = vld1q_u8(x[i].qs + 32); + uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3)); + uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9)); + uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27)); + uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned + uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh)); + qx5 = vmulq_u8(qx5, shift); + + // multiply by 3 and keep the 2 bits above 8 bits + int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + 160); + const int8x16_t qy1 = vld1q_s8(y[i].qs + 176); + const int8x16_t qy2 = vld1q_s8(y[i].qs + 192); + const int8x16_t qy3 = vld1q_s8(y[i].qs + 208); + const int8x16_t qy4 = vld1q_s8(y[i].qs + 224); + const int8x16_t qy5 = vld1q_s8(y[i].qs + 240); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); +#endif + } + + const int16x8_t ysum0 = vld1q_s16(y[i].bsums); + const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); + + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vaddq_s32(sumi0, sumi1); + sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); + + sumf += d * (float) vaddvq_s32(sumi0); +#else + sumi0 = vaddq_s16(sumi0, sumi1); + sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); + + sumf += d * (float) vaddlvq_s16(sumi0); +#endif + } + + *s = sumf; + +#elif defined(__AVX2__) + __m256 sumf = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + // 16-bit sums + __m256i sumi0 = _mm256_setzero_si256(); + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + + // first 32 bytes of 5 elements + { + __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs)); + // 8-bit multiplies with shifts, masks and adds + __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3 + __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9 + __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9 + __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9 + + // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits? + + // Cancel the +1 from avg so that it behaves like a halving add + qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1)); + qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1)); + qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1)); + qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1)); + qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1)); + // Multiply by 3 and get the top 2 bits + qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256())); + qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256())); + qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256())); + qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256())); + qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256())); + qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3)); + qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3)); + qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3)); + qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3)); + qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3)); + + const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0)); + const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32)); + const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64)); + const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96)); + const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128)); + + qx0 = _mm256_maddubs_epi16(qx0, qy0); + qx1 = _mm256_maddubs_epi16(qx1, qy1); + qx2 = _mm256_maddubs_epi16(qx2, qy2); + qx3 = _mm256_maddubs_epi16(qx3, qy3); + qx4 = _mm256_maddubs_epi16(qx4, qy4); + + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); + sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); + sumi2 = _mm256_add_epi16(sumi2, qx4); + } + + // last 16 bytes of 5-element, along with the 4 bytes of 4 elements + { + __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned + __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh)); + __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3 + __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9 + __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9 + __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9 + __m256i qx01 = MM256_SET_M128I(qx1, qx0); + __m256i qx23 = MM256_SET_M128I(qx3, qx2); + + // avx2 does not have 8-bit multiplies, so 16-bit it is. + qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1)); + qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF)); + __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1)); + + __m256i qx45 = MM256_SET_M128I(qx5, qx4); + + // Cancel the +1 from avg so that it behaves like a halving add + qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1)); + qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1)); + qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1)); + // Multiply by 3 and get the top 2 bits + qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256())); + qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256())); + qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256())); + qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3)); + qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3)); + qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3)); + + const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160)); + const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192)); + const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224)); + + qx01 = _mm256_maddubs_epi16(qx01, qy01); + qx23 = _mm256_maddubs_epi16(qx23, qy23); + qx45 = _mm256_maddubs_epi16(qx45, qy45); + + sumi0 = _mm256_add_epi16(sumi0, qx01); + sumi1 = _mm256_add_epi16(sumi1, qx23); + sumi2 = _mm256_add_epi16(sumi2, qx45); + } + + const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); + const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(x[i].d)); + + sumi0 = _mm256_sub_epi16(sumi0, ysum); + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2)); + sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); + + sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); + } + + *s = hsum_float_8(sumf); + +#else + const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int sum = 0; + + for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) { + for (size_t l = 0; l < 5; ++l) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[j*5 + l*32 + m]; + } + } + } + for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) { + for (size_t l = 0; l < 5; ++l) { + for (size_t m = 0; m < 16; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[j*5 + l*16 + m]; + } + } + } + + for (size_t l = 0; l < 4; ++l) { + for (size_t j = 0; j < sizeof(x->qh); ++j) { + uint8_t q = x[i].qh[j] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j]; + } + } + + sumf += (float) sum * (LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d); + } + + *s = sumf; +#endif +} + +void lm_ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + float sumf = 0.0f; + + const uint8x16_t m3 = vdupq_n_u8(3); + + for (int i = 0; i < nb; ++i) { +#if defined(__ARM_FEATURE_DOTPROD) + int32x4_t sumi0 = vdupq_n_s32(0); + int32x4_t sumi1 = vdupq_n_s32(0); +#else + int16x8_t sumi0 = vdupq_n_s16(0); + int16x8_t sumi1 = vdupq_n_s16(0); +#endif + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + uint8x16_t qx0 = vld1q_u8(x[i].qs + j); + uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16); + uint8x16_t qx2 = vshrq_n_u8(qx0, 2); + uint8x16_t qx3 = vshrq_n_u8(qx1, 2); + uint8x16_t qx4 = vshrq_n_u8(qx0, 4); + uint8x16_t qx5 = vshrq_n_u8(qx1, 4); + uint8x16_t qx6 = vshrq_n_u8(qx0, 6); + uint8x16_t qx7 = vshrq_n_u8(qx1, 6); + + int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3)); + int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3)); + int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0); + const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16); + const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32); + const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48); + const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64); + const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80); + const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96); + const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); + sumi0 = vdotq_s32(sumi0, sqx6, qy6); + sumi1 = vdotq_s32(sumi1, sqx7, qy7); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); +#endif + } + + const int16x8_t ysum0 = vld1q_s16(y[i].bsums); + const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); + + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vaddq_s32(sumi0, sumi1); + sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); + + sumf += d * (float) vaddvq_s32(sumi0); +#else + sumi0 = vaddq_s16(sumi0, sumi1); + sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); + + sumf += d * (float) vaddlvq_s16(sumi0); +#endif + } + + *s = sumf; + +#elif defined(__AVX2__) + __m256 sumf = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + // 16-bit sums, because 256*127 still fits + __m256i sumi0 = _mm256_setzero_si256(); + __m256i sumi1 = _mm256_setzero_si256(); + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j)); + __m256i qx1 = _mm256_srli_epi16(qx0, 2); + __m256i qx2 = _mm256_srli_epi16(qx0, 4); + __m256i qx3 = _mm256_srli_epi16(qx0, 6); + + // 0, 1, 2 (should not be 3) + qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3)); + qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3)); + qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3)); + qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3)); + + const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0)); + const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32)); + const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64)); + const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96)); + + qx0 = _mm256_maddubs_epi16(qx0, qy0); + qx1 = _mm256_maddubs_epi16(qx1, qy1); + qx2 = _mm256_maddubs_epi16(qx2, qy2); + qx3 = _mm256_maddubs_epi16(qx3, qy3); + + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); + sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); + } + + const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); + const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(x[i].d)); + + sumi0 = _mm256_add_epi16(sumi0, sumi1); + sumi0 = _mm256_sub_epi16(sumi0, ysum); + sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); + + sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); + } + + *s = hsum_float_8(sumf); + +#else + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + for (size_t l = 0; l < 4; ++l) { + for (size_t k = 0; k < 32; ++k) { + sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1); + } + } + } + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + + sumf += (float) sumi * d; + } + + *s = sumf; +#endif +} + +void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + const uint8x16_t m3 = vdupq_n_u8(0x3); + const uint8x16_t m4 = vdupq_n_u8(0xF); + + const int32x4_t vzero = vdupq_n_s32(0); + + lm_ggml_int8x16x2_t q2bytes; + uint8_t aux[16]; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint8_t * restrict sc = x[i].scales; + + const uint8x16_t mins_and_scales = vld1q_u8(sc); + const uint8x16_t scales = vandq_u8(mins_and_scales, m4); + vst1q_u8(aux, scales); + + const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); + const lm_ggml_int16x8x2_t q8sums = lm_ggml_vld1q_s16_x2(y[i].bsums); + const lm_ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}}; + const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), + vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); + const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), + vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); + sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); + + int isum = 0; + int is = 0; + +// We use this macro instead of a function call because for some reason +// the code runs 2-3% slower, even if the function is declared inline +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ + isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; + +#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ + q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32;\ + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ + MULTIPLY_ACCUM_WITH_SCALE((index)); + + for (int j = 0; j < QK_K/128; ++j) { + const lm_ggml_uint8x16x2_t q2bits = lm_ggml_vld1q_u8_x2(q2); q2 += 32; + + lm_ggml_int8x16x2_t q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32; + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); + + MULTIPLY_ACCUM_WITH_SCALE(0); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); + + is += 8; + } + + sum += d * isum; + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m128i m4 = _mm_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m256i mins = _mm256_cvtepi8_epi16(mins8); + const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); + + const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i q2_0 = _mm256_and_si256(q2bits, m3); + const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); + const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); + const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); + + __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); + __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); + + p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); + p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); + p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); + p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); + + p0 = _mm256_add_epi32(p0, p1); + p2 = _mm256_add_epi32(p2, p3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(0x3); + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float dall = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // load mins and scales from block_q2_K.scales[QK_K/16] + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales16 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m128i mins_0 = _mm_cvtepi8_epi16(mins16); + const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16)); + + // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2 + const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0])); + const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8])); + + // sumf += -dmin * summs in 32bits*8 + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc); + + const __m128i scales_0 = _mm_cvtepi8_epi16(scales16); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16)); + const __m128i scales[2] = { scales_0, scales_1 }; + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + for (int j = 0; j < QK_K/128; ++j) { + + // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // load 2bits*16*8 from block_q2_K.qs[QK_K/4] + __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_0 = _mm_and_si128(q2bits, m3); + const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_1 = _mm_and_si128(q2bits, m3); + const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + + // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8 + __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0); + __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1); + __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2); + __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3); + __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4); + __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5); + __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6); + __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7); + + // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8 + __m128i shuffle = _mm_set1_epi16(0x0100); + p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0); + shuffle = _mm_add_epi16(shuffle, m2); + p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1); + shuffle = _mm_add_epi16(shuffle, m2); + p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2); + shuffle = _mm_add_epi16(shuffle, m2); + p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3); + shuffle = _mm_add_epi16(shuffle, m2); + p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4); + shuffle = _mm_add_epi16(shuffle, m2); + p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5); + shuffle = _mm_add_epi16(shuffle, m2); + p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6); + shuffle = _mm_add_epi16(shuffle, m2); + p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7); + + p0 = _mm_add_epi32(p0, p1); + p2 = _mm_add_epi32(p2, p3); + p4 = _mm_add_epi32(p4, p5); + p6 = _mm_add_epi32(p6, p7); + + // isum in 32bits*4*2 + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6)); + } + + // sumf += dall * isum - dmin * summs in 32bits + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + float sumf = 0; + uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + const float dall = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + size_t vl = 16; + + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + + vl = 32; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + + uint8_t is=0; + int isum=0; + + for (int j = 0; j < QK_K/128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); + + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); + + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); + + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum1); + + q2+=32; q8+=128; is=8; + + } + + sumf += dall * isum; + + } + + *s = sumf; + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0x3); + const vector signed char lowScaleMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + vector signed char q2xmins = (vector signed char)vec_xl( 0, x[i].scales); + vector signed char vscales = vec_and(q2xmins, lowScaleMask); + + q2xmins = vec_sr(q2xmins, v4); + vector signed short q2xmins0 = vec_unpackh(q2xmins); + vector signed short q2xmins1 = vec_unpackl(q2xmins); + + vector signed int prod0 = vec_mule(q2xmins0, q8ysums0); + vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0); + vector signed int prod2 = vec_mule(q2xmins1, q8ysums1); + vector signed int prod3 = vec_mulo(q2xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q2); + vector signed char qxs1 = (vector signed char)vec_xl(16, q2); + q2 += 32; + + vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask); + vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask); + vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask); + vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask); + vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask); + vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask); + vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask); + vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y02 = vec_xl( 64, q8); + vector signed char q8y12 = vec_xl( 80, q8); + vector signed char q8y03 = vec_xl( 96, q8); + vector signed char q8y13 = vec_xl(112, q8); + q8 += 128; + + vector signed int qv0 = vec_msum(q8y00, q2x00, v0); + vector signed int qv1 = vec_msum(q8y01, q2x01, v0); + vector signed int qv2 = vec_msum(q8y02, q2x02, v0); + vector signed int qv3 = vec_msum(q8y03, q2x03, v0); + vector signed int qv4 = vec_msum(q8y10, q2x10, v0); + vector signed int qv5 = vec_msum(q8y11, q2x11, v0); + vector signed int qv6 = vec_msum(q8y12, q2x12, v0); + vector signed int qv7 = vec_msum(q8y13, q2x13, v0); + + vector signed short vscales_07 = vec_unpackh(vscales); + vector signed int vscales_03 = vec_unpackh(vscales_07); + vector signed int vscales_47 = vec_unpackl(vscales_07); + vector signed int vs0 = vec_splat(vscales_03, 0); + vector signed int vs1 = vec_splat(vscales_03, 1); + vector signed int vs2 = vec_splat(vscales_03, 2); + vector signed int vs3 = vec_splat(vscales_03, 3); + vector signed int vs4 = vec_splat(vscales_47, 0); + vector signed int vs5 = vec_splat(vscales_47, 1); + vector signed int vs6 = vec_splat(vscales_47, 2); + vector signed int vs7 = vec_splat(vscales_47, 3); + vscales = vec_sld(vscales, vscales, 8); + + vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1); + vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2); + vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3); + vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4); + vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5); + vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6); + vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined __loongarch_asx + + const __m256i m3 = __lasx_xvreplgr2vr_b(3); + const __m128i m4 = __lsx_vreplgr2vr_b(0xF); + + __m256 acc = (__m256)__lasx_xvldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0); + const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4); + const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4); + const __m256i mins = lasx_ext8_16(mins8); + const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0)); + + acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc); + + const __m256i all_scales = lasx_ext8_16(scales8); + const __m128i l_scales = lasx_extracti128(all_scales, 0); + const __m128i h_scales = lasx_extracti128(all_scales, 1); + const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; + + __m256i sumi = __lasx_xvldi(0); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m256i q2bits = __lasx_xvld((const __m256i*)q2, 0); q2 += 32; + + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + const __m256i q2_0 = __lasx_xvand_v(q2bits, m3); + const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3); + const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3); + const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3); + + __m256i p0 = lasx_maddubs_h(q2_0, q8_0); + __m256i p1 = lasx_maddubs_h(q2_1, q8_1); + __m256i p2 = lasx_maddubs_h(q2_2, q8_2); + __m256i p3 = lasx_maddubs_h(q2_3, q8_3); + + p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0); + p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1); + p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2); + p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3); + + p0 = __lasx_xvadd_w(p0, p1); + p2 = __lasx_xvadd_w(p2, p3); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p0, p2)); + } + + acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#else + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < 16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; + } + sumf += dall * isum - dmin * summs; + } + *s = sumf; +#endif +} + +void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + uint32_t aux[3]; + uint32_t utmp[4]; + + const uint8x16_t m3b = vdupq_n_u8(0x3); + const int32x4_t vzero = vdupq_n_s32(0); + + const uint8x16_t m0 = vdupq_n_u8(1); + const uint8x16_t m1 = vshlq_n_u8(m0, 1); + const uint8x16_t m2 = vshlq_n_u8(m0, 2); + const uint8x16_t m3 = vshlq_n_u8(m0, 3); + const int8_t m32 = 32; + + lm_ggml_int8x16x4_t q3bytes; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + lm_ggml_uint8x16x2_t qhbits = lm_ggml_vld1q_u8_x2(qh); + + lm_ggml_uint8x16x4_t q3h; + + int32_t isum = 0; + + // Set up scales + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= m32; + + for (int j = 0; j < QK_K/128; ++j) { + + const lm_ggml_uint8x16x2_t q3bits = lm_ggml_vld1q_u8_x2(q3); q3 += 32; + const lm_ggml_int8x16x4_t q8bytes_1 = lm_ggml_vld1q_s8_x4(q8); q8 += 64; + const lm_ggml_int8x16x4_t q8bytes_2 = lm_ggml_vld1q_s8_x4(q8); q8 += 64; + + q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); + q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); + q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); + q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + + isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; + isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; + isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; + isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; + + scale += 4; + + q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); + q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); + q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); + q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + + isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; + isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; + isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; + isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; + + scale += 4; + + if (j == 0) { + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); + } + + } + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i mone = _mm256_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Set up scales + memcpy(aux, x[i].scales, 12); + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + // high bit + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); + + // integer accumulator + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits + const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; + + // prepare low and high bits + const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); + const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); + const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); + const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); + const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + + // accumulate + p16_0 = _mm256_add_epi32(p16_0, p16_1); + p16_2 = _mm256_add_epi32(p16_2, p16_3); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + + } + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + const uint32_t *aux; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Set up scales + aux = (const uint32_t *)x[i].scales; + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m128i scales_0 = _mm_cvtepi8_epi16(scales128); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128)); + const __m128i scales[2] = { scales_0, scales_1 }; + + // high bit *128*2 from block_q3_K.hmask[QK_K/8] + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]); + + // integer accumulator + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] + const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + + // prepare low and high bits + const int bit = j << 2; + + const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); + const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); + const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); + const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); + + const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); + const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); + const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + + const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); + const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); + const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + + const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); + const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); + const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + + // load Q8 quants from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0); + __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1); + __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2); + __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3); + __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4); + __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5); + __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6); + __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7); + + __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + p16_4 = _mm_sub_epi16(p16_4, q8s_4); + p16_5 = _mm_sub_epi16(p16_5, q8s_5); + p16_6 = _mm_sub_epi16(p16_6, q8s_6); + p16_7 = _mm_sub_epi16(p16_7, q8s_7); + + // multiply with scales + __m128i shuffle = _mm_set1_epi16(0x0100); + p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0); + shuffle = _mm_add_epi16(shuffle, m2); + p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1); + shuffle = _mm_add_epi16(shuffle, m2); + p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2); + shuffle = _mm_add_epi16(shuffle, m2); + p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3); + shuffle = _mm_add_epi16(shuffle, m2); + p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4); + shuffle = _mm_add_epi16(shuffle, m2); + p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5); + shuffle = _mm_add_epi16(shuffle, m2); + p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6); + shuffle = _mm_add_epi16(shuffle, m2); + p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7); + + // accumulate + p16_0 = _mm_add_epi32(p16_0, p16_1); + p16_2 = _mm_add_epi32(p16_2, p16_3); + p16_4 = _mm_add_epi32(p16_4, p16_5); + p16_6 = _mm_add_epi32(p16_6, p16_7); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6)); + + } + + // multiply with block scale and accumulate + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + uint32_t aux[3]; + uint32_t utmp[4]; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + + int sum_t = 0; + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + // retrieve lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q3 += 32; q8 += 128; scale += 8; + + } + + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + + } + + *s = sumf; + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0x3); + const vector signed char lowMask1 = vec_splats((int8_t)0xf); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); + const vector signed char v1 = vec_splats((signed char)0x1); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector signed char off = vec_splats((signed char)0x20); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + UNUSED(kmask1); + UNUSED(kmask2); + + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(u0, lowMask1); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2)); + vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4); + vector signed char u31 = vec_and(u3, lowMask2); + + u1 = vec_or(u1, u30); + u2 = vec_or(vec_sr(u0, v4), u31); + + vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2); + vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask); + vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask); + + vscales = vec_sub(vscales, off); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q3); + vector signed char qxs1 = (vector signed char)vec_xl(16, q3); + q3 += 32; + + //the low 2 bits + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask); + vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask); + vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask); + vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask); + vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask); + + //the 3rd bit + vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2); + vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2); + vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2); + vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2); + vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2); + vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2); + vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2); + vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2); + qxhs0 = vec_sr(qxhs0, v4); + qxhs1 = vec_sr(qxhs1, v4); + + vector signed char q3x00 = vec_sub(qxs00, qxh00); + vector signed char q3x01 = vec_sub(qxs01, qxh01); + vector signed char q3x02 = vec_sub(qxs02, qxh02); + vector signed char q3x03 = vec_sub(qxs03, qxh03); + vector signed char q3x10 = vec_sub(qxs10, qxh10); + vector signed char q3x11 = vec_sub(qxs11, qxh11); + vector signed char q3x12 = vec_sub(qxs12, qxh12); + vector signed char q3x13 = vec_sub(qxs13, qxh13); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y02 = vec_xl( 64, q8); + vector signed char q8y12 = vec_xl( 80, q8); + vector signed char q8y03 = vec_xl( 96, q8); + vector signed char q8y13 = vec_xl(112, q8); + q8 += 128; + + vector signed short vscales_h = vec_unpackh(vscales); + vector signed short vs0 = vec_splat(vscales_h, 0); + vector signed short vs1 = vec_splat(vscales_h, 1); + vector signed short vs2 = vec_splat(vscales_h, 2); + vector signed short vs3 = vec_splat(vscales_h, 3); + vector signed short vs4 = vec_splat(vscales_h, 4); + vector signed short vs5 = vec_splat(vscales_h, 5); + vector signed short vs6 = vec_splat(vscales_h, 6); + vector signed short vs7 = vec_splat(vscales_h, 7); + vscales = vec_sld(vscales, vscales, 8); + + vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00)); + vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01)); + vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02)); + vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03)); + vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10)); + vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11)); + vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12)); + vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13)); + + vsumi0 = vec_msum(qv00, vs0, vsumi0); + vsumi1 = vec_msum(qv01, vs2, vsumi1); + vsumi2 = vec_msum(qv02, vs4, vsumi2); + vsumi3 = vec_msum(qv03, vs6, vsumi3); + vsumi4 = vec_msum(qv10, vs1, vsumi4); + vsumi5 = vec_msum(qv11, vs3, vsumi5); + vsumi6 = vec_msum(qv12, vs5, vsumi6); + vsumi7 = vec_msum(qv13, vs7, vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined __loongarch_asx + + const __m256i m3 = __lasx_xvreplgr2vr_b(3); + const __m256i mone = __lasx_xvreplgr2vr_b(1); + const __m128i m32 = __lsx_vreplgr2vr_b(32); + + __m256 acc = (__m256)__lasx_xvldi(0); + + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + // Set up scales + memcpy(aux, x[i].scales, 12); + __m128i scales128 = lsx_set_w( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = __lsx_vsub_b(scales128, m32); + const __m256i all_scales = lasx_ext8_16(scales128); + const __m128i l_scales = lasx_extracti128(all_scales, 0); + const __m128i h_scales = lasx_extracti128(all_scales, 1); + const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; + + // high bit + const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0); + + // integer accumulator + __m256i sumi = __lasx_xvldi(0); + + int bit = 0; + int is = 0; + __m256i xvbit; + + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits + const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32; + + xvbit = __lasx_xvreplgr2vr_h(bit); + // prepare low and high bits + const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3); + const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); + ++bit; + + xvbit = __lasx_xvreplgr2vr_h(bit); + const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3); + const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); + ++bit; + + xvbit = __lasx_xvreplgr2vr_h(bit); + const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3); + const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); + ++bit; + + xvbit = __lasx_xvreplgr2vr_h(bit); + const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3); + const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); + ++bit; + + // load Q8 quants + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0); + __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1); + __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2); + __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3); + + __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0); + __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1); + __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2); + __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3); + + p16_0 = __lasx_xvsub_h(p16_0, q8s_0); + p16_1 = __lasx_xvsub_h(p16_1, q8s_1); + p16_2 = __lasx_xvsub_h(p16_2, q8s_2); + p16_3 = __lasx_xvsub_h(p16_3, q8s_3); + + // multiply with scales + p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); + p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); + p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); + p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + + // accumulate + p16_0 = __lasx_xvadd_w(p16_0, p16_1); + p16_2 = __lasx_xvadd_w(p16_2, p16_3); + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2)); + } + // multiply with block scale and accumulate + acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME + } + + *s = hsum_float_8(acc); + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#ifdef __ARM_NEON + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int32x4_t mzero = vdupq_n_s32(0); + + lm_ggml_int8x16x2_t q4bytes; + lm_ggml_int8x16x2_t q8bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + + uint32x2_t mins8 = { 0 }; + mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); + mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + const lm_ggml_uint8x16x2_t q4bits = lm_ggml_vld1q_u8_x2(q4); q4 += 32; + + q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + sumi1 += vaddvq_s32(p1) * scales[2*j+0]; + + q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + + sumi2 += vaddvq_s32(p2) * scales[2*j+1]; + } + + sumf += d * (sumi1 + sumi2); + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + p16l = _mm256_madd_epi16(scale_l, p16l); + + const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + p16h = _mm256_madd_epi16(scale_h, p16h); + const __m256i sumj = _mm256_add_epi32(p16l, p16h); + + sumi = _mm256_add_epi32(sumi, sumj); + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_0 = _mm_and_si128(q4bits, m4); + const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_1 = _mm_and_si128(q4bits, m4); + const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + + const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_0 = _mm_add_epi32(sumi_0, p16l); + const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16l = _mm_maddubs_epi16(q4l_1, q8l_1); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_1 = _mm_add_epi32(sumi_1, p16l); + + const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_0 = _mm_add_epi32(sumi_0, p16h); + const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16h = _mm_maddubs_epi16(q4h_1, q8h_1); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_1 = _mm_add_epi32(sumi_1, p16h); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#elif defined __riscv_v_intrinsic + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + size_t vl = 8; + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + vl = 32; + + int32_t sum_1 = 0; + int32_t sum_2 = 0; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + + q4 += 32; q8 += 64; + + } + + sumf += d*(sum_1 + sum_2); + + } + + *s = sumf; + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed char lowMask1 = vec_splats((int8_t)0x3f); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v2 = vec_splats((uint8_t)2); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(kmask3); + UNUSED(utmp); + + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = vec_sr(u2, v4); + + vector signed char u30 = u1; + vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); + + u1 = vec_and(u0, lowMask1); + u2 = vec_or(u30, u31); + + vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); + + vector signed short vscales = vec_unpackh(utmps); + vector signed short q4xmins = vec_unpackl(utmps); + vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins); + vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins); + + vector signed int prod0 = vec_mule(q4xmins0, q8ysums0); + vector signed int prod1 = vec_mule(q4xmins1, q8ysums1); + vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0); + vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; j+=2) { + __builtin_prefetch(q4, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); + vector signed char qxs1 = (vector signed char)vec_xl(16, q4); + vector signed char qxs2 = (vector signed char)vec_xl(32, q4); + vector signed char qxs3 = (vector signed char)vec_xl(48, q4); + q4 += 64; + + vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask); + vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4); + vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask); + vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4); + vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask); + vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4); + vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask); + vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y20 = vec_xl( 64, q8); + vector signed char q8y30 = vec_xl( 80, q8); + vector signed char q8y21 = vec_xl( 96, q8); + vector signed char q8y31 = vec_xl(112, q8); + q8 += 128; + + vector signed int qv00 = vec_msum(q8y00, q4x00, v0); + vector signed int qv01 = vec_msum(q8y01, q4x01, v0); + vector signed int qv10 = vec_msum(q8y10, q4x10, v0); + vector signed int qv11 = vec_msum(q8y11, q4x11, v0); + vector signed int qv20 = vec_msum(q8y20, q4x20, v0); + vector signed int qv21 = vec_msum(q8y21, q4x21, v0); + vector signed int qv30 = vec_msum(q8y30, q4x30, v0); + vector signed int qv31 = vec_msum(q8y31, q4x31, v0); + + vector signed int vscales_h = vec_unpackh(vscales); + vector signed int vs0 = vec_splat(vscales_h, 0); + vector signed int vs1 = vec_splat(vscales_h, 1); + vector signed int vs2 = vec_splat(vscales_h, 2); + vector signed int vs3 = vec_splat(vscales_h, 3); + vscales = vec_sld(vscales, vscales, 8); + + vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1); + vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2); + vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3); + + vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1); + vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2); + vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined __loongarch_asx + LM_GGML_UNUSED(kmask1); + LM_GGML_UNUSED(kmask2); + LM_GGML_UNUSED(kmask3); + + const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); + + __m256 acc = (__m256)__lasx_xvldi(0); + __m128 acc_m = (__m128)__lsx_vldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); + const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); + const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); + acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); + + const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); + const __m256i scales = lasx_insertf128(sc128, sc128); + + __m256i sumi = __lasx_xvldi(0); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; + const __m256i q4l = __lasx_xvand_v(q4bits, m4); + const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4); + + const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + __m256i p16l = lasx_maddubs_h(q4l, q8l); + p16l = lasx_madd_h(scale_l, p16l); + + const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + __m256i p16h = lasx_maddubs_h(q4h, q8h); + p16h = lasx_madd_h(scale_h, p16h); + const __m256i sumj = __lasx_xvadd_w(p16l, p16h); + + sumi = __lasx_xvadd_w(sumi, sumj); + } + + __m256 vd = __lasx_xvreplfr2vr_s(d); + acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); + + } + + acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee)); + __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0); + acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1); + + + ft_union fi; + fi.i = __lsx_vpickve2gr_w(acc_m, 0); + *s = hsum_float_8(acc) + fi.f ; +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = LM_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#ifdef __ARM_NEON + const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + const int32x4_t mzero = vdupq_n_s32(0); + + lm_ggml_int8x16x4_t q5bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + int32_t sumi_mins = vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + lm_ggml_uint8x16x2_t qhbits = lm_ggml_vld1q_u8_x2(qh); + + lm_ggml_uint8x16x4_t q5h; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const lm_ggml_uint8x16x2_t q5bits = lm_ggml_vld1q_u8_x2(q5); q5 += 32; + const lm_ggml_int8x16x4_t q8bytes = lm_ggml_vld1q_s8_x4(q8); q8 += 64; + + q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); + q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); + + q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); + q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); + q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); + q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + + sumi += vaddvq_s32(lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; + sumi += vaddvq_s32(lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; + } + + sumf += d * sumi - dmin * sumi_mins; + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); + __m256i hmask = mone; + + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); + + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]); + __m128i hmask = mone; + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + int bit = 0; + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + + __m128i q5l_0 = _mm_and_si128(q5bits_0, m4); + __m128i q5l_1 = _mm_and_si128(q5bits_1, m4); + __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0); + __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1); + p16_0 = _mm_madd_epi16(scale_0, p16_0); + p16_1 = _mm_madd_epi16(scale_0, p16_1); + + q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4); + q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4); + q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + q5_0 = _mm_add_epi8(q5l_0, q5h_0); + q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1); + p16_2 = _mm_madd_epi16(scale_1, p16_2); + p16_3 = _mm_madd_epi16(scale_1, p16_3); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __riscv_v_intrinsic + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + float sums = 0.0; + + size_t vl; + + for (int i = 0; i < nb; ++i) { + + vl = 8; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const float dmin = LM_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + vl = 32; + int32_t aux32 = 0; + int is = 0; + + uint8_t m = 1; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q5 and Q8 + vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); + vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); + + // compute mask for addition + vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu(vmask_1, q5_a, q5_a, 16, vl); + m <<= 1; + + vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu(vmask_2, q5_l, q5_l, 16, vl); + m <<= 1; + + vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); + vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); + + vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); + vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); + + vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); + vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); + + aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); + q5 += 32; q8 += 64; + + } + + vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); + sums += __riscv_vfmv_f_s_f32m1_f32(vaux); + + } + + *s = sumf+sums; + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed char lowMask1 = vec_splats((int8_t)0x3f); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v1 = vec_splats((unsigned char)0x1); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(kmask3); + UNUSED(utmp); + + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = vec_sr(u2, v4); + + vector signed char u30 = u1; + vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); + + u1 = vec_and(u0, lowMask1); + u2 = vec_or(u30, u31); + + vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + vector signed short vscales = vec_unpackh(utmps); + + vector signed short q5xmins = vec_unpackl(utmps); + vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins); + vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins); + + vector signed int prod0 = vec_mule(q5xmins0, q8ysums0); + vector signed int prod1 = vec_mule(q5xmins1, q8ysums1); + vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0); + vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh); + vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; ++j) { + __builtin_prefetch(q5, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q5); + vector signed char qxs1 = (vector signed char)vec_xl(16, q5); + q5 += 32; + + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_sr(qxs0, v4); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_sr(qxs1, v4); + + vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4); + vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3); + vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4); + vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3); + qxhs0 = vec_sr(qxhs0, v2); + qxhs1 = vec_sr(qxhs1, v2); + + vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00); + vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01); + vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10); + vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl(16, q8); + vector signed char q8y01 = vec_xl(32, q8); + vector signed char q8y11 = vec_xl(48, q8); + q8 += 64; + + vector signed int qv00 = vec_msum(q8y00, q5x00, v0); + vector signed int qv01 = vec_msum(q8y01, q5x01, v0); + vector signed int qv10 = vec_msum(q8y10, q5x10, v0); + vector signed int qv11 = vec_msum(q8y11, q5x11, v0); + + vector signed int vscales_h = vec_unpackh(vscales); + vector signed int vs0 = vec_splat(vscales_h, 0); + vector signed int vs1 = vec_splat(vscales_h, 1); + vscales = vec_sld(vscales, vscales, 12); + + vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1); + vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2); + vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined __loongarch_asx + LM_GGML_UNUSED(kmask1); + LM_GGML_UNUSED(kmask2); + LM_GGML_UNUSED(kmask3); + + const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); + const __m128i mzero = __lsx_vldi(0); + const __m256i mone = __lasx_xvreplgr2vr_b(1); + + __m256 acc = (__m256)__lasx_xvldi(0); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); + const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); + const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); + const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero); + summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check + + const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); + const __m256i scales = lasx_insertf128(sc128, sc128); + + const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0); + __m256i hmask = mone; + + __m256i sumi = __lasx_xvldi(0); + + int bit = 0; + __m256i xvbit; + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32; + + xvbit = __lasx_xvreplgr2vr_h(bit++); + const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4); + const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); + const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0); + hmask = __lasx_xvslli_h(hmask, 1); + + xvbit = __lasx_xvreplgr2vr_h(bit++); + const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4); + const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); + const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1); + hmask = __lasx_xvslli_h(hmask, 1); + + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0); + __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1); + + p16_0 = lasx_madd_h(scale_0, p16_0); + p16_1 = lasx_madd_h(scale_1, p16_1); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); + + } + + __m256 vd = __lasx_xvreplfr2vr_s(d); + acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = LM_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + float sum = 0; + + const uint8x16_t m4b = vdupq_n_u8(0xF); + const int32x4_t vzero = vdupq_n_s32(0); + //const int8x16_t m32s = vdupq_n_s8(32); + + const uint8x16_t mone = vdupq_n_u8(3); + + lm_ggml_int8x16x4_t q6bytes; + lm_ggml_uint8x16x4_t q6h; + + for (int i = 0; i < nb; ++i) { + + const float d_all = LM_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + const lm_ggml_int16x8x2_t q8sums = lm_ggml_vld1q_s16_x2(y[i].bsums); + const int8x16_t scales = vld1q_s8(scale); + const lm_ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}}; + + const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), + vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), + vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); + int32_t isum_mins = vaddvq_s32(prod); + + int32_t isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + lm_ggml_uint8x16x2_t qhbits = lm_ggml_vld1q_u8_x2(qh); qh += 32; + lm_ggml_uint8x16x4_t q6bits = lm_ggml_vld1q_u8_x4(q6); q6 += 64; + lm_ggml_int8x16x4_t q8bytes = lm_ggml_vld1q_s8_x4(q8); q8 += 64; + + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 2); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); + + isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + + scale += 4; + + q8bytes = lm_ggml_vld1q_s8_x4(q8); q8 += 64; + + shifted = vshrq_n_u8(qhbits.val[0], 4); + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[0], 6); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); + + isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; + } + //sum += isum * d_all * y[i].d; + sum += d_all * y[i].d * (isum - 32 * isum_mins); + + } + *s = sum; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + + __m256i sumi = _mm256_setzero_si256(); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; + + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); + const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); + const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); + + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i m15 = _mm_set1_epi8(15); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + // handle the q6_k -32 offset separately using bsums + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1); + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales); + const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8)); + const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5); + const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; + const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; + + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2); + const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48)); + const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48)); + const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2); + const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2); + + const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3); + const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4); + const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5); + const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6); + const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7); + + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); + + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; + + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3); + p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); + p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5); + p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); + p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); + + } + + sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0); + sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1); + const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + size_t vl; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + int sum_t = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + vl = 32; + + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q6 += 64; qh += 32; q8 += 128; is=8; + + } + + sumf += d * sum_t; + + } + + *s = sumf; + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector signed char off = vec_splats((signed char)0x20); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict qs = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q6, 0, 0); + __builtin_prefetch(qh, 0, 0); + __builtin_prefetch(q8, 0, 0); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q6); + vector signed char qxs1 = (vector signed char)vec_xl(16, q6); + vector signed char qxs2 = (vector signed char)vec_xl(32, q6); + vector signed char qxs3 = (vector signed char)vec_xl(48, q6); + q6 += 64; + + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_sr(qxs0, v4); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_sr(qxs1, v4); + vector signed char qxs20 = vec_and(qxs2, lowMask); + vector signed char qxs21 = vec_sr(qxs2, v4); + vector signed char qxs30 = vec_and(qxs3, lowMask); + vector signed char qxs31 = vec_sr(qxs3, v4); + + vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh); + vector signed char qxhs1 = (vector signed char)vec_xl(16, qh); + qh += 32; + + vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4); + vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4); + vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4); + vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4); + vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4); + vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4); + vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4); + vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4); + + vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off); + vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off); + vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off); + vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off); + vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off); + vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off); + vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off); + vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y20 = vec_xl( 32, q8); + vector signed char q8y30 = vec_xl( 48, q8); + vector signed char q8y01 = vec_xl( 64, q8); + vector signed char q8y11 = vec_xl( 80, q8); + vector signed char q8y21 = vec_xl( 96, q8); + vector signed char q8y31 = vec_xl(112, q8); + q8 += 128; + + vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00)); + vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10)); + vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20)); + vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30)); + vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01)); + vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11)); + vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21)); + vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31)); + + vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8)); + qs += 8; + + vector signed short vs0 = vec_splat(vscales, 0); + vector signed short vs1 = vec_splat(vscales, 1); + vector signed short vs2 = vec_splat(vscales, 2); + vector signed short vs3 = vec_splat(vscales, 3); + vector signed short vs4 = vec_splat(vscales, 4); + vector signed short vs5 = vec_splat(vscales, 5); + vector signed short vs6 = vec_splat(vscales, 6); + vector signed short vs7 = vec_splat(vscales, 7); + + vsumi0 = vec_msum(qv00, vs0, vsumi0); + vsumi1 = vec_msum(qv01, vs4, vsumi1); + vsumi2 = vec_msum(qv10, vs1, vsumi2); + vsumi3 = vec_msum(qv11, vs5, vsumi3); + vsumi4 = vec_msum(qv20, vs2, vsumi4); + vsumi5 = vec_msum(qv21, vs6, vsumi5); + vsumi6 = vec_msum(qv30, vs3, vsumi6); + vsumi7 = vec_msum(qv31, vs7, vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined __loongarch_asx + + const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); + const __m256i m2 = __lasx_xvreplgr2vr_b(3); + const __m256i m32s = __lasx_xvreplgr2vr_b(32); + + __m256 acc = (__m256)__lasx_xvldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0); + + __m256i sumi = __lasx_xvldi(0); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3)); + is += 4; + + const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; + const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; + const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32; + + const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4); + const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4); + const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4); + const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4); + + const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0); + const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1); + const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2); + const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3); + + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0); + __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1); + __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2); + __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3); + + __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0); + __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1); + __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2); + __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3); + + p16_0 = __lasx_xvsub_h(p16_0, q8s_0); + p16_1 = __lasx_xvsub_h(p16_1, q8s_1); + p16_2 = __lasx_xvsub_h(p16_2, q8s_2); + p16_3 = __lasx_xvsub_h(p16_3, q8s_3); + + p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0); + p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1); + p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2); + p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3)); + } + + acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); + } + + *s = hsum_float_8(acc); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx) +static const int8_t keven_signs_q2xs[1024] = { + 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, + 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, + 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, + 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, + 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, + 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, + 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, + 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, + 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, + 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, + 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, + 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, + 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, + 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, + 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, + 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, + 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, + 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, + 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, + 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, + 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, + 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, + 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, + 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +}; +#endif + +void lm_ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xxs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + lm_ggml_int8x16x4_t q2u; + lm_ggml_int8x16x4_t q2s; + lm_ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + float sumf1 = 0, sumf2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1]))); + q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3]))); + q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9]))); + q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11]))); + q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); + q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); + q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127)))); + q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]); + const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]); + sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28)); + sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28)); + } + sumf += d*(sumf1 + sumf2); + } + *s = 0.25f * sumf; + +#elif defined(__AVX2__) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], + signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__AVX__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]); + const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); + const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__POWER9_VECTOR__) + const vector int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + memcpy(aux32, q2, 4*sizeof(uint32_t)); + q2 += 8; + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1])}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3])}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9])}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11])}; + + vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127))}; + vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127))}; + vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127))}; + vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127))}; + + vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); + vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); + vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); + vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = aux32[1] >> 28; + const uint16_t ls1 = aux32[3] >> 28; + + vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1)); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + + const __m256i q2_1 = lasx_set_d(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m256i q2_2 = lasx_set_d(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m256i s2_1 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i s2_2 = lasx_set_d(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], + signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1); + const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2); + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); + const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, q2, 2*sizeof(uint32_t)); + q2 += 4; + const uint32_t ls = 2*(aux32[1] >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} + +void lm_ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + lm_ggml_int8x16x4_t q2u; + lm_ggml_int8x16x4_t q2s; + lm_ggml_int8x16x4_t q8b; + + int32x4x4_t scales32; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint8x8_t scales8 = vld1_u8(x[i].scales); + const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf)); + const uint8x8_t scales_h = vshr_n_u8(scales8, 4); + uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); + scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1)); + const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales)); + const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales)); + scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1))); + scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1))); + scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2))); + scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2))); + int32x4_t sumi = vdupq_n_s32(0); + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; + q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511)))); + q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511)))); + q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511)))); + q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511)))); + q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9)))); + q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9)))); + q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9)))); + q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]); + const int32x4_t p2 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]); + const int32x4_t p3 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]); + const int32x4_t p4 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]); + const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4)); + sumi = vmlaq_s32(sumi, p, scales32.val[ib64]); + q2 += 8; + } + sumf += d*vaddvq_s32(sumi); + } + *s = 0.125f * sumf; + +#elif defined(__AVX2__) + + const __m256i mone = _mm256_set1_epi8(1); + static const char block_sign_shuffle_mask_1[32] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, + }; + static const uint8_t bit_selector_mask_bytes[32] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes); + const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1); + const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2); + + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper); + const __m256i m511 = _mm256_set1_epi16(511); + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m256i aux_gindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = _mm_set1_epi64x(aux64); + stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); + const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); + + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { + + const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16; + aux_gindex = _mm256_and_si256(q2_data, m511); + + const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9); + const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13); + const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper); + + const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting); + const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits); + + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + + const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], + iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], + iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); + const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], + iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); + const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], + iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); + + const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits); + const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1); + const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l); + const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h); + + __m256i signs; + signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone)); + + signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone)); + + signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone)); + + signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone)); + + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3); + const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4); + + const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0))); + const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1))); + const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2))); + const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3))); + + sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2)); + sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4)); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__AVX__) + const __m128i mone = _mm_set1_epi8(1); + static const char block_sign_shuffle_mask_1[32] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, + }; + static const uint8_t bit_selector_mask_bytes[32] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes); + const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1); + const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1); + const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1); + const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2); + const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1); + + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper); + const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1); + const __m128i m511 = _mm_set1_epi16(511); + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m256i aux_gindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = _mm_set1_epi64x(aux64); + stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); + const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { + + const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2); + const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16; + aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511)); + + const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9); + const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9); + const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13); + const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13); + const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0); + const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1); + + const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0); + const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1); + const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0); + const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1); + + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]); + const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]); + const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]); + const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); + const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]); + + // AVX2 full_signs_1 is full_sign_bits_0 here + // AVX2 full_signs_2 is full_sign_bits_1 here + __m128i signs_0, signs_1; + signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone)); + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0); + const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1); + const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0); + const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1); + + __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)); + const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)); + const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)); + const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)); + const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1)); + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1)); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__loongarch_asx) + + const __m256i mone = __lasx_xvreplgr2vr_b(1); + static const char block_sign_shuffle_mask_1[32] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, + }; + static const uint8_t bit_selector_mask_bytes[32] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m256i bit_selector_mask = __lasx_xvld((const __m256i*)bit_selector_mask_bytes, 0); + const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0); + const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0); + + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m256i bit_helper = __lasx_xvld((const __m256i*)k_bit_helper, 0); + const __m256i m511 = __lasx_xvreplgr2vr_h(511); + const __m128i m4 = __lsx_vreplgr2vr_b(0xf); + const __m128i m1 = __lsx_vreplgr2vr_b(1); + + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m256i aux_gindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = __lsx_vreplgr2vr_d(aux64); + stmp = __lsx_vilvl_b( __lsx_vand_v(__lsx_vsrli_h(stmp, 4), m4), __lsx_vand_v(stmp, m4)); + const __m128i scales = __lsx_vadd_b(__lsx_vslli_h(stmp, 1), m1); + + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { + + const __m256i q2_data = __lasx_xvld((const __m256i*)q2, 0); q2 += 16; + aux_gindex = __lasx_xvand_v(q2_data, m511); + + const __m256i partial_sign_bits = __lasx_xvsrli_h(q2_data, 9); + const __m256i partial_sign_bits_upper = __lasx_xvsrli_h(q2_data, 13); + const __m256i partial_sign_bits_for_counting = __lasx_xvxor_v(partial_sign_bits, partial_sign_bits_upper); + + const __m256i odd_bits = lasx_shuffle_b(bit_helper, partial_sign_bits_for_counting); + const __m256i full_sign_bits = __lasx_xvor_v(partial_sign_bits, odd_bits); + + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_4 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + + const __m256i q2_1 = lasx_set_d(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], + iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); + const __m256i q2_2 = lasx_set_d(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], + iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); + const __m256i q2_3 = lasx_set_d(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], + iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); + const __m256i q2_4 = lasx_set_d(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], + iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); + + const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0); + const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1); + const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l); + const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h); + + __m256i signs; + signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1); + + signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_2); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2); + + signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_1); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_3 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_3); + + signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_2); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_4 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_4); + + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const __m256i dot3 = lasx_maddubs_h(q2_3, q8s_3); + const __m256i dot4 = lasx_maddubs_h(q2_4, q8s_4); + + const __m256i sc1 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+0))); + const __m256i sc2 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+1))); + const __m256i sc3 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+2))); + const __m256i sc4 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+3))); + + sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot1, sc1)); + sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot2, sc2)); + sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot3, sc3)); + sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot4, sc4)); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); +#elif defined(__POWER9_VECTOR__) + const vector int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint16_t * restrict q2 = x[i].qs; + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; ++j) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xs_grid + (q2[0] & 511)), *(const int64_t *)(iq2xs_grid + (q2[1] & 511))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xs_grid + (q2[2] & 511)), *(const int64_t *)(iq2xs_grid + (q2[3] & 511))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xs_grid + (q2[4] & 511)), *(const int64_t *)(iq2xs_grid + (q2[5] & 511))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xs_grid + (q2[6] & 511)), *(const int64_t *)(iq2xs_grid + (q2[7] & 511))}; + + vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((q2[0] >> 9))), *(const int64_t *)(signs64 + ((q2[1] >> 9)))}; + vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((q2[2] >> 9))), *(const int64_t *)(signs64 + ((q2[3] >> 9)))}; + vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((q2[4] >> 9))), *(const int64_t *)(signs64 + ((q2[5] >> 9)))}; + vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((q2[6] >> 9))), *(const int64_t *)(signs64 + ((q2[7] >> 9)))}; + q2 += 8; + + vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); + vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); + vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); + vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); + const uint16_t ls3 = (uint16_t)(sc[1] >> 4); + sc += 2; + + vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); + vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); + vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); + + vsumi0 = vec_msum(qv0, vscales0, vsumi0); + vsumi1 = vec_msum(qv1, vscales1, vsumi1); + vsumi2 = vec_msum(qv2, vscales2, vsumi2); + vsumi3 = vec_msum(qv3, vscales3, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1; + const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls1; + sumi = 0; + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls2; + q2 += 4; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} + +void lm_ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_s * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + const lm_ggml_uint8x16x2_t mask1 = lm_ggml_vld1q_u8_x2(k_mask1); + const uint8x16_t mask2 = vld1q_u8(k_mask2); + const uint8x16_t m1 = vdupq_n_u8(1); + const int32x4_t vzero = vdupq_n_s32(0); + + uint8x16x2_t vs; + lm_ggml_int8x16x4_t q2s; + lm_ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; + + int sumi1 = 0, sumi2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; + q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300))))); + q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300))))); + q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300))))); + q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300))))); + qs += 8; + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); + vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vceqq_u8(vs.val[0], mask2); + vs.val[1] = vceqq_u8(vs.val[1], mask2); + + q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]); + q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]); + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); + vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vceqq_u8(vs.val[0], mask2); + vs.val[1] = vceqq_u8(vs.val[1], mask2); + + signs += 4; + + q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]); + q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]); + + const int32x4_t p1 = lm_ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]); + const int32x4_t p2 = lm_ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]); + const int32x4_t p3 = lm_ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]); + const int32x4_t p4 = lm_ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]); + + sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf)); + sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >> 4)); + sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf)); + sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >> 4)); + } + sumf += d*(sumi1 + sumi2); + } + + *s = 0.125f * sumf; + +#elif defined(__AVX2__) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); + const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + + uint64_t aux64; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); + const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 + + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], + iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], + iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + qs += 8; + + __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); + + aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 + + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0))); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1))); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__AVX__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); + const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); + const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); + const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); + + uint64_t aux64; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); + const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8); + const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8)); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]); + qs += 8; + + __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); + __m128i aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); + const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); + + aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); + aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); + const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); + + signs += 4; + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0))); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1))); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0))); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1))); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__POWER9_VECTOR__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + const vector int v0 = vec_splats((int32_t)0); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector unsigned char mask0 = vec_xl( 0, k_mask1); + const vector unsigned char mask1 = vec_xl(16, k_mask1); + const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * restrict q2 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2s_grid + (q2[0] | ((qh[0] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[1] | ((qh[0] << 6) & 0x300)))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2s_grid + (q2[2] | ((qh[0] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[3] | ((qh[0] << 2) & 0x300)))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2s_grid + (q2[4] | ((qh[1] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[5] | ((qh[1] << 6) & 0x300)))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2s_grid + (q2[6] | ((qh[1] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[7] | ((qh[1] << 2) & 0x300)))}; + q2 += 8; + qh += 2; + + vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); + vector signed char vsigns23 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); + signs += 4; + + vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); + vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); + vector signed char vsigns2 = vec_perm(vsigns23, vsigns23, mask0); + vector signed char vsigns3 = vec_perm(vsigns23, vsigns23, mask1); + + vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); + vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); + vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); + vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); + + vector signed char q2x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux64x2_0), vsigns0); + vector signed char q2x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux64x2_1), vsigns1); + vector signed char q2x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux64x2_2), vsigns2); + vector signed char q2x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux64x2_3), vsigns3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); + const uint16_t ls3 = (uint16_t)(sc[1] >> 4); + sc += 2; + + vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); + vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); + vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); + + vsumi0 = vec_msum(qv0, vscales0, vsumi0); + vsumi1 = vec_msum(qv1, vscales1, vsumi1); + vsumi2 = vec_msum(qv2, vscales2, vsumi2); + vsumi3 = vec_msum(qv3, vscales3, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + + const __m128i m4 = __lsx_vreplgr2vr_b(0xf); + const __m128i m1 = __lsx_vreplgr2vr_b(1); + + const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0); + const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0); + uint64_t aux64; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; + + __m128i tmp1; + memcpy(&aux64, x[i].scales, 8); + tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64, 0); + tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64 >> 4, 1); + const __m128i scales8 = __lsx_vadd_b(__lsx_vslli_h(__lsx_vand_v(tmp1, m4), 1), m1); + const __m256i scales16 = lasx_ext8_16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 + + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q2_1 = lasx_set_d(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], + iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m256i q2_2 = lasx_set_d(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], + iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + qs += 8; + + __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | ((uint32_t) signs[1] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1); + + aux256 = __lasx_xvreplgr2vr_w(signs[2] | ((uint32_t) signs[3] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 + + const __m256i p1 = lasx_madd_h(dot1, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+0))); + const __m256i p2 = lasx_madd_h(dot2, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+1))); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = qs + QK_K/8; + + int bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf); + int ls2 = 1 + 2*(x[i].scales[ib32] >> 4); + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += ls1 * sumi1 + ls2 * sumi2; + qs += 4; + signs += 4; + } + + sumf += d * bsum; + } + + *s = 0.125f * sumf; + +#endif + +} + +void lm_ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_xxs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + lm_ggml_int8x16x4_t q3s; + lm_ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + float sumf1 = 0, sumf2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; + memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t); + const uint32x4_t aux32x4_0 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]); + const uint32x4_t aux32x4_1 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]); + const uint32x4_t aux32x4_2 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]); + const uint32x4_t aux32x4_3 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]); + q3 += 16; + q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127)))); + q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127)))); + q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); + q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); + q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0)); + q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1)); + q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2)); + q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3)); + const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); + const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); + sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28)); + sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28)); + } + sumf += d*(sumf1 + sumf2); + } + *s = 0.5f * sumf; + +#elif defined(__AVX2__) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], + signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.25f * hsum_float_8(accumf); + +#elif defined(__AVX__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); + q3 += 8; + const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]); + const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.25f * hsum_float_8(accumf); + +#elif defined(__POWER9_VECTOR__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + const vector int v0 = vec_splats((int32_t)0); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * restrict q3 = x[i].qs; + const uint32_t * restrict signs = (const uint32_t *)(x[i].qs + QK_K/4); + const int8_t * restrict q8 = y[i].qs; + +#pragma GCC unroll 1 + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector unsigned int aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]}; + vector unsigned int aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]}; + vector unsigned int aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]}; + vector unsigned int aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]}; + q3 += 16; + + vector unsigned long long aux64x2_0 = {(uint64_t)(signs64[(signs[0] >> 0) & 127]), (uint64_t)(signs64[(signs[0] >> 7) & 127])}; + vector unsigned long long aux64x2_1 = {(uint64_t)(signs64[(signs[0] >> 14) & 127]), (uint64_t)(signs64[(signs[0] >> 21) & 127])}; + vector unsigned long long aux64x2_2 = {(uint64_t)(signs64[(signs[1] >> 0) & 127]), (uint64_t)(signs64[(signs[1] >> 7) & 127])}; + vector unsigned long long aux64x2_3 = {(uint64_t)(signs64[(signs[1] >> 14) & 127]), (uint64_t)(signs64[(signs[1] >> 21) & 127])}; + + vector signed char q3x0 = vec_mul((vector signed char)aux64x2_0, (vector signed char)aux32x4_0); + vector signed char q3x1 = vec_mul((vector signed char)aux64x2_1, (vector signed char)aux32x4_1); + vector signed char q3x2 = vec_mul((vector signed char)aux64x2_2, (vector signed char)aux32x4_2); + vector signed char q3x3 = vec_mul((vector signed char)aux64x2_3, (vector signed char)aux32x4_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(signs[0] >> 28); + const uint16_t ls1 = (uint16_t)(signs[1] >> 28); + signs += 2; + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.25f * vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q2_1 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + const __m256i q2_2 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + + const __m256i s2_1 = lasx_set_d(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], + signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m256i s2_2 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1); + const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2); + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + + const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); + const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = 0.25f * hsum_float_8(accumf); + +#else + + uint32_t aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t); + const uint32_t ls = 2*(aux32 >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + q3 += 8; + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.25f * sumf; +#endif +} + +void lm_ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_s * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + typedef union { + uint16x8_t vec_index; + uint16_t index[8]; + } vec_index_t; + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + + const lm_ggml_uint8x16x2_t mask1 = lm_ggml_vld1q_u8_x2(k_mask1); + const uint8x16_t mask2 = vld1q_u8(k_mask2); + + const int16x8_t hshift = vld1q_s16(k_shift); + const uint16x8_t m256 = vdupq_n_u16(256); + const uint8x16_t m1 = vdupq_n_u8(1); + + uint8x16x2_t vs; + lm_ggml_int8x16x4_t q3s; + lm_ggml_int8x16x4_t q8b; + vec_index_t idx; + + uint32_t scales32[2]; + const uint8_t * scales8 = (const uint8_t *)scales32; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(scales32, x[i].scales, 4); + scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; + scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; + + int sumi1 = 0, sumi2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; + + const uint8x16_t idx_l = vld1q_u8(qs); qs += 16; + idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256)); + const uint32x4_t aux32x4_0 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], + iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); + const uint32x4_t aux32x4_1 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], + iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); + idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256)); + const uint32x4_t aux32x4_2 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], + iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); + const uint32x4_t aux32x4_3 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], + iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); + + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); + vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); + + q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0)); + q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1)); + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); + vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); + + signs += 4; + + q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2)); + q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3)); + + const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); + const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); + + sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0]; + sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4]; + } + sumf += d*(sumi1 + sumi2); + } + *s = sumf; + +#elif defined(__AVX2__) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); + const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + + const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); + const __m256i idx_mask = _mm256_set1_epi32(256); + + typedef union { + __m256i vec[2]; + uint32_t index[16]; + } index_t; + + index_t idx; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16; + idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]); + idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]); + idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask); + idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask); + idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l))); + idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1))); + + // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. + //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); + //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); + const __m256i q2_1 = _mm256_set_epi32( + iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], + iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] + ); + const __m256i q2_2 = _mm256_set_epi32( + iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], + iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] + ); + + __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); + + aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; + const uint16_t ls2 = x[i].scales[ib32/2] >> 4; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = hsum_float_8(accumf); + +#elif defined(__AVX__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); + const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); + const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); + const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); + + const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256); + const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16); + const __m128i idx_mask = _mm_set1_epi32(256); + + typedef union { + __m128i vec[4]; + uint32_t index[16]; + } index_t; + + index_t idx; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs); + const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp); + const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16; + idx.vec[0] = _mm_set1_epi32(qh[ib32+0]); + idx.vec[1] = idx.vec[0]; + idx.vec[2] = _mm_set1_epi32(qh[ib32+1]); + idx.vec[3] = idx.vec[2]; + + idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask); + idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask); + idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask); + idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask); + + idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0)); + idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8))); + idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1)); + idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8))); + + const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]); + const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]); + const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]); + const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]); + + __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16)); + __m128i aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); + const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); + + aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16)); + aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); + const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); + + signs += 4; + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; + const uint16_t ls2 = x[i].scales[ib32/2] >> 4; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = hsum_float_8(accumf); + +#elif defined(__POWER9_VECTOR__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + const vector int v0 = vec_splats((int32_t)0); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector unsigned char mask0 = vec_xl( 0, k_mask1); + const vector unsigned char mask1 = vec_xl(16, k_mask1); + const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].signs); + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector unsigned int aux32x4_0 = {iq3s_grid[q3[ 0] | ((qh[0] << 8) & 256)], iq3s_grid[q3[ 1] | ((qh[0] << 7) & 256)], + iq3s_grid[q3[ 2] | ((qh[0] << 6) & 256)], iq3s_grid[q3[ 3] | ((qh[0] << 5) & 256)]}; + vector unsigned int aux32x4_1 = {iq3s_grid[q3[ 4] | ((qh[0] << 4) & 256)], iq3s_grid[q3[ 5] | ((qh[0] << 3) & 256)], + iq3s_grid[q3[ 6] | ((qh[0] << 2) & 256)], iq3s_grid[q3[ 7] | ((qh[0] << 1) & 256)]}; + vector unsigned int aux32x4_2 = {iq3s_grid[q3[ 8] | ((qh[1] << 8) & 256)], iq3s_grid[q3[ 9] | ((qh[1] << 7) & 256)], + iq3s_grid[q3[10] | ((qh[1] << 6) & 256)], iq3s_grid[q3[11] | ((qh[1] << 5) & 256)]}; + vector unsigned int aux32x4_3 = {iq3s_grid[q3[12] | ((qh[1] << 4) & 256)], iq3s_grid[q3[13] | ((qh[1] << 3) & 256)], + iq3s_grid[q3[14] | ((qh[1] << 2) & 256)], iq3s_grid[q3[15] | ((qh[1] << 1) & 256)]}; + q3 += 16; + qh += 2; + + vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); + vector signed char vsigns02 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); + signs += 4; + + vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); + vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); + vector signed char vsigns2 = vec_perm(vsigns02, vsigns02, mask0); + vector signed char vsigns3 = vec_perm(vsigns02, vsigns02, mask1); + + vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); + vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); + vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); + vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); + + vector signed char q3x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux32x4_0), vsigns0); + vector signed char q3x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux32x4_1), vsigns1); + vector signed char q3x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux32x4_2), vsigns2); + vector signed char q3x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux32x4_3), vsigns3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + sc ++; + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0); + const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0); + + __m256i idx_shift = lasx_set_w(1, 2, 3, 4, 5, 6, 7, 8); + const __m256i idx_mask = __lasx_xvreplgr2vr_w(256); + + typedef union { + __m256i vec[2]; + uint32_t index[16]; + } index_t; + + index_t idx; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i idx_l = lasx_extu8_16(__lsx_vld(qs, 0)); qs += 16; + idx.vec[0] = __lasx_xvreplgr2vr_w(qh[ib32+0]); + idx.vec[1] = __lasx_xvreplgr2vr_w(qh[ib32+1]); + idx.vec[0] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[0], idx_shift), idx_mask); + idx.vec[1] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[1], idx_shift), idx_mask); + idx.vec[0] = __lasx_xvor_v(idx.vec[0], lasx_ext16_32(lasx_extracti128(idx_l, 0))); + idx.vec[1] = __lasx_xvor_v(idx.vec[1], lasx_ext16_32(lasx_extracti128(idx_l, 1))); + + // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. + //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); + //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); + const __m256i q2_1 = lasx_set_w( + iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], + iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] + ); + const __m256i q2_2 = lasx_set_w( + iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], + iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] + ); + + __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | (signs[1] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1); + + aux256 = __lasx_xvreplgr2vr_w(signs[2] | (signs[3] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; + const uint16_t ls2 = x[i].scales[ib32/2] >> 4; + const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); + const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = hsum_float_8(accumf); + +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint8_t * restrict signs = x[i].signs; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1; + const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls1; + sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls2; + } + sumf += d * bsum; + } + *s = sumf; +#endif +} + +#if defined(__AVX2__) +static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { + const __m256i ax = _mm256_sign_epi8(x, x); + const __m256i sy = _mm256_sign_epi8(y, x); + return _mm256_maddubs_epi16(ax, sy); +} +#elif defined(__loongarch_asx) +static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { + const __m256i ax = __lasx_xvsigncov_b(x, x); + const __m256i sy = __lasx_xvsigncov_b(x, y); + __m256i tmp1, tmp2, tmp3; + tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy); + tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy); + tmp3 = __lasx_xvadd_h(tmp1, tmp2); + return __lasx_xvsat_h(tmp3, 15); +} +#endif + +void lm_ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined __ARM_NEON + + lm_ggml_int8x16x4_t q1b; + lm_ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + int sumi1 = 0, sumi2 = 0, sumi3 = 0; + + for (int ib = 0; ib < QK_K/32; ib += 2) { + + q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700))))); + q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700))))); + q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700))))); + q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700))))); + qs += 8; + + q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; + + const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]); + const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]); + + const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + sumi1 += vaddvq_s32(p1) * ls1; + sumi2 += vaddvq_s32(p2) * ls2; + sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1) + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1); + + } + + sumf += y[i].d * LM_GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3); + } + + *s = sumf; + +#elif defined __AVX2__ + + __m256 accum = _mm256_setzero_ps(); + float accum1 = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + __m256i sumi = _mm256_setzero_si256(); + int sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], + iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); + const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], + iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); + qs += 8; + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2)); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2)); + sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + } + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum); + accum1 += d * sumi1; + + } + + *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; + +#elif defined __AVX__ + __m256 accum = _mm256_setzero_ps(); + float accum1 = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + int sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); + const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]); + const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); + const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]); + qs += 8; + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); + const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); + const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); + const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); + const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2)); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); + sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + } + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum); + accum1 += d * sumi1; + + } + + *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; + +#elif defined(__POWER9_VECTOR__) + const vector unsigned char v0 = vec_splats((unsigned char)0x0); + const vector unsigned short vsign = vec_splats((unsigned short)0x8000); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi8 = vec_splats((int32_t)0); + + const uint8_t * restrict q1 = x[i].qs; + const uint16_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + const int16_t * restrict qs = y[i].bsums; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q1, 0, 1); + __builtin_prefetch(qh, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq1s_grid + (q1[0] | ((qh[0] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[1] | ((qh[0] << 5) & 0x700)))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq1s_grid + (q1[2] | ((qh[0] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[3] | ((qh[0] >> 1) & 0x700)))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq1s_grid + (q1[4] | ((qh[1] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[5] | ((qh[1] << 5) & 0x700)))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq1s_grid + (q1[6] | ((qh[1] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[7] | ((qh[1] >> 1) & 0x700)))}; + q1 += 8; + + vector signed char q1x0 = (vector signed char)aux64x2_0; + vector signed char q1x1 = (vector signed char)aux64x2_1; + vector signed char q1x2 = (vector signed char)aux64x2_2; + vector signed char q1x3 = (vector signed char)aux64x2_3; + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q1x0, q8y0), vec_mulo(q1x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q1x1, q8y1), vec_mulo(q1x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q1x2, q8y2), vec_mulo(q1x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q1x3, q8y3), vec_mulo(q1x3, q8y3)); + + const uint16_t ls0 = (uint16_t)((qh[0] >> 12) & 7); + const uint16_t ls1 = (uint16_t)((qh[1] >> 12) & 7); + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + vector signed short vscales = vec_sld(vscales23, vscales01, 8); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + + vector signed short q8ysums = vec_xl_len(qs, 8); + qs += 4; + q8ysums = vec_mergeh(q8ysums, (vector signed short)v0); + + vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8); + qh += 2; + vector __bool short vsel = vec_cmpge(qxh, (vector signed short)v0); + + vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel); + + vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + + vsumf0 = vec_madd(vec_ctf(vsumi8, 0), vec_mul(vd, vec_splats(IQ1S_DELTA)), vsumf0); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + __m256 accum = (__m256)__lasx_xvldi(0); + float accum1 = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + __m256i sumi = __lasx_xvldi(0); + int sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ib += 2) { + __m256i q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)], 0); + q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], 1); + q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2); + q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3); + + __m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0); + q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1); + q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2); + q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], 3); + + qs += 8; + const __m256i q8b_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8b_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + + __m256i tmp1, tmp5, tmp6; + tmp1 = __lasx_xvreplgr2vr_h(ls1); + tmp5 = __lasx_xvmulwev_w_h(dot1, tmp1); + tmp6 = __lasx_xvmulwod_w_h(dot1, tmp1); + const __m256i p1 = __lasx_xvadd_w(tmp5, tmp6); + + tmp1 = __lasx_xvreplgr2vr_h(ls2); + tmp5 = __lasx_xvmulwev_w_h(dot2, tmp1); + tmp6 = __lasx_xvmulwod_w_h(dot2, tmp1); + const __m256i p2 = __lasx_xvadd_w(tmp5, tmp6); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p1, p2)); + sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + } + + const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); + accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum); + accum1 += d * sumi1; + } + + *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; + +#else + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + int sumi = 0, sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = 2*((qh[ib] >> 12) & 7) + 1; + const int delta = qh[ib] & 0x8000 ? -1 : 1; + int lsum = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); + for (int j = 0; j < 8; ++j) { + lsum += q8[j] * grid[j]; + } + q8 += 8; + } + sumi += ls * lsum; + sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]); + qs += 4; + } + + sumf += LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; + +#endif +} + +void lm_ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + +#if defined __ARM_NEON + const int32x4_t mask = vdupq_n_s32(0x7); + const int32x4_t mone = vdupq_n_s32(1); + const int32x4_t mzero = vdupq_n_s32(0); + + lm_ggml_int8x16x4_t deltas; + deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1)); + deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1)); + deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1)); + deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1)); + + lm_ggml_int8x16x4_t q1b; + lm_ggml_int8x16x4_t q8b; + + uint32_t aux32; + const uint8_t * aux8 = (const uint8_t *)&aux32; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + int32x4_t sumi1 = mzero; + int32x4_t sumi2 = mzero; + + for (int ib = 0; ib < QK_K/32; ib += 2) { + + q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700))))); + q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700))))); + q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700))))); + q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700))))); + + q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; + + const int32x4_t p1 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), lm_ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1])); + const int32x4_t p2 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), lm_ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3])); + const int32x4_t p12 = vpaddq_s32(p1, p2); + + const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that + aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202); + + const int32x4_t p3 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), lm_ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1])); + const int32x4_t p4 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), lm_ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3])); + const int32x4_t p34 = vpaddq_s32(p3, p4); + + int32x4_t scales_4 = lm_ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9); + + scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone); + + sumi1 = vmlaq_s32(sumi1, scales_4, p12); + sumi2 = vmlaq_s32(sumi2, scales_4, p34); + + qs += 8; qh += 4; + + } + + sumf += y[i].d * LM_GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2)); + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i mask = _mm256_set1_epi16(0x7); + const __m256i mone = _mm256_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m256i q1b_1 = _mm256_set_epi64x( + iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)], + iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)] + ); + const __m256i q1b_2 = _mm256_set_epi64x( + iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)], + iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)] + ); + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + + const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + + const __m256i dot3 = mul_add_epi8(delta1, q8b_1); + const __m256i dot4 = mul_add_epi8(delta2, q8b_2); + + __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0)); + __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6)); + + scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone); + scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone); + const __m256i p1 = _mm256_madd_epi16(dot1, scale1); + const __m256i p2 = _mm256_madd_epi16(dot2, scale2); + const __m256i p3 = _mm256_madd_epi16(dot3, scale1); + const __m256i p4 = _mm256_madd_epi16(dot4, scale2); + + sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4)); + + qs += 8; qh += 4; + } + + const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(scale.f16)); + + accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1); + accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2); + } + + *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); + +#elif defined __AVX__ + const __m128i mask = _mm_set1_epi16(0x7); + const __m128i mone = _mm_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q1b_1_0 = _mm_set_epi64x( + iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]); + const __m128i q1b_1_1 = _mm_set_epi64x( + iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]); + const __m128i q1b_2_0 = _mm_set_epi64x( + iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]); + const __m128i q1b_2_1 = _mm_set_epi64x( + iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); + const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); + const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); + const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); + + const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + + const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0); + const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1); + const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0); + const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1); + + __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0); + __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3); + __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6); + __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9); + + scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone); + scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone); + scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone); + scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone); + const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1); + const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0); + const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1); + const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0); + const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1)); + + qs += 8; qh += 4; + } + + const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(scale.f16)); + + accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1); + accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2); + } + + *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); + +#else + + int sum1[2], sum2[2], delta[4]; + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + delta[0] = qh[0] & 0x08 ? -1 : 1; + delta[1] = qh[0] & 0x80 ? -1 : 1; + delta[2] = qh[1] & 0x08 ? -1 : 1; + delta[3] = qh[1] & 0x80 ? -1 : 1; + sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700))); + int lsum1 = 0, lsum2 = 0; + for (int j = 0; j < 8; ++j) { + lsum1 += q8[j] * grid[j]; + lsum2 += q8[j]; + } + q8 += 8; + sum1[l/2] += lsum1; + sum2[l/2] += lsum2*delta[l]; + } + + const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1; + const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1; + + sumi1 += sum1[0] * ls1 + sum1[1] * ls2; + sumi2 += sum2[0] * ls1 + sum2[1] * ls2; + qs += 4; + qh += 2; + } + + sumf += LM_GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; + +#endif +} + +void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * restrict x = vx; + const block_q8_0 * restrict y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + +#if defined __ARM_NEON + const int8x16_t values = vld1q_s8(kvalues_iq4nl); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + uint8x16x2_t q4bits; + int8x16x4_t q4b; + int8x16x4_t q8b; + int32x4_t prod_1, prod_2; + + for (; ib + 1 < nb; ib += 2) { + + q4bits.val[0] = vld1q_u8(x[ib + 0].qs); + q4bits.val[1] = vld1q_u8(x[ib + 1].qs); + q8b.val[0] = vld1q_s8(y[ib + 0].qs); + q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16); + q8b.val[2] = vld1q_s8(y[ib + 1].qs); + q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16); + + q4b.val[0] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); + q4b.val[1] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); + q4b.val[2] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); + q4b.val[3] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); + + prod_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); + prod_2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); + + sumf += + LM_GGML_FP16_TO_FP32(x[ib+0].d) * LM_GGML_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) + + LM_GGML_FP16_TO_FP32(x[ib+1].d) * LM_GGML_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2); + } + +#elif defined __AVX2__ + + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + const __m256i mone = _mm256_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs); + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs); + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs); + const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); + const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); + const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); + accum1 = _mm256_fmadd_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 0].d)*LM_GGML_FP16_TO_FP32(x[ib + 0].d)), + _mm256_cvtepi32_ps(p_1), accum1); + accum2 = _mm256_fmadd_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 1].d)*LM_GGML_FP16_TO_FP32(x[ib + 1].d)), + _mm256_cvtepi32_ps(p_2), accum2); + } + + sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); + +#elif defined __AVX__ + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); + + const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); + const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); + const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); + const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); + + const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1); + const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d); + accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum); + } + + sumf = hsum_float_8(accum); + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + + const vector signed char values = vec_xl( 0, kvalues_iq4nl); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q4x0 = vec_and(qxs, lowMask); + vector signed char q4x1 = vec_sr(qxs, v4); + + q4x0 = vec_perm(values, values, (vector unsigned char)q4x0); + q4x1 = vec_perm(values, values, (vector unsigned char)q4x1); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi1 = vec_sum4s(qv1, vsumi1); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + } + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#elif defined (__loongarch_asx) + + const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); + const __m128i m4b = __lsx_vreplgr2vr_b(0x0f); + const __m256i mone = __lasx_xvreplgr2vr_h(1); + + __m256 accum1 = (__m256)__lasx_xvldi(0); + __m256 accum2 = (__m256)__lasx_xvldi(0); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0); + const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0); + const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0); + const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0); + const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)), + lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b))); + const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)), + lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const __m256i p_1 = lasx_madd_h(p16_1, mone); + const __m256i p_2 = lasx_madd_h(p16_2, mone); + accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(y[ib + 0].d)*LM_GGML_FP16_TO_FP32(x[ib + 0].d)), + __lasx_xvffint_s_w(p_1), accum1); + accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(y[ib + 1].d)*LM_GGML_FP16_TO_FP32(x[ib + 1].d)), + __lasx_xvffint_s_w(p_2), accum2); + } + + sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2)); + +#endif + for (; ib < nb; ++ib) { + const float d = LM_GGML_FP16_TO_FP32(y[ib].d)*LM_GGML_FP16_TO_FP32(x[ib].d); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; + sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; + } + sumf += d * (sumi1 + sumi2); + } + *s = sumf; +} + +void lm_ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined __ARM_NEON + const int8x16_t values = vld1q_s8(kvalues_iq4nl); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + lm_ggml_uint8x16x2_t q4bits; + lm_ggml_int8x16x4_t q4b; + lm_ggml_int8x16x4_t q8b; + int32x4_t prod_1, prod_2; + + float sumf = 0; + + for (int ibl = 0; ibl < nb; ++ibl) { + + const int8_t * q8 = y[ibl].qs; + const uint8_t * q4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/64; ++ib) { + + q4bits = lm_ggml_vld1q_u8_x2(q4); q4 += 32; + q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; + + q4b.val[0] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); + q4b.val[1] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); + q4b.val[2] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); + q4b.val[3] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); + + prod_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); + prod_2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); + + int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; + int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + h >>= 4; + sumi1 += vaddvq_s32(prod_1) * ls1; + sumi2 += vaddvq_s32(prod_2) * ls2; + + } + + sumf += LM_GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs); qs += 16; + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16; + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); + const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1)); + const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2)); + sumi1 = _mm256_add_epi32(p_1, sumi1); + sumi2 = _mm256_add_epi32(p_2, sumi2); + } + accum = _mm256_fmadd_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum); + } + + *s = hsum_float_8(accum); + +#elif defined __AVX__ + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16; + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16; + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); + const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); + const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); + const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); + const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1)); + const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1)); + const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2)); + const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2)); + sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0); + sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1); + sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0); + sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1); + } + __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0); + __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1); + accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum); + } + + *s = hsum_float_8(accum); + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector signed char values = vec_xl( 0, kvalues_iq4nl); + + for (int ibl = 0; ibl < nb; ++ibl) { + + vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ibl].d)); + vector float vyd = vec_splats(y[ibl].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + uint16_t h = x[ibl].scales_h; + + const uint8_t * restrict q4 = x[ibl].qs; + const uint8_t * restrict sc = x[ibl].scales_l; + const int8_t * restrict q8 = y[ibl].qs; + + for (int ib = 0; ib < QK_K/64; ib ++ ) { + __builtin_prefetch(q4, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); + vector signed char qxs1 = (vector signed char)vec_xl(16, q4); + q4 += 32; + + vector signed char q4x00 = (vector signed char)vec_and(qxs0, lowMask); + vector signed char q4x01 = (vector signed char)vec_sr(qxs0, v4); + vector signed char q4x10 = (vector signed char)vec_and(qxs1, lowMask); + vector signed char q4x11 = (vector signed char)vec_sr(qxs1, v4); + + q4x00 = vec_perm(values, values, (vector unsigned char)q4x00); + q4x01 = vec_perm(values, values, (vector unsigned char)q4x01); + q4x10 = vec_perm(values, values, (vector unsigned char)q4x10); + q4x11 = vec_perm(values, values, (vector unsigned char)q4x11); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q4x00, q8y0), vec_mulo(q4x00, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x01, q8y1), vec_mulo(q4x01, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q4x10, q8y2), vec_mulo(q4x10, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q4x11, q8y3), vec_mulo(q4x11, q8y3)); + + const uint16_t ls0 = (uint16_t)(((sc[0] & 0xf) | ((h << 4) & 0x30)) - 32); + const uint16_t ls1 = (uint16_t)(((sc[0] >> 4) | ((h << 2) & 0x30)) - 32); + h >>= 4; + sc ++; + + vector signed short vscales01 = vec_splats((int16_t)ls0); + vector signed short vscales23 = vec_splats((int16_t)ls1); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); + const __m128i m4b = __lsx_vreplgr2vr_b(0x0f); + + __m256 accum = (__m256)__lasx_xvldi(0); + __m256i tmp1; + __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask; + + mask_8f = __lsx_vreplgr2vr_b(0x8f); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + __m128i zero = __lsx_vldi(0); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f); + tmp0 = __lsx_vori_b(tmp2, 0x10); + mask = __lsx_vsle_b(zero, tmp2); + tmp3 = __lsx_vand_v(tmp0, mask); + tmp3 = __lsx_vshuf_b(values128, zero, tmp3); + + tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f); + tmp0 = __lsx_vori_b(tmp2, 0x10); + mask = __lsx_vsle_b(zero, tmp2); + tmp4 = __lsx_vand_v(tmp0, mask); + tmp4 = __lsx_vshuf_b(values128, zero, tmp4); + + const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4); + + tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f); + tmp0 = __lsx_vori_b(tmp2, 0x10); + mask = __lsx_vsle_b(zero, tmp2); + tmp3 = __lsx_vand_v(tmp0, mask); + tmp3 = __lsx_vshuf_b(values128, zero, tmp3); + + tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f); + tmp0 = __lsx_vori_b(tmp2, 0x10); + mask = __lsx_vsle_b(zero, tmp2); + tmp4 = __lsx_vand_v(tmp0, mask); + tmp4 = __lsx_vshuf_b(values128, zero, tmp4); + + const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4); + + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + __m256i tmp5, tmp6; + tmp1 = __lasx_xvreplgr2vr_h(ls1); + tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1); + tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1); + const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6); + tmp1 = __lasx_xvreplgr2vr_h(ls2); + tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1); + tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1); + const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6); + sumi1 = __lasx_xvadd_w(p_1, sumi1); + sumi2 = __lasx_xvadd_w(p_2, sumi2); + } + accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accum); + } + + *s = hsum_float_8(accum); + +#else + float sumf = 0; + for (int ibl = 0; ibl < nb; ++ibl) { + const float d4d8 = LM_GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + uint16_t h = x[ibl].scales_h; + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); + const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); + h >>= 4; + const float d1 = d4d8*(ls1 - 32); + const float d2 = d4d8*(ls2 - 32); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d1 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + sumi1 = sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d2 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + } + } + *s = sumf; +#endif +} + +// ============================ 4-bit non-linear quants + +void quantize_row_iq4_nl(const float * restrict x, void * restrict y, int64_t k) { + assert(k % QK4_NL == 0); + quantize_row_iq4_nl_ref(x, y, k); +} + +void quantize_row_iq4_xs(const float * restrict x, void * restrict y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq4_xs(x, y, 1, k, NULL); +} diff --git a/cpp/ggml-cpu-quants.h b/cpp/ggml-cpu-quants.h new file mode 100644 index 00000000..44980bc4 --- /dev/null +++ b/cpp/ggml-cpu-quants.h @@ -0,0 +1,63 @@ +#pragma once + +#define LM_GGML_COMMON_DECL_C +#include "ggml-common.h" + +#include "ggml.h" + +// GGML CPU internal header + +#ifdef __cplusplus +extern "C" { +#endif + +// Quantization +void quantize_row_q4_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q4_1(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q5_1(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); + +void quantize_row_q2_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q3_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q4_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q5_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q6_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); + +void quantize_row_tq1_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_tq2_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); + +void quantize_row_iq4_nl (const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs (const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); + +// Dot product +void lm_ggml_vec_dot_q4_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_q4_1_q8_1(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_q5_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_q5_1_q8_1(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_q8_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); + +void lm_ggml_vec_dot_q2_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_q3_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_q4_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_q5_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_q6_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); + +void lm_ggml_vec_dot_tq1_0_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_tq2_0_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); + +void lm_ggml_vec_dot_iq2_xxs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_iq2_xs_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_iq2_s_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_iq3_xxs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_iq1_s_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_iq1_m_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_iq4_nl_q8_0 (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_iq4_xs_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +void lm_ggml_vec_dot_iq3_s_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); + +#ifdef __cplusplus +} +#endif diff --git a/cpp/ggml-cpu.c b/cpp/ggml-cpu.c new file mode 100644 index 00000000..406fb85a --- /dev/null +++ b/cpp/ggml-cpu.c @@ -0,0 +1,13975 @@ +#define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows +#define _USE_MATH_DEFINES // For M_PI on MSVC + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-cpu-aarch64.h" +#include "ggml-cpu-impl.h" +#include "ggml-cpu.h" +#include "ggml-impl.h" +#include "ggml-quants.h" +#include "ggml-cpu-quants.h" +#include "ggml-threading.h" +#include "ggml.h" + +#if defined(_MSC_VER) || defined(__MINGW32__) +#include // using malloc.h with MSC/MINGW +#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(__gnu_linux__) +#include +#endif + +#ifdef LM_GGML_USE_OPENMP +#include +#endif + +#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8) +#undef LM_GGML_USE_LLAMAFILE +#endif + +#ifdef LM_GGML_USE_LLAMAFILE +#include "llamafile/sgemm.h" +#endif + +#if defined(_MSC_VER) +// disable "possible loss of data" to avoid hundreds of casts +// we should just be careful :) +#pragma warning(disable: 4244 4267) + +// disable POSIX deprecation warnings +// these functions are never going away, anyway +#pragma warning(disable: 4996) + +// unreachable code because of multiple instances of code after LM_GGML_ABORT +#pragma warning(disable: 4702) +#endif + +// Note: once we move threading into a separate C++ file +// will use std::hardware_destructive_interference_size instead of hardcoding it here +// and we'll use C++ attribute syntax. +#define LM_GGML_CACHE_LINE 64 + +#if defined(__clang__) || defined(__GNUC__) +#define LM_GGML_CACHE_ALIGN __attribute__((aligned(LM_GGML_CACHE_LINE))) +#endif + +#if defined(__has_feature) +#if __has_feature(thread_sanitizer) +#define LM_GGML_TSAN_ENABLED 1 +#endif +#else // __has_feature +#if defined(__SANITIZE_THREAD__) +#define LM_GGML_TSAN_ENABLED 1 +#endif +#endif // __has_feature + +#define UNUSED LM_GGML_UNUSED +#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0) + +#if defined(LM_GGML_USE_ACCELERATE) +#include +#endif + +// floating point type used to accumulate sums +typedef double lm_ggml_float; + +#define LM_GGML_GELU_FP16 +#define LM_GGML_GELU_QUICK_FP16 + +#define LM_GGML_SOFT_MAX_UNROLL 4 +#define LM_GGML_VEC_DOT_UNROLL 2 +#define LM_GGML_VEC_MAD_UNROLL 32 + +// +// global data +// + +// precomputed gelu table for f16 (128 KB) +static lm_ggml_fp16_t lm_ggml_table_gelu_f16[1 << 16]; + +// precomputed quick gelu table for f16 (128 KB) +static lm_ggml_fp16_t lm_ggml_table_gelu_quick_f16[1 << 16]; + +#if defined(__ARM_ARCH) +struct lm_ggml_arm_arch_features_type { + int has_neon; + int has_i8mm; + int has_sve; + int sve_cnt; +} lm_ggml_arm_arch_features = {-1, -1, -1, 0}; +#endif + + +#if defined(_WIN32) + +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX + #define NOMINMAX +#endif +#include + + +#if !defined(__clang__) +#define LM_GGML_CACHE_ALIGN __declspec(align(LM_GGML_CACHE_LINE)) + +typedef volatile LONG atomic_int; +typedef atomic_int atomic_bool; +typedef atomic_int atomic_flag; + +#define ATOMIC_FLAG_INIT 0 + +typedef enum { + memory_order_relaxed, + memory_order_consume, + memory_order_acquire, + memory_order_release, + memory_order_acq_rel, + memory_order_seq_cst +} memory_order; + +static void atomic_store(atomic_int * ptr, LONG val) { + InterlockedExchange(ptr, val); +} +static void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) { + // TODO: add support for explicit memory order + InterlockedExchange(ptr, val); +} +static LONG atomic_load(atomic_int * ptr) { + return InterlockedCompareExchange(ptr, 0, 0); +} +static LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) { + // TODO: add support for explicit memory order + return InterlockedCompareExchange(ptr, 0, 0); +} +static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) { + return InterlockedExchangeAdd(ptr, inc); +} +static LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) { + // TODO: add support for explicit memory order + return InterlockedExchangeAdd(ptr, inc); +} +static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) { + return InterlockedExchange(ptr, 1); +} +static void atomic_flag_clear(atomic_flag * ptr) { + InterlockedExchange(ptr, 0); +} +static void atomic_thread_fence(memory_order mo) { + MemoryBarrier(); +} +#else // clang +#include +#endif + +typedef HANDLE pthread_t; + +typedef DWORD thread_ret_t; +static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) { + (void) unused; + HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); + if (handle == NULL) + { + return EAGAIN; + } + + *out = handle; + return 0; +} + +static int pthread_join(pthread_t thread, void * unused) { + (void) unused; + int ret = (int) WaitForSingleObject(thread, INFINITE); + CloseHandle(thread); + return ret; +} + +static int sched_yield (void) { + Sleep (0); + return 0; +} +#else + +#include +#include +#include +#if defined(__FreeBSD__) +#include +#endif + +typedef void * thread_ret_t; + +#include +#include +#include + +#endif + +typedef pthread_t lm_ggml_thread_t; + +#ifdef LM_GGML_USE_CPU_HBM +#include +#endif + +#if defined(__APPLE__) +#include +#include +#include +#endif + +// +// cache line +// + +#if defined(__cpp_lib_hardware_interference_size) +#define CACHE_LINE_SIZE hardware_destructive_interference_size +#else +#if defined(__POWER9_VECTOR__) +#define CACHE_LINE_SIZE 128 +#else +#define CACHE_LINE_SIZE 64 +#endif +#endif + +static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); + + +static void lm_ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc); +static void lm_ggml_vec_dot_f16(int n, float * restrict s, size_t bs, lm_ggml_fp16_t * restrict x, size_t bx, lm_ggml_fp16_t * restrict y, size_t by, int nrc); +static void lm_ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, lm_ggml_bf16_t * restrict x, size_t bx, lm_ggml_bf16_t * restrict y, size_t by, int nrc); + +static const struct lm_ggml_type_traits_cpu type_traits_cpu[LM_GGML_TYPE_COUNT] = { + [LM_GGML_TYPE_F32] = { + .vec_dot = (lm_ggml_vec_dot_t) lm_ggml_vec_dot_f32, + .vec_dot_type = LM_GGML_TYPE_F32, + .nrows = 1, + }, + [LM_GGML_TYPE_F16] = { + .from_float = (lm_ggml_from_float_t) lm_ggml_fp32_to_fp16_row, + .vec_dot = (lm_ggml_vec_dot_t) lm_ggml_vec_dot_f16, + .vec_dot_type = LM_GGML_TYPE_F16, + .nrows = 1, + }, + [LM_GGML_TYPE_Q4_0] = { + .from_float = quantize_row_q4_0, + .vec_dot = lm_ggml_vec_dot_q4_0_q8_0, + .vec_dot_type = LM_GGML_TYPE_Q8_0, +#if defined (__ARM_FEATURE_MATMUL_INT8) + .nrows = 2, +#else + .nrows = 1, +#endif + }, + [LM_GGML_TYPE_Q4_1] = { + .from_float = quantize_row_q4_1, + .vec_dot = lm_ggml_vec_dot_q4_1_q8_1, + .vec_dot_type = LM_GGML_TYPE_Q8_1, +#if defined (__ARM_FEATURE_MATMUL_INT8) + .nrows = 2, +#else + .nrows = 1, +#endif + }, + [LM_GGML_TYPE_Q5_0] = { + .from_float = quantize_row_q5_0, + .vec_dot = lm_ggml_vec_dot_q5_0_q8_0, + .vec_dot_type = LM_GGML_TYPE_Q8_0, + .nrows = 1, + }, + [LM_GGML_TYPE_Q5_1] = { + .from_float = quantize_row_q5_1, + .vec_dot = lm_ggml_vec_dot_q5_1_q8_1, + .vec_dot_type = LM_GGML_TYPE_Q8_1, + .nrows = 1, + }, + [LM_GGML_TYPE_Q8_0] = { + .from_float = quantize_row_q8_0, + .from_float_to_mat = quantize_mat_q8_0, + .vec_dot = lm_ggml_vec_dot_q8_0_q8_0, + .vec_dot_type = LM_GGML_TYPE_Q8_0, +#if defined (__ARM_FEATURE_MATMUL_INT8) + .nrows = 2, +#else + .nrows = 1, +#endif + }, + [LM_GGML_TYPE_Q8_1] = { + .from_float = quantize_row_q8_1, + .vec_dot_type = LM_GGML_TYPE_Q8_1, + .nrows = 1, + }, + [LM_GGML_TYPE_Q2_K] = { + .from_float = quantize_row_q2_K, + .vec_dot = lm_ggml_vec_dot_q2_K_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_Q3_K] = { + .from_float = quantize_row_q3_K, + .vec_dot = lm_ggml_vec_dot_q3_K_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_Q4_K] = { + .from_float = quantize_row_q4_K, + .vec_dot = lm_ggml_vec_dot_q4_K_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_Q5_K] = { + .from_float = quantize_row_q5_K, + .vec_dot = lm_ggml_vec_dot_q5_K_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_Q6_K] = { + .from_float = quantize_row_q6_K, + .vec_dot = lm_ggml_vec_dot_q6_K_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_IQ2_XXS] = { + .from_float = NULL, + .vec_dot = lm_ggml_vec_dot_iq2_xxs_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_IQ2_XS] = { + .from_float = NULL, + .vec_dot = lm_ggml_vec_dot_iq2_xs_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_IQ3_XXS] = { + // NOTE: from_float for iq3 and iq2_s was removed because these quants require initialization in lm_ggml_quantize_init + //.from_float = quantize_row_iq3_xxs, + .vec_dot = lm_ggml_vec_dot_iq3_xxs_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_IQ3_S] = { + //.from_float = quantize_row_iq3_s, + .vec_dot = lm_ggml_vec_dot_iq3_s_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_IQ2_S] = { + //.from_float = quantize_row_iq2_s, + .vec_dot = lm_ggml_vec_dot_iq2_s_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_IQ1_S] = { + .from_float = NULL, + .vec_dot = lm_ggml_vec_dot_iq1_s_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_IQ1_M] = { + .from_float = NULL, + .vec_dot = lm_ggml_vec_dot_iq1_m_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_IQ4_NL] = { + .from_float = quantize_row_iq4_nl, + .vec_dot = lm_ggml_vec_dot_iq4_nl_q8_0, + .vec_dot_type = LM_GGML_TYPE_Q8_0, + .nrows = 1, + }, + [LM_GGML_TYPE_IQ4_XS] = { + .from_float = quantize_row_iq4_xs, + .vec_dot = lm_ggml_vec_dot_iq4_xs_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_Q8_K] = { + .from_float = quantize_row_q8_K, + }, + [LM_GGML_TYPE_BF16] = { + .from_float = (lm_ggml_from_float_t) lm_ggml_fp32_to_bf16_row, + .vec_dot = (lm_ggml_vec_dot_t) lm_ggml_vec_dot_bf16, + .vec_dot_type = LM_GGML_TYPE_BF16, + .nrows = 1, + }, + [LM_GGML_TYPE_Q4_0_4_4] = { + .from_float = NULL, + .vec_dot = NULL, + .vec_dot_type = LM_GGML_TYPE_Q8_0, + .nrows = 1, + .ncols = 4, + .gemv = lm_ggml_gemv_q4_0_4x4_q8_0, + .gemm = lm_ggml_gemm_q4_0_4x4_q8_0, + }, + [LM_GGML_TYPE_Q4_0_4_8] = { + .from_float = NULL, + .vec_dot = NULL, + .vec_dot_type = LM_GGML_TYPE_Q8_0, + .nrows = 1, + .ncols = 4, + .gemv = lm_ggml_gemv_q4_0_4x8_q8_0, + .gemm = lm_ggml_gemm_q4_0_4x8_q8_0, + }, + [LM_GGML_TYPE_Q4_0_8_8] = { + .from_float = NULL, + .vec_dot = NULL, + .vec_dot_type = LM_GGML_TYPE_Q8_0, + .nrows = 1, + .ncols = 8, + .gemv = lm_ggml_gemv_q4_0_8x8_q8_0, + .gemm = lm_ggml_gemm_q4_0_8x8_q8_0, + }, + [LM_GGML_TYPE_TQ1_0] = { + .from_float = quantize_row_tq1_0, + .vec_dot = lm_ggml_vec_dot_tq1_0_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, + [LM_GGML_TYPE_TQ2_0] = { + .from_float = quantize_row_tq2_0, + .vec_dot = lm_ggml_vec_dot_tq2_0_q8_K, + .vec_dot_type = LM_GGML_TYPE_Q8_K, + .nrows = 1, + }, +}; + +const struct lm_ggml_type_traits_cpu * lm_ggml_get_type_traits_cpu(enum lm_ggml_type type) { + return &type_traits_cpu[type]; +} + +// +// simd mappings +// + +// we define a common set of C macros which map to specific intrinsics based on the current architecture +// we then implement the fundamental computation operations below using only these macros +// adding support for new architectures requires to define the corresponding SIMD macros +// +// LM_GGML_F32_STEP / LM_GGML_F16_STEP +// number of elements to process in a single step +// +// LM_GGML_F32_EPR / LM_GGML_F16_EPR +// number of elements to fit in a single register +// + +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) + +#define LM_GGML_SIMD + +// F32 NEON + +#define LM_GGML_F32_STEP 16 +#define LM_GGML_F32_EPR 4 + +#define LM_GGML_F32x4 float32x4_t +#define LM_GGML_F32x4_ZERO vdupq_n_f32(0.0f) +#define LM_GGML_F32x4_SET1(x) vdupq_n_f32(x) +#define LM_GGML_F32x4_LOAD vld1q_f32 +#define LM_GGML_F32x4_STORE vst1q_f32 +#define LM_GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) +#define LM_GGML_F32x4_ADD vaddq_f32 +#define LM_GGML_F32x4_MUL vmulq_f32 +#define LM_GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) +#define LM_GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = LM_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ + } \ + (res) = LM_GGML_F32x4_REDUCE_ONE((x)[0]); \ +} + +#define LM_GGML_F32_VEC LM_GGML_F32x4 +#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x4_ZERO +#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x4_SET1 +#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x4_LOAD +#define LM_GGML_F32_VEC_STORE LM_GGML_F32x4_STORE +#define LM_GGML_F32_VEC_FMA LM_GGML_F32x4_FMA +#define LM_GGML_F32_VEC_ADD LM_GGML_F32x4_ADD +#define LM_GGML_F32_VEC_MUL LM_GGML_F32x4_MUL +#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x4_REDUCE + +// F16 NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + #define LM_GGML_F16_STEP 32 + #define LM_GGML_F16_EPR 8 + + #define LM_GGML_F16x8 float16x8_t + #define LM_GGML_F16x8_ZERO vdupq_n_f16(0.0f) + #define LM_GGML_F16x8_SET1(x) vdupq_n_f16(x) + #define LM_GGML_F16x8_LOAD(x) vld1q_f16((const lm_ggml_fp16_internal_t *)(x)) + #define LM_GGML_F16x8_STORE vst1q_f16 + #define LM_GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) + #define LM_GGML_F16x8_ADD vaddq_f16 + #define LM_GGML_F16x8_MUL vmulq_f16 + #define LM_GGML_F16x8_REDUCE(res, x) \ + do { \ + int offset = LM_GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \ + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \ + (res) = (lm_ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ + } while (0) + + #define LM_GGML_F16_VEC LM_GGML_F16x8 + #define LM_GGML_F16_VEC_ZERO LM_GGML_F16x8_ZERO + #define LM_GGML_F16_VEC_SET1 LM_GGML_F16x8_SET1 + #define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F16x8_LOAD(p) + #define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F16x8_STORE((lm_ggml_fp16_internal_t *)(p), (r)[i]) + #define LM_GGML_F16_VEC_FMA LM_GGML_F16x8_FMA + #define LM_GGML_F16_VEC_ADD LM_GGML_F16x8_ADD + #define LM_GGML_F16_VEC_MUL LM_GGML_F16x8_MUL + #define LM_GGML_F16_VEC_REDUCE LM_GGML_F16x8_REDUCE +#else + // if FP16 vector arithmetic is not supported, we use FP32 instead + // and take advantage of the vcvt_ functions to convert to/from FP16 + + #define LM_GGML_F16_STEP 16 + #define LM_GGML_F16_EPR 4 + + #define LM_GGML_F32Cx4 float32x4_t + #define LM_GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) + #define LM_GGML_F32Cx4_SET1(x) vdupq_n_f32(x) + #define LM_GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const lm_ggml_fp16_internal_t *)(x))) + #define LM_GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) + #define LM_GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) + #define LM_GGML_F32Cx4_ADD vaddq_f32 + #define LM_GGML_F32Cx4_MUL vmulq_f32 + #define LM_GGML_F32Cx4_REDUCE LM_GGML_F32x4_REDUCE + + #define LM_GGML_F16_VEC LM_GGML_F32Cx4 + #define LM_GGML_F16_VEC_ZERO LM_GGML_F32Cx4_ZERO + #define LM_GGML_F16_VEC_SET1 LM_GGML_F32Cx4_SET1 + #define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F32Cx4_LOAD(p) + #define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F32Cx4_STORE((lm_ggml_fp16_internal_t *)(p), r[i]) + #define LM_GGML_F16_VEC_FMA LM_GGML_F32Cx4_FMA + #define LM_GGML_F16_VEC_ADD LM_GGML_F32Cx4_ADD + #define LM_GGML_F16_VEC_MUL LM_GGML_F32Cx4_MUL + #define LM_GGML_F16_VEC_REDUCE LM_GGML_F32Cx4_REDUCE +#endif + +#elif defined(__AVX512F__) + +#define LM_GGML_SIMD + +// F32 AVX512 + +#define LM_GGML_F32_STEP 64 +#define LM_GGML_F32_EPR 16 + +#define LM_GGML_F32x16 __m512 +#define LM_GGML_F32x16_ZERO _mm512_setzero_ps() +#define LM_GGML_F32x16_SET1(x) _mm512_set1_ps(x) +#define LM_GGML_F32x16_LOAD _mm512_loadu_ps +#define LM_GGML_F32x16_STORE _mm512_storeu_ps +// _mm512_fmadd_ps is defined in AVX512F so no guard is required +#define LM_GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) +#define LM_GGML_F32x16_ADD _mm512_add_ps +#define LM_GGML_F32x16_MUL _mm512_mul_ps +#define LM_GGML_F32x16_REDUCE(res, x) \ +do { \ + int offset = LM_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ps(x[i], x[offset+i]); \ + } \ + res = _mm512_reduce_add_ps(x[0]); \ +} while (0) + +// TODO: is this optimal ? + +#define LM_GGML_F32_VEC LM_GGML_F32x16 +#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x16_ZERO +#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x16_SET1 +#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x16_LOAD +#define LM_GGML_F32_VEC_STORE LM_GGML_F32x16_STORE +#define LM_GGML_F32_VEC_FMA LM_GGML_F32x16_FMA +#define LM_GGML_F32_VEC_ADD LM_GGML_F32x16_ADD +#define LM_GGML_F32_VEC_MUL LM_GGML_F32x16_MUL +#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x16_REDUCE + +// F16 AVX512 + +// F16 AVX + +#define LM_GGML_F16_STEP 64 +#define LM_GGML_F16_EPR 16 + +// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead + +#define LM_GGML_F32Cx16 __m512 +#define LM_GGML_F32Cx16_ZERO _mm512_setzero_ps() +#define LM_GGML_F32Cx16_SET1(x) _mm512_set1_ps(x) + +// unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F +// so F16C guard isn't required +#define LM_GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) +#define LM_GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0)) + +#define LM_GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) +#define LM_GGML_F32Cx16_ADD _mm512_add_ps +#define LM_GGML_F32Cx16_MUL _mm512_mul_ps +#define LM_GGML_F32Cx16_REDUCE(res, x) \ +do { \ + int offset = LM_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ps(x[i], x[offset+i]); \ + } \ + res = _mm512_reduce_add_ps(x[0]); \ +} while (0) + +#define LM_GGML_F16_VEC LM_GGML_F32Cx16 +#define LM_GGML_F16_VEC_ZERO LM_GGML_F32Cx16_ZERO +#define LM_GGML_F16_VEC_SET1 LM_GGML_F32Cx16_SET1 +#define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F32Cx16_LOAD(p) +#define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F32Cx16_STORE(p, r[i]) +#define LM_GGML_F16_VEC_FMA LM_GGML_F32Cx16_FMA +#define LM_GGML_F16_VEC_ADD LM_GGML_F32Cx16_ADD +#define LM_GGML_F16_VEC_MUL LM_GGML_F32Cx16_MUL +#define LM_GGML_F16_VEC_REDUCE LM_GGML_F32Cx16_REDUCE + +#elif defined(__AVX__) + +#define LM_GGML_SIMD + +// F32 AVX + +#define LM_GGML_F32_STEP 32 +#define LM_GGML_F32_EPR 8 + +#define LM_GGML_F32x8 __m256 +#define LM_GGML_F32x8_ZERO _mm256_setzero_ps() +#define LM_GGML_F32x8_SET1(x) _mm256_set1_ps(x) +#define LM_GGML_F32x8_LOAD _mm256_loadu_ps +#define LM_GGML_F32x8_STORE _mm256_storeu_ps +#if defined(__FMA__) + #define LM_GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) +#else + #define LM_GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) +#endif +#define LM_GGML_F32x8_ADD _mm256_add_ps +#define LM_GGML_F32x8_MUL _mm256_mul_ps +#define LM_GGML_F32x8_REDUCE(res, x) \ +do { \ + int offset = LM_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ + _mm256_extractf128_ps(x[0], 1)); \ + const __m128 t1 = _mm_hadd_ps(t0, t0); \ + res = (lm_ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ +} while (0) +// TODO: is this optimal ? + +#define LM_GGML_F32_VEC LM_GGML_F32x8 +#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x8_ZERO +#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x8_SET1 +#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x8_LOAD +#define LM_GGML_F32_VEC_STORE LM_GGML_F32x8_STORE +#define LM_GGML_F32_VEC_FMA LM_GGML_F32x8_FMA +#define LM_GGML_F32_VEC_ADD LM_GGML_F32x8_ADD +#define LM_GGML_F32_VEC_MUL LM_GGML_F32x8_MUL +#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x8_REDUCE + +// F16 AVX + +#define LM_GGML_F16_STEP 32 +#define LM_GGML_F16_EPR 8 + +// F16 arithmetic is not supported by AVX, so we use F32 instead + +#define LM_GGML_F32Cx8 __m256 +#define LM_GGML_F32Cx8_ZERO _mm256_setzero_ps() +#define LM_GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) + +#if defined(__F16C__) +// the _mm256_cvt intrinsics require F16C +#define LM_GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) +#define LM_GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) +#else +static inline __m256 __avx_f32cx8_load(lm_ggml_fp16_t *x) { + float tmp[8]; + + for (int i = 0; i < 8; i++) { + tmp[i] = LM_GGML_FP16_TO_FP32(x[i]); + } + + return _mm256_loadu_ps(tmp); +} +static inline void __avx_f32cx8_store(lm_ggml_fp16_t *x, __m256 y) { + float arr[8]; + + _mm256_storeu_ps(arr, y); + + for (int i = 0; i < 8; i++) + x[i] = LM_GGML_FP32_TO_FP16(arr[i]); +} +#define LM_GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) +#define LM_GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y) +#endif + +#define LM_GGML_F32Cx8_FMA LM_GGML_F32x8_FMA +#define LM_GGML_F32Cx8_ADD _mm256_add_ps +#define LM_GGML_F32Cx8_MUL _mm256_mul_ps +#define LM_GGML_F32Cx8_REDUCE LM_GGML_F32x8_REDUCE + +#define LM_GGML_F16_VEC LM_GGML_F32Cx8 +#define LM_GGML_F16_VEC_ZERO LM_GGML_F32Cx8_ZERO +#define LM_GGML_F16_VEC_SET1 LM_GGML_F32Cx8_SET1 +#define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F32Cx8_LOAD(p) +#define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F32Cx8_STORE(p, r[i]) +#define LM_GGML_F16_VEC_FMA LM_GGML_F32Cx8_FMA +#define LM_GGML_F16_VEC_ADD LM_GGML_F32Cx8_ADD +#define LM_GGML_F16_VEC_MUL LM_GGML_F32Cx8_MUL +#define LM_GGML_F16_VEC_REDUCE LM_GGML_F32Cx8_REDUCE + +#elif defined(__POWER9_VECTOR__) + +#define LM_GGML_SIMD + +// F32 POWER9 + +#define LM_GGML_F32_STEP 32 +#define LM_GGML_F32_EPR 4 + +#define LM_GGML_F32x4 vector float +#define LM_GGML_F32x4_ZERO 0.0f +#define LM_GGML_F32x4_SET1 vec_splats +#define LM_GGML_F32x4_LOAD(p) vec_xl(0, p) +#define LM_GGML_F32x4_STORE(p, r) vec_xst(r, 0, p) +#define LM_GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) +#define LM_GGML_F32x4_ADD vec_add +#define LM_GGML_F32x4_MUL vec_mul +#define LM_GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = LM_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + res = vec_extract(x[0], 0) + \ + vec_extract(x[0], 1) + \ + vec_extract(x[0], 2) + \ + vec_extract(x[0], 3); \ +} + +#define LM_GGML_F32_VEC LM_GGML_F32x4 +#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x4_ZERO +#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x4_SET1 +#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x4_LOAD +#define LM_GGML_F32_VEC_STORE LM_GGML_F32x4_STORE +#define LM_GGML_F32_VEC_FMA LM_GGML_F32x4_FMA +#define LM_GGML_F32_VEC_ADD LM_GGML_F32x4_ADD +#define LM_GGML_F32_VEC_MUL LM_GGML_F32x4_MUL +#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x4_REDUCE + +// F16 POWER9 +#define LM_GGML_F16_STEP LM_GGML_F32_STEP +#define LM_GGML_F16_EPR LM_GGML_F32_EPR +#define LM_GGML_F16_VEC LM_GGML_F32x4 +#define LM_GGML_F16_VEC_ZERO LM_GGML_F32x4_ZERO +#define LM_GGML_F16_VEC_SET1 LM_GGML_F32x4_SET1 +#define LM_GGML_F16_VEC_FMA LM_GGML_F32x4_FMA +#define LM_GGML_F16_VEC_ADD LM_GGML_F32x4_ADD +#define LM_GGML_F16_VEC_MUL LM_GGML_F32x4_MUL +#define LM_GGML_F16_VEC_REDUCE LM_GGML_F32x4_REDUCE +// Use vec_xl, not vec_ld, in case the load address is not aligned. +#define LM_GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ + vec_extract_fp32_from_shorth(vec_xl(0, p - LM_GGML_F16_EPR)) : \ + vec_extract_fp32_from_shortl(vec_xl(0, p)) +#define LM_GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] +#define LM_GGML_F16_VEC_STORE(p, r, i) \ + if (i & 0x1) \ + vec_xst(vec_pack_to_short_fp32(r[i - LM_GGML_ENDIAN_BYTE(1)], \ + r[i - LM_GGML_ENDIAN_BYTE(0)]), \ + 0, p - LM_GGML_F16_EPR) + +#elif defined(__wasm_simd128__) + +#define LM_GGML_SIMD + +// F32 WASM + +#define LM_GGML_F32_STEP 16 +#define LM_GGML_F32_EPR 4 + +#define LM_GGML_F32x4 v128_t +#define LM_GGML_F32x4_ZERO wasm_f32x4_splat(0.0f) +#define LM_GGML_F32x4_SET1(x) wasm_f32x4_splat(x) +#define LM_GGML_F32x4_LOAD wasm_v128_load +#define LM_GGML_F32x4_STORE wasm_v128_store +#define LM_GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) +#define LM_GGML_F32x4_ADD wasm_f32x4_add +#define LM_GGML_F32x4_MUL wasm_f32x4_mul +#define LM_GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = LM_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define LM_GGML_F32_VEC LM_GGML_F32x4 +#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x4_ZERO +#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x4_SET1 +#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x4_LOAD +#define LM_GGML_F32_VEC_STORE LM_GGML_F32x4_STORE +#define LM_GGML_F32_VEC_FMA LM_GGML_F32x4_FMA +#define LM_GGML_F32_VEC_ADD LM_GGML_F32x4_ADD +#define LM_GGML_F32_VEC_MUL LM_GGML_F32x4_MUL +#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x4_REDUCE + +// F16 WASM + +#define LM_GGML_F16_STEP 16 +#define LM_GGML_F16_EPR 4 + +inline static v128_t __wasm_f16x4_load(const lm_ggml_fp16_t * p) { + float tmp[4]; + + tmp[0] = LM_GGML_FP16_TO_FP32(p[0]); + tmp[1] = LM_GGML_FP16_TO_FP32(p[1]); + tmp[2] = LM_GGML_FP16_TO_FP32(p[2]); + tmp[3] = LM_GGML_FP16_TO_FP32(p[3]); + + return wasm_v128_load(tmp); +} + +inline static void __wasm_f16x4_store(lm_ggml_fp16_t * p, v128_t x) { + float tmp[4]; + + wasm_v128_store(tmp, x); + + p[0] = LM_GGML_FP32_TO_FP16(tmp[0]); + p[1] = LM_GGML_FP32_TO_FP16(tmp[1]); + p[2] = LM_GGML_FP32_TO_FP16(tmp[2]); + p[3] = LM_GGML_FP32_TO_FP16(tmp[3]); +} + +#define LM_GGML_F16x4 v128_t +#define LM_GGML_F16x4_ZERO wasm_f32x4_splat(0.0f) +#define LM_GGML_F16x4_SET1(x) wasm_f32x4_splat(x) +#define LM_GGML_F16x4_LOAD(x) __wasm_f16x4_load(x) +#define LM_GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) +#define LM_GGML_F16x4_FMA LM_GGML_F32x4_FMA +#define LM_GGML_F16x4_ADD wasm_f32x4_add +#define LM_GGML_F16x4_MUL wasm_f32x4_mul +#define LM_GGML_F16x4_REDUCE(res, x) \ +{ \ + int offset = LM_GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define LM_GGML_F16_VEC LM_GGML_F16x4 +#define LM_GGML_F16_VEC_ZERO LM_GGML_F16x4_ZERO +#define LM_GGML_F16_VEC_SET1 LM_GGML_F16x4_SET1 +#define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F16x4_LOAD(p) +#define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F16x4_STORE(p, r[i]) +#define LM_GGML_F16_VEC_FMA LM_GGML_F16x4_FMA +#define LM_GGML_F16_VEC_ADD LM_GGML_F16x4_ADD +#define LM_GGML_F16_VEC_MUL LM_GGML_F16x4_MUL +#define LM_GGML_F16_VEC_REDUCE LM_GGML_F16x4_REDUCE + +#elif defined(__SSE3__) + +#define LM_GGML_SIMD + +// F32 SSE + +#define LM_GGML_F32_STEP 32 +#define LM_GGML_F32_EPR 4 + +#define LM_GGML_F32x4 __m128 +#define LM_GGML_F32x4_ZERO _mm_setzero_ps() +#define LM_GGML_F32x4_SET1(x) _mm_set1_ps(x) +#define LM_GGML_F32x4_LOAD _mm_loadu_ps +#define LM_GGML_F32x4_STORE _mm_storeu_ps +#if defined(__FMA__) + // TODO: Does this work? + #define LM_GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a) +#else + #define LM_GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a) +#endif +#define LM_GGML_F32x4_ADD _mm_add_ps +#define LM_GGML_F32x4_MUL _mm_mul_ps +#define LM_GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = LM_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \ + res = (lm_ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \ +} +// TODO: is this optimal ? + +#define LM_GGML_F32_VEC LM_GGML_F32x4 +#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x4_ZERO +#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x4_SET1 +#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x4_LOAD +#define LM_GGML_F32_VEC_STORE LM_GGML_F32x4_STORE +#define LM_GGML_F32_VEC_FMA LM_GGML_F32x4_FMA +#define LM_GGML_F32_VEC_ADD LM_GGML_F32x4_ADD +#define LM_GGML_F32_VEC_MUL LM_GGML_F32x4_MUL +#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x4_REDUCE + +// F16 SSE + +#define LM_GGML_F16_STEP 32 +#define LM_GGML_F16_EPR 4 + +static inline __m128 __sse_f16x4_load(lm_ggml_fp16_t *x) { + float tmp[4]; + + tmp[0] = LM_GGML_FP16_TO_FP32(x[0]); + tmp[1] = LM_GGML_FP16_TO_FP32(x[1]); + tmp[2] = LM_GGML_FP16_TO_FP32(x[2]); + tmp[3] = LM_GGML_FP16_TO_FP32(x[3]); + + return _mm_loadu_ps(tmp); +} + +static inline void __sse_f16x4_store(lm_ggml_fp16_t *x, __m128 y) { + float arr[4]; + + _mm_storeu_ps(arr, y); + + x[0] = LM_GGML_FP32_TO_FP16(arr[0]); + x[1] = LM_GGML_FP32_TO_FP16(arr[1]); + x[2] = LM_GGML_FP32_TO_FP16(arr[2]); + x[3] = LM_GGML_FP32_TO_FP16(arr[3]); +} + +#define LM_GGML_F32Cx4 __m128 +#define LM_GGML_F32Cx4_ZERO _mm_setzero_ps() +#define LM_GGML_F32Cx4_SET1(x) _mm_set1_ps(x) +#define LM_GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x) +#define LM_GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y) +#define LM_GGML_F32Cx4_FMA LM_GGML_F32x4_FMA +#define LM_GGML_F32Cx4_ADD _mm_add_ps +#define LM_GGML_F32Cx4_MUL _mm_mul_ps +#define LM_GGML_F32Cx4_REDUCE LM_GGML_F32x4_REDUCE + +#define LM_GGML_F16_VEC LM_GGML_F32Cx4 +#define LM_GGML_F16_VEC_ZERO LM_GGML_F32Cx4_ZERO +#define LM_GGML_F16_VEC_SET1 LM_GGML_F32Cx4_SET1 +#define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F32Cx4_LOAD(p) +#define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F32Cx4_STORE(p, r[i]) +#define LM_GGML_F16_VEC_FMA LM_GGML_F32Cx4_FMA +#define LM_GGML_F16_VEC_ADD LM_GGML_F32Cx4_ADD +#define LM_GGML_F16_VEC_MUL LM_GGML_F32Cx4_MUL +#define LM_GGML_F16_VEC_REDUCE LM_GGML_F32Cx4_REDUCE + +#elif defined(__loongarch_asx) + +#define LM_GGML_SIMD + +// F32 LASX +#define LM_GGML_F32_STEP 32 +#define LM_GGML_F32_EPR 8 + +#define LM_GGML_F32x8 __m256 +#define LM_GGML_F32x8_ZERO (__m256)__lasx_xvldi(0) +#define LM_GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x)) +#define LM_GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0) +#define LM_GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0) +#define LM_GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a) +#define LM_GGML_F32x8_ADD __lasx_xvfadd_s +#define LM_GGML_F32x8_MUL __lasx_xvfmul_s +#define LM_GGML_F32x8_REDUCE(res, x) \ +do { \ + int offset = LM_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ + } \ + float *tmp_p = (float *)&x[0]; \ + res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \ +} while (0) +// TODO: is this optimal ? + +#define LM_GGML_F32_VEC LM_GGML_F32x8 +#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x8_ZERO +#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x8_SET1 +#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x8_LOAD +#define LM_GGML_F32_VEC_STORE LM_GGML_F32x8_STORE +#define LM_GGML_F32_VEC_FMA LM_GGML_F32x8_FMA +#define LM_GGML_F32_VEC_ADD LM_GGML_F32x8_ADD +#define LM_GGML_F32_VEC_MUL LM_GGML_F32x8_MUL +#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x8_REDUCE + +// F16 LASX + +#define LM_GGML_F16_STEP 32 +#define LM_GGML_F16_EPR 8 + +// F16 arithmetic is not supported by AVX, so we use F32 instead + +#define LM_GGML_F32Cx8 __m256 +#define LM_GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0) +#define LM_GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x)) + +static inline __m256 __lasx_f32cx8_load(const lm_ggml_fp16_t * x) { + float tmp[8]; + + for (int i = 0; i < 8; i++) { + tmp[i] = LM_GGML_FP16_TO_FP32(x[i]); + } + + return (__m256)__lasx_xvld(tmp, 0); +} +static inline void __lasx_f32cx8_store(lm_ggml_fp16_t * x, __m256 y) { + float arr[8]; + + __lasx_xvst(y, arr, 0); + + for (int i = 0; i < 8; i++) { + x[i] = LM_GGML_FP32_TO_FP16(arr[i]); + } +} +#define LM_GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x) +#define LM_GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y) + +#define LM_GGML_F32Cx8_FMA LM_GGML_F32x8_FMA +#define LM_GGML_F32Cx8_ADD __lasx_xvfadd_s +#define LM_GGML_F32Cx8_MUL __lasx_xvfmul_s +#define LM_GGML_F32Cx8_REDUCE LM_GGML_F32x8_REDUCE + +#define LM_GGML_F16_VEC LM_GGML_F32Cx8 +#define LM_GGML_F16_VEC_ZERO LM_GGML_F32Cx8_ZERO +#define LM_GGML_F16_VEC_SET1 LM_GGML_F32Cx8_SET1 +#define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F32Cx8_LOAD(p) +#define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F32Cx8_STORE(p, r[i]) +#define LM_GGML_F16_VEC_FMA LM_GGML_F32Cx8_FMA +#define LM_GGML_F16_VEC_ADD LM_GGML_F32Cx8_ADD +#define LM_GGML_F16_VEC_MUL LM_GGML_F32Cx8_MUL +#define LM_GGML_F16_VEC_REDUCE LM_GGML_F32Cx8_REDUCE + +#elif defined(__loongarch_sx) + +#define LM_GGML_SIMD + +// F32 LSX + +#define LM_GGML_F32_STEP 32 +#define LM_GGML_F32_EPR 4 + +#define LM_GGML_F32x4 __m128 +#define LM_GGML_F32x4_ZERO __lsx_vldi(0) +#define LM_GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) +#define LM_GGML_F32x4_LOAD(x) __lsx_vld((x), 0) +#define LM_GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0) +#define LM_GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a) +#define LM_GGML_F32x4_ADD __lsx_vfadd_s +#define LM_GGML_F32x4_MUL __lsx_vfmul_s +#define LM_GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = LM_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ + } \ + __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \ + tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \ + tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ + const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \ + tmp = __lsx_vsrli_d((__m128i)t0, 32); \ + tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \ + tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ + res = (lm_ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \ +} + +#define LM_GGML_F32_VEC LM_GGML_F32x4 +#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x4_ZERO +#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x4_SET1 +#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x4_LOAD +#define LM_GGML_F32_VEC_STORE LM_GGML_F32x4_STORE +#define LM_GGML_F32_VEC_FMA LM_GGML_F32x4_FMA +#define LM_GGML_F32_VEC_ADD LM_GGML_F32x4_ADD +#define LM_GGML_F32_VEC_MUL LM_GGML_F32x4_MUL +#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x4_REDUCE + +// F16 LSX + +#define LM_GGML_F16_STEP 32 +#define LM_GGML_F16_EPR 4 + +static inline __m128 __lsx_f16x4_load(const lm_ggml_fp16_t * x) { + float tmp[4]; + + tmp[0] = LM_GGML_FP16_TO_FP32(x[0]); + tmp[1] = LM_GGML_FP16_TO_FP32(x[1]); + tmp[2] = LM_GGML_FP16_TO_FP32(x[2]); + tmp[3] = LM_GGML_FP16_TO_FP32(x[3]); + + return __lsx_vld(tmp, 0); +} + +static inline void __lsx_f16x4_store(lm_ggml_fp16_t * x, __m128 y) { + float arr[4]; + + __lsx_vst(y, arr, 0); + + x[0] = LM_GGML_FP32_TO_FP16(arr[0]); + x[1] = LM_GGML_FP32_TO_FP16(arr[1]); + x[2] = LM_GGML_FP32_TO_FP16(arr[2]); + x[3] = LM_GGML_FP32_TO_FP16(arr[3]); +} + +#define LM_GGML_F32Cx4 __m128 +#define LM_GGML_F32Cx4_ZERO __lsx_vldi(0) +#define LM_GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) +#define LM_GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x) +#define LM_GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y) +#define LM_GGML_F32Cx4_FMA LM_GGML_F32x4_FMA +#define LM_GGML_F32Cx4_ADD __lsx_vfadd_s +#define LM_GGML_F32Cx4_MUL __lsx_vfmul_s +#define LM_GGML_F32Cx4_REDUCE LM_GGML_F32x4_REDUCE + +#define LM_GGML_F16_VEC LM_GGML_F32Cx4 +#define LM_GGML_F16_VEC_ZERO LM_GGML_F32Cx4_ZERO +#define LM_GGML_F16_VEC_SET1 LM_GGML_F32Cx4_SET1 +#define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F32Cx4_LOAD(p) +#define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F32Cx4_STORE(p, r[i]) +#define LM_GGML_F16_VEC_FMA LM_GGML_F32Cx4_FMA +#define LM_GGML_F16_VEC_ADD LM_GGML_F32Cx4_ADD +#define LM_GGML_F16_VEC_MUL LM_GGML_F32Cx4_MUL +#define LM_GGML_F16_VEC_REDUCE LM_GGML_F32Cx4_REDUCE + +#endif + +// LM_GGML_F32_ARR / LM_GGML_F16_ARR +// number of registers to use per step +#ifdef LM_GGML_SIMD +#define LM_GGML_F32_ARR (LM_GGML_F32_STEP/LM_GGML_F32_EPR) +#define LM_GGML_F16_ARR (LM_GGML_F16_STEP/LM_GGML_F16_EPR) +#endif + +// +// Threading defs +// + +typedef pthread_t lm_ggml_thread_t; + +#if defined(_WIN32) + +typedef CONDITION_VARIABLE lm_ggml_cond_t; +typedef SRWLOCK lm_ggml_mutex_t; + +#define lm_ggml_mutex_init(m) InitializeSRWLock(m) +#define lm_ggml_mutex_destroy(m) +#define lm_ggml_mutex_lock(m) AcquireSRWLockExclusive(m) +#define lm_ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m) +#define lm_ggml_mutex_lock_shared(m) AcquireSRWLockShared(m) +#define lm_ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m) + +#define lm_ggml_cond_init(c) InitializeConditionVariable(c) +#define lm_ggml_cond_destroy(c) +#define lm_ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED) +#define lm_ggml_cond_broadcast(c) WakeAllConditionVariable(c) + +#define lm_ggml_thread_create pthread_create +#define lm_ggml_thread_join pthread_join + +#else + +typedef pthread_cond_t lm_ggml_cond_t; +typedef pthread_mutex_t lm_ggml_mutex_t; + +#define lm_ggml_mutex_init(m) pthread_mutex_init(m, NULL) +#define lm_ggml_mutex_destroy(m) pthread_mutex_destroy(m) +#define lm_ggml_mutex_lock(m) pthread_mutex_lock(m) +#define lm_ggml_mutex_unlock(m) pthread_mutex_unlock(m) +#define lm_ggml_mutex_lock_shared(m) pthread_mutex_lock(m) +#define lm_ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m) + +#define lm_ggml_lock_init(x) UNUSED(x) +#define lm_ggml_lock_destroy(x) UNUSED(x) +#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) +#define lm_ggml_lock_lock(x) _mm_pause() +#else +#define lm_ggml_lock_lock(x) UNUSED(x) +#endif +#define lm_ggml_lock_unlock(x) UNUSED(x) + +#define LM_GGML_LOCK_INITIALIZER 0 +#define lm_ggml_cond_init(c) pthread_cond_init(c, NULL) +#define lm_ggml_cond_destroy(c) pthread_cond_destroy(c) +#define lm_ggml_cond_wait(c, m) pthread_cond_wait(c, m) +#define lm_ggml_cond_broadcast(c) pthread_cond_broadcast(c) + +#define lm_ggml_thread_create pthread_create +#define lm_ggml_thread_join pthread_join + +#endif + +// Threadpool def +struct lm_ggml_threadpool { + lm_ggml_mutex_t mutex; // mutex for cond.var + lm_ggml_cond_t cond; // cond.var for waiting for new work + + struct lm_ggml_cgraph * cgraph; + struct lm_ggml_cplan * cplan; + + // synchronization primitives + atomic_int n_graph; // incremented when there is work to be done (i.e each graph) + atomic_int LM_GGML_CACHE_ALIGN n_barrier; + atomic_int LM_GGML_CACHE_ALIGN n_barrier_passed; + atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. + + // these are atomic as an annotation for thread-sanitizer + atomic_bool stop; // Used for stopping the threadpool altogether + atomic_bool pause; // Used for pausing the threadpool or individual threads + atomic_bool abort; // Used for aborting processing of a graph + + struct lm_ggml_compute_state * workers; // per thread state + int n_threads_max; // number of threads in the pool + atomic_int n_threads_cur; // number of threads used in the current graph + + int32_t prio; // Scheduling priority + uint32_t poll; // Polling level (0 - no polling) + + enum lm_ggml_status ec; +}; + +// Per-thread state +struct lm_ggml_compute_state { +#ifndef LM_GGML_USE_OPENMP + lm_ggml_thread_t thrd; + bool cpumask[LM_GGML_MAX_N_THREADS]; + int last_graph; + bool pending; +#endif + struct lm_ggml_threadpool * threadpool; + int ith; +}; + +struct lm_ggml_compute_params { + // ith = thread index, nth = number of threads + int ith, nth; + + // work buffer for all threads + size_t wsize; + void * wdata; + + struct lm_ggml_threadpool * threadpool; +}; + +// +// fundamental operations +// + +inline static void lm_ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void lm_ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void lm_ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void lm_ggml_vec_set_f16(const int n, lm_ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void lm_ggml_vec_set_bf16(const int n, lm_ggml_bf16_t * x, const lm_ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void lm_ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } +inline static void lm_ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } +inline static void lm_ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } +inline static void lm_ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } +inline static void lm_ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } +inline static void lm_ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void lm_ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void lm_ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } +inline static void lm_ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void lm_ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } + +static void lm_ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + +#if defined(LM_GGML_SIMD) + float sumf = 0.0f; + const int np = (n & ~(LM_GGML_F32_STEP - 1)); + + LM_GGML_F32_VEC sum[LM_GGML_F32_ARR] = { LM_GGML_F32_VEC_ZERO }; + + LM_GGML_F32_VEC ax[LM_GGML_F32_ARR]; + LM_GGML_F32_VEC ay[LM_GGML_F32_ARR]; + + for (int i = 0; i < np; i += LM_GGML_F32_STEP) { + for (int j = 0; j < LM_GGML_F32_ARR; j++) { + ax[j] = LM_GGML_F32_VEC_LOAD(x + i + j*LM_GGML_F32_EPR); + ay[j] = LM_GGML_F32_VEC_LOAD(y + i + j*LM_GGML_F32_EPR); + + sum[j] = LM_GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + LM_GGML_F32_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += x[i]*y[i]; + } +#else + // scalar + lm_ggml_float sumf = 0.0; + for (int i = 0; i < n; ++i) { + sumf += (lm_ggml_float)(x[i]*y[i]); + } +#endif + + *s = sumf; +} + +static void lm_ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, lm_ggml_bf16_t * restrict x, size_t bx, lm_ggml_bf16_t * restrict y, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + int i = 0; + lm_ggml_float sumf = 0; + +#if defined(__AVX512BF16__) + __m512 c1 = _mm512_setzero_ps(); + __m512 c2 = _mm512_setzero_ps(); + for (; i + 64 <= n; i += 64) { + c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))), + m512bh(_mm512_loadu_si512((y + i)))); + c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))), + m512bh(_mm512_loadu_si512((y + i + 32)))); + } + sumf += (lm_ggml_float)_mm512_reduce_add_ps(c1); + sumf += (lm_ggml_float)_mm512_reduce_add_ps(c2); + +#elif defined(__AVX512F__) +#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16)) + __m512 c1 = _mm512_setzero_ps(); + __m512 c2 = _mm512_setzero_ps(); + for (; i + 32 <= n; i += 32) { + c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1); + c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2); + } + sumf += (lm_ggml_float)_mm512_reduce_add_ps(c1); + sumf += (lm_ggml_float)_mm512_reduce_add_ps(c2); + +#undef LOAD +#elif defined(__AVX2__) || defined(__AVX__) +#if defined(__AVX2__) +#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)) +#else +#define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1)) +#endif + __m256 c1 = _mm256_setzero_ps(); + __m256 c2 = _mm256_setzero_ps(); + __m256 c3 = _mm256_setzero_ps(); + __m256 c4 = _mm256_setzero_ps(); + for (; i + 32 <= n; i += 32) { + c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1); + c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2); + c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3); + c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4); + } + __m128 g; + c1 = _mm256_add_ps(_mm256_add_ps(c1, c3), + _mm256_add_ps(c2, c4)); + g = _mm_add_ps(_mm256_extractf128_ps(c1, 1), + _mm256_castps256_ps128(c1)); + g = _mm_add_ps(g, _mm_movehl_ps(g, g)); + g = _mm_add_ss(g, _mm_movehdup_ps(g)); + sumf += (lm_ggml_float)_mm_cvtss_f32(g); + +#undef LOAD +#endif + + for (; i < n; ++i) { + sumf += (lm_ggml_float)(LM_GGML_BF16_TO_FP32(x[i]) * + LM_GGML_BF16_TO_FP32(y[i])); + } + *s = sumf; +} + +static void lm_ggml_vec_dot_f16(int n, float * restrict s, size_t bs, lm_ggml_fp16_t * restrict x, size_t bx, lm_ggml_fp16_t * restrict y, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + lm_ggml_float sumf = 0.0; + +#if defined(LM_GGML_SIMD) + const int np = (n & ~(LM_GGML_F16_STEP - 1)); + + LM_GGML_F16_VEC sum[LM_GGML_F16_ARR] = { LM_GGML_F16_VEC_ZERO }; + + LM_GGML_F16_VEC ax[LM_GGML_F16_ARR]; + LM_GGML_F16_VEC ay[LM_GGML_F16_ARR]; + + for (int i = 0; i < np; i += LM_GGML_F16_STEP) { + for (int j = 0; j < LM_GGML_F16_ARR; j++) { + ax[j] = LM_GGML_F16_VEC_LOAD(x + i + j*LM_GGML_F16_EPR, j); + ay[j] = LM_GGML_F16_VEC_LOAD(y + i + j*LM_GGML_F16_EPR, j); + + sum[j] = LM_GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + LM_GGML_F16_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += (lm_ggml_float)(LM_GGML_FP16_TO_FP32(x[i])*LM_GGML_FP16_TO_FP32(y[i])); + } +#else + for (int i = 0; i < n; ++i) { + sumf += (lm_ggml_float)(LM_GGML_FP16_TO_FP32(x[i])*LM_GGML_FP16_TO_FP32(y[i])); + } +#endif + + *s = sumf; +} + +// compute LM_GGML_VEC_DOT_UNROLL dot products at once +// xs - x row stride in bytes +inline static void lm_ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, lm_ggml_fp16_t * restrict y) { + lm_ggml_float sumf[LM_GGML_VEC_DOT_UNROLL] = { 0.0 }; + + lm_ggml_fp16_t * restrict x[LM_GGML_VEC_DOT_UNROLL]; + + for (int i = 0; i < LM_GGML_VEC_DOT_UNROLL; ++i) { + x[i] = (lm_ggml_fp16_t *) ((char *) xv + i*xs); + } + +#if defined(LM_GGML_SIMD) + const int np = (n & ~(LM_GGML_F16_STEP - 1)); + + LM_GGML_F16_VEC sum[LM_GGML_VEC_DOT_UNROLL][LM_GGML_F16_ARR] = { { LM_GGML_F16_VEC_ZERO } }; + + LM_GGML_F16_VEC ax[LM_GGML_F16_ARR]; + LM_GGML_F16_VEC ay[LM_GGML_F16_ARR]; + + for (int i = 0; i < np; i += LM_GGML_F16_STEP) { + for (int j = 0; j < LM_GGML_F16_ARR; j++) { + ay[j] = LM_GGML_F16_VEC_LOAD(y + i + j*LM_GGML_F16_EPR, j); + + for (int k = 0; k < LM_GGML_VEC_DOT_UNROLL; ++k) { + ax[j] = LM_GGML_F16_VEC_LOAD(x[k] + i + j*LM_GGML_F16_EPR, j); + + sum[k][j] = LM_GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); + } + } + } + + // reduce sum0..sum3 to sum0 + for (int k = 0; k < LM_GGML_VEC_DOT_UNROLL; ++k) { + LM_GGML_F16_VEC_REDUCE(sumf[k], sum[k]); + } + + // leftovers + for (int i = np; i < n; ++i) { + for (int j = 0; j < LM_GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (lm_ggml_float)(LM_GGML_FP16_TO_FP32(x[j][i])*LM_GGML_FP16_TO_FP32(y[i])); + } + } +#else + for (int i = 0; i < n; ++i) { + for (int j = 0; j < LM_GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (lm_ggml_float)(LM_GGML_FP16_TO_FP32(x[j][i])*LM_GGML_FP16_TO_FP32(y[i])); + } + } +#endif + + for (int i = 0; i < LM_GGML_VEC_DOT_UNROLL; ++i) { + s[i] = sumf[i]; + } +} + +inline static void lm_ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { +#if defined(LM_GGML_SIMD) + const int np = (n & ~(LM_GGML_F32_STEP - 1)); + + LM_GGML_F32_VEC vx = LM_GGML_F32_VEC_SET1(v); + + LM_GGML_F32_VEC ax[LM_GGML_F32_ARR]; + LM_GGML_F32_VEC ay[LM_GGML_F32_ARR]; + + for (int i = 0; i < np; i += LM_GGML_F32_STEP) { + for (int j = 0; j < LM_GGML_F32_ARR; j++) { + ax[j] = LM_GGML_F32_VEC_LOAD(x + i + j*LM_GGML_F32_EPR); + ay[j] = LM_GGML_F32_VEC_LOAD(y + i + j*LM_GGML_F32_EPR); + ay[j] = LM_GGML_F32_VEC_FMA(ay[j], ax[j], vx); + + LM_GGML_F32_VEC_STORE(y + i + j*LM_GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += x[i]*v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] += x[i]*v; + } +#endif +} + +inline static void lm_ggml_vec_mad_f16(const int n, lm_ggml_fp16_t * restrict y, const lm_ggml_fp16_t * restrict x, const float v) { +#if defined(LM_GGML_SIMD) + const int np = (n & ~(LM_GGML_F16_STEP - 1)); + + LM_GGML_F16_VEC vx = LM_GGML_F16_VEC_SET1(v); + + LM_GGML_F16_VEC ax[LM_GGML_F16_ARR]; + LM_GGML_F16_VEC ay[LM_GGML_F16_ARR]; + + for (int i = 0; i < np; i += LM_GGML_F16_STEP) { + for (int j = 0; j < LM_GGML_F16_ARR; j++) { + ax[j] = LM_GGML_F16_VEC_LOAD(x + i + j*LM_GGML_F16_EPR, j); + ay[j] = LM_GGML_F16_VEC_LOAD(y + i + j*LM_GGML_F16_EPR, j); + ay[j] = LM_GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + LM_GGML_F16_VEC_STORE(y + i + j*LM_GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(y[i]) + LM_GGML_FP16_TO_FP32(x[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(y[i]) + LM_GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + +// xs and vs are byte strides of x and v +inline static void lm_ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { + + const float * restrict x[LM_GGML_VEC_MAD_UNROLL]; + const float * restrict v[LM_GGML_VEC_MAD_UNROLL]; + + for (int i = 0; i < LM_GGML_VEC_MAD_UNROLL; ++i) { + x[i] = (const float *) ((const char *) xv + i*xs); + v[i] = (const float *) ((const char *) vv + i*vs); + } + +#if defined(LM_GGML_SIMD) + const int np = (n & ~(LM_GGML_F32_STEP - 1)); + + LM_GGML_F32_VEC vx[LM_GGML_VEC_MAD_UNROLL]; + + for (int k = 0; k < LM_GGML_VEC_MAD_UNROLL; ++k) { + vx[k] = LM_GGML_F32_VEC_SET1(v[k][0]); + } + + LM_GGML_F32_VEC ax[LM_GGML_VEC_MAD_UNROLL][LM_GGML_F32_ARR]; + LM_GGML_F32_VEC ay[LM_GGML_F32_ARR]; + + for (int i = 0; i < np; i += LM_GGML_F32_STEP) { + for (int j = 0; j < LM_GGML_F32_ARR; j++) { + ay[j] = LM_GGML_F32_VEC_LOAD(y + i + j*LM_GGML_F32_EPR); + + for (int k = 0; k < LM_GGML_VEC_MAD_UNROLL; ++k) { + ax[k][j] = LM_GGML_F32_VEC_LOAD(x[k] + i + j*LM_GGML_F32_EPR); + ay[j] = LM_GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); + } + + LM_GGML_F32_VEC_STORE(y + i + j*LM_GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int k = 0; k < LM_GGML_VEC_MAD_UNROLL; ++k) { + for (int i = np; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } +#else + // scalar + for (int k = 0; k < LM_GGML_VEC_MAD_UNROLL; ++k) { + for (int i = 0; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } +#endif +} + +//inline static void lm_ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } +inline static void lm_ggml_vec_scale_f32(const int n, float * y, const float v) { +#if defined(LM_GGML_USE_ACCELERATE) + vDSP_vsmul(y, 1, &v, y, 1, n); +#elif defined(LM_GGML_SIMD) + const int np = (n & ~(LM_GGML_F32_STEP - 1)); + + LM_GGML_F32_VEC vx = LM_GGML_F32_VEC_SET1(v); + + LM_GGML_F32_VEC ay[LM_GGML_F32_ARR]; + + for (int i = 0; i < np; i += LM_GGML_F32_STEP) { + for (int j = 0; j < LM_GGML_F32_ARR; j++) { + ay[j] = LM_GGML_F32_VEC_LOAD(y + i + j*LM_GGML_F32_EPR); + ay[j] = LM_GGML_F32_VEC_MUL(ay[j], vx); + + LM_GGML_F32_VEC_STORE(y + i + j*LM_GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] *= v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] *= v; + } +#endif +} + +inline static void lm_ggml_vec_scale_f16(const int n, lm_ggml_fp16_t * y, const float v) { +#if defined(LM_GGML_SIMD) + const int np = (n & ~(LM_GGML_F16_STEP - 1)); + + LM_GGML_F16_VEC vx = LM_GGML_F16_VEC_SET1(v); + + LM_GGML_F16_VEC ay[LM_GGML_F16_ARR]; + + for (int i = 0; i < np; i += LM_GGML_F16_STEP) { + for (int j = 0; j < LM_GGML_F16_ARR; j++) { + ay[j] = LM_GGML_F16_VEC_LOAD(y + i + j*LM_GGML_F16_EPR, j); + ay[j] = LM_GGML_F16_VEC_MUL(ay[j], vx); + + LM_GGML_F16_VEC_STORE(y + i + j*LM_GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(y[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(y[i])*v); + } +#endif +} + +inline static void lm_ggml_vec_norm_f32 (const int n, float * s, const float * x) { lm_ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } +inline static void lm_ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } +inline static void lm_ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } +inline static void lm_ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } +inline static void lm_ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); } +inline static void lm_ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); } +inline static void lm_ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } +inline static void lm_ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } +inline static void lm_ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } +inline static void lm_ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } +inline static void lm_ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); } +inline static void lm_ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } +inline static void lm_ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); } +inline static void lm_ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); } +// TODO: optimize performance +inline static void lm_ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } +inline static void lm_ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } +inline static void lm_ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); } + +static const float GELU_COEF_A = 0.044715f; +static const float GELU_QUICK_COEF = -1.702f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +inline static float lm_ggml_gelu_f32(float x) { + return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +inline static void lm_ggml_vec_gelu_f16(const int n, lm_ggml_fp16_t * y, const lm_ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = lm_ggml_table_gelu_f16[i16[i]]; + } +} + +#ifdef LM_GGML_GELU_FP16 +inline static void lm_ggml_vec_gelu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + if (x[i] <= -10.0f) { + y[i] = 0.0f; + } else if (x[i] >= 10.0f) { + y[i] = x[i]; + } else { + lm_ggml_fp16_t fp16 = LM_GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = LM_GGML_FP16_TO_FP32(lm_ggml_table_gelu_f16[t]); + } + } +} +#else +inline static void lm_ggml_vec_gelu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = lm_ggml_gelu_f32(x[i]); + } +} +#endif + +inline static float lm_ggml_gelu_quick_f32(float x) { + return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); +} + +//inline static void lm_ggml_vec_gelu_quick_f16(const int n, lm_ggml_fp16_t * y, const lm_ggml_fp16_t * x) { +// const uint16_t * i16 = (const uint16_t *) x; +// for (int i = 0; i < n; ++i) { +// y[i] = lm_ggml_table_gelu_quick_f16[i16[i]]; +// } +//} + +#ifdef LM_GGML_GELU_QUICK_FP16 +inline static void lm_ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + lm_ggml_fp16_t fp16 = LM_GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = LM_GGML_FP16_TO_FP32(lm_ggml_table_gelu_quick_f16[t]); + } +} +#else +inline static void lm_ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = lm_ggml_gelu_quick_f32(x[i]); + } +} +#endif + +// Sigmoid Linear Unit (SiLU) function +inline static float lm_ggml_silu_f32(float x) { + return x/(1.0f + expf(-x)); +} + +#if __FINITE_MATH_ONLY__ +#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix" +#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461" +#endif + +#if defined(__ARM_NEON) && defined(__aarch64__) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static float32x4_t lm_ggml_v_expf(float32x4_t x) { + const float32x4_t r = vdupq_n_f32(0x1.8p23f); + const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f)); + const float32x4_t n = vsubq_f32(z, r); + const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n, + vdupq_n_f32(0x1.7f7d1cp-20f)); + const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23); + const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1)))); + const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126)); + const float32x4_t u = vmulq_f32(b, b); + const float32x4_t j = vfmaq_f32( + vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b), + vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b), + vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u); + if (!vpaddd_u64(vreinterpretq_u64_u32(c))) + return vfmaq_f32(k, j, k); + const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000)); + const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000))); + const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d)); + return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1), + vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static float32x4_t lm_ggml_v_silu(float32x4_t x) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t zero = vdupq_n_f32(0.0f); + const float32x4_t neg_x = vsubq_f32(zero, x); + const float32x4_t exp_neg_x = lm_ggml_v_expf(neg_x); + const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + return vdivq_f32(x, one_plus_exp_neg_x); +} + +#elif defined(__AVX512F__) && defined(__AVX512DQ__) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static __m512 lm_ggml_v_expf(__m512 x) { + const __m512 r = _mm512_set1_ps(0x1.8p23f); + const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); + const __m512 n = _mm512_sub_ps(z, r); + const __m512 b = + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); + const __mmask16 d = + _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); + const __m512 u = _mm512_mul_ps(b, b); + const __m512 j = _mm512_fmadd_ps( + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, + _mm512_set1_ps(0x1.573e2ep-5f)), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, + _mm512_set1_ps(0x1.fffdb6p-2f))), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); + const __m512 res = _mm512_scalef_ps(j, n); + if (_mm512_kortestz(d, d)) + return res; + const __m512 zero = _mm512_setzero_ps(); + const __m512 alt = _mm512_mask_blend_ps( + _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); + return _mm512_mask_blend_ps(d, res, alt); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m512 lm_ggml_v_silu(__m512 x) { + const __m512 one = _mm512_set1_ps(1); + const __m512 zero = _mm512_setzero_ps(); + const __m512 neg_x = _mm512_sub_ps(zero, x); + const __m512 exp_neg_x = lm_ggml_v_expf(neg_x); + const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + return _mm512_div_ps(x, one_plus_exp_neg_x); +} + +#elif defined(__AVX2__) && defined(__FMA__) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static __m256 lm_ggml_v_expf(__m256 x) { + const __m256 r = _mm256_set1_ps(0x1.8p23f); + const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r); + const __m256 n = _mm256_sub_ps(z, r); + const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f), + _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x)); + const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23); + const __m256 k = _mm256_castsi256_ps( + _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1)))); + const __m256i c = _mm256_castps_si256( + _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), + _mm256_set1_ps(126), _CMP_GT_OQ)); + const __m256 u = _mm256_mul_ps(b, b); + const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b, + _mm256_set1_ps(0x1.573e2ep-5f)), u, + _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b, + _mm256_set1_ps(0x1.fffdb6p-2f))), + u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b)); + if (!_mm256_movemask_ps(_mm256_castsi256_ps(c))) + return _mm256_fmadd_ps(j, k, k); + const __m256i g = _mm256_and_si256( + _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)), + _mm256_set1_epi32(0x82000000u)); + const __m256 s1 = + _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u))); + const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g)); + const __m256i d = _mm256_castps_si256( + _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), + _mm256_set1_ps(192), _CMP_GT_OQ)); + return _mm256_or_ps( + _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)), + _mm256_andnot_ps( + _mm256_castsi256_ps(d), + _mm256_or_ps( + _mm256_and_ps(_mm256_castsi256_ps(c), + _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)), + _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k))))); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m256 lm_ggml_v_silu(__m256 x) { + const __m256 one = _mm256_set1_ps(1); + const __m256 zero = _mm256_setzero_ps(); + const __m256 neg_x = _mm256_sub_ps(zero, x); + const __m256 exp_neg_x = lm_ggml_v_expf(neg_x); + const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + return _mm256_div_ps(x, one_plus_exp_neg_x); +} + +#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON + +#if defined(__FMA__) +#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z) +#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z) +#else +#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z) +#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y)) +#endif + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static __m128 lm_ggml_v_expf(__m128 x) { + const __m128 r = _mm_set1_ps(0x1.8p23f); + const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r); + const __m128 n = _mm_sub_ps(z, r); + const __m128 b = + NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x)); + const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23); + const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1)))); + const __m128i c = + _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126))); + const __m128 u = _mm_mul_ps(b, b); + const __m128 j = + MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u, + MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))), + u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b)); + if (!_mm_movemask_epi8(c)) + return MADD128(j, k, k); + const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())), + _mm_set1_epi32(0x82000000u)); + const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u))); + const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g)); + const __m128i d = + _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192))); + return _mm_or_ps( + _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)), + _mm_andnot_ps(_mm_castsi128_ps(d), + _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)), + _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k))))); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m128 lm_ggml_v_silu(__m128 x) { + const __m128 one = _mm_set1_ps(1); + const __m128 zero = _mm_setzero_ps(); + const __m128 neg_x = _mm_sub_ps(zero, x); + const __m128 exp_neg_x = lm_ggml_v_expf(neg_x); + const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x); + return _mm_div_ps(x, one_plus_exp_neg_x); +} + +#endif // __ARM_NEON / __AVX2__ / __SSE2__ + +static void lm_ggml_vec_silu_f32(const int n, float * y, const float * x) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, lm_ggml_v_silu(_mm512_loadu_ps(x + i))); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, lm_ggml_v_silu(_mm256_loadu_ps(x + i))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(y + i, lm_ggml_v_silu(_mm_loadu_ps(x + i))); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, lm_ggml_v_silu(vld1q_f32(x + i))); + } +#endif + for (; i < n; ++i) { + y[i] = lm_ggml_silu_f32(x[i]); + } +} + +static lm_ggml_float lm_ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { + int i = 0; + lm_ggml_float sum = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + __m512 val = lm_ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i), + _mm512_set1_ps(max))); + _mm512_storeu_ps(y + i, val); + sum += (lm_ggml_float)_mm512_reduce_add_ps(val); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + __m256 val = lm_ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i), + _mm256_set1_ps(max))); + _mm256_storeu_ps(y + i, val); + __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), + _mm256_castps256_ps128(val)); + val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); + val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); + sum += (lm_ggml_float)_mm_cvtss_f32(val2); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + __m128 val = lm_ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i), + _mm_set1_ps(max))); + _mm_storeu_ps(y + i, val); +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) + val = _mm_add_ps(val, _mm_movehl_ps(val, val)); + val = _mm_add_ss(val, _mm_movehdup_ps(val)); +#else + __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1)); + val = _mm_add_ps(val, tmp); + tmp = _mm_movehl_ps(tmp, val); + val = _mm_add_ss(val, tmp); +#endif + sum += (lm_ggml_float)_mm_cvtss_f32(val); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + float32x4_t val = lm_ggml_v_expf(vsubq_f32(vld1q_f32(x + i), + vdupq_n_f32(max))); + vst1q_f32(y + i, val); + sum += (lm_ggml_float)vaddvq_f32(val); + } +#endif + for (; i < n; ++i) { + float val = expf(x[i] - max); + sum += (lm_ggml_float)val; + y[i] = val; + } + return sum; +} + +static lm_ggml_float lm_ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) { + // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i) + + int i = 0; + lm_ggml_float sum = 0; + for (; i < n; ++i) { + float val = x[i] - max; + y[i] = val; + sum += (lm_ggml_float)expf(val); + } + return sum = (lm_ggml_float)logf(sum); +} + +inline static float lm_ggml_silu_backward_f32(float x, float dy) { + const float s = 1.0f/(1.0f + expf(-x)); + return dy*s*(1.0f + x*(1.0f - s)); +} + +inline static void lm_ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) { + for (int i = 0; i < n; ++i) { + dx[i] = lm_ggml_silu_backward_f32(x[i], dy[i]); + } +} + +inline static void lm_ggml_vec_sum_f32(const int n, float * s, const float * x) { +#ifndef LM_GGML_USE_ACCELERATE + lm_ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += (lm_ggml_float)x[i]; + } + *s = sum; +#else + vDSP_sve(x, 1, s, n); +#endif +} + +inline static void lm_ggml_vec_sum_f32_ggf(const int n, lm_ggml_float * s, const float * x) { + lm_ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += (lm_ggml_float)x[i]; + } + *s = sum; +} + +inline static void lm_ggml_vec_sum_f16_ggf(const int n, float * s, const lm_ggml_fp16_t * x) { + float sum = 0.0f; + for (int i = 0; i < n; ++i) { + sum += LM_GGML_FP16_TO_FP32(x[i]); + } + *s = sum; +} + +inline static void lm_ggml_vec_sum_bf16_ggf(const int n, float * s, const lm_ggml_bf16_t * x) { + float sum = 0.0f; + for (int i = 0; i < n; ++i) { + sum += LM_GGML_BF16_TO_FP32(x[i]); + } + *s = sum; +} + +inline static void lm_ggml_vec_max_f32(const int n, float * s, const float * x) { +#ifndef LM_GGML_USE_ACCELERATE + float max = -INFINITY; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + } + *s = max; +#else + vDSP_maxv(x, 1, s, n); +#endif +} + +inline static void lm_ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { + lm_ggml_vec_norm_f32(n, s, x); + *s = 1.f/(*s); +} + +inline static void lm_ggml_vec_argmax_f32(const int n, int * s, const float * x) { + float max = -INFINITY; + int idx = 0; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + if (max == x[i]) { idx = i; } + } + *s = idx; +} + +// Helpers for polling loops +#if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) ) +static inline void lm_ggml_thread_cpu_relax(void) { + __asm__ volatile("yield" ::: "memory"); +} +#elif defined(__x86_64__) +static inline void lm_ggml_thread_cpu_relax(void) { + _mm_pause(); +} +#else +static inline void lm_ggml_thread_cpu_relax(void) {;} +#endif + +// +// NUMA support +// + +#define LM_GGML_NUMA_MAX_NODES 8 +#define LM_GGML_NUMA_MAX_CPUS 512 + +struct lm_ggml_numa_node { + uint32_t cpus[LM_GGML_NUMA_MAX_CPUS]; // hardware threads on this node + uint32_t n_cpus; +}; + +struct lm_ggml_numa_nodes { + enum lm_ggml_numa_strategy numa_strategy; + struct lm_ggml_numa_node nodes[LM_GGML_NUMA_MAX_NODES]; + uint32_t n_nodes; + uint32_t total_cpus; // hardware threads on system + uint32_t current_node; // node on which main process is execting +#if defined(__gnu_linux__) + cpu_set_t cpuset; // cpuset from numactl +#else + uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype +#endif +}; + +// +// ggml state +// + +struct lm_ggml_state { + struct lm_ggml_numa_nodes numa; +}; + +static struct lm_ggml_state g_state = {0}; + +static void lm_ggml_barrier(struct lm_ggml_threadpool * tp) { + int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed); + if (n_threads == 1) { + return; + } + +#ifdef LM_GGML_USE_OPENMP + #pragma omp barrier +#else + int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed); + + // enter barrier (full seq-cst fence) + int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst); + + if (n_barrier == (n_threads - 1)) { + // last thread + atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed); + + // exit barrier (fill seq-cst fence) + atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst); + return; + } + + // wait for other threads + while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) { + lm_ggml_thread_cpu_relax(); + } + + // exit barrier (full seq-cst fence) + // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead + #ifdef LM_GGML_TSAN_ENABLED + atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst); + #else + atomic_thread_fence(memory_order_seq_cst); + #endif +#endif +} + +#if defined(__gnu_linux__) +static cpu_set_t lm_ggml_get_numa_affinity(void) { + cpu_set_t cpuset; + pthread_t thread; + thread = pthread_self(); + CPU_ZERO(&cpuset); + pthread_getaffinity_np(thread, sizeof(cpu_set_t), &cpuset); + return cpuset; +} +#else +static uint32_t lm_ggml_get_numa_affinity(void) { + return 0; // no NUMA support +} +#endif + +void lm_ggml_numa_init(enum lm_ggml_numa_strategy numa_flag) { + if (g_state.numa.n_nodes > 0) { + fprintf(stderr, "lm_ggml_numa_init: NUMA already initialized\n"); + + return; + } + +#if defined(__gnu_linux__) + struct stat st; + char path[256]; + int rv; + + // set numa scheme + g_state.numa.numa_strategy = numa_flag; + + LM_GGML_PRINT_DEBUG("numa strategy %u\n",g_state.numa.numa_strategy); + + g_state.numa.cpuset = lm_ggml_get_numa_affinity(); + + // enumerate nodes + while (g_state.numa.n_nodes < LM_GGML_NUMA_MAX_NODES) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes); + LM_GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) != 0) { break; } + ++g_state.numa.n_nodes; + } + + // enumerate CPUs + while (g_state.numa.total_cpus < LM_GGML_NUMA_MAX_CPUS) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus); + LM_GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) != 0) { break; } + ++g_state.numa.total_cpus; + } + + LM_GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus); + + // figure out which node we're on + uint current_cpu; + int getcpu_ret = 0; +#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__) + getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node); +#else + // old glibc doesn't have a wrapper for this call. Fall back on direct syscall +# if !defined(SYS_getcpu) && defined(SYS_get_cpu) +# define SYS_getcpu SYS_get_cpu // some older glibc versions use this name +# endif + getcpu_ret = syscall(SYS_getcpu, ¤t_cpu, &g_state.numa.current_node); +#endif + + if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) { + g_state.numa.n_nodes = 0; + return; + } + + LM_GGML_PRINT_DEBUG("found our process on numa node %u, CPU %u\n", g_state.numa.current_node, current_cpu); + + for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) { + struct lm_ggml_numa_node * node = &g_state.numa.nodes[n]; + LM_GGML_PRINT_DEBUG("CPUs on node %u:", n); + node->n_cpus = 0; + for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c); + LM_GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) == 0) { + node->cpus[node->n_cpus++] = c; + LM_GGML_PRINT_DEBUG(" %u", c); + } + } + LM_GGML_PRINT_DEBUG("\n"); + } + + if (lm_ggml_is_numa()) { + FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r"); + if (fptr != NULL) { + char buf[42]; + if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) { + LM_GGML_LOG_WARN("/proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n"); + } + fclose(fptr); + } + } +#else + UNUSED(numa_flag); + // TODO +#endif +} + +bool lm_ggml_is_numa(void) { + return g_state.numa.n_nodes > 1; +} + +#if defined(__ARM_ARCH) + +#if defined(__linux__) && defined(__aarch64__) +#include +#elif defined(__APPLE__) +#include +#endif + +#if !defined(HWCAP2_I8MM) +#define HWCAP2_I8MM 0 +#endif + +static void lm_ggml_init_arm_arch_features(void) { +#if defined(__linux__) && defined(__aarch64__) + uint32_t hwcap = getauxval(AT_HWCAP); + uint32_t hwcap2 = getauxval(AT_HWCAP2); + + lm_ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD); + lm_ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM); + lm_ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE); + +#if defined(__ARM_FEATURE_SVE) + lm_ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); +#endif +#elif defined(__APPLE__) + int oldp = 0; + size_t size = sizeof(oldp); + if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + lm_ggml_arm_arch_features.has_neon = oldp; + + if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + lm_ggml_arm_arch_features.has_i8mm = oldp; + + lm_ggml_arm_arch_features.has_sve = 0; + lm_ggml_arm_arch_features.sve_cnt = 0; +#else +// Run-time CPU feature detection not implemented for this platform, fallback to compile time +#if defined(__ARM_NEON) + lm_ggml_arm_arch_features.has_neon = 1; +#else + lm_ggml_arm_arch_features.has_neon = 0; +#endif + +#if defined(__ARM_FEATURE_MATMUL_INT8) + lm_ggml_arm_arch_features.has_i8mm = 1; +#else + lm_ggml_arm_arch_features.has_i8mm = 0; +#endif + +#if defined(__ARM_FEATURE_SVE) + lm_ggml_arm_arch_features.has_sve = 1; + lm_ggml_arm_arch_features.sve_cnt = 16; +#else + lm_ggml_arm_arch_features.has_sve = 0; + lm_ggml_arm_arch_features.sve_cnt = 0; +#endif +#endif +} +#endif + +struct lm_ggml_tensor * lm_ggml_new_i32(struct lm_ggml_context * ctx, int32_t value) { + LM_GGML_ASSERT(!lm_ggml_get_no_alloc(ctx)); + + struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, 1); + + lm_ggml_set_i32(result, value); + + return result; +} + +struct lm_ggml_tensor * lm_ggml_new_f32(struct lm_ggml_context * ctx, float value) { + LM_GGML_ASSERT(!lm_ggml_get_no_alloc(ctx)); + + struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, 1); + + lm_ggml_set_f32(result, value); + + return result; +} + +struct lm_ggml_tensor * lm_ggml_set_i32 (struct lm_ggml_tensor * tensor, int32_t value) { + const int n = lm_ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case LM_GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + lm_ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case LM_GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + lm_ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case LM_GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + lm_ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case LM_GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(lm_ggml_fp16_t)); + for (int i = 0; i < n; i++) { + lm_ggml_vec_set_f16(nc, (lm_ggml_fp16_t *)(data + i*n1), LM_GGML_FP32_TO_FP16(value)); + } + } break; + case LM_GGML_TYPE_BF16: + { + assert(tensor->nb[0] == sizeof(lm_ggml_fp16_t)); + for (int i = 0; i < n; i++) { + lm_ggml_vec_set_bf16(nc, (lm_ggml_bf16_t *)(data + i*n1), LM_GGML_FP32_TO_BF16(value)); + } + } break; + case LM_GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + lm_ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } + + return tensor; +} + +struct lm_ggml_tensor * lm_ggml_set_f32(struct lm_ggml_tensor * tensor, float value) { + const int n = lm_ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case LM_GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + lm_ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case LM_GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + lm_ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case LM_GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + lm_ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case LM_GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(lm_ggml_fp16_t)); + for (int i = 0; i < n; i++) { + lm_ggml_vec_set_f16(nc, (lm_ggml_fp16_t *)(data + i*n1), LM_GGML_FP32_TO_FP16(value)); + } + } break; + case LM_GGML_TYPE_BF16: + { + assert(tensor->nb[0] == sizeof(lm_ggml_bf16_t)); + for (int i = 0; i < n; i++) { + lm_ggml_vec_set_bf16(nc, (lm_ggml_bf16_t *)(data + i*n1), LM_GGML_FP32_TO_BF16(value)); + } + } break; + case LM_GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + lm_ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } + + return tensor; +} + +int32_t lm_ggml_get_i32_1d(const struct lm_ggml_tensor * tensor, int i) { + if (!lm_ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + lm_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + return lm_ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]); + } + switch (tensor->type) { + case LM_GGML_TYPE_I8: + { + LM_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } + case LM_GGML_TYPE_I16: + { + LM_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } + case LM_GGML_TYPE_I32: + { + LM_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } + case LM_GGML_TYPE_F16: + { + LM_GGML_ASSERT(tensor->nb[0] == sizeof(lm_ggml_fp16_t)); + return LM_GGML_FP16_TO_FP32(((lm_ggml_fp16_t *)(tensor->data))[i]); + } + case LM_GGML_TYPE_BF16: + { + LM_GGML_ASSERT(tensor->nb[0] == sizeof(lm_ggml_bf16_t)); + return LM_GGML_BF16_TO_FP32(((lm_ggml_bf16_t *)(tensor->data))[i]); + } + case LM_GGML_TYPE_F32: + { + LM_GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +void lm_ggml_set_i32_1d(const struct lm_ggml_tensor * tensor, int i, int32_t value) { + if (!lm_ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + lm_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + lm_ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value); + return; + } + switch (tensor->type) { + case LM_GGML_TYPE_I8: + { + LM_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case LM_GGML_TYPE_I16: + { + LM_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case LM_GGML_TYPE_I32: + { + LM_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case LM_GGML_TYPE_F16: + { + LM_GGML_ASSERT(tensor->nb[0] == sizeof(lm_ggml_fp16_t)); + ((lm_ggml_fp16_t *)(tensor->data))[i] = LM_GGML_FP32_TO_FP16(value); + } break; + case LM_GGML_TYPE_BF16: + { + LM_GGML_ASSERT(tensor->nb[0] == sizeof(lm_ggml_bf16_t)); + ((lm_ggml_bf16_t *)(tensor->data))[i] = LM_GGML_FP32_TO_BF16(value); + } break; + case LM_GGML_TYPE_F32: + { + LM_GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +int32_t lm_ggml_get_i32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case LM_GGML_TYPE_I8: + return ((int8_t *) data)[0]; + case LM_GGML_TYPE_I16: + return ((int16_t *) data)[0]; + case LM_GGML_TYPE_I32: + return ((int32_t *) data)[0]; + case LM_GGML_TYPE_F16: + return LM_GGML_FP16_TO_FP32(((lm_ggml_fp16_t *) data)[0]); + case LM_GGML_TYPE_BF16: + return LM_GGML_BF16_TO_FP32(((lm_ggml_bf16_t *) data)[0]); + case LM_GGML_TYPE_F32: + return ((float *) data)[0]; + default: + LM_GGML_ABORT("fatal error"); + } +} + +void lm_ggml_set_i32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case LM_GGML_TYPE_I8: + { + ((int8_t *)(data))[0] = value; + } break; + case LM_GGML_TYPE_I16: + { + ((int16_t *)(data))[0] = value; + } break; + case LM_GGML_TYPE_I32: + { + ((int32_t *)(data))[0] = value; + } break; + case LM_GGML_TYPE_F16: + { + ((lm_ggml_fp16_t *)(data))[0] = LM_GGML_FP32_TO_FP16(value); + } break; + case LM_GGML_TYPE_BF16: + { + ((lm_ggml_bf16_t *)(data))[0] = LM_GGML_FP32_TO_BF16(value); + } break; + case LM_GGML_TYPE_F32: + { + ((float *)(data))[0] = value; + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +float lm_ggml_get_f32_1d(const struct lm_ggml_tensor * tensor, int i) { + if (!lm_ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + lm_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + return lm_ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]); + } + switch (tensor->type) { + case LM_GGML_TYPE_I8: + { + return ((int8_t *)(tensor->data))[i]; + } + case LM_GGML_TYPE_I16: + { + return ((int16_t *)(tensor->data))[i]; + } + case LM_GGML_TYPE_I32: + { + return ((int32_t *)(tensor->data))[i]; + } + case LM_GGML_TYPE_F16: + { + return LM_GGML_FP16_TO_FP32(((lm_ggml_fp16_t *)(tensor->data))[i]); + } + case LM_GGML_TYPE_BF16: + { + return LM_GGML_BF16_TO_FP32(((lm_ggml_bf16_t *)(tensor->data))[i]); + } + case LM_GGML_TYPE_F32: + { + return ((float *)(tensor->data))[i]; + } + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +void lm_ggml_set_f32_1d(const struct lm_ggml_tensor * tensor, int i, float value) { + if (!lm_ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + lm_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + lm_ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value); + return; + } + switch (tensor->type) { + case LM_GGML_TYPE_I8: + { + ((int8_t *)(tensor->data))[i] = value; + } break; + case LM_GGML_TYPE_I16: + { + ((int16_t *)(tensor->data))[i] = value; + } break; + case LM_GGML_TYPE_I32: + { + ((int32_t *)(tensor->data))[i] = value; + } break; + case LM_GGML_TYPE_F16: + { + ((lm_ggml_fp16_t *)(tensor->data))[i] = LM_GGML_FP32_TO_FP16(value); + } break; + case LM_GGML_TYPE_BF16: + { + ((lm_ggml_bf16_t *)(tensor->data))[i] = LM_GGML_FP32_TO_BF16(value); + } break; + case LM_GGML_TYPE_F32: + { + ((float *)(tensor->data))[i] = value; + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +float lm_ggml_get_f32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case LM_GGML_TYPE_I8: + return ((int8_t *) data)[0]; + case LM_GGML_TYPE_I16: + return ((int16_t *) data)[0]; + case LM_GGML_TYPE_I32: + return ((int32_t *) data)[0]; + case LM_GGML_TYPE_F16: + return LM_GGML_FP16_TO_FP32(((lm_ggml_fp16_t *) data)[0]); + case LM_GGML_TYPE_BF16: + return LM_GGML_BF16_TO_FP32(((lm_ggml_bf16_t *) data)[0]); + case LM_GGML_TYPE_F32: + return ((float *) data)[0]; + default: + LM_GGML_ABORT("fatal error"); + } +} + +void lm_ggml_set_f32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case LM_GGML_TYPE_I8: + { + ((int8_t *)(data))[0] = value; + } break; + case LM_GGML_TYPE_I16: + { + ((int16_t *)(data))[0] = value; + } break; + case LM_GGML_TYPE_I32: + { + ((int32_t *)(data))[0] = value; + } break; + case LM_GGML_TYPE_F16: + { + ((lm_ggml_fp16_t *)(data))[0] = LM_GGML_FP32_TO_FP16(value); + } break; + case LM_GGML_TYPE_BF16: + { + ((lm_ggml_bf16_t *)(data))[0] = LM_GGML_FP32_TO_BF16(value); + } break; + case LM_GGML_TYPE_F32: + { + ((float *)(data))[0] = value; + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +//////////////////////////////////////////////////////////////////////////////// + +// lm_ggml_compute_forward_dup + +static void lm_ggml_compute_forward_dup_same_cont( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); + LM_GGML_ASSERT(src0->type == dst->type); + + const size_t nb0 = lm_ggml_type_size(src0->type); + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + // parallelize by elements + const int ne = lm_ggml_nelements(dst); + const int dr = (ne + nth - 1) / nth; + const int ie0 = dr * ith; + const int ie1 = MIN(ie0 + dr, ne); + + if (ie0 < ie1) { + memcpy( + ((char *) dst->data + ie0*nb0), + ((char *) src0->data + ie0*nb0), + (ie1 - ie0) * nb0); + } +} + +static void lm_ggml_compute_forward_dup_f16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == lm_ggml_type_size(src0->type) && nb0 == lm_ggml_type_size(dst->type)) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy + + if (lm_ggml_is_contiguous(dst)) { + if (nb00 == sizeof(lm_ggml_fp16_t)) { + if (dst->type == LM_GGML_TYPE_F16) { + size_t id = 0; + const size_t rs = ne00 * nb00; + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (dst->type == LM_GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + for (int i00 = 0; i00 < ne00; i00++) { + dst_ptr[id] = LM_GGML_FP16_TO_FP32(src0_ptr[i00]); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (lm_ggml_get_type_traits_cpu(dst->type)->from_float) { + lm_ggml_from_float_t const quantize_row_q = lm_ggml_get_type_traits_cpu(dst->type)->from_float; + float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + size_t id = 0; + size_t rs = nb0 * (ne00 / lm_ggml_blck_size(dst->type)); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + for (int i00 = 0; i00 < ne00; i00++) { + src0_f32[i00] = LM_GGML_FP16_TO_FP32(src0_ptr[i00]); + } + + quantize_row_q(src0_f32, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + LM_GGML_ABORT("fatal error"); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == LM_GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = LM_GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == LM_GGML_TYPE_F16) { + size_t id = 0; + lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else { + LM_GGML_ABORT("fatal error"); // TODO: implement + } + } + return; + } + + // dst counters + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if (dst->type == LM_GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, sizeof(lm_ggml_fp16_t)); + + if (++i10 == ne00) { + i10 = 0; + if (++i11 == ne01) { + i11 = 0; + if (++i12 == ne02) { + i12 = 0; + if (++i13 == ne03) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == LM_GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(float *) dst_ptr = LM_GGML_FP16_TO_FP32(*(const lm_ggml_fp16_t *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + LM_GGML_ABORT("fatal error"); // TODO: implement + } +} + +static void lm_ggml_compute_forward_dup_bf16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == lm_ggml_type_size(src0->type) && nb0 == lm_ggml_type_size(dst->type)) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy + + if (lm_ggml_is_contiguous(dst)) { + if (nb00 == sizeof(lm_ggml_bf16_t)) { + if (dst->type == LM_GGML_TYPE_BF16) { + size_t id = 0; + const size_t rs = ne00 * nb00; + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (dst->type == LM_GGML_TYPE_F16) { + size_t id = 0; + lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + for (int i00 = 0; i00 < ne00; i00++) { + dst_ptr[id] = LM_GGML_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(src0_ptr[i00])); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == LM_GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + for (int i00 = 0; i00 < ne00; i00++) { + dst_ptr[id] = LM_GGML_BF16_TO_FP32(src0_ptr[i00]); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (lm_ggml_get_type_traits_cpu(dst->type)->from_float) { + lm_ggml_from_float_t const quantize_row_q = lm_ggml_get_type_traits_cpu(dst->type)->from_float; + float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + size_t id = 0; + size_t rs = nb0 * (ne00 / lm_ggml_blck_size(dst->type)); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + for (int i00 = 0; i00 < ne00; i00++) { + src0_f32[i00] = LM_GGML_BF16_TO_FP32(src0_ptr[i00]); + } + + quantize_row_q(src0_f32, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + LM_GGML_ABORT("fatal error"); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == LM_GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = LM_GGML_BF16_TO_FP32(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == LM_GGML_TYPE_BF16) { + size_t id = 0; + lm_ggml_bf16_t * dst_ptr = (lm_ggml_bf16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == LM_GGML_TYPE_F16) { + size_t id = 0; + lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = LM_GGML_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(*src0_ptr)); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else { + LM_GGML_ABORT("fatal error"); // TODO: implement + } + } + return; + } + + // dst counters + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if (dst->type == LM_GGML_TYPE_BF16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, sizeof(lm_ggml_bf16_t)); + + if (++i10 == ne00) { + i10 = 0; + if (++i11 == ne01) { + i11 = 0; + if (++i12 == ne02) { + i12 = 0; + if (++i13 == ne03) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == LM_GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(lm_ggml_fp16_t *) dst_ptr = LM_GGML_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(*(const lm_ggml_bf16_t *) src0_ptr)); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == LM_GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(float *) dst_ptr = LM_GGML_BF16_TO_FP32(*(const lm_ggml_bf16_t *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + LM_GGML_ABORT("fatal error"); // TODO: implement + } +} + +static void lm_ggml_compute_forward_dup_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == lm_ggml_type_size(src0->type) && nb0 == lm_ggml_type_size(dst->type)) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + if (lm_ggml_is_contiguous(dst)) { + // TODO: simplify + if (nb00 == sizeof(float)) { + if (dst->type == LM_GGML_TYPE_F32) { + size_t id = 0; + const size_t rs = ne00 * nb00; + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (lm_ggml_get_type_traits_cpu(dst->type)->from_float) { + lm_ggml_from_float_t const quantize_row_q = lm_ggml_get_type_traits_cpu(dst->type)->from_float; + + size_t id = 0; + size_t rs = nb0 * (ne00 / lm_ggml_blck_size(dst->type)); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + quantize_row_q(src0_ptr, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + LM_GGML_ABORT("fatal error"); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == LM_GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == LM_GGML_TYPE_F16) { + size_t id = 0; + lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = LM_GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == LM_GGML_TYPE_BF16) { + size_t id = 0; + lm_ggml_bf16_t * dst_ptr = (lm_ggml_bf16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = LM_GGML_FP32_TO_BF16(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else { + LM_GGML_ABORT("fatal error"); // TODO: implement + } + } + + return; + } + + // dst counters + + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if (dst->type == LM_GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, sizeof(float)); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == LM_GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(lm_ggml_fp16_t *) dst_ptr = LM_GGML_FP32_TO_FP16(*(const float *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == LM_GGML_TYPE_BF16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(lm_ggml_bf16_t *) dst_ptr = LM_GGML_FP32_TO_BF16(*(const float *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + LM_GGML_ABORT("fatal error"); // TODO: implement + } +} + +// A simplified version of lm_ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy. +static void lm_ggml_compute_forward_dup_bytes( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); + LM_GGML_ASSERT(src0->type == dst->type); + + LM_GGML_TENSOR_UNARY_OP_LOCALS; + + if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst)) { + lm_ggml_compute_forward_dup_same_cont(params, dst); + return; + } + + const size_t type_size = lm_ggml_type_size(src0->type); + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == type_size && nb0 == type_size) { + // copy by rows + const size_t rs = ne00 * type_size; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + if (lm_ggml_is_contiguous(dst)) { + size_t id = 0; + char * dst_ptr = (char *) dst->data; + const size_t rs = ne00 * type_size; + + if (nb00 == type_size) { + // src0 is contigous on first dimension, copy by rows + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, type_size); + + id += type_size; + } + } + id += rs * (ne01 - ir1); + } + } + } + + return; + } + + // dst counters + + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, type_size); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } +} + +static void lm_ggml_compute_forward_dup( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (src0->type == dst->type) { + lm_ggml_compute_forward_dup_bytes(params, dst); + return; + } + + switch (src0->type) { + case LM_GGML_TYPE_F16: + { + lm_ggml_compute_forward_dup_f16(params, dst); + } break; + case LM_GGML_TYPE_BF16: + { + lm_ggml_compute_forward_dup_bf16(params, dst); + } break; + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_dup_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_add + +static void lm_ggml_compute_forward_add_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + LM_GGML_ASSERT( nb0 == sizeof(float)); + LM_GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { +#ifdef LM_GGML_USE_ACCELERATE + vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10); +#else + lm_ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); +#endif + } + } + } else { + // src1 is not contiguous + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); + + dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; + } + } + } +} + +static void lm_ggml_compute_forward_add_f16_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + + if (dst->type == LM_GGML_TYPE_F32) { + LM_GGML_ASSERT( nb0 == sizeof(float)); + } + else { + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F16); + LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_fp16_t)); + } + + LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + if (dst->type == LM_GGML_TYPE_F16) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); + } + } + } else { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = LM_GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; + } + } + } + } + else { + // src1 is not contiguous + LM_GGML_ABORT("fatal error"); + } +} + +static void lm_ggml_compute_forward_add_bf16_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_BF16); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + + if (dst->type == LM_GGML_TYPE_F32) { + LM_GGML_ASSERT( nb0 == sizeof(float)); + } + else { + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_BF16); + LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_bf16_t)); + } + + LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + if (dst->type == LM_GGML_TYPE_BF16) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + lm_ggml_bf16_t * dst_ptr = (lm_ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = LM_GGML_FP32_TO_BF16(LM_GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); + } + } + } else { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = LM_GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; + } + } + } + } + else { + // src1 is not contiguous + LM_GGML_ABORT("fatal error"); + } +} + +static void lm_ggml_compute_forward_add_f16_f16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F16); + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F16); + + LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_fp16_t)); + LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(lm_ggml_fp16_t)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + lm_ggml_fp16_t * src1_ptr = (lm_ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(src0_ptr[i]) + LM_GGML_FP16_TO_FP32(src1_ptr[i])); + } + } + } + else { + // src1 is not contiguous + LM_GGML_ABORT("fatal error"); + } +} + +static void lm_ggml_compute_forward_add_bf16_bf16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_BF16); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_BF16); + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_BF16); + + LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_bf16_t)); + LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(lm_ggml_bf16_t)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + lm_ggml_bf16_t * dst_ptr = (lm_ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + lm_ggml_bf16_t * src1_ptr = (lm_ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = LM_GGML_FP32_TO_BF16(LM_GGML_BF16_TO_FP32(src0_ptr[i]) + LM_GGML_BF16_TO_FP32(src1_ptr[i])); + } + } + } + else { + // src1 is not contiguous + LM_GGML_ABORT("fatal error"); + } +} + +static void lm_ggml_compute_forward_add_q_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const enum lm_ggml_type type = src0->type; + const enum lm_ggml_type dtype = dst->type; + lm_ggml_to_float_t const dequantize_row_q = lm_ggml_get_type_traits(type)->to_float; + lm_ggml_from_float_t const quantize_row_q = lm_ggml_get_type_traits_cpu(dtype)->from_float; + + // we don't support permuted src0 or src1 + LM_GGML_ASSERT(nb00 == lm_ggml_type_size(type)); + LM_GGML_ASSERT(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + LM_GGML_ASSERT(nb0 <= nb1); + LM_GGML_ASSERT(nb1 <= nb2); + LM_GGML_ASSERT(nb2 <= nb3); + + LM_GGML_ASSERT(lm_ggml_is_quantized(src0->type)); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + // src1 and dst are same shape as src0 => same indices + const int i13 = i03; + const int i12 = i02; + const int i11 = i01; + + const int i3 = i03; + const int i2 = i02; + const int i1 = i01; + + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)); + void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + assert(ne00 % 32 == 0); + + // unquantize row from src0 to temp buffer + dequantize_row_q(src0_row, wdata, ne00); + // add src1 + lm_ggml_vec_acc_f32(ne00, wdata, src1_row); + // quantize row to dst + if (quantize_row_q != NULL) { + quantize_row_q(wdata, dst_row, ne00); + } else { + memcpy(dst_row, wdata, ne0*nb0); + } + } +} + +static void lm_ggml_compute_forward_add( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + if (src1->type == LM_GGML_TYPE_F32) { + lm_ggml_compute_forward_add_f32(params, dst); + } + else { + LM_GGML_ABORT("fatal error"); + } + } break; + case LM_GGML_TYPE_F16: + { + if (src1->type == LM_GGML_TYPE_F16) { + lm_ggml_compute_forward_add_f16_f16(params, dst); + } + else if (src1->type == LM_GGML_TYPE_F32) { + lm_ggml_compute_forward_add_f16_f32(params, dst); + } + else { + LM_GGML_ABORT("fatal error"); + } + } break; + case LM_GGML_TYPE_BF16: + { + if (src1->type == LM_GGML_TYPE_BF16) { + lm_ggml_compute_forward_add_bf16_bf16(params, dst); + } + else if (src1->type == LM_GGML_TYPE_F32) { + lm_ggml_compute_forward_add_bf16_f32(params, dst); + } + else { + LM_GGML_ABORT("fatal error"); + } + } break; + case LM_GGML_TYPE_Q4_0: + case LM_GGML_TYPE_Q4_1: + case LM_GGML_TYPE_Q5_0: + case LM_GGML_TYPE_Q5_1: + case LM_GGML_TYPE_Q8_0: + case LM_GGML_TYPE_Q2_K: + case LM_GGML_TYPE_Q3_K: + case LM_GGML_TYPE_Q4_K: + case LM_GGML_TYPE_Q5_K: + case LM_GGML_TYPE_Q6_K: + case LM_GGML_TYPE_TQ1_0: + case LM_GGML_TYPE_TQ2_0: + case LM_GGML_TYPE_IQ2_XXS: + case LM_GGML_TYPE_IQ2_XS: + case LM_GGML_TYPE_IQ3_XXS: + case LM_GGML_TYPE_IQ1_S: + case LM_GGML_TYPE_IQ1_M: + case LM_GGML_TYPE_IQ4_NL: + case LM_GGML_TYPE_IQ4_XS: + case LM_GGML_TYPE_IQ3_S: + case LM_GGML_TYPE_IQ2_S: + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + case LM_GGML_TYPE_Q4_0_8_8: + { + lm_ggml_compute_forward_add_q_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_add1 + +static void lm_ggml_compute_forward_add1_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + LM_GGML_ASSERT( nb0 == sizeof(float)); + LM_GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + +#ifdef LM_GGML_USE_ACCELERATE + UNUSED(lm_ggml_vec_add1_f32); + + vDSP_vadd( + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, + (float *) ((char *) src1->data), 0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, + ne0); +#else + lm_ggml_vec_add1_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), + *(float *) src1->data); +#endif + } +} + +static void lm_ggml_compute_forward_add1_f16_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); + + // scalar to add + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F16); + + LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_fp16_t)); + LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void lm_ggml_compute_forward_add1_f16_f16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); + + // scalar to add + const float v = LM_GGML_FP16_TO_FP32(*(lm_ggml_fp16_t *) src1->data); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F16); + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F16); + + LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_fp16_t)); + LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void lm_ggml_compute_forward_add1_q_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); + + // scalar to add + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + const enum lm_ggml_type type = src0->type; + lm_ggml_to_float_t const dequantize_row_q = lm_ggml_get_type_traits(type)->to_float; + lm_ggml_from_float_t const quantize_row_q = lm_ggml_get_type_traits_cpu(type)->from_float; + + // we don't support permuted src0 + LM_GGML_ASSERT(nb00 == lm_ggml_type_size(type)); + + // dst cannot be transposed or permuted + LM_GGML_ASSERT(nb0 <= nb1); + LM_GGML_ASSERT(nb1 <= nb2); + LM_GGML_ASSERT(nb2 <= nb3); + + LM_GGML_ASSERT(lm_ggml_is_quantized(src0->type)); + LM_GGML_ASSERT(dst->type == src0->type); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03)); + void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 )); + + assert(ne0 % 32 == 0); + + // unquantize row from src0 to temp buffer + dequantize_row_q(src0_row, wdata, ne0); + // add src1 + lm_ggml_vec_acc1_f32(ne0, wdata, v); + // quantize row to dst + quantize_row_q(wdata, dst_row, ne0); + } +} + +static void lm_ggml_compute_forward_add1_bf16_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); + + // scalar to add + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_BF16); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_BF16); + + LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_bf16_t)); + LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + lm_ggml_bf16_t * dst_ptr = (lm_ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = LM_GGML_FP32_TO_BF16(LM_GGML_BF16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void lm_ggml_compute_forward_add1_bf16_bf16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); + + // scalar to add + const float v = LM_GGML_BF16_TO_FP32(*(lm_ggml_bf16_t *) src1->data); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_BF16); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_BF16); + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_BF16); + + LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_bf16_t)); + LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + lm_ggml_bf16_t * dst_ptr = (lm_ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = LM_GGML_FP32_TO_BF16(LM_GGML_BF16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void lm_ggml_compute_forward_add1( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_add1_f32(params, dst); + } break; + case LM_GGML_TYPE_F16: + { + if (src1->type == LM_GGML_TYPE_F16) { + lm_ggml_compute_forward_add1_f16_f16(params, dst); + } + else if (src1->type == LM_GGML_TYPE_F32) { + lm_ggml_compute_forward_add1_f16_f32(params, dst); + } + else { + LM_GGML_ABORT("fatal error"); + } + } break; + case LM_GGML_TYPE_BF16: + { + if (src1->type == LM_GGML_TYPE_BF16) { + lm_ggml_compute_forward_add1_bf16_bf16(params, dst); + } + else if (src1->type == LM_GGML_TYPE_F32) { + lm_ggml_compute_forward_add1_bf16_f32(params, dst); + } + else { + LM_GGML_ABORT("fatal error"); + } + } break; + case LM_GGML_TYPE_Q4_0: + case LM_GGML_TYPE_Q4_1: + case LM_GGML_TYPE_Q5_0: + case LM_GGML_TYPE_Q5_1: + case LM_GGML_TYPE_Q8_0: + case LM_GGML_TYPE_Q8_1: + case LM_GGML_TYPE_Q2_K: + case LM_GGML_TYPE_Q3_K: + case LM_GGML_TYPE_Q4_K: + case LM_GGML_TYPE_Q5_K: + case LM_GGML_TYPE_Q6_K: + case LM_GGML_TYPE_TQ1_0: + case LM_GGML_TYPE_TQ2_0: + case LM_GGML_TYPE_IQ2_XXS: + case LM_GGML_TYPE_IQ2_XS: + case LM_GGML_TYPE_IQ3_XXS: + case LM_GGML_TYPE_IQ1_S: + case LM_GGML_TYPE_IQ1_M: + case LM_GGML_TYPE_IQ4_NL: + case LM_GGML_TYPE_IQ4_XS: + case LM_GGML_TYPE_IQ3_S: + case LM_GGML_TYPE_IQ2_S: + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + case LM_GGML_TYPE_Q4_0_8_8: + { + lm_ggml_compute_forward_add1_q_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_acc + +static void lm_ggml_compute_forward_acc_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); + + // view src0 and dst with these strides and data offset inbytes during acc + // nb0 is implicitly element_size because src0 and dst are contiguous + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + if (params->ith == 0) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + lm_ggml_nbytes(dst)); + } + lm_ggml_barrier(params->threadpool); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src1); + const int nc = src1->ne[0]; + + LM_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + LM_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) + + // src0 and dst as viewed during acc + const size_t nb0 = lm_ggml_element_size(src0); + + const size_t nb00 = nb0; + const size_t nb01 = nb1; + const size_t nb02 = nb2; + const size_t nb03 = nb3; + + LM_GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < lm_ggml_nbytes(dst)); + LM_GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < lm_ggml_nbytes(src0)); + + LM_GGML_ASSERT(nb10 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are viewed with shape of src1 and offset + // => same indices + const int i3 = ir/(ne12*ne11); + const int i2 = (ir - i3*ne12*ne11)/ne11; + const int i1 = (ir - i3*ne12*ne11 - i2*ne11); + +#ifdef LM_GGML_USE_ACCELERATE + vDSP_vadd( + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1, + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc); +#else + lm_ggml_vec_add_f32(nc, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); +#endif + } +} + +static void lm_ggml_compute_forward_acc( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_acc_f32(params, dst); + } break; + case LM_GGML_TYPE_F16: + case LM_GGML_TYPE_BF16: + case LM_GGML_TYPE_Q4_0: + case LM_GGML_TYPE_Q4_1: + case LM_GGML_TYPE_Q5_0: + case LM_GGML_TYPE_Q5_1: + case LM_GGML_TYPE_Q8_0: + case LM_GGML_TYPE_Q8_1: + case LM_GGML_TYPE_Q2_K: + case LM_GGML_TYPE_Q3_K: + case LM_GGML_TYPE_Q4_K: + case LM_GGML_TYPE_Q5_K: + case LM_GGML_TYPE_Q6_K: + case LM_GGML_TYPE_TQ1_0: + case LM_GGML_TYPE_TQ2_0: + case LM_GGML_TYPE_IQ2_XXS: + case LM_GGML_TYPE_IQ2_XS: + case LM_GGML_TYPE_IQ3_XXS: + case LM_GGML_TYPE_IQ1_S: + case LM_GGML_TYPE_IQ1_M: + case LM_GGML_TYPE_IQ4_NL: + case LM_GGML_TYPE_IQ4_XS: + case LM_GGML_TYPE_IQ3_S: + case LM_GGML_TYPE_IQ2_S: + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + case LM_GGML_TYPE_Q4_0_8_8: + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_sub + +static void lm_ggml_compute_forward_sub_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + assert(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + LM_GGML_ASSERT( nb0 == sizeof(float)); + LM_GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { +#ifdef LM_GGML_USE_ACCELERATE + vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); +#else + lm_ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); +#endif + } + } + } else { + // src1 is not contiguous + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); + + dst_ptr[i0] = src0_ptr[i0] - *src1_ptr; + } + } + } +} + +static void lm_ggml_compute_forward_sub( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_sub_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_mul + +static void lm_ggml_compute_forward_mul_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + LM_GGML_ASSERT( nb0 == sizeof(float)); + LM_GGML_ASSERT(nb00 == sizeof(float)); + + if (nb10 == sizeof(float)) { + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0 ; r < nr0; ++r) { +#ifdef LM_GGML_USE_ACCELERATE + UNUSED(lm_ggml_vec_mul_f32); + + vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10); +#else + lm_ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); +#endif + } + } + } else { + // src1 is not contiguous + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne00; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); + + dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); + } + } + } +} + +static void lm_ggml_compute_forward_mul( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32 && "only f32 src1 supported for now"); + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_mul_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_div + +static void lm_ggml_compute_forward_div_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + LM_GGML_ASSERT( nb0 == sizeof(float)); + LM_GGML_ASSERT(nb00 == sizeof(float)); + + if (nb10 == sizeof(float)) { + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { +#ifdef LM_GGML_USE_ACCELERATE + UNUSED(lm_ggml_vec_div_f32); + + vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); +#else + lm_ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); +#endif + } + } + } else { + // src1 is not contiguous + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne00; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); + + dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr); + } + } + } +} + +static void lm_ggml_compute_forward_div( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_div_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_sqr + +static void lm_ggml_compute_forward_sqr_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + lm_ggml_vec_sqr_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_sqr( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_sqr_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_sqrt + +static void lm_ggml_compute_forward_sqrt_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + lm_ggml_vec_sqrt_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_sqrt( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_sqrt_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_log + +static void lm_ggml_compute_forward_log_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + LM_GGML_ASSERT( dst->nb[0] == sizeof(float)); + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + lm_ggml_vec_log_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_log( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_log_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_sin + +static void lm_ggml_compute_forward_sin_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + LM_GGML_ASSERT( dst->nb[0] == sizeof(float)); + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + lm_ggml_vec_sin_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_sin( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_sin_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_cos + +static void lm_ggml_compute_forward_cos_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + LM_GGML_ASSERT( dst->nb[0] == sizeof(float)); + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + lm_ggml_vec_cos_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_cos( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_cos_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_sum + +static void lm_ggml_compute_forward_sum_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_scalar(dst)); + assert(src0->nb[0] == sizeof(float)); + + LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + LM_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + + lm_ggml_float sum = 0; + lm_ggml_float row_sum = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + lm_ggml_vec_sum_f32_ggf(ne00, + &row_sum, + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + sum += row_sum; + } + } + } + ((float *) dst->data)[0] = sum; +} + +static void lm_ggml_compute_forward_sum_f16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_scalar(dst)); + + assert(src0->nb[0] == sizeof(lm_ggml_fp16_t)); + + LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + LM_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + + float sum = 0; + float row_sum = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + lm_ggml_vec_sum_f16_ggf(ne00, + &row_sum, + (lm_ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); + sum += row_sum; + } + } + } + ((lm_ggml_fp16_t *) dst->data)[0] = LM_GGML_FP32_TO_FP16(sum); +} + +static void lm_ggml_compute_forward_sum_bf16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_scalar(dst)); + + assert(src0->nb[0] == sizeof(lm_ggml_bf16_t)); + + LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + LM_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + + float sum = 0; + float row_sum = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + lm_ggml_vec_sum_bf16_ggf(ne00, + &row_sum, + (lm_ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); + sum += row_sum; + } + } + } + ((lm_ggml_bf16_t *) dst->data)[0] = LM_GGML_FP32_TO_BF16(sum); +} + +static void lm_ggml_compute_forward_sum( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_sum_f32(params, dst); + } break; + case LM_GGML_TYPE_F16: + { + lm_ggml_compute_forward_sum_f16(params, dst); + } break; + case LM_GGML_TYPE_BF16: + { + lm_ggml_compute_forward_sum_bf16(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_sum_rows + +static void lm_ggml_compute_forward_sum_rows_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + LM_GGML_ASSERT(dst->nb[0] == sizeof(float)); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + LM_GGML_ASSERT(ne0 == 1); + LM_GGML_ASSERT(ne1 == ne01); + LM_GGML_ASSERT(ne2 == ne02); + LM_GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + float row_sum = 0; + lm_ggml_vec_sum_f32(ne00, &row_sum, src_row); + dst_row[0] = row_sum; + } + } + } +} + +static void lm_ggml_compute_forward_sum_rows( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_sum_rows_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_mean + +static void lm_ggml_compute_forward_mean_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + assert(ne0 == 1); + assert(ne1 == ne01); + assert(ne2 == ne02); + assert(ne3 == ne03); + + UNUSED(ne0); + UNUSED(ne1); + UNUSED(ne2); + UNUSED(ne3); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + lm_ggml_vec_sum_f32(ne00, + (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + + *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; + } + } + } +} + +static void lm_ggml_compute_forward_mean( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_mean_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_argmax + +static void lm_ggml_compute_forward_argmax_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + assert(dst->nb[0] == sizeof(float)); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + + const size_t nb01 = src0->nb[1]; + const size_t nb0 = dst->nb[0]; + + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src = (float *) ((char *) src0->data + i1*nb01); + int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0); + int v = 0; + lm_ggml_vec_argmax_f32(ne00, &v, src); + dst_[0] = v; + } +} + +static void lm_ggml_compute_forward_argmax( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_argmax_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_count_equal + +static void lm_ggml_compute_forward_count_equal_i32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_TENSOR_BINARY_OP_LOCALS; + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_I32); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_I32); + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1)); + LM_GGML_ASSERT(lm_ggml_is_scalar(dst)); + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_I64); + + const int64_t nr = lm_ggml_nrows(src0); + + const int ith = params->ith; + const int nth = params->nth; + + int64_t * sums = (int64_t *) params->wdata; + int64_t sum_thread = 0; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02*ne01); + const int64_t i02 = (ir - i03*ne03) / ne01; + const int64_t i01 = ir - i03*ne03 - i02*ne02; + + const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01; + const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11; + + for (int64_t i00 = 0; i00 < ne00; ++i00) { + const int32_t val0 = *((const int32_t *) (data0 + i00*nb00)); + const int32_t val1 = *((const int32_t *) (data1 + i00*nb10)); + + sum_thread += val0 == val1; + } + } + if (ith != 0) { + sums[ith] = sum_thread; + } + lm_ggml_barrier(params->threadpool); + + if (ith != 0) { + return; + } + + for (int ith_other = 1; ith_other < nth; ++ith_other) { + sum_thread += sums[ith_other]; + } + *((int64_t *) dst->data) = sum_thread; +} + +static void lm_ggml_compute_forward_count_equal( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_I32: + { + lm_ggml_compute_forward_count_equal_i32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_repeat + +static void lm_ggml_compute_forward_repeat_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + LM_GGML_ASSERT(lm_ggml_can_repeat(src0, dst)); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + // guaranteed to be an integer due to the check in lm_ggml_can_repeat + const int nr0 = (int)(ne0/ne00); + const int nr1 = (int)(ne1/ne01); + const int nr2 = (int)(ne2/ne02); + const int nr3 = (int)(ne3/ne03); + + // TODO: support for transposed / permuted tensors + LM_GGML_ASSERT(nb0 == sizeof(float)); + LM_GGML_ASSERT(nb00 == sizeof(float)); + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne03; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne02; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne01; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + lm_ggml_vec_cpy_f32(ne00, + (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0), + (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01)); + } + } + } + } + } + } + } +} + +static void lm_ggml_compute_forward_repeat_f16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + LM_GGML_ASSERT(lm_ggml_can_repeat(src0, dst)); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + // guaranteed to be an integer due to the check in lm_ggml_can_repeat + const int nr0 = (int)(ne0/ne00); + const int nr1 = (int)(ne1/ne01); + const int nr2 = (int)(ne2/ne02); + const int nr3 = (int)(ne3/ne03); + + // TODO: support for transposed / permuted tensors + LM_GGML_ASSERT(nb0 == sizeof(lm_ggml_fp16_t)); + LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne03; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne02; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne01; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + lm_ggml_fp16_t * y = (lm_ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0); + lm_ggml_fp16_t * x = (lm_ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01); + // lm_ggml_vec_cpy_f16(ne00, y, x) + for (int i = 0; i < ne00; ++i) { + y[i] = x[i]; + } + } + } + } + } + } + } + } +} + +static void lm_ggml_compute_forward_repeat( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F16: + case LM_GGML_TYPE_BF16: + case LM_GGML_TYPE_I16: + { + lm_ggml_compute_forward_repeat_f16(params, dst); + } break; + case LM_GGML_TYPE_F32: + case LM_GGML_TYPE_I32: + { + lm_ggml_compute_forward_repeat_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_repeat_back + +static void lm_ggml_compute_forward_repeat_back_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + LM_GGML_ASSERT(lm_ggml_can_repeat(dst, src0)); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + // guaranteed to be an integer due to the check in lm_ggml_can_repeat + const int nr0 = (int)(ne00/ne0); + const int nr1 = (int)(ne01/ne1); + const int nr2 = (int)(ne02/ne2); + const int nr3 = (int)(ne03/ne3); + + // TODO: support for transposed / permuted tensors + LM_GGML_ASSERT(nb0 == sizeof(float)); + LM_GGML_ASSERT(nb00 == sizeof(float)); + + if (lm_ggml_is_contiguous(dst)) { + lm_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + } else { + for (int k3 = 0; k3 < ne3; k3++) { + for (int k2 = 0; k2 < ne2; k2++) { + for (int k1 = 0; k1 < ne1; k1++) { + lm_ggml_vec_set_f32(ne0, + (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3), + 0); + } + } + } + } + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne3; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne2; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne1; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + lm_ggml_vec_acc_f32(ne0, + (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1), + (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00)); + } + } + } + } + } + } + } +} + +static void lm_ggml_compute_forward_repeat_back( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_repeat_back_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_concat + +static void lm_ggml_compute_forward_concat_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t dim = lm_ggml_get_op_params_i32(dst, 0); + + LM_GGML_ASSERT(dim >= 0 && dim < 4); + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; + + const float * x; + + // TODO: smarter multi-theading + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2 += nth) { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); + } else { + x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); + } + + float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + *y = *x; + } + } + } + } +} + +static void lm_ggml_compute_forward_concat( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + case LM_GGML_TYPE_I32: + { + lm_ggml_compute_forward_concat_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_abs + +static void lm_ggml_compute_forward_abs_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + lm_ggml_vec_abs_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_abs( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_abs_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_sgn + +static void lm_ggml_compute_forward_sgn_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + lm_ggml_vec_sgn_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_sgn( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_sgn_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_neg + +static void lm_ggml_compute_forward_neg_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + lm_ggml_vec_neg_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_neg( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_neg_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_step + +static void lm_ggml_compute_forward_step_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + lm_ggml_vec_step_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_step( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_step_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_tanh + +static void lm_ggml_compute_forward_tanh_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + lm_ggml_vec_tanh_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_tanh( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_tanh_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_elu + +static void lm_ggml_compute_forward_elu_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + lm_ggml_vec_elu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_elu( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_elu_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_relu + +static void lm_ggml_compute_forward_relu_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + lm_ggml_vec_relu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_relu( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_relu_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_sigmoid + +static void lm_ggml_compute_forward_sigmoid_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + lm_ggml_vec_sigmoid_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_sigmoid( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_sigmoid_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_gelu + +static void lm_ggml_compute_forward_gelu_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = lm_ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + lm_ggml_vec_gelu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void lm_ggml_compute_forward_gelu( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_gelu_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_gelu_quick + +static void lm_ggml_compute_forward_gelu_quick_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = lm_ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + lm_ggml_vec_gelu_quick_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void lm_ggml_compute_forward_gelu_quick( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_gelu_quick_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_silu + +static void lm_ggml_compute_forward_silu_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = lm_ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + lm_ggml_vec_silu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void lm_ggml_compute_forward_silu( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_silu_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} +// lm_ggml_compute_forward_leaky_relu + +static void lm_ggml_compute_forward_leaky_relu_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + float negative_slope; + memcpy(&negative_slope, dst->op_params, sizeof(float)); + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + lm_ggml_vec_leaky_relu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope); + } +} + +static void lm_ggml_compute_forward_leaky_relu( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_leaky_relu_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_silu_back + +static void lm_ggml_compute_forward_silu_back_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * grad = dst->src[1]; + + assert(lm_ggml_is_contiguous_1(grad)); + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + assert(lm_ggml_are_same_shape(src0, grad)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = lm_ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + lm_ggml_vec_silu_backward_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1])), + (float *) ((char *) grad->data + i1*(grad->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void lm_ggml_compute_forward_silu_back( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_silu_back_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + + +static void lm_ggml_compute_forward_hardswish_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + lm_ggml_vec_hardswish_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} +static void lm_ggml_compute_forward_hardswish( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_hardswish_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +static void lm_ggml_compute_forward_hardsigmoid_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + lm_ggml_vec_hardsigmoid_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_hardsigmoid( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_hardsigmoid_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +static void lm_ggml_compute_forward_exp_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + lm_ggml_vec_exp_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_exp( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_exp_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + + +// lm_ggml_compute_forward_norm + +static void lm_ggml_compute_forward_norm_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + LM_GGML_ASSERT(eps > 0.0f); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + lm_ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (lm_ggml_float)x[i00]; + } + + float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + lm_ggml_float sum2 = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + float v = x[i00] - mean; + y[i00] = v; + sum2 += (lm_ggml_float)(v*v); + } + + float variance = sum2/ne00; + const float scale = 1.0f/sqrtf(variance + eps); + + lm_ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void lm_ggml_compute_forward_norm( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_norm_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_group_rms_norm + +static void lm_ggml_compute_forward_rms_norm_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + LM_GGML_ASSERT(eps > 0.0f); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + lm_ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (lm_ggml_float)(x[i00] * x[i00]); + } + + const float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + memcpy(y, x, ne00 * sizeof(float)); + // for (int i00 = 0; i00 < ne00; i00++) { + // y[i00] = x[i00]; + // } + + const float scale = 1.0f/sqrtf(mean + eps); + + lm_ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void lm_ggml_compute_forward_rms_norm( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_rms_norm_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +static void lm_ggml_compute_forward_rms_norm_back_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst) && lm_ggml_are_same_shape(src0, src1)); + + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + // src1 is same shape as src0 => same indices + const int64_t i11 = i01; + const int64_t i12 = i02; + const int64_t i13 = i03; + + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); + + lm_ggml_float sum_xx = 0.0; + lm_ggml_float sum_xdz = 0.0; + + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum_xx += (lm_ggml_float)(x[i00] * x[i00]); + sum_xdz += (lm_ggml_float)(x[i00] * dz[i00]); + } + + //const float mean = (float)(sum_xx)/ne00; + const float mean_eps = (float)(sum_xx)/ne00 + eps; + const float sum_eps = (float)(sum_xx) + eps*ne00; + //const float mean_xdz = (float)(sum_xdz)/ne00; + // we could cache rms from forward pass to improve performance. + // to do this implement lm_ggml_rms and compose lm_ggml_rms_norm using lm_ggml_rms. + //const float rms = sqrtf(mean_eps); + const float rrms = 1.0f / sqrtf(mean_eps); + //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3) + + { + // z = rms_norm(x) + // + // rms_norm(src0) = + // scale( + // src0, + // div( + // 1, + // sqrt( + // add( + // scale( + // sum( + // sqr( + // src0)), + // (1.0/N)), + // eps)))); + + // postorder: + // ## op args grad + // 00 param src0 grad[#00] + // 01 const 1 + // 02 sqr (#00) grad[#02] + // 03 sum (#02) grad[#03] + // 04 const 1/N + // 05 scale (#03, #04) grad[#05] + // 06 const eps + // 07 add (#05, #06) grad[#07] + // 08 sqrt (#07) grad[#08] + // 09 div (#01,#08) grad[#09] + // 10 scale (#00,#09) grad[#10] + // + // backward pass, given grad[#10] + // #10: scale + // grad[#00] += scale(grad[#10],#09) + // grad[#09] += sum(mul(grad[#10],#00)) + // #09: div + // grad[#08] += neg(mul(grad[#09], div(#09,#08))) + // #08: sqrt + // grad[#07] += mul(grad[#08], div(0.5, #08)) + // #07: add + // grad[#05] += grad[#07] + // #05: scale + // grad[#03] += scale(grad[#05],#04) + // #03: sum + // grad[#02] += repeat(grad[#03], #02) + // #02: + // grad[#00] += scale(mul(#00, grad[#02]), 2.0) + // + // substitute and simplify: + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) + // grad[#02] = repeat(grad[#03], #02) + // grad[#02] = repeat(scale(grad[#05],#04), #02) + // grad[#02] = repeat(scale(grad[#07],#04), #02) + // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02) + // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02) + // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02) + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N))) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps))) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps)) + // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps)) + // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps)) + // a = b*c + d*e + // a = b*c*f/f + d*e*f/f + // a = (b*c*f + d*e*f)*(1/f) + // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c)) + // a = (b + d*e/c)*c + // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps) + // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms + // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms + // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms + // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms + // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms + // a = (dz + x*div(-mean_xdz,mean_eps))*rrms + // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms) + // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + } + // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + // post-order: + // dx := x + // dx := scale(dx,-mean_xdz/mean_eps) + // dx := add(dx, dz) + // dx := scale(dx, rrms) + float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + lm_ggml_vec_cpy_f32 (ne00, dx, x); + // lm_ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); + lm_ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); + lm_ggml_vec_acc_f32 (ne00, dx, dz); + lm_ggml_vec_scale_f32(ne00, dx, rrms); + } + } + } +} + +static void lm_ggml_compute_forward_rms_norm_back( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_rms_norm_back_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_group_norm + +static void lm_ggml_compute_forward_group_norm_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + // TODO: optimize + + float eps; + memcpy(&eps, dst->op_params + 1, sizeof(float)); + + int n_channels = src0->ne[2]; + int n_groups = dst->op_params[0]; + int n_channels_per_group = (n_channels + n_groups - 1) / n_groups; + for (int i = ith; i < n_groups; i += nth) { + int start = i * n_channels_per_group; + int end = start + n_channels_per_group; + if (end > n_channels) { + end = n_channels; + } + int step = end - start; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + lm_ggml_float sum = 0.0; + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + + lm_ggml_float sumr = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sumr += (lm_ggml_float)x[i00]; + } + sum += sumr; + } + } + const float mean = sum / (ne00 * ne01 * step); + + lm_ggml_float sum2 = 0.0; + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + + float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); + + lm_ggml_float sumr = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + float v = x[i00] - mean; + y[i00] = v; + sumr += (lm_ggml_float)(v * v); + } + sum2 += sumr; + } + } + const float variance = sum2 / (ne00 * ne01 * step); + const float scale = 1.0f / sqrtf(variance + eps); + + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); + lm_ggml_vec_scale_f32(ne00, y, scale); + } + } + } + } +} + +static void lm_ggml_compute_forward_group_norm( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_group_norm_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_mul_mat + +static void lm_ggml_compute_forward_mul_mat_one_chunk( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst, + const enum lm_ggml_type type, + const int64_t num_rows_per_vec_dot, + const int64_t ir0_start, + const int64_t ir0_end, + const int64_t ir1_start, + const int64_t ir1_end) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const bool src1_cont = lm_ggml_is_contiguous(src1); + + lm_ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; + enum lm_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; + + // broadcast factors + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; + + //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end); + + // threads with no work simply yield (not sure if it helps) + if (ir0_start >= ir0_end || ir1_start >= ir1_end) { + return; + } + + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = lm_ggml_row_size(vec_dot_type, ne10); + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // block-tiling attempt + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11; + + // attempt to reduce false-sharing (does not seem to make a difference) + // 16 * 2, accounting for mmla kernels + float tmp[32]; + + for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { + for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) { + const int64_t i13 = (ir1 / (ne12 * ne1)); + const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; + const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); + + // broadcast src0 into src1 + const int64_t i03 = i13 / r3; + const int64_t i02 = i12 / r2; + + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; + + const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03); + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char*)wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size + : (i11 * nb11 + i12 * nb12 + i13 * nb13)); + float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} + + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + } + + for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { + memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); + } + } + } + } +} + +static void lm_ggml_compute_forward_mul_mat( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + enum lm_ggml_type type = src0->type; + + if (src0->buffer && lm_ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) { + type = (enum lm_ggml_type)(intptr_t)src0->extra; + } + + enum lm_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; + lm_ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float; + lm_ggml_from_float_to_mat_t const from_float_to_mat = type_traits_cpu[vec_dot_type].from_float_to_mat; + int64_t const vec_dot_num_rows = type_traits_cpu[type].nrows; + int64_t const matmul_num_cols = type_traits_cpu[type].ncols; + int64_t const blck_size_interleave = lm_ggml_get_type_traits(type)->blck_size_interleave; + lm_ggml_gemv_t const gemv = type_traits_cpu[type].gemv; + lm_ggml_gemm_t const gemm = type_traits_cpu[type].gemm; + + LM_GGML_ASSERT(ne0 == ne01); + LM_GGML_ASSERT(ne1 == ne11); + LM_GGML_ASSERT(ne2 == ne12); + LM_GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + LM_GGML_ASSERT(nb00 == lm_ggml_type_size(type)); + LM_GGML_ASSERT(nb10 == lm_ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + LM_GGML_ASSERT(nb0 == sizeof(float)); + LM_GGML_ASSERT(nb0 <= nb1); + LM_GGML_ASSERT(nb1 <= nb2); + LM_GGML_ASSERT(nb2 <= nb3); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + +#if LM_GGML_USE_LLAMAFILE + // broadcast factors + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; + + const bool src1_cont = lm_ggml_is_contiguous(src1); + + if (src1_cont) { + for (int64_t i13 = 0; i13 < ne13; i13++) + for (int64_t i12 = 0; i12 < ne12; i12++) + if (!llamafile_sgemm(ne01, ne11, ne00/lm_ggml_blck_size(type), + (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, + nb01/lm_ggml_type_size(type), + (const char *)src1->data + i12*nb12 + i13*nb13, + nb11/lm_ggml_type_size(src1->type), + (char *)dst->data + i12*nb2 + i13*nb3, + nb1/lm_ggml_type_size(dst->type), + ith, nth, + type, + src1->type, + dst->type)) + goto UseGgmlGemm1; + return; + } +UseGgmlGemm1:; +#endif + + if (src1->type != vec_dot_type) { + char * wdata = params->wdata; + + const size_t nbw1 = lm_ggml_row_size(vec_dot_type, ne10); + const size_t nbw2 = nbw1*ne11; + const size_t nbw3 = nbw2*ne12; + + assert(params->wsize >= ne13*nbw3); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + int64_t i11_processed = 0; + if ((lm_ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) { + for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { + from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + 4, ne10, blck_size_interleave); + } + i11_processed = ne11 - ne11 % 4; + } + for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); + } + } + } + } + + if (ith == 0) { + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + atomic_store_explicit(¶ms->threadpool->current_chunk, nth, memory_order_relaxed); + } + + lm_ggml_barrier(params->threadpool); + +#if LM_GGML_USE_LLAMAFILE + if (src1->type != vec_dot_type) { + const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = lm_ggml_row_size(vec_dot_type, ne10); + + for (int64_t i13 = 0; i13 < ne13; i13++) + for (int64_t i12 = 0; i12 < ne12; i12++) + if (!llamafile_sgemm(ne01, ne11, ne00/lm_ggml_blck_size(type), + (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, + nb01/lm_ggml_type_size(type), + (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, + row_size/lm_ggml_type_size(vec_dot_type), + (char *)dst->data + i12*nb2 + i13*nb3, + nb1/lm_ggml_type_size(dst->type), + ith, nth, + type, + vec_dot_type, + dst->type)) + goto UseGgmlGemm2; + return; + } +UseGgmlGemm2:; +#endif + + // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) + const int64_t nr0 = ne0; + + // This is the size of the rest of the dimensions of the result + const int64_t nr1 = ne1 * ne2 * ne3; + + // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols + int64_t num_rows_per_vec_dot = vec_dot_num_rows; + // TODO: currently the mmla kernels support only even numbered rows/cols. + // this check can be removed once they are extended to support odd numbered rows/cols too + if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { + num_rows_per_vec_dot = 1; + } + + // Now select a reasonable chunk size. + int chunk_size = 16; + + // We need to step up the size if it's small + if (nr0 == 1 || nr1 == 1) { + chunk_size = 64; + } + + // distribute the work across the inner or outer loop based on which one is larger + // The number of chunks in the 0/1 dim. + // CEIL(nr0/chunk_size) + int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; + int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; + + // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread. + // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915 + // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that. + if (nchunk0 * nchunk1 < nth * 4 || lm_ggml_is_numa()) { + // distribute the thread work across the inner or outer loop based on which one is larger + nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + } + + // The number of elements in each chunk + const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; + + if ((lm_ggml_n_dims(src0) == 2) && gemv) { + const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t src1_col_stride = lm_ggml_is_contiguous(src1) || src1->type != vec_dot_type ? lm_ggml_row_size(vec_dot_type, ne10) : nb11; + int64_t src0_start = (ith * ne01) / nth; + int64_t src0_end = ((ith + 1) * ne01) / nth; + src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start; + src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end; + if (src0_start >= src0_end) return; + + // If there are more than three rows in src1, use gemm; otherwise, use gemv. + if (gemm && (ne11 > 3)) { + gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); + } + for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) { + gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1, + src0_end - src0_start); + } + return; + } + + // The first chunk comes from our thread_id, the rest will get auto-assigned. + int current_chunk = ith; + + while (current_chunk < nchunk0 * nchunk1) { + const int64_t ith0 = current_chunk % nchunk0; + const int64_t ith1 = current_chunk / nchunk0; + + const int64_t ir0_start = dr0 * ith0; + const int64_t ir0_end = MIN(ir0_start + dr0, nr0); + + const int64_t ir1_start = dr1 * ith1; + const int64_t ir1_end = MIN(ir1_start + dr1, nr1); + + lm_ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); + + if (nth >= nchunk0 * nchunk1) { + break; + } + + current_chunk = atomic_fetch_add_explicit(¶ms->threadpool->current_chunk, 1, memory_order_relaxed); + } +} + +// lm_ggml_compute_forward_mul_mat_id + +static void lm_ggml_compute_forward_mul_mat_id( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + const struct lm_ggml_tensor * ids = dst->src[2]; + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const enum lm_ggml_type type = src0->type; + + const bool src1_cont = lm_ggml_is_contiguous(src1); + + lm_ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; + enum lm_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; + lm_ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float; + int64_t const matmul_num_cols = type_traits_cpu[type].ncols; + lm_ggml_gemv_t const gemv = type_traits_cpu[type].gemv; + + // we don't support permuted src0 or src1 + LM_GGML_ASSERT(nb00 == lm_ggml_type_size(type)); + LM_GGML_ASSERT(nb10 == lm_ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + LM_GGML_ASSERT(nb0 == sizeof(float)); + LM_GGML_ASSERT(nb0 <= nb1); + LM_GGML_ASSERT(nb1 <= nb2); + LM_GGML_ASSERT(nb2 <= nb3); + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + char * wdata_src1_end = (src1->type == vec_dot_type) ? + (char *) params->wdata : + (char *) params->wdata + LM_GGML_PAD(lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(src1)), sizeof(int64_t)); + + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + + int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] + + if (src1->type != vec_dot_type) { + char * wdata = params->wdata; + + const size_t nbw1 = lm_ggml_row_size(vec_dot_type, ne10); + const size_t nbw2 = nbw1*ne11; + const size_t nbw3 = nbw2*ne12; + + assert(params->wsize >= ne13*nbw3); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); + } + } + } + } + +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] + + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); + + // group rows by src0 matrix + for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int id = 0; id < n_ids; ++id) { + const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); + + assert(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; + matrix_row_counts[i02] += 1; + } + } + } + + lm_ggml_barrier(params->threadpool); + + // compute each matrix multiplication in sequence + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const char * src0_cur = (const char *) src0->data + cur_a*nb02; + + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = lm_ggml_row_size(vec_dot_type, ne10); + + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = cne1; // src1 rows + + if (((lm_ggml_n_dims(src0) - 1) == 2) && gemv) { + int64_t src0_cur_start = (ith * ne01) / nth; + int64_t src0_cur_end = ((ith + 1) * ne01) / nth; + src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start; + src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end; + if (src0_cur_start >= src0_cur_end) return; + + for (int ir1 = 0; ir1 < nr1; ir1++) { + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12 * ne11) * row_size + : (i11 * nb11 + i12 * nb12)); + + gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); + } + continue; + } + + // distribute the thread work across the inner or outer loop based on which one is larger + + const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + + const int64_t ith0 = ith % nth0; + const int64_t ith1 = ith / nth0; + + const int64_t dr0 = (nr0 + nth0 - 1)/nth0; + const int64_t dr1 = (nr1 + nth1 - 1)/nth1; + + const int64_t ir010 = dr0*ith0; + const int64_t ir011 = MIN(ir010 + dr0, nr0); + + const int64_t ir110 = dr1*ith1; + const int64_t ir111 = MIN(ir110 + dr1, nr1); + + // threads with no work simply yield (not sure if it helps) + //if (ir010 >= ir011 || ir110 >= ir111) { + // sched_yield(); + // continue; + //} + + // block-tiling attempt + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + // attempt to reduce false-sharing (does not seem to make a difference) + float tmp[16]; + + for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { + for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { + const int64_t _i12 = ir1; // logical row index for this expert + + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12*ne11)*row_size + : (i11*nb11 + i12*nb12)); + + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2)); + + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} + + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1); + } + + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + } + } + } + } + +#undef MMID_MATRIX_ROW +} + +// lm_ggml_compute_forward_out_prod + +static void lm_ggml_compute_forward_out_prod_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + + const int ith = params->ith; + const int nth = params->nth; + + LM_GGML_ASSERT(ne0 == ne00); + LM_GGML_ASSERT(ne1 == ne10); + LM_GGML_ASSERT(ne2 == ne02); + LM_GGML_ASSERT(ne02 == ne12); + LM_GGML_ASSERT(ne3 == ne13); + LM_GGML_ASSERT(ne03 == ne13); + + // we don't support permuted src0 or src1 + LM_GGML_ASSERT(nb00 == sizeof(float)); + + // dst cannot be transposed or permuted + LM_GGML_ASSERT(nb0 == sizeof(float)); + // LM_GGML_ASSERT(nb0 <= nb1); + // LM_GGML_ASSERT(nb1 <= nb2); + // LM_GGML_ASSERT(nb2 <= nb3); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + if (ith == 0) { + lm_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + } + lm_ggml_barrier(params->threadpool); + + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + // parallelize by last three dimensions + + // total rows in dst + const int64_t nr = ne1*ne2*ne3; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // block-tiling attempt + const int64_t blck_0 = MAX(LM_GGML_VEC_MAD_UNROLL, 32); + const int64_t blck_1 = 16; + + for (int64_t bir = ir0; bir < ir1; bir += blck_1) { + const int64_t bir1 = MIN(bir + blck_1, ir1); + for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { + const int64_t bne01 = MIN(bi01 + blck_0, ne01); + for (int64_t ir = bir; ir < bir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2; + const int64_t i03 = i3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; + +#if LM_GGML_VEC_MAD_UNROLL > 2 + const int64_t bne01_unroll = bne01 - (bne01 % LM_GGML_VEC_MAD_UNROLL); + for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += LM_GGML_VEC_MAD_UNROLL) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + lm_ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); + } + for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + lm_ggml_vec_mad_f32(ne0, d, s0, *s1); + } +#else + for (int64_t i01 = bi01; i01 < bne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + lm_ggml_vec_mad_f32(ne0, d, s0, *s1); + } +#endif + } + } + } +} + +static void lm_ggml_compute_forward_out_prod_q_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const enum lm_ggml_type type = src0->type; + lm_ggml_to_float_t const dequantize_row_q = lm_ggml_get_type_traits(type)->to_float; + + LM_GGML_ASSERT(ne02 == ne12); + LM_GGML_ASSERT(ne03 == ne13); + LM_GGML_ASSERT(ne2 == ne12); + LM_GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 dim0 + LM_GGML_ASSERT(nb00 == lm_ggml_type_size(type)); + + // dst dim0 cannot be transposed or permuted + LM_GGML_ASSERT(nb0 == sizeof(float)); + // LM_GGML_ASSERT(nb0 <= nb1); + // LM_GGML_ASSERT(nb1 <= nb2); + // LM_GGML_ASSERT(nb2 <= nb3); + + LM_GGML_ASSERT(ne0 == ne00); + LM_GGML_ASSERT(ne1 == ne10); + LM_GGML_ASSERT(ne2 == ne02); + LM_GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + if (ith == 0) { + lm_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + } + lm_ggml_barrier(params->threadpool); + + // parallelize by last three dimensions + + // total rows in dst + const int64_t nr = ne1*ne2*ne3; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; + + for (int64_t ir = ir0; ir < ir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2; + const int64_t i03 = i3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; + + for (int64_t i01 = 0; i01 < ne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + dequantize_row_q(s0, wdata, ne0); + lm_ggml_vec_mad_f32(ne0, d, wdata, *s1); + } + } +} + +static void lm_ggml_compute_forward_out_prod( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_Q4_0: + case LM_GGML_TYPE_Q4_1: + case LM_GGML_TYPE_Q5_0: + case LM_GGML_TYPE_Q5_1: + case LM_GGML_TYPE_Q8_0: + case LM_GGML_TYPE_Q2_K: + case LM_GGML_TYPE_Q3_K: + case LM_GGML_TYPE_Q4_K: + case LM_GGML_TYPE_Q5_K: + case LM_GGML_TYPE_Q6_K: + case LM_GGML_TYPE_TQ1_0: + case LM_GGML_TYPE_TQ2_0: + case LM_GGML_TYPE_IQ2_XXS: + case LM_GGML_TYPE_IQ2_XS: + case LM_GGML_TYPE_IQ3_XXS: + case LM_GGML_TYPE_IQ1_S: + case LM_GGML_TYPE_IQ1_M: + case LM_GGML_TYPE_IQ4_NL: + case LM_GGML_TYPE_IQ4_XS: + case LM_GGML_TYPE_IQ3_S: + case LM_GGML_TYPE_IQ2_S: + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + case LM_GGML_TYPE_Q4_0_8_8: + { + lm_ggml_compute_forward_out_prod_q_f32(params, dst); + } break; + case LM_GGML_TYPE_F16: + { + LM_GGML_ABORT("fatal error"); // todo + // lm_ggml_compute_forward_out_prod_f16_f32(params, dst); + } + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_out_prod_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_scale + +static void lm_ggml_compute_forward_scale_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(dst)); + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + + // scale factor + float v; + memcpy(&v, dst->op_params, sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = lm_ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const size_t nb01 = src0->nb[1]; + + const size_t nb1 = dst->nb[1]; + + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + } + lm_ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); + } +} + +static void lm_ggml_compute_forward_scale( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_scale_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_set + +static void lm_ggml_compute_forward_set_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); + + // view src0 and dst with these strides and data offset inbytes during set + // nb0 is implicitly element_size because src0 and dst are contiguous + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + if (params->ith == 0) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + lm_ggml_nbytes(dst)); + } + lm_ggml_barrier(params->threadpool); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src1); + const int nc = src1->ne[0]; + + LM_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + LM_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) + + // src0 and dst as viewed during set + const size_t nb0 = lm_ggml_element_size(src0); + + const int im0 = (ne10 == 0 ? 0 : ne10-1); + const int im1 = (ne11 == 0 ? 0 : ne11-1); + const int im2 = (ne12 == 0 ? 0 : ne12-1); + const int im3 = (ne13 == 0 ? 0 : ne13-1); + + LM_GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= lm_ggml_nbytes(dst)); + + LM_GGML_ASSERT(nb10 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are viewed with shape of src1 and offset + // => same indices + const int i3 = ir/(ne12*ne11); + const int i2 = (ir - i3*ne12*ne11)/ne11; + const int i1 = (ir - i3*ne12*ne11 - i2*ne11); + + lm_ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + } +} + +static void lm_ggml_compute_forward_set( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_set_f32(params, dst); + } break; + case LM_GGML_TYPE_F16: + case LM_GGML_TYPE_BF16: + case LM_GGML_TYPE_Q4_0: + case LM_GGML_TYPE_Q4_1: + case LM_GGML_TYPE_Q5_0: + case LM_GGML_TYPE_Q5_1: + case LM_GGML_TYPE_Q8_0: + case LM_GGML_TYPE_Q8_1: + case LM_GGML_TYPE_Q2_K: + case LM_GGML_TYPE_Q3_K: + case LM_GGML_TYPE_Q4_K: + case LM_GGML_TYPE_Q5_K: + case LM_GGML_TYPE_Q6_K: + case LM_GGML_TYPE_TQ1_0: + case LM_GGML_TYPE_TQ2_0: + case LM_GGML_TYPE_IQ2_XXS: + case LM_GGML_TYPE_IQ2_XS: + case LM_GGML_TYPE_IQ3_XXS: + case LM_GGML_TYPE_IQ1_S: + case LM_GGML_TYPE_IQ1_M: + case LM_GGML_TYPE_IQ4_NL: + case LM_GGML_TYPE_IQ4_XS: + case LM_GGML_TYPE_IQ3_S: + case LM_GGML_TYPE_IQ2_S: + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + case LM_GGML_TYPE_Q4_0_8_8: + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_cpy + +static void lm_ggml_compute_forward_cpy( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + lm_ggml_compute_forward_dup(params, dst); +} + +// lm_ggml_compute_forward_cont + +static void lm_ggml_compute_forward_cont( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + lm_ggml_compute_forward_dup(params, dst); +} + +// lm_ggml_compute_forward_reshape + +static void lm_ggml_compute_forward_reshape( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + // NOP + UNUSED(params); + UNUSED(dst); +} + +// lm_ggml_compute_forward_view + +static void lm_ggml_compute_forward_view( + const struct lm_ggml_compute_params * params, + const struct lm_ggml_tensor * dst) { + // NOP + UNUSED(params); + UNUSED(dst); +} + +// lm_ggml_compute_forward_permute + +static void lm_ggml_compute_forward_permute( + const struct lm_ggml_compute_params * params, + const struct lm_ggml_tensor * dst) { + // NOP + UNUSED(params); + UNUSED(dst); +} + +// lm_ggml_compute_forward_transpose + +static void lm_ggml_compute_forward_transpose( + const struct lm_ggml_compute_params * params, + const struct lm_ggml_tensor * dst) { + // NOP + UNUSED(params); + UNUSED(dst); +} + +// lm_ggml_compute_forward_get_rows + +static void lm_ggml_compute_forward_get_rows_q( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = lm_ggml_nelements(src1); + + const enum lm_ggml_type type = src0->type; + lm_ggml_to_float_t const dequantize_row_q = lm_ggml_get_type_traits(type)->to_float; + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == lm_ggml_type_size(type)); + assert(lm_ggml_nrows(dst) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i/(ne11*ne10); + const int64_t i11 = (i - i12*ne11*ne10)/ne10; + const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + LM_GGML_ASSERT(i01 >= 0 && i01 < ne01); + + dequantize_row_q( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } +} + +static void lm_ggml_compute_forward_get_rows_f16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = lm_ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(lm_ggml_fp16_t)); + assert(lm_ggml_nrows(dst) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i/(ne11*ne10); + const int64_t i11 = (i - i12*ne11*ne10)/ne10; + const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + LM_GGML_ASSERT(i01 >= 0 && i01 < ne01); + + lm_ggml_fp16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } +} + +static void lm_ggml_compute_forward_get_rows_bf16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = lm_ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(lm_ggml_bf16_t)); + assert(lm_ggml_nrows(dst) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i/(ne11*ne10); + const int64_t i11 = (i - i12*ne11*ne10)/ne10; + const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + LM_GGML_ASSERT(i01 >= 0 && i01 < ne01); + + lm_ggml_bf16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } +} + +static void lm_ggml_compute_forward_get_rows_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = lm_ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(float)); + assert(lm_ggml_nrows(dst) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i/(ne11*ne10); + const int64_t i11 = (i - i12*ne11*ne10)/ne10; + const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + LM_GGML_ASSERT(i01 >= 0 && i01 < ne01); + + lm_ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), + (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + } +} + +static void lm_ggml_compute_forward_get_rows( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_Q4_0: + case LM_GGML_TYPE_Q4_1: + case LM_GGML_TYPE_Q5_0: + case LM_GGML_TYPE_Q5_1: + case LM_GGML_TYPE_Q8_0: + case LM_GGML_TYPE_Q8_1: + case LM_GGML_TYPE_Q2_K: + case LM_GGML_TYPE_Q3_K: + case LM_GGML_TYPE_Q4_K: + case LM_GGML_TYPE_Q5_K: + case LM_GGML_TYPE_Q6_K: + case LM_GGML_TYPE_TQ1_0: + case LM_GGML_TYPE_TQ2_0: + case LM_GGML_TYPE_IQ2_XXS: + case LM_GGML_TYPE_IQ2_XS: + case LM_GGML_TYPE_IQ3_XXS: + case LM_GGML_TYPE_IQ1_S: + case LM_GGML_TYPE_IQ1_M: + case LM_GGML_TYPE_IQ4_NL: + case LM_GGML_TYPE_IQ4_XS: + case LM_GGML_TYPE_IQ3_S: + case LM_GGML_TYPE_IQ2_S: + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + case LM_GGML_TYPE_Q4_0_8_8: + { + lm_ggml_compute_forward_get_rows_q(params, dst); + } break; + case LM_GGML_TYPE_F16: + { + lm_ggml_compute_forward_get_rows_f16(params, dst); + } break; + case LM_GGML_TYPE_BF16: + { + lm_ggml_compute_forward_get_rows_bf16(params, dst); + } break; + case LM_GGML_TYPE_F32: + case LM_GGML_TYPE_I32: + { + lm_ggml_compute_forward_get_rows_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } + + //static bool first = true; + //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + //if (first) { + // first = false; + //} else { + // for (int k = 0; k < dst->ne[1]; ++k) { + // for (int j = 0; j < dst->ne[0]/16; ++j) { + // for (int i = 0; i < 16; ++i) { + // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + // } + // printf("\n"); + // } + // printf("\n"); + // } + // printf("\n"); + // exit(0); + //} +} + +// lm_ggml_compute_forward_get_rows_back + +static void lm_ggml_compute_forward_get_rows_back_f32_f16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + if (params->ith != 0) { + return; + } + + LM_GGML_ASSERT(lm_ggml_is_contiguous(dst)); + + // lm_ggml_compute_forward_dup_same_cont(params, opt0, dst); + + memset(dst->data, 0, lm_ggml_nbytes(dst)); + + const int nc = src0->ne[0]; + const int nr = lm_ggml_nelements(src1); + + LM_GGML_ASSERT( dst->ne[0] == nc); + LM_GGML_ASSERT(src0->nb[0] == sizeof(lm_ggml_fp16_t)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + for (int j = 0; j < nc; ++j) { + lm_ggml_fp16_t v = ((lm_ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j]; + ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += LM_GGML_FP16_TO_FP32(v); + } + } +} + +static void lm_ggml_compute_forward_get_rows_back_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + if (params->ith != 0) { + return; + } + + LM_GGML_ASSERT(lm_ggml_is_contiguous(dst)); + + // lm_ggml_compute_forward_dup_same_cont(params, opt0, dst); + + memset(dst->data, 0, lm_ggml_nbytes(dst)); + + const int nc = src0->ne[0]; + const int nr = lm_ggml_nelements(src1); + + LM_GGML_ASSERT( dst->ne[0] == nc); + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + lm_ggml_vec_add_f32(nc, + (float *) ((char *) dst->data + r*dst->nb[1]), + (float *) ((char *) dst->data + r*dst->nb[1]), + (float *) ((char *) src0->data + i*src0->nb[1])); + } +} + +static void lm_ggml_compute_forward_get_rows_back( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F16: + { + lm_ggml_compute_forward_get_rows_back_f32_f16(params, dst); + } break; + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_get_rows_back_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } + + //static bool first = true; + //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + //if (first) { + // first = false; + //} else { + // for (int k = 0; k < dst->ne[1]; ++k) { + // for (int j = 0; j < dst->ne[0]/16; ++j) { + // for (int i = 0; i < 16; ++i) { + // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + // } + // printf("\n"); + // } + // printf("\n"); + // } + // printf("\n"); + // exit(0); + //} +} + +// lm_ggml_compute_forward_diag + +static void lm_ggml_compute_forward_diag_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + // TODO: handle transposed/permuted matrices + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + LM_GGML_ASSERT(ne00 == ne0); + LM_GGML_ASSERT(ne00 == ne1); + LM_GGML_ASSERT(ne01 == 1); + LM_GGML_ASSERT(ne02 == ne2); + LM_GGML_ASSERT(ne03 == ne3); + + LM_GGML_ASSERT(nb00 == sizeof(float)); + LM_GGML_ASSERT(nb0 == sizeof(float)); + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + for (int i1 = 0; i1 < ne1; i1++) { + float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02); + for (int i0 = 0; i0 < i1; i0++) { + d[i0] = 0; + } + d[i1] = s[i1]; + for (int i0 = i1+1; i0 < ne0; i0++) { + d[i0] = 0; + } + } + } + } +} + +static void lm_ggml_compute_forward_diag( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_diag_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_diag_mask_inf + +static void lm_ggml_compute_forward_diag_mask_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst, + const float value) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + const int ith = params->ith; + const int nth = params->nth; + + const int n_past = ((int32_t *) dst->op_params)[0]; + const bool inplace = src0->data == dst->data; + + LM_GGML_ASSERT(n_past >= 0); + + if (!inplace) { + if (ith == 0) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); + memcpy( + ((char *) dst->data), + ((char *) src0->data), + lm_ggml_nbytes(dst)); + } + lm_ggml_barrier(params->threadpool); + } + + // TODO: handle transposed/permuted matrices + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + const int nr = src0->ne[1]; + const int nz = n/nr; + + LM_GGML_ASSERT( dst->nb[0] == sizeof(float)); + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int k = 0; k < nz; k++) { + for (int j = ith; j < nr; j += nth) { + for (int i = n_past; i < nc; i++) { + if (i > n_past + j) { + *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value; + } + } + } + } +} + +static void lm_ggml_compute_forward_diag_mask_inf( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +static void lm_ggml_compute_forward_diag_mask_zero( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_diag_mask_f32(params, dst, 0); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_soft_max + +static void lm_ggml_compute_forward_soft_max_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + assert(lm_ggml_is_contiguous(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + //const int64_t ne11 = src1 ? src1->ne[1] : 1; + + // TODO: is this supposed to be ceil instead of floor? + // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 + const uint32_t n_head = ne02; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const int nc = src0->ne[0]; + const int nr = lm_ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; + + const bool use_f16 = (src1 && src1->type == LM_GGML_TYPE_F16); + + for (int i1 = ir0; i1 < ir1; i1++) { + // ALiBi + const uint32_t h = (i1/ne01)%ne02; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + + float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); + float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); + + // broadcast the mask across rows + lm_ggml_fp16_t * mp_f16 = src1 ? (lm_ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + + lm_ggml_vec_cpy_f32 (nc, wp, sp); + lm_ggml_vec_scale_f32(nc, wp, scale); + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*LM_GGML_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*mp_f32[i]; + } + } + } + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(wp[i])); + } +#endif + + float max = -INFINITY; + lm_ggml_vec_max_f32(nc, &max, wp); + + lm_ggml_float sum = lm_ggml_vec_soft_max_f32(nc, dp, wp, max); + assert(sum > 0.0); + + sum = 1.0/sum; + lm_ggml_vec_scale_f32(nc, dp, sum); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(dp[i])); + assert(!isinf(dp[i])); + } +#endif + } +} + +static void lm_ggml_compute_forward_soft_max( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_soft_max_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + + +// lm_ggml_compute_forward_soft_max_back + +static void lm_ggml_compute_forward_soft_max_back_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(src1)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(dst)); + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + LM_GGML_ASSERT(lm_ggml_are_same_shape(src1, dst)); + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = lm_ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float *dy = (float *)((char *) src0->data + i1*src0->nb[1]); + float *y = (float *)((char *) src1->data + i1*src1->nb[1]); + float *dx = (float *)((char *) dst->data + i1*dst->nb[1]); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(dy[i])); + assert(!isnan(y[i])); + } +#endif + // Jii = yi - yi*yi + // Jij = -yi*yj + // J = diag(y)-y.T*y + // dx = J * dy + // dxk = sum_i(Jki * dyi) + // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk + // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk + // dxk = sum_i(-yk*yi * dyi) + yk*dyk + // dxk = -yk * sum_i(yi * dyi) + yk*dyk + // dxk = -yk * dot(y, dy) + yk*dyk + // dxk = yk * (- dot(y, dy) + dyk) + // dxk = yk * (dyk - dot(y, dy)) + // + // post-order: + // dot_y_dy := dot(y, dy) + // dx := dy + // dx := dx - dot_y_dy + // dx := dx * y + + // linear runtime, no additional memory + float dot_y_dy = 0; + lm_ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1); + lm_ggml_vec_cpy_f32 (nc, dx, dy); + lm_ggml_vec_acc1_f32(nc, dx, -dot_y_dy); + lm_ggml_vec_mul_f32 (nc, dx, dx, y); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(dx[i])); + assert(!isinf(dx[i])); + } +#endif + } +} + +static void lm_ggml_compute_forward_soft_max_back( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_soft_max_back_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_clamp + +static void lm_ggml_compute_forward_clamp_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + float min; + float max; + memcpy(&min, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + LM_GGML_ASSERT( nb0 == sizeof(float)); + LM_GGML_ASSERT(nb00 == sizeof(float)); + + for (int j = ith; j < n; j += nth) { + float * dst_ptr = (float *) ((char *) dst->data + j*nb1); + float * src0_ptr = (float *) ((char *) src0->data + j*nb01); + + for (int i = 0; i < nc; i++) { + dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min); + } + } +} + +static void lm_ggml_compute_forward_clamp( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_clamp_f32(params, dst); + } break; + case LM_GGML_TYPE_F16: + case LM_GGML_TYPE_BF16: + case LM_GGML_TYPE_Q4_0: + case LM_GGML_TYPE_Q4_1: + case LM_GGML_TYPE_Q5_0: + case LM_GGML_TYPE_Q5_1: + case LM_GGML_TYPE_Q8_0: + case LM_GGML_TYPE_Q8_1: + case LM_GGML_TYPE_Q2_K: + case LM_GGML_TYPE_Q3_K: + case LM_GGML_TYPE_Q4_K: + case LM_GGML_TYPE_Q5_K: + case LM_GGML_TYPE_Q6_K: + case LM_GGML_TYPE_TQ1_0: + case LM_GGML_TYPE_TQ2_0: + case LM_GGML_TYPE_IQ2_XXS: + case LM_GGML_TYPE_IQ2_XS: + case LM_GGML_TYPE_IQ3_XXS: + case LM_GGML_TYPE_IQ1_S: + case LM_GGML_TYPE_IQ1_M: + case LM_GGML_TYPE_IQ4_NL: + case LM_GGML_TYPE_IQ4_XS: + case LM_GGML_TYPE_IQ3_S: + case LM_GGML_TYPE_IQ2_S: + case LM_GGML_TYPE_Q8_K: + case LM_GGML_TYPE_Q4_0_4_4: + case LM_GGML_TYPE_Q4_0_4_8: + case LM_GGML_TYPE_Q4_0_8_8: + case LM_GGML_TYPE_I8: + case LM_GGML_TYPE_I16: + case LM_GGML_TYPE_I32: + case LM_GGML_TYPE_I64: + case LM_GGML_TYPE_F64: + case LM_GGML_TYPE_COUNT: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_rope + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / MAX(0.001f, high - low); + return 1 - MIN(1, MAX(0, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + float * cos_theta, float * sin_theta) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + *cos_theta = cosf(theta) * mscale; + *sin_theta = sinf(theta) * mscale; +} + +static void lm_ggml_rope_cache_init( + float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, + float * cache, float sin_sign, float theta_scale) { + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py + float theta = theta_base; + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0/2] : 1.0f; + rope_yarn( + theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] + ); + cache[i0 + 1] *= sin_sign; + + theta *= theta_scale; + } +} + +static void lm_ggml_compute_forward_rope_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst, + const bool forward) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + const struct lm_ggml_tensor * src2 = dst->src[2]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + LM_GGML_ASSERT(nb00 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(dst); + + LM_GGML_ASSERT(n_dims <= ne0); + LM_GGML_ASSERT(n_dims % 2 == 0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + float corr_dims[2]; + lm_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + const bool is_neox = mode & LM_GGML_ROPE_TYPE_NEOX; + + const float * freq_factors = NULL; + if (src2 != NULL) { + LM_GGML_ASSERT(src2->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; + } + + // backward process uses inverse rotation by cos and sin. + // cos and sin build a rotation matrix, where the inverse is the transpose. + // this essentially just switches the sign of sin. + const float sin_sign = forward ? 1.0f : -1.0f; + + const int32_t * pos = (const int32_t *) src1->data; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; + + float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; + lm_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (ir++ < ir0) continue; + if (ir > ir1) break; + + if (!is_neox) { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } + } else { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } + } + + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } + } + } +} + +// TODO: deduplicate f16/f32 code +static void lm_ggml_compute_forward_rope_f16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst, + const bool forward) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + const struct lm_ggml_tensor * src2 = dst->src[2]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + LM_GGML_ASSERT(nb0 == sizeof(lm_ggml_fp16_t)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(dst); + + LM_GGML_ASSERT(n_dims <= ne0); + LM_GGML_ASSERT(n_dims % 2 == 0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + float corr_dims[2]; + lm_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + const bool is_neox = mode & LM_GGML_ROPE_TYPE_NEOX; + + const float * freq_factors = NULL; + if (src2 != NULL) { + LM_GGML_ASSERT(src2->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; + } + + // backward process uses inverse rotation by cos and sin. + // cos and sin build a rotation matrix, where the inverse is the transpose. + // this essentially just switches the sign of sin. + const float sin_sign = forward ? 1.0f : -1.0f; + + const int32_t * pos = (const int32_t *) src1->data; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; + + float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; + lm_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (ir++ < ir0) continue; + if (ir > ir1) break; + + if (!is_neox) { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = LM_GGML_FP16_TO_FP32(src[0]); + const float x1 = LM_GGML_FP16_TO_FP32(src[1]); + + dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[1] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } else { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = LM_GGML_FP16_TO_FP32(src[0]); + const float x1 = LM_GGML_FP16_TO_FP32(src[n_dims/2]); + + dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } + + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } + } + } +} + +static void lm_ggml_compute_forward_rope( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F16: + { + lm_ggml_compute_forward_rope_f16(params, dst, true); + } break; + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_rope_f32(params, dst, true); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_rope_back + +static void lm_ggml_compute_forward_rope_back( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F16: + { + lm_ggml_compute_forward_rope_f16(params, dst, false); + } break; + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_rope_f32(params, dst, false); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_conv_transpose_1d + +static void lm_ggml_compute_forward_conv_transpose_1d_f16_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00*ne01*ne02; + + LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); + LM_GGML_ASSERT(nb10 == sizeof(float)); + + if (ith == 0) { + memset(params->wdata, 0, params->wsize); + + // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) + { + lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) params->wdata + 0; + + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + lm_ggml_fp16_t * dst_data = wdata + i01*ne00*ne02; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ne02 + i02] = src[i00]; + } + } + } + } + + // permute source data (src1) from (L x Cin) to (Cin x L) + { + lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) params->wdata + nk; + lm_ggml_fp16_t * dst_data = wdata; + + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne11 + i11] = LM_GGML_FP32_TO_FP16(src[i10]); + } + } + } + + // need to zero dst since we are accumulating into it + memset(dst->data, 0, lm_ggml_nbytes(dst)); + } + lm_ggml_barrier(params->threadpool); + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + + // total rows in dst + const int nr = ne1; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) params->wdata + 0; + lm_ggml_fp16_t * const wdata_src = wdata + nk; + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + lm_ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + lm_ggml_vec_dot_f16(ne02, &v, 0, + (lm_ggml_fp16_t *) wdata_src + i1n, 0, + (lm_ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1); + dst_data[i10*s0 + i00] += v; + } + } + } +} + +static void lm_ggml_compute_forward_conv_transpose_1d_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00*ne01*ne02; + + LM_GGML_ASSERT(nb00 == sizeof(float)); + LM_GGML_ASSERT(nb10 == sizeof(float)); + + if (ith == 0) { + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) + { + float * const wdata = (float *) params->wdata + 0; + + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i01*ne00*ne02; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ne02 + i02] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + nk; + float * dst_data = wdata; + + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne11 + i11] = src[i10]; + } + } + } + + // need to zero dst since we are accumulating into it + memset(dst->data, 0, lm_ggml_nbytes(dst)); + } + lm_ggml_barrier(params->threadpool); + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + + // total rows in dst + const int nr = ne1; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * const wdata = (float *) params->wdata + 0; + float * const wdata_src = wdata + nk; + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + float * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + lm_ggml_vec_dot_f32(ne02, &v, 0, + wdata_src + i1n, 0, + wdata_kernel + i00*ne02, 0, 1); + dst_data[i10*s0 + i00] += v; + } + } + } +} + +static void lm_ggml_compute_forward_conv_transpose_1d( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F16: + { + lm_ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst); + } break; + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_conv_transpose_1d_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_im2col_f32 +// src0: kernel [OC, IC, KH, KW] +// src1: image [N, IC, IH, IW] +// dst: result [N, OH, OW, IC*KH*KW] +static void lm_ggml_compute_forward_im2col_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); + + LM_GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = is_2D ? ne13 : ne12; + const int64_t IC = is_2D ? ne12 : ne11; + const int64_t IH = is_2D ? ne11 : 1; + const int64_t IW = ne10; + + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + + const int64_t OH = is_2D ? ne2 : 1; + const int64_t OW = ne1; + + int ofs0 = is_2D ? nb13 : nb12; + int ofs1 = is_2D ? nb12 : nb11; + + LM_GGML_ASSERT(nb10 == sizeof(float)); + + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] + { + float * const wdata = (float *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { // 1 + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + + // micro kernel + float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] + + for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; + } else { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]); + } + } + } + } + } + } + } + } +} + + +// lm_ggml_compute_forward_im2col_f16 +// src0: kernel [OC, IC, KH, KW] +// src1: image [N, IC, IH, IW] +// dst: result [N, OH, OW, IC*KH*KW] +static void lm_ggml_compute_forward_im2col_f16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F16); + + LM_GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = is_2D ? ne13 : ne12; + const int64_t IC = is_2D ? ne12 : ne11; + const int64_t IH = is_2D ? ne11 : 1; + const int64_t IW = ne10; + + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + + const int64_t OH = is_2D ? ne2 : 1; + const int64_t OW = ne1; + + int ofs0 = is_2D ? nb13 : nb12; + int ofs1 = is_2D ? nb12 : nb11; + + LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); + LM_GGML_ASSERT(nb10 == sizeof(float)); + + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] + { + lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { // 1 + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + + // micro kernel + lm_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] + + for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; + } else { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = LM_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); + } + } + } + } + } + } + } + } +} + +static void lm_ggml_compute_forward_im2col( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + switch (dst->type) { + case LM_GGML_TYPE_F16: + { + lm_ggml_compute_forward_im2col_f16(params, dst); + } break; + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_im2col_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_im2col_back_f32 + +static void lm_ggml_compute_forward_im2col_back_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); + + LM_GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = is_2D ? ne3 : ne2; + const int64_t IC = is_2D ? ne2 : ne1; + const int64_t IH = is_2D ? ne1 : 1; + const int64_t IW = ne0; + + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + + const int64_t OH = is_2D ? ne12 : 1; + const int64_t OW = ne11; + + int ofs0 = is_2D ? nb3 : nb2; + int ofs1 = is_2D ? nb2 : nb1; + + LM_GGML_ASSERT(nb0 == sizeof(float)); + + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] + { + float * const wdata = (float *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + for (int64_t iih = 0; iih < IH; iih++) { + for (int64_t iiw = 0; iiw < IW; iiw++) { + + // micro kernel + float grad = 0.0f; + for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikw = 0; ikw < KW; ikw++) { + // For s0 > 1 some values were skipped over in the forward pass. + // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well. + const int64_t tmpw = (iiw + p0 - ikw*d0); + if (tmpw % s0 != 0) { + continue; + } + const int64_t iow = tmpw / s0; + + // Equivalent logic as above except for s1. + int64_t ioh; + if (is_2D) { + const int64_t tmph = iih + p1 - ikh*d1; + + if (tmph % s1 != 0) { + continue; + } + + ioh = tmph / s1; + } else { + ioh = 0; + } + + if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) { + continue; + } + + const float * const src_data = (const float *) src1->data + + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + grad += src_data[iic*(KH*KW) + ikh*KW + ikw]; + } + } + float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW] + dst_data[iih*IW + iiw] = grad; + } + } + } + } + } +} + +// lm_ggml_compute_forward_conv_transpose_2d + +static void lm_ggml_compute_forward_conv_transpose_2d( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00*ne01*ne02*ne03; + + LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); + LM_GGML_ASSERT(nb10 == sizeof(float)); + + if (ith == 0) { + memset(params->wdata, 0, params->wsize); + + // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) + { + lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) params->wdata + 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02); + lm_ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03; + for (int64_t i01 = 0; i01 < ne01; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00]; + } + } + } + } + } + + // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh) + { + lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) params->wdata + nk; + for (int i12 = 0; i12 < ne12; i12++) { + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11); + lm_ggml_fp16_t * dst_data = wdata + i11*ne10*ne12; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne12 + i12] = LM_GGML_FP32_TO_FP16(src[i10]); + } + } + } + } + + memset(dst->data, 0, lm_ggml_nbytes(dst)); + } + lm_ggml_barrier(params->threadpool); + + const int32_t stride = lm_ggml_get_op_params_i32(dst, 0); + + // total patches in dst + const int np = ne2; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + const int ip0 = dp*ith; + const int ip1 = MIN(ip0 + dp, np); + + lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) params->wdata + 0; + lm_ggml_fp16_t * const wdata_src = wdata + nk; + + for (int i2 = ip0; i2 < ip1; i2++) { // Cout + float * dst_data = (float *)((char *) dst->data + i2*nb2); + lm_ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; + for (int i11 = 0; i11 < ne11; i11++) { + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i11*ne10*ne12 + i10*ne12; + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + lm_ggml_vec_dot_f16(ne03, &v, 0, + wdata_src + i1n, 0, + wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; + } + } + } + } + } +} + +// lm_ggml_compute_forward_pool_1d_sk_p0 + +static void lm_ggml_compute_forward_pool_1d_sk_p0( + const struct lm_ggml_compute_params * params, + const enum lm_ggml_op_pool op, + const int k, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src = dst->src[0]; + + assert(src->type == LM_GGML_TYPE_F32 || src->type == LM_GGML_TYPE_F16); + + if (params->ith != 0) { + return; + } + + const char * cdata = (const char *)src->data; + const char * const data_end = cdata + lm_ggml_nbytes(src); + float * drow = (float *)dst->data; + + const int64_t rs = dst->ne[0]; + + while (cdata < data_end) { + const void * srow = (const void *)cdata; + int j = 0; + for (int64_t i = 0; i < rs; ++i) { + switch (op) { + case LM_GGML_OP_POOL_AVG: drow[i] = 0; break; + case LM_GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break; + case LM_GGML_OP_POOL_COUNT: LM_GGML_ABORT("fatal error"); + } + for (int ki = 0; ki < k; ++ki) { + const float srow_j = (src->type == LM_GGML_TYPE_F32) ? ((const float*)srow)[j] : LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t*)srow)[j]); + switch (op) { + case LM_GGML_OP_POOL_AVG: drow[i] += srow_j; break; + case LM_GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break; + case LM_GGML_OP_POOL_COUNT: LM_GGML_ABORT("fatal error"); + } + ++j; + } + switch (op) { + case LM_GGML_OP_POOL_AVG: drow[i] /= k; break; + case LM_GGML_OP_POOL_MAX: break; + case LM_GGML_OP_POOL_COUNT: LM_GGML_ABORT("fatal error"); + } + } + + cdata += src->nb[1]; + drow += rs; + } +} + +// lm_ggml_compute_forward_pool_1d + +static void lm_ggml_compute_forward_pool_1d( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const int32_t * opts = (const int32_t *)dst->op_params; + enum lm_ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int s0 = opts[2]; + const int p0 = opts[3]; + LM_GGML_ASSERT(p0 == 0); // padding not supported + LM_GGML_ASSERT(k0 == s0); // only s = k supported + + lm_ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst); +} + +// lm_ggml_compute_forward_pool_2d + +static void lm_ggml_compute_forward_pool_2d( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src = dst->src[0]; + + assert(src->type == LM_GGML_TYPE_F32 || src->type == LM_GGML_TYPE_F16); + + if (params->ith != 0) { + return; + } + + const int32_t * opts = (const int32_t *)dst->op_params; + enum lm_ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; + const char * cdata = (const char*)src->data; + const char * const data_end = cdata + lm_ggml_nbytes(src); + + const int64_t px = dst->ne[0]; + const int64_t py = dst->ne[1]; + const int64_t pa = px * py; + + float * dplane = (float *)dst->data; + + const int ka = k0 * k1; + const int offset0 = -p0; + const int offset1 = -p1; + + while (cdata < data_end) { + for (int oy = 0; oy < py; ++oy) { + float * const drow = dplane + oy * px; + for (int ox = 0; ox < px; ++ox) { + float * const out = drow + ox; + switch (op) { + case LM_GGML_OP_POOL_AVG: *out = 0; break; + case LM_GGML_OP_POOL_MAX: *out = -FLT_MAX; break; + case LM_GGML_OP_POOL_COUNT: LM_GGML_ABORT("fatal error"); + } + + const int ix = offset0 + ox * s0; + const int iy = offset1 + oy * s1; + + for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= src->ne[1]) continue; + const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky)); + for (int kx = 0; kx < k0; ++kx) { + int j = ix + kx; + if (j < 0 || j >= src->ne[0]) continue; + const float srow_j = (src->type == LM_GGML_TYPE_F32) ? ((const float*)srow)[j] : LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t*)srow)[j]); + switch (op) { + case LM_GGML_OP_POOL_AVG: *out += srow_j; break; + case LM_GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break; + case LM_GGML_OP_POOL_COUNT: LM_GGML_ABORT("fatal error"); + } + } + } + switch (op) { + case LM_GGML_OP_POOL_AVG: *out /= ka; break; + case LM_GGML_OP_POOL_MAX: break; + case LM_GGML_OP_POOL_COUNT: LM_GGML_ABORT("fatal error"); + } + } + } + + cdata += src->nb[2]; + dplane += pa; + } +} + +// lm_ggml_compute_forward_pool_2d_back + +static void lm_ggml_compute_forward_pool_2d_back( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src = dst->src[0]; + const struct lm_ggml_tensor * dstf = dst->src[1]; // forward tensor of dst + + assert(dst->type == LM_GGML_TYPE_F32 || dst->type == LM_GGML_TYPE_F16); + + if (params->ith != 0) { + return; + } + + const int32_t * opts = (const int32_t *)dst->op_params; + enum lm_ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; + + char * cdata = (char *) dst->data; + const char * cdataf = (const char *) dstf->data; + const char * const data_end = cdata + lm_ggml_nbytes(dst); + + LM_GGML_ASSERT(params->ith == 0); + memset(cdata, 0, lm_ggml_nbytes(dst)); + + const int64_t px = src->ne[0]; + const int64_t py = src->ne[1]; + const int64_t pa = px * py; + + const float * splane = (const float *) src->data; + + const int ka = k0 * k1; + const int offset0 = -p0; + const int offset1 = -p1; + + while (cdata < data_end) { + for (int oy = 0; oy < py; ++oy) { + const float * const srow = splane + oy * px; + for (int ox = 0; ox < px; ++ox) { + const float grad0 = srow[ox]; + + const int ix = offset0 + ox * s0; + const int iy = offset1 + oy * s1; + + if (op == LM_GGML_OP_POOL_MAX) { + float maxval = -FLT_MAX; + int kxmax = -1; + int kymax = -1; + + for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= dst->ne[1]) { + continue; + } + const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky)); + for (int kx = 0; kx < k0; ++kx) { + int j = ix + kx; + if (j < 0 || j >= dst->ne[0]) { + continue; + } + + const float val = dst->type == LM_GGML_TYPE_F32 ? + ((const float *) drowf)[j] : LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t *) drowf)[j]); + if (val <= maxval) { + continue; + } + + maxval = val; + kxmax = kx; + kymax = ky; + } + } + + if (kxmax == -1 || kymax == -1) { + continue; + } + + void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax)); + const int j = ix + kxmax; + if (dst->type == LM_GGML_TYPE_F32) { + ((float *) drow)[j] += grad0; + } else { + ((lm_ggml_fp16_t *) drow)[j] = LM_GGML_FP32_TO_FP16(grad0 + LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t *) drow)[j])); + } + } else if (op == LM_GGML_OP_POOL_AVG) { + const float grad = grad0 / ka; + + for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= dst->ne[1]) { + continue; + } + void * drow = (void *)(cdata + dst->nb[1] * (iy + ky)); + for (int kx = 0; kx < k0; ++kx) { + int j = ix + kx; + if (j < 0 || j >= dst->ne[0]) { + continue; + } + + if (dst->type == LM_GGML_TYPE_F32) { + ((float *) drow)[j] += grad; + } else { + ((lm_ggml_fp16_t *) drow)[j] += LM_GGML_FP32_TO_FP16(grad); + } + } + } + } else { + LM_GGML_ASSERT(false); + } + } + } + + cdata += dst->nb[2]; + cdataf += dst->nb[2]; + splane += pa; + } +} + +// lm_ggml_compute_forward_upscale + +static void lm_ggml_compute_forward_upscale_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + + const int ith = params->ith; + const int nth = params->nth; + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + const float sf0 = (float)ne0/src0->ne[0]; + const float sf1 = (float)ne1/src0->ne[1]; + const float sf2 = (float)ne2/src0->ne[2]; + const float sf3 = (float)ne3/src0->ne[3]; + + // TODO: optimize + + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3 / sf3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2 / sf2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const int64_t i01 = i1 / sf1; + for (int64_t i0 = 0; i0 < ne0; i0++) { + const int64_t i00 = i0 / sf0; + + const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + *y = *x; + } + } + } + } +} + +static void lm_ggml_compute_forward_upscale( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_upscale_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + + +// lm_ggml_compute_forward_pad + +static void lm_ggml_compute_forward_pad_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + LM_GGML_ASSERT( dst->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + float * dst_ptr = (float *) dst->data; + + // TODO: optimize + + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = ith; i1 < ne1; i1 += nth) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + for (int64_t i3 = 0; i3 < ne3; ++i3) { + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + + const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + dst_ptr[dst_idx] = *src_ptr; + } else { + dst_ptr[dst_idx] = 0; + } + } + } + } + } +} + +static void lm_ggml_compute_forward_pad( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_pad_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + + +// lm_ggml_compute_forward_arange + +static void lm_ggml_compute_forward_arange_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + LM_GGML_ASSERT(dst->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const float start = lm_ggml_get_op_params_f32(dst, 0); + const float stop = lm_ggml_get_op_params_f32(dst, 1); + const float step = lm_ggml_get_op_params_f32(dst, 2); + + const int64_t steps = (int64_t) ceilf((stop - start) / step); + + LM_GGML_ASSERT(lm_ggml_nelements(dst) == steps); + + for (int64_t i = ith; i < steps; i+= nth) { + float value = start + step * i; + ((float *)dst->data)[i] = value; + } +} + +static void lm_ggml_compute_forward_arange( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + switch (dst->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_arange_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +static void lm_ggml_compute_forward_timestep_embedding_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + const int dim = lm_ggml_get_op_params_i32(dst, 0); + const int max_period = lm_ggml_get_op_params_i32(dst, 1); + + int half = dim / 2; + + for (int64_t i = 0; i < ne00; i++) { + float * embed_data = (float *)((char *) dst->data + i*nb1); + for (int64_t j = ith; j < half; j += nth) { + float timestep = ((float *)src0->data)[i]; + float freq = (float)expf(-logf(max_period) * j / half); + float arg = timestep * freq; + embed_data[j] = cosf(arg); + embed_data[j + half] = sinf(arg); + } + if (dim % 2 != 0 && ith == 0) { + embed_data[dim] = 0.f; + } + } +} + +static void lm_ggml_compute_forward_timestep_embedding( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_timestep_embedding_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_argsort + +static void lm_ggml_compute_forward_argsort_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + LM_GGML_ASSERT(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = lm_ggml_nrows(src0); + + enum lm_ggml_sort_order order = (enum lm_ggml_sort_order) lm_ggml_get_op_params_i32(dst, 0); + + for (int64_t i = ith; i < nr; i += nth) { + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + const float * src_data = (float *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne0; j++) { + dst_data[j] = j; + } + + // C doesn't have a functional sort, so we do a bubble sort instead + for (int64_t j = 0; j < ne0; j++) { + for (int64_t k = j + 1; k < ne0; k++) { + if ((order == LM_GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || + (order == LM_GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { + int32_t tmp = dst_data[j]; + dst_data[j] = dst_data[k]; + dst_data[k] = tmp; + } + } + } + } +} + +static void lm_ggml_compute_forward_argsort( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_argsort_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_flash_attn_ext + +static void lm_ggml_compute_forward_flash_attn_ext_f16( + const struct lm_ggml_compute_params * params, + const struct lm_ggml_tensor * q, + const struct lm_ggml_tensor * k, + const struct lm_ggml_tensor * v, + const struct lm_ggml_tensor * mask, + struct lm_ggml_tensor * dst) { + + LM_GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + LM_GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + LM_GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + LM_GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + LM_GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + LM_GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + LM_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + LM_GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + + LM_GGML_ASSERT(ne0 == D); + LM_GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + LM_GGML_ASSERT(nbq0 == lm_ggml_type_size(q->type)); + LM_GGML_ASSERT(nbk0 == lm_ggml_type_size(k->type)); + LM_GGML_ASSERT(nbv0 == lm_ggml_type_size(v->type)); + + LM_GGML_ASSERT(neq0 == D); + LM_GGML_ASSERT(nek0 == D); + LM_GGML_ASSERT(nev0 == D); + + LM_GGML_ASSERT(neq1 == N); + LM_GGML_ASSERT(nev0 == D); + + // dst cannot be transposed or permuted + LM_GGML_ASSERT(nb0 == sizeof(float)); + LM_GGML_ASSERT(nb0 <= nb1); + LM_GGML_ASSERT(nb1 <= nb2); + LM_GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + + // parallelize by q rows using lm_ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + enum lm_ggml_type const k_vec_dot_type = type_traits_cpu[k->type].vec_dot_type; + lm_ggml_from_float_t const q_to_vec_dot = type_traits_cpu[k_vec_dot_type].from_float; + lm_ggml_vec_dot_t const kq_vec_dot = type_traits_cpu[k->type].vec_dot; + lm_ggml_to_float_t const v_to_float = lm_ggml_get_type_traits(v->type)->to_float; + + LM_GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type"); + LM_GGML_ASSERT(v_to_float && "fattn: unsupported V-type"); + + // loop over n_batch and n_head + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + const uint32_t h = iq2; // head index + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value + + float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator + float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer + lm_ggml_fp16_t * VKQ16 = (lm_ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator + lm_ggml_fp16_t * Q_q = (lm_ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 + + if (v->type == LM_GGML_TYPE_F16) { + memset(VKQ16, 0, D*sizeof(lm_ggml_fp16_t)); + } else { + memset(VKQ32, 0, D*sizeof(float)); + } + + const lm_ggml_fp16_t * mp = mask ? (lm_ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + q_to_vec_dot(pq, Q_q, D); + + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope*LM_GGML_FP16_TO_FP32(mp[ic]) : 0.0f; + if (mv == -INFINITY) { + continue; + } + + float s; // KQ value + + const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); + kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); + + s = s*scale; // scale KQ value + + if (logit_softcap != 0.0f) { + s = logit_softcap*tanhf(s); + } + + s += mv; // apply mask + + const float Mold = M; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); + + if (v->type == LM_GGML_TYPE_F16) { + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + lm_ggml_vec_scale_f16(D, VKQ16, ms); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + // V += v*expf(s - M) + lm_ggml_vec_mad_f16(D, VKQ16, (const lm_ggml_fp16_t *) v_data, vs); + } else { + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + lm_ggml_vec_scale_f32(D, VKQ32, ms); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + v_to_float(v_data, V32, D); + + // V += v*expf(s - M) + lm_ggml_vec_mad_f32(D, VKQ32, V32, vs); + } + + S = S*ms + vs; // scale and increment sum with partial sum + } + + if (v->type == LM_GGML_TYPE_F16) { + for (int64_t d = 0; d < D; ++d) { + VKQ32[d] = LM_GGML_FP16_TO_FP32(VKQ16[d]); + } + } + + // V /= S + const float S_inv = 1.0f/S; + lm_ggml_vec_scale_f32(D, VKQ32, S_inv); + + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); + + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); + } +} + +static void lm_ggml_compute_forward_flash_attn_ext( + const struct lm_ggml_compute_params * params, + const struct lm_ggml_tensor * q, + const struct lm_ggml_tensor * k, + const struct lm_ggml_tensor * v, + const struct lm_ggml_tensor * mask, + struct lm_ggml_tensor * dst) { + switch (dst->op_params[3]) { + case LM_GGML_PREC_DEFAULT: + case LM_GGML_PREC_F32: + { + // uses F32 accumulators + lm_ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_flash_attn_back + +static void lm_ggml_compute_forward_flash_attn_back_f32( + const struct lm_ggml_compute_params * params, + const bool masked, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * q = dst->src[0]; + const struct lm_ggml_tensor * k = dst->src[1]; + const struct lm_ggml_tensor * v = dst->src[2]; + const struct lm_ggml_tensor * d = dst->src[3]; + + LM_GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + LM_GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + LM_GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + LM_GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + LM_GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + LM_GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + LM_GGML_TENSOR_LOCALS(int64_t, ned, d, ne) + LM_GGML_TENSOR_LOCALS(size_t, nbd, d, nb) + LM_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + LM_GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = lm_ggml_up(M, LM_GGML_SOFT_MAX_UNROLL); + const int mxDM = MAX(D, Mup); + + // LM_GGML_ASSERT(ne0 == D); + // LM_GGML_ASSERT(ne1 == N); + LM_GGML_ASSERT(P >= 0); + + LM_GGML_ASSERT(nbq0 == sizeof(float)); + LM_GGML_ASSERT(nbk0 == sizeof(float)); + LM_GGML_ASSERT(nbv0 == sizeof(float)); + + LM_GGML_ASSERT(neq0 == D); + LM_GGML_ASSERT(nek0 == D); + LM_GGML_ASSERT(nev1 == D); + LM_GGML_ASSERT(ned0 == D); + + LM_GGML_ASSERT(neq1 == N); + LM_GGML_ASSERT(nek1 == N + P); + LM_GGML_ASSERT(nev1 == D); + LM_GGML_ASSERT(ned1 == N); + + // dst cannot be transposed or permuted + LM_GGML_ASSERT(nb0 == sizeof(float)); + LM_GGML_ASSERT(nb0 <= nb1); + LM_GGML_ASSERT(nb1 <= nb2); + LM_GGML_ASSERT(nb2 <= nb3); + + if (ith == 0) { + memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); + } + lm_ggml_barrier(params->threadpool); + + const int64_t elem_q = lm_ggml_nelements(q); + const int64_t elem_k = lm_ggml_nelements(k); + + enum lm_ggml_type result_type = dst->type; + LM_GGML_ASSERT(lm_ggml_blck_size(result_type) == 1); + const size_t tsize = lm_ggml_type_size(result_type); + + const size_t offs_q = 0; + const size_t offs_k = offs_q + LM_GGML_PAD(elem_q * tsize, LM_GGML_MEM_ALIGN); + const size_t offs_v = offs_k + LM_GGML_PAD(elem_k * tsize, LM_GGML_MEM_ALIGN); + + void * grad_q = (char *) dst->data; + void * grad_k = (char *) dst->data + offs_k; + void * grad_v = (char *) dst->data + offs_v; + + const size_t nbgq1 = nb0*neq0; + const size_t nbgq2 = nb0*neq0*neq1; + const size_t nbgq3 = nb0*neq0*neq1*neq2; + + const size_t nbgk1 = nb0*nek0; + const size_t nbgk2 = nb0*nek0*nek1; + const size_t nbgk3 = nb0*nek0*nek1*neq2; + + const size_t nbgv1 = nb0*nev0; + const size_t nbgv2 = nb0*nev0*nev1; + const size_t nbgv3 = nb0*nev0*nev1*neq2; + + // parallelize by k rows using lm_ggml_vec_dot_f32 + + // total rows in k + const int nr = nek2*nek3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0f/sqrtf(D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + // how often k2 (and v2) is repeated in q2 + int nrep = neq2/nek2; + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int ik3 = ir/(nek2); + const int ik2 = ir - ik3*nek2; + + const int iq3 = ik3; + const int id3 = ik3; + const int iv3 = ik3; + const int iv2 = ik2; + + for (int irep = 0; irep < nrep; ++irep) { + const int iq2 = ik2 + irep*nek2; + const int id2 = iq2; + + // (ik2 + irep*nek2) % nek2 == ik2 + for (int iq1 = 0; iq1 < neq1; ++iq1) { + const int id1 = iq1; + + // not sure about CACHE_LINE_SIZE_F32.. + // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? + float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); + float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + const int64_t masked_begin = masked ? (P + iq1 + 1) : M; + for (int64_t ic = 0; ic < masked_begin; ++ic) { + // k indices + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + lm_ggml_vec_dot_f32(neq0, + S + i1, 0, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1); + } + + // scale + lm_ggml_vec_scale_f32(masked_begin, S, scale); + + for (int64_t i = masked_begin; i < M; i++) { + S[i] = -INFINITY; + } + + // softmax + // exclude known -INF S[..] values from max and loop + // dont forget to set their SM values to zero + { + float max = -INFINITY; + lm_ggml_vec_max_f32(masked_begin, &max, S); + + lm_ggml_float sum = 0.0; + { +#ifdef LM_GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(SM, 1, &max, SM, 1, Mup); + vvexpf(SM, SM, &Mup); + lm_ggml_vec_sum_f32(Mup, &sum, SM); +#else + sum = lm_ggml_vec_soft_max_f32(Mup, SM, S, max); +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + lm_ggml_vec_scale_f32(masked_begin, SM, sum); + + } + + // step-by-step explanation + { + // forward-process shape grads from backward process + // parallel_for ik2,ik3: + // for irep: + // iq2 = ik2 + irep*nek2 + // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur] + // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] + // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur] + // for iq1: + // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur + // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur + // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 + // S0 = -Inf [D,1,1,1] + // ~S1[i] = dot(kcur[:D,i], qcur) + // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale + // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) + // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur + // ~S5[i] = dot(vcur[:,i], S4) + // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3] + // ~dst[i,iq1,iq2,iq3] = S5[i] ^ + // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3] + // dst backward-/ grad[dst] = d + // + // output gradients with their dependencies: + // + // grad[kcur] = grad[S1].T @ qcur + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S4] = grad[S5] @ vcur + // grad[S4] = d[:D,id1,id2,id3] @ vcur + // grad[qcur] = grad[S1] @ kcur + // grad[vcur] = grad[S5].T @ S4 + // grad[vcur] = d[:D,id1,id2,id3].T @ S4 + // + // in post-order: + // + // S1 = qcur @ kcur.T + // S2 = S1 * scale + // S3 = diag_mask_inf(S2, P) + // S4 = softmax(S3) + // grad[S4] = d[:D,id1,id2,id3] @ vcur + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[qcur] = grad[S1] @ kcur + // grad[kcur] = grad[S1].T @ qcur + // grad[vcur] = d[:D,id1,id2,id3].T @ S4 + // + // using less variables (SM=S4): + // + // S = diag_mask_inf(qcur @ kcur.T * scale, P) + // SM = softmax(S) + // S = d[:D,iq1,iq2,iq3] @ vcur + // dot_SM_gradSM = dot(SM, S) + // S = SM * (S - dot(SM, S)) + // S = diag_mask_zero(S, P) * scale + // + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[k][:D,:M,ik2,ik3] += S.T @ qcur + // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM + } + + // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] + // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] + // for ic: + // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3] + // exclude known future zero S[..] values from operation + lm_ggml_vec_set_f32(masked_begin, S, 0); + for (int64_t ic = 0; ic < D; ++ic) { + lm_ggml_vec_mad_f32(masked_begin, + S, + (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); + } + + // S = SM * (S - dot(SM, S)) + float dot_SM_gradSM = 0; + lm_ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1); + lm_ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); + lm_ggml_vec_mul_f32 (masked_begin, S, S, SM); + + // S = diag_mask_zero(S, P) * scale + // already done by above lm_ggml_vec_set_f32 + + // exclude known zero S[..] values from operation + lm_ggml_vec_scale_f32(masked_begin, S, scale); + + // S shape [M,1] + // SM shape [M,1] + // kcur shape [D,M] + // qcur shape [D,1] + // vcur shape [M,D] + + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] + // for ic: + // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3] + // exclude known zero S[..] values from loop + for (int64_t ic = 0; ic < masked_begin; ++ic) { + lm_ggml_vec_mad_f32(D, + (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)), + (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)), + S[ic]); + } + + // grad[k][:D,:M,iq2,iq3] += S.T @ qcur + // for ic: + // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] + // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] + // exclude known zero S[..] values from loop + for (int64_t ic = 0; ic < masked_begin; ++ic) { + lm_ggml_vec_mad_f32(D, + (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), + S[ic]); + } + + // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM + // for ic: + // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M] + // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M] + // exclude known zero SM[..] values from mad + for (int64_t ic = 0; ic < D; ++ic) { + lm_ggml_vec_mad_f32(masked_begin, + (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)), + SM, + *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); + } + } + } + } +} + +static void lm_ggml_compute_forward_flash_attn_back( + const struct lm_ggml_compute_params * params, + const bool masked, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * q = dst->src[0]; + + switch (q->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_flash_attn_back_f32(params, masked, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_ssm_conv + +static void lm_ggml_compute_forward_ssm_conv_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + const struct lm_ggml_tensor * src0 = dst->src[0]; // conv_x + const struct lm_ggml_tensor * src1 = dst->src[1]; // conv1d.weight + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1->ne[0]; // d_conv + const int ncs = src0->ne[0]; // d_conv - 1 + n_t + const int nr = src0->ne[1]; // d_inner + const int n_t = dst->ne[1]; // tokens per sequence + const int n_s = dst->ne[2]; // number of sequences in the batch + + LM_GGML_ASSERT( dst->ne[0] == nr); + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + LM_GGML_ASSERT(src1->nb[0] == sizeof(float)); + LM_GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + const int ir = ir1 - ir0; + + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + // {d_conv - 1 + n_t, d_inner, n_seqs} + // sliding window + const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} + const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} + float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} + + // TODO: transpose the output for smaller strides for big batches? + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // rowwise dot product + // NOTE: not using lm_ggml_vec_dot_f32, because its sum is in double precision + float sumf = 0.0f; + + // d_conv + for (int i0 = 0; i0 < nc; ++i0) { + sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; + } + x[i1] = sumf; + } + } + } +} + +static void lm_ggml_compute_forward_ssm_conv( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + switch (dst->src[0]->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_ssm_conv_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_ssm_scan + +static void lm_ggml_compute_forward_ssm_scan_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + const struct lm_ggml_tensor * src0 = dst->src[0]; // s + const struct lm_ggml_tensor * src1 = dst->src[1]; // x + const struct lm_ggml_tensor * src2 = dst->src[2]; // dt + const struct lm_ggml_tensor * src3 = dst->src[3]; // A + const struct lm_ggml_tensor * src4 = dst->src[4]; // B + const struct lm_ggml_tensor * src5 = dst->src[5]; // C + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = src1->ne[1]; // number of tokens per sequence + const int64_t n_s = src0->ne[2]; // number of sequences in the batch + + LM_GGML_ASSERT(lm_ggml_nelements(src1) + lm_ggml_nelements(src0) == lm_ggml_nelements(dst)); + LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); + LM_GGML_ASSERT(src1->nb[0] == sizeof(float)); + LM_GGML_ASSERT(src2->nb[0] == sizeof(float)); + LM_GGML_ASSERT(src3->nb[0] == sizeof(float)); + LM_GGML_ASSERT(src4->nb[0] == sizeof(float)); + LM_GGML_ASSERT(src5->nb[0] == sizeof(float)); + // required for the dot product between s and C + LM_GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + // required for per-sequence offsets for states + LM_GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); + // required to get correct offset for state destination (i.e. src1->nb[3]) + LM_GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + const int ir = ir1 - ir0; + + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } + + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // state = prev_state * dA + dB * x + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[i0]; + s[i] = state; + } + y[i1] = sumf; + } + } + } +} + +static void lm_ggml_compute_forward_ssm_scan( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + switch (dst->src[0]->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_ssm_scan_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_win_part + +static void lm_ggml_compute_forward_win_part_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + UNUSED(params); + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + LM_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + + const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t w = ((const int32_t *)(dst->op_params))[2]; + + assert(ne00 == ne0); + assert(ne3 == nep0*nep1); + + // TODO: optimize / multi-thread + for (int py = 0; py < nep1; ++py) { + for (int px = 0; px < nep0; ++px) { + const int64_t i3 = py*nep0 + px; + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i02 = py*w + i2; + const int64_t i01 = px*w + i1; + const int64_t i00 = i0; + + const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0; + const int64_t j = i02*ne01*ne00 + i01*ne00 + i00; + + if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { + ((float *) dst->data)[i] = 0.0f; + } else { + ((float *) dst->data)[i] = ((float *) src0->data)[j]; + } + } + } + } + } + } +} + +static void lm_ggml_compute_forward_win_part( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_win_part_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_win_unpart + +static void lm_ggml_compute_forward_win_unpart_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + UNUSED(params); + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + LM_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + + const int32_t w = ((const int32_t *)(dst->op_params))[0]; + + // padding + const int px = (w - ne1%w)%w; + //const int py = (w - ne2%w)%w; + + const int npx = (px + ne1)/w; + //const int npy = (py + ne2)/w; + + assert(ne0 == ne00); + + // TODO: optimize / multi-thread + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int ip2 = i2/w; + const int ip1 = i1/w; + + const int64_t i02 = i2%w; + const int64_t i01 = i1%w; + const int64_t i00 = i0; + + const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00; + const int64_t j = i2*ne1*ne0 + i1*ne0 + i0; + + ((float *) dst->data)[j] = ((float *) src0->data)[i]; + } + } + } +} + +static void lm_ggml_compute_forward_win_unpart( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_win_unpart_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +//gmml_compute_forward_unary + +static void lm_ggml_compute_forward_unary( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const enum lm_ggml_unary_op op = lm_ggml_get_unary_op(dst); + + switch (op) { + case LM_GGML_UNARY_OP_ABS: + { + lm_ggml_compute_forward_abs(params, dst); + } break; + case LM_GGML_UNARY_OP_SGN: + { + lm_ggml_compute_forward_sgn(params, dst); + } break; + case LM_GGML_UNARY_OP_NEG: + { + lm_ggml_compute_forward_neg(params, dst); + } break; + case LM_GGML_UNARY_OP_STEP: + { + lm_ggml_compute_forward_step(params, dst); + } break; + case LM_GGML_UNARY_OP_TANH: + { + lm_ggml_compute_forward_tanh(params, dst); + } break; + case LM_GGML_UNARY_OP_ELU: + { + lm_ggml_compute_forward_elu(params, dst); + } break; + case LM_GGML_UNARY_OP_RELU: + { + lm_ggml_compute_forward_relu(params, dst); + } break; + case LM_GGML_UNARY_OP_SIGMOID: + { + lm_ggml_compute_forward_sigmoid(params, dst); + } break; + case LM_GGML_UNARY_OP_GELU: + { + lm_ggml_compute_forward_gelu(params, dst); + } break; + case LM_GGML_UNARY_OP_GELU_QUICK: + { + lm_ggml_compute_forward_gelu_quick(params, dst); + } break; + case LM_GGML_UNARY_OP_SILU: + { + lm_ggml_compute_forward_silu(params, dst); + } break; + case LM_GGML_UNARY_OP_HARDSWISH: + { + lm_ggml_compute_forward_hardswish(params, dst); + } break; + case LM_GGML_UNARY_OP_HARDSIGMOID: + { + lm_ggml_compute_forward_hardsigmoid(params, dst); + } break; + case LM_GGML_UNARY_OP_EXP: + { + lm_ggml_compute_forward_exp(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_get_rel_pos + +static void lm_ggml_compute_forward_get_rel_pos_f16( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + UNUSED(params); + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + const int64_t w = ne1; + + lm_ggml_fp16_t * src0_data = (lm_ggml_fp16_t *) src0->data; + lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *) dst->data; + + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + const int64_t pos = (w - i1 - 1) + i2; + for (int64_t i0 = 0; i0 < ne0; ++i0) { + dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; + } + } + } +} + +static void lm_ggml_compute_forward_get_rel_pos( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F16: + case LM_GGML_TYPE_BF16: + { + lm_ggml_compute_forward_get_rel_pos_f16(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_add_rel_pos + +static void lm_ggml_compute_forward_add_rel_pos_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + const struct lm_ggml_tensor * src2 = dst->src[2]; + + const bool inplace = (bool) ((int32_t *) dst->op_params)[0]; + if (!inplace) { + if (params->ith == 0) { + memcpy((char *) dst->data, (char *) src0->data, lm_ggml_nbytes(dst)); + } + lm_ggml_barrier(params->threadpool); + } + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359 + + float * src1_data = (float *) src1->data; + float * src2_data = (float *) src2->data; + float * dst_data = (float *) dst->data; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int ith = params->ith; + const int nth = params->nth; + + // total patches in dst + const int np = ne13; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + const int ip0 = dp*ith; + const int ip1 = MIN(ip0 + dp, np); + + for (int64_t i13 = ip0; i13 < ip1; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10; + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t jp0 = jp1 + i10; + const float src1_e = src1_data[jp0]; + const float src2_e = src2_data[jp0]; + + const int64_t jdh = jp0 * ne10; + const int64_t jdw = jdh - (ne10 - 1) * i10; + + for (int64_t j = 0; j < ne10; ++j) { + dst_data[jdh + j ] += src2_e; + dst_data[jdw + j*ne10] += src1_e; + } + } + } + } + } +} + +static void lm_ggml_compute_forward_add_rel_pos( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_add_rel_pos_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_rwkv_wkv6 + +static void lm_ggml_compute_forward_rwkv_wkv6_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + const int64_t T = dst->src[1]->ne[3]; + const int64_t C = dst->ne[0]; + const int64_t HEADS = dst->src[1]->ne[2]; + const int64_t n_seqs = dst->src[5]->ne[1]; + const int64_t head_size = C / HEADS; + + float * dst_data = (float *) dst->data; + float * state = ((float *) dst->data) + C * T; + + const int ith = params->ith; + const int nth = params->nth; + + if (ith >= HEADS) { + return; + } + + const int h_start = (HEADS * ith) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; + + float * k = (float *) dst->src[0]->data; + float * v = (float *) dst->src[1]->data; + float * r = (float *) dst->src[2]->data; + float * time_faaaa = (float *) dst->src[3]->data; + float * time_decay = (float *) dst->src[4]->data; + + size_t t_stride = HEADS * head_size; // Same to C + + size_t h_stride = C / HEADS; + LM_GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS + size_t h_stride_2d = head_size * head_size; + + if (ith == 0) { + memset(dst_data, 0, T * C * sizeof(float)); + } + lm_ggml_barrier(params->threadpool); + + + #if defined(__AVX__) && !defined(__AVX512F__) + #define LM_GGML_F32X LM_GGML_F32x8 + #define LM_GGML_F32X_SET1 LM_GGML_F32x8_SET1 + #define LM_GGML_F32X_LOAD LM_GGML_F32x8_LOAD + #define LM_GGML_F32X_STORE LM_GGML_F32x8_STORE + #define LM_GGML_F32X_MUL LM_GGML_F32x8_MUL + #define LM_GGML_F32X_FMA LM_GGML_F32x8_FMA + #define WKV_VECTOR_SIZE 8 + #elif defined(__AVX512F__) + #define LM_GGML_F32X LM_GGML_F32x16 + #define LM_GGML_F32X_SET1 LM_GGML_F32x16_SET1 + #define LM_GGML_F32X_LOAD LM_GGML_F32x16_LOAD + #define LM_GGML_F32X_STORE LM_GGML_F32x16_STORE + #define LM_GGML_F32X_MUL LM_GGML_F32x16_MUL + #define LM_GGML_F32X_FMA LM_GGML_F32x16_FMA + #define WKV_VECTOR_SIZE 16 + #elif defined(__ARM_NEON) && defined(__aarch64__) + #define LM_GGML_F32X LM_GGML_F32x4 + #define LM_GGML_F32X_SET1 LM_GGML_F32x4_SET1 + #define LM_GGML_F32X_LOAD LM_GGML_F32x4_LOAD + #define LM_GGML_F32X_STORE LM_GGML_F32x4_STORE + #define LM_GGML_F32X_MUL LM_GGML_F32x4_MUL + #define LM_GGML_F32X_FMA LM_GGML_F32x4_FMA + #define WKV_VECTOR_SIZE 4 + #endif + + #ifdef WKV_VECTOR_SIZE + const int64_t vec_count = head_size / WKV_VECTOR_SIZE; + + for (int64_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + size_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; + + for (int64_t h = h_start; h < h_end; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (int64_t i = 0; i < head_size; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_i_offset = h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + float k_val = k[t_h_i_offset]; + float r_val = r[t_h_i_offset]; + float time_faaaa_val = time_faaaa[h_i_offset]; + float time_decay_val = time_decay[t_h_i_offset]; + + // Broadcast scalar values to vectors + LM_GGML_F32X k_vec = LM_GGML_F32X_SET1(k_val); + LM_GGML_F32X r_vec = LM_GGML_F32X_SET1(r_val); + LM_GGML_F32X time_faaaa_vec = LM_GGML_F32X_SET1(time_faaaa_val); + LM_GGML_F32X time_decay_vec = LM_GGML_F32X_SET1(time_decay_val); + + for (int64_t j = 0; j < vec_count; j++) { + size_t base_j = j * WKV_VECTOR_SIZE; + size_t t_h_j_offset = t_h_offset + base_j; + size_t h_2d_i_j_offset = h_2d_i_offset + base_j; + + // Load x elements at once + LM_GGML_F32X v_vec = LM_GGML_F32X_LOAD(&v[t_h_j_offset]); + LM_GGML_F32X prev_state_vec = LM_GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]); + LM_GGML_F32X dst_vec = LM_GGML_F32X_LOAD(&dst_data[t_h_j_offset]); + + // Compute kv = v * k + LM_GGML_F32X kv_vec = LM_GGML_F32X_MUL(v_vec, k_vec); + + // Compute temp = kv * time_faaaa + prev_state + LM_GGML_F32X temp_vec = LM_GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec); + + // Update dst: dst += temp * r + dst_vec = LM_GGML_F32X_FMA(dst_vec, temp_vec, r_vec); + LM_GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec); + + // Update state: state = prev_state * time_decay + kv + LM_GGML_F32X new_state_vec = LM_GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec); + LM_GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec); + } + + // Handle remaining elements, this will not be used. + for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + float temp_val = kv_val * time_faaaa_val + prev_state_val; + dst_data[t_h_j_offset] += temp_val * r_val; + state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + } + } + } + } + + #else + // basically fused operations: + // dst = r @ (time_faaaa * (k @ v) + state), + // state = time_decay * state + (k @ v), + // recursive through each token + for (int64_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + size_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; + + for (int64_t h = h_start; h < h_end; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (int64_t i = 0; i < head_size; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_i_offset = h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + float k_val = k[t_h_i_offset]; + float r_val = r[t_h_i_offset]; + float time_faaaa_val = time_faaaa[h_i_offset]; + // RWKV v6: different time_decay for each token. + float time_decay_val = time_decay[t_h_i_offset]; + + for (int64_t j = 0; j < head_size; j++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + float temp_val = kv_val * time_faaaa_val + prev_state_val; + dst_data[t_h_j_offset] += temp_val * r_val; + state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + } + } + } + } + #endif +} + + +static void lm_ggml_compute_forward_rwkv_wkv6( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_rwkv_wkv6_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_map_unary + +static void lm_ggml_compute_forward_map_unary_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst, + const lm_ggml_unary_op_f32_t fun) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + fun(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void lm_ggml_compute_forward_map_unary( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst, + const lm_ggml_unary_op_f32_t fun) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_map_unary_f32(params, dst, fun); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_map_binary + +static void lm_ggml_compute_forward_map_binary_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst, + const lm_ggml_binary_op_f32_t fun) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + if (params->ith != 0) { + return; + } + + assert(lm_ggml_is_contiguous_1(src0)); + assert(lm_ggml_is_contiguous_1(src1)); + assert(lm_ggml_is_contiguous_1(dst)); + assert(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); + + const int n = lm_ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + fun(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + +static void lm_ggml_compute_forward_map_binary( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst, + const lm_ggml_binary_op_f32_t fun) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_map_binary_f32(params, dst, fun); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_map_custom1 + +static void lm_ggml_compute_forward_map_custom1_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst, + const lm_ggml_custom1_op_f32_t fun) { + + const struct lm_ggml_tensor * a = dst->src[0]; + + if (params->ith != 0) { + return; + } + + fun(dst, a); +} + +// lm_ggml_compute_forward_map_custom2 + +static void lm_ggml_compute_forward_map_custom2_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst, + const lm_ggml_custom2_op_f32_t fun) { + + const struct lm_ggml_tensor * a = dst->src[0]; + const struct lm_ggml_tensor * b = dst->src[1]; + + if (params->ith != 0) { + return; + } + + fun(dst, a, b); +} + +// lm_ggml_compute_forward_map_custom3 + +static void lm_ggml_compute_forward_map_custom3_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst, + const lm_ggml_custom3_op_f32_t fun) { + + const struct lm_ggml_tensor * a = dst->src[0]; + const struct lm_ggml_tensor * b = dst->src[1]; + const struct lm_ggml_tensor * c = dst->src[1]; + + if (params->ith != 0) { + return; + } + + fun(dst, a, b, c); +} + +// lm_ggml_compute_forward_map_custom1 + +static void lm_ggml_compute_forward_map_custom1( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * a = dst->src[0]; + + struct lm_ggml_map_custom1_op_params p; + memcpy(&p, dst->op_params, sizeof(p)); + + p.fun(dst, a, params->ith, params->nth, p.userdata); +} + +// lm_ggml_compute_forward_map_custom2 + +static void lm_ggml_compute_forward_map_custom2( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * a = dst->src[0]; + const struct lm_ggml_tensor * b = dst->src[1]; + + struct lm_ggml_map_custom2_op_params p; + memcpy(&p, dst->op_params, sizeof(p)); + + p.fun(dst, a, b, params->ith, params->nth, p.userdata); +} + +// lm_ggml_compute_forward_map_custom3 + +static void lm_ggml_compute_forward_map_custom3( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * a = dst->src[0]; + const struct lm_ggml_tensor * b = dst->src[1]; + const struct lm_ggml_tensor * c = dst->src[2]; + + struct lm_ggml_map_custom3_op_params p; + memcpy(&p, dst->op_params, sizeof(p)); + + p.fun(dst, a, b, c, params->ith, params->nth, p.userdata); +} + +// lm_ggml_compute_forward_cross_entropy_loss + +static void lm_ggml_compute_forward_cross_entropy_loss_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src0->nb[0] == lm_ggml_type_size(src0->type)); + LM_GGML_ASSERT(src1->nb[0] == lm_ggml_type_size(src1->type)); + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1)); + LM_GGML_ASSERT(lm_ggml_is_scalar(dst)); + LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F32); + + // TODO: handle transposed/permuted matrices + const int64_t nc = src0->ne[0]; + const int64_t nr = lm_ggml_nrows(src0); + + const int ith = params->ith; + const int nth = params->nth; + + float * sums = (float *) params->wdata; + float * st = ((float *) params->wdata) + nth + ith*nc; + float sum_thread = 0.0f; + + LM_GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc)); + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + for (int64_t i1 = ir0; i1 < ir1; ++i1) { + const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]); + const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]); + +#ifndef NDEBUG + for (int64_t i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(s0[i])); + assert(!isnan(s1[i])); + } +#endif + + float max = -INFINITY; + lm_ggml_vec_max_f32(nc, &max, s0); + const lm_ggml_float sum_softmax = lm_ggml_vec_log_soft_max_f32(nc, st, s0, max); + assert(sum_softmax >= 0.0); + + lm_ggml_vec_add1_f32(nc, st, st, -sum_softmax); + lm_ggml_vec_mul_f32(nc, st, st, s1); + + float sum_st = 0.0f; + lm_ggml_vec_sum_f32(nc, &sum_st, st); + sum_thread += sum_st; + +#ifndef NDEBUG + for (int64_t i = 0; i < nc; ++i) { + assert(!isnan(st[i])); + assert(!isinf(st[i])); + } +#endif + } + sums[ith] = sum_thread; + lm_ggml_barrier(params->threadpool); + + if (ith == 0) { + float * dp = (float *) dst->data; + lm_ggml_vec_sum_f32(nth, dp, sums); + dp[0] *= -1.0f / (float) nr; + } +} + +static void lm_ggml_compute_forward_cross_entropy_loss( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_cross_entropy_loss_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// lm_ggml_compute_forward_cross_entropy_loss_back + +static void lm_ggml_compute_forward_cross_entropy_loss_back_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + const struct lm_ggml_tensor * opt0 = dst->src[2]; + + LM_GGML_ASSERT(lm_ggml_is_contiguous(dst)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(src1)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(opt0)); + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); + + const int64_t ith = params->ith; + const int64_t nth = params->nth; + + // TODO: handle transposed/permuted matrices + const int64_t nc = src0->ne[0]; + const int64_t nr = lm_ggml_nrows(src0); + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr; + + for (int64_t i1 = ir0; i1 < ir1; i1++) { + float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); + float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); + float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); + +#ifndef NDEBUG + for (int64_t i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(s0[i])); + assert(!isnan(s1[i])); + } +#endif + + // soft_max + float max = -INFINITY; + lm_ggml_vec_max_f32(nc, &max, s0); + lm_ggml_float sum = lm_ggml_vec_soft_max_f32(nc, ds0, s0, max); + assert(sum > 0.0); + lm_ggml_vec_scale_f32(nc, ds0, 1.0/sum); + + // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr + lm_ggml_vec_sub_f32(nc, ds0, ds0, s1); + lm_ggml_vec_scale_f32(nc, ds0, d_by_nr); + +#ifndef NDEBUG + for (int64_t i = 0; i < nc; ++i) { + assert(!isnan(ds0[i])); + assert(!isinf(ds0[i])); + } +#endif + } +} + +static void lm_ggml_compute_forward_cross_entropy_loss_back( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_cross_entropy_loss_back_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +static void lm_ggml_compute_forward_opt_step_adamw_f32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src0_grad = dst->src[1]; + const struct lm_ggml_tensor * src0_grad_m = dst->src[2]; + const struct lm_ggml_tensor * src0_grad_v = dst->src[3]; + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src0_grad)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src0); + + LM_GGML_TENSOR_UNARY_OP_LOCALS + LM_GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + /* const float gnorm = 1.0f; */ + int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t)); + const float alpha = lm_ggml_get_op_params_f32(dst, 2); + const float beta1 = lm_ggml_get_op_params_f32(dst, 3); + const float beta2 = lm_ggml_get_op_params_f32(dst, 4); + const float eps = lm_ggml_get_op_params_f32(dst, 5); + const float wd = lm_ggml_get_op_params_f32(dst, 6); + + const float beta1h = alpha/(1.0f - powf(beta1, iter)); + const float beta2h = 1.0f/(1.0f - powf(beta2, iter)); + + for (int ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const size_t offset = i03*nb03 + i02*nb02 + i01*nb01; + + float * w = (float *) ((char *) src0->data + offset); // weight + const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad + float * m = (float *) ((char *) src0_grad_m->data + offset); + float * v = (float *) ((char *) src0_grad_v->data + offset); + + for (int i00 = 0; i00 < ne00; ++i00) { + m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1); + v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2); + + const float mh = m[i00]*beta1h; + const float vh = sqrtf(v[i00]*beta2h) + eps; + + // The weight decay is applied independently of the Adam momenta m and v. + // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss. + // See: https://arxiv.org/pdf/1711.05101v3.pdf + w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh; + } + } + + lm_ggml_barrier(params->threadpool); + if (ith != 0) { + return; + } + + iter++; + memcpy(&dst->op_params[0], &iter, sizeof(int64_t)); +} + +static void lm_ggml_compute_forward_opt_step_adamw( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case LM_GGML_TYPE_F32: + { + lm_ggml_compute_forward_opt_step_adamw_f32(params, dst); + } break; + default: + { + LM_GGML_ABORT("fatal error"); + } + } +} +///////////////////////////////// + +static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * tensor) { + LM_GGML_ASSERT(params); + + if (tensor->op == LM_GGML_OP_NONE || lm_ggml_is_empty(tensor)) { + return; + } + + switch (tensor->op) { + case LM_GGML_OP_DUP: + { + lm_ggml_compute_forward_dup(params, tensor); + } break; + case LM_GGML_OP_ADD: + { + lm_ggml_compute_forward_add(params, tensor); + } break; + case LM_GGML_OP_ADD1: + { + lm_ggml_compute_forward_add1(params, tensor); + } break; + case LM_GGML_OP_ACC: + { + lm_ggml_compute_forward_acc(params, tensor); + } break; + case LM_GGML_OP_SUB: + { + lm_ggml_compute_forward_sub(params, tensor); + } break; + case LM_GGML_OP_MUL: + { + lm_ggml_compute_forward_mul(params, tensor); + } break; + case LM_GGML_OP_DIV: + { + lm_ggml_compute_forward_div(params, tensor); + } break; + case LM_GGML_OP_SQR: + { + lm_ggml_compute_forward_sqr(params, tensor); + } break; + case LM_GGML_OP_SQRT: + { + lm_ggml_compute_forward_sqrt(params, tensor); + } break; + case LM_GGML_OP_LOG: + { + lm_ggml_compute_forward_log(params, tensor); + } break; + case LM_GGML_OP_SIN: + { + lm_ggml_compute_forward_sin(params, tensor); + } break; + case LM_GGML_OP_COS: + { + lm_ggml_compute_forward_cos(params, tensor); + } break; + case LM_GGML_OP_SUM: + { + lm_ggml_compute_forward_sum(params, tensor); + } break; + case LM_GGML_OP_SUM_ROWS: + { + lm_ggml_compute_forward_sum_rows(params, tensor); + } break; + case LM_GGML_OP_MEAN: + { + lm_ggml_compute_forward_mean(params, tensor); + } break; + case LM_GGML_OP_ARGMAX: + { + lm_ggml_compute_forward_argmax(params, tensor); + } break; + case LM_GGML_OP_COUNT_EQUAL: + { + lm_ggml_compute_forward_count_equal(params, tensor); + } break; + case LM_GGML_OP_REPEAT: + { + lm_ggml_compute_forward_repeat(params, tensor); + } break; + case LM_GGML_OP_REPEAT_BACK: + { + lm_ggml_compute_forward_repeat_back(params, tensor); + } break; + case LM_GGML_OP_CONCAT: + { + lm_ggml_compute_forward_concat(params, tensor); + } break; + case LM_GGML_OP_SILU_BACK: + { + lm_ggml_compute_forward_silu_back(params, tensor); + } break; + case LM_GGML_OP_NORM: + { + lm_ggml_compute_forward_norm(params, tensor); + } break; + case LM_GGML_OP_RMS_NORM: + { + lm_ggml_compute_forward_rms_norm(params, tensor); + } break; + case LM_GGML_OP_RMS_NORM_BACK: + { + lm_ggml_compute_forward_rms_norm_back(params, tensor); + } break; + case LM_GGML_OP_GROUP_NORM: + { + lm_ggml_compute_forward_group_norm(params, tensor); + } break; + case LM_GGML_OP_MUL_MAT: + { + lm_ggml_compute_forward_mul_mat(params, tensor); + } break; + case LM_GGML_OP_MUL_MAT_ID: + { + lm_ggml_compute_forward_mul_mat_id(params, tensor); + } break; + case LM_GGML_OP_OUT_PROD: + { + lm_ggml_compute_forward_out_prod(params, tensor); + } break; + case LM_GGML_OP_SCALE: + { + lm_ggml_compute_forward_scale(params, tensor); + } break; + case LM_GGML_OP_SET: + { + lm_ggml_compute_forward_set(params, tensor); + } break; + case LM_GGML_OP_CPY: + { + lm_ggml_compute_forward_cpy(params, tensor); + } break; + case LM_GGML_OP_CONT: + { + lm_ggml_compute_forward_cont(params, tensor); + } break; + case LM_GGML_OP_RESHAPE: + { + lm_ggml_compute_forward_reshape(params, tensor); + } break; + case LM_GGML_OP_VIEW: + { + lm_ggml_compute_forward_view(params, tensor); + } break; + case LM_GGML_OP_PERMUTE: + { + lm_ggml_compute_forward_permute(params, tensor); + } break; + case LM_GGML_OP_TRANSPOSE: + { + lm_ggml_compute_forward_transpose(params, tensor); + } break; + case LM_GGML_OP_GET_ROWS: + { + lm_ggml_compute_forward_get_rows(params, tensor); + } break; + case LM_GGML_OP_GET_ROWS_BACK: + { + lm_ggml_compute_forward_get_rows_back(params, tensor); + } break; + case LM_GGML_OP_DIAG: + { + lm_ggml_compute_forward_diag(params, tensor); + } break; + case LM_GGML_OP_DIAG_MASK_INF: + { + lm_ggml_compute_forward_diag_mask_inf(params, tensor); + } break; + case LM_GGML_OP_DIAG_MASK_ZERO: + { + lm_ggml_compute_forward_diag_mask_zero(params, tensor); + } break; + case LM_GGML_OP_SOFT_MAX: + { + lm_ggml_compute_forward_soft_max(params, tensor); + } break; + case LM_GGML_OP_SOFT_MAX_BACK: + { + lm_ggml_compute_forward_soft_max_back(params, tensor); + } break; + case LM_GGML_OP_ROPE: + { + lm_ggml_compute_forward_rope(params, tensor); + } break; + case LM_GGML_OP_ROPE_BACK: + { + lm_ggml_compute_forward_rope_back(params, tensor); + } break; + case LM_GGML_OP_CLAMP: + { + lm_ggml_compute_forward_clamp(params, tensor); + } break; + case LM_GGML_OP_CONV_TRANSPOSE_1D: + { + lm_ggml_compute_forward_conv_transpose_1d(params, tensor); + } break; + case LM_GGML_OP_IM2COL: + { + lm_ggml_compute_forward_im2col(params, tensor); + } break; + case LM_GGML_OP_IM2COL_BACK: + { + lm_ggml_compute_forward_im2col_back_f32(params, tensor); + } break; + case LM_GGML_OP_CONV_TRANSPOSE_2D: + { + lm_ggml_compute_forward_conv_transpose_2d(params, tensor); + } break; + case LM_GGML_OP_POOL_1D: + { + lm_ggml_compute_forward_pool_1d(params, tensor); + } break; + case LM_GGML_OP_POOL_2D: + { + lm_ggml_compute_forward_pool_2d(params, tensor); + } break; + case LM_GGML_OP_POOL_2D_BACK: + { + lm_ggml_compute_forward_pool_2d_back(params, tensor); + } break; + case LM_GGML_OP_UPSCALE: + { + lm_ggml_compute_forward_upscale(params, tensor); + } break; + case LM_GGML_OP_PAD: + { + lm_ggml_compute_forward_pad(params, tensor); + } break; + case LM_GGML_OP_ARANGE: + { + lm_ggml_compute_forward_arange(params, tensor); + } break; + case LM_GGML_OP_TIMESTEP_EMBEDDING: + { + lm_ggml_compute_forward_timestep_embedding(params, tensor); + } break; + case LM_GGML_OP_ARGSORT: + { + lm_ggml_compute_forward_argsort(params, tensor); + } break; + case LM_GGML_OP_LEAKY_RELU: + { + lm_ggml_compute_forward_leaky_relu(params, tensor); + } break; + case LM_GGML_OP_FLASH_ATTN_EXT: + { + lm_ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; + case LM_GGML_OP_FLASH_ATTN_BACK: + { + int32_t t = lm_ggml_get_op_params_i32(tensor, 0); + LM_GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + lm_ggml_compute_forward_flash_attn_back(params, masked, tensor); + } break; + case LM_GGML_OP_SSM_CONV: + { + lm_ggml_compute_forward_ssm_conv(params, tensor); + } break; + case LM_GGML_OP_SSM_SCAN: + { + lm_ggml_compute_forward_ssm_scan(params, tensor); + } break; + case LM_GGML_OP_WIN_PART: + { + lm_ggml_compute_forward_win_part(params, tensor); + } break; + case LM_GGML_OP_WIN_UNPART: + { + lm_ggml_compute_forward_win_unpart(params, tensor); + } break; + case LM_GGML_OP_UNARY: + { + lm_ggml_compute_forward_unary(params, tensor); + } break; + case LM_GGML_OP_GET_REL_POS: + { + lm_ggml_compute_forward_get_rel_pos(params, tensor); + } break; + case LM_GGML_OP_ADD_REL_POS: + { + lm_ggml_compute_forward_add_rel_pos(params, tensor); + } break; + case LM_GGML_OP_RWKV_WKV6: + { + lm_ggml_compute_forward_rwkv_wkv6(params, tensor); + } break; + case LM_GGML_OP_MAP_UNARY: + { + lm_ggml_unary_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + lm_ggml_compute_forward_map_unary(params, tensor, fun); + } + break; + case LM_GGML_OP_MAP_BINARY: + { + lm_ggml_binary_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + lm_ggml_compute_forward_map_binary(params, tensor, fun); + } + break; + case LM_GGML_OP_MAP_CUSTOM1_F32: + { + lm_ggml_custom1_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + lm_ggml_compute_forward_map_custom1_f32(params, tensor, fun); + } + break; + case LM_GGML_OP_MAP_CUSTOM2_F32: + { + lm_ggml_custom2_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + lm_ggml_compute_forward_map_custom2_f32(params, tensor, fun); + } + break; + case LM_GGML_OP_MAP_CUSTOM3_F32: + { + lm_ggml_custom3_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + lm_ggml_compute_forward_map_custom3_f32(params, tensor, fun); + } + break; + case LM_GGML_OP_MAP_CUSTOM1: + { + lm_ggml_compute_forward_map_custom1(params, tensor); + } + break; + case LM_GGML_OP_MAP_CUSTOM2: + { + lm_ggml_compute_forward_map_custom2(params, tensor); + } + break; + case LM_GGML_OP_MAP_CUSTOM3: + { + lm_ggml_compute_forward_map_custom3(params, tensor); + } + break; + case LM_GGML_OP_CROSS_ENTROPY_LOSS: + { + lm_ggml_compute_forward_cross_entropy_loss(params, tensor); + } + break; + case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + lm_ggml_compute_forward_cross_entropy_loss_back(params, tensor); + } + break; + case LM_GGML_OP_OPT_STEP_ADAMW: + { + lm_ggml_compute_forward_opt_step_adamw(params, tensor); + } + break; + case LM_GGML_OP_NONE: + { + // nop + } break; + case LM_GGML_OP_COUNT: + { + LM_GGML_ABORT("fatal error"); + } + } +} + +// Android's libc implementation "bionic" does not support setting affinity +#if defined(__gnu_linux__) +static void set_numa_thread_affinity(int thread_n) { + if (!lm_ggml_is_numa()) { + return; + } + + int node_num; + int rv; + size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); + + switch(g_state.numa.numa_strategy) { + case LM_GGML_NUMA_STRATEGY_DISTRIBUTE: + // run thread on node_num thread_n / (threads per node) + node_num = thread_n % g_state.numa.n_nodes; + break; + case LM_GGML_NUMA_STRATEGY_ISOLATE: + // run thread on current_node + node_num = g_state.numa.current_node; + break; + case LM_GGML_NUMA_STRATEGY_NUMACTL: + // use the cpuset that numactl gave us + rv = pthread_setaffinity_np(pthread_self(), setsize, &g_state.numa.cpuset); + if (rv) { + fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",strerror(rv)); + } + return; + default: + return; + } + + struct lm_ggml_numa_node * node = &g_state.numa.nodes[node_num]; + + cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); + CPU_ZERO_S(setsize, cpus); + for (size_t i = 0; i < node->n_cpus; ++i) { + CPU_SET_S(node->cpus[i], setsize, cpus); + } + + rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); + if (rv) { + fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv)); + } + + CPU_FREE(cpus); +} + +static void clear_numa_thread_affinity(void) { + if (!lm_ggml_is_numa()) { + return; + } + + size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); + + cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); + CPU_ZERO_S(setsize, cpus); + for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) { + CPU_SET_S(i, setsize, cpus); + } + + int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); + if (rv) { + fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv)); + } + + CPU_FREE(cpus); +} +#else +// TODO: Windows etc. +// (the linux implementation may also work on BSD, someone should test) +static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); } +static void clear_numa_thread_affinity(void) {} +#endif + +static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { + int n_tasks = 0; + + if (lm_ggml_is_empty(node)) { + // no need to multi-thread a no-op + n_tasks = 1; + return n_tasks; + } + + switch (node->op) { + case LM_GGML_OP_CPY: + case LM_GGML_OP_DUP: + case LM_GGML_OP_CONT: + case LM_GGML_OP_ADD: + case LM_GGML_OP_ADD1: + case LM_GGML_OP_ACC: + { + n_tasks = n_threads; + } break; + case LM_GGML_OP_SUB: + case LM_GGML_OP_SQR: + case LM_GGML_OP_SQRT: + case LM_GGML_OP_LOG: + case LM_GGML_OP_SIN: + case LM_GGML_OP_COS: + case LM_GGML_OP_SUM: + case LM_GGML_OP_SUM_ROWS: + case LM_GGML_OP_MEAN: + case LM_GGML_OP_ARGMAX: + { + n_tasks = 1; + } break; + case LM_GGML_OP_COUNT_EQUAL: + { + n_tasks = n_threads; + } break; + case LM_GGML_OP_REPEAT: + case LM_GGML_OP_REPEAT_BACK: + case LM_GGML_OP_LEAKY_RELU: + { + n_tasks = 1; + } break; + case LM_GGML_OP_UNARY: + switch (lm_ggml_get_unary_op(node)) { + case LM_GGML_UNARY_OP_ABS: + case LM_GGML_UNARY_OP_SGN: + case LM_GGML_UNARY_OP_NEG: + case LM_GGML_UNARY_OP_STEP: + case LM_GGML_UNARY_OP_TANH: + case LM_GGML_UNARY_OP_ELU: + case LM_GGML_UNARY_OP_RELU: + case LM_GGML_UNARY_OP_SIGMOID: + case LM_GGML_UNARY_OP_HARDSWISH: + case LM_GGML_UNARY_OP_HARDSIGMOID: + case LM_GGML_UNARY_OP_EXP: + { + n_tasks = 1; + } break; + + case LM_GGML_UNARY_OP_GELU: + case LM_GGML_UNARY_OP_GELU_QUICK: + case LM_GGML_UNARY_OP_SILU: + { + n_tasks = n_threads; + } break; + default: + LM_GGML_ABORT("fatal error"); + } + break; + case LM_GGML_OP_SILU_BACK: + case LM_GGML_OP_MUL: + case LM_GGML_OP_DIV: + case LM_GGML_OP_NORM: + case LM_GGML_OP_RMS_NORM: + case LM_GGML_OP_RMS_NORM_BACK: + case LM_GGML_OP_GROUP_NORM: + case LM_GGML_OP_CONCAT: + case LM_GGML_OP_MUL_MAT: + case LM_GGML_OP_MUL_MAT_ID: + case LM_GGML_OP_OUT_PROD: + { + n_tasks = n_threads; + } break; + case LM_GGML_OP_GET_ROWS: + { + // FIXME: get_rows can use additional threads, but the cost of launching additional threads + // decreases performance with GPU offloading + //n_tasks = n_threads; + n_tasks = 1; + } break; + case LM_GGML_OP_SCALE: + case LM_GGML_OP_SET: + case LM_GGML_OP_RESHAPE: + case LM_GGML_OP_VIEW: + case LM_GGML_OP_PERMUTE: + case LM_GGML_OP_TRANSPOSE: + case LM_GGML_OP_GET_ROWS_BACK: + case LM_GGML_OP_DIAG: + { + n_tasks = 1; + } break; + case LM_GGML_OP_DIAG_MASK_ZERO: + case LM_GGML_OP_DIAG_MASK_INF: + case LM_GGML_OP_SOFT_MAX_BACK: + case LM_GGML_OP_ROPE: + case LM_GGML_OP_ROPE_BACK: + case LM_GGML_OP_ADD_REL_POS: + { + n_tasks = n_threads; + } break; + case LM_GGML_OP_CLAMP: + { + n_tasks = 1; //TODO + } break; + case LM_GGML_OP_SOFT_MAX: + { + n_tasks = MIN(n_threads, lm_ggml_nrows(node->src[0])); + } break; + case LM_GGML_OP_IM2COL: + case LM_GGML_OP_IM2COL_BACK: + case LM_GGML_OP_CONV_TRANSPOSE_1D: + case LM_GGML_OP_CONV_TRANSPOSE_2D: + { + n_tasks = n_threads; + } break; + case LM_GGML_OP_POOL_1D: + case LM_GGML_OP_POOL_2D: + case LM_GGML_OP_POOL_2D_BACK: + { + n_tasks = 1; + } break; + case LM_GGML_OP_UPSCALE: + case LM_GGML_OP_PAD: + case LM_GGML_OP_ARANGE: + case LM_GGML_OP_TIMESTEP_EMBEDDING: + case LM_GGML_OP_ARGSORT: + case LM_GGML_OP_FLASH_ATTN_EXT: + case LM_GGML_OP_FLASH_ATTN_BACK: + case LM_GGML_OP_SSM_CONV: + case LM_GGML_OP_SSM_SCAN: + { + n_tasks = n_threads; + } break; + case LM_GGML_OP_WIN_PART: + case LM_GGML_OP_WIN_UNPART: + case LM_GGML_OP_GET_REL_POS: + case LM_GGML_OP_RWKV_WKV6: + case LM_GGML_OP_MAP_UNARY: + case LM_GGML_OP_MAP_BINARY: + case LM_GGML_OP_MAP_CUSTOM1_F32: + case LM_GGML_OP_MAP_CUSTOM2_F32: + case LM_GGML_OP_MAP_CUSTOM3_F32: + { + n_tasks = 1; + } break; + case LM_GGML_OP_MAP_CUSTOM1: + { + struct lm_ggml_map_custom1_op_params p; + memcpy(&p, node->op_params, sizeof(p)); + if (p.n_tasks == LM_GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p.n_tasks, n_threads); + } + } break; + case LM_GGML_OP_MAP_CUSTOM2: + { + struct lm_ggml_map_custom2_op_params p; + memcpy(&p, node->op_params, sizeof(p)); + if (p.n_tasks == LM_GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p.n_tasks, n_threads); + } + } break; + case LM_GGML_OP_MAP_CUSTOM3: + { + struct lm_ggml_map_custom3_op_params p; + memcpy(&p, node->op_params, sizeof(p)); + if (p.n_tasks == LM_GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p.n_tasks, n_threads); + } + } break; + case LM_GGML_OP_CROSS_ENTROPY_LOSS: + case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK: + case LM_GGML_OP_OPT_STEP_ADAMW: + { + n_tasks = n_threads; + } break; + case LM_GGML_OP_NONE: + { + n_tasks = 1; + } break; + case LM_GGML_OP_COUNT: + { + LM_GGML_ABORT("fatal error"); + } + default: + { + fprintf(stderr, "%s: op not implemented: ", __func__); + if (node->op < LM_GGML_OP_COUNT) { + fprintf(stderr, "%s\n", lm_ggml_op_name(node->op)); + } else { + fprintf(stderr, "%d\n", node->op); + } + LM_GGML_ABORT("fatal error"); + } + } + + assert(n_tasks > 0); + + return n_tasks; +} + +static thread_ret_t lm_ggml_graph_compute_secondary_thread(void* data); + +#if defined(_WIN32) +#include "windows.h" + +// TODO: support > 64 CPUs +bool lm_ggml_thread_apply_affinity(bool * mask) { + HANDLE h = GetCurrentThread(); + uint64_t bitmask = 0ULL; + + assert(LM_GGML_MAX_N_THREADS >= 64); + + for (int32_t i = 0; i < 8; i++) { + int32_t idx = i * 8; + uint8_t val = 0; + val |= mask[idx + 0] << 0; + val |= mask[idx + 1] << 1; + val |= mask[idx + 2] << 2; + val |= mask[idx + 3] << 3; + val |= mask[idx + 4] << 4; + val |= mask[idx + 5] << 5; + val |= mask[idx + 6] << 6; + val |= mask[idx + 7] << 7; + bitmask |= (uint64_t)val << idx; + } + + for (int32_t i = 64; i < LM_GGML_MAX_N_THREADS; i++) { + if (mask[i]) { + fprintf(stderr, "warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\n"); + break; + } + } + + DWORD_PTR m = (DWORD_PTR)bitmask; + + m = SetThreadAffinityMask(h, m); + + return m != 0; +} + +static bool lm_ggml_thread_apply_priority(int32_t prio) { + // Note that on Windows the Process Priority Class must be updated in order to set Thread priority. + // This is up to the applications. + DWORD p = THREAD_PRIORITY_NORMAL; + switch (prio) { + case LM_GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break; + case LM_GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break; + case LM_GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break; + case LM_GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break; + } + + if (prio == LM_GGML_SCHED_PRIO_NORMAL) { + // Keep inherited policy/priority + return true; + } + + if (!SetThreadPriority(GetCurrentThread(), p)) { + fprintf(stderr, "warn: failed to set thread priority %d : (%d)\n", prio, (int) GetLastError()); + return false; + } + + return true; +} + +#elif defined(__APPLE__) +#include +#include + +static bool lm_ggml_thread_apply_affinity(const bool * mask) { + // Not supported on Apple platforms + UNUSED(mask); + return true; +} + +static bool lm_ggml_thread_apply_priority(int32_t prio) { + struct sched_param p; + int32_t policy = SCHED_OTHER; + switch (prio) { + case LM_GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; + case LM_GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; + case LM_GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; + case LM_GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break; + } + + if (prio == LM_GGML_SCHED_PRIO_NORMAL) { + // Keep inherited policy/priority + return true; + } + + int32_t err = pthread_setschedparam(pthread_self(), policy, &p); + if (err != 0) { + fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err); + return false; + } + + return true; +} + +#elif defined(__gnu_linux__) +// TODO: this may not work on BSD, to be verified + +static bool lm_ggml_thread_apply_affinity(const bool * mask) { + cpu_set_t cpuset; + int err; + + CPU_ZERO(&cpuset); + + for (uint32_t i = 0; i < LM_GGML_MAX_N_THREADS; i++) { + if (mask[i]) { + LM_GGML_PRINT_DEBUG("Thread %lx: adding %d to cpuset\n", pthread_self(), i); + CPU_SET(i, &cpuset); + } + } + +#ifdef __ANDROID__ + err = sched_setaffinity(0, sizeof(cpuset), &cpuset); + if (err < 0) { + err = errno; + } +#else + err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset); +#endif + if (err != 0) { + fprintf(stderr, "warn: failed to set affinity mask 0x%llx : %s (%d)\n", (unsigned long long)mask, strerror(err), err); + return false; + } + + return true; +} + +static bool lm_ggml_thread_apply_priority(int32_t prio) { + struct sched_param p; + int32_t policy = SCHED_OTHER; + switch (prio) { + case LM_GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; + case LM_GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; + case LM_GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; + case LM_GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break; + } + + if (prio == LM_GGML_SCHED_PRIO_NORMAL) { + // Keep inherited policy/priority + return true; + } + + int32_t err = pthread_setschedparam(pthread_self(), policy, &p); + if (err != 0) { + fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err); + return false; + } + + return true; +} + +#else // unsupported platforms + +static bool lm_ggml_thread_apply_affinity(const bool * mask) { + UNUSED(mask); + return true; +} + +static bool lm_ggml_thread_apply_priority(int32_t prio) { + UNUSED(prio); + return true; +} + +#endif + +static bool lm_ggml_thread_cpumask_is_valid(const bool * mask) { + for (int i = 0; i < LM_GGML_MAX_N_THREADS; i++) { + if (mask[i]) { return true; } + } + return false; +} + +static void lm_ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) { + if (!strict) { + memcpy(local_mask, global_mask, LM_GGML_MAX_N_THREADS); + return; + } else { + memset(local_mask, 0, LM_GGML_MAX_N_THREADS); + int32_t base_idx = *iter; + for (int32_t i = 0; i < LM_GGML_MAX_N_THREADS; i++) { + int32_t idx = base_idx + i; + if (idx >= LM_GGML_MAX_N_THREADS) { + // Just a cheaper modulo + idx -= LM_GGML_MAX_N_THREADS; + } + if (global_mask[idx]) { + local_mask[idx] = 1; + *iter = idx + 1; + return; + } + } + } +} + +void lm_ggml_threadpool_free(struct lm_ggml_threadpool* threadpool) { + if (!threadpool) return; + + const int n_threads = threadpool->n_threads_max; + +#ifndef LM_GGML_USE_OPENMP + struct lm_ggml_compute_state* workers = threadpool->workers; + + lm_ggml_mutex_lock(&threadpool->mutex); + + threadpool->stop = true; + threadpool->pause = false; + + lm_ggml_cond_broadcast(&threadpool->cond); + lm_ggml_mutex_unlock(&threadpool->mutex); + + for (int j = 1; j < n_threads; j++) { + int32_t rc = lm_ggml_thread_join(workers[j].thrd, NULL); + LM_GGML_ASSERT(rc == LM_GGML_EXIT_SUCCESS || rc == LM_GGML_EXIT_ABORTED); + UNUSED(rc); + } + + lm_ggml_mutex_destroy(&threadpool->mutex); + lm_ggml_cond_destroy(&threadpool->cond); +#endif // LM_GGML_USE_OPENMP + + const size_t workers_size = sizeof(struct lm_ggml_compute_state) * n_threads; + lm_ggml_aligned_free(threadpool->workers, workers_size); + lm_ggml_aligned_free(threadpool, sizeof(struct lm_ggml_threadpool)); +} + +#ifndef LM_GGML_USE_OPENMP +// pause/resume must be called under mutex +static void lm_ggml_threadpool_pause_locked(struct lm_ggml_threadpool * threadpool) { + LM_GGML_PRINT_DEBUG("Pausing threadpool\n"); + threadpool->pause = true; + lm_ggml_cond_broadcast(&threadpool->cond); +} + +static void lm_ggml_threadpool_resume_locked(struct lm_ggml_threadpool * threadpool) { + LM_GGML_PRINT_DEBUG("Resuming threadpool\n"); + threadpool->pause = false; + lm_ggml_cond_broadcast(&threadpool->cond); +} +#endif + +void lm_ggml_threadpool_pause(struct lm_ggml_threadpool * threadpool) { +#ifndef LM_GGML_USE_OPENMP + lm_ggml_mutex_lock(&threadpool->mutex); + if (!threadpool->pause) { + lm_ggml_threadpool_pause_locked(threadpool); + } + lm_ggml_mutex_unlock(&threadpool->mutex); +#else + UNUSED(threadpool); +#endif +} + +void lm_ggml_threadpool_resume(struct lm_ggml_threadpool * threadpool) { +#ifndef LM_GGML_USE_OPENMP + lm_ggml_mutex_lock(&threadpool->mutex); + if (threadpool->pause) { + lm_ggml_threadpool_resume_locked(threadpool); + } + lm_ggml_mutex_unlock(&threadpool->mutex); +#else + UNUSED(threadpool); +#endif +} + +struct lm_ggml_cplan lm_ggml_graph_plan( + const struct lm_ggml_cgraph * cgraph, + int n_threads, + struct lm_ggml_threadpool * threadpool) { + + if (threadpool == NULL) { + //LM_GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); + } + if (n_threads <= 0) { + n_threads = threadpool ? threadpool->n_threads_max : LM_GGML_DEFAULT_N_THREADS; + } + + size_t work_size = 0; + + struct lm_ggml_cplan cplan; + memset(&cplan, 0, sizeof(struct lm_ggml_cplan)); + + int max_tasks = 1; + + // thread scheduling for the different operations + work buffer size estimation + for (int i = 0; i < cgraph->n_nodes; i++) { + struct lm_ggml_tensor * node = cgraph->nodes[i]; + + const int n_tasks = lm_ggml_get_n_tasks(node, n_threads); + + max_tasks = MAX(max_tasks, n_tasks); + + size_t cur = 0; + + switch (node->op) { + case LM_GGML_OP_CPY: + case LM_GGML_OP_DUP: + { + if (lm_ggml_is_quantized(node->type) || + // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32 + (node->src[0]->type == LM_GGML_TYPE_F16 && node->src[1] && node->src[1]->type == LM_GGML_TYPE_BF16) || + (node->src[0]->type == LM_GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == LM_GGML_TYPE_F16)) { + cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->ne[0] * n_tasks; + } + } break; + case LM_GGML_OP_ADD: + case LM_GGML_OP_ADD1: + { + if (lm_ggml_is_quantized(node->src[0]->type)) { + cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; + } + } break; + case LM_GGML_OP_ACC: + { + if (lm_ggml_is_quantized(node->src[0]->type)) { + cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; + } + } break; + case LM_GGML_OP_COUNT_EQUAL: + { + cur = lm_ggml_type_size(node->type)*n_tasks; + } break; + case LM_GGML_OP_MUL_MAT: + { + const enum lm_ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type; + + if (node->src[1]->type != vec_dot_type) { + cur = lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(node->src[1])); + } + } break; + case LM_GGML_OP_MUL_MAT_ID: + { + cur = 0; + const struct lm_ggml_tensor * src0 = node->src[0]; + const struct lm_ggml_tensor * src1 = node->src[1]; + const enum lm_ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type; + if (src1->type != vec_dot_type) { + cur += lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(src1)); + } + const int n_as = src0->ne[2]; + cur += LM_GGML_PAD(cur, sizeof(int64_t)); // align + cur += n_as * sizeof(int64_t); // matrix_row_counts + cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows + } break; + case LM_GGML_OP_OUT_PROD: + { + if (lm_ggml_is_quantized(node->src[0]->type)) { + cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; + } + } break; + case LM_GGML_OP_SOFT_MAX: + case LM_GGML_OP_ROPE: + { + cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->ne[0] * n_tasks; + } break; + case LM_GGML_OP_CONV_TRANSPOSE_1D: + { + LM_GGML_ASSERT(node->src[0]->ne[3] == 1); + LM_GGML_ASSERT(node->src[1]->ne[2] == 1); + LM_GGML_ASSERT(node->src[1]->ne[3] == 1); + + const int64_t ne00 = node->src[0]->ne[0]; // K + const int64_t ne01 = node->src[0]->ne[1]; // Cout + const int64_t ne02 = node->src[0]->ne[2]; // Cin + + const int64_t ne10 = node->src[1]->ne[0]; // L + const int64_t ne11 = node->src[1]->ne[1]; // Cin + + if ((node->src[0]->type == LM_GGML_TYPE_F16 || + node->src[0]->type == LM_GGML_TYPE_BF16) && + node->src[1]->type == LM_GGML_TYPE_F32) { + cur += sizeof(lm_ggml_fp16_t)*ne00*ne01*ne02; + cur += sizeof(lm_ggml_fp16_t)*ne10*ne11; + } else if (node->src[0]->type == LM_GGML_TYPE_F32 && + node->src[1]->type == LM_GGML_TYPE_F32) { + cur += sizeof(float)*ne00*ne01*ne02; + cur += sizeof(float)*ne10*ne11; + } else { + LM_GGML_ABORT("fatal error"); + } + } break; + case LM_GGML_OP_CONV_TRANSPOSE_2D: + { + const int64_t ne00 = node->src[0]->ne[0]; // W + const int64_t ne01 = node->src[0]->ne[1]; // H + const int64_t ne02 = node->src[0]->ne[2]; // Channels Out + const int64_t ne03 = node->src[0]->ne[3]; // Channels In + + const int64_t ne10 = node->src[1]->ne[0]; // W + const int64_t ne11 = node->src[1]->ne[1]; // H + const int64_t ne12 = node->src[1]->ne[2]; // Channels In + + cur += sizeof(lm_ggml_fp16_t)*ne00*ne01*ne02*ne03; + cur += sizeof(lm_ggml_fp16_t)*ne10*ne11*ne12; + } break; + case LM_GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D + + cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread + } break; + case LM_GGML_OP_FLASH_ATTN_BACK: + { + const int64_t D = node->src[0]->ne[0]; + const int64_t ne11 = lm_ggml_up(node->src[1]->ne[1], LM_GGML_SOFT_MAX_UNROLL); + const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in lm_ggml_compute_forward_flash_attn_back + if (node->src[1]->type == LM_GGML_TYPE_F32) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == LM_GGML_TYPE_F16) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == LM_GGML_TYPE_BF16) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } + } break; + + case LM_GGML_OP_CROSS_ENTROPY_LOSS: + { + cur = lm_ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); + } break; + case LM_GGML_OP_COUNT: + { + LM_GGML_ABORT("fatal error"); + } + default: + break; + } + + work_size = MAX(work_size, cur); + } + + if (work_size > 0) { + work_size += CACHE_LINE_SIZE*(n_threads); + } + + cplan.threadpool = threadpool; + cplan.n_threads = MIN(max_tasks, n_threads); + cplan.work_size = work_size; + cplan.work_data = NULL; + + return cplan; +} + +static thread_ret_t lm_ggml_graph_compute_thread(void * data) { + struct lm_ggml_compute_state * state = (struct lm_ggml_compute_state *) data; + struct lm_ggml_threadpool * tp = state->threadpool; + + const struct lm_ggml_cgraph * cgraph = tp->cgraph; + const struct lm_ggml_cplan * cplan = tp->cplan; + + set_numa_thread_affinity(state->ith); + + struct lm_ggml_compute_params params = { + /*.ith =*/ state->ith, + /*.nth =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed), + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + /*.threadpool=*/ tp, + }; + + for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) { + struct lm_ggml_tensor * node = cgraph->nodes[node_n]; + + lm_ggml_compute_forward(¶ms, node); + + if (state->ith == 0 && cplan->abort_callback && + cplan->abort_callback(cplan->abort_callback_data)) { + tp->abort = true; + tp->ec = LM_GGML_STATUS_ABORTED; + } + + lm_ggml_barrier(state->threadpool); + } + + return 0; +} + +#ifndef LM_GGML_USE_OPENMP + +// check if thread is active +static inline bool lm_ggml_graph_compute_thread_active(struct lm_ggml_compute_state * state) { + struct lm_ggml_threadpool * threadpool = state->threadpool; + int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed); + return (state->ith < n_threads); +} + +// check if thread is ready to proceed (exit from polling or sleeping) +static inline bool lm_ggml_graph_compute_thread_ready(struct lm_ggml_compute_state * state) { + struct lm_ggml_threadpool * threadpool = state->threadpool; + + if (state->pending || threadpool->stop || threadpool->pause) { return true; } + + // check for new graph/work + int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed); + if (new_graph != state->last_graph) { + state->pending = lm_ggml_graph_compute_thread_active(state); + state->last_graph = new_graph; + } + + return state->pending; +} + +// sync thread state after polling +static inline void lm_ggml_graph_compute_thread_sync(struct lm_ggml_compute_state * state) { + // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead + #ifdef LM_GGML_TSAN_ENABLED + atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst); + #else + atomic_thread_fence(memory_order_seq_cst); + #endif + UNUSED(state); +} + +static inline bool lm_ggml_graph_compute_poll_for_work(struct lm_ggml_compute_state * state) { + struct lm_ggml_threadpool * threadpool = state->threadpool; + + // Skip polling for unused threads + if (!lm_ggml_graph_compute_thread_active(state)) { + return state->pending; + } + + // This seems to make 0 ... 100 a decent range for polling level across modern processors. + // Perhaps, we can adjust it dynamically based on load and things. + const uint64_t n_rounds = 1024UL * 128 * threadpool->poll; + + for (uint64_t i=0; !lm_ggml_graph_compute_thread_ready(state) && i < n_rounds; i++) { + // No new work. Keep polling. + lm_ggml_thread_cpu_relax(); + } + + return state->pending; +} + +static inline bool lm_ggml_graph_compute_check_for_work(struct lm_ggml_compute_state * state) { + struct lm_ggml_threadpool * threadpool = state->threadpool; + + if (lm_ggml_graph_compute_poll_for_work(state)) { + lm_ggml_graph_compute_thread_sync(state); + return state->pending; + } + + lm_ggml_mutex_lock_shared(&threadpool->mutex); + while (!lm_ggml_graph_compute_thread_ready(state)) { + // No new work. Wait for the signal. + LM_GGML_PRINT_DEBUG("thread #%d waiting for work (sleeping)\n", state->ith); + lm_ggml_cond_wait(&threadpool->cond, &threadpool->mutex); + } + lm_ggml_mutex_unlock_shared(&threadpool->mutex); + + return state->pending; +} + +static thread_ret_t lm_ggml_graph_compute_secondary_thread(void* data) { + struct lm_ggml_compute_state * state = (struct lm_ggml_compute_state *) data; + struct lm_ggml_threadpool * threadpool = state->threadpool; + + lm_ggml_thread_apply_priority(threadpool->prio); + if (lm_ggml_thread_cpumask_is_valid(state->cpumask)) { + lm_ggml_thread_apply_affinity(state->cpumask); + } + + while (true) { + // Check if we need to sleep + while (threadpool->pause) { + LM_GGML_PRINT_DEBUG("thread #%d inside pause loop\n", state->ith); + lm_ggml_mutex_lock_shared(&threadpool->mutex); + if (threadpool->pause) { + lm_ggml_cond_wait(&threadpool->cond, &threadpool->mutex); + } + LM_GGML_PRINT_DEBUG("thread #%d resuming after wait\n", state->ith); + lm_ggml_mutex_unlock_shared(&threadpool->mutex); + } + + // This needs to be checked for after the cond_wait + if (threadpool->stop) break; + + // Check if there is new work + // The main thread is the only one that can dispatch new work + + lm_ggml_graph_compute_check_for_work(state); + if (state->pending) { + state->pending = false; + + lm_ggml_graph_compute_thread(state); + } + } + + return (thread_ret_t) 0; +} + +// Start processing new graph +static void lm_ggml_graph_compute_kickoff(struct lm_ggml_threadpool * threadpool, int n_threads) +{ + // Always take the mutex here because the worker threads are doing hybrid poll/wait + + lm_ggml_mutex_lock(&threadpool->mutex); + + LM_GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads); + + // Update the number of active threads + atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed); + + // Indicate the graph is ready to be processed + // We need the full seq-cst fence here because of the polling threads (used in thread_sync) + atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst); + + if (threadpool->pause) { + // Update main thread prio and affinity to match the threadpool settings + lm_ggml_thread_apply_priority(threadpool->prio); + if (lm_ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) { + lm_ggml_thread_apply_affinity(threadpool->workers[0].cpumask); + } + + // resume does cond broadcast + lm_ggml_threadpool_resume_locked(threadpool); + } else { + lm_ggml_cond_broadcast(&threadpool->cond); + } + + lm_ggml_mutex_unlock(&threadpool->mutex); +} + +#endif // LM_GGML_USE_OPENMP + +void lm_ggml_threadpool_params_init(struct lm_ggml_threadpool_params * p, int n_threads) { + p->n_threads = n_threads; + p->prio = 0; // default priority (usually means normal or inherited) + p->poll = 50; // hybrid-polling enabled + p->strict_cpu = false; // no strict placement (all threads share same cpumask) + p->paused = false; // threads are ready to go + memset(p->cpumask, 0, LM_GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited) +} + +struct lm_ggml_threadpool_params lm_ggml_threadpool_params_default(int n_threads) { + struct lm_ggml_threadpool_params p; + lm_ggml_threadpool_params_init(&p, n_threads); + return p; +} + +bool lm_ggml_threadpool_params_match(const struct lm_ggml_threadpool_params * p0, const struct lm_ggml_threadpool_params * p1) { + if (p0->n_threads != p1->n_threads ) return false; + if (p0->prio != p1->prio ) return false; + if (p0->poll != p1->poll ) return false; + if (p0->strict_cpu != p1->strict_cpu ) return false; + return memcmp(p0->cpumask, p1->cpumask, LM_GGML_MAX_N_THREADS) == 0; +} + +static struct lm_ggml_threadpool * lm_ggml_threadpool_new_impl( + struct lm_ggml_threadpool_params * tpp, + struct lm_ggml_cgraph * cgraph, + struct lm_ggml_cplan * cplan) { + + struct lm_ggml_threadpool * threadpool = + lm_ggml_aligned_malloc(sizeof(struct lm_ggml_threadpool)); + { + threadpool->cgraph = cgraph; + threadpool->cplan = cplan; + threadpool->n_graph = 0; + threadpool->n_barrier = 0; + threadpool->n_barrier_passed = 0; + threadpool->current_chunk = 0; + threadpool->stop = false; + threadpool->pause = tpp->paused; + threadpool->abort = false; + threadpool->workers = NULL; + threadpool->n_threads_max = tpp->n_threads; + threadpool->n_threads_cur = tpp->n_threads; + threadpool->poll = tpp->poll; + threadpool->prio = tpp->prio; + threadpool->ec = LM_GGML_STATUS_SUCCESS; + } + + // Allocate and init workers state + const size_t workers_size = sizeof(struct lm_ggml_compute_state) * tpp->n_threads; + struct lm_ggml_compute_state * workers = lm_ggml_aligned_malloc(workers_size); + + memset(workers, 0, workers_size); + for (int j = 0; j < tpp->n_threads; j++) { + workers[j].threadpool = threadpool; + workers[j].ith = j; + } + + threadpool->workers = workers; + +#ifndef LM_GGML_USE_OPENMP + lm_ggml_mutex_init(&threadpool->mutex); + lm_ggml_cond_init(&threadpool->cond); + + // Spin the threads for all workers, and update CPU placements. + // Place the main thread last (towards the higher numbered CPU cores). + + int32_t cpumask_iter = 0; + + for (int j = 1; j < tpp->n_threads; j++) { + lm_ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter); + + int32_t rc = lm_ggml_thread_create(&workers[j].thrd, NULL, lm_ggml_graph_compute_secondary_thread, &workers[j]); + LM_GGML_ASSERT(rc == 0); + } + + lm_ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter); + + if (!threadpool->pause) { + // Update main thread prio and affinity at the start, otherwise we'll do it in resume + lm_ggml_thread_apply_priority(threadpool->prio); + if (lm_ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) { + lm_ggml_thread_apply_affinity(threadpool->workers[0].cpumask); + } + } +#endif // LM_GGML_USE_OPENMP + + return threadpool; +} + +struct lm_ggml_threadpool * lm_ggml_threadpool_new(struct lm_ggml_threadpool_params * tpp) { + return lm_ggml_threadpool_new_impl(tpp, NULL, NULL); +} + +enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct lm_ggml_cplan * cplan) { + lm_ggml_cpu_init(); + + LM_GGML_ASSERT(cplan); + LM_GGML_ASSERT(cplan->n_threads > 0); + LM_GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL); + + int n_threads = cplan->n_threads; + struct lm_ggml_threadpool * threadpool = cplan->threadpool; + + bool disposable_threadpool = false; + + if (threadpool == NULL) { + //LM_GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); + disposable_threadpool = true; + + struct lm_ggml_threadpool_params ttp = lm_ggml_threadpool_params_default(n_threads); + threadpool = lm_ggml_threadpool_new_impl(&ttp, cgraph, cplan); + } else { + // Reset some of the parameters that need resetting + // No worker threads should be accessing the parameters below at this stage + threadpool->cgraph = cgraph; + threadpool->cplan = cplan; + threadpool->current_chunk = 0; + threadpool->abort = false; + threadpool->ec = LM_GGML_STATUS_SUCCESS; + } + +#ifdef LM_GGML_USE_OPENMP + if (n_threads > 1) { + #pragma omp parallel num_threads(n_threads) + { + #pragma omp single + { + // update the number of threads from the actual number of threads that we got from OpenMP + n_threads = omp_get_num_threads(); + atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed); + } + + lm_ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]); + } + } else { + atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed); + lm_ggml_graph_compute_thread(&threadpool->workers[0]); + } +#else + if (n_threads > threadpool->n_threads_max) { + LM_GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max); + n_threads = threadpool->n_threads_max; + } + + // Kick all threads to start the new graph + lm_ggml_graph_compute_kickoff(threadpool, n_threads); + + // This is a work thread too + lm_ggml_graph_compute_thread(&threadpool->workers[0]); +#endif + + // don't leave affinity set on the main thread + clear_numa_thread_affinity(); + + enum lm_ggml_status ret = threadpool->ec; + + if (disposable_threadpool) { + lm_ggml_threadpool_free(threadpool); + } + + return ret; +} + +enum lm_ggml_status lm_ggml_graph_compute_with_ctx(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, int n_threads) { + struct lm_ggml_cplan cplan = lm_ggml_graph_plan(cgraph, n_threads, NULL); + + cplan.work_data = (uint8_t *)lm_ggml_new_buffer(ctx, cplan.work_size); + + return lm_ggml_graph_compute(cgraph, &cplan); +} + + +int lm_ggml_cpu_has_avx(void) { +#if defined(__AVX__) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_avx_vnni(void) { +#if defined(__AVXVNNI__) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_avx2(void) { +#if defined(__AVX2__) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_avx512(void) { +#if defined(__AVX512F__) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_avx512_vbmi(void) { +#if defined(__AVX512VBMI__) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_avx512_vnni(void) { +#if defined(__AVX512VNNI__) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_avx512_bf16(void) { +#if defined(__AVX512BF16__) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_amx_int8(void) { +#if defined(__AMX_INT8__) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_fma(void) { +#if defined(__FMA__) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_arm_fma(void) { +#if defined(__ARM_FEATURE_FMA) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_riscv_v(void) { +#if defined(__riscv_v_intrinsic) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_f16c(void) { +#if defined(__F16C__) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_fp16_va(void) { +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_wasm_simd(void) { +#if defined(__wasm_simd128__) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_llamafile(void) { +#if defined(LM_GGML_USE_LLAMAFILE) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_sse3(void) { +#if defined(__SSE3__) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_ssse3(void) { +#if defined(__SSSE3__) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_vsx(void) { +#if defined(__POWER9_VECTOR__) + return 1; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_neon(void) { +#if defined(__ARM_ARCH) + return lm_ggml_arm_arch_features.has_neon; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_sve(void) { +#if defined(__ARM_ARCH) + return lm_ggml_arm_arch_features.has_sve; +#else + return 0; +#endif +} + +int lm_ggml_cpu_has_matmul_int8(void) { +#if defined(__ARM_ARCH) + return lm_ggml_arm_arch_features.has_i8mm; +#else + return 0; +#endif +} + +int lm_ggml_cpu_get_sve_cnt(void) { +#if defined(__ARM_ARCH) + return lm_ggml_arm_arch_features.sve_cnt; +#else + return 0; +#endif +} + +void lm_ggml_cpu_init(void) { + // needed to initialize f16 tables + { + struct lm_ggml_init_params params = { 0, NULL, false }; + struct lm_ggml_context * ctx = lm_ggml_init(params); + lm_ggml_free(ctx); + } + + lm_ggml_critical_section_start(); + + static bool is_first_call = true; + + if (is_first_call) { + // initialize GELU, Quick GELU, SILU and EXP F32 tables + { + const uint64_t t_start = lm_ggml_time_us(); UNUSED(t_start); + + for (int i = 0; i < (1 << 16); ++i) { + union { + uint16_t u16; + lm_ggml_fp16_t fp16; + } u = {i}; + float f = LM_GGML_FP16_TO_FP32(u.fp16); + lm_ggml_table_gelu_f16[i] = LM_GGML_FP32_TO_FP16(lm_ggml_gelu_f32(f)); + lm_ggml_table_gelu_quick_f16[i] = LM_GGML_FP32_TO_FP16(lm_ggml_gelu_quick_f32(f)); + } + + const uint64_t t_end = lm_ggml_time_us(); UNUSED(t_end); + + LM_GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0); + } + +#if defined(__ARM_ARCH) + lm_ggml_init_arm_arch_features(); +#endif + + is_first_call = false; + } + + lm_ggml_critical_section_end(); +} diff --git a/cpp/ggml-cpu.cpp b/cpp/ggml-cpu.cpp new file mode 100644 index 00000000..56a2527b --- /dev/null +++ b/cpp/ggml-cpu.cpp @@ -0,0 +1,663 @@ +#include "ggml-backend.h" +#include "ggml-backend-impl.h" +#include "ggml-cpu.h" +#include "ggml-cpu-aarch64.h" +#include "ggml-impl.h" +#include +#include +#include + +#if defined(__APPLE__) +#include +#include +#endif + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX + #define NOMINMAX +#endif +#include +#endif + +// ggml-backend interface + +#ifdef LM_GGML_USE_CPU_HBM + +// buffer type HBM + +#include + +static const char * lm_ggml_backend_cpu_hbm_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { + return "CPU_HBM"; + + LM_GGML_UNUSED(buft); +} + +static void lm_ggml_backend_cpu_hbm_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) { + hbw_free(buffer->context); +} + +static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_hbm_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { + void * ptr; + int result = hbw_posix_memalign(&ptr, lm_ggml_backend_cpu_buffer_type_get_alignment(buft), size); + if (result != 0) { + LM_GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size); + return NULL; + } + + lm_ggml_backend_buffer_t buffer = lm_ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.free_buffer = lm_ggml_backend_cpu_hbm_buffer_free_buffer; + + return buffer; +} + +lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_hbm_buffer_type(void) { + static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type_hbm = { + /* .iface = */ { + /* .get_name = */ lm_ggml_backend_cpu_hbm_buffer_type_get_name, + /* .alloc_buffer = */ lm_ggml_backend_cpu_hbm_buffer_type_alloc_buffer, + /* .get_alignment = */ lm_ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes + /* .is_host = */ lm_ggml_backend_cpu_buffer_type_is_host, + }, + /* .context = */ NULL, + }; + + return &lm_ggml_backend_cpu_buffer_type_hbm; +} +#endif + +// buffer type AARCH64 + +static void lm_ggml_backend_cpu_aarch64_buffer_init_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor) { + tensor->extra = (void *)lm_ggml_aarch64_get_optimal_repack_type(tensor); // NOLINT + + LM_GGML_UNUSED(buffer); +} + +static void lm_ggml_backend_cpu_aarch64_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + LM_GGML_ASSERT(offset == 0); + LM_GGML_ASSERT(size == lm_ggml_nbytes(tensor)); + + enum lm_ggml_type repack_type = (enum lm_ggml_type)(intptr_t)tensor->extra; + + lm_ggml_aarch64_repack_tensor(tensor, repack_type, data, size); + + LM_GGML_UNUSED(buffer); +} + +static const char * lm_ggml_backend_cpu_aarch64_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { + return "CPU_AARCH64"; + + LM_GGML_UNUSED(buft); +} + +static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { + auto * buffer = lm_ggml_backend_buft_alloc_buffer(lm_ggml_backend_cpu_buffer_type(), size); + + if (buffer == NULL) { + return NULL; + } + + buffer->buft = buft; + buffer->iface.init_tensor = lm_ggml_backend_cpu_aarch64_buffer_init_tensor; + buffer->iface.set_tensor = lm_ggml_backend_cpu_aarch64_buffer_set_tensor; + + return buffer; +} + +lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_aarch64_buffer_type(void) { + static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type_aarch64 = { + /* .iface = */ { + /* .get_name = */ lm_ggml_backend_cpu_aarch64_buffer_type_get_name, + /* .alloc_buffer = */ lm_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer, + /* .get_alignment = */ lm_ggml_backend_cpu_buffer_type()->iface.get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes + /* .is_host = */ NULL, + }, + /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0), + /* .context = */ NULL, + }; + + return &lm_ggml_backend_cpu_buffer_type_aarch64; +} + +bool lm_ggml_backend_cpu_buft_is_aarch64(lm_ggml_backend_buffer_type_t buft) { + return buft == lm_ggml_backend_cpu_aarch64_buffer_type(); +} + +static lm_ggml_backend_buffer_type_t * lm_ggml_backend_cpu_get_extra_bufts(lm_ggml_backend_dev_t device) { + static std::vector bufts = []() { + std::vector bufts; + +#ifdef LM_GGML_USE_CPU_HBM + bufts.push_back(lm_ggml_backend_cpu_hbm_buffer_type()); +#endif + +#ifdef LM_GGML_USE_CPU_AARCH64 + bufts.push_back(lm_ggml_backend_cpu_aarch64_buffer_type()); +#endif + + bufts.push_back(NULL); + + return bufts; + }(); + + return bufts.data(); + + LM_GGML_UNUSED(device); +} + +// CPU backend - backend (stream) + +struct lm_ggml_backend_cpu_context { + int n_threads; + lm_ggml_threadpool_t threadpool; + + uint8_t * work_data; + size_t work_size; + + lm_ggml_abort_callback abort_callback; + void * abort_callback_data; +}; + +static const char * lm_ggml_backend_cpu_get_name(lm_ggml_backend_t backend) { + return "CPU"; + + LM_GGML_UNUSED(backend); +} + +static void lm_ggml_backend_cpu_free(lm_ggml_backend_t backend) { + struct lm_ggml_backend_cpu_context * cpu_ctx = (struct lm_ggml_backend_cpu_context *)backend->context; + delete[] cpu_ctx->work_data; + delete cpu_ctx; + delete backend; +} + +struct lm_ggml_backend_plan_cpu { + struct lm_ggml_cplan cplan; + struct lm_ggml_cgraph cgraph; +}; + +static lm_ggml_backend_graph_plan_t lm_ggml_backend_cpu_graph_plan_create(lm_ggml_backend_t backend, const struct lm_ggml_cgraph * cgraph) { + struct lm_ggml_backend_cpu_context * cpu_ctx = (struct lm_ggml_backend_cpu_context *)backend->context; + + struct lm_ggml_backend_plan_cpu * cpu_plan = new lm_ggml_backend_plan_cpu; + + cpu_plan->cplan = lm_ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool); + cpu_plan->cgraph = *cgraph; // FIXME: deep copy + + if (cpu_plan->cplan.work_size > 0) { + cpu_plan->cplan.work_data = new uint8_t[cpu_plan->cplan.work_size]; + if (cpu_plan->cplan.work_data == NULL) { + delete cpu_plan; + return NULL; + } + } + + cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback; + cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data; + + return cpu_plan; +} + +static void lm_ggml_backend_cpu_graph_plan_free(lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan) { + struct lm_ggml_backend_plan_cpu * cpu_plan = (struct lm_ggml_backend_plan_cpu *)plan; + + delete[] cpu_plan->cplan.work_data; + delete cpu_plan; + + LM_GGML_UNUSED(backend); +} + +static enum lm_ggml_status lm_ggml_backend_cpu_graph_plan_compute(lm_ggml_backend_t backend, lm_ggml_backend_graph_plan_t plan) { + struct lm_ggml_backend_plan_cpu * cpu_plan = (struct lm_ggml_backend_plan_cpu *)plan; + + return lm_ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); + + LM_GGML_UNUSED(backend); +} + +static enum lm_ggml_status lm_ggml_backend_cpu_graph_compute(lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph) { + struct lm_ggml_backend_cpu_context * cpu_ctx = (struct lm_ggml_backend_cpu_context *)backend->context; + + struct lm_ggml_cplan cplan = lm_ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool); + + if (cpu_ctx->work_size < cplan.work_size) { + delete[] cpu_ctx->work_data; + cpu_ctx->work_data = new uint8_t[cplan.work_size]; + if (cpu_ctx->work_data == NULL) { + cpu_ctx->work_size = 0; + return LM_GGML_STATUS_ALLOC_FAILED; + } + cpu_ctx->work_size = cplan.work_size; + } + cplan.work_data = (uint8_t *)cpu_ctx->work_data; + + cplan.abort_callback = cpu_ctx->abort_callback; + cplan.abort_callback_data = cpu_ctx->abort_callback_data; + + return lm_ggml_graph_compute(cgraph, &cplan); +} + +static const struct lm_ggml_backend_i lm_ggml_backend_cpu_i = { + /* .get_name = */ lm_ggml_backend_cpu_get_name, + /* .free = */ lm_ggml_backend_cpu_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ lm_ggml_backend_cpu_graph_plan_create, + /* .graph_plan_free = */ lm_ggml_backend_cpu_graph_plan_free, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ lm_ggml_backend_cpu_graph_plan_compute, + /* .graph_compute = */ lm_ggml_backend_cpu_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +static lm_ggml_guid_t lm_ggml_backend_cpu_guid(void) { + static lm_ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 }; + return &guid; +} + +lm_ggml_backend_t lm_ggml_backend_cpu_init(void) { + // initialize CPU backend now to avoid slowing the first graph computation + lm_ggml_cpu_init(); + + struct lm_ggml_backend_cpu_context * ctx = new lm_ggml_backend_cpu_context; + if (ctx == NULL) { + return NULL; + } + + ctx->n_threads = LM_GGML_DEFAULT_N_THREADS; + ctx->threadpool = NULL; + ctx->work_data = NULL; + ctx->work_size = 0; + ctx->abort_callback = NULL; + ctx->abort_callback_data = NULL; + + lm_ggml_backend_t cpu_backend = new lm_ggml_backend { + /* .guid = */ lm_ggml_backend_cpu_guid(), + /* .interface = */ lm_ggml_backend_cpu_i, + /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0), + /* .context = */ ctx, + }; + + if (cpu_backend == NULL) { + delete ctx; + return NULL; + } + + return cpu_backend; +} + +bool lm_ggml_backend_is_cpu(lm_ggml_backend_t backend) { + return backend != NULL && lm_ggml_guid_matches(backend->guid, lm_ggml_backend_cpu_guid()); +} + +void lm_ggml_backend_cpu_set_n_threads(lm_ggml_backend_t backend_cpu, int n_threads) { + LM_GGML_ASSERT(lm_ggml_backend_is_cpu(backend_cpu)); + + struct lm_ggml_backend_cpu_context * ctx = (struct lm_ggml_backend_cpu_context *)backend_cpu->context; + ctx->n_threads = n_threads; +} + +void lm_ggml_backend_cpu_set_threadpool(lm_ggml_backend_t backend_cpu, lm_ggml_threadpool_t threadpool) { + LM_GGML_ASSERT(lm_ggml_backend_is_cpu(backend_cpu)); + + struct lm_ggml_backend_cpu_context * ctx = (struct lm_ggml_backend_cpu_context *)backend_cpu->context; + + if (ctx->threadpool && ctx->threadpool != threadpool) { + // already had a different threadpool, pause/suspend it before switching + lm_ggml_threadpool_pause(ctx->threadpool); + } + ctx->threadpool = threadpool; +} + +void lm_ggml_backend_cpu_set_abort_callback(lm_ggml_backend_t backend_cpu, lm_ggml_abort_callback abort_callback, void * abort_callback_data) { + LM_GGML_ASSERT(lm_ggml_backend_is_cpu(backend_cpu)); + + struct lm_ggml_backend_cpu_context * ctx = (struct lm_ggml_backend_cpu_context *)backend_cpu->context; + ctx->abort_callback = abort_callback; + ctx->abort_callback_data = abort_callback_data; +} + +// CPU backend - device + +struct lm_ggml_backend_cpu_device_context { + std::string description = "CPU"; + + lm_ggml_backend_cpu_device_context() { +#ifdef __APPLE__ + size_t len = 0; + if (!sysctlbyname("machdep.cpu.brand_string", NULL, &len, NULL, 0)) { + description.resize(len); + sysctlbyname("machdep.cpu.brand_string", &description[0], &len, NULL, 0); // NOLINT + } +#elif defined(__linux__) + FILE * f = fopen("/proc/cpuinfo", "r"); + if (f) { + char buf[1024]; + while (fgets(buf, sizeof(buf), f)) { + if (strncmp(buf, "model name", 10) == 0) { + char * p = strchr(buf, ':'); + if (p) { + p++; + while (std::isspace(*p)) { + p++; + } + while (std::isspace(p[strlen(p) - 1])) { + p[strlen(p) - 1] = '\0'; + } + description = p; + break; + } + } + } + fclose(f); + } +#elif defined(_WIN32) + HKEY hKey; + if (RegOpenKeyEx(HKEY_LOCAL_MACHINE, + TEXT("HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"), + 0, + KEY_READ, + &hKey) == ERROR_SUCCESS) { + DWORD cpu_brand_size = 0; + if (RegQueryValueExA(hKey, + TEXT("ProcessorNameString"), + NULL, + NULL, + NULL, + &cpu_brand_size) == ERROR_SUCCESS) { + description.resize(cpu_brand_size); + if (RegQueryValueExA(hKey, + TEXT("ProcessorNameString"), + NULL, + NULL, + (LPBYTE)&description[0], // NOLINT + &cpu_brand_size) == ERROR_SUCCESS) { + if (description.find('\0') != std::string::npos) { + description.resize(description.find('\0')); + } + } + } + RegCloseKey(hKey); + } +#endif + } +}; + +static const char * lm_ggml_backend_cpu_device_get_name(lm_ggml_backend_dev_t dev) { + return "CPU"; + + LM_GGML_UNUSED(dev); +} + +static const char * lm_ggml_backend_cpu_device_get_description(lm_ggml_backend_dev_t dev) { + struct lm_ggml_backend_cpu_device_context * ctx = (struct lm_ggml_backend_cpu_device_context *)dev->context; + + return ctx->description.c_str(); +} + +static void lm_ggml_backend_cpu_device_get_memory(lm_ggml_backend_dev_t dev, size_t * free, size_t * total) { + // TODO + *free = 0; + *total = 0; + + LM_GGML_UNUSED(dev); +} + +static enum lm_ggml_backend_dev_type lm_ggml_backend_cpu_device_get_type(lm_ggml_backend_dev_t dev) { + return LM_GGML_BACKEND_DEVICE_TYPE_CPU; + + LM_GGML_UNUSED(dev); +} + +static void lm_ggml_backend_cpu_device_get_props(lm_ggml_backend_dev_t dev, struct lm_ggml_backend_dev_props * props) { + props->name = lm_ggml_backend_cpu_device_get_name(dev); + props->description = lm_ggml_backend_cpu_device_get_description(dev); + props->type = lm_ggml_backend_cpu_device_get_type(dev); + lm_ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ false, + }; +} + +static lm_ggml_backend_t lm_ggml_backend_cpu_device_init_backend(lm_ggml_backend_dev_t dev, const char * params) { + return lm_ggml_backend_cpu_init(); + + LM_GGML_UNUSED(dev); + LM_GGML_UNUSED(params); +} + +static lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_device_get_buffer_type(lm_ggml_backend_dev_t dev) { + return lm_ggml_backend_cpu_buffer_type(); + + LM_GGML_UNUSED(dev); +} + +static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_device_buffer_from_host_ptr(lm_ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + return lm_ggml_backend_cpu_buffer_from_ptr(ptr, size); + + LM_GGML_UNUSED(dev); + LM_GGML_UNUSED(max_tensor_size); +} + +static bool lm_ggml_backend_cpu_device_supports_op(lm_ggml_backend_dev_t dev, const struct lm_ggml_tensor * op) { + const struct lm_ggml_tensor * src0 = op->src[0]; + const struct lm_ggml_tensor * src1 = op->src[1]; + + if (src0 && src0->buffer && lm_ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) { + if (op->op != LM_GGML_OP_MUL_MAT || src0->type != LM_GGML_TYPE_Q4_0 || lm_ggml_aarch64_get_optimal_repack_type(src0) == LM_GGML_TYPE_Q4_0) { + return false; + } + } + + for (int i = 1; i < LM_GGML_MAX_SRC; i++) { + if (op->src[i] && op->src[i]->buffer && lm_ggml_backend_cpu_buft_is_aarch64(op->src[i]->buffer->buft)) { + return false; + } + } + + switch (op->op) { + case LM_GGML_OP_CPY: + return + op->type != LM_GGML_TYPE_IQ2_XXS && + op->type != LM_GGML_TYPE_IQ2_XS && + op->type != LM_GGML_TYPE_IQ1_S && + op->type != LM_GGML_TYPE_IQ1_M; // missing type_traits.from_float + case LM_GGML_OP_MUL_MAT: + return src1->type == LM_GGML_TYPE_F32 || src1->type == lm_ggml_get_type_traits_cpu(src0->type)->vec_dot_type; + case LM_GGML_OP_ROPE_BACK: + return op->src[2] == NULL && (op->op_params[2] & 4) == 0; + case LM_GGML_OP_IM2COL_BACK: + return src0->type == LM_GGML_TYPE_F32 && src1->type == LM_GGML_TYPE_F32; + case LM_GGML_OP_OUT_PROD: + return (src0->type == LM_GGML_TYPE_F32 || lm_ggml_is_quantized(src0->type)) && src1->type == LM_GGML_TYPE_F32; + default: + return true; + } + + LM_GGML_UNUSED(dev); +} + +static bool lm_ggml_backend_cpu_device_supports_buft(lm_ggml_backend_dev_t dev, lm_ggml_backend_buffer_type_t buft) { + return lm_ggml_backend_buft_is_host(buft) || lm_ggml_backend_cpu_buft_is_aarch64(buft); + + LM_GGML_UNUSED(dev); +} + +static const struct lm_ggml_backend_device_i lm_ggml_backend_cpu_device_i = { + /* .get_name = */ lm_ggml_backend_cpu_device_get_name, + /* .get_description = */ lm_ggml_backend_cpu_device_get_description, + /* .get_memory = */ lm_ggml_backend_cpu_device_get_memory, + /* .get_type = */ lm_ggml_backend_cpu_device_get_type, + /* .get_props = */ lm_ggml_backend_cpu_device_get_props, + /* .init_backend = */ lm_ggml_backend_cpu_device_init_backend, + /* .get_buffer_type = */ lm_ggml_backend_cpu_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ lm_ggml_backend_cpu_device_buffer_from_host_ptr, + /* .supports_op = */ lm_ggml_backend_cpu_device_supports_op, + /* .supports_buft = */ lm_ggml_backend_cpu_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// CPU backend - backend (reg) + +static const char * lm_ggml_backend_cpu_reg_get_name(lm_ggml_backend_reg_t reg) { + return "CPU"; + + LM_GGML_UNUSED(reg); +} + +static size_t lm_ggml_backend_cpu_reg_get_device_count(lm_ggml_backend_reg_t reg) { + return 1; + + LM_GGML_UNUSED(reg); +} + +static lm_ggml_backend_dev_t lm_ggml_backend_cpu_reg_get_device(lm_ggml_backend_reg_t reg, size_t index) { + LM_GGML_ASSERT(index == 0); + + static lm_ggml_backend_cpu_device_context ctx; + static lm_ggml_backend_device lm_ggml_backend_cpu_device = { + /* .iface = */ lm_ggml_backend_cpu_device_i, + /* .reg = */ reg, + /* .context = */ &ctx, + }; + + return &lm_ggml_backend_cpu_device; +} + +struct lm_ggml_backend_feature { + const char * name; + const char * value; +}; + +// Not used yet +// This is intended to replace the the lm_ggml_cpu_has_* functions when loading the CPU backend dynamically, +// and additionally to allow other backends to expose their own list of features that applications can query using the same API. +static lm_ggml_backend_feature * lm_ggml_backend_cpu_get_features(lm_ggml_backend_reg_t reg) { + static std::vector features = []() { + std::vector features; + if (lm_ggml_cpu_has_sse3()) { + features.push_back({ "SSE3", "1" }); + } + if (lm_ggml_cpu_has_ssse3()) { + features.push_back({ "SSSE3", "1" }); + } + if (lm_ggml_cpu_has_avx()) { + features.push_back({ "AVX", "1" }); + } + if (lm_ggml_cpu_has_avx2()) { + features.push_back({ "AVX2", "1" }); + } + if (lm_ggml_cpu_has_f16c()) { + features.push_back({ "F16C", "1" }); + } + if (lm_ggml_cpu_has_fma()) { + features.push_back({ "FMA", "1" }); + } + if (lm_ggml_cpu_has_avx_vnni()) { + features.push_back({ "AVX_VNNI", "1" }); + } + if (lm_ggml_cpu_has_avx512()) { + features.push_back({ "AVX512", "1" }); + } + if (lm_ggml_cpu_has_avx512_vbmi()) { + features.push_back({ "AVX512_VBMI", "1" }); + } + if (lm_ggml_cpu_has_avx512_vnni()) { + features.push_back({ "AVX512_VNNI", "1" }); + } + if (lm_ggml_cpu_has_avx512_bf16()) { + features.push_back({ "AVX512_BF16", "1" }); + } + if (lm_ggml_cpu_has_amx_int8()) { + features.push_back({ "AMX_INT8", "1" }); + } + if (lm_ggml_cpu_has_neon()) { + features.push_back({ "NEON", "1" }); + } + if (lm_ggml_cpu_has_arm_fma()) { + features.push_back({ "ARM_FMA", "1" }); + } + if (lm_ggml_cpu_has_fp16_va()) { + features.push_back({ "FP16_VA", "1" }); + } + if (lm_ggml_cpu_has_matmul_int8()) { + features.push_back({ "MATMUL_INT8", "1" }); + } + if (lm_ggml_cpu_has_sve()) { + features.push_back({ "SVE", "1" }); + } + if (lm_ggml_cpu_get_sve_cnt() > 0) { + static std::string sve_cnt = std::to_string(lm_ggml_cpu_get_sve_cnt()); + features.push_back({ "SVE_CNT", sve_cnt.c_str() }); + } + if (lm_ggml_cpu_has_riscv_v()) { + features.push_back({ "RISCV_V", "1" }); + } + if (lm_ggml_cpu_has_vsx()) { + features.push_back({ "VSX", "1" }); + } + if (lm_ggml_cpu_has_wasm_simd()) { + features.push_back({ "WASM_SIMD", "1" }); + } + if (lm_ggml_cpu_has_llamafile()) { + features.push_back({ "LLAMAFILE", "1" }); + } + + features.push_back({ nullptr, nullptr }); + + return features; + }(); + + return features.data(); + + LM_GGML_UNUSED(reg); +} + +static void * lm_ggml_backend_cpu_get_proc_address(lm_ggml_backend_reg_t reg, const char * name) { + if (strcmp(name, "lm_ggml_backend_set_n_threads") == 0) { + return (void *)lm_ggml_backend_cpu_set_n_threads; + } + if (strcmp(name, "lm_ggml_backend_dev_get_extra_bufts") == 0) { + return (void *)lm_ggml_backend_cpu_get_extra_bufts; + } + + return NULL; + + LM_GGML_UNUSED(reg); +} + +static const struct lm_ggml_backend_reg_i lm_ggml_backend_cpu_reg_i = { + /* .get_name = */ lm_ggml_backend_cpu_reg_get_name, + /* .get_device_count = */ lm_ggml_backend_cpu_reg_get_device_count, + /* .get_device = */ lm_ggml_backend_cpu_reg_get_device, + /* .get_proc_address = */ lm_ggml_backend_cpu_get_proc_address, +}; + +lm_ggml_backend_reg_t lm_ggml_backend_cpu_reg(void) { + // init CPU feature detection + lm_ggml_cpu_init(); + + static struct lm_ggml_backend_reg lm_ggml_backend_cpu_reg = { + /* .iface = */ lm_ggml_backend_cpu_reg_i, + /* .context = */ NULL, + }; + + return &lm_ggml_backend_cpu_reg; +} diff --git a/cpp/ggml-cpu.h b/cpp/ggml-cpu.h new file mode 100644 index 00000000..a49af8ff --- /dev/null +++ b/cpp/ggml-cpu.h @@ -0,0 +1,177 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + + // Scheduling priorities + enum lm_ggml_sched_priority { + LM_GGML_SCHED_PRIO_NORMAL, + LM_GGML_SCHED_PRIO_MEDIUM, + LM_GGML_SCHED_PRIO_HIGH, + LM_GGML_SCHED_PRIO_REALTIME + }; + + // Threadpool params + // Use lm_ggml_threadpool_params_default() or lm_ggml_threadpool_params_init() to populate the defaults + struct lm_ggml_threadpool_params { + bool cpumask[LM_GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings) + int n_threads; // number of threads + enum lm_ggml_sched_priority prio; // thread priority + uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling) + bool strict_cpu; // strict cpu placement + bool paused; // start in paused state + }; + + struct lm_ggml_threadpool; // forward declaration, see ggml.c + + typedef struct lm_ggml_threadpool * lm_ggml_threadpool_t; + + // the compute plan that needs to be prepared for lm_ggml_graph_compute() + // since https://github.com/ggerganov/ggml/issues/287 + struct lm_ggml_cplan { + size_t work_size; // size of work buffer, calculated by `lm_ggml_graph_plan()` + uint8_t * work_data; // work buffer, to be allocated by caller before calling to `lm_ggml_graph_compute()` + + int n_threads; + struct lm_ggml_threadpool * threadpool; + + // abort lm_ggml_graph_compute when true + lm_ggml_abort_callback abort_callback; + void * abort_callback_data; + }; + + // numa strategies + enum lm_ggml_numa_strategy { + LM_GGML_NUMA_STRATEGY_DISABLED = 0, + LM_GGML_NUMA_STRATEGY_DISTRIBUTE = 1, + LM_GGML_NUMA_STRATEGY_ISOLATE = 2, + LM_GGML_NUMA_STRATEGY_NUMACTL = 3, + LM_GGML_NUMA_STRATEGY_MIRROR = 4, + LM_GGML_NUMA_STRATEGY_COUNT + }; + + LM_GGML_BACKEND_API void lm_ggml_numa_init(enum lm_ggml_numa_strategy numa); // call once for better performance on NUMA systems + LM_GGML_BACKEND_API bool lm_ggml_is_numa(void); // true if init detected that system has >1 NUMA node + + LM_GGML_BACKEND_API struct lm_ggml_tensor * lm_ggml_new_i32(struct lm_ggml_context * ctx, int32_t value); + LM_GGML_BACKEND_API struct lm_ggml_tensor * lm_ggml_new_f32(struct lm_ggml_context * ctx, float value); + + LM_GGML_BACKEND_API struct lm_ggml_tensor * lm_ggml_set_i32 (struct lm_ggml_tensor * tensor, int32_t value); + LM_GGML_BACKEND_API struct lm_ggml_tensor * lm_ggml_set_f32 (struct lm_ggml_tensor * tensor, float value); + + LM_GGML_BACKEND_API int32_t lm_ggml_get_i32_1d(const struct lm_ggml_tensor * tensor, int i); + LM_GGML_BACKEND_API void lm_ggml_set_i32_1d(const struct lm_ggml_tensor * tensor, int i, int32_t value); + + LM_GGML_BACKEND_API int32_t lm_ggml_get_i32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3); + LM_GGML_BACKEND_API void lm_ggml_set_i32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value); + + LM_GGML_BACKEND_API float lm_ggml_get_f32_1d(const struct lm_ggml_tensor * tensor, int i); + LM_GGML_BACKEND_API void lm_ggml_set_f32_1d(const struct lm_ggml_tensor * tensor, int i, float value); + + LM_GGML_BACKEND_API float lm_ggml_get_f32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3); + LM_GGML_BACKEND_API void lm_ggml_set_f32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value); + + LM_GGML_BACKEND_API struct lm_ggml_threadpool_params lm_ggml_threadpool_params_default(int n_threads); + LM_GGML_BACKEND_API void lm_ggml_threadpool_params_init (struct lm_ggml_threadpool_params * p, int n_threads); + LM_GGML_BACKEND_API bool lm_ggml_threadpool_params_match (const struct lm_ggml_threadpool_params * p0, const struct lm_ggml_threadpool_params * p1); + LM_GGML_BACKEND_API struct lm_ggml_threadpool * lm_ggml_threadpool_new (struct lm_ggml_threadpool_params * params); + LM_GGML_BACKEND_API void lm_ggml_threadpool_free (struct lm_ggml_threadpool * threadpool); + LM_GGML_BACKEND_API int lm_ggml_threadpool_get_n_threads(struct lm_ggml_threadpool * threadpool); + LM_GGML_BACKEND_API void lm_ggml_threadpool_pause (struct lm_ggml_threadpool * threadpool); + LM_GGML_BACKEND_API void lm_ggml_threadpool_resume (struct lm_ggml_threadpool * threadpool); + + // lm_ggml_graph_plan() has to be called before lm_ggml_graph_compute() + // when plan.work_size > 0, caller must allocate memory for plan.work_data + LM_GGML_BACKEND_API struct lm_ggml_cplan lm_ggml_graph_plan( + const struct lm_ggml_cgraph * cgraph, + int n_threads, /* = LM_GGML_DEFAULT_N_THREADS */ + struct lm_ggml_threadpool * threadpool /* = NULL */ ); + LM_GGML_BACKEND_API enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct lm_ggml_cplan * cplan); + + // same as lm_ggml_graph_compute() but the work data is allocated as a part of the context + // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data + LM_GGML_BACKEND_API enum lm_ggml_status lm_ggml_graph_compute_with_ctx(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, int n_threads); + + // + // system info + // + + // x86 + LM_GGML_BACKEND_API int lm_ggml_cpu_has_sse3 (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_ssse3 (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_avx (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_avx2 (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_f16c (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_fma (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_avx_vnni (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_avx512 (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_avx512_vbmi(void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_avx512_vnni(void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_avx512_bf16(void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_amx_int8 (void); + // ARM + LM_GGML_BACKEND_API int lm_ggml_cpu_has_neon (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_arm_fma (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_fp16_va (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_matmul_int8(void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_sve (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_get_sve_cnt (void); // sve vector length in bytes + // other + LM_GGML_BACKEND_API int lm_ggml_cpu_has_riscv_v (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_vsx (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_wasm_simd (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_llamafile (void); + + // Internal types and functions exposed for tests and benchmarks + + typedef void (*lm_ggml_from_float_to_mat_t) + (const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs); + typedef void (*lm_ggml_vec_dot_t) (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT x, size_t bx, + const void * LM_GGML_RESTRICT y, size_t by, int nrc); + typedef void (*lm_ggml_gemv_t) (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT x, + const void * LM_GGML_RESTRICT y, int nr, int nc); + typedef void (*lm_ggml_gemm_t) (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT x, + const void * LM_GGML_RESTRICT y, int nr, int nc); + + struct lm_ggml_type_traits_cpu { + lm_ggml_from_float_t from_float; + lm_ggml_from_float_to_mat_t from_float_to_mat; + lm_ggml_vec_dot_t vec_dot; + enum lm_ggml_type vec_dot_type; + int64_t nrows; // number of rows to process simultaneously + int64_t ncols; // number of columns to process simultaneously + lm_ggml_gemv_t gemv; + lm_ggml_gemm_t gemm; + }; + + LM_GGML_BACKEND_API const struct lm_ggml_type_traits_cpu * lm_ggml_get_type_traits_cpu(enum lm_ggml_type type); + + LM_GGML_BACKEND_API void lm_ggml_cpu_init(void); + + // + // CPU backend + // + + LM_GGML_BACKEND_API lm_ggml_backend_t lm_ggml_backend_cpu_init(void); + + LM_GGML_BACKEND_API bool lm_ggml_backend_is_cpu (lm_ggml_backend_t backend); + LM_GGML_BACKEND_API void lm_ggml_backend_cpu_set_n_threads (lm_ggml_backend_t backend_cpu, int n_threads); + LM_GGML_BACKEND_API void lm_ggml_backend_cpu_set_threadpool (lm_ggml_backend_t backend_cpu, lm_ggml_threadpool_t threadpool); + LM_GGML_BACKEND_API void lm_ggml_backend_cpu_set_abort_callback(lm_ggml_backend_t backend_cpu, lm_ggml_abort_callback abort_callback, void * abort_callback_data); + + LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_cpu_reg(void); + +#ifdef LM_GGML_USE_CPU_HBM + LM_GGML_BACKEND_API lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_hbm_buffer_type(void); +#endif + + LM_GGML_BACKEND_API lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_aarch64_buffer_type(void); + LM_GGML_BACKEND_API bool lm_ggml_backend_cpu_buft_is_aarch64(lm_ggml_backend_buffer_type_t buft); + +#ifdef __cplusplus +} +#endif diff --git a/cpp/ggml-impl.h b/cpp/ggml-impl.h index d25c8680..e01aedf1 100644 --- a/cpp/ggml-impl.h +++ b/cpp/ggml-impl.h @@ -3,11 +3,28 @@ // GGML internal header #include "ggml.h" - #include +#include #include // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/ #include #include +#include + +#ifdef __ARM_FEATURE_SVE +#include +#endif // __ARM_FEATURE_SVE + +#if defined(__ARM_NEON) +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include +#endif + +#if defined(__F16C__) +#include +#endif #ifdef __cplusplus extern "C" { @@ -27,15 +44,29 @@ extern "C" { // if C99 - static_assert is noop // ref: https://stackoverflow.com/a/53923785/4039976 #ifndef __cplusplus -#ifndef static_assert -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) -#define static_assert(cond, msg) _Static_assert(cond, msg) -#else -#define static_assert(cond, msg) struct global_scope_noop_trick -#endif -#endif + #ifndef static_assert + #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) + #define static_assert(cond, msg) _Static_assert(cond, msg) + #else + #define static_assert(cond, msg) struct global_scope_noop_trick + #endif + #endif #endif +static inline int lm_ggml_up32(int n) { + return (n + 31) & ~31; +} + +//static inline int lm_ggml_up64(int n) { +// return (n + 63) & ~63; +//} + +static inline int lm_ggml_up(int n, int m) { + // assert m is a power of 2 + LM_GGML_ASSERT((m & (m - 1)) == 0); + return (n + m - 1) & ~(m - 1); +} + // // logging // @@ -51,6 +82,72 @@ void lm_ggml_log_callback_default(enum lm_ggml_log_level level, const char * tex #define LM_GGML_LOG_DEBUG(...) lm_ggml_log_internal(LM_GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) #define LM_GGML_LOG_CONT(...) lm_ggml_log_internal(LM_GGML_LOG_LEVEL_CONT , __VA_ARGS__) +#define LM_GGML_DEBUG 0 + +#if (LM_GGML_DEBUG >= 1) +#define LM_GGML_PRINT_DEBUG(...) LM_GGML_LOG_DEBUG(__VA_ARGS__) +#else +#define LM_GGML_PRINT_DEBUG(...) +#endif + +#if (LM_GGML_DEBUG >= 5) +#define LM_GGML_PRINT_DEBUG_5(...) LM_GGML_LOG_DEBUG(__VA_ARGS__) +#else +#define LM_GGML_PRINT_DEBUG_5(...) +#endif + +#if (LM_GGML_DEBUG >= 10) +#define LM_GGML_PRINT_DEBUG_10(...) LM_GGML_LOG_DEBUG(__VA_ARGS__) +#else +#define LM_GGML_PRINT_DEBUG_10(...) +#endif + +// tensor params + +static void lm_ggml_set_op_params(struct lm_ggml_tensor * tensor, const void * params, size_t params_size) { + LM_GGML_ASSERT(tensor != NULL); // silence -Warray-bounds warnings + assert(params_size <= LM_GGML_MAX_OP_PARAMS); + memcpy(tensor->op_params, params, params_size); +} + +static int32_t lm_ggml_get_op_params_i32(const struct lm_ggml_tensor * tensor, uint32_t i) { + assert(i < LM_GGML_MAX_OP_PARAMS / sizeof(int32_t)); + return ((const int32_t *)(tensor->op_params))[i]; +} + +static float lm_ggml_get_op_params_f32(const struct lm_ggml_tensor * tensor, uint32_t i) { + assert(i < LM_GGML_MAX_OP_PARAMS / sizeof(float)); + return ((const float *)(tensor->op_params))[i]; +} + +static void lm_ggml_set_op_params_i32(struct lm_ggml_tensor * tensor, uint32_t i, int32_t value) { + assert(i < LM_GGML_MAX_OP_PARAMS / sizeof(int32_t)); + ((int32_t *)(tensor->op_params))[i] = value; +} + +static void lm_ggml_set_op_params_f32(struct lm_ggml_tensor * tensor, uint32_t i, float value) { + assert(i < LM_GGML_MAX_OP_PARAMS / sizeof(float)); + ((float *)(tensor->op_params))[i] = value; +} + +struct lm_ggml_map_custom1_op_params { + lm_ggml_custom1_op_t fun; + int n_tasks; + void * userdata; +}; + +struct lm_ggml_map_custom2_op_params { + lm_ggml_custom2_op_t fun; + int n_tasks; + void * userdata; +}; + +struct lm_ggml_map_custom3_op_params { + lm_ggml_custom3_op_t fun; + int n_tasks; + void * userdata; +}; + // bitset typedef uint32_t lm_ggml_bitset_t; @@ -204,6 +301,250 @@ struct lm_ggml_cgraph lm_ggml_graph_view(struct lm_ggml_cgraph * cgraph, int i0, void * lm_ggml_aligned_malloc(size_t size); void lm_ggml_aligned_free(void * ptr, size_t size); +// FP16 to FP32 conversion + +#if defined(__ARM_NEON) + #ifdef _MSC_VER + typedef uint16_t lm_ggml_fp16_internal_t; + #else + typedef __fp16 lm_ggml_fp16_internal_t; + #endif +#endif + +#if defined(__ARM_NEON) && !defined(_MSC_VER) + #define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) + #define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x) + + #define LM_GGML_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) + + static inline float lm_ggml_compute_fp16_to_fp32(lm_ggml_fp16_t h) { + lm_ggml_fp16_internal_t tmp; + memcpy(&tmp, &h, sizeof(lm_ggml_fp16_t)); + return (float)tmp; + } + + static inline lm_ggml_fp16_t lm_ggml_compute_fp32_to_fp16(float f) { + lm_ggml_fp16_t res; + lm_ggml_fp16_internal_t tmp = f; + memcpy(&res, &tmp, sizeof(lm_ggml_fp16_t)); + return res; + } + +#elif defined(__F16C__) + + #ifdef _MSC_VER + #define LM_GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) + #define LM_GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) + #else + #define LM_GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) + #define LM_GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) + #endif + +#elif defined(__POWER9_VECTOR__) + + #define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) + #define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x) + /* the inline asm below is about 12% faster than the lookup method */ + #define LM_GGML_FP16_TO_FP32(x) LM_GGML_COMPUTE_FP16_TO_FP32(x) + #define LM_GGML_FP32_TO_FP16(x) LM_GGML_COMPUTE_FP32_TO_FP16(x) + + static inline float lm_ggml_compute_fp16_to_fp32(lm_ggml_fp16_t h) { + register float f; + register double d; + __asm__( + "mtfprd %0,%2\n" + "xscvhpdp %0,%0\n" + "frsp %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=f"(f): + /* in */ "r"(h)); + return f; + } + + static inline lm_ggml_fp16_t lm_ggml_compute_fp32_to_fp16(float f) { + register double d; + register lm_ggml_fp16_t r; + __asm__( /* xscvdphp can work on double or single precision */ + "xscvdphp %0,%2\n" + "mffprd %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=r"(r): + /* in */ "f"(f)); + return r; + } + +#else + + // FP16 <-> FP32 + // ref: https://github.com/Maratyszcza/FP16 + + static inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32; + fp32.as_bits = w; + return fp32.as_value; + } + + static inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32; + fp32.as_value = f; + return fp32.as_bits; + } + + static inline float lm_ggml_compute_fp16_to_fp32(lm_ggml_fp16_t h) { + const uint32_t w = (uint32_t) h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + const uint32_t exp_offset = UINT32_C(0xE0) << 23; + #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L) + const float exp_scale = 0x1.0p-112f; + #else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); + #endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); + } + + static inline lm_ggml_fp16_t lm_ggml_compute_fp32_to_fp16(float f) { + #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; + #else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); + #endif + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); + } + + #define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) + #define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x) + +#endif // defined(__ARM_NEON) && (!defined(__MSC_VER) + +// precomputed f32 table for f16 (256 KB) +// defined in ggml.c, initialized in lm_ggml_init() +LM_GGML_API float lm_ggml_table_f32_f16[1 << 16]; + +// On ARM NEON, it's quicker to directly convert x -> x instead of calling into lm_ggml_lookup_fp16_to_fp32, +// so we define LM_GGML_FP16_TO_FP32 and LM_GGML_FP32_TO_FP16 elsewhere for NEON. +// This is also true for POWER9. +#if !defined(LM_GGML_FP16_TO_FP32) +inline static float lm_ggml_lookup_fp16_to_fp32(lm_ggml_fp16_t f) { + uint16_t s; + memcpy(&s, &f, sizeof(uint16_t)); + return lm_ggml_table_f32_f16[s]; +} + +#define LM_GGML_FP16_TO_FP32(x) lm_ggml_lookup_fp16_to_fp32(x) +#endif + +#if !defined(LM_GGML_FP32_TO_FP16) +#define LM_GGML_FP32_TO_FP16(x) LM_GGML_COMPUTE_FP32_TO_FP16(x) +#endif + +/** + * Converts brain16 to float32. + * + * The bfloat16 floating point format has the following structure: + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌──┴───┐┌─┴───┐ + * 0b0000000000000000 brain16 + * + * Since bf16 has the same number of exponent bits as a 32bit float, + * encoding and decoding numbers becomes relatively straightforward. + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌──┴───┐┌─┴───────────────────┐ + * 0b00000000000000000000000000000000 IEEE binary32 + * + * For comparison, the standard fp16 format has fewer exponent bits. + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌─┴─┐┌─┴──────┐ + * 0b0000000000000000 IEEE binary16 + * + * @see IEEE 754-2008 + */ +static inline float lm_ggml_compute_bf16_to_fp32(lm_ggml_bf16_t h) { + union { + float f; + uint32_t i; + } u; + u.i = (uint32_t)h.bits << 16; + return u.f; +} + +/** + * Converts float32 to brain16. + * + * This is binary identical with Google Brain float conversion. + * Floats shall round to nearest even, and NANs shall be quiet. + * Subnormals aren't flushed to zero, except perhaps when used. + * This code should vectorize nicely if using modern compilers. + */ +static inline lm_ggml_bf16_t lm_ggml_compute_fp32_to_bf16(float s) { + lm_ggml_bf16_t h; + union { + float f; + uint32_t i; + } u; + u.f = s; + if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */ + h.bits = (u.i >> 16) | 64; /* force to quiet */ + return h; + } + h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16; + return h; +} + +#define LM_GGML_FP32_TO_BF16(x) lm_ggml_compute_fp32_to_bf16(x) +#define LM_GGML_BF16_TO_FP32(x) lm_ggml_compute_bf16_to_fp32(x) + #ifdef __cplusplus } #endif diff --git a/cpp/ggml-metal.h b/cpp/ggml-metal.h index 28fc8193..b12d2bdb 100644 --- a/cpp/ggml-metal.h +++ b/cpp/ggml-metal.h @@ -39,27 +39,27 @@ extern "C" { // user-code should use only these functions // -LM_GGML_API lm_ggml_backend_t lm_ggml_backend_metal_init(void); +LM_GGML_BACKEND_API lm_ggml_backend_t lm_ggml_backend_metal_init(void); -LM_GGML_API bool lm_ggml_backend_is_metal(lm_ggml_backend_t backend); +LM_GGML_BACKEND_API bool lm_ggml_backend_is_metal(lm_ggml_backend_t backend); LM_GGML_DEPRECATED( - LM_GGML_API lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size), + LM_GGML_BACKEND_API lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size), "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713"); -LM_GGML_API void lm_ggml_backend_metal_set_abort_callback(lm_ggml_backend_t backend, lm_ggml_abort_callback abort_callback, void * user_data); +LM_GGML_BACKEND_API void lm_ggml_backend_metal_set_abort_callback(lm_ggml_backend_t backend, lm_ggml_abort_callback abort_callback, void * user_data); -LM_GGML_API lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void); +LM_GGML_BACKEND_API lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void); // helper to check if the device supports a specific family // ideally, the user code should be doing these checks // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf -LM_GGML_API bool lm_ggml_backend_metal_supports_family(lm_ggml_backend_t backend, int family); +LM_GGML_BACKEND_API bool lm_ggml_backend_metal_supports_family(lm_ggml_backend_t backend, int family); // capture all command buffers committed the next time `lm_ggml_backend_graph_compute` is called -LM_GGML_API void lm_ggml_backend_metal_capture_next_compute(lm_ggml_backend_t backend); +LM_GGML_BACKEND_API void lm_ggml_backend_metal_capture_next_compute(lm_ggml_backend_t backend); -LM_GGML_API lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void); +LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void); #ifdef __cplusplus } diff --git a/cpp/ggml-metal.m b/cpp/ggml-metal.m index d0a084e9..8385ad91 100644 --- a/cpp/ggml-metal.m +++ b/cpp/ggml-metal.m @@ -36,16 +36,20 @@ id mtl_device; int mtl_device_ref_count; - bool support_simdgroup_reduction; - bool support_simdgroup_mm; + bool has_simdgroup_reduction; + bool has_simdgroup_mm; + bool has_bfloat; + bool use_bfloat; char name[128]; } g_lm_ggml_ctx_dev_main = { - /*.mtl_device =*/ nil, - /*.mtl_device_ref_count =*/ 0, - /*.support_simdgroup_reduction =*/ false, - /*.support_simdgroup_mm =*/ false, - /*.name =*/ "", + /*.mtl_device =*/ nil, + /*.mtl_device_ref_count =*/ 0, + /*.has_simdgroup_reduction =*/ false, + /*.has_simdgroup_mm =*/ false, + /*.has_bfloat =*/ false, + /*.use_bfloat =*/ false, + /*.name =*/ "", }; // acquire @@ -55,10 +59,19 @@ if (ctx->mtl_device == nil) { ctx->mtl_device = MTLCreateSystemDefaultDevice(); - ctx->support_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; - ctx->support_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; + ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; + ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; - ctx->support_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; + ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; + + ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; + ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6]; + +#if defined(LM_GGML_METAL_USE_BF16) + ctx->use_bfloat = ctx->has_bfloat; +#else + ctx->use_bfloat = false; +#endif strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1); } @@ -120,6 +133,7 @@ static void lm_ggml_backend_metal_device_rel(struct lm_ggml_backend_metal_device LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, + LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, @@ -146,10 +160,14 @@ static void lm_ggml_backend_metal_device_rel(struct lm_ggml_backend_metal_device LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, - LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, @@ -170,10 +188,11 @@ static void lm_ggml_backend_metal_device_rel(struct lm_ggml_backend_metal_device LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, - //LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, //LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, //LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, + //LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, @@ -195,6 +214,7 @@ static void lm_ggml_backend_metal_device_rel(struct lm_ggml_backend_metal_device LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, + LM_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, @@ -216,6 +236,7 @@ static void lm_ggml_backend_metal_device_rel(struct lm_ggml_backend_metal_device LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, + LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, @@ -255,13 +276,64 @@ static void lm_ggml_backend_metal_device_rel(struct lm_ggml_backend_metal_device LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, - //LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, - //LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, + LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, + LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, + LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, + LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, @@ -440,7 +512,15 @@ @implementation LMGGMLMetalClass // dictionary of preprocessor macros NSMutableDictionary * prep = [NSMutableDictionary dictionary]; - MTLCompileOptions* options = [MTLCompileOptions new]; + if (ctx_dev->use_bfloat) { + [prep setObject:@"1" forKey:@"LM_GGML_METAL_USE_BF16"]; + } + +#if LM_GGML_METAL_EMBED_LIBRARY + [prep setObject:@"1" forKey:@"LM_GGML_METAL_EMBED_LIBRARY"]; +#endif + + MTLCompileOptions * options = [MTLCompileOptions new]; options.preprocessorMacros = prep; //[options setFastMathEnabled:false]; @@ -450,7 +530,14 @@ @implementation LMGGMLMetalClass LM_GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; } + +#if !__has_feature(objc_arc) + [options release]; +#endif } +#if LM_GGML_METAL_EMBED_LIBRARY + [src release]; +#endif // LM_GGML_METAL_EMBED_LIBRARY } } @@ -483,9 +570,11 @@ @implementation LMGGMLMetalClass } } - LM_GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false"); - LM_GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false"); - LM_GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); + LM_GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false"); + LM_GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false"); + LM_GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false"); + LM_GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false"); + LM_GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); ctx->capture_next_compute = false; ctx->capture_started = false; @@ -511,16 +600,14 @@ @implementation LMGGMLMetalClass ctx->kernels[i].pipeline = nil; } - /* - LM_GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ - (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ - (int) kernel->pipeline.threadExecutionWidth); \ - */ #define LM_GGML_METAL_ADD_KERNEL(e, name, supported) \ if (supported) { \ struct lm_ggml_metal_kernel * kernel = &ctx->kernels[e]; \ id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \ + LM_GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ [metal_function release]; \ if (error) { \ LM_GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ @@ -531,8 +618,9 @@ @implementation LMGGMLMetalClass LM_GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ } - const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm; - const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction; + const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm; + const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction; + const bool use_bfloat = ctx_dev->use_bfloat; // simd_sum and simd_max requires MTLGPUFamilyApple7 @@ -560,14 +648,15 @@ @implementation LMGGMLMetalClass LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); @@ -588,101 +677,108 @@ @implementation LMGGMLMetalClass LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, support_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NORM, norm, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, support_simdgroup_reduction); - //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, support_simdgroup_reduction); - //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction); - //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, support_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction); + //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction); + //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction); + //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); @@ -698,18 +794,69 @@ @implementation LMGGMLMetalClass LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm); - //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction); - //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); @@ -799,15 +946,18 @@ static void lm_ggml_metal_free(struct lm_ggml_backend_metal_context * ctx) { } static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_context * ctx_dev, const struct lm_ggml_tensor * op) { - for (size_t i = 0, n = 3; i < n; ++i) { - if (op->src[i] != NULL && op->src[i]->type == LM_GGML_TYPE_BF16) { - return false; + const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm; + const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction; + const bool use_bfloat = ctx_dev->use_bfloat; + + if (!use_bfloat) { + for (size_t i = 0, n = 3; i < n; ++i) { + if (op->src[i] != NULL && op->src[i]->type == LM_GGML_TYPE_BF16) { + return false; + } } } - const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm; - const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction; - switch (op->op) { case LM_GGML_OP_UNARY: switch (lm_ggml_get_unary_op(op)) { @@ -845,7 +995,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_ case LM_GGML_OP_SOFT_MAX: case LM_GGML_OP_RMS_NORM: case LM_GGML_OP_GROUP_NORM: - return support_simdgroup_reduction; + return has_simdgroup_reduction; case LM_GGML_OP_NORM: case LM_GGML_OP_ROPE: return true; @@ -862,22 +1012,16 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_ case LM_GGML_OP_LEAKY_RELU: return true; case LM_GGML_OP_FLASH_ATTN_EXT: - if (op->src[1]->type != LM_GGML_TYPE_F16) { - return false; - } - if (op->src[2]->type != LM_GGML_TYPE_F16) { - return false; - } - if (op->src[0]->ne[0] == 256) { + if (op->src[1]->type != op->src[2]->type) { return false; } - return support_simdgroup_mm; // TODO: over-restricted for vec-kernels + return has_simdgroup_mm; // TODO: over-restricted for vec-kernels case LM_GGML_OP_SSM_CONV: case LM_GGML_OP_SSM_SCAN: return true; case LM_GGML_OP_MUL_MAT: case LM_GGML_OP_MUL_MAT_ID: - return support_simdgroup_reduction && + return has_simdgroup_reduction && (op->src[0]->type != LM_GGML_TYPE_F32 || op->src[1]->type == LM_GGML_TYPE_F32); case LM_GGML_OP_CPY: case LM_GGML_OP_DUP: @@ -888,6 +1032,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_ switch (op->type) { case LM_GGML_TYPE_F32: case LM_GGML_TYPE_F16: + case LM_GGML_TYPE_BF16: case LM_GGML_TYPE_Q8_0: case LM_GGML_TYPE_Q4_0: case LM_GGML_TYPE_Q4_1: @@ -900,10 +1045,18 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_ } case LM_GGML_TYPE_F16: switch (op->type) { - case LM_GGML_TYPE_F32: - case LM_GGML_TYPE_F16: + case LM_GGML_TYPE_F32: + case LM_GGML_TYPE_F16: return true; - default: + default: + return false; + } + case LM_GGML_TYPE_BF16: + switch (op->type) { + case LM_GGML_TYPE_F32: + case LM_GGML_TYPE_BF16: + return true; + default: return false; } default: @@ -989,7 +1142,7 @@ static void lm_ggml_metal_encode_node( const uint64_t nb20 = src2 ? src2->nb[0] : 0; LM_GGML_UNUSED(nb20); const uint64_t nb21 = src2 ? src2->nb[1] : 0; const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; LM_GGML_UNUSED(nb23); const int64_t ne0 = dst ? dst->ne[0] : 0; const int64_t ne1 = dst ? dst->ne[1] : 0; @@ -1774,6 +1927,7 @@ static void lm_ggml_metal_encode_node( switch (src0->type) { case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break; case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break; + case LM_GGML_TYPE_BF16: LM_GGML_ASSERT(nb01 % 8 == 0); break; default: break; } @@ -1782,6 +1936,7 @@ static void lm_ggml_metal_encode_node( switch (src0->type) { case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; + case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break; case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; @@ -1859,6 +2014,25 @@ static void lm_ggml_metal_encode_node( nrows = 4; } } break; + case LM_GGML_TYPE_BF16: + { + nth0 = 32; + nth1 = 1; + if (src1t == LM_GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; + nrows = ne11; + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; + nrows = 4; + } + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; + nrows = 4; + } + } break; case LM_GGML_TYPE_Q4_0: { nth0 = 8; @@ -2077,12 +2251,12 @@ static void lm_ggml_metal_encode_node( if ([device supportsFamily:MTLGPUFamilyApple7] && ne00 % 32 == 0 && ne00 >= 64 && dst_rows > dst_rows_min) { - // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) switch (src0->type) { - case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break; - case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break; + case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break; + case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break; + case LM_GGML_TYPE_BF16: LM_GGML_ASSERT(nb01 % 8 == 0); break; default: break; } @@ -2091,6 +2265,7 @@ static void lm_ggml_metal_encode_node( switch (src0->type) { case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; + case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break; case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; @@ -2160,6 +2335,13 @@ static void lm_ggml_metal_encode_node( nth1 = 1; pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; } break; + case LM_GGML_TYPE_BF16: + { + LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; + } break; case LM_GGML_TYPE_Q4_0: { nth0 = 8; @@ -2357,6 +2539,7 @@ static void lm_ggml_metal_encode_node( switch (src0->type) { case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; + case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break; case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; @@ -2815,6 +2998,7 @@ static void lm_ggml_metal_encode_node( LM_GGML_ASSERT(ne11 % 32 == 0); LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src1->type == src2->type); LM_GGML_ASSERT(lm_ggml_are_same_shape (src1, src2)); @@ -2861,27 +3045,176 @@ static void lm_ggml_metal_encode_node( bool use_vec_kernel = false; + // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) + // for now avoiding mainly to keep the number of templates/kernels a bit lower if (ne01 >= 4 || (ne00%128 != 0)) { - switch (ne00) { - case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + switch (src1->type) { + case LM_GGML_TYPE_F16: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + LM_GGML_LOG_ERROR("add template specialization for this size\n"); + LM_GGML_ABORT("add template specialization for this size"); + } + } + } break; + case LM_GGML_TYPE_BF16: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break; + default: + { + LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + LM_GGML_LOG_ERROR("add template specialization for this size\n"); + LM_GGML_ABORT("add template specialization for this size"); + } + } + } break; + case LM_GGML_TYPE_Q4_0: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break; + case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break; + default: + { + LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + LM_GGML_LOG_ERROR("add template specialization for this size\n"); + LM_GGML_ABORT("add template specialization for this size"); + } + } + } break; + case LM_GGML_TYPE_Q4_1: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break; + case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break; + case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break; + default: + { + LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + LM_GGML_LOG_ERROR("add template specialization for this size\n"); + LM_GGML_ABORT("add template specialization for this size"); + } + } + } break; + case LM_GGML_TYPE_Q5_0: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break; + case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break; + default: + { + LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + LM_GGML_LOG_ERROR("add template specialization for this size\n"); + LM_GGML_ABORT("add template specialization for this size"); + } + } + } break; + case LM_GGML_TYPE_Q5_1: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break; + case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break; + case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break; + default: + { + LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + LM_GGML_LOG_ERROR("add template specialization for this size\n"); + LM_GGML_ABORT("add template specialization for this size"); + } + } + } break; + case LM_GGML_TYPE_Q8_0: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break; + case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break; + default: + { + LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + LM_GGML_LOG_ERROR("add template specialization for this size\n"); + LM_GGML_ABORT("add template specialization for this size"); + } + } + } break; default: - { - LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - LM_GGML_LOG_ERROR("add template specialization for this size\n"); - LM_GGML_ABORT("add template specialization for this size"); - } + { + LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + LM_GGML_LOG_ERROR("add template specialization for this type\n"); + LM_GGML_ABORT("add template specialization for this type"); + } } } else { use_vec_kernel = true; switch (ne00) { - case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + case 128: + { + switch (src1->type) { + case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break; + case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break; + case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break; + case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break; + case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break; + case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break; + default: + { + LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + LM_GGML_LOG_ERROR("add template specialization for this type\n"); + LM_GGML_ABORT("add template specialization for this type"); + } + } + } break; + case 256: + { + switch (src1->type) { + case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break; + case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break; + case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break; + case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break; + case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break; + case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break; + default: + { + LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + LM_GGML_LOG_ERROR("add template specialization for this type\n"); + LM_GGML_ABORT("add template specialization for this type"); + } + } + } break; default: { LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00); @@ -2913,18 +3246,15 @@ static void lm_ggml_metal_encode_node( [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; - [encoder setBytes:&scale length:sizeof( float) atIndex:23]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; - [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19]; + [encoder setBytes:&scale length:sizeof( float) atIndex:20]; + [encoder setBytes:&max_bias length:sizeof( float) atIndex:21]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:22]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:23]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24]; + [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25]; if (!use_vec_kernel) { // half8x8 kernel @@ -2935,10 +3265,19 @@ static void lm_ggml_metal_encode_node( LM_GGML_ASSERT(nqptg % 8 == 0); LM_GGML_ASSERT(ncpsg % 32 == 0); + // 2*(2*ncpsg + nqptg)*(nsg) + // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) + // + // 16*32*(nsg) + // the shared memory needed for the simdgroups to load the KV cache + // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG + // +#define FATTN_SMEM(nsg) (LM_GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) + int64_t nsgmax = 2; while (true) { - const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); + const size_t smem = FATTN_SMEM(nsgmax); if (smem > device.maxThreadgroupMemoryLength) { break; } @@ -2949,16 +3288,15 @@ static void lm_ggml_metal_encode_node( // simdgroups per threadgroup (a.k.a. warps) const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); + const size_t smem = FATTN_SMEM(nsg); - //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength); + //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); LM_GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); - - [encoder setThreadgroupMemoryLength:LM_GGML_PAD(smem, 16) atIndex:0]; - + [encoder setThreadgroupMemoryLength:smem atIndex:0]; +#undef FATTN_SMEM [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } else { - // half1x4 kernel + // half4x4 kernel const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! @@ -2966,8 +3304,28 @@ static void lm_ggml_metal_encode_node( LM_GGML_ASSERT(nqptg % 1 == 0); LM_GGML_ASSERT(ncpsg % 32 == 0); + // ne00 + 2*ncpsg*(nsg) + // for each query, we load it as f16 in shared memory (ne00) + // and store the soft_max values and the mask + // + // ne00*(nsg) + // each simdgroup has a full f16 head vector in shared mem to accumulate results + // +#define FATTN_SMEM(nsg) (LM_GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16)) + + int64_t nsgmax = 2; + + while (true) { + const size_t smem = FATTN_SMEM(nsgmax); + if (smem > device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); int64_t nsg = 1; while (nsg <= nsgt) { @@ -2975,12 +3333,12 @@ static void lm_ggml_metal_encode_node( } nsg /= 2; - const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + const size_t smem = FATTN_SMEM(nsg); - //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength); + //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); LM_GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:LM_GGML_PAD(smem, 16) atIndex:0]; - + [encoder setThreadgroupMemoryLength:smem atIndex:0]; +#undef FATTN_SMEM [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } } break; @@ -3002,6 +3360,7 @@ static void lm_ggml_metal_encode_node( switch (dstt) { case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; + case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break; case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; @@ -3019,6 +3378,14 @@ static void lm_ggml_metal_encode_node( default: LM_GGML_ABORT("not implemented"); }; } break; + case LM_GGML_TYPE_BF16: + { + switch (dstt) { + case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break; + case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break; + default: LM_GGML_ASSERT(false && "not implemented"); + }; + } break; default: LM_GGML_ABORT("not implemented"); } @@ -3837,7 +4204,7 @@ static lm_ggml_backend_buffer_t lm_ggml_backend_metal_device_buffer_from_ptr(lm_ } } - return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_type(), lm_ggml_backend_metal_buffer_i, ctx, size); + return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_from_ptr_type(), lm_ggml_backend_metal_buffer_i, ctx, size); } static bool lm_ggml_backend_metal_device_supports_op(lm_ggml_backend_dev_t dev, const struct lm_ggml_tensor * op) { @@ -3847,7 +4214,8 @@ static bool lm_ggml_backend_metal_device_supports_op(lm_ggml_backend_dev_t dev, } static bool lm_ggml_backend_metal_device_supports_buft(lm_ggml_backend_dev_t dev, lm_ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == lm_ggml_backend_metal_buffer_type_get_name; + return buft->iface.get_name == lm_ggml_backend_metal_buffer_type_get_name || + buft->iface.get_name == lm_ggml_backend_metal_buffer_from_ptr_type_get_name; UNUSED(dev); } diff --git a/cpp/ggml-quants.c b/cpp/ggml-quants.c index e44c3531..7b4f460f 100644 --- a/cpp/ggml-quants.c +++ b/cpp/ggml-quants.c @@ -4,7 +4,7 @@ #include "ggml-quants.h" #include "ggml-impl.h" #include "ggml-cpu-impl.h" - +#include "ggml-cpu.h" #include #include @@ -27,643 +27,6 @@ #define UNUSED LM_GGML_UNUSED -// some compilers don't provide _mm256_set_m128i, e.g. gcc 7 -#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) -// multiply int8_t, add results pairwise twice -static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { - // Get absolute values of x vectors - const __m128i ax = _mm_sign_epi8(x, x); - // Sign the values of the y vectors - const __m128i sy = _mm_sign_epi8(y, x); - // Perform multiplication and create 16-bit values - const __m128i dot = _mm_maddubs_epi16(ax, sy); - const __m128i ones = _mm_set1_epi16(1); - return _mm_madd_epi16(ones, dot); -} - -#if __AVX__ || __AVX2__ || __AVX512F__ -// horizontally add 8 floats -static inline float hsum_float_8(const __m256 x) { - __m128 res = _mm256_extractf128_ps(x, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(x)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res)); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); - return _mm_cvtss_f32(res); -} - -// horizontally add 8 int32_t -static inline int hsum_i32_8(const __m256i a) { - const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); - const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); - const __m128i sum64 = _mm_add_epi32(hi64, sum128); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} - -// horizontally add 4 int32_t -static inline int hsum_i32_4(const __m128i a) { - const __m128i hi64 = _mm_unpackhi_epi64(a, a); - const __m128i sum64 = _mm_add_epi32(hi64, a); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} - -#if defined(__AVX2__) || defined(__AVX512F__) -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m256i shuf_mask = _mm256_set_epi64x( - 0x0303030303030303, 0x0202020202020202, - 0x0101010101010101, 0x0000000000000000); - __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); - const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); - bytes = _mm256_or_si256(bytes, bit_mask); - return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) -{ - const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); - const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); - const __m256i lowMask = _mm256_set1_epi8( 0xF ); - return _mm256_and_si256(lowMask, bytes); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m256i x) { - const __m256i ones = _mm256_set1_epi16(1); - const __m256i summed_pairs = _mm256_madd_epi16(ones, x); - return _mm256_cvtepi32_ps(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) - const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); - return _mm256_cvtepi32_ps(summed_pairs); -#else - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - return sum_i16_pairs_float(dot); -#endif -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { -#if __AVXVNNIINT8__ - const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); - return _mm256_cvtepi32_ps(summed_pairs); -#else - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(x, x); - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(y, x); - return mul_sum_us8_pairs_float(ax, sy); -#endif -} - -static inline __m128i packNibbles( __m256i bytes ) -{ - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh -#if __AVX512F__ - const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 - bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh - return _mm256_cvtepi16_epi8(bytes); // abcd_efgh -#else - const __m256i lowByte = _mm256_set1_epi16( 0xFF ); - __m256i high = _mm256_andnot_si256( lowByte, bytes ); - __m256i low = _mm256_and_si256( lowByte, bytes ); - high = _mm256_srli_epi16( high, 4 ); - bytes = _mm256_or_si256( low, high ); - - // Compress uint16_t lanes into bytes - __m128i r0 = _mm256_castsi256_si128( bytes ); - __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); - return _mm_packus_epi16( r0, r1 ); -#endif -} -#elif defined(__AVX__) -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); - const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); - __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); - __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); - const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); - bytesl = _mm_or_si128(bytesl, bit_mask); - bytesh = _mm_or_si128(bytesh, bit_mask); - bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); - bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); - return MM256_SET_M128I(bytesh, bytesl); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) -{ - // Load 16 bytes from memory - __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); - __m128i tmph = _mm_srli_epi16(tmpl, 4); - const __m128i lowMask = _mm_set1_epi8(0xF); - tmpl = _mm_and_si128(lowMask, tmpl); - tmph = _mm_and_si128(lowMask, tmph); - return MM256_SET_M128I(tmph, tmpl); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { - const __m128i ones = _mm_set1_epi16(1); - const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); - const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); - const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); - return _mm256_cvtepi32_ps(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { - const __m128i axl = _mm256_castsi256_si128(ax); - const __m128i axh = _mm256_extractf128_si256(ax, 1); - const __m128i syl = _mm256_castsi256_si128(sy); - const __m128i syh = _mm256_extractf128_si256(sy, 1); - // Perform multiplication and create 16-bit values - const __m128i dotl = _mm_maddubs_epi16(axl, syl); - const __m128i doth = _mm_maddubs_epi16(axh, syh); - return sum_i16_pairs_float(doth, dotl); -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { - const __m128i xl = _mm256_castsi256_si128(x); - const __m128i xh = _mm256_extractf128_si256(x, 1); - const __m128i yl = _mm256_castsi256_si128(y); - const __m128i yh = _mm256_extractf128_si256(y, 1); - // Get absolute values of x vectors - const __m128i axl = _mm_sign_epi8(xl, xl); - const __m128i axh = _mm_sign_epi8(xh, xh); - // Sign the values of the y vectors - const __m128i syl = _mm_sign_epi8(yl, xl); - const __m128i syh = _mm_sign_epi8(yh, xh); - // Perform multiplication and create 16-bit values - const __m128i dotl = _mm_maddubs_epi16(axl, syl); - const __m128i doth = _mm_maddubs_epi16(axh, syh); - return sum_i16_pairs_float(doth, dotl); -} - -static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) -{ - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh - const __m128i lowByte = _mm_set1_epi16( 0xFF ); - __m128i high = _mm_andnot_si128( lowByte, bytes1 ); - __m128i low = _mm_and_si128( lowByte, bytes1 ); - high = _mm_srli_epi16( high, 4 ); - bytes1 = _mm_or_si128( low, high ); - high = _mm_andnot_si128( lowByte, bytes2 ); - low = _mm_and_si128( lowByte, bytes2 ); - high = _mm_srli_epi16( high, 4 ); - bytes2 = _mm_or_si128( low, high ); - - return _mm_packus_epi16( bytes1, bytes2); -} - -static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) { - const __m128i ax = _mm_sign_epi8(x, x); - const __m128i sy = _mm_sign_epi8(y, x); - return _mm_maddubs_epi16(ax, sy); -} -#endif -#elif defined(__SSSE3__) -// horizontally add 4x4 floats -static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { - __m128 res_0 =_mm_hadd_ps(a, b); - __m128 res_1 =_mm_hadd_ps(c, d); - __m128 res =_mm_hadd_ps(res_0, res_1); - res =_mm_hadd_ps(res, res); - res =_mm_hadd_ps(res, res); - - return _mm_cvtss_f32(res); -} -#endif // __AVX__ || __AVX2__ || __AVX512F__ -#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) - -#if defined(__ARM_NEON) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__) -#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s -#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) -#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) -#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) -#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) -#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) -#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) -#define B8(c,s ) B7(c,s, c), B7(c,s, s) - -// precomputed tables for expanding 8bits to 8 bytes: -static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 -static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 -#endif - -#if defined(__loongarch_asx) - -#ifdef __clang__ -#define VREGS_PREFIX "$vr" -#define XREGS_PREFIX "$xr" -#else // GCC -#define VREGS_PREFIX "$f" -#define XREGS_PREFIX "$f" -#endif -#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31" -// Convert __m128i to __m256i -static inline __m256i ____m256i(__m128i in) { - __m256i out = __lasx_xvldi(0); - __asm__ volatile ( - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " XREGS_PREFIX"\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[in], " VREGS_PREFIX "\\j \n\t" - " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - : [out] "+f" (out) : [in] "f" (in) - ); - return out; -} -// Convert two __m128i to __m256i -static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) { - __m256i out; - __asm__ volatile ( - ".irp i," __ALL_REGS "\n\t" - " .ifc %[hi], " VREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[lo], " VREGS_PREFIX "\\j \n\t" - " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - ".ifnc %[out], %[hi] \n\t" - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " XREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[hi], " VREGS_PREFIX "\\j \n\t" - " xvori.b $xr\\i, $xr\\j, 0 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - ".endif \n\t" - : [out] "=f" (out), [hi] "+f" (inhi) - : [lo] "f" (inlo) - ); - return out; -} -// Convert __m256i low part to __m128i -static inline __m128i lasx_extracti128_lo(__m256i in) { - __m128i out; - __asm__ volatile ( - ".ifnc %[out], %[in] \n\t" - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " VREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[in], " XREGS_PREFIX "\\j \n\t" - " vori.b $vr\\i, $vr\\j, 0 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - ".endif \n\t" - : [out] "=f" (out) : [in] "f" (in) - ); - return out; -} -// Convert __m256i high part to __m128i -static inline __m128i lasx_extracti128_hi(__m256i in) { - __m128i out; - __asm__ volatile ( - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " VREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[in], " XREGS_PREFIX "\\j \n\t" - " xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - : [out] "=f" (out) : [in] "f" (in) - ); - return out; -} - -static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) { - v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7}; - return (__m256i)__ret; -} - -static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) { - v4i32 __ret = {d, c, b, a}; - return (__m128i)__ret; -} - -static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) { - v4i64 __ret = {d, c, b, a}; - return (__m256i)__ret; -} - -static __m256i lasx_insertf128( __m128i x, __m128i y) { - return lasx_set_q(x, y); -} - -static __m128i lsx_shuffle_b(__m128i a, __m128i b) { - __m128i mask_f, zero, tmp0, tmp2, mask; - int f = 0x8f; - mask_f = __lsx_vreplgr2vr_b(f); - zero = __lsx_vldi(0); - tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits - tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive - mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask - tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones - return __lsx_vshuf_b(a, zero, tmp2); -} - -static __m256i lasx_shuffle_b(__m256i a, __m256i b) { - __m256i mask_f, zero, tmp0, tmp2, mask; - int f = 0x8f; - mask_f = __lasx_xvreplgr2vr_b(f); - zero = __lasx_xvldi(0); - tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits - tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive - mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask - tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones - return __lasx_xvshuf_b(a, zero, tmp2); -} - -static __m256i lasx_extu8_16(__m128i a) { - __m128i zero = __lsx_vldi(0); - __m128i vlo = __lsx_vilvl_b(zero, a); - __m128i vhi = __lsx_vilvh_b(zero, a); - return lasx_set_q(vhi, vlo); -} - -static __m256i lasx_ext8_16(__m128i a) { - __m128i sign = __lsx_vslti_b(a, 0); - __m128i vlo = __lsx_vilvl_b(sign, a); - __m128i vhi = __lsx_vilvh_b(sign, a); - return lasx_set_q(vhi, vlo); -} - -static __m256i lasx_ext16_32(__m128i a) { - __m256i tmp1; - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7); - return tmp1; -} - -static __m128i lasx_extracti128( __m256i a, int pos) { - __m128i ret; - if( pos == 0) - { - ret = lasx_extracti128_lo(a); - } else { - ret = lasx_extracti128_hi(a); - } - return ret; -} - -static __m128 lasx_extractf128( __m256 a, int pos) { - __m128 ret; - if( pos == 0) - { - ret = (__m128)lasx_extracti128_lo((__m256i)a); - } else { - ret = (__m128)lasx_extracti128_hi((__m256i)a); - } - return ret; -} - -static __m128i lsx_hadd_h(__m128i a, __m128i b) { - __m128i tmp1 = __lsx_vpickev_h(b, a); - __m128i tmp2 = __lsx_vpickod_h(b, a); - return __lsx_vadd_h(tmp1, tmp2); -} - -static __m128i lsx_hadd_w(__m128i a, __m128i b) { - __m128i tmp1 = __lsx_vpickev_w(b, a); - __m128i tmp2 = __lsx_vpickod_w(b, a); - return __lsx_vadd_w(tmp1, tmp2); -} - -static __m128 lsx_hadd_s(__m128 a, __m128 b) { - __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a); - __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a); - - return __lsx_vfadd_s(tmp1, tmp2); -} - -static __m256i lasx_maddubs_h(__m256i a, __m256i b) { - __m256i tmp1, tmp2; - tmp1 = __lasx_xvmulwev_h_b(a, b); - tmp2 = __lasx_xvmulwod_h_b(a, b); - return __lasx_xvsadd_h(tmp1, tmp2); -} - -static __m256i lasx_madd_h(__m256i a, __m256i b) { - __m256i tmp1, tmp2; - tmp1 = __lasx_xvmulwev_w_h(a, b); - tmp2 = __lasx_xvmulwod_w_h(a, b); - return __lasx_xvadd_w(tmp1, tmp2); -} - -static __m256i lasx_packs_w(__m256i a, __m256i b) { - __m256i tmp, tmp1; - tmp = __lasx_xvsat_w(a, 15); - tmp1 = __lasx_xvsat_w(b, 15); - return __lasx_xvpickev_h(tmp1, tmp); -} - -static __m256i lasx_packs_h(__m256i a, __m256i b) { - __m256i tmp, tmp1; - tmp = __lasx_xvsat_h(a, 7); - tmp1 = __lasx_xvsat_h(b, 7); - return __lasx_xvpickev_b(tmp1, tmp); -} - -static __m128i lsx_packs_w(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_w(a, 15); - tmp1 = __lsx_vsat_w(b, 15); - return __lsx_vpickev_h(tmp1, tmp); -} - -static __m128i lsx_packs_h(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_h(a, 7); - tmp1 = __lsx_vsat_h(b, 7); - return __lsx_vpickev_b(tmp1, tmp); -} - -static __m128i lsx_packus_h(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_hu(a, 7); - tmp1 = __lsx_vsat_hu(b, 7); - return __lsx_vpickev_b(tmp1, tmp); -} - - -static __m128i lsx_maddubs_h(__m128i a, __m128i b) { - __m128i tmp1, tmp2; - tmp1 = __lsx_vmulwev_h_b(a, b); - tmp2 = __lsx_vmulwod_h_b(a, b); - return __lsx_vsadd_h(tmp1, tmp2); -} - -static __m128i lsx_madd_h(__m128i a, __m128i b) { - __m128i tmp1, tmp2; - tmp1 = __lsx_vmulwev_w_h(a, b); - tmp2 = __lsx_vmulwod_w_h(a, b); - return __lsx_vadd_w(tmp1, tmp2); -} - -// multiply int8_t, add results pairwise twice -static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { - // Get absolute values of x vectors - const __m128i ax = __lsx_vsigncov_b(x, x); - // Sign the values of the y vectors - const __m128i sy = __lsx_vsigncov_b(x, y); - // Perform multiplication and create 16-bit values - const __m128i dot = lsx_maddubs_h(ax, sy); - const __m128i ones = __lsx_vreplgr2vr_h(1); - return lsx_madd_h(ones, dot); -} - -// horizontally add 8 floats -static inline float hsum_float_8(const __m256 x) { - __m128 res = lasx_extractf128(x, 1); - ft_union tmp; - res = __lsx_vfadd_s(res, lasx_extractf128(x, 0)); - res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res)); - res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0)); - tmp.i = __lsx_vpickve2gr_w(res, 0); - return tmp.f; -} - -// horizontally add 8 int32_t -static inline int hsum_i32_8(const __m256i a) { - - __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11); - __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00); - - __m128i tmp1_128 = lasx_extracti128_lo(tmp1); - __m128i tmp2_128 = lasx_extracti128_lo(tmp2); - - __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128); - - __m128i ev = __lsx_vpickev_w(sum128, sum128); - __m128i od = __lsx_vpickod_w(sum128, sum128); - __m128i sum64 = __lsx_vadd_w(ev, od); - - int sum64_1, sum64_2; - sum64_1 = __lsx_vpickve2gr_w(sum64, 0); - sum64_2 = __lsx_vpickve2gr_w(sum64, 1); - - return sum64_1 + sum64_2; -} - -// horizontally add 4 int32_t -static inline int hsum_i32_4(const __m128i a) { - __m128i ev = __lsx_vpickev_w(a, a); - __m128i od = __lsx_vpickod_w(a, a); - __m128i sum64 = __lsx_vadd_w(ev, od); - - int sum64_1, sum64_2; - sum64_1 = __lsx_vpickve2gr_w(sum64, 0); - sum64_2 = __lsx_vpickve2gr_w(sum64, 1); - - return sum64_1 + sum64_2; -} - -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m256i shuf_mask = lasx_set_d( - 0x0303030303030303, 0x0202020202020202, - 0x0101010101010101, 0x0000000000000000); - - __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask); - const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe); - bytes = __lasx_xvor_v(bytes, bit_mask); - return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1)); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { - const __m128i lo = __lsx_vld((const __m128i *)rsi, 0); - __m128i hi = __lsx_vsrli_h(lo, 4); - return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m256i x) { - __m256i v = __lasx_xvpackod_h(x, x); - __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v); - return __lasx_xvffint_s_w(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { - // Perform multiplication and create 16-bit values - const __m256i dot = lasx_maddubs_h(ax, sy); - return sum_i16_pairs_float(dot); -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { - - // Get absolute values of x vectors - const __m256i ax = __lasx_xvsigncov_b(x, x); - // Sign the values of the y vectors - const __m256i sy = __lasx_xvsigncov_b(x, y); - - return mul_sum_us8_pairs_float(ax, sy); -} - -static inline __m128i packNibbles( __m256i bytes ) { - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh - const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF); - __m256i high = __lasx_xvandn_v(lowByte, bytes); - __m256i low = __lasx_xvand_v(lowByte, bytes); - high = __lasx_xvsrli_h(high, 4); - bytes = __lasx_xvor_v(low, high); - // Compress uint16_t lanes into bytes - __m128i *r0 = (__m128i *)&bytes; - __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11); - __m128i *r1 = (__m128i *)&tmp_h128; - - __m128i zero = __lsx_vldi(0); - __m128i tmp, tmp2, tmp3; - - tmp = __lsx_vmax_h(zero, *r0); - tmp2 = __lsx_vsat_hu(tmp, 7); - - tmp = __lsx_vmax_h(zero, *r1); - tmp3 = __lsx_vsat_hu(tmp, 7); - return __lsx_vpickev_b(tmp3, tmp2); -} -#endif //__loongarch_asx - // reference implementation for deterministic creation of model files void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; @@ -702,11 +65,6 @@ void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, in } } -void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q4_0_ref(x, y, k); -} - - void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) { const int qk = QK4_1; @@ -744,10 +102,6 @@ void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, in } } -void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q4_1_ref(x, y, k); -} - void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, int64_t k) { static const int qk = QK5_0; @@ -792,10 +146,6 @@ void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, in } } -void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q5_0_ref(x, y, k); -} - void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, int64_t k) { const int qk = QK5_1; @@ -840,10 +190,6 @@ void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, in } } -void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q5_1_ref(x, y, k); -} - // reference implementation for deterministic creation of model files void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) { assert(k % QK8_0 == 0); @@ -870,291 +216,6 @@ void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, in } } -void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - for (int i = 0; i < nb; i++) { - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = LM_GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - } - } -#elif defined(__wasm_simd128__) - for (int i = 0; i < nb; i++) { - v128_t srcv [8]; - v128_t asrcv[8]; - v128_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), - wasm_f32x4_extract_lane(amaxv[0], 1)), - MAX(wasm_f32x4_extract_lane(amaxv[0], 2), - wasm_f32x4_extract_lane(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = LM_GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); - const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); - - y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); - y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); - y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); - y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); - } - } -#elif defined(__AVX2__) || defined(__AVX__) - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = maxScalar / 127.f; - y[i].d = LM_GGML_FP32_TO_FP16(d); - const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - -#if defined(__AVX2__) - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)y[i].qs, i0); -#else - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - - _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); - _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); -#endif - } -#elif defined(__riscv_v_intrinsic) - - size_t vl = __riscv_vsetvl_e32m4(QK8_0); - - for (int i = 0; i < nb; i++) { - // load elements - vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl); - - vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); - vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); - vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); - float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = LM_GGML_FP32_TO_FP16(d); - - vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); - - // convert to integer - vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); - vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); - - // store result - __riscv_vse8_v_i8m1(y[i].qs , vs, vl); - } - -#elif defined(__POWER9_VECTOR__) - for (int i = 0; i < nb; i++) { - vector float srcv [8]; - vector float asrcv[8]; - vector float amaxv[8]; - vector signed int vi[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(vec_extract(amaxv[0], 0), - vec_extract(amaxv[0], 1)), - MAX(vec_extract(amaxv[0], 2), - vec_extract(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - const vector float vid = vec_splats(id); - - y[i].d = LM_GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const vector float v = vec_round(vec_mul(srcv[j], vid)); - vi[j] = vec_cts(v, 0); - } - vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); - vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); - } - -#elif defined(__loongarch_asx) - for (int i = 0; i < nb; i++) { - ft_union fi; - __m256 v0 = (__m256)__lasx_xvld( x , 0); - __m256 v1 = (__m256)__lasx_xvld( x , 32); - __m256 v2 = (__m256)__lasx_xvld( x , 64); - __m256 v3 = (__m256)__lasx_xvld( x , 96); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f ); - __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) ); - - __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) ); - max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); - __m128 tmp = max4; - max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 )); - fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 ); - const float max_scalar = fi.f; - - // Quantize these floats - const float d = max_scalar / 127.f; - y[i].d = LM_GGML_FP32_TO_FP16(d); - const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; - const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id ); - - // Apply the multiplier - v0 = __lasx_xvfmul_s( v0, mul ); - v1 = __lasx_xvfmul_s( v1, mul ); - v2 = __lasx_xvfmul_s( v2, mul ); - v3 = __lasx_xvfmul_s( v3, mul ); - - // Round to nearest integer - __m256i i0 = __lasx_xvftintrne_w_s( v0 ); - __m256i i1 = __lasx_xvftintrne_w_s( v1 ); - __m256i i2 = __lasx_xvftintrne_w_s( v2 ); - __m256i i3 = __lasx_xvftintrne_w_s( v3 ); - - __m128i ni0 = lasx_extracti128( i0, 0 ); - __m128i ni1 = lasx_extracti128( i0, 1); - __m128i ni2 = lasx_extracti128( i1, 0); - __m128i ni3 = lasx_extracti128( i1, 1); - __m128i ni4 = lasx_extracti128( i2, 0); - __m128i ni5 = lasx_extracti128( i2, 1); - __m128i ni6 = lasx_extracti128( i3, 0); - __m128i ni7 = lasx_extracti128( i3, 1); - - // Convert int32 to int16 - ni0 = lsx_packs_w( ni0, ni1 ); - ni2 = lsx_packs_w( ni2, ni3 ); - ni4 = lsx_packs_w( ni4, ni5 ); - ni6 = lsx_packs_w( ni6, ni7 ); - // Convert int16 to int8 - ni0 = lsx_packs_h( ni0, ni2 ); - ni4 = lsx_packs_h( ni4, ni6 ); - - __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); - __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); - - } -#else - LM_GGML_UNUSED(nb); - // scalar - quantize_row_q8_0_ref(x, y, k); -#endif -} - // reference implementation for deterministic creation of model files void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) { assert(QK8_1 == 32); @@ -1191,381 +252,53 @@ void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, in } } -void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK8_1 == 0); - const int nb = k / QK8_1; +void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) { + static const int qk = QK4_0; + + assert(k % qk == 0); - block_q8_1 * restrict y = vy; + const int nb = k / qk; -#if defined(__ARM_NEON) for (int i = 0; i < nb; i++) { - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + const float d = LM_GGML_FP16_TO_FP32(x[i].d); - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F) - 8; + const int x1 = (x[i].qs[j] >> 4) - 8; - const float amax = vmaxvq_f32(amaxv[0]); + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; +void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int64_t k) { + static const int qk = QK4_1; - y[i].d = LM_GGML_FP32_TO_FP16(d); + assert(k % qk == 0); - int32x4_t accv = vdupq_n_s32(0); + const int nb = k / qk; - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); + for (int i = 0; i < nb; i++) { + const float d = LM_GGML_FP16_TO_FP32(x[i].d); + const float m = LM_GGML_FP16_TO_FP32(x[i].m); - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F); + const int x1 = (x[i].qs[j] >> 4); - accv = vaddq_s32(accv, vi); + y[i*qk + j + 0 ] = x0*d + m; + y[i*qk + j + qk/2] = x1*d + m; } - - y[i].s = LM_GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); } -#elif defined(__wasm_simd128__) - for (int i = 0; i < nb; i++) { - v128_t srcv [8]; - v128_t asrcv[8]; - v128_t amaxv[8]; +} - for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); +void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int64_t k) { + static const int qk = QK5_0; - for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + assert(k % qk == 0); - const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), - wasm_f32x4_extract_lane(amaxv[0], 1)), - MAX(wasm_f32x4_extract_lane(amaxv[0], 2), - wasm_f32x4_extract_lane(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = LM_GGML_FP32_TO_FP16(d); - - v128_t accv = wasm_i32x4_splat(0); - - for (int j = 0; j < 8; j++) { - const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); - const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); - - y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); - y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); - y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); - y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); - - accv = wasm_i32x4_add(accv, vi); - } - - y[i].s = LM_GGML_FP32_TO_FP16( - d * (wasm_i32x4_extract_lane(accv, 0) + - wasm_i32x4_extract_lane(accv, 1) + - wasm_i32x4_extract_lane(accv, 2) + - wasm_i32x4_extract_lane(accv, 3))); - } -#elif defined(__AVX2__) || defined(__AVX__) - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float max_scalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = max_scalar / 127.f; - y[i].d = LM_GGML_FP32_TO_FP16(d); - const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - -#if defined(__AVX2__) - // Compute the sum of the quants and set y[i].s - y[i].s = LM_GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); - - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)y[i].qs, i0); -#else - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1); - - // Compute the sum of the quants and set y[i].s - const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); - const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); - y[i].s = LM_GGML_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1))); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - - _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); - _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); -#endif - } -#elif defined(__riscv_v_intrinsic) - - size_t vl = __riscv_vsetvl_e32m4(QK8_1); - - for (int i = 0; i < nb; i++) { - // load elements - vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl); - - vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); - vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl); - vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); - float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = LM_GGML_FP32_TO_FP16(d); - - vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); - - // convert to integer - vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); - vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); - - // store result - __riscv_vse8_v_i8m1(y[i].qs , vs, vl); - - // compute sum for y[i].s - vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl); - vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl); - - // set y[i].s - int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); - y[i].s = LM_GGML_FP32_TO_FP16(sum*d); - } - -#elif defined(__POWER9_VECTOR__) - for (int i = 0; i < nb; i++) { - vector float srcv [8]; - vector float asrcv[8]; - vector float amaxv[8]; - vector signed int vi[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(vec_extract(amaxv[0], 0), - vec_extract(amaxv[0], 1)), - MAX(vec_extract(amaxv[0], 2), - vec_extract(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - const vector float vid = vec_splats(id); - - y[i].d = LM_GGML_FP32_TO_FP16(d); - - vector int accv = vec_splats(0); - - for (int j = 0; j < 8; j++) { - const vector float v = vec_round(vec_mul(srcv[j], vid)); - vi[j] = vec_cts(v, 0); - - accv = vec_add(accv, vi[j]); - } - vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); - vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); - - accv = vec_add(accv, vec_sld(accv, accv, 4)); - accv = vec_add(accv, vec_sld(accv, accv, 8)); - y[i].s = LM_GGML_FP32_TO_FP16(d * vec_extract(accv, 0)); - } - -#elif defined(__loongarch_asx) - for (int i = 0; i < nb; i++) { - ft_union ft; - __m256 v0 = (__m256)__lasx_xvld( x , 0 ); - __m256 v1 = (__m256)__lasx_xvld( x , 32 ); - __m256 v2 = (__m256)__lasx_xvld( x , 64 ); - __m256 v3 = (__m256)__lasx_xvld( x , 96 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f ); - __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) ); - - __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) ); - max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); - __m128 tmp = max4; - max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 )); - ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 ); - const float max_scalar = ft.f; - - // Quantize these floats - const float d = max_scalar / 127.f; - y[i].d = LM_GGML_FP32_TO_FP16(d); - const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; - const __m256 mul = __lasx_xvreplfr2vr_s( id ); - - // Apply the multiplier - v0 = __lasx_xvfmul_s( v0, mul ); - v1 = __lasx_xvfmul_s( v1, mul ); - v2 = __lasx_xvfmul_s( v2, mul ); - v3 = __lasx_xvfmul_s( v3, mul ); - - // Round to nearest integer - __m256i i0 = __lasx_xvftintrne_w_s( v0 ); - __m256i i1 = __lasx_xvftintrne_w_s( v1 ); - __m256i i2 = __lasx_xvftintrne_w_s( v2 ); - __m256i i3 = __lasx_xvftintrne_w_s( v3 ); - - __m128i ni0 = lasx_extracti128(i0, 0); - __m128i ni1 = lasx_extracti128( i0, 1); - __m128i ni2 = lasx_extracti128( i1, 0); - __m128i ni3 = lasx_extracti128( i1, 1); - __m128i ni4 = lasx_extracti128( i2, 0 ); - __m128i ni5 = lasx_extracti128( i2, 1); - __m128i ni6 = lasx_extracti128( i3, 0); - __m128i ni7 = lasx_extracti128( i3, 1); - - // Compute the sum of the quants and set y[i].s - const __m128i s0 = __lsx_vadd_w(__lsx_vadd_w(ni0, ni1), __lsx_vadd_w(ni2, ni3)); - const __m128i s1 = __lsx_vadd_w(__lsx_vadd_w(ni4, ni5), __lsx_vadd_w(ni6, ni7)); - y[i].s = LM_GGML_FP32_TO_FP16(d * hsum_i32_4(__lsx_vadd_w(s0, s1))); - - // Convert int32 to int16 - ni0 = lsx_packs_w( ni0, ni1 ); - ni2 = lsx_packs_w( ni2, ni3 ); - ni4 = lsx_packs_w( ni4, ni5 ); - ni6 = lsx_packs_w( ni6, ni7 ); - // Convert int16 to int8 - ni0 = lsx_packs_h( ni0, ni2 ); - ni4 = lsx_packs_h( ni4, ni6 ); - - __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); - __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); - } -#else - LM_GGML_UNUSED(nb); - // scalar - quantize_row_q8_1_ref(x, y, k); -#endif -} - -void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) { - static const int qk = QK4_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d); - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0x0F) - 8; - const int x1 = (x[i].qs[j] >> 4) - 8; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; - } - } -} - -void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int64_t k) { - static const int qk = QK4_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d); - const float m = LM_GGML_FP16_TO_FP32(x[i].m); - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0x0F); - const int x1 = (x[i].qs[j] >> 4); - - y[i*qk + j + 0 ] = x0*d + m; - y[i*qk + j + qk/2] = x1*d + m; - } - } -} - -void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int64_t k) { - static const int qk = QK5_0; - - assert(k % qk == 0); - - const int nb = k / qk; + const int nb = k / qk; for (int i = 0; i < nb; i++) { const float d = LM_GGML_FP16_TO_FP32(x[i].d); @@ -2008,10 +741,6 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int6 } } -void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) { - quantize_row_q2_K_ref(x, vy, k); -} - static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights, uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux, float rmin, float rdelta, int nstep, bool use_mad) { @@ -2374,10 +1103,6 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int6 } } -void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { - quantize_row_q3_K_ref(x, vy, k); -} - static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) { assert(n_per_row % QK_K == 0); const int nb = n_per_row / QK_K; @@ -2576,12 +1301,6 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int6 } } -void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_q4_K * restrict y = vy; - quantize_row_q4_K_ref(x, y, k); -} - static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) { assert(n_per_row % QK_K == 0); const int64_t nb = n_per_row / QK_K; @@ -2787,12 +1506,6 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int6 } } -void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_q5_K * restrict y = vy; - quantize_row_q5_K_ref(x, y, k); -} - static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) { assert(n_per_row % QK_K == 0); const int64_t nb = n_per_row / QK_K; @@ -3005,12 +1718,6 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int6 } } -void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_q6_K * restrict y = vy; - quantize_row_q6_K_ref(x, y, k); -} - static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) { assert(n_per_row % QK_K == 0); const int64_t nb = n_per_row / QK_K; @@ -3413,33 +2120,20 @@ void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, } } -void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_tq1_0 * restrict y = vy; - quantize_row_tq1_0_ref(x, y, k); -} - -void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_tq2_0 * restrict y = vy; - quantize_row_tq2_0_ref(x, y, k); -} - size_t quantize_tq1_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { (void)quant_weights; // not used const size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_TQ1_0, n_per_row); - quantize_row_tq1_0(src, dst, (int64_t)nrow*n_per_row); + quantize_row_tq1_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * row_size; } size_t quantize_tq2_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { (void)quant_weights; // not used const size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_TQ2_0, n_per_row); - quantize_row_tq2_0(src, dst, (int64_t)nrow*n_per_row); + quantize_row_tq2_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * row_size; } - void dequantize_row_tq1_0(const block_tq1_0 * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -3615,9394 +2309,221 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y qs += 8; } } -} - -// ====================== 3.3125 bpw (de)-quantization - -void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int64_t k) { - assert(k % QK_K == 0); - const int64_t nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - const float d = LM_GGML_FP16_TO_FP32(x[i].d); - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint8_t * signs = x[i].signs; - - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf)); - const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4)); - for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); - y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); - } - y += 8; - } - qs += 8; - signs += 4; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); - y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); - } - y += 8; - } - qh += 2; - qs += 8; - signs += 4; - } - } -} - -// ====================== 1.5625 bpw (de)-quantization - -void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int64_t k) { - assert(k % QK_K == 0); - const int64_t nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - const float d = LM_GGML_FP16_TO_FP32(x[i].d); - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; - - for (int ib = 0; ib < QK_K/32; ++ib) { - const float dl = d * (2*((qh[ib] >> 12) & 7) + 1); - const float delta = qh[ib] & 0x8000 ? -IQ1S_DELTA : IQ1S_DELTA; - for (int l = 0; l < 4; ++l) { - const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); - for (int j = 0; j < 8; ++j) { - y[j] = dl * (grid[j] + delta); - } - y += 8; - } - qs += 4; - } - } -} - -void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int64_t k) { - assert(k % QK_K == 0); - const int64_t nb = k / QK_K; - - float delta[4]; - uint16_t idx[4]; - - iq1m_scale_t scale; - - for (int i = 0; i < nb; i++) { - - const uint16_t * sc = (const uint16_t *)x[i].scales; - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - const float d = LM_GGML_FP16_TO_FP32(scale.f16); - - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - - for (int ib = 0; ib < QK_K/32; ++ib) { - const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1); - const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1); - - idx[0] = qs[0] | ((qh[0] << 8) & 0x700); - idx[1] = qs[1] | ((qh[0] << 4) & 0x700); - idx[2] = qs[2] | ((qh[1] << 8) & 0x700); - idx[3] = qs[3] | ((qh[1] << 4) & 0x700); - delta[0] = qh[0] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA; - delta[1] = qh[0] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA; - delta[2] = qh[1] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA; - delta[3] = qh[1] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA; - for (int l = 0; l < 2; ++l) { - const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]); - for (int j = 0; j < 8; ++j) { - y[j] = dl1 * (grid[j] + delta[l]); - } - y += 8; - } - for (int l = 2; l < 4; ++l) { - const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]); - for (int j = 0; j < 8; ++j) { - y[j] = dl2 * (grid[j] + delta[l]); - } - y += 8; - } - qs += 4; - qh += 2; - } - } -} - -static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; - -void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int64_t k) { - assert(k % QK4_NL == 0); - const int64_t nb = k / QK4_NL; - - for (int i = 0; i < nb; i++) { - - const uint8_t * qs = x[i].qs; - - const float d = LM_GGML_FP16_TO_FP32(x[i].d); - for (int j = 0; j < QK4_NL/2; ++j) { - y[j+ 0] = d * kvalues_iq4nl[qs[j] & 0xf]; - y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >> 4]; - } - y += QK4_NL; - qs += QK4_NL/2; - } -} - -void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int64_t k) { - assert(k % QK_K == 0); - const int64_t nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - const uint8_t * qs = x[i].qs; - - const float d = LM_GGML_FP16_TO_FP32(x[i].d); - - for (int ib = 0; ib < QK_K/32; ++ib) { - const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4); - const float dl = d * (ls - 32); - for (int j = 0; j < 16; ++j) { - y[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf]; - y[j+16] = dl * kvalues_iq4nl[qs[j] >> 4]; - } - y += 32; - qs += 16; - } - } -} - -//===================================== Q8_K ============================================== - -void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, int64_t k) { - assert(k % QK_K == 0); - const int64_t nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - float max = 0; - float amax = 0; - for (int j = 0; j < QK_K; ++j) { - float ax = fabsf(x[j]); - if (ax > amax) { - amax = ax; max = x[j]; - } - } - if (!amax) { - y[i].d = 0; - memset(y[i].qs, 0, QK_K); - x += QK_K; - continue; - } - //const float iscale = -128.f/max; - // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward - const float iscale = -127.f/max; - for (int j = 0; j < QK_K; ++j) { - int v = nearest_int(iscale*x[j]); - y[i].qs[j] = MIN(127, v); - } - for (int j = 0; j < QK_K/16; ++j) { - int sum = 0; - for (int ii = 0; ii < 16; ++ii) { - sum += y[i].qs[j*16 + ii]; - } - y[i].bsums[j] = sum; - } - y[i].d = 1/iscale; - x += QK_K; - } -} - -void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int64_t k) { - assert(k % QK_K == 0); - const int64_t nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - for (int j = 0; j < QK_K; ++j) { - *y++ = x[i].d * x[i].qs[j]; - } - } -} - -void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q8_K_ref(x, y, k); -} - -//===================================== Dot products ================================= - -// -// Helper functions -// -#if __AVX__ || __AVX2__ || __AVX512F__ - -// shuffles to pick the required scales in dot products -static inline __m256i get_scale_shuffle_q3k(int i) { - static const uint8_t k_shuffle[128] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, - }; - return _mm256_loadu_si256((const __m256i*)k_shuffle + i); -} -static inline __m256i get_scale_shuffle_k4(int i) { - static const uint8_t k_shuffle[256] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, - 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, - 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, - 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 - }; - return _mm256_loadu_si256((const __m256i*)k_shuffle + i); -} -static inline __m128i get_scale_shuffle(int i) { - static const uint8_t k_shuffle[128] = { - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, - 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, - 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, - 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, - 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, - 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, - 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 - }; - return _mm_loadu_si128((const __m128i*)k_shuffle + i); -} -#elif defined(__loongarch_asx) -// shuffles to pick the required scales in dot products -static inline __m256i get_scale_shuffle_q3k(int i) { - static const uint8_t k_shuffle[128] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, - }; - return __lasx_xvld((const __m256i*)k_shuffle + i, 0); -} -static inline __m256i get_scale_shuffle_k4(int i) { - static const uint8_t k_shuffle[256] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, - 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, - 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, - 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 - }; - return __lasx_xvld((const __m256i*)k_shuffle + i, 0); -} -static inline __m128i get_scale_shuffle(int i) { - static const uint8_t k_shuffle[128] = { - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, - 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, - 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, - 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, - 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, - 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, - 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 - }; - return __lsx_vld((const __m128i*)k_shuffle + i, 0); -} -#endif - -void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); -#if defined(__ARM_FEATURE_MATMUL_INT8) - assert((nrc == 2) || (nrc == 1)); -#else - assert(nrc == 1); -#endif - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q4_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_FEATURE_MATMUL_INT8) - if (nrc == 2) { - const block_q4_0 * restrict vx0 = vx; - const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx); - const block_q8_0 * restrict vy0 = vy; - const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by); - - float32x4_t sumv0 = vdupq_n_f32(0.0f); - - for (int i = 0; i < nb; i++) { - const block_q4_0 * restrict b_x0 = &vx0[i]; - const block_q4_0 * restrict b_x1 = &vx1[i]; - const block_q8_0 * restrict b_y0 = &vy0[i]; - const block_q8_0 * restrict b_y1 = &vy1[i]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - const int8x16_t s8b = vdupq_n_s8(0x8); - - const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); - const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // sub 8 - const int8x16_t x0_l = vsubq_s8(v0_0l, s8b); - const int8x16_t x0_h = vsubq_s8(v0_0h, s8b); - const int8x16_t x1_l = vsubq_s8(v0_1l, s8b); - const int8x16_t x1_h = vsubq_s8(v0_1h, s8b); - - // load y - const int8x16_t y0_l = vld1q_s8(b_y0->qs); - const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); - const int8x16_t y1_l = vld1q_s8(b_y1->qs); - const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - - float32_t _scale[4] = { LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y0->d), - LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y1->d), - LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y0->d), - LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y1->d)}; - - float32x4_t scale = vld1q_f32(_scale); - - int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - - int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - - int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - - int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - - sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); - } - float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); - float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - - vst1_f32(s, vget_low_f32(sumv2)); - vst1_f32(s + bs, vget_high_f32(sumv2)); - return; - } -#endif - - int ib = 0; - float sumf = 0; - -#if defined(__ARM_FEATURE_SVE) - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); - - const int vector_length = lm_ggml_cpu_get_sve_cnt()*8; - - // VLA Implementation using switch case - switch (vector_length) { - case 128: - { - // predicate for activating higher lanes for 4 float32 elements - const svbool_t ph4 = svptrue_pat_b32(SV_VL4); - - for (; ib + 1 < nb; ib += 2) { - const block_q4_0 * restrict x0 = &x[ib + 0]; - const block_q4_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - // load x - const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); - const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); - - // 4-bit -> 8-bit - const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F)); - const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04)); - const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F)); - const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04)); - - // sub 8 - const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8); - const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8); - const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8); - const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8); - - // load y - const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16); - const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs); - const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16); - - // dot product - sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4, - svdot_s32(svdup_n_s32(0), qx0ls, qy0l), - svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4, - svdot_s32(svdup_n_s32(0), qx1ls, qy1l), - svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); - } - - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - } break; - case 256: - { - // predicate for activating higher lanes for 16 int8 elements - const svbool_t ph16 = svptrue_pat_b8(SV_VL16); - // predicate for activating lower lanes for 16 int8 elements - const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); - - for (; ib + 1 < nb; ib += 2) { - const block_q4_0 * restrict x0 = &x[ib + 0]; - const block_q4_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - // load x - const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); - const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); - - // 4-bit -> 8-bit - const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); - const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); - - // sub 8 - const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); - const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); - - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - - // dot product - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), - svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), - svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); - } - - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - } break; - case 512: - { - // predicate for activating higher lanes for 32 int8 elements - const svbool_t ph32 = svptrue_pat_b8(SV_VL32); - - // predicate for activating higher lanes for 16 int8 elements - const svbool_t ph16 = svptrue_pat_b8(SV_VL16); - // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes - const svbool_t pl16 = svnot_b_z(ph32, ph16); - - for (; ib + 1 < nb; ib += 2) { - const block_q4_0 * restrict x0 = &x[ib + 0]; - const block_q4_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - // load x - const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs); - const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs); - - // 4-bit -> 8-bit - const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); - const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); - - // sub 8 - const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8); - const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8); - - // load y - const svint8_t qy0 = svld1_s8(ph32, y0->qs); - const svint8_t qy1 = svld1_s8(ph32, y1->qs); - - // dot product - sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32, - svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32, - svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); - } - - sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1)); - } break; - default: - assert(false && "Unsupported vector length"); - break; - } - -#elif defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - for (; ib + 1 < nb; ib += 2) { - const block_q4_0 * restrict x0 = &x[ib + 0]; - const block_q4_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - const int8x16_t s8b = vdupq_n_s8(0x8); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - - // dot product into int32x4_t - const int32x4_t p_0 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); - const int32x4_t p_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) ); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = _mm256_set1_epi8( 8 ); - qx = _mm256_sub_epi8( qx, off ); - - __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps( d, q, acc ); - } - - sumf = hsum_float_8(acc); -#elif defined(__AVX__) - const __m128i mone = _mm_set1_epi16(1); - - __m256 accum1 = _mm256_setzero_ps(); - __m256 accum2 = _mm256_setzero_ps(); - for (; ib + 1 < nb; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); - const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); - const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); - const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); - const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); - - const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8)); - const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8)); - const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8)); - const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8)); - const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); - const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); - const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); - const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); - const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone); - const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone); - const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone); - const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone); - accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 0].d)*LM_GGML_FP16_TO_FP32(x[ib + 0].d)), - _mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1); - accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 1].d)*LM_GGML_FP16_TO_FP32(x[ib + 1].d)), - _mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2); - } - - sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); -#elif defined(__SSSE3__) - // set constants - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); - - // Initialize accumulator with zeros - __m128 acc_0 = _mm_setzero_ps(); - __m128 acc_1 = _mm_setzero_ps(); - __m128 acc_2 = _mm_setzero_ps(); - __m128 acc_3 = _mm_setzero_ps(); - - for (; ib + 1 < nb; ib += 2) { - _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) ); - - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs); - - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( LM_GGML_FP16_TO_FP32(x[ib + 1].d) * LM_GGML_FP16_TO_FP32(y[ib + 1].d) ); - - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); - - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); - - // Apply the scale - __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); - __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); - __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); - __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); - - // Acummulate - acc_0 = _mm_add_ps(p0_d, acc_0); - acc_1 = _mm_add_ps(p1_d, acc_1); - acc_2 = _mm_add_ps(p2_d, acc_2); - acc_3 = _mm_add_ps(p3_d, acc_3); - } - - sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); -#elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - for (; ib < nb; ++ib) { - // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); - - // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - // subtract offset - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += sumi*LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d); - } - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector signed char v8 = vec_splats((signed char)0x8); - - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 8 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl(16, y[ib].qs); - - vector signed char q4x0 = vec_and(qxs, lowMask); - vector signed char q4x1 = vec_sr(qxs, v4); - - q4x0 = vec_sub(q4x0, v8); - q4x1 = vec_sub(q4x1, v8); - - vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); - - vector signed int vsumi0 = v0; - - vsumi0 = vec_sum4s(qv0, vsumi0); - vsumi0 = vec_sum4s(qv1, vsumi0); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = __lasx_xvreplfr2vr_s( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) ); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = __lasx_xvreplgr2vr_b( 8 ); - qx = __lasx_xvsub_b( qx, off ); - - __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - /* Multiply q with scale and accumulate */ - acc = __lasx_xvfmadd_s( d, q, acc ); - } - - sumf = hsum_float_8(acc); -#elif defined(__loongarch_sx) - // set constants - const __m128i low_mask = __lsx_vreplgr2vr_b(0xF); - const __m128i off = __lsx_vreplgr2vr_b(8); - - // Initialize accumulator with zeros - __m128 acc_0 = __lsx_vldi(0); - __m128 acc_1 = __lsx_vldi(0); - __m128 acc_2 = __lsx_vldi(0); - __m128 acc_3 = __lsx_vldi(0); - - for (; ib + 1 < nb; ib += 2) { - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = __lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) ); - - const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0); - - __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1); - __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0); - bx_0 = __lsx_vsub_b(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4)); - __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0); - bx_1 = __lsx_vsub_b(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - //_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = __lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[ib + 1].d) * LM_GGML_FP16_TO_FP32(y[ib + 1].d) ); - - const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0); - - __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3); - __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0); - bx_2 = __lsx_vsub_b(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4)); - __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0); - bx_3 = __lsx_vsub_b(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = __lsx_vffint_s_w(i32_0); - __m128 p1 = __lsx_vffint_s_w(i32_1); - __m128 p2 = __lsx_vffint_s_w(i32_2); - __m128 p3 = __lsx_vffint_s_w(i32_3); - - // Apply the scale - __m128 p0_d = __lsx_vfmul_s( d_0_1, p0 ); - __m128 p1_d = __lsx_vfmul_s( d_0_1, p1 ); - __m128 p2_d = __lsx_vfmul_s( d_2_3, p2 ); - __m128 p3_d = __lsx_vfmul_s( d_2_3, p3 ); - - // Acummulate - acc_0 = __lsx_vfadd_s(p0_d, acc_0); - acc_1 = __lsx_vfadd_s(p1_d, acc_1); - acc_2 = __lsx_vfadd_s(p2_d, acc_2); - acc_3 = __lsx_vfadd_s(p3_d, acc_3); - } - - sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); -#endif - for (; ib < nb; ++ib) { - int sumi0 = 0; - int sumi1 = 0; - - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[ib].qs[j] & 0x0F) - 8; - const int v1 = (x[ib].qs[j] >> 4) - 8; - - sumi0 += (v0 * y[ib].qs[j]); - sumi1 += (v1 * y[ib].qs[j + qk/2]); - } - - int sumi = sumi0 + sumi1; - sumf += sumi*LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d); - } - - *s = sumf; -} - -void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - const int qk = QK8_1; - const int nb = n / qk; - - assert(n % qk == 0); -#if defined(__ARM_FEATURE_MATMUL_INT8) - assert((nrc == 2) || (nrc == 1)); -#else - assert(nrc == 1); -#endif - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q4_1 * restrict x = vx; - const block_q8_1 * restrict y = vy; - -#if defined(__ARM_FEATURE_MATMUL_INT8) - if (nrc == 2) { - const block_q4_1 * restrict vx0 = vx; - const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx); - const block_q8_1 * restrict vy0 = vy; - const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by); - - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t summs0 = vdupq_n_f32(0.0f); - - for (int i = 0; i < nb; i++) { - const block_q4_1 * restrict b_x0 = &vx0[i]; - const block_q4_1 * restrict b_x1 = &vx1[i]; - const block_q8_1 * restrict b_y0 = &vy0[i]; - const block_q8_1 * restrict b_y1 = &vy1[i]; - - float32_t summs_t[4] = {LM_GGML_FP16_TO_FP32(b_x0->m) * LM_GGML_FP16_TO_FP32(b_y0->s), - LM_GGML_FP16_TO_FP32(b_x1->m) * LM_GGML_FP16_TO_FP32(b_y0->s), - LM_GGML_FP16_TO_FP32(b_x0->m) * LM_GGML_FP16_TO_FP32(b_y1->s), - LM_GGML_FP16_TO_FP32(b_x1->m) * LM_GGML_FP16_TO_FP32(b_y1->s)}; - summs0 = vaddq_f32(summs0, vld1q_f32(summs_t)); - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); - const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); - - // 4-bit -> 8-bit - const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // load y - const int8x16_t y0_l = vld1q_s8(b_y0->qs); - const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); - const int8x16_t y1_l = vld1q_s8(b_y1->qs); - const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - - // mmla into int32x4_t - float32_t _scale[4] = {LM_GGML_FP16_TO_FP32(b_x0->d)*b_y0->d, - LM_GGML_FP16_TO_FP32(b_x0->d)*b_y1->d, - LM_GGML_FP16_TO_FP32(b_x1->d)*b_y0->d, - LM_GGML_FP16_TO_FP32(b_x1->d)*b_y1->d}; - float32x4_t scale = vld1q_f32(_scale); - - int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - - int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - - int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - - int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); - } - - float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); - float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - sumv2 = vaddq_f32(sumv2, summs0); - - vst1_f32(s, vget_low_f32 (sumv2)); - vst1_f32(s + bs, vget_high_f32(sumv2)); - return; - } -#endif - - int ib = 0; - float sumf = 0; - - // TODO: add WASM SIMD -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float summs = 0; - - for (; ib + 1 < nb; ib += 2) { - const block_q4_1 * restrict x0 = &x[ib + 0]; - const block_q4_1 * restrict x1 = &x[ib + 1]; - const block_q8_1 * restrict y0 = &y[ib + 0]; - const block_q8_1 * restrict y1 = &y[ib + 1]; - - summs += LM_GGML_FP16_TO_FP32(x0->m) * LM_GGML_FP16_TO_FP32(y0->s) + LM_GGML_FP16_TO_FP32(x1->m) * LM_GGML_FP16_TO_FP32(y1->s); - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - - // dot product into int32x4_t - const int32x4_t p_0 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); - const int32x4_t p_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; -#elif defined(__AVX2__) || defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - float summs = 0; - - // Main loop - for (; ib < nb; ++ib) { - const float d0 = LM_GGML_FP16_TO_FP32(x[ib].d); - const float d1 = LM_GGML_FP16_TO_FP32(y[ib].d); - - summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); - - const __m256 d0v = _mm256_set1_ps( d0 ); - const __m256 d1v = _mm256_set1_ps( d1 ); - - // Compute combined scales - const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); - - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i qx = bytes_from_nibbles_32(x[ib].qs); - const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs ); - - const __m256 xy = mul_sum_us8_pairs_float(qx, qy); - - // Accumulate d0*d1*x*y -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d0d1, xy, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); -#endif - } - - sumf = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - for (; ib < nb; ++ib) { - // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); - - // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s); - } - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 4 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].m)); - vector float vys = {LM_GGML_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f}; - vsumf0 = vec_madd(vxmin, vys, vsumf0); - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl(16, y[ib].qs); - - vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask); - vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4); - - vector signed int vsumi0 = v0; - - vsumi0 = vec_msum(q8y0, q4x0, vsumi0); - vsumi0 = vec_msum(q8y1, q4x1, vsumi0); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - float summs = 0; - - // Main loop - for (; ib < nb; ++ib) { - const float d0 = LM_GGML_FP16_TO_FP32(x[ib].d); - const float d1 = LM_GGML_FP16_TO_FP32(y[ib].d); - - summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); - - const __m256 d0v = __lasx_xvreplfr2vr_s( d0 ); - const __m256 d1v = __lasx_xvreplfr2vr_s( d1 ); - - // Compute combined scales - const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v ); - - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i qx = bytes_from_nibbles_32(x[ib].qs); - const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0); - - const __m256 xy = mul_sum_us8_pairs_float(qx, qy); - - // Accumulate d0*d1*x*y - acc = __lasx_xvfmadd_s( d0d1, xy, acc ); - } - - sumf = hsum_float_8(acc) + summs; -#endif - for (; ib < nb; ++ib) { - int sumi0 = 0; - int sumi1 = 0; - - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[ib].qs[j] & 0x0F); - const int v1 = (x[ib].qs[j] >> 4); - - sumi0 += (v0 * y[ib].qs[j]); - sumi1 += (v1 * y[ib].qs[j + qk/2]); - } - - int sumi = sumi0 + sumi1; - sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s); - } - - *s = sumf; -} - -void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - const int qk = QK8_0; - const int nb = n / qk; - - int ib = 0; - float sumf = 0; - - assert(n % qk == 0); - assert(qk == QK5_0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q5_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - uint32_t qh0; - uint32_t qh1; - - uint64_t tmp0[4]; - uint64_t tmp1[4]; - - for (; ib + 1 < nb; ib += 2) { - const block_q5_0 * restrict x0 = &x[ib]; - const block_q5_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - // extract the 5th bit via lookup table ((!b) << 4) - memcpy(&qh0, x0->qh, sizeof(qh0)); - memcpy(&qh1, x1->qh, sizeof(qh1)); - - tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_1[(qh0 >> 24) ]; - - tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_1[(qh1 >> 24) ]; - - const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); - const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); - const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); - const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__wasm_simd128__) - v128_t sumv = wasm_f32x4_splat(0.0f); - - uint32_t qh; - uint64_t tmp[4]; - - // TODO: check if unrolling this is better - for (; ib < nb; ++ib) { - const block_q5_0 * restrict x0 = &x[ib]; - const block_q8_0 * restrict y0 = &y[ib]; - - const v128_t m4b = wasm_i8x16_splat(0x0F); - - // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); - - tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_1[(qh >> 24) ]; - - const v128_t qhl = wasm_v128_load(tmp + 0); - const v128_t qhh = wasm_v128_load(tmp + 2); - - const v128_t v0 = wasm_v128_load(x0->qs); - - // 4-bit -> 8-bit - const v128_t v0l = wasm_v128_and (v0, m4b); - const v128_t v0h = wasm_u8x16_shr(v0, 4); - - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); - const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); - - // load y - const v128_t v1l = wasm_v128_load(y0->qs); - const v128_t v1h = wasm_v128_load(y0->qs + 16); - - // int8x16 -> int16x8 - const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); - const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); - const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); - const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); - - const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); - const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); - const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); - const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - - // dot product - sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( - wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(LM_GGML_FP16_TO_FP32(x0->d) * LM_GGML_FP16_TO_FP32(y0->d)))); - } - - sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - __m256i bxhi = bytes_from_bits_32(x[ib].qh); - bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); - qx = _mm256_or_si256(qx, bxhi); - - __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps(d, q, acc); - } - - sumf = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8((char)0xF0); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); - - __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); - const __m256i bxhi = bytes_from_bits_32(x[ib].qh); - __m128i bxhil = _mm256_castsi256_si128(bxhi); - __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_andnot_si128(bxhil, mask); - bxhih = _mm_andnot_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx_0); - __m128i bxh = _mm256_extractf128_si256(bx_0, 1); - bxl = _mm_or_si128(bxl, bxhil); - bxh = _mm_or_si128(bxh, bxhih); - bx_0 = MM256_SET_M128I(bxh, bxl); - - const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0); - - /* Multiply q with scale and accumulate */ - acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); - } - - sumf = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - uint32_t qh; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - // These temporary registers are for masking and shift operations - vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); - - vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); - vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); - - for (; ib < nb; ++ib) { - memcpy(&qh, x[ib].qh, sizeof(uint32_t)); - - // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - - // ((qh & (1u << (j + 16))) >> (j + 12)); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); - vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); - - // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); - vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); - vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); - - // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); - - vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); - vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d)) * sumi; - } - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector unsigned char v4 = vec_splats((unsigned char)4); - - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 4 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])}; - vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])}; - - vector signed char qh0 = (vector signed char)aux64x2_0; - vector signed char qh1 = (vector signed char)aux64x2_1; - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - - vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0); - vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1); - - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl( 16, y[ib].qs); - - vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1)); - - qv0 = vec_add(qv0, qv1); - - vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); //FIXME - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - __m256i bxhi = bytes_from_bits_32(x[ib].qh); - bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0)); - qx = __lasx_xvor_v(qx, bxhi); - - __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - /* Multiply q with scale and accumulate */ - acc = __lasx_xvfmadd_s(d, q, acc); - } - - sumf = hsum_float_8(acc); -#endif - for (; ib < nb; ++ib) { - uint32_t qh; - memcpy(&qh, x[ib].qh, sizeof(qh)); - - int sumi0 = 0; - int sumi1 = 0; - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); - - const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); - const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); - - sumi0 += (x0 * y[ib].qs[j]); - sumi1 += (x1 * y[ib].qs[j + qk/2]); - } - - int sumi = sumi0 + sumi1; - sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d)) * sumi; - } - - *s = sumf; -} - -void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - const int qk = QK8_1; - const int nb = n / qk; - - int ib = 0; - float sumf = 0; - - assert(n % qk == 0); - assert(qk == QK5_1); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q5_1 * restrict x = vx; - const block_q8_1 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float summs0 = 0.0f; - float summs1 = 0.0f; - - uint32_t qh0; - uint32_t qh1; - - uint64_t tmp0[4]; - uint64_t tmp1[4]; - - for (; ib + 1 < nb; ib += 2) { - const block_q5_1 * restrict x0 = &x[ib]; - const block_q5_1 * restrict x1 = &x[ib + 1]; - const block_q8_1 * restrict y0 = &y[ib]; - const block_q8_1 * restrict y1 = &y[ib + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - summs0 += LM_GGML_FP16_TO_FP32(x0->m) * LM_GGML_FP16_TO_FP32(y0->s); - summs1 += LM_GGML_FP16_TO_FP32(x1->m) * LM_GGML_FP16_TO_FP32(y1->s); - - // extract the 5th bit via lookup table ((b) << 4) - memcpy(&qh0, x0->qh, sizeof(qh0)); - memcpy(&qh1, x1->qh, sizeof(qh1)); - - tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_0[(qh0 >> 24) ]; - - tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_0[(qh1 >> 24) ]; - - const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); - const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); - const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); - const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // add high bit - const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; -#elif defined(__wasm_simd128__) - v128_t sumv = wasm_f32x4_splat(0.0f); - - float summs = 0.0f; - - uint32_t qh; - uint64_t tmp[4]; - - // TODO: check if unrolling this is better - for (; ib < nb; ++ib) { - const block_q5_1 * restrict x0 = &x[ib]; - const block_q8_1 * restrict y0 = &y[ib]; - - summs += LM_GGML_FP16_TO_FP32(x0->m) * LM_GGML_FP16_TO_FP32(y0->s); - - const v128_t m4b = wasm_i8x16_splat(0x0F); - - // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); - - tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_0[(qh >> 24) ]; - - const v128_t qhl = wasm_v128_load(tmp + 0); - const v128_t qhh = wasm_v128_load(tmp + 2); - - const v128_t v0 = wasm_v128_load(x0->qs); - - // 4-bit -> 8-bit - const v128_t v0l = wasm_v128_and (v0, m4b); - const v128_t v0h = wasm_u8x16_shr(v0, 4); - - // add high bit - const v128_t v0lf = wasm_v128_or(v0l, qhl); - const v128_t v0hf = wasm_v128_or(v0h, qhh); - - // load y - const v128_t v1l = wasm_v128_load(y0->qs); - const v128_t v1h = wasm_v128_load(y0->qs + 16); - - // int8x16 -> int16x8 - const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); - const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); - const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); - const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); - - const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); - const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); - const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); - const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - - // dot product - sumv = wasm_f32x4_add(sumv, - wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(LM_GGML_FP16_TO_FP32(x0->d) * LM_GGML_FP16_TO_FP32(y0->d)))); - } - - sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.0f; - - // Main loop - for (; ib < nb; ++ib) { - const __m256 dx = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d)); - - summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - __m256i bxhi = bytes_from_bits_32(x[ib].qh); - bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); - qx = _mm256_or_si256(qx, bxhi); - - const __m256 dy = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib].d)); - const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_us8_pairs_float(qx, qy); - - acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); - } - - sumf = hsum_float_8(acc) + summs; -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8(0x10); - - float summs = 0.0f; - - // Main loop - for (; ib < nb; ++ib) { - const __m256 dx = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d)); - - summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); - - __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); - const __m256i bxhi = bytes_from_bits_32(x[ib].qh); - __m128i bxhil = _mm256_castsi256_si128(bxhi); - __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_and_si128(bxhil, mask); - bxhih = _mm_and_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx_0); - __m128i bxh = _mm256_extractf128_si256(bx_0, 1); - bxl = _mm_or_si128(bxl, bxhil); - bxh = _mm_or_si128(bxh, bxhih); - bx_0 = MM256_SET_M128I(bxh, bxl); - - const __m256 dy = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib].d)); - const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0); - - acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); - } - - sumf = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) - uint32_t qh; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - // temporary registers for shift operations - vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); - - for (; ib < nb; ++ib) { - memcpy(&qh, x[ib].qh, sizeof(uint32_t)); - - // load qh - vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); - - // ((qh >> (j + 0)) << 4) & 0x10; - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); - - // ((qh >> (j + 12)) ) & 0x10; - vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); - - // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); - vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); - vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); - - // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); - - vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); - vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s); - } - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 4 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].m)); - vector float vys = {LM_GGML_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f}; - vsumf0 = vec_madd(vxmin, vys, vsumf0); - - vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])}; - vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])}; - - vector signed char qh0 = (vector signed char)aux64x2_0; - vector signed char qh1 = (vector signed char)aux64x2_1; - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - - vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0); - vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1); - - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl( 16, y[ib].qs); - - vector signed int vsumi0 = v0; - - vsumi0 = vec_msum(q8y0, q5x0, vsumi0); - vsumi0 = vec_msum(q8y1, q5x1, vsumi0); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - float summs = 0.0f; - - // Main loop - for (; ib < nb; ++ib) { - const __m256 dx = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(x[ib].d)); - - summs += LM_GGML_FP16_TO_FP32(x[ib].m) * LM_GGML_FP16_TO_FP32(y[ib].s); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - __m256i bxhi = bytes_from_bits_32(x[ib].qh); - bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10)); - qx = __lasx_xvor_v(qx, bxhi); - - const __m256 dy = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(y[ib].d)); - const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); - - const __m256 q = mul_sum_us8_pairs_float(qx, qy); - - acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc); - } - - sumf = hsum_float_8(acc) + summs; -#endif - for (; ib < nb; ++ib) { - uint32_t qh; - memcpy(&qh, x[ib].qh, sizeof(qh)); - - int sumi0 = 0; - int sumi1 = 0; - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; - const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; - - sumi0 += (x0 * y[ib].qs[j]); - sumi1 += (x1 * y[ib].qs[j + qk/2]); - } - - int sumi = sumi0 + sumi1; - sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s); - } - - *s = sumf; -} - -void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); -#if defined(__ARM_FEATURE_MATMUL_INT8) - assert((nrc == 2) || (nrc == 1)); -#else - assert(nrc == 1); -#endif - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q8_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_FEATURE_MATMUL_INT8) - if (nrc == 2) { - const block_q8_0 * restrict vx0 = vx; - const block_q8_0 * restrict vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx); - const block_q8_0 * restrict vy0 = vy; - const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by); - - float32x4_t sumv0 = vdupq_n_f32(0.0f); - - for (int i = 0; i < nb; i++) { - const block_q8_0 * restrict b_x0 = &vx0[i]; - const block_q8_0 * restrict b_y0 = &vy0[i]; - - const block_q8_0 * restrict b_x1 = &vx1[i]; - const block_q8_0 * restrict b_y1 = &vy1[i]; - - const int8x16_t x0_l = vld1q_s8(b_x0->qs); - const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16); - const int8x16_t x1_l = vld1q_s8(b_x1->qs); - const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16); - - // load y - const int8x16_t y0_l = vld1q_s8(b_y0->qs); - const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); - const int8x16_t y1_l = vld1q_s8(b_y1->qs); - const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - - float32_t _scale[4] = {LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y0->d), - LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y1->d), - LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y0->d), - LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y1->d)}; - float32x4_t scale = vld1q_f32(_scale); - - int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - - int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - - int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - - int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - - sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); - } - float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); - float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - - vst1_f32(s, vget_low_f32(sumv2)); - vst1_f32(s + bs, vget_high_f32(sumv2)); - return; - } -#endif - - int ib = 0; - float sumf = 0; - -#if defined(__ARM_FEATURE_SVE) - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); - - const int vector_length = lm_ggml_cpu_get_sve_cnt()*8; - - //VLA Implemenation for SVE - switch (vector_length) { - case 128: - { - // predicate for activating lanes for 16 Int8 elements - const svbool_t ph16 = svptrue_pat_b8 (SV_VL16); - const svbool_t pl16 = svptrue_pat_b32(SV_VL4); - - for (; ib + 1 < nb; ib += 2) { - const block_q8_0 * restrict x0 = &x[ib + 0]; - const block_q8_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - // load x - const svint8_t qx0_0 = svld1_s8(ph16, x0->qs); - const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16); - const svint8_t qx1_0 = svld1_s8(ph16, x1->qs); - const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16); - - // load y - const svint8_t qy0_0 = svld1_s8(ph16, y0->qs); - const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16); - const svint8_t qy1_0 = svld1_s8(ph16, y1->qs); - const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16); - - sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16, - svdot_s32(svdup_n_s32(0), qx0_0, qy0_0), - svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16, - svdot_s32(svdup_n_s32(0), qx1_0, qy1_0), - svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); - } - - sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1)); - } break; - case 256: - { - //printf("sve256"); - for (; ib + 1 < nb; ib += 2) { - const block_q8_0 * restrict x0 = &x[ib + 0]; - const block_q8_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - // load x - const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); - const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); - - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), - svdot_s32(svdup_n_s32(0), qx0, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), - svdot_s32(svdup_n_s32(0), qx1, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); - } - - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - } break; - case 512: - { - // predicate for activating high 256 bit - const svbool_t ph32 = svptrue_pat_b8(SV_VL32); - // predicate for activating low 256 bit - const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32); - - // predicate for activating high lanes for 8 float32 elements - const svbool_t ph8 = svptrue_pat_b32(SV_VL8); - // predicate for activating low lanes for 8 float32 elements - const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8); - - svfloat32_t sumv00 = svdup_n_f32(0.0f); - - for (; ib + 1 < nb; ib += 2) { - const block_q8_0 * restrict x0 = &x[ib + 0]; - const block_q8_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits - // and add them to make one 64 element vector - // load x - const svint8_t qx_32 = svld1_s8(ph32, x0->qs); - svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2); - - qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64); - - // load y - const svint8_t qy_32 = svld1_s8(ph32, y0->qs); - svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2); - - qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64); - - // scale creation - const float32_t deq1 = LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d); - const float32_t deq2 = LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d); - - // duplicate deq1 in first half of vector and deq2 in second half of vector - const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2); - - const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64)); - - sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp); - } - - sumf = svaddv_f32(svptrue_b32(), sumv00); - break; - } - default: - assert(false && "Unsupported vector length"); - break; - } -#elif defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - for (; ib + 1 < nb; ib += 2) { - const block_q8_0 * restrict x0 = &x[ib + 0]; - const block_q8_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - const int8x16_t x0_0 = vld1q_s8(x0->qs); - const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); - const int8x16_t x1_0 = vld1q_s8(x1->qs); - const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); - - // load y - const int8x16_t y0_0 = vld1q_s8(y0->qs); - const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); - const int8x16_t y1_0 = vld1q_s8(y1->qs); - const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - lm_ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), - lm_ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d)); - - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - lm_ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), - lm_ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) || defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (; ib < nb; ++ib) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); - __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs); - __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - // Multiply q with scale and accumulate -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d, q, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc ); -#endif - } - - sumf = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk); - - for (; ib < nb; ++ib) { - // load elements - vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl); - vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); - - vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl); - - vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); - - sumf += sumi*(LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d)); - } -#elif defined(__POWER9_VECTOR__) - const vector signed int v0 = vec_splats((int32_t)0); - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 8 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector signed char q8x0 = vec_xl( 0, x[ib].qs); - vector signed char q8x1 = vec_xl(16, x[ib].qs); - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl(16, y[ib].qs); - - vector signed short qv0 = vec_mule(q8x0, q8y0); - vector signed short qv1 = vec_mulo(q8x0, q8y0); - vector signed short qv2 = vec_mule(q8x1, q8y1); - vector signed short qv3 = vec_mulo(q8x1, q8y1); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - - vsumi0 = vec_sum4s(qv0, vsumi0); - vsumi1 = vec_sum4s(qv1, vsumi1); - vsumi0 = vec_sum4s(qv2, vsumi0); - vsumi1 = vec_sum4s(qv3, vsumi1); - - vsumi0 = vec_add(vsumi0, vsumi1); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - // Main loop - for (; ib < nb; ++ib) { - // Compute combined scale for the block - const __m256 d = __lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)); - __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0); - __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - // Multiply q with scale and accumulate - acc = __lasx_xvfmadd_s( d, q, acc ); - } - - sumf = hsum_float_8(acc); -#endif - for (; ib < nb; ++ib) { - int sumi = 0; - - for (int j = 0; j < qk; j++) { - sumi += x[ib].qs[j]*y[ib].qs[j]; - } - - sumf += sumi*(LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d)); - } - - *s = sumf; -} - -void lm_ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_tq1_0 * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - float sumf = 0.0f; - - uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; - - const uint8x16_t shift = vld1q_u8(k_shift); - - for (int i = 0; i < nb; ++i) { -#if defined(__ARM_FEATURE_DOTPROD) - int32x4_t sumi0 = vdupq_n_s32(0); - int32x4_t sumi1 = vdupq_n_s32(0); -#else - int16x8_t sumi0 = vdupq_n_s16(0); - int16x8_t sumi1 = vdupq_n_s16(0); -#endif - - // first 32 bytes of 5 elements - { - uint8x16_t qx0 = vld1q_u8(x[i].qs + 0); - uint8x16_t qx1 = vld1q_u8(x[i].qs + 16); - uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3)); - uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3)); - uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9)); - uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9)); - uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27)); - uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27)); - uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81)); - uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81)); - - // multiply by 3 and keep the 2 bits above 8 bits - int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); - int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); - int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); - int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); - int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); - int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); - int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6)); - int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6)); - int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6)); - int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6)); - - const int8x16_t qy0 = vld1q_s8(y[i].qs + 0); - const int8x16_t qy1 = vld1q_s8(y[i].qs + 16); - const int8x16_t qy2 = vld1q_s8(y[i].qs + 32); - const int8x16_t qy3 = vld1q_s8(y[i].qs + 48); - const int8x16_t qy4 = vld1q_s8(y[i].qs + 64); - const int8x16_t qy5 = vld1q_s8(y[i].qs + 80); - const int8x16_t qy6 = vld1q_s8(y[i].qs + 96); - const int8x16_t qy7 = vld1q_s8(y[i].qs + 112); - const int8x16_t qy8 = vld1q_s8(y[i].qs + 128); - const int8x16_t qy9 = vld1q_s8(y[i].qs + 144); - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vdotq_s32(sumi0, sqx0, qy0); - sumi1 = vdotq_s32(sumi1, sqx1, qy1); - sumi0 = vdotq_s32(sumi0, sqx2, qy2); - sumi1 = vdotq_s32(sumi1, sqx3, qy3); - sumi0 = vdotq_s32(sumi0, sqx4, qy4); - sumi1 = vdotq_s32(sumi1, sqx5, qy5); - sumi0 = vdotq_s32(sumi0, sqx6, qy6); - sumi1 = vdotq_s32(sumi1, sqx7, qy7); - sumi0 = vdotq_s32(sumi0, sqx8, qy8); - sumi1 = vdotq_s32(sumi1, sqx9, qy9); -#else - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9)); -#endif - } - - // last 16 bytes of 5-element, along with the 4 bytes of 4 elements - { - uint8x16_t qx0 = vld1q_u8(x[i].qs + 32); - uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3)); - uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9)); - uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27)); - uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81)); - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned - uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh)); - qx5 = vmulq_u8(qx5, shift); - - // multiply by 3 and keep the 2 bits above 8 bits - int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); - int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); - int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); - int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); - int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); - int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); - - const int8x16_t qy0 = vld1q_s8(y[i].qs + 160); - const int8x16_t qy1 = vld1q_s8(y[i].qs + 176); - const int8x16_t qy2 = vld1q_s8(y[i].qs + 192); - const int8x16_t qy3 = vld1q_s8(y[i].qs + 208); - const int8x16_t qy4 = vld1q_s8(y[i].qs + 224); - const int8x16_t qy5 = vld1q_s8(y[i].qs + 240); - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vdotq_s32(sumi0, sqx0, qy0); - sumi1 = vdotq_s32(sumi1, sqx1, qy1); - sumi0 = vdotq_s32(sumi0, sqx2, qy2); - sumi1 = vdotq_s32(sumi1, sqx3, qy3); - sumi0 = vdotq_s32(sumi0, sqx4, qy4); - sumi1 = vdotq_s32(sumi1, sqx5, qy5); -#else - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); -#endif - } - - const int16x8_t ysum0 = vld1q_s16(y[i].bsums); - const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); - - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vaddq_s32(sumi0, sumi1); - sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); - - sumf += d * (float) vaddvq_s32(sumi0); -#else - sumi0 = vaddq_s16(sumi0, sumi1); - sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); - - sumf += d * (float) vaddlvq_s16(sumi0); -#endif - } - - *s = sumf; - -#elif defined(__AVX2__) - __m256 sumf = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - // 16-bit sums - __m256i sumi0 = _mm256_setzero_si256(); - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - - // first 32 bytes of 5 elements - { - __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs)); - // 8-bit multiplies with shifts, masks and adds - __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3 - __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9 - __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9 - __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9 - - // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits? - - // Cancel the +1 from avg so that it behaves like a halving add - qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1)); - qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1)); - qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1)); - qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1)); - qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1)); - // Multiply by 3 and get the top 2 bits - qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256())); - qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256())); - qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256())); - qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256())); - qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256())); - qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3)); - qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3)); - qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3)); - qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3)); - qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3)); - - const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0)); - const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32)); - const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64)); - const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96)); - const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128)); - - qx0 = _mm256_maddubs_epi16(qx0, qy0); - qx1 = _mm256_maddubs_epi16(qx1, qy1); - qx2 = _mm256_maddubs_epi16(qx2, qy2); - qx3 = _mm256_maddubs_epi16(qx3, qy3); - qx4 = _mm256_maddubs_epi16(qx4, qy4); - - sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); - sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); - sumi2 = _mm256_add_epi16(sumi2, qx4); - } - - // last 16 bytes of 5-element, along with the 4 bytes of 4 elements - { - __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32)); - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned - __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh)); - __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3 - __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9 - __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9 - __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9 - __m256i qx01 = MM256_SET_M128I(qx1, qx0); - __m256i qx23 = MM256_SET_M128I(qx3, qx2); - - // avx2 does not have 8-bit multiplies, so 16-bit it is. - qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1)); - qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF)); - __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1)); - - __m256i qx45 = MM256_SET_M128I(qx5, qx4); - - // Cancel the +1 from avg so that it behaves like a halving add - qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1)); - qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1)); - qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1)); - // Multiply by 3 and get the top 2 bits - qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256())); - qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256())); - qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256())); - qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3)); - qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3)); - qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3)); - - const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160)); - const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192)); - const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224)); - - qx01 = _mm256_maddubs_epi16(qx01, qy01); - qx23 = _mm256_maddubs_epi16(qx23, qy23); - qx45 = _mm256_maddubs_epi16(qx45, qy45); - - sumi0 = _mm256_add_epi16(sumi0, qx01); - sumi1 = _mm256_add_epi16(sumi1, qx23); - sumi2 = _mm256_add_epi16(sumi2, qx45); - } - - const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); - const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(x[i].d)); - - sumi0 = _mm256_sub_epi16(sumi0, ysum); - sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2)); - sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); - - sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); - } - - *s = hsum_float_8(sumf); - -#else - const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; - - float sumf = 0.0f; - - for (int i = 0; i < nb; ++i) { - int sum = 0; - - for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) { - for (size_t l = 0; l < 5; ++l) { - for (size_t m = 0; m < 32; ++m) { - uint8_t q = x[i].qs[j + m] * pow3[l]; - uint16_t xi = ((uint16_t) q * 3) >> 8; - sum += (xi - 1) * y[i].qs[j*5 + l*32 + m]; - } - } - } - for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) { - for (size_t l = 0; l < 5; ++l) { - for (size_t m = 0; m < 16; ++m) { - uint8_t q = x[i].qs[j + m] * pow3[l]; - uint16_t xi = ((uint16_t) q * 3) >> 8; - sum += (xi - 1) * y[i].qs[j*5 + l*16 + m]; - } - } - } - - for (size_t l = 0; l < 4; ++l) { - for (size_t j = 0; j < sizeof(x->qh); ++j) { - uint8_t q = x[i].qh[j] * pow3[l]; - uint16_t xi = ((uint16_t) q * 3) >> 8; - sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j]; - } - } - - sumf += (float) sum * (LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d); - } - - *s = sumf; -#endif -} - -void lm_ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_tq2_0 * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - float sumf = 0.0f; - - const uint8x16_t m3 = vdupq_n_u8(3); - - for (int i = 0; i < nb; ++i) { -#if defined(__ARM_FEATURE_DOTPROD) - int32x4_t sumi0 = vdupq_n_s32(0); - int32x4_t sumi1 = vdupq_n_s32(0); -#else - int16x8_t sumi0 = vdupq_n_s16(0); - int16x8_t sumi1 = vdupq_n_s16(0); -#endif - - for (size_t j = 0; j < sizeof(x->qs); j += 32) { - uint8x16_t qx0 = vld1q_u8(x[i].qs + j); - uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16); - uint8x16_t qx2 = vshrq_n_u8(qx0, 2); - uint8x16_t qx3 = vshrq_n_u8(qx1, 2); - uint8x16_t qx4 = vshrq_n_u8(qx0, 4); - uint8x16_t qx5 = vshrq_n_u8(qx1, 4); - uint8x16_t qx6 = vshrq_n_u8(qx0, 6); - uint8x16_t qx7 = vshrq_n_u8(qx1, 6); - - int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3)); - int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3)); - int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3)); - int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3)); - int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3)); - int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3)); - int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3)); - int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3)); - - const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0); - const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16); - const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32); - const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48); - const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64); - const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80); - const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96); - const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112); - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vdotq_s32(sumi0, sqx0, qy0); - sumi1 = vdotq_s32(sumi1, sqx1, qy1); - sumi0 = vdotq_s32(sumi0, sqx2, qy2); - sumi1 = vdotq_s32(sumi1, sqx3, qy3); - sumi0 = vdotq_s32(sumi0, sqx4, qy4); - sumi1 = vdotq_s32(sumi1, sqx5, qy5); - sumi0 = vdotq_s32(sumi0, sqx6, qy6); - sumi1 = vdotq_s32(sumi1, sqx7, qy7); -#else - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); -#endif - } - - const int16x8_t ysum0 = vld1q_s16(y[i].bsums); - const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); - - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vaddq_s32(sumi0, sumi1); - sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); - - sumf += d * (float) vaddvq_s32(sumi0); -#else - sumi0 = vaddq_s16(sumi0, sumi1); - sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); - - sumf += d * (float) vaddlvq_s16(sumi0); -#endif - } - - *s = sumf; - -#elif defined(__AVX2__) - __m256 sumf = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - // 16-bit sums, because 256*127 still fits - __m256i sumi0 = _mm256_setzero_si256(); - __m256i sumi1 = _mm256_setzero_si256(); - - for (size_t j = 0; j < sizeof(x->qs); j += 32) { - __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j)); - __m256i qx1 = _mm256_srli_epi16(qx0, 2); - __m256i qx2 = _mm256_srli_epi16(qx0, 4); - __m256i qx3 = _mm256_srli_epi16(qx0, 6); - - // 0, 1, 2 (should not be 3) - qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3)); - qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3)); - qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3)); - qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3)); - - const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0)); - const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32)); - const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64)); - const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96)); - - qx0 = _mm256_maddubs_epi16(qx0, qy0); - qx1 = _mm256_maddubs_epi16(qx1, qy1); - qx2 = _mm256_maddubs_epi16(qx2, qy2); - qx3 = _mm256_maddubs_epi16(qx3, qy3); - - sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); - sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); - } - - const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); - const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(x[i].d)); - - sumi0 = _mm256_add_epi16(sumi0, sumi1); - sumi0 = _mm256_sub_epi16(sumi0, ysum); - sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); - - sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); - } - - *s = hsum_float_8(sumf); - -#else - float sumf = 0.0f; - - for (int i = 0; i < nb; ++i) { - int32_t sumi = 0; - - for (size_t j = 0; j < sizeof(x->qs); j += 32) { - for (size_t l = 0; l < 4; ++l) { - for (size_t k = 0; k < 32; ++k) { - sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1); - } - } - } - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - - sumf += (float) sumi * d; - } - - *s = sumf; -#endif -} - -void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q2_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - const uint8x16_t m3 = vdupq_n_u8(0x3); - const uint8x16_t m4 = vdupq_n_u8(0xF); - - const int32x4_t vzero = vdupq_n_s32(0); - - lm_ggml_int8x16x2_t q2bytes; - uint8_t aux[16]; - - float sum = 0; - - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - const uint8_t * restrict sc = x[i].scales; - - const uint8x16_t mins_and_scales = vld1q_u8(sc); - const uint8x16_t scales = vandq_u8(mins_and_scales, m4); - vst1q_u8(aux, scales); - - const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); - const lm_ggml_int16x8x2_t q8sums = lm_ggml_vld1q_s16_x2(y[i].bsums); - const lm_ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}}; - const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), - vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); - const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), - vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); - sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); - - int isum = 0; - int is = 0; - -// We use this macro instead of a function call because for some reason -// the code runs 2-3% slower, even if the function is declared inline -#define MULTIPLY_ACCUM_WITH_SCALE(index)\ - isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ - isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; - -#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ - q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32;\ - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ - MULTIPLY_ACCUM_WITH_SCALE((index)); - - for (int j = 0; j < QK_K/128; ++j) { - const lm_ggml_uint8x16x2_t q2bits = lm_ggml_vld1q_u8_x2(q2); q2 += 32; - - lm_ggml_int8x16x2_t q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32; - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); - - MULTIPLY_ACCUM_WITH_SCALE(0); - - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); - - is += 8; - } - - sum += d * isum; - } - - *s = sum; - -#elif defined __AVX2__ - - const __m256i m3 = _mm256_set1_epi8(3); - const __m128i m4 = _mm_set1_epi8(0xF); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); - const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); - const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); - const __m256i mins = _mm256_cvtepi8_epi16(mins8); - const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); - - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); - - const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); - const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); - const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); - const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; - - __m256i sumi = _mm256_setzero_si256(); - - for (int j = 0; j < QK_K/128; ++j) { - - const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - const __m256i q2_0 = _mm256_and_si256(q2bits, m3); - const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); - const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); - const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); - - __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); - __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); - __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); - __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); - - p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); - p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); - p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); - p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); - - p0 = _mm256_add_epi32(p0, p1); - p2 = _mm256_add_epi32(p2, p3); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); - } - - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m3 = _mm_set1_epi8(0x3); - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(0x2); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float dall = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - // load mins and scales from block_q2_K.scales[QK_K/16] - const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); - const __m128i scales16 = _mm_and_si128(mins_and_scales, m4); - const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); - const __m128i mins_0 = _mm_cvtepi8_epi16(mins16); - const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16)); - - // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2 - const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0])); - const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8])); - - // sumf += -dmin * summs in 32bits*8 - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc); - - const __m128i scales_0 = _mm_cvtepi8_epi16(scales16); - const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16)); - const __m128i scales[2] = { scales_0, scales_1 }; - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - for (int j = 0; j < QK_K/128; ++j) { - - // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K] - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - - // load 2bits*16*8 from block_q2_K.qs[QK_K/4] - __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; - const __m128i q2_0 = _mm_and_si128(q2bits, m3); - const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); - q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; - const __m128i q2_1 = _mm_and_si128(q2bits, m3); - const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); - - // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8 - __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0); - __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1); - __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2); - __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3); - __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4); - __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5); - __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6); - __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7); - - // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8 - __m128i shuffle = _mm_set1_epi16(0x0100); - p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0); - shuffle = _mm_add_epi16(shuffle, m2); - p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1); - shuffle = _mm_add_epi16(shuffle, m2); - p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2); - shuffle = _mm_add_epi16(shuffle, m2); - p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3); - shuffle = _mm_add_epi16(shuffle, m2); - p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4); - shuffle = _mm_add_epi16(shuffle, m2); - p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5); - shuffle = _mm_add_epi16(shuffle, m2); - p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6); - shuffle = _mm_add_epi16(shuffle, m2); - p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7); - - p0 = _mm_add_epi32(p0, p1); - p2 = _mm_add_epi32(p2, p3); - p4 = _mm_add_epi32(p4, p5); - p6 = _mm_add_epi32(p6, p7); - - // isum in 32bits*4*2 - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6)); - } - - // sumf += dall * isum - dmin * summs in 32bits - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - float sumf = 0; - uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - - const float dall = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - size_t vl = 16; - - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - - vl = 32; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - - uint8_t is=0; - int isum=0; - - for (int j = 0; j < QK_K/128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); - - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); - - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); - - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); - - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); - - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); - - isum += __riscv_vmv_x_s_i32m1_i32(isum1); - - q2+=32; q8+=128; is=8; - - } - - sumf += dall * isum; - - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0x3); - const vector signed char lowScaleMask = vec_splats((signed char)0xF); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v6 = vec_splats((unsigned char)0x6); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[i].dmin)); - vector float vdmin = vec_mul(vxmin, vyd); - - vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); - vector signed short q8ysums1 = vec_xl(16, y[i].bsums); - - vector signed char q2xmins = (vector signed char)vec_xl( 0, x[i].scales); - vector signed char vscales = vec_and(q2xmins, lowScaleMask); - - q2xmins = vec_sr(q2xmins, v4); - vector signed short q2xmins0 = vec_unpackh(q2xmins); - vector signed short q2xmins1 = vec_unpackl(q2xmins); - - vector signed int prod0 = vec_mule(q2xmins0, q8ysums0); - vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0); - vector signed int prod2 = vec_mule(q2xmins1, q8ysums1); - vector signed int prod3 = vec_mulo(q2xmins1, q8ysums1); - - vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); - vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); - vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); - vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - vector signed int vsumi4 = v0; - vector signed int vsumi5 = v0; - vector signed int vsumi6 = v0; - vector signed int vsumi7 = v0; - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/128; ++j) { - __builtin_prefetch(q2, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q2); - vector signed char qxs1 = (vector signed char)vec_xl(16, q2); - q2 += 32; - - vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask); - vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask); - vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask); - vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask); - vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask); - vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask); - vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask); - vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl( 16, q8); - vector signed char q8y01 = vec_xl( 32, q8); - vector signed char q8y11 = vec_xl( 48, q8); - vector signed char q8y02 = vec_xl( 64, q8); - vector signed char q8y12 = vec_xl( 80, q8); - vector signed char q8y03 = vec_xl( 96, q8); - vector signed char q8y13 = vec_xl(112, q8); - q8 += 128; - - vector signed int qv0 = vec_msum(q8y00, q2x00, v0); - vector signed int qv1 = vec_msum(q8y01, q2x01, v0); - vector signed int qv2 = vec_msum(q8y02, q2x02, v0); - vector signed int qv3 = vec_msum(q8y03, q2x03, v0); - vector signed int qv4 = vec_msum(q8y10, q2x10, v0); - vector signed int qv5 = vec_msum(q8y11, q2x11, v0); - vector signed int qv6 = vec_msum(q8y12, q2x12, v0); - vector signed int qv7 = vec_msum(q8y13, q2x13, v0); - - vector signed short vscales_07 = vec_unpackh(vscales); - vector signed int vscales_03 = vec_unpackh(vscales_07); - vector signed int vscales_47 = vec_unpackl(vscales_07); - vector signed int vs0 = vec_splat(vscales_03, 0); - vector signed int vs1 = vec_splat(vscales_03, 1); - vector signed int vs2 = vec_splat(vscales_03, 2); - vector signed int vs3 = vec_splat(vscales_03, 3); - vector signed int vs4 = vec_splat(vscales_47, 0); - vector signed int vs5 = vec_splat(vscales_47, 1); - vector signed int vs6 = vec_splat(vscales_47, 2); - vector signed int vs7 = vec_splat(vscales_47, 3); - vscales = vec_sld(vscales, vscales, 8); - - vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0); - vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1); - vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2); - vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3); - vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4); - vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5); - vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6); - vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7); - } - - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - - const __m256i m3 = __lasx_xvreplgr2vr_b(3); - const __m128i m4 = __lsx_vreplgr2vr_b(0xF); - - __m256 acc = (__m256)__lasx_xvldi(0); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0); - const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4); - const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4); - const __m256i mins = lasx_ext8_16(mins8); - const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0)); - - acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc); - - const __m256i all_scales = lasx_ext8_16(scales8); - const __m128i l_scales = lasx_extracti128(all_scales, 0); - const __m128i h_scales = lasx_extracti128(all_scales, 1); - const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; - - __m256i sumi = __lasx_xvldi(0); - - for (int j = 0; j < QK_K/128; ++j) { - - const __m256i q2bits = __lasx_xvld((const __m256i*)q2, 0); q2 += 32; - - const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - const __m256i q2_0 = __lasx_xvand_v(q2bits, m3); - const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3); - const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3); - const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3); - - __m256i p0 = lasx_maddubs_h(q2_0, q8_0); - __m256i p1 = lasx_maddubs_h(q2_1, q8_1); - __m256i p2 = lasx_maddubs_h(q2_2, q8_2); - __m256i p3 = lasx_maddubs_h(q2_3, q8_3); - - p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0); - p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1); - p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2); - p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3); - - p0 = __lasx_xvadd_w(p0, p1); - p2 = __lasx_xvadd_w(p2, p3); - - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p0, p2)); - } - - acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); - - } - - *s = hsum_float_8(acc); - -#else - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - - int summs = 0; - for (int j = 0; j < 16; ++j) { - summs += y[i].bsums[j] * (sc[j] >> 4); - } - - const float dall = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - int isum = 0; - int is = 0; - int d; - for (int k = 0; k < QK_K/128; ++k) { - int shift = 0; - for (int j = 0; j < 4; ++j) { - d = sc[is++] & 0xF; - int isuml = 0; - for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); - isum += d * isuml; - d = sc[is++] & 0xF; - isuml = 0; - for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); - isum += d * isuml; - shift += 2; - q8 += 32; - } - q2 += 32; - } - sumf += dall * isum - dmin * summs; - } - *s = sumf; -#endif -} - -void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const uint32_t kmask1 = 0x03030303; - const uint32_t kmask2 = 0x0f0f0f0f; - - const block_q3_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - - uint32_t aux[3]; - uint32_t utmp[4]; - - const uint8x16_t m3b = vdupq_n_u8(0x3); - const int32x4_t vzero = vdupq_n_s32(0); - - const uint8x16_t m0 = vdupq_n_u8(1); - const uint8x16_t m1 = vshlq_n_u8(m0, 1); - const uint8x16_t m2 = vshlq_n_u8(m0, 2); - const uint8x16_t m3 = vshlq_n_u8(m0, 3); - const int8_t m32 = 32; - - lm_ggml_int8x16x4_t q3bytes; - - float sum = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - - lm_ggml_uint8x16x2_t qhbits = lm_ggml_vld1q_u8_x2(qh); - - lm_ggml_uint8x16x4_t q3h; - - int32_t isum = 0; - - // Set up scales - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= m32; - - for (int j = 0; j < QK_K/128; ++j) { - - const lm_ggml_uint8x16x2_t q3bits = lm_ggml_vld1q_u8_x2(q3); q3 += 32; - const lm_ggml_int8x16x4_t q8bytes_1 = lm_ggml_vld1q_s8_x4(q8); q8 += 64; - const lm_ggml_int8x16x4_t q8bytes_2 = lm_ggml_vld1q_s8_x4(q8); q8 += 64; - - q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); - q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); - q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); - q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); - - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); - - isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; - isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; - isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; - isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; - - scale += 4; - - q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); - q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); - q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); - q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); - - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); - - isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; - isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; - isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; - isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; - - scale += 4; - - if (j == 0) { - qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); - qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); - } - - } - sum += d * isum; - - } - - *s = sum; - -#elif defined __AVX2__ - - const __m256i m3 = _mm256_set1_epi8(3); - const __m256i mone = _mm256_set1_epi8(1); - const __m128i m32 = _mm_set1_epi8(32); - - __m256 acc = _mm256_setzero_ps(); - - uint32_t aux[3]; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - // Set up scales - memcpy(aux, x[i].scales, 12); - __m128i scales128 = _mm_set_epi32( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = _mm_sub_epi8(scales128, m32); - const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); - const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); - const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); - const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; - - // high bit - const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); - - // integer accumulator - __m256i sumi = _mm256_setzero_si256(); - - int bit = 0; - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits - const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; - - // prepare low and high bits - const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); - const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); - const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); - const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); - const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - // load Q8 quants - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); - - __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); - __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); - __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); - - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); - - // multiply with scales - p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); - p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); - p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); - - // accumulate - p16_0 = _mm256_add_epi32(p16_0, p16_1); - p16_2 = _mm256_add_epi32(p16_2, p16_3); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); - - } - - // multiply with block scale and accumulate - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m3 = _mm_set1_epi8(3); - const __m128i mone = _mm_set1_epi8(1); - const __m128i m32 = _mm_set1_epi8(32); - const __m128i m2 = _mm_set1_epi8(2); - - __m256 acc = _mm256_setzero_ps(); - - const uint32_t *aux; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - // Set up scales - aux = (const uint32_t *)x[i].scales; - __m128i scales128 = _mm_set_epi32( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = _mm_sub_epi8(scales128, m32); - const __m128i scales_0 = _mm_cvtepi8_epi16(scales128); - const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128)); - const __m128i scales[2] = { scales_0, scales_1 }; - - // high bit *128*2 from block_q3_K.hmask[QK_K/8] - const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]); - const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]); - - // integer accumulator - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] - const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; - const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; - - // prepare low and high bits - const int bit = j << 2; - - const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); - const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); - const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); - const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); - - const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); - const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); - const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2); - const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2); - - const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); - const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); - const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2); - const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2); - - const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); - const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); - const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2); - const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2); - - // load Q8 quants from block_q8_K.qs[QK_K] - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0); - __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1); - __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2); - __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3); - __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4); - __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5); - __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6); - __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7); - - __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1); - __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2); - __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3); - __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4); - __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5); - __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6); - __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7); - - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - p16_4 = _mm_sub_epi16(p16_4, q8s_4); - p16_5 = _mm_sub_epi16(p16_5, q8s_5); - p16_6 = _mm_sub_epi16(p16_6, q8s_6); - p16_7 = _mm_sub_epi16(p16_7, q8s_7); - - // multiply with scales - __m128i shuffle = _mm_set1_epi16(0x0100); - p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0); - shuffle = _mm_add_epi16(shuffle, m2); - p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1); - shuffle = _mm_add_epi16(shuffle, m2); - p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2); - shuffle = _mm_add_epi16(shuffle, m2); - p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3); - shuffle = _mm_add_epi16(shuffle, m2); - p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4); - shuffle = _mm_add_epi16(shuffle, m2); - p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5); - shuffle = _mm_add_epi16(shuffle, m2); - p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6); - shuffle = _mm_add_epi16(shuffle, m2); - p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7); - - // accumulate - p16_0 = _mm_add_epi32(p16_0, p16_1); - p16_2 = _mm_add_epi32(p16_2, p16_3); - p16_4 = _mm_add_epi32(p16_4, p16_5); - p16_6 = _mm_add_epi32(p16_6, p16_7); - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6)); - - } - - // multiply with block scale and accumulate - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - uint32_t aux[3]; - uint32_t utmp[4]; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= 32; - - - size_t vl = 32; - uint8_t m = 1; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); - - int sum_t = 0; - - for (int j = 0; j < QK_K; j += 128) { - - vl = 32; - - // load Q3 - vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); - - vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); - vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); - vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); - vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); - - // compute mask for subtraction - vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); - vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); - vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); - m <<= 1; - - // load Q8 and take product with Q3 - vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - - vl = 16; - - // retrieve lane to multiply with scale - vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); - vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); - vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); - vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); - vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); - vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); - vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); - vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); - - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - - q3 += 32; q8 += 128; scale += 8; - - } - - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - sumf += d*sum_t; - - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0x3); - const vector signed char lowMask1 = vec_splats((int8_t)0xf); - const vector signed char lowMask2 = vec_splats((int8_t)0x30); - const vector int v0 = vec_splats((int32_t)0); - const vector signed char v1 = vec_splats((signed char)0x1); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v3 = vec_splats((unsigned char)0x3); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector unsigned char v6 = vec_splats((unsigned char)0x6); - const vector signed char off = vec_splats((signed char)0x20); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - UNUSED(kmask1); - UNUSED(kmask2); - - vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); - vector signed char u1 = vec_and(u0, lowMask1); - vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); - vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2)); - vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4); - vector signed char u31 = vec_and(u3, lowMask2); - - u1 = vec_or(u1, u30); - u2 = vec_or(vec_sr(u0, v4), u31); - - vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2); - vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask); - vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask); - - vscales = vec_sub(vscales, off); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - vector signed int vsumi4 = v0; - vector signed int vsumi5 = v0; - vector signed int vsumi6 = v0; - vector signed int vsumi7 = v0; - - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/128; ++j) { - __builtin_prefetch(q3, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q3); - vector signed char qxs1 = (vector signed char)vec_xl(16, q3); - q3 += 32; - - //the low 2 bits - vector signed char qxs00 = vec_and(qxs0, lowMask); - vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask); - vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask); - vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask); - vector signed char qxs10 = vec_and(qxs1, lowMask); - vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask); - vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask); - vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask); - - //the 3rd bit - vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2); - vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2); - vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2); - vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2); - vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2); - vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2); - vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2); - vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2); - qxhs0 = vec_sr(qxhs0, v4); - qxhs1 = vec_sr(qxhs1, v4); - - vector signed char q3x00 = vec_sub(qxs00, qxh00); - vector signed char q3x01 = vec_sub(qxs01, qxh01); - vector signed char q3x02 = vec_sub(qxs02, qxh02); - vector signed char q3x03 = vec_sub(qxs03, qxh03); - vector signed char q3x10 = vec_sub(qxs10, qxh10); - vector signed char q3x11 = vec_sub(qxs11, qxh11); - vector signed char q3x12 = vec_sub(qxs12, qxh12); - vector signed char q3x13 = vec_sub(qxs13, qxh13); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl( 16, q8); - vector signed char q8y01 = vec_xl( 32, q8); - vector signed char q8y11 = vec_xl( 48, q8); - vector signed char q8y02 = vec_xl( 64, q8); - vector signed char q8y12 = vec_xl( 80, q8); - vector signed char q8y03 = vec_xl( 96, q8); - vector signed char q8y13 = vec_xl(112, q8); - q8 += 128; - - vector signed short vscales_h = vec_unpackh(vscales); - vector signed short vs0 = vec_splat(vscales_h, 0); - vector signed short vs1 = vec_splat(vscales_h, 1); - vector signed short vs2 = vec_splat(vscales_h, 2); - vector signed short vs3 = vec_splat(vscales_h, 3); - vector signed short vs4 = vec_splat(vscales_h, 4); - vector signed short vs5 = vec_splat(vscales_h, 5); - vector signed short vs6 = vec_splat(vscales_h, 6); - vector signed short vs7 = vec_splat(vscales_h, 7); - vscales = vec_sld(vscales, vscales, 8); - - vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00)); - vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01)); - vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02)); - vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03)); - vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10)); - vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11)); - vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12)); - vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13)); - - vsumi0 = vec_msum(qv00, vs0, vsumi0); - vsumi1 = vec_msum(qv01, vs2, vsumi1); - vsumi2 = vec_msum(qv02, vs4, vsumi2); - vsumi3 = vec_msum(qv03, vs6, vsumi3); - vsumi4 = vec_msum(qv10, vs1, vsumi4); - vsumi5 = vec_msum(qv11, vs3, vsumi5); - vsumi6 = vec_msum(qv12, vs5, vsumi6); - vsumi7 = vec_msum(qv13, vs7, vsumi7); - } - - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - - const __m256i m3 = __lasx_xvreplgr2vr_b(3); - const __m256i mone = __lasx_xvreplgr2vr_b(1); - const __m128i m32 = __lsx_vreplgr2vr_b(32); - - __m256 acc = (__m256)__lasx_xvldi(0); - - uint32_t aux[3]; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - // Set up scales - memcpy(aux, x[i].scales, 12); - __m128i scales128 = lsx_set_w( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = __lsx_vsub_b(scales128, m32); - const __m256i all_scales = lasx_ext8_16(scales128); - const __m128i l_scales = lasx_extracti128(all_scales, 0); - const __m128i h_scales = lasx_extracti128(all_scales, 1); - const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; - - // high bit - const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0); - - // integer accumulator - __m256i sumi = __lasx_xvldi(0); - - int bit = 0; - int is = 0; - __m256i xvbit; - - - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits - const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32; - - xvbit = __lasx_xvreplgr2vr_h(bit); - // prepare low and high bits - const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3); - const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; - - xvbit = __lasx_xvreplgr2vr_h(bit); - const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3); - const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; - - xvbit = __lasx_xvreplgr2vr_h(bit); - const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3); - const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; - - xvbit = __lasx_xvreplgr2vr_h(bit); - const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3); - const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; - - // load Q8 quants - const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0); - __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1); - __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2); - __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3); - - __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0); - __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1); - __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2); - __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3); - - p16_0 = __lasx_xvsub_h(p16_0, q8s_0); - p16_1 = __lasx_xvsub_h(p16_1, q8s_1); - p16_2 = __lasx_xvsub_h(p16_2, q8s_2); - p16_3 = __lasx_xvsub_h(p16_3, q8s_3); - - // multiply with scales - p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); - p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); - p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); - p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); - - // accumulate - p16_0 = __lasx_xvadd_w(p16_0, p16_1); - p16_2 = __lasx_xvadd_w(p16_2, p16_3); - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2)); - } - // multiply with block scale and accumulate - acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME - } - - *s = hsum_float_8(acc); - -#else - // scalar version - // This function is written like this so the compiler can manage to vectorize most of it - // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the - // manually vectorized version above. Every other version I tried would run at least 4 times slower. - // The ideal situation would be if we could just write the code once, and the compiler would - // automatically produce the best possible set of machine instructions, instead of us having to manually - // write vectorized versions for AVX, ARM_NEON, etc. - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - uint32_t auxs[4]; - const int8_t * scales = (const int8_t*)auxs; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict hm = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - uint8_t m = 1; - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - q3 += 32; - } - a = aux8; - - memcpy(auxs, x[i].scales, 12); - uint32_t tmp = auxs[2]; - auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); - auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - for (int j = 0; j < QK_K/16; ++j) { - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; - q8 += 8; a += 8; - } - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; - -#endif - -} - -void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q4_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - - uint32_t utmp[4]; - -#ifdef __ARM_NEON - const uint8x16_t m4b = vdupq_n_u8(0xf); - const int32x4_t mzero = vdupq_n_s32(0); - - lm_ggml_int8x16x2_t q4bytes; - lm_ggml_int8x16x2_t q8bytes; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); - - memcpy(utmp, x[i].scales, 12); - - uint32x2_t mins8 = { 0 }; - mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); - mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); - - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[0] &= kmask1; - - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - sumf -= dmin * vaddvq_s32(prod); - - const uint8_t * scales = (const uint8_t *)utmp; - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - int32_t sumi1 = 0; - int32_t sumi2 = 0; - - for (int j = 0; j < QK_K/64; ++j) { - const lm_ggml_uint8x16x2_t q4bits = lm_ggml_vld1q_u8_x2(q4); q4 += 32; - - q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - - const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - sumi1 += vaddvq_s32(p1) * scales[2*j+0]; - - q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - - const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - - sumi2 += vaddvq_s32(p2) * scales[2*j+1]; - } - - sumf += d * (sumi1 + sumi2); - - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - - __m256 acc = _mm256_setzero_ps(); - __m128 acc_m = _mm_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); - - const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); - const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); - const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); - acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); - - const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); - const __m256i scales = MM256_SET_M128I(sc128, sc128); - - __m256i sumi = _mm256_setzero_si256(); - - for (int j = 0; j < QK_K/64; ++j) { - - const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); - - const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4l = _mm256_and_si256(q4bits, m4); - const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); - - const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); - p16l = _mm256_madd_epi16(scale_l, p16l); - - const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); - p16h = _mm256_madd_epi16(scale_h, p16h); - const __m256i sumj = _mm256_add_epi32(p16l, p16h); - - sumi = _mm256_add_epi32(sumi, sumj); - } - - __m256 vd = _mm256_set1_ps(d); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); - - } - - acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); - acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); - - *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(0x2); - - __m256 acc = _mm256_setzero_ps(); - __m128 acc_m = _mm_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); - const __m128i scales = _mm_cvtepu8_epi16(utmps); - const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); - - const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); - const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); - const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); - const __m128i prod = _mm_madd_epi16(mins, q8s); - acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - __m128i shuffle = _mm_set1_epi16(0x0100); - for (int j = 0; j < QK_K/64; ++j) { - - const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - - __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4l_0 = _mm_and_si128(q4bits, m4); - const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); - q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4l_1 = _mm_and_si128(q4bits, m4); - const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); - - const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); - p16l = _mm_madd_epi16(scale_l, p16l); - sumi_0 = _mm_add_epi32(sumi_0, p16l); - const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - p16l = _mm_maddubs_epi16(q4l_1, q8l_1); - p16l = _mm_madd_epi16(scale_l, p16l); - sumi_1 = _mm_add_epi32(sumi_1, p16l); - - const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); - p16h = _mm_madd_epi16(scale_h, p16h); - sumi_0 = _mm_add_epi32(sumi_0, p16h); - const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - p16h = _mm_maddubs_epi16(q4h_1, q8h_1); - p16h = _mm_madd_epi16(scale_h, p16h); - sumi_1 = _mm_add_epi32(sumi_1, p16h); - - } - - __m256 vd = _mm256_set1_ps(d); - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); - - } - - acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); - acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); - - *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); - -#elif defined __riscv_v_intrinsic - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - size_t vl = 8; - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - vl = 32; - - int32_t sum_1 = 0; - int32_t sum_2 = 0; - - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - - for (int j = 0; j < QK_K/64; ++j) { - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); - - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); - - sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; - - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); - - sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; - - q4 += 32; q8 += 64; - - } - - sumf += d*(sum_1 + sum_2); - - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed char lowMask1 = vec_splats((int8_t)0x3f); - const vector signed char lowMask2 = vec_splats((int8_t)0x30); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v2 = vec_splats((uint8_t)2); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[i].dmin)); - vector float vdmin = vec_mul(vxmin, vyd); - - vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); - vector signed short q8ysums1 = vec_xl(16, y[i].bsums); - - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(kmask3); - UNUSED(utmp); - - vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); - vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); - vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); - vector signed char u3 = vec_sr(u2, v4); - - vector signed char u30 = u1; - vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); - - u1 = vec_and(u0, lowMask1); - u2 = vec_or(u30, u31); - - vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); - - vector signed short vscales = vec_unpackh(utmps); - vector signed short q4xmins = vec_unpackl(utmps); - vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins); - vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins); - - vector signed int prod0 = vec_mule(q4xmins0, q8ysums0); - vector signed int prod1 = vec_mule(q4xmins1, q8ysums1); - vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0); - vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1); - - vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); - vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); - vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); - vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/64; j+=2) { - __builtin_prefetch(q4, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); - vector signed char qxs1 = (vector signed char)vec_xl(16, q4); - vector signed char qxs2 = (vector signed char)vec_xl(32, q4); - vector signed char qxs3 = (vector signed char)vec_xl(48, q4); - q4 += 64; - - vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask); - vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4); - vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask); - vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4); - vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask); - vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4); - vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask); - vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl( 16, q8); - vector signed char q8y01 = vec_xl( 32, q8); - vector signed char q8y11 = vec_xl( 48, q8); - vector signed char q8y20 = vec_xl( 64, q8); - vector signed char q8y30 = vec_xl( 80, q8); - vector signed char q8y21 = vec_xl( 96, q8); - vector signed char q8y31 = vec_xl(112, q8); - q8 += 128; - - vector signed int qv00 = vec_msum(q8y00, q4x00, v0); - vector signed int qv01 = vec_msum(q8y01, q4x01, v0); - vector signed int qv10 = vec_msum(q8y10, q4x10, v0); - vector signed int qv11 = vec_msum(q8y11, q4x11, v0); - vector signed int qv20 = vec_msum(q8y20, q4x20, v0); - vector signed int qv21 = vec_msum(q8y21, q4x21, v0); - vector signed int qv30 = vec_msum(q8y30, q4x30, v0); - vector signed int qv31 = vec_msum(q8y31, q4x31, v0); - - vector signed int vscales_h = vec_unpackh(vscales); - vector signed int vs0 = vec_splat(vscales_h, 0); - vector signed int vs1 = vec_splat(vscales_h, 1); - vector signed int vs2 = vec_splat(vscales_h, 2); - vector signed int vs3 = vec_splat(vscales_h, 3); - vscales = vec_sld(vscales, vscales, 8); - - vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); - vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1); - vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2); - vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3); - - vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0); - vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1); - vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2); - vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - LM_GGML_UNUSED(kmask1); - LM_GGML_UNUSED(kmask2); - LM_GGML_UNUSED(kmask3); - - const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); - - __m256 acc = (__m256)__lasx_xvldi(0); - __m128 acc_m = (__m128)__lsx_vldi(0); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); - - const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); - const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); - const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); - acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); - - const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); - const __m256i scales = lasx_insertf128(sc128, sc128); - - __m256i sumi = __lasx_xvldi(0); - - for (int j = 0; j < QK_K/64; ++j) { - - const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); - - const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; - const __m256i q4l = __lasx_xvand_v(q4bits, m4); - const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4); - - const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - __m256i p16l = lasx_maddubs_h(q4l, q8l); - p16l = lasx_madd_h(scale_l, p16l); - - const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - __m256i p16h = lasx_maddubs_h(q4h, q8h); - p16h = lasx_madd_h(scale_h, p16h); - const __m256i sumj = __lasx_xvadd_w(p16l, p16h); - - sumi = __lasx_xvadd_w(sumi, sumj); - } - - __m256 vd = __lasx_xvreplfr2vr_s(d); - acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); - - } - - acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee)); - __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0); - acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1); - - - ft_union fi; - fi.i = __lsx_vpickve2gr_w(acc_m, 0); - *s = hsum_float_8(acc) + fi.f ; -#else - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - for (int j = 0; j < QK_K/64; ++j) { - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); - a += 32; - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); - a += 32; q4 += 32; - } - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - int sumi = 0; - for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/32; ++j) { - int32_t scale = scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - const float dmin = LM_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - sumf -= dmin * sumi; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} - -void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q5_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - - uint32_t utmp[4]; - -#ifdef __ARM_NEON - const uint8x16_t m4b = vdupq_n_u8(0xf); - const uint8x16_t mone = vdupq_n_u8(1); - const uint8x16_t mtwo = vdupq_n_u8(2); - const int32x4_t mzero = vdupq_n_s32(0); - - lm_ggml_int8x16x4_t q5bytes; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - int32_t sumi_mins = vaddvq_s32(prod); - - const uint8_t * scales = (const uint8_t *)utmp; - - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - lm_ggml_uint8x16x2_t qhbits = lm_ggml_vld1q_u8_x2(qh); - - lm_ggml_uint8x16x4_t q5h; - - int32_t sumi = 0; - - for (int j = 0; j < QK_K/64; ++j) { - - const lm_ggml_uint8x16x2_t q5bits = lm_ggml_vld1q_u8_x2(q5); q5 += 32; - const lm_ggml_int8x16x4_t q8bytes = lm_ggml_vld1q_s8_x4(q8); q8 += 64; - - q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); - q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); - q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); - q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); - qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); - qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); - - q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); - q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); - q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); - q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); - - sumi += vaddvq_s32(lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; - sumi += vaddvq_s32(lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; - } - - sumf += d * sumi - dmin * sumi_mins; - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m128i mzero = _mm_setzero_si128(); - const __m256i mone = _mm256_set1_epi8(1); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.f; - - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); - - const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); - const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); - const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); - const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); - summs += dmin * _mm_extract_epi32(hsum, 0); - - const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); - const __m256i scales = MM256_SET_M128I(sc128, sc128); - - const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); - __m256i hmask = mone; - - __m256i sumi = _mm256_setzero_si256(); - - int bit = 0; - - for (int j = 0; j < QK_K/64; ++j) { - - const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); - - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; - - const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); - const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); - const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); - hmask = _mm256_slli_epi16(hmask, 1); - - const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); - const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); - const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); - hmask = _mm256_slli_epi16(hmask, 1); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); - - p16_0 = _mm256_madd_epi16(scale_0, p16_0); - p16_1 = _mm256_madd_epi16(scale_1, p16_1); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); - - } - - __m256 vd = _mm256_set1_ps(d); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); - - } - - *s = hsum_float_8(acc) + summs; - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i mzero = _mm_setzero_si128(); - const __m128i mone = _mm_set1_epi8(1); - const __m128i m2 = _mm_set1_epi8(2); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.f; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); - const __m128i scales = _mm_cvtepu8_epi16(utmps); - const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); - - const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); - const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); - const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); - const __m128i prod = _mm_madd_epi16(mins, q8s); - const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); - summs += dmin * _mm_extract_epi32(hsum, 0); - - const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]); - const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]); - __m128i hmask = mone; - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - int bit = 0; - - __m128i shuffle = _mm_set1_epi16(0x0100); - for (int j = 0; j < QK_K/64; ++j) { - - const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - - const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; - const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; - - __m128i q5l_0 = _mm_and_si128(q5bits_0, m4); - __m128i q5l_1 = _mm_and_si128(q5bits_1, m4); - __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); - __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); - __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0); - __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1); - hmask = _mm_slli_epi16(hmask, 1); - - __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1); - p16_0 = _mm_madd_epi16(scale_0, p16_0); - p16_1 = _mm_madd_epi16(scale_0, p16_1); - - q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4); - q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4); - q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); - q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); - q5_0 = _mm_add_epi8(q5l_0, q5h_0); - q5_1 = _mm_add_epi8(q5l_1, q5h_1); - hmask = _mm_slli_epi16(hmask, 1); - - q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0); - __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1); - p16_2 = _mm_madd_epi16(scale_1, p16_2); - p16_3 = _mm_madd_epi16(scale_1, p16_3); - - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); - - } - - __m256 vd = _mm256_set1_ps(d); - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); - - } - - *s = hsum_float_8(acc) + summs; - -#elif defined __riscv_v_intrinsic - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - float sumf = 0; - float sums = 0.0; - - size_t vl; - - for (int i = 0; i < nb; ++i) { - - vl = 8; - - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict hm = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const float dmin = LM_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - - vl = 32; - int32_t aux32 = 0; - int is = 0; - - uint8_t m = 1; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); - - for (int j = 0; j < QK_K/64; ++j) { - // load Q5 and Q8 - vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); - vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); - - // compute mask for addition - vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu(vmask_1, q5_a, q5_a, 16, vl); - m <<= 1; - - vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu(vmask_2, q5_l, q5_l, 16, vl); - m <<= 1; - - vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); - vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); - - vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); - vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); - - vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); - vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); - - aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); - q5 += 32; q8 += 64; - - } - - vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); - sums += __riscv_vfmv_f_s_f32m1_f32(vaux); - - } - - *s = sumf+sums; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed char lowMask1 = vec_splats((int8_t)0x3f); - const vector signed char lowMask2 = vec_splats((int8_t)0x30); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v1 = vec_splats((unsigned char)0x1); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v3 = vec_splats((unsigned char)0x3); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(LM_GGML_FP16_TO_FP32(x[i].dmin)); - vector float vdmin = vec_mul(vxmin, vyd); - - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(kmask3); - UNUSED(utmp); - - vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); - vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); - vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); - vector signed char u3 = vec_sr(u2, v4); - - vector signed char u30 = u1; - vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); - - u1 = vec_and(u0, lowMask1); - u2 = vec_or(u30, u31); - - vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); - - vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); - vector signed short q8ysums1 = vec_xl(16, y[i].bsums); - - vector signed short vscales = vec_unpackh(utmps); - - vector signed short q5xmins = vec_unpackl(utmps); - vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins); - vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins); - - vector signed int prod0 = vec_mule(q5xmins0, q8ysums0); - vector signed int prod1 = vec_mule(q5xmins1, q8ysums1); - vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0); - vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1); - - vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); - vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); - vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); - vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); - - vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh); - vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/64; ++j) { - __builtin_prefetch(q5, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q5); - vector signed char qxs1 = (vector signed char)vec_xl(16, q5); - q5 += 32; - - vector signed char qxs00 = vec_and(qxs0, lowMask); - vector signed char qxs01 = vec_sr(qxs0, v4); - vector signed char qxs10 = vec_and(qxs1, lowMask); - vector signed char qxs11 = vec_sr(qxs1, v4); - - vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4); - vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3); - vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4); - vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3); - qxhs0 = vec_sr(qxhs0, v2); - qxhs1 = vec_sr(qxhs1, v2); - - vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00); - vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01); - vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10); - vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl(16, q8); - vector signed char q8y01 = vec_xl(32, q8); - vector signed char q8y11 = vec_xl(48, q8); - q8 += 64; - - vector signed int qv00 = vec_msum(q8y00, q5x00, v0); - vector signed int qv01 = vec_msum(q8y01, q5x01, v0); - vector signed int qv10 = vec_msum(q8y10, q5x10, v0); - vector signed int qv11 = vec_msum(q8y11, q5x11, v0); - - vector signed int vscales_h = vec_unpackh(vscales); - vector signed int vs0 = vec_splat(vscales_h, 0); - vector signed int vs1 = vec_splat(vscales_h, 1); - vscales = vec_sld(vscales, vscales, 12); - - vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); - vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1); - vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2); - vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - LM_GGML_UNUSED(kmask1); - LM_GGML_UNUSED(kmask2); - LM_GGML_UNUSED(kmask3); - - const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); - const __m128i mzero = __lsx_vldi(0); - const __m256i mone = __lasx_xvreplgr2vr_b(1); - - __m256 acc = (__m256)__lasx_xvldi(0); - - float summs = 0.f; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); - - const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); - const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); - const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); - const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero); - summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check - - const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); - const __m256i scales = lasx_insertf128(sc128, sc128); - - const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0); - __m256i hmask = mone; - - __m256i sumi = __lasx_xvldi(0); - - int bit = 0; - __m256i xvbit; - - for (int j = 0; j < QK_K/64; ++j) { - - const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); - - const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32; - - xvbit = __lasx_xvreplgr2vr_h(bit++); - const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4); - const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); - const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0); - hmask = __lasx_xvslli_h(hmask, 1); - - xvbit = __lasx_xvreplgr2vr_h(bit++); - const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4); - const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); - const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1); - hmask = __lasx_xvslli_h(hmask, 1); - - const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0); - __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1); - - p16_0 = lasx_madd_h(scale_0, p16_0); - p16_1 = lasx_madd_h(scale_1, p16_1); - - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); - - } - - __m256 vd = __lasx_xvreplfr2vr_s(d); - acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); - - } - - *s = hsum_float_8(acc) + summs; - -#else - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const uint8_t * restrict hm = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - uint8_t m = 1; - for (int j = 0; j < QK_K/64; ++j) { - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); - for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); - for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); - a += 32; m <<= 1; - q4 += 32; - } - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - int sumi = 0; - for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/32; ++j) { - int32_t scale = scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - const float dmin = LM_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - sumf -= dmin * sumi; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} - -void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q6_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - float sum = 0; - - const uint8x16_t m4b = vdupq_n_u8(0xF); - const int32x4_t vzero = vdupq_n_s32(0); - //const int8x16_t m32s = vdupq_n_s8(32); - - const uint8x16_t mone = vdupq_n_u8(3); - - lm_ggml_int8x16x4_t q6bytes; - lm_ggml_uint8x16x4_t q6h; - - for (int i = 0; i < nb; ++i) { - - const float d_all = LM_GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - const lm_ggml_int16x8x2_t q8sums = lm_ggml_vld1q_s16_x2(y[i].bsums); - const int8x16_t scales = vld1q_s8(scale); - const lm_ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}}; - - const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), - vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), - vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), - vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); - int32_t isum_mins = vaddvq_s32(prod); - - int32_t isum = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - lm_ggml_uint8x16x2_t qhbits = lm_ggml_vld1q_u8_x2(qh); qh += 32; - lm_ggml_uint8x16x4_t q6bits = lm_ggml_vld1q_u8_x4(q6); q6 += 64; - lm_ggml_int8x16x4_t q8bytes = lm_ggml_vld1q_s8_x4(q8); q8 += 64; - - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); - uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 2); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - - //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); - //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); - //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); - //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); - q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); - q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); - q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); - q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); - - isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - - scale += 4; - - q8bytes = lm_ggml_vld1q_s8_x4(q8); q8 += 64; - - shifted = vshrq_n_u8(qhbits.val[0], 4); - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 4); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[0], 6); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 6); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - - //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); - //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); - //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); - //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); - q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); - q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); - q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); - q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); - - isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - scale += 4; - } - //sum += isum * d_all * y[i].d; - sum += d_all * y[i].d * (isum - 32 * isum_mins); - - } - *s = sum; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); - - __m256i sumi = _mm256_setzero_si256(); - - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); - is += 4; - - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; - - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); - - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); - const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); - - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); - __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); - __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); - - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); - - p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); - p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); - p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); - - } - - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m3 = _mm_set1_epi8(3); - const __m128i m32s = _mm_set1_epi8(32); - const __m128i m2 = _mm_set1_epi8(2); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); - for (int j = 0; j < QK_K/128; ++j) { - - const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; - const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; - - const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); - const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); - const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4); - const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4); - const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4); - const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4); - const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4); - const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4); - - const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - - const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0); - const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1); - const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2); - const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3); - const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4); - const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5); - const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6); - const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7); - - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - - __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0); - __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1); - __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2); - __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3); - __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4); - __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5); - __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6); - __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7); - - __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); - __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); - __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); - __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); - __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); - __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); - __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); - - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - p16_4 = _mm_sub_epi16(p16_4, q8s_4); - p16_5 = _mm_sub_epi16(p16_5, q8s_5); - p16_6 = _mm_sub_epi16(p16_6, q8s_6); - p16_7 = _mm_sub_epi16(p16_7, q8s_7); - - const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - - p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); - p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); - p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); - p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); - p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5); - p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); - p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7); - - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); - - } - - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - size_t vl; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - - int sum_t = 0; - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - vl = 32; - - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); - - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); - - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); - - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); - - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); - - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); - - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - - vl = 16; - - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - - q6 += 64; qh += 32; q8 += 128; is=8; - - } - - sumf += d * sum_t; - - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v3 = vec_splats((unsigned char)0x3); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector unsigned char v6 = vec_splats((unsigned char)0x6); - const vector signed char off = vec_splats((signed char)0x20); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - vector signed int vsumi4 = v0; - vector signed int vsumi5 = v0; - vector signed int vsumi6 = v0; - vector signed int vsumi7 = v0; - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict qs = x[i].scales; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/128; ++j) { - __builtin_prefetch(q6, 0, 0); - __builtin_prefetch(qh, 0, 0); - __builtin_prefetch(q8, 0, 0); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q6); - vector signed char qxs1 = (vector signed char)vec_xl(16, q6); - vector signed char qxs2 = (vector signed char)vec_xl(32, q6); - vector signed char qxs3 = (vector signed char)vec_xl(48, q6); - q6 += 64; - - vector signed char qxs00 = vec_and(qxs0, lowMask); - vector signed char qxs01 = vec_sr(qxs0, v4); - vector signed char qxs10 = vec_and(qxs1, lowMask); - vector signed char qxs11 = vec_sr(qxs1, v4); - vector signed char qxs20 = vec_and(qxs2, lowMask); - vector signed char qxs21 = vec_sr(qxs2, v4); - vector signed char qxs30 = vec_and(qxs3, lowMask); - vector signed char qxs31 = vec_sr(qxs3, v4); - - vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh); - vector signed char qxhs1 = (vector signed char)vec_xl(16, qh); - qh += 32; - - vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4); - vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4); - vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4); - vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4); - vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4); - vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4); - vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4); - vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4); - - vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off); - vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off); - vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off); - vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off); - vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off); - vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off); - vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off); - vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl( 16, q8); - vector signed char q8y20 = vec_xl( 32, q8); - vector signed char q8y30 = vec_xl( 48, q8); - vector signed char q8y01 = vec_xl( 64, q8); - vector signed char q8y11 = vec_xl( 80, q8); - vector signed char q8y21 = vec_xl( 96, q8); - vector signed char q8y31 = vec_xl(112, q8); - q8 += 128; - - vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00)); - vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10)); - vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20)); - vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30)); - vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01)); - vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11)); - vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21)); - vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31)); - - vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8)); - qs += 8; - - vector signed short vs0 = vec_splat(vscales, 0); - vector signed short vs1 = vec_splat(vscales, 1); - vector signed short vs2 = vec_splat(vscales, 2); - vector signed short vs3 = vec_splat(vscales, 3); - vector signed short vs4 = vec_splat(vscales, 4); - vector signed short vs5 = vec_splat(vscales, 5); - vector signed short vs6 = vec_splat(vscales, 6); - vector signed short vs7 = vec_splat(vscales, 7); - - vsumi0 = vec_msum(qv00, vs0, vsumi0); - vsumi1 = vec_msum(qv01, vs4, vsumi1); - vsumi2 = vec_msum(qv10, vs1, vsumi2); - vsumi3 = vec_msum(qv11, vs5, vsumi3); - vsumi4 = vec_msum(qv20, vs2, vsumi4); - vsumi5 = vec_msum(qv21, vs6, vsumi5); - vsumi6 = vec_msum(qv30, vs3, vsumi6); - vsumi7 = vec_msum(qv31, vs7, vsumi7); - } - - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - - const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); - const __m256i m2 = __lasx_xvreplgr2vr_b(3); - const __m256i m32s = __lasx_xvreplgr2vr_b(32); - - __m256 acc = (__m256)__lasx_xvldi(0); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0); - - __m256i sumi = __lasx_xvldi(0); - - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3)); - is += 4; - - const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; - const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; - const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32; - - const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4); - const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4); - - const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0); - const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1); - const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3); - - const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0); - __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1); - __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2); - __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3); - - __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0); - __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1); - __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2); - __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3); - - p16_0 = __lasx_xvsub_h(p16_0, q8s_0); - p16_1 = __lasx_xvsub_h(p16_1, q8s_1); - p16_2 = __lasx_xvsub_h(p16_2, q8s_2); - p16_3 = __lasx_xvsub_h(p16_3, q8s_3); - - p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0); - p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1); - p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2); - p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3); - - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3)); - } - - acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); - } - - *s = hsum_float_8(acc); - -#else - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) { - a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - } - a += 128; - q4 += 64; - qh += 32; - } - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/16; ++j) { - int scale = x[i].scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} - -#if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx) -static const int8_t keven_signs_q2xs[1024] = { - 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, - 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, - 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, - 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, - 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, - 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, - 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, - 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, - 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, - 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, - 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, - 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, - 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, - 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, - 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, - 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, - 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, - 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, - 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, - 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, - 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, - 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, - 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, - 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, - 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, - 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, - 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, - 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, - 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, - 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, - 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, - 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -}; -#endif - -void lm_ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq2_xxs * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - lm_ggml_int8x16x4_t q2u; - lm_ggml_int8x16x4_t q2s; - lm_ggml_int8x16x4_t q8b; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - float sumf1 = 0, sumf2 = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; - memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; - q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1]))); - q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3]))); - q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9]))); - q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11]))); - q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); - q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); - q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127)))); - q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127)))); - q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); - q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); - q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); - q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); - const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]); - const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]); - sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28)); - sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28)); - } - sumf += d*(sumf1 + sumf2); - } - *s = 0.25f * sumf; - -#elif defined(__AVX2__) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; - const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); - const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); - const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], - signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], - signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); - const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); - const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - const uint16_t ls1 = aux32[1] >> 28; - const uint16_t ls2 = aux32[3] >> 28; - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); - sumi1 = _mm256_add_epi32(sumi1, p1); - sumi2 = _mm256_add_epi32(sumi2, p2); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__AVX__) - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; - const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); - const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]); - const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); - const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]); - const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); - const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); - const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]); - const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); - const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); - const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); - const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - const uint16_t ls1 = aux32[1] >> 28; - const uint16_t ls2 = aux32[3] >> 28; - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); - sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); - sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); - sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); - sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__POWER9_VECTOR__) - const vector int v0 = vec_splats((int32_t)0); - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q2, 0, 1); - __builtin_prefetch(q8, 0, 1); - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - memcpy(aux32, q2, 4*sizeof(uint32_t)); - q2 += 8; - - vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1])}; - vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3])}; - vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9])}; - vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11])}; - - vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127))}; - vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127))}; - vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127))}; - vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127))}; - - vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); - vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); - vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); - vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); - - const uint16_t ls0 = aux32[1] >> 28; - const uint16_t ls1 = aux32[3] >> 28; - - vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1)); - vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1)); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = 0.125f * vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; - - const __m256i q2_1 = lasx_set_d(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); - const __m256i q2_2 = lasx_set_d(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); - const __m256i s2_1 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], - signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m256i s2_2 = lasx_set_d(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], - signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); - const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1); - const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2); - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); - const uint16_t ls1 = aux32[1] >> 28; - const uint16_t ls2 = aux32[3] >> 28; - const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); - const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); - sumi1 = __lasx_xvadd_w(sumi1, p1); - sumi2 = __lasx_xvadd_w(sumi2, p2); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - } - - *s = 0.125f * hsum_float_8(accumf); - -#else - - uint32_t aux32[2]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - float sumf = 0.f; - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - int32_t bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - memcpy(aux32, q2, 2*sizeof(uint32_t)); - q2 += 4; - const uint32_t ls = 2*(aux32[1] >> 28) + 1; - int32_t sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); - const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; - for (int j = 0; j < 8; ++j) { - sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - bsum += sumi * ls; - } - sumf += d * bsum; - } - *s = 0.125f * sumf; -#endif -} - -void lm_ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq2_xs * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - lm_ggml_int8x16x4_t q2u; - lm_ggml_int8x16x4_t q2s; - lm_ggml_int8x16x4_t q8b; - - int32x4x4_t scales32; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - const uint8x8_t scales8 = vld1_u8(x[i].scales); - const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf)); - const uint8x8_t scales_h = vshr_n_u8(scales8, 4); - uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); - scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1)); - const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales)); - const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales)); - scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1))); - scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1))); - scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2))); - scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2))); - int32x4_t sumi = vdupq_n_s32(0); - for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { - q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; - q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511)))); - q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511)))); - q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511)))); - q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511)))); - q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9)))); - q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9)))); - q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9)))); - q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9)))); - q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); - q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); - q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); - q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); - const int32x4_t p1 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]); - const int32x4_t p2 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]); - const int32x4_t p3 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]); - const int32x4_t p4 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]); - const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4)); - sumi = vmlaq_s32(sumi, p, scales32.val[ib64]); - q2 += 8; - } - sumf += d*vaddvq_s32(sumi); - } - *s = 0.125f * sumf; - -#elif defined(__AVX2__) - - const __m256i mone = _mm256_set1_epi8(1); - static const char block_sign_shuffle_mask_1[32] = { - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, - 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, - }; - static const char block_sign_shuffle_mask_2[32] = { - 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, - 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, - }; - static const uint8_t bit_selector_mask_bytes[32] = { - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes); - const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1); - const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2); - - static const uint8_t k_bit_helper[32] = { - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - }; - const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper); - const __m256i m511 = _mm256_set1_epi16(511); - const __m128i m4 = _mm_set1_epi8(0xf); - const __m128i m1 = _mm_set1_epi8(1); - - uint64_t aux64; - - // somewhat hacky, but gives a significant boost in performance - __m256i aux_gindex; - const uint16_t * gindex = (const uint16_t *)&aux_gindex; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - __m128i stmp = _mm_set1_epi64x(aux64); - stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); - const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); - - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { - - const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16; - aux_gindex = _mm256_and_si256(q2_data, m511); - - const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9); - const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13); - const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper); - - const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting); - const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits); - - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - - const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], - iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); - const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], - iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); - const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], - iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); - const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], - iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); - - const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits); - const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1); - const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l); - const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h); - - __m256i signs; - signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1); - signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone)); - - signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2); - signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone)); - - signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1); - signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone)); - - signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2); - signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone)); - - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3); - const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4); - - const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0))); - const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1))); - const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2))); - const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3))); - - sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1)); - sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2)); - sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3)); - sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4)); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__AVX__) - const __m128i mone = _mm_set1_epi8(1); - static const char block_sign_shuffle_mask_1[32] = { - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, - 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, - }; - static const char block_sign_shuffle_mask_2[32] = { - 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, - 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, - }; - static const uint8_t bit_selector_mask_bytes[32] = { - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes); - const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1); - const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1); - const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1); - const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2); - const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1); - - static const uint8_t k_bit_helper[32] = { - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - }; - const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper); - const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1); - const __m128i m511 = _mm_set1_epi16(511); - const __m128i m4 = _mm_set1_epi8(0xf); - const __m128i m1 = _mm_set1_epi8(1); - - uint64_t aux64; - - // somewhat hacky, but gives a significant boost in performance - __m256i aux_gindex; - const uint16_t * gindex = (const uint16_t *)&aux_gindex; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - __m128i stmp = _mm_set1_epi64x(aux64); - stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); - const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); - - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { - - const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2); - const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16; - aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511)); - - const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9); - const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9); - const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13); - const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13); - const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0); - const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1); - - const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0); - const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1); - const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0); - const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1); - - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - - const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]); - const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]); - const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]); - const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]); - const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]); - const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]); - const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); - const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]); - - // AVX2 full_signs_1 is full_sign_bits_0 here - // AVX2 full_signs_2 is full_sign_bits_1 here - __m128i signs_0, signs_1; - signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0); - signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1); - signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); - signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); - const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone)); - const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone)); - - signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0); - signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1); - signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); - signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); - const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone)); - const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone)); - - signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0); - signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1); - signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); - signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); - const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone)); - const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone)); - - signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0); - signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1); - signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); - signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); - const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone)); - const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone)); - - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0); - const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1); - const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0); - const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1); - - __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)); - const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp); - const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)); - const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp); - const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)); - const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp); - const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)); - const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp); - const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - - sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0)); - sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1)); - sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0)); - sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1)); - sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0)); - sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1)); - sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0)); - sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1)); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__loongarch_asx) - - const __m256i mone = __lasx_xvreplgr2vr_b(1); - static const char block_sign_shuffle_mask_1[32] = { - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, - 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, - }; - static const char block_sign_shuffle_mask_2[32] = { - 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, - 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, - }; - static const uint8_t bit_selector_mask_bytes[32] = { - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m256i bit_selector_mask = __lasx_xvld((const __m256i*)bit_selector_mask_bytes, 0); - const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0); - const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0); - - static const uint8_t k_bit_helper[32] = { - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - }; - const __m256i bit_helper = __lasx_xvld((const __m256i*)k_bit_helper, 0); - const __m256i m511 = __lasx_xvreplgr2vr_h(511); - const __m128i m4 = __lsx_vreplgr2vr_b(0xf); - const __m128i m1 = __lsx_vreplgr2vr_b(1); - - uint64_t aux64; - - // somewhat hacky, but gives a significant boost in performance - __m256i aux_gindex; - const uint16_t * gindex = (const uint16_t *)&aux_gindex; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - __m128i stmp = __lsx_vreplgr2vr_d(aux64); - stmp = __lsx_vilvl_b( __lsx_vand_v(__lsx_vsrli_h(stmp, 4), m4), __lsx_vand_v(stmp, m4)); - const __m128i scales = __lsx_vadd_b(__lsx_vslli_h(stmp, 1), m1); - - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { - - const __m256i q2_data = __lasx_xvld((const __m256i*)q2, 0); q2 += 16; - aux_gindex = __lasx_xvand_v(q2_data, m511); - - const __m256i partial_sign_bits = __lasx_xvsrli_h(q2_data, 9); - const __m256i partial_sign_bits_upper = __lasx_xvsrli_h(q2_data, 13); - const __m256i partial_sign_bits_for_counting = __lasx_xvxor_v(partial_sign_bits, partial_sign_bits_upper); - - const __m256i odd_bits = lasx_shuffle_b(bit_helper, partial_sign_bits_for_counting); - const __m256i full_sign_bits = __lasx_xvor_v(partial_sign_bits, odd_bits); - - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_3 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_4 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - - const __m256i q2_1 = lasx_set_d(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], - iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); - const __m256i q2_2 = lasx_set_d(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], - iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); - const __m256i q2_3 = lasx_set_d(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], - iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); - const __m256i q2_4 = lasx_set_d(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], - iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); - - const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0); - const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1); - const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l); - const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h); - - __m256i signs; - signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1); - signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1); - - signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_2); - signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2); - - signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_1); - signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_3 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_3); - - signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_2); - signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_4 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_4); - - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); - const __m256i dot3 = lasx_maddubs_h(q2_3, q8s_3); - const __m256i dot4 = lasx_maddubs_h(q2_4, q8s_4); - - const __m256i sc1 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+0))); - const __m256i sc2 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+1))); - const __m256i sc3 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+2))); - const __m256i sc4 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+3))); - - sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot1, sc1)); - sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot2, sc2)); - sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot3, sc3)); - sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot4, sc4)); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); -#elif defined(__POWER9_VECTOR__) - const vector int v0 = vec_splats((int32_t)0); - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint16_t * restrict q2 = x[i].qs; - const uint8_t * restrict sc = x[i].scales; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/64; ++j) { - __builtin_prefetch(q2, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xs_grid + (q2[0] & 511)), *(const int64_t *)(iq2xs_grid + (q2[1] & 511))}; - vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xs_grid + (q2[2] & 511)), *(const int64_t *)(iq2xs_grid + (q2[3] & 511))}; - vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xs_grid + (q2[4] & 511)), *(const int64_t *)(iq2xs_grid + (q2[5] & 511))}; - vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xs_grid + (q2[6] & 511)), *(const int64_t *)(iq2xs_grid + (q2[7] & 511))}; - - vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((q2[0] >> 9))), *(const int64_t *)(signs64 + ((q2[1] >> 9)))}; - vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((q2[2] >> 9))), *(const int64_t *)(signs64 + ((q2[3] >> 9)))}; - vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((q2[4] >> 9))), *(const int64_t *)(signs64 + ((q2[5] >> 9)))}; - vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((q2[6] >> 9))), *(const int64_t *)(signs64 + ((q2[7] >> 9)))}; - q2 += 8; - - vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); - vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); - vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); - vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); - - const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); - const uint16_t ls1 = (uint16_t)(sc[0] >> 4); - const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); - const uint16_t ls3 = (uint16_t)(sc[1] >> 4); - sc += 2; - - vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); - vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); - vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); - vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); - - vsumi0 = vec_msum(qv0, vscales0, vsumi0); - vsumi1 = vec_msum(qv1, vscales1, vsumi1); - vsumi2 = vec_msum(qv2, vscales2, vsumi2); - vsumi3 = vec_msum(qv3, vscales3, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = 0.125f * vec_extract(vsumf0, 0); -#else - - float sumf = 0.f; - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const uint8_t * restrict sc = x[i].scales; - const int8_t * restrict q8 = y[i].qs; - int32_t bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1; - const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1; - int32_t sumi = 0; - for (int l = 0; l < 2; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); - const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; - for (int j = 0; j < 8; ++j) { - sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - bsum += sumi * ls1; - sumi = 0; - for (int l = 2; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); - const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; - for (int j = 0; j < 8; ++j) { - sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - bsum += sumi * ls2; - q2 += 4; - } - sumf += d * bsum; - } - *s = 0.125f * sumf; -#endif -} - -void lm_ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq2_s * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - - const lm_ggml_uint8x16x2_t mask1 = lm_ggml_vld1q_u8_x2(k_mask1); - const uint8x16_t mask2 = vld1q_u8(k_mask2); - const uint8x16_t m1 = vdupq_n_u8(1); - const int32x4_t vzero = vdupq_n_s32(0); - - uint8x16x2_t vs; - lm_ggml_int8x16x4_t q2s; - lm_ggml_int8x16x4_t q8b; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); - const int8_t * restrict q8 = y[i].qs; - - int sumi1 = 0, sumi2 = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; - q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))), - vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300))))); - q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))), - vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300))))); - q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))), - vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300))))); - q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))), - vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300))))); - qs += 8; - - vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); - vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); - vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vceqq_u8(vs.val[0], mask2); - vs.val[1] = vceqq_u8(vs.val[1], mask2); - - q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]); - q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]); - - vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); - vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); - vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vceqq_u8(vs.val[0], mask2); - vs.val[1] = vceqq_u8(vs.val[1], mask2); - - signs += 4; - - q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]); - q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]); - - const int32x4_t p1 = lm_ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]); - const int32x4_t p2 = lm_ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]); - const int32x4_t p3 = lm_ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]); - const int32x4_t p4 = lm_ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]); - - sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf)); - sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >> 4)); - sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf)); - sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >> 4)); - } - sumf += d*(sumi1 + sumi2); - } - - *s = 0.125f * sumf; - -#elif defined(__AVX2__) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m128i m4 = _mm_set1_epi8(0xf); - const __m128i m1 = _mm_set1_epi8(1); - - const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); - const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); - - uint64_t aux64; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); - const int8_t * restrict q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); - const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 - - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], - iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], - iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], - iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); - const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], - iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], - iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], - iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); - qs += 8; - - __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); - aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); - const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); - const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); - - aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); - aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); - const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); - const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); - - signs += 4; - - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 - - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0))); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1))); - sumi1 = _mm256_add_epi32(sumi1, p1); - sumi2 = _mm256_add_epi32(sumi2, p2); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__AVX__) - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m128i m4 = _mm_set1_epi8(0xf); - const __m128i m1 = _mm_set1_epi8(1); - - const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); - const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); - const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); - const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); - - uint64_t aux64; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); - const int8_t * restrict q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); - const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8); - const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8)); - - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], - iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); - const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], - iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]); - const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], - iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); - const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], - iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]); - qs += 8; - - __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); - __m128i aux128_1 = aux128_0; - aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); - aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); - const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); - const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); - const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); - const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); - - aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); - aux128_1 = aux128_0; - aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); - aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); - const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); - const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); - const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); - const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); - - signs += 4; - - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0))); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1))); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0))); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1))); - sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); - sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); - sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); - sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__POWER9_VECTOR__) - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - - const vector int v0 = vec_splats((int32_t)0); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const vector unsigned char mask0 = vec_xl( 0, k_mask1); - const vector unsigned char mask1 = vec_xl(16, k_mask1); - const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint8_t * restrict q2 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); - const uint8_t * restrict sc = x[i].scales; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q2, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed long long aux64x2_0 = {*(const int64_t *)(iq2s_grid + (q2[0] | ((qh[0] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[1] | ((qh[0] << 6) & 0x300)))}; - vector signed long long aux64x2_1 = {*(const int64_t *)(iq2s_grid + (q2[2] | ((qh[0] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[3] | ((qh[0] << 2) & 0x300)))}; - vector signed long long aux64x2_2 = {*(const int64_t *)(iq2s_grid + (q2[4] | ((qh[1] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[5] | ((qh[1] << 6) & 0x300)))}; - vector signed long long aux64x2_3 = {*(const int64_t *)(iq2s_grid + (q2[6] | ((qh[1] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[7] | ((qh[1] << 2) & 0x300)))}; - q2 += 8; - qh += 2; - - vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); - vector signed char vsigns23 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); - signs += 4; - - vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); - vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); - vector signed char vsigns2 = vec_perm(vsigns23, vsigns23, mask0); - vector signed char vsigns3 = vec_perm(vsigns23, vsigns23, mask1); - - vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); - vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); - vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); - vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); - - vector signed char q2x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux64x2_0), vsigns0); - vector signed char q2x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux64x2_1), vsigns1); - vector signed char q2x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux64x2_2), vsigns2); - vector signed char q2x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux64x2_3), vsigns3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); - - const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); - const uint16_t ls1 = (uint16_t)(sc[0] >> 4); - const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); - const uint16_t ls3 = (uint16_t)(sc[1] >> 4); - sc += 2; - - vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); - vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); - vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); - vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); - - vsumi0 = vec_msum(qv0, vscales0, vsumi0); - vsumi1 = vec_msum(qv1, vscales1, vsumi1); - vsumi2 = vec_msum(qv2, vscales2, vsumi2); - vsumi3 = vec_msum(qv3, vscales3, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = 0.125f * vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - - const __m128i m4 = __lsx_vreplgr2vr_b(0xf); - const __m128i m1 = __lsx_vreplgr2vr_b(1); - - const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0); - const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0); - uint64_t aux64; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); - const int8_t * restrict q8 = y[i].qs; - - __m128i tmp1; - memcpy(&aux64, x[i].scales, 8); - tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64, 0); - tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64 >> 4, 1); - const __m128i scales8 = __lsx_vadd_b(__lsx_vslli_h(__lsx_vand_v(tmp1, m4), 1), m1); - const __m256i scales16 = lasx_ext8_16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 - - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q2_1 = lasx_set_d(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], - iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], - iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], - iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); - const __m256i q2_2 = lasx_set_d(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], - iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], - iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], - iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); - qs += 8; - - __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | ((uint32_t) signs[1] << 16)); - aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); - const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2); - const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1); - - aux256 = __lasx_xvreplgr2vr_w(signs[2] | ((uint32_t) signs[3] << 16)); - aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); - const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2); - const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2); - - signs += 4; - - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 - - const __m256i p1 = lasx_madd_h(dot1, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+0))); - const __m256i p2 = lasx_madd_h(dot2, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+1))); - sumi1 = __lasx_xvadd_w(sumi1, p1); - sumi2 = __lasx_xvadd_w(sumi2, p2); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - } - - *s = 0.125f * hsum_float_8(accumf); - -#else - - float sumf = 0; - for (int i = 0; i < nb; i++) { - - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint8_t * signs = qs + QK_K/8; - - int bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf); - int ls2 = 1 + 2*(x[i].scales[ib32] >> 4); - int sumi1 = 0, sumi2 = 0; - for (int l = 0; l < 2; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); - for (int j = 0; j < 8; ++j) { - sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - for (int l = 2; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); - for (int j = 0; j < 8; ++j) { - sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - bsum += ls1 * sumi1 + ls2 * sumi2; - qs += 4; - signs += 4; - } - - sumf += d * bsum; - } - - *s = 0.125f * sumf; - -#endif - -} - -void lm_ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq3_xxs * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[2]; - - lm_ggml_int8x16x4_t q3s; - lm_ggml_int8x16x4_t q8b; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict gas = x[i].qs + QK_K/4; - const int8_t * restrict q8 = y[i].qs; - float sumf1 = 0, sumf2 = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; - memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t); - const uint32x4_t aux32x4_0 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]); - const uint32x4_t aux32x4_1 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]); - const uint32x4_t aux32x4_2 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]); - const uint32x4_t aux32x4_3 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]); - q3 += 16; - q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127)))); - q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127)))); - q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); - q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); - q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0)); - q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1)); - q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2)); - q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3)); - const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); - const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); - sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28)); - sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28)); - } - sumf += d*(sumf1 + sumf2); - } - *s = 0.5f * sumf; - -#elif defined(__AVX2__) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[2]; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict gas = x[i].qs + QK_K/4; - const int8_t * restrict q8 = y[i].qs; - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], - iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - q3 += 8; - const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], - iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - q3 += 8; - memcpy(aux32, gas, 8); gas += 8; - const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], - signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); - const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], - signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); - const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - const uint16_t ls1 = aux32[0] >> 28; - const uint16_t ls2 = aux32[1] >> 28; - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); - sumi1 = _mm256_add_epi32(sumi1, p1); - sumi2 = _mm256_add_epi32(sumi2, p2); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = 0.25f * hsum_float_8(accumf); - -#elif defined(__AVX__) - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[2]; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict gas = x[i].qs + QK_K/4; - const int8_t * restrict q8 = y[i].qs; - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); - q3 += 8; - const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); - q3 += 8; - memcpy(aux32, gas, 8); gas += 8; - const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); - const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]); - const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); - const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); - const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); - const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); - const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - const uint16_t ls1 = aux32[0] >> 28; - const uint16_t ls2 = aux32[1] >> 28; - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); - sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); - sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); - sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); - sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = 0.25f * hsum_float_8(accumf); - -#elif defined(__POWER9_VECTOR__) - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - const vector int v0 = vec_splats((int32_t)0); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint8_t * restrict q3 = x[i].qs; - const uint32_t * restrict signs = (const uint32_t *)(x[i].qs + QK_K/4); - const int8_t * restrict q8 = y[i].qs; - -#pragma GCC unroll 1 - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q3, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector unsigned int aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]}; - vector unsigned int aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]}; - vector unsigned int aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]}; - vector unsigned int aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]}; - q3 += 16; - - vector unsigned long long aux64x2_0 = {(uint64_t)(signs64[(signs[0] >> 0) & 127]), (uint64_t)(signs64[(signs[0] >> 7) & 127])}; - vector unsigned long long aux64x2_1 = {(uint64_t)(signs64[(signs[0] >> 14) & 127]), (uint64_t)(signs64[(signs[0] >> 21) & 127])}; - vector unsigned long long aux64x2_2 = {(uint64_t)(signs64[(signs[1] >> 0) & 127]), (uint64_t)(signs64[(signs[1] >> 7) & 127])}; - vector unsigned long long aux64x2_3 = {(uint64_t)(signs64[(signs[1] >> 14) & 127]), (uint64_t)(signs64[(signs[1] >> 21) & 127])}; - - vector signed char q3x0 = vec_mul((vector signed char)aux64x2_0, (vector signed char)aux32x4_0); - vector signed char q3x1 = vec_mul((vector signed char)aux64x2_1, (vector signed char)aux32x4_1); - vector signed char q3x2 = vec_mul((vector signed char)aux64x2_2, (vector signed char)aux32x4_2); - vector signed char q3x3 = vec_mul((vector signed char)aux64x2_3, (vector signed char)aux32x4_3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); - - const uint16_t ls0 = (uint16_t)(signs[0] >> 28); - const uint16_t ls1 = (uint16_t)(signs[1] >> 28); - signs += 2; - - vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); - vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = 0.25f * vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[2]; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict gas = x[i].qs + QK_K/4; - const int8_t * restrict q8 = y[i].qs; - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q2_1 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], - iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - q3 += 8; - const __m256i q2_2 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], - iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - q3 += 8; - memcpy(aux32, gas, 8); gas += 8; - - const __m256i s2_1 = lasx_set_d(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], - signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); - const __m256i s2_2 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], - signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1); - const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2); - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); - const uint16_t ls1 = aux32[0] >> 28; - const uint16_t ls2 = aux32[1] >> 28; - - const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); - const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); - sumi1 = __lasx_xvadd_w(sumi1, p1); - sumi2 = __lasx_xvadd_w(sumi2, p2); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - } - - *s = 0.25f * hsum_float_8(accumf); - -#else - - uint32_t aux32; - - float sumf = 0.f; - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict gas = x[i].qs + QK_K/4; - const int8_t * restrict q8 = y[i].qs; - int32_t bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t); - const uint32_t ls = 2*(aux32 >> 28) + 1; - int32_t sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); - const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); - const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 4; ++j) { - sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1); - sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1); - } - q8 += 8; - } - q3 += 8; - bsum += sumi * ls; - } - sumf += d * bsum; - } - *s = 0.25f * sumf; -#endif -} - -void lm_ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq3_s * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - typedef union { - uint16x8_t vec_index; - uint16_t index[8]; - } vec_index_t; - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - - static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; - - const lm_ggml_uint8x16x2_t mask1 = lm_ggml_vld1q_u8_x2(k_mask1); - const uint8x16_t mask2 = vld1q_u8(k_mask2); - - const int16x8_t hshift = vld1q_s16(k_shift); - const uint16x8_t m256 = vdupq_n_u16(256); - const uint8x16_t m1 = vdupq_n_u8(1); - - uint8x16x2_t vs; - lm_ggml_int8x16x4_t q3s; - lm_ggml_int8x16x4_t q8b; - vec_index_t idx; - - uint32_t scales32[2]; - const uint8_t * scales8 = (const uint8_t *)scales32; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)x[i].signs; - const int8_t * restrict q8 = y[i].qs; - - memcpy(scales32, x[i].scales, 4); - scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; - scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; - - int sumi1 = 0, sumi2 = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; - - const uint8x16_t idx_l = vld1q_u8(qs); qs += 16; - idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256)); - const uint32x4_t aux32x4_0 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], - iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); - const uint32x4_t aux32x4_1 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], - iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); - idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256)); - const uint32x4_t aux32x4_2 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], - iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); - const uint32x4_t aux32x4_3 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], - iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); - - - vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); - vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); - vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); - vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); - - q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0)); - q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1)); - - vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); - vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); - vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); - vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); - - signs += 4; - - q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2)); - q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3)); - - const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); - const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); - - sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0]; - sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4]; - } - sumf += d*(sumi1 + sumi2); - } - *s = sumf; - -#elif defined(__AVX2__) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); - const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); - - const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); - const __m256i idx_mask = _mm256_set1_epi32(256); - - typedef union { - __m256i vec[2]; - uint32_t index[16]; - } index_t; - - index_t idx; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)x[i].signs; - const int8_t * restrict q8 = y[i].qs; - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16; - idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]); - idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]); - idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask); - idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask); - idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l))); - idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1))); - - // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. - //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); - //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); - const __m256i q2_1 = _mm256_set_epi32( - iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], - iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] - ); - const __m256i q2_2 = _mm256_set_epi32( - iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], - iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] - ); - - __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); - aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); - const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); - const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); - - aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16)); - aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); - const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); - const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); - - signs += 4; - - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; - const uint16_t ls2 = x[i].scales[ib32/2] >> 4; - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); - sumi1 = _mm256_add_epi32(sumi1, p1); - sumi2 = _mm256_add_epi32(sumi2, p2); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = hsum_float_8(accumf); - -#elif defined(__AVX__) - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); - const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); - const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); - const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); - - const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256); - const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16); - const __m128i idx_mask = _mm_set1_epi32(256); - - typedef union { - __m128i vec[4]; - uint32_t index[16]; - } index_t; - - index_t idx; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)x[i].signs; - const int8_t * restrict q8 = y[i].qs; - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs); - const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp); - const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16; - idx.vec[0] = _mm_set1_epi32(qh[ib32+0]); - idx.vec[1] = idx.vec[0]; - idx.vec[2] = _mm_set1_epi32(qh[ib32+1]); - idx.vec[3] = idx.vec[2]; - - idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask); - idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask); - idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask); - idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask); - - idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0)); - idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8))); - idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1)); - idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8))); - - const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]); - const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]); - const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]); - const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]); - - __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16)); - __m128i aux128_1 = aux128_0; - aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); - aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); - const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); - const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); - const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); - const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); - - aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16)); - aux128_1 = aux128_0; - aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); - aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); - const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); - const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); - const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); - const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); - - signs += 4; - - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; - const uint16_t ls2 = x[i].scales[ib32/2] >> 4; - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); - sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); - sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); - sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); - sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = hsum_float_8(accumf); - -#elif defined(__POWER9_VECTOR__) - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - - const vector int v0 = vec_splats((int32_t)0); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const vector unsigned char mask0 = vec_xl( 0, k_mask1); - const vector unsigned char mask1 = vec_xl(16, k_mask1); - const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)(x[i].signs); - const uint8_t * restrict sc = x[i].scales; - const int8_t * restrict q8 = y[i].qs; - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q3, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector unsigned int aux32x4_0 = {iq3s_grid[q3[ 0] | ((qh[0] << 8) & 256)], iq3s_grid[q3[ 1] | ((qh[0] << 7) & 256)], - iq3s_grid[q3[ 2] | ((qh[0] << 6) & 256)], iq3s_grid[q3[ 3] | ((qh[0] << 5) & 256)]}; - vector unsigned int aux32x4_1 = {iq3s_grid[q3[ 4] | ((qh[0] << 4) & 256)], iq3s_grid[q3[ 5] | ((qh[0] << 3) & 256)], - iq3s_grid[q3[ 6] | ((qh[0] << 2) & 256)], iq3s_grid[q3[ 7] | ((qh[0] << 1) & 256)]}; - vector unsigned int aux32x4_2 = {iq3s_grid[q3[ 8] | ((qh[1] << 8) & 256)], iq3s_grid[q3[ 9] | ((qh[1] << 7) & 256)], - iq3s_grid[q3[10] | ((qh[1] << 6) & 256)], iq3s_grid[q3[11] | ((qh[1] << 5) & 256)]}; - vector unsigned int aux32x4_3 = {iq3s_grid[q3[12] | ((qh[1] << 4) & 256)], iq3s_grid[q3[13] | ((qh[1] << 3) & 256)], - iq3s_grid[q3[14] | ((qh[1] << 2) & 256)], iq3s_grid[q3[15] | ((qh[1] << 1) & 256)]}; - q3 += 16; - qh += 2; - - vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); - vector signed char vsigns02 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); - signs += 4; - - vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); - vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); - vector signed char vsigns2 = vec_perm(vsigns02, vsigns02, mask0); - vector signed char vsigns3 = vec_perm(vsigns02, vsigns02, mask1); - - vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); - vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); - vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); - vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); - - vector signed char q3x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux32x4_0), vsigns0); - vector signed char q3x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux32x4_1), vsigns1); - vector signed char q3x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux32x4_2), vsigns2); - vector signed char q3x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux32x4_3), vsigns3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); - - const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); - const uint16_t ls1 = (uint16_t)(sc[0] >> 4); - sc ++; - - vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); - vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0); - const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0); - - __m256i idx_shift = lasx_set_w(1, 2, 3, 4, 5, 6, 7, 8); - const __m256i idx_mask = __lasx_xvreplgr2vr_w(256); - - typedef union { - __m256i vec[2]; - uint32_t index[16]; - } index_t; - - index_t idx; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)x[i].signs; - const int8_t * restrict q8 = y[i].qs; - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i idx_l = lasx_extu8_16(__lsx_vld(qs, 0)); qs += 16; - idx.vec[0] = __lasx_xvreplgr2vr_w(qh[ib32+0]); - idx.vec[1] = __lasx_xvreplgr2vr_w(qh[ib32+1]); - idx.vec[0] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[0], idx_shift), idx_mask); - idx.vec[1] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[1], idx_shift), idx_mask); - idx.vec[0] = __lasx_xvor_v(idx.vec[0], lasx_ext16_32(lasx_extracti128(idx_l, 0))); - idx.vec[1] = __lasx_xvor_v(idx.vec[1], lasx_ext16_32(lasx_extracti128(idx_l, 1))); - - // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. - //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); - //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); - const __m256i q2_1 = lasx_set_w( - iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], - iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] - ); - const __m256i q2_2 = lasx_set_w( - iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], - iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] - ); - - __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | (signs[1] << 16)); - aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); - const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2); - const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1); - - aux256 = __lasx_xvreplgr2vr_w(signs[2] | (signs[3] << 16)); - aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); - const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2); - const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2); - - signs += 4; - - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); - const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; - const uint16_t ls2 = x[i].scales[ib32/2] >> 4; - const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); - const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); - sumi1 = __lasx_xvadd_w(sumi1, p1); - sumi2 = __lasx_xvadd_w(sumi2, p2); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - } - - *s = hsum_float_8(accumf); - -#else - - float sumf = 0.f; - for (int i = 0; i < nb; ++i) { - const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint8_t * restrict signs = x[i].signs; - const int8_t * restrict q8 = y[i].qs; - int32_t bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1; - const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; - int32_t sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); - sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); - } - q8 += 8; - } - qs += 8; - signs += 4; - bsum += sumi * ls1; - sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); - sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); - } - q8 += 8; - } - qs += 8; - signs += 4; - bsum += sumi * ls2; - } - sumf += d * bsum; - } - *s = sumf; -#endif -} - -#if defined(__AVX2__) -static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { - const __m256i ax = _mm256_sign_epi8(x, x); - const __m256i sy = _mm256_sign_epi8(y, x); - return _mm256_maddubs_epi16(ax, sy); -} -#elif defined(__loongarch_asx) -static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { - const __m256i ax = __lasx_xvsigncov_b(x, x); - const __m256i sy = __lasx_xvsigncov_b(x, y); - __m256i tmp1, tmp2, tmp3; - tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy); - tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy); - tmp3 = __lasx_xvadd_h(tmp1, tmp2); - return __lasx_xvsat_h(tmp3, 15); -} -#endif - -void lm_ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq1_s * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined __ARM_NEON - - lm_ggml_int8x16x4_t q1b; - lm_ggml_int8x16x4_t q8b; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; - - int sumi1 = 0, sumi2 = 0, sumi3 = 0; - - for (int ib = 0; ib < QK_K/32; ib += 2) { - - q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700))))); - q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700))))); - q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700))))); - q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700))))); - qs += 8; - - q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; - - const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]); - const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]); - - const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; - const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; - sumi1 += vaddvq_s32(p1) * ls1; - sumi2 += vaddvq_s32(p2) * ls2; - sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1) - + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1); - - } - - sumf += y[i].d * LM_GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3); - } - - *s = sumf; - -#elif defined __AVX2__ - - __m256 accum = _mm256_setzero_ps(); - float accum1 = 0; - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; - - __m256i sumi = _mm256_setzero_si256(); - int sumi1 = 0; - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], - iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); - const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], - iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); - qs += 8; - const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); - const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); - const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; - const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1)); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2)); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2)); - sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 - + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; - } - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum); - accum1 += d * sumi1; - - } - - *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; - -#elif defined __AVX__ - __m256 accum = _mm256_setzero_ps(); - float accum1 = 0; - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; - - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - int sumi1 = 0; - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); - const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]); - const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); - const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]); - qs += 8; - const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - - const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); - const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); - const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); - const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); - const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; - const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1)); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1)); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2)); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2)); - - sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); - sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); - sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 - + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; - } - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum); - accum1 += d * sumi1; - - } - - *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; - -#elif defined(__POWER9_VECTOR__) - const vector unsigned char v0 = vec_splats((unsigned char)0x0); - const vector unsigned short vsign = vec_splats((unsigned short)0x8000); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi8 = vec_splats((int32_t)0); - - const uint8_t * restrict q1 = x[i].qs; - const uint16_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - const int16_t * restrict qs = y[i].bsums; - - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q1, 0, 1); - __builtin_prefetch(qh, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed long long aux64x2_0 = {*(const int64_t *)(iq1s_grid + (q1[0] | ((qh[0] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[1] | ((qh[0] << 5) & 0x700)))}; - vector signed long long aux64x2_1 = {*(const int64_t *)(iq1s_grid + (q1[2] | ((qh[0] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[3] | ((qh[0] >> 1) & 0x700)))}; - vector signed long long aux64x2_2 = {*(const int64_t *)(iq1s_grid + (q1[4] | ((qh[1] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[5] | ((qh[1] << 5) & 0x700)))}; - vector signed long long aux64x2_3 = {*(const int64_t *)(iq1s_grid + (q1[6] | ((qh[1] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[7] | ((qh[1] >> 1) & 0x700)))}; - q1 += 8; - - vector signed char q1x0 = (vector signed char)aux64x2_0; - vector signed char q1x1 = (vector signed char)aux64x2_1; - vector signed char q1x2 = (vector signed char)aux64x2_2; - vector signed char q1x3 = (vector signed char)aux64x2_3; - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q1x0, q8y0), vec_mulo(q1x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q1x1, q8y1), vec_mulo(q1x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q1x2, q8y2), vec_mulo(q1x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q1x3, q8y3), vec_mulo(q1x3, q8y3)); - - const uint16_t ls0 = (uint16_t)((qh[0] >> 12) & 7); - const uint16_t ls1 = (uint16_t)((qh[1] >> 12) & 7); - - vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); - vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); - vector signed short vscales = vec_sld(vscales23, vscales01, 8); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - - vector signed short q8ysums = vec_xl_len(qs, 8); - qs += 4; - q8ysums = vec_mergeh(q8ysums, (vector signed short)v0); - - vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8); - qh += 2; - vector __bool short vsel = vec_cmpge(qxh, (vector signed short)v0); - - vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel); - - vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - - vsumf0 = vec_madd(vec_ctf(vsumi8, 0), vec_mul(vd, vec_splats(IQ1S_DELTA)), vsumf0); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) +} - __m256 accum = (__m256)__lasx_xvldi(0); - float accum1 = 0; - for (int i = 0; i < nb; ++i) { +// ====================== 3.3125 bpw (de)-quantization - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; +void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - __m256i sumi = __lasx_xvldi(0); - int sumi1 = 0; - for (int ib = 0; ib < QK_K/32; ib += 2) { - __m256i q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)], 0); - q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], 1); - q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2); - q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3); + for (int i = 0; i < nb; i++) { - __m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0); - q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1); - q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2); - q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], 3); + const float d = LM_GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = x[i].signs; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf)); + const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4)); + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + y += 8; + } qs += 8; - const __m256i q8b_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8b_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); - const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); - const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; - const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; - - __m256i tmp1, tmp5, tmp6; - tmp1 = __lasx_xvreplgr2vr_h(ls1); - tmp5 = __lasx_xvmulwev_w_h(dot1, tmp1); - tmp6 = __lasx_xvmulwod_w_h(dot1, tmp1); - const __m256i p1 = __lasx_xvadd_w(tmp5, tmp6); - - tmp1 = __lasx_xvreplgr2vr_h(ls2); - tmp5 = __lasx_xvmulwev_w_h(dot2, tmp1); - tmp6 = __lasx_xvmulwod_w_h(dot2, tmp1); - const __m256i p2 = __lasx_xvadd_w(tmp5, tmp6); - - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p1, p2)); - sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 - + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + signs += 4; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + y += 8; + } + qh += 2; + qs += 8; + signs += 4; } - - const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d); - accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum); - accum1 += d * sumi1; } +} - *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; +// ====================== 1.5625 bpw (de)-quantization -#else +void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - float sumf = 0; for (int i = 0; i < nb; i++) { - const int8_t * q8 = y[i].qs; + const float d = LM_GGML_FP16_TO_FP32(x[i].d); const uint8_t * qs = x[i].qs; const uint16_t * qh = x[i].qh; - int sumi = 0, sumi1 = 0; for (int ib = 0; ib < QK_K/32; ++ib) { - const int ls = 2*((qh[ib] >> 12) & 7) + 1; - const int delta = qh[ib] & 0x8000 ? -1 : 1; - int lsum = 0; + const float dl = d * (2*((qh[ib] >> 12) & 7) + 1); + const float delta = qh[ib] & 0x8000 ? -IQ1S_DELTA : IQ1S_DELTA; for (int l = 0; l < 4; ++l) { const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); for (int j = 0; j < 8; ++j) { - lsum += q8[j] * grid[j]; + y[j] = dl * (grid[j] + delta); } - q8 += 8; + y += 8; } - sumi += ls * lsum; - sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]); qs += 4; } - - sumf += LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); } - - *s = sumf; - -#endif } -void lm_ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq1_m * restrict x = vx; - const block_q8_K * restrict y = vy; +void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - const int nb = n / QK_K; + float delta[4]; + uint16_t idx[4]; iq1m_scale_t scale; -#if defined __ARM_NEON - const int32x4_t mask = vdupq_n_s32(0x7); - const int32x4_t mone = vdupq_n_s32(1); - const int32x4_t mzero = vdupq_n_s32(0); - - lm_ggml_int8x16x4_t deltas; - deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1)); - deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1)); - deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1)); - deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1)); - - lm_ggml_int8x16x4_t q1b; - lm_ggml_int8x16x4_t q8b; - - uint32_t aux32; - const uint8_t * aux8 = (const uint8_t *)&aux32; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { + for (int i = 0; i < nb; i++) { - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; const uint16_t * sc = (const uint16_t *)x[i].scales; - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = LM_GGML_FP16_TO_FP32(scale.f16); - int32x4_t sumi1 = mzero; - int32x4_t sumi2 = mzero; - - for (int ib = 0; ib < QK_K/32; ib += 2) { - - q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700))))); - q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700))))); - q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700))))); - q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700))))); - - q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; - - const int32x4_t p1 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), lm_ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1])); - const int32x4_t p2 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), lm_ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3])); - const int32x4_t p12 = vpaddq_s32(p1, p2); - - const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that - aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202); - - const int32x4_t p3 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), lm_ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1])); - const int32x4_t p4 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), lm_ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3])); - const int32x4_t p34 = vpaddq_s32(p3, p4); - - int32x4_t scales_4 = lm_ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9); - - scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone); - - sumi1 = vmlaq_s32(sumi1, scales_4, p12); - sumi2 = vmlaq_s32(sumi2, scales_4, p34); + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; - qs += 8; qh += 4; + for (int ib = 0; ib < QK_K/32; ++ib) { + const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1); + const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1); + idx[0] = qs[0] | ((qh[0] << 8) & 0x700); + idx[1] = qs[1] | ((qh[0] << 4) & 0x700); + idx[2] = qs[2] | ((qh[1] << 8) & 0x700); + idx[3] = qs[3] | ((qh[1] << 4) & 0x700); + delta[0] = qh[0] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA; + delta[1] = qh[0] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA; + delta[2] = qh[1] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA; + delta[3] = qh[1] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA; + for (int l = 0; l < 2; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]); + for (int j = 0; j < 8; ++j) { + y[j] = dl1 * (grid[j] + delta[l]); + } + y += 8; + } + for (int l = 2; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]); + for (int j = 0; j < 8; ++j) { + y[j] = dl2 * (grid[j] + delta[l]); + } + y += 8; + } + qs += 4; + qh += 2; } - - sumf += y[i].d * LM_GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2)); } +} - *s = sumf; - -#elif defined __AVX2__ - - const __m256i mask = _mm256_set1_epi16(0x7); - const __m256i mone = _mm256_set1_epi16(1); - - __m256 accum1 = _mm256_setzero_ps(); - __m256 accum2 = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint16_t * sc = (const uint16_t *)x[i].scales; - - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m256i q1b_1 = _mm256_set_epi64x( - iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)], - iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)] - ); - const __m256i q1b_2 = _mm256_set_epi64x( - iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)], - iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)] - ); - const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); - const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); - - const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, - qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, - qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - - const __m256i dot3 = mul_add_epi8(delta1, q8b_1); - const __m256i dot4 = mul_add_epi8(delta2, q8b_2); +static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; - __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0)); - __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6)); +void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int64_t k) { + assert(k % QK4_NL == 0); + const int64_t nb = k / QK4_NL; - scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone); - scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone); - const __m256i p1 = _mm256_madd_epi16(dot1, scale1); - const __m256i p2 = _mm256_madd_epi16(dot2, scale2); - const __m256i p3 = _mm256_madd_epi16(dot3, scale1); - const __m256i p4 = _mm256_madd_epi16(dot4, scale2); + for (int i = 0; i < nb; i++) { - sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2)); - sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4)); + const uint8_t * qs = x[i].qs; - qs += 8; qh += 4; + const float d = LM_GGML_FP16_TO_FP32(x[i].d); + for (int j = 0; j < QK4_NL/2; ++j) { + y[j+ 0] = d * kvalues_iq4nl[qs[j] & 0xf]; + y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >> 4]; } - - const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(scale.f16)); - - accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1); - accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2); + y += QK4_NL; + qs += QK4_NL/2; } +} - *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); - -#elif defined __AVX__ - const __m128i mask = _mm_set1_epi16(0x7); - const __m128i mone = _mm_set1_epi16(1); - - __m256 accum1 = _mm256_setzero_ps(); - __m256 accum2 = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint16_t * sc = (const uint16_t *)x[i].scales; - - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q1b_1_0 = _mm_set_epi64x( - iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]); - const __m128i q1b_1_1 = _mm_set_epi64x( - iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]); - const __m128i q1b_2_0 = _mm_set_epi64x( - iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]); - const __m128i q1b_2_1 = _mm_set_epi64x( - iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]); - const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - - const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); - const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); - const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); - const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); - - const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - - const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0); - const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1); - const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0); - const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1); - - __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0); - __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3); - __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6); - __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9); - - scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone); - scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone); - scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone); - scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone); - const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1); - const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0); - const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1); - const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0); - const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1); - - sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); - sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); - sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0)); - sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1)); - - qs += 8; qh += 4; - } - - const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(scale.f16)); - - accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1); - accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2); - } - - *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); - -#else - - int sum1[2], sum2[2], delta[4]; +void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - float sumf = 0; for (int i = 0; i < nb; i++) { - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint16_t * sc = (const uint16_t *)x[i].scales; + const uint8_t * qs = x[i].qs; - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = LM_GGML_FP16_TO_FP32(x[i].d); - int sumi1 = 0, sumi2 = 0; for (int ib = 0; ib < QK_K/32; ++ib) { - delta[0] = qh[0] & 0x08 ? -1 : 1; - delta[1] = qh[0] & 0x80 ? -1 : 1; - delta[2] = qh[1] & 0x08 ? -1 : 1; - delta[3] = qh[1] & 0x80 ? -1 : 1; - sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0; - for (int l = 0; l < 4; ++l) { - const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700))); - int lsum1 = 0, lsum2 = 0; - for (int j = 0; j < 8; ++j) { - lsum1 += q8[j] * grid[j]; - lsum2 += q8[j]; - } - q8 += 8; - sum1[l/2] += lsum1; - sum2[l/2] += lsum2*delta[l]; + const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4); + const float dl = d * (ls - 32); + for (int j = 0; j < 16; ++j) { + y[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf]; + y[j+16] = dl * kvalues_iq4nl[qs[j] >> 4]; } - - const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1; - const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1; - - sumi1 += sum1[0] * ls1 + sum1[1] * ls2; - sumi2 += sum2[0] * ls1 + sum2[1] * ls2; - qs += 4; - qh += 2; + y += 32; + qs += 16; } - - sumf += LM_GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2); } - - *s = sumf; - -#endif } -void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - assert(n % QK4_NL == 0); - static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); - - const block_iq4_nl * restrict x = vx; - const block_q8_0 * restrict y = vy; - - const int nb = n / QK4_NL; - - int ib = 0; - float sumf = 0; - -#if defined __ARM_NEON - const int8x16_t values = vld1q_s8(kvalues_iq4nl); - const uint8x16_t m4b = vdupq_n_u8(0x0f); - uint8x16x2_t q4bits; - int8x16x4_t q4b; - int8x16x4_t q8b; - int32x4_t prod_1, prod_2; - - for (; ib + 1 < nb; ib += 2) { - - q4bits.val[0] = vld1q_u8(x[ib + 0].qs); - q4bits.val[1] = vld1q_u8(x[ib + 1].qs); - q8b.val[0] = vld1q_s8(y[ib + 0].qs); - q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16); - q8b.val[2] = vld1q_s8(y[ib + 1].qs); - q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16); - - q4b.val[0] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); - q4b.val[1] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); - q4b.val[2] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); - q4b.val[3] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); - - prod_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); - prod_2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); - - sumf += - LM_GGML_FP16_TO_FP32(x[ib+0].d) * LM_GGML_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) + - LM_GGML_FP16_TO_FP32(x[ib+1].d) * LM_GGML_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2); - } - -#elif defined __AVX2__ - - const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); - const __m128i m4b = _mm_set1_epi8(0x0f); - const __m256i mone = _mm256_set1_epi16(1); - - __m256 accum1 = _mm256_setzero_ps(); - __m256 accum2 = _mm256_setzero_ps(); - for (; ib + 1 < nb; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs); - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs); - const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs); - const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs); - const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), - _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); - const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), - _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); - const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); - const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); - const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); - const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); - accum1 = _mm256_fmadd_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 0].d)*LM_GGML_FP16_TO_FP32(x[ib + 0].d)), - _mm256_cvtepi32_ps(p_1), accum1); - accum2 = _mm256_fmadd_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 1].d)*LM_GGML_FP16_TO_FP32(x[ib + 1].d)), - _mm256_cvtepi32_ps(p_2), accum2); - } - - sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); - -#elif defined __AVX__ - const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); - const __m128i m4b = _mm_set1_epi8(0x0f); - const __m128i mone = _mm_set1_epi16(1); - - __m256 accum1 = _mm256_setzero_ps(); - __m256 accum2 = _mm256_setzero_ps(); - for (; ib + 1 < nb; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); - const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); - const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); - const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); - const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); - - const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); - const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); - const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); - const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); - const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); - const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); - const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); - const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); - const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone); - const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone); - const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone); - const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone); - accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 0].d)*LM_GGML_FP16_TO_FP32(x[ib + 0].d)), - _mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1); - accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(y[ib + 1].d)*LM_GGML_FP16_TO_FP32(x[ib + 1].d)), - _mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2); - } - - sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - - const vector signed char values = vec_xl( 0, kvalues_iq4nl); - -#pragma GCC unroll 4 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(LM_GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - vector signed char q4x0 = vec_and(qxs, lowMask); - vector signed char q4x1 = vec_sr(qxs, v4); - - q4x0 = vec_perm(values, values, (vector unsigned char)q4x0); - q4x1 = vec_perm(values, values, (vector unsigned char)q4x1); - - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl(16, y[ib].qs); - - vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - - vsumi0 = vec_sum4s(qv0, vsumi0); - vsumi1 = vec_sum4s(qv1, vsumi1); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - } - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined (__loongarch_asx) - - const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); - const __m128i m4b = __lsx_vreplgr2vr_b(0x0f); - const __m256i mone = __lasx_xvreplgr2vr_h(1); - - __m256 accum1 = (__m256)__lasx_xvldi(0); - __m256 accum2 = (__m256)__lasx_xvldi(0); - for (; ib + 1 < nb; ib += 2) { - const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0); - const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0); - const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0); - const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0); - const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)), - lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b))); - const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)), - lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b))); - const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); - const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); - const __m256i p_1 = lasx_madd_h(p16_1, mone); - const __m256i p_2 = lasx_madd_h(p16_2, mone); - accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(y[ib + 0].d)*LM_GGML_FP16_TO_FP32(x[ib + 0].d)), - __lasx_xvffint_s_w(p_1), accum1); - accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(y[ib + 1].d)*LM_GGML_FP16_TO_FP32(x[ib + 1].d)), - __lasx_xvffint_s_w(p_2), accum2); - } +//===================================== Q8_K ============================================== + +void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2)); + for (int i = 0; i < nb; i++) { -#endif - for (; ib < nb; ++ib) { - const float d = LM_GGML_FP16_TO_FP32(y[ib].d)*LM_GGML_FP16_TO_FP32(x[ib].d); - int sumi1 = 0, sumi2 = 0; - for (int j = 0; j < QK4_NL/2; ++j) { - sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; - sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; + float max = 0; + float amax = 0; + for (int j = 0; j < QK_K; ++j) { + float ax = fabsf(x[j]); + if (ax > amax) { + amax = ax; max = x[j]; + } + } + if (!amax) { + y[i].d = 0; + memset(y[i].qs, 0, QK_K); + x += QK_K; + continue; + } + //const float iscale = -128.f/max; + // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward + const float iscale = -127.f/max; + for (int j = 0; j < QK_K; ++j) { + int v = nearest_int(iscale*x[j]); + y[i].qs[j] = MIN(127, v); + } + for (int j = 0; j < QK_K/16; ++j) { + int sum = 0; + for (int ii = 0; ii < 16; ++ii) { + sum += y[i].qs[j*16 + ii]; + } + y[i].bsums[j] = sum; } - sumf += d * (sumi1 + sumi2); + y[i].d = 1/iscale; + x += QK_K; } - *s = sumf; } -void lm_ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - assert(n % QK_K == 0); - - const block_iq4_xs * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined __ARM_NEON - const int8x16_t values = vld1q_s8(kvalues_iq4nl); - const uint8x16_t m4b = vdupq_n_u8(0x0f); - lm_ggml_uint8x16x2_t q4bits; - lm_ggml_int8x16x4_t q4b; - lm_ggml_int8x16x4_t q8b; - int32x4_t prod_1, prod_2; - - float sumf = 0; - - for (int ibl = 0; ibl < nb; ++ibl) { - - const int8_t * q8 = y[ibl].qs; - const uint8_t * q4 = x[ibl].qs; - uint16_t h = x[ibl].scales_h; - - int sumi1 = 0, sumi2 = 0; - for (int ib = 0; ib < QK_K/64; ++ib) { - - q4bits = lm_ggml_vld1q_u8_x2(q4); q4 += 32; - q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64; - - q4b.val[0] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); - q4b.val[1] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); - q4b.val[2] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); - q4b.val[3] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); - - prod_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); - prod_2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); - - int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; - int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; - h >>= 4; - sumi1 += vaddvq_s32(prod_1) * ls1; - sumi2 += vaddvq_s32(prod_2) * ls2; - - } - - sumf += LM_GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); - const __m128i m4b = _mm_set1_epi8(0x0f); - - __m256 accum = _mm256_setzero_ps(); - for (int ibl = 0; ibl < nb; ++ibl) { - const uint8_t * qs = x[ibl].qs; - const int8_t * q8 = y[ibl].qs; - uint16_t sh = x[ibl].scales_h; - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs); qs += 16; - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16; - const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), - _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); - const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), - _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); - const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); - const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); - const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; - const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; - sh >>= 4; - const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1)); - const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2)); - sumi1 = _mm256_add_epi32(p_1, sumi1); - sumi2 = _mm256_add_epi32(p_2, sumi2); - } - accum = _mm256_fmadd_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), - _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum); - } - - *s = hsum_float_8(accum); - -#elif defined __AVX__ - const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); - const __m128i m4b = _mm_set1_epi8(0x0f); - - __m256 accum = _mm256_setzero_ps(); - for (int ibl = 0; ibl < nb; ++ibl) { - const uint8_t * qs = x[ibl].qs; - const int8_t * q8 = y[ibl].qs; - uint16_t sh = x[ibl].scales_h; - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16; - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16; - const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); - const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); - const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); - const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); - const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); - const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); - const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); - const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); - const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; - const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; - sh >>= 4; - const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1)); - const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1)); - const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2)); - const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2)); - sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0); - sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1); - sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0); - sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1); - } - __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0); - __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1); - accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(LM_GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), - _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum); - } - - *s = hsum_float_8(accum); - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const vector signed char values = vec_xl( 0, kvalues_iq4nl); - - for (int ibl = 0; ibl < nb; ++ibl) { - - vector float vxd = vec_splats(LM_GGML_FP16_TO_FP32(x[ibl].d)); - vector float vyd = vec_splats(y[ibl].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - uint16_t h = x[ibl].scales_h; - - const uint8_t * restrict q4 = x[ibl].qs; - const uint8_t * restrict sc = x[ibl].scales_l; - const int8_t * restrict q8 = y[ibl].qs; - - for (int ib = 0; ib < QK_K/64; ib ++ ) { - __builtin_prefetch(q4, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); - vector signed char qxs1 = (vector signed char)vec_xl(16, q4); - q4 += 32; - - vector signed char q4x00 = (vector signed char)vec_and(qxs0, lowMask); - vector signed char q4x01 = (vector signed char)vec_sr(qxs0, v4); - vector signed char q4x10 = (vector signed char)vec_and(qxs1, lowMask); - vector signed char q4x11 = (vector signed char)vec_sr(qxs1, v4); - - q4x00 = vec_perm(values, values, (vector unsigned char)q4x00); - q4x01 = vec_perm(values, values, (vector unsigned char)q4x01); - q4x10 = vec_perm(values, values, (vector unsigned char)q4x10); - q4x11 = vec_perm(values, values, (vector unsigned char)q4x11); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q4x00, q8y0), vec_mulo(q4x00, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q4x01, q8y1), vec_mulo(q4x01, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q4x10, q8y2), vec_mulo(q4x10, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q4x11, q8y3), vec_mulo(q4x11, q8y3)); - - const uint16_t ls0 = (uint16_t)(((sc[0] & 0xf) | ((h << 4) & 0x30)) - 32); - const uint16_t ls1 = (uint16_t)(((sc[0] >> 4) | ((h << 2) & 0x30)) - 32); - h >>= 4; - sc ++; - - vector signed short vscales01 = vec_splats((int16_t)ls0); - vector signed short vscales23 = vec_splats((int16_t)ls1); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); - const __m128i m4b = __lsx_vreplgr2vr_b(0x0f); - - __m256 accum = (__m256)__lasx_xvldi(0); - __m256i tmp1; - __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask; - - mask_8f = __lsx_vreplgr2vr_b(0x8f); - for (int ibl = 0; ibl < nb; ++ibl) { - const uint8_t * qs = x[ibl].qs; - const int8_t * q8 = y[ibl].qs; - uint16_t sh = x[ibl].scales_h; - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - __m128i zero = __lsx_vldi(0); - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16; - const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16; - const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp3 = __lsx_vand_v(tmp0, mask); - tmp3 = __lsx_vshuf_b(values128, zero, tmp3); - - tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp4 = __lsx_vand_v(tmp0, mask); - tmp4 = __lsx_vshuf_b(values128, zero, tmp4); - - const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4); - - tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp3 = __lsx_vand_v(tmp0, mask); - tmp3 = __lsx_vshuf_b(values128, zero, tmp3); - - tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp4 = __lsx_vand_v(tmp0, mask); - tmp4 = __lsx_vshuf_b(values128, zero, tmp4); - - const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4); - - const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); - const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); - const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; - const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; - sh >>= 4; - __m256i tmp5, tmp6; - tmp1 = __lasx_xvreplgr2vr_h(ls1); - tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1); - tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1); - const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6); - tmp1 = __lasx_xvreplgr2vr_h(ls2); - tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1); - tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1); - const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6); - sumi1 = __lasx_xvadd_w(p_1, sumi1); - sumi2 = __lasx_xvadd_w(p_2, sumi2); - } - accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(LM_GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), - __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accum); - } - - *s = hsum_float_8(accum); +void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; -#else - float sumf = 0; - for (int ibl = 0; ibl < nb; ++ibl) { - const float d4d8 = LM_GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; - uint16_t h = x[ibl].scales_h; - const uint8_t * qs = x[ibl].qs; - const int8_t * q8 = y[ibl].qs; - for (int ib = 0; ib < QK_K/32; ib += 2) { - const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); - const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); - h >>= 4; - const float d1 = d4d8*(ls1 - 32); - const float d2 = d4d8*(ls2 - 32); - int sumi1 = 0, sumi2 = 0; - for (int j = 0; j < 16; ++j) { - sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; - sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; - } - sumf += d1 * (sumi1 + sumi2); - qs += 16; - q8 += 32; - sumi1 = sumi2 = 0; - for (int j = 0; j < 16; ++j) { - sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; - sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; - } - sumf += d2 * (sumi1 + sumi2); - qs += 16; - q8 += 32; + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK_K; ++j) { + *y++ = x[i].d * x[i].qs[j]; } } - *s = sumf; -#endif } // ================================ IQ2 quantization ============================================= @@ -14249,12 +3770,6 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t return nrow * nblock * sizeof(block_iq3_xxs); } -void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_iq3_xxs * restrict y = vy; - quantize_row_iq3_xxs_ref(x, y, k); -} - void quantize_row_iq3_xxs_ref(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_row_iq3_xxs_impl(256, x, y, k, NULL); @@ -14465,12 +3980,6 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t n return nrow * nblock * sizeof(block_iq3_s); } -void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_iq3_s * restrict y = vy; - quantize_row_iq3_s_ref(x, y, k); -} - void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq3_s(x, y, 1, k, NULL); @@ -15194,7 +4703,8 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int64_t return nrow * nblock * sizeof(block_iq4_nl); } -void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k) { +//void quantize_row_iq4_nl_ref(const float * restrict x, void * restrict vy, int64_t k) { +void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { LM_GGML_ASSERT(k%QK4_NL == 0); int64_t nblock = k/QK4_NL; uint8_t L[QK4_NL]; @@ -15202,18 +4712,13 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k uint16_t unused_h; uint8_t * unused_l = NULL; float scale; - block_iq4_nl * iq4 = (block_iq4_nl *)vy; + block_iq4_nl * iq4 = y; for (int ibl = 0; ibl < nblock; ++ibl) { quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l, &scale, weight, L, kvalues_iq4nl, NULL, -1); } } -void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { - assert(k % QK4_NL == 0); - quantize_row_iq4_nl(x, y, k); -} - size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { LM_GGML_ASSERT(n_per_row%QK_K == 0); int64_t nblock = n_per_row/QK_K; @@ -15234,12 +4739,6 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t return nrow * nblock * sizeof(block_iq4_xs); } -void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_iq4_xs * restrict y = vy; - quantize_row_iq4_xs_ref(x, y, k); -} - void quantize_row_iq4_xs_ref(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq4_xs(x, y, 1, k, NULL); @@ -15432,11 +4931,7 @@ void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y, quantize_iq2_s(x, y, 1, k, NULL); } -void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_iq2_s * restrict y = vy; - quantize_row_iq2_s_ref(x, y, k); -} +// =============================== data validation static bool validate_float(float f, size_t i) { if (isinf(f)) { diff --git a/cpp/ggml-quants.h b/cpp/ggml-quants.h index 31613bbe..23309cd0 100644 --- a/cpp/ggml-quants.h +++ b/cpp/ggml-quants.h @@ -11,136 +11,89 @@ extern "C" { #endif +// NOTE: these functions are defined as LM_GGML_API because they used by the CPU backend + // Quantization -void quantize_row_q4_0_ref(const float * LM_GGML_RESTRICT x, block_q4_0 * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q4_1_ref(const float * LM_GGML_RESTRICT x, block_q4_1 * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q5_0_ref(const float * LM_GGML_RESTRICT x, block_q5_0 * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q5_1_ref(const float * LM_GGML_RESTRICT x, block_q5_1 * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q8_0_ref(const float * LM_GGML_RESTRICT x, block_q8_0 * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q8_1_ref(const float * LM_GGML_RESTRICT x, block_q8_1 * LM_GGML_RESTRICT y, int64_t k); - -void quantize_row_q2_K_ref(const float * LM_GGML_RESTRICT x, block_q2_K * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q3_K_ref(const float * LM_GGML_RESTRICT x, block_q3_K * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q4_K_ref(const float * LM_GGML_RESTRICT x, block_q4_K * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q5_K_ref(const float * LM_GGML_RESTRICT x, block_q5_K * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q6_K_ref(const float * LM_GGML_RESTRICT x, block_q6_K * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q8_K_ref(const float * LM_GGML_RESTRICT x, block_q8_K * LM_GGML_RESTRICT y, int64_t k); - -void quantize_row_tq1_0_ref(const float * LM_GGML_RESTRICT x, block_tq1_0 * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_tq2_0_ref(const float * LM_GGML_RESTRICT x, block_tq2_0 * LM_GGML_RESTRICT y, int64_t k); - -void quantize_row_iq3_xxs_ref(const float * LM_GGML_RESTRICT x, block_iq3_xxs * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_nl_ref (const float * LM_GGML_RESTRICT x, block_iq4_nl * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_xs_ref (const float * LM_GGML_RESTRICT x, block_iq4_xs * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_iq3_s_ref (const float * LM_GGML_RESTRICT x, block_iq3_s * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_iq2_s_ref (const float * LM_GGML_RESTRICT x, block_iq2_s * LM_GGML_RESTRICT y, int64_t k); - -void quantize_row_q4_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q4_1(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q5_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q5_1(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q8_1(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); - -void quantize_row_q2_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q3_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q4_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q5_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q6_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_q8_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); - -void quantize_row_tq1_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_tq2_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); - -void quantize_row_iq3_xxs(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_nl (const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_xs (const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_iq3_s (const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); -void quantize_row_iq2_s (const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_q4_0_ref(const float * LM_GGML_RESTRICT x, block_q4_0 * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_q4_1_ref(const float * LM_GGML_RESTRICT x, block_q4_1 * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_q5_0_ref(const float * LM_GGML_RESTRICT x, block_q5_0 * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_q5_1_ref(const float * LM_GGML_RESTRICT x, block_q5_1 * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_q8_0_ref(const float * LM_GGML_RESTRICT x, block_q8_0 * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_q8_1_ref(const float * LM_GGML_RESTRICT x, block_q8_1 * LM_GGML_RESTRICT y, int64_t k); + +LM_GGML_API void quantize_row_q2_K_ref(const float * LM_GGML_RESTRICT x, block_q2_K * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_q3_K_ref(const float * LM_GGML_RESTRICT x, block_q3_K * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_q4_K_ref(const float * LM_GGML_RESTRICT x, block_q4_K * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_q5_K_ref(const float * LM_GGML_RESTRICT x, block_q5_K * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_q6_K_ref(const float * LM_GGML_RESTRICT x, block_q6_K * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_q8_K_ref(const float * LM_GGML_RESTRICT x, block_q8_K * LM_GGML_RESTRICT y, int64_t k); + +LM_GGML_API void quantize_row_tq1_0_ref(const float * LM_GGML_RESTRICT x, block_tq1_0 * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_tq2_0_ref(const float * LM_GGML_RESTRICT x, block_tq2_0 * LM_GGML_RESTRICT y, int64_t k); + +LM_GGML_API void quantize_row_iq3_xxs_ref(const float * LM_GGML_RESTRICT x, block_iq3_xxs * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_iq4_nl_ref (const float * LM_GGML_RESTRICT x, block_iq4_nl * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_iq4_xs_ref (const float * LM_GGML_RESTRICT x, block_iq4_xs * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_iq3_s_ref (const float * LM_GGML_RESTRICT x, block_iq3_s * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void quantize_row_iq2_s_ref (const float * LM_GGML_RESTRICT x, block_iq2_s * LM_GGML_RESTRICT y, int64_t k); // Dequantization -void dequantize_row_q4_0(const block_q4_0 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_q4_1(const block_q4_1 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_q5_0(const block_q5_0 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_q5_1(const block_q5_1 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_q8_0(const block_q8_0 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -//void dequantize_row_q8_1(const block_q8_1 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); - -void dequantize_row_q2_K(const block_q2_K * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_q3_K(const block_q3_K * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_q4_K(const block_q4_K * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_q5_K(const block_q5_K * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_q6_K(const block_q6_K * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_q8_K(const block_q8_K * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); - -void dequantize_row_tq1_0(const block_tq1_0 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_tq2_0(const block_tq2_0 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); - -void dequantize_row_iq2_xxs(const block_iq2_xxs * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_iq2_xs (const block_iq2_xs * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_iq2_s (const block_iq2_s * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_iq3_xxs(const block_iq3_xxs * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_iq1_s (const block_iq1_s * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_iq1_m (const block_iq1_m * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_iq4_nl (const block_iq4_nl * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_iq4_xs (const block_iq4_xs * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); -void dequantize_row_iq3_s (const block_iq3_s * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); - -// Dot product -void lm_ggml_vec_dot_q4_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_q4_1_q8_1(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_q5_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_q5_1_q8_1(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_q8_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); - -void lm_ggml_vec_dot_q2_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_q3_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_q4_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_q5_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_q6_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); - -void lm_ggml_vec_dot_tq1_0_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_tq2_0_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); - -void lm_ggml_vec_dot_iq2_xxs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_iq2_xs_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_iq2_s_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_iq3_xxs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_iq1_s_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_iq1_m_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_iq4_nl_q8_0 (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_iq4_xs_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); -void lm_ggml_vec_dot_iq3_s_q8_K (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc); +LM_GGML_API void dequantize_row_q4_0(const block_q4_0 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_q4_1(const block_q4_1 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_q5_0(const block_q5_0 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_q5_1(const block_q5_1 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_q8_0(const block_q8_0 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +//LM_GGML_API void dequantize_row_q8_1(const block_q8_1 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); + +LM_GGML_API void dequantize_row_q2_K(const block_q2_K * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_q3_K(const block_q3_K * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_q4_K(const block_q4_K * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_q5_K(const block_q5_K * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_q6_K(const block_q6_K * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_q8_K(const block_q8_K * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); + +LM_GGML_API void dequantize_row_tq1_0(const block_tq1_0 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_tq2_0(const block_tq2_0 * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); + +LM_GGML_API void dequantize_row_iq2_xxs(const block_iq2_xxs * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_iq2_xs (const block_iq2_xs * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_iq2_s (const block_iq2_s * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_iq3_xxs(const block_iq3_xxs * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_iq1_s (const block_iq1_s * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_iq1_m (const block_iq1_m * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_iq4_nl (const block_iq4_nl * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_iq4_xs (const block_iq4_xs * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); +LM_GGML_API void dequantize_row_iq3_s (const block_iq3_s * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") -size_t quantize_iq2_xxs(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq2_xs (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq2_s (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq3_xxs(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq1_s (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq1_m (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq4_nl (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq4_xs (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq3_s (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); - -size_t quantize_tq1_0(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_tq2_0(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); - -size_t quantize_q2_K(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q3_K(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q4_K(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q5_K(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q6_K(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q4_0(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q4_1(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q5_0(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q5_1(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q8_0(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); - -void lm_iq2xs_init_impl(enum lm_ggml_type type); -void lm_iq2xs_free_impl(enum lm_ggml_type type); -void lm_iq3xs_init_impl(int grid_size); -void lm_iq3xs_free_impl(int grid_size); +LM_GGML_API size_t quantize_iq2_xxs(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_iq2_xs (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_iq2_s (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_iq3_xxs(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_iq1_s (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_iq1_m (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_iq4_nl (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_iq4_xs (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_iq3_s (const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +LM_GGML_API size_t quantize_tq1_0(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_tq2_0(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +LM_GGML_API size_t quantize_q2_K(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_q3_K(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_q4_K(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_q5_K(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_q6_K(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_q4_0(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_q4_1(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_q5_0(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_q5_1(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +LM_GGML_API size_t quantize_q8_0(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +LM_GGML_API void lm_iq2xs_init_impl(enum lm_ggml_type type); +LM_GGML_API void lm_iq2xs_free_impl(enum lm_ggml_type type); +LM_GGML_API void lm_iq3xs_init_impl(int grid_size); +LM_GGML_API void lm_iq3xs_free_impl(int grid_size); #ifdef __cplusplus } diff --git a/cpp/ggml-threading.cpp b/cpp/ggml-threading.cpp new file mode 100644 index 00000000..b8078778 --- /dev/null +++ b/cpp/ggml-threading.cpp @@ -0,0 +1,12 @@ +#include "ggml-threading.h" +#include + +std::mutex lm_ggml_critical_section_mutex; + +void lm_ggml_critical_section_start() { + lm_ggml_critical_section_mutex.lock(); +} + +void lm_ggml_critical_section_end(void) { + lm_ggml_critical_section_mutex.unlock(); +} diff --git a/cpp/ggml-threading.h b/cpp/ggml-threading.h new file mode 100644 index 00000000..d453d269 --- /dev/null +++ b/cpp/ggml-threading.h @@ -0,0 +1,12 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +void lm_ggml_critical_section_start(void); +void lm_ggml_critical_section_end(void); + +#ifdef __cplusplus +} +#endif diff --git a/cpp/ggml.c b/cpp/ggml.c index 1f2a84bd..876dd87e 100644 --- a/cpp/ggml.c +++ b/cpp/ggml.c @@ -1,11 +1,13 @@ -#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows +#define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows #define _USE_MATH_DEFINES // For M_PI on MSVC #include "ggml-backend.h" #include "ggml-impl.h" -#include "ggml-cpu-impl.h" -#include "ggml-quants.h" +#include "ggml-threading.h" #include "ggml.h" + +// FIXME: required here for quantization functions +#include "ggml-quants.h" #include "ggml-aarch64.h" #if defined(_MSC_VER) || defined(__MINGW32__) @@ -31,168 +33,38 @@ #include #endif -#ifdef LM_GGML_USE_OPENMP -#include -#endif - -#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8) -#undef LM_GGML_USE_LLAMAFILE -#endif - -#ifdef LM_GGML_USE_LLAMAFILE -#include -#endif - -#if defined(_MSC_VER) -// disable "possible loss of data" to avoid hundreds of casts -// we should just be careful :) -#pragma warning(disable: 4244 4267) - -// disable POSIX deprecation warnings -// these functions are never going away, anyway -#pragma warning(disable: 4996) - -// unreachable code because of multiple instances of code after LM_GGML_ABORT -#pragma warning(disable: 4702) -#endif - -// Note: once we move threading into a separate C++ file -// will use std::hardware_destructive_interference_size instead of hardcoding it here -// and we'll use C++ attribute syntax. -#define LM_GGML_CACHE_LINE 64 - -#if defined(__clang__) || defined(__GNUC__) -#define LM_GGML_CACHE_ALIGN __attribute__((aligned(LM_GGML_CACHE_LINE))) -#endif - -#if defined(__has_feature) -#if __has_feature(thread_sanitizer) -#define LM_GGML_TSAN_ENABLED 1 -#endif -#else // __has_feature -#if defined(__SANITIZE_THREAD__) -#define LM_GGML_TSAN_ENABLED 1 +#if defined(__APPLE__) +#include +#include +#include #endif -#endif // __has_feature #if defined(_WIN32) - #define WIN32_LEAN_AND_MEAN #ifndef NOMINMAX #define NOMINMAX #endif #include - -#if !defined(__clang__) -#define LM_GGML_CACHE_ALIGN __declspec(align(LM_GGML_CACHE_LINE)) - -typedef volatile LONG atomic_int; -typedef atomic_int atomic_bool; -typedef atomic_int atomic_flag; - -#define ATOMIC_FLAG_INIT 0 - -typedef enum { - memory_order_relaxed, - memory_order_consume, - memory_order_acquire, - memory_order_release, - memory_order_acq_rel, - memory_order_seq_cst -} memory_order; - -static void atomic_store(atomic_int * ptr, LONG val) { - InterlockedExchange(ptr, val); -} -static void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) { - // TODO: add support for explicit memory order - InterlockedExchange(ptr, val); -} -static LONG atomic_load(atomic_int * ptr) { - return InterlockedCompareExchange(ptr, 0, 0); -} -static LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) { - // TODO: add support for explicit memory order - return InterlockedCompareExchange(ptr, 0, 0); -} -static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) { - return InterlockedExchangeAdd(ptr, inc); -} -static LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) { - // TODO: add support for explicit memory order - return InterlockedExchangeAdd(ptr, inc); -} -static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) { - return InterlockedExchange(ptr, 1); -} -static void atomic_flag_clear(atomic_flag * ptr) { - InterlockedExchange(ptr, 0); -} -static void atomic_thread_fence(memory_order mo) { - MemoryBarrier(); -} -#else // clang -#include #endif -typedef HANDLE pthread_t; - -typedef DWORD thread_ret_t; -static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) { - (void) unused; - HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); - if (handle == NULL) - { - return EAGAIN; - } - - *out = handle; - return 0; -} - -static int pthread_join(pthread_t thread, void * unused) { - (void) unused; - int ret = (int) WaitForSingleObject(thread, INFINITE); - CloseHandle(thread); - return ret; -} +#define UNUSED LM_GGML_UNUSED -static int sched_yield (void) { - Sleep (0); - return 0; -} +#if defined(_MSC_VER) +#define m512bh(p) p +#define m512i(p) p #else - -#include -#include -#include -#if defined(__FreeBSD__) -#include +#define m512bh(p) (__m512bh)(p) +#define m512i(p) (__m512i)(p) #endif -typedef void * thread_ret_t; - -#include -#include -#include - -#endif - -typedef pthread_t lm_ggml_thread_t; - -#ifdef LM_GGML_USE_CPU_HBM -#include -#endif - -#if defined(__APPLE__) -#include -#include -#include -#endif +// precomputed f32 table for f16 (256 KB) (ggml-impl.h) +float lm_ggml_table_f32_f16[1 << 16]; #if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \ (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH)) - +#include +#include +#include #include #if defined(__ANDROID__) @@ -305,15 +177,6 @@ void lm_ggml_abort(const char * file, int line, const char * fmt, ...) { abort(); } -#define LM_GGML_DEBUG 0 - -#define LM_GGML_GELU_FP16 -#define LM_GGML_GELU_QUICK_FP16 - -#define LM_GGML_SOFT_MAX_UNROLL 4 -#define LM_GGML_VEC_DOT_UNROLL 2 -#define LM_GGML_VEC_MAD_UNROLL 32 - // // logging // @@ -358,24 +221,6 @@ void lm_ggml_log_callback_default(enum lm_ggml_log_level level, const char * tex fflush(stderr); } -#if (LM_GGML_DEBUG >= 1) -#define LM_GGML_PRINT_DEBUG(...) LM_GGML_LOG_DEBUG(__VA_ARGS__) -#else -#define LM_GGML_PRINT_DEBUG(...) -#endif - -#if (LM_GGML_DEBUG >= 5) -#define LM_GGML_PRINT_DEBUG_5(...) LM_GGML_LOG_DEBUG(__VA_ARGS__) -#else -#define LM_GGML_PRINT_DEBUG_5(...) -#endif - -#if (LM_GGML_DEBUG >= 10) -#define LM_GGML_PRINT_DEBUG_10(...) LM_GGML_LOG_DEBUG(__VA_ARGS__) -#else -#define LM_GGML_PRINT_DEBUG_10(...) -#endif - // // end of logging block // @@ -388,17 +233,20 @@ void lm_ggml_log_callback_default(enum lm_ggml_log_level level, const char * tex void * lm_ggml_aligned_malloc(size_t size) { + const int alignment = 64; + #if defined(_MSC_VER) || defined(__MINGW32__) - return _aligned_malloc(size, TENSOR_ALIGNMENT); + return _aligned_malloc(size, alignment); #else if (size == 0) { LM_GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for lm_ggml_aligned_malloc!\n"); return NULL; } void * aligned_memory = NULL; -#ifdef LM_GGML_USE_CPU_HBM - int result = hbw_posix_memalign(&aligned_memory, TENSOR_ALIGNMENT, size); -#elif TARGET_OS_OSX + #ifdef LM_GGML_USE_CPU_HBM + int result = hbw_posix_memalign(&aligned_memory, alignment, size); + #elif TARGET_OS_OSX + LM_GGML_UNUSED(alignment); kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE); int result = EFAULT; switch (alloc_status) { @@ -415,12 +263,9 @@ void * lm_ggml_aligned_malloc(size_t size) { result = EFAULT; break; } -#elif LM_GGML_USE_METAL - const long page_size = sysconf(_SC_PAGESIZE); - int result = posix_memalign(&aligned_memory, MAX(TENSOR_ALIGNMENT, page_size), size); -#else - int result = posix_memalign(&aligned_memory, TENSOR_ALIGNMENT, size); -#endif + #else + int result = posix_memalign(&aligned_memory, alignment, size); + #endif if (result != 0) { // Handle allocation failure const char *error_desc = "unknown allocation error"; @@ -433,7 +278,6 @@ void * lm_ggml_aligned_malloc(size_t size) { break; } LM_GGML_LOG_ERROR("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0)); - LM_GGML_ABORT("fatal error"); return NULL; } return aligned_memory; @@ -490,44 +334,6 @@ inline static void * lm_ggml_calloc(size_t num, size_t size) { #define LM_GGML_FREE(ptr) free(ptr) -#define UNUSED LM_GGML_UNUSED -#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0) - -#if defined(LM_GGML_USE_ACCELERATE) -#include -#endif - -// floating point type used to accumulate sums -typedef double lm_ggml_float; - -#undef MIN -#undef MAX - -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -// -// global data -// - -// precomputed gelu table for f16 (128 KB) -static lm_ggml_fp16_t lm_ggml_table_gelu_f16[1 << 16]; - -// precomputed quick gelu table for f16 (128 KB) -static lm_ggml_fp16_t lm_ggml_table_gelu_quick_f16[1 << 16]; - -// precomputed f32 table for f16 (256 KB) (ggml-impl.h) -float lm_ggml_table_f32_f16[1 << 16]; - -#if defined(__ARM_ARCH) -struct lm_ggml_arm_arch_features_type { - int has_neon; - int has_i8mm; - int has_sve; - int sve_cnt; -} lm_ggml_arm_arch_features = {-1, -1, -1, 0}; -#endif - const char * lm_ggml_status_to_string(enum lm_ggml_status status) { switch (status) { case LM_GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)"; @@ -565,19 +371,23 @@ void lm_ggml_fp16_to_fp32_row(const lm_ggml_fp16_t * x, float * y, int64_t n) { } } +// FIXME: these functions must detect the instruction set at runtime, since they are part of the core ggml library +// currently, the lm_ggml_cpu_has_* functions are entirely compile-time void lm_ggml_fp32_to_fp16_row(const float * x, lm_ggml_fp16_t * y, int64_t n) { int64_t i = 0; #if defined(__F16C__) - for (; i + 7 < n; i += 8) { - __m256 x_vec = _mm256_loadu_ps(x + i); - __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); - _mm_storeu_si128((__m128i *)(y + i), y_vec); - } - for(; i + 3 < n; i += 4) { - __m128 x_vec = _mm_loadu_ps(x + i); - __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); - _mm_storel_epi64((__m128i *)(y + i), y_vec); - } + //if (lm_ggml_cpu_has_f16c()) { + for (; i + 7 < n; i += 8) { + __m256 x_vec = _mm256_loadu_ps(x + i); + __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128((__m128i *)(y + i), y_vec); + } + for(; i + 3 < n; i += 4) { + __m128 x_vec = _mm_loadu_ps(x + i); + __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storel_epi64((__m128i *)(y + i), y_vec); + } + //} #endif for (; i < n; i++) { y[i] = LM_GGML_FP32_TO_FP16(x[i]); @@ -587,25 +397,30 @@ void lm_ggml_fp32_to_fp16_row(const float * x, lm_ggml_fp16_t * y, int64_t n) { void lm_ggml_bf16_to_fp32_row(const lm_ggml_bf16_t * x, float * y, int64_t n) { int64_t i = 0; #if defined(__AVX512F__) - for (; i + 16 <= n; i += 16) { - _mm512_storeu_ps(y + i, - _mm512_castsi512_ps( - _mm512_slli_epi32( - _mm512_cvtepu16_epi32( - _mm256_loadu_si256( - (const __m256i *)(x + i))), - 16))); - } -#elif defined(__AVX2__) - for (; i + 8 <= n; i += 8) { - _mm256_storeu_ps(y + i, - _mm256_castsi256_ps( - _mm256_slli_epi32( - _mm256_cvtepu16_epi32( - _mm_loadu_si128( - (const __m128i *)(x + i))), - 16))); - } + //if (lm_ggml_cpu_has_avx512()) { + for (; i + 16 <= n; i += 16) { + _mm512_storeu_ps(y + i, + _mm512_castsi512_ps( + _mm512_slli_epi32( + _mm512_cvtepu16_epi32( + _mm256_loadu_si256( + (const __m256i *)(x + i))), + 16))); + } + //} +#endif +#if defined(__AVX2__) + //if (lm_ggml_cpu_has_avx2()) { + for (; i + 8 <= n; i += 8) { + _mm256_storeu_ps(y + i, + _mm256_castsi256_ps( + _mm256_slli_epi32( + _mm256_cvtepu16_epi32( + _mm_loadu_si128( + (const __m128i *)(x + i))), + 16))); + } + //} #endif for (; i < n; i++) { y[i] = LM_GGML_BF16_TO_FP32(x[i]); @@ -737,24 +552,8 @@ FILE * lm_ggml_fopen(const char * fname, const char * mode) { #else return fopen(fname, mode); #endif -} - -// -// cache line -// - -#if defined(__cpp_lib_hardware_interference_size) -#define CACHE_LINE_SIZE hardware_destructive_interference_size -#else -#if defined(__POWER9_VECTOR__) -#define CACHE_LINE_SIZE 128 -#else -#define CACHE_LINE_SIZE 64 -#endif -#endif - -static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); +} static void lm_ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc); static void lm_ggml_vec_dot_f16(int n, float * restrict s, size_t bs, lm_ggml_fp16_t * restrict x, size_t bx, lm_ggml_fp16_t * restrict y, size_t by, int nrc); static void lm_ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, lm_ggml_bf16_t * restrict x, size_t bx, lm_ggml_bf16_t * restrict y, size_t by, int nrc); @@ -789,16 +588,12 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .blck_size = 1, .type_size = sizeof(double), .is_quantized = false, - .nrows = 1, }, [LM_GGML_TYPE_F32] = { .type_name = "f32", .blck_size = 1, .type_size = sizeof(float), .is_quantized = false, - .vec_dot = (lm_ggml_vec_dot_t) lm_ggml_vec_dot_f32, - .vec_dot_type = LM_GGML_TYPE_F32, - .nrows = 1, }, [LM_GGML_TYPE_F16] = { .type_name = "f16", @@ -806,11 +601,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(lm_ggml_fp16_t), .is_quantized = false, .to_float = (lm_ggml_to_float_t) lm_ggml_fp16_to_fp32_row, - .from_float = (lm_ggml_from_float_t) lm_ggml_fp32_to_fp16_row, .from_float_ref = (lm_ggml_from_float_t) lm_ggml_fp32_to_fp16_row, - .vec_dot = (lm_ggml_vec_dot_t) lm_ggml_vec_dot_f16, - .vec_dot_type = LM_GGML_TYPE_F16, - .nrows = 1, }, [LM_GGML_TYPE_Q4_0] = { .type_name = "q4_0", @@ -818,15 +609,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_q4_0), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q4_0, - .from_float = quantize_row_q4_0, .from_float_ref = (lm_ggml_from_float_t) quantize_row_q4_0_ref, - .vec_dot = lm_ggml_vec_dot_q4_0_q8_0, - .vec_dot_type = LM_GGML_TYPE_Q8_0, -#if defined (__ARM_FEATURE_MATMUL_INT8) - .nrows = 2, -#else - .nrows = 1, -#endif }, [LM_GGML_TYPE_Q4_1] = { .type_name = "q4_1", @@ -834,39 +617,19 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_q4_1), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q4_1, - .from_float = quantize_row_q4_1, .from_float_ref = (lm_ggml_from_float_t) quantize_row_q4_1_ref, - .vec_dot = lm_ggml_vec_dot_q4_1_q8_1, - .vec_dot_type = LM_GGML_TYPE_Q8_1, -#if defined (__ARM_FEATURE_MATMUL_INT8) - .nrows = 2, -#else - .nrows = 1, -#endif }, [4] = { // LM_GGML_TYPE_Q4_2 .type_name = "DEPRECATED", .blck_size = 0, .type_size = 0, .is_quantized = false, - .to_float = NULL, - .from_float = NULL, - .from_float_ref = NULL, - .vec_dot = NULL, - .vec_dot_type = LM_GGML_TYPE_COUNT, - .nrows = 1, }, [5] = { // LM_GGML_TYPE_Q4_3 .type_name = "DEPRECATED", .blck_size = 0, .type_size = 0, .is_quantized = false, - .to_float = NULL, - .from_float = NULL, - .from_float_ref = NULL, - .vec_dot = NULL, - .vec_dot_type = LM_GGML_TYPE_COUNT, - .nrows = 1, }, [LM_GGML_TYPE_Q5_0] = { .type_name = "q5_0", @@ -874,11 +637,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_q5_0), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q5_0, - .from_float = quantize_row_q5_0, .from_float_ref = (lm_ggml_from_float_t) quantize_row_q5_0_ref, - .vec_dot = lm_ggml_vec_dot_q5_0_q8_0, - .vec_dot_type = LM_GGML_TYPE_Q8_0, - .nrows = 1, }, [LM_GGML_TYPE_Q5_1] = { .type_name = "q5_1", @@ -886,11 +645,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_q5_1), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q5_1, - .from_float = quantize_row_q5_1, .from_float_ref = (lm_ggml_from_float_t) quantize_row_q5_1_ref, - .vec_dot = lm_ggml_vec_dot_q5_1_q8_1, - .vec_dot_type = LM_GGML_TYPE_Q8_1, - .nrows = 1, }, [LM_GGML_TYPE_Q8_0] = { .type_name = "q8_0", @@ -898,26 +653,14 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_q8_0), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q8_0, - .from_float = quantize_row_q8_0, .from_float_ref = (lm_ggml_from_float_t) quantize_row_q8_0_ref, - .from_float_to_mat = quantize_mat_q8_0, - .vec_dot = lm_ggml_vec_dot_q8_0_q8_0, - .vec_dot_type = LM_GGML_TYPE_Q8_0, -#if defined (__ARM_FEATURE_MATMUL_INT8) - .nrows = 2, -#else - .nrows = 1, -#endif }, [LM_GGML_TYPE_Q8_1] = { .type_name = "q8_1", .blck_size = QK8_1, .type_size = sizeof(block_q8_1), .is_quantized = true, - .from_float = quantize_row_q8_1, .from_float_ref = (lm_ggml_from_float_t) quantize_row_q8_1_ref, - .vec_dot_type = LM_GGML_TYPE_Q8_1, - .nrows = 1, }, [LM_GGML_TYPE_Q2_K] = { .type_name = "q2_K", @@ -925,11 +668,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_q2_K), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q2_K, - .from_float = quantize_row_q2_K, .from_float_ref = (lm_ggml_from_float_t) quantize_row_q2_K_ref, - .vec_dot = lm_ggml_vec_dot_q2_K_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, [LM_GGML_TYPE_Q3_K] = { .type_name = "q3_K", @@ -937,11 +676,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_q3_K), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q3_K, - .from_float = quantize_row_q3_K, .from_float_ref = (lm_ggml_from_float_t) quantize_row_q3_K_ref, - .vec_dot = lm_ggml_vec_dot_q3_K_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, [LM_GGML_TYPE_Q4_K] = { .type_name = "q4_K", @@ -949,11 +684,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_q4_K), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q4_K, - .from_float = quantize_row_q4_K, .from_float_ref = (lm_ggml_from_float_t) quantize_row_q4_K_ref, - .vec_dot = lm_ggml_vec_dot_q4_K_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, [LM_GGML_TYPE_Q5_K] = { .type_name = "q5_K", @@ -961,11 +692,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_q5_K), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q5_K, - .from_float = quantize_row_q5_K, .from_float_ref = (lm_ggml_from_float_t) quantize_row_q5_K_ref, - .vec_dot = lm_ggml_vec_dot_q5_K_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, [LM_GGML_TYPE_Q6_K] = { .type_name = "q6_K", @@ -973,11 +700,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_q6_K), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_q6_K, - .from_float = quantize_row_q6_K, .from_float_ref = (lm_ggml_from_float_t) quantize_row_q6_K_ref, - .vec_dot = lm_ggml_vec_dot_q6_K_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, [LM_GGML_TYPE_IQ2_XXS] = { .type_name = "iq2_xxs", @@ -985,11 +708,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq2_xxs), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq2_xxs, - .from_float = NULL, .from_float_ref = NULL, - .vec_dot = lm_ggml_vec_dot_iq2_xxs_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, [LM_GGML_TYPE_IQ2_XS] = { .type_name = "iq2_xs", @@ -997,11 +716,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq2_xs), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq2_xs, - .from_float = NULL, .from_float_ref = NULL, - .vec_dot = lm_ggml_vec_dot_iq2_xs_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, [LM_GGML_TYPE_IQ3_XXS] = { .type_name = "iq3_xxs", @@ -1009,11 +724,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq3_xxs), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq3_xxs, - .from_float = quantize_row_iq3_xxs, .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq3_xxs_ref, - .vec_dot = lm_ggml_vec_dot_iq3_xxs_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, [LM_GGML_TYPE_IQ3_S] = { .type_name = "iq3_s", @@ -1021,11 +732,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq3_s), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq3_s, - .from_float = quantize_row_iq3_s, .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq3_s_ref, - .vec_dot = lm_ggml_vec_dot_iq3_s_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, [LM_GGML_TYPE_IQ2_S] = { .type_name = "iq2_s", @@ -1033,11 +740,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq2_s), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq2_s, - .from_float = quantize_row_iq2_s, .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq2_s_ref, - .vec_dot = lm_ggml_vec_dot_iq2_s_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, [LM_GGML_TYPE_IQ1_S] = { .type_name = "iq1_s", @@ -1045,11 +748,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq1_s), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq1_s, - .from_float = NULL, .from_float_ref = NULL, - .vec_dot = lm_ggml_vec_dot_iq1_s_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, [LM_GGML_TYPE_IQ1_M] = { .type_name = "iq1_m", @@ -1057,11 +756,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq1_m), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq1_m, - .from_float = NULL, .from_float_ref = NULL, - .vec_dot = lm_ggml_vec_dot_iq1_m_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, [LM_GGML_TYPE_IQ4_NL] = { .type_name = "iq4_nl", @@ -1069,11 +764,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq4_nl), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq4_nl, - .from_float = quantize_row_iq4_nl, .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq4_nl_ref, - .vec_dot = lm_ggml_vec_dot_iq4_nl_q8_0, - .vec_dot_type = LM_GGML_TYPE_Q8_0, - .nrows = 1, }, [LM_GGML_TYPE_IQ4_XS] = { .type_name = "iq4_xs", @@ -1081,18 +772,13 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq4_xs), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_iq4_xs, - .from_float = quantize_row_iq4_xs, .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq4_xs_ref, - .vec_dot = lm_ggml_vec_dot_iq4_xs_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, [LM_GGML_TYPE_Q8_K] = { .type_name = "q8_K", .blck_size = QK_K, .type_size = sizeof(block_q8_K), .is_quantized = true, - .from_float = quantize_row_q8_K, }, [LM_GGML_TYPE_BF16] = { .type_name = "bf16", @@ -1100,11 +786,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(lm_ggml_bf16_t), .is_quantized = false, .to_float = (lm_ggml_to_float_t) lm_ggml_bf16_to_fp32_row, - .from_float = (lm_ggml_from_float_t) lm_ggml_fp32_to_bf16_row, .from_float_ref = (lm_ggml_from_float_t) lm_ggml_fp32_to_bf16_row_ref, - .vec_dot = (lm_ggml_vec_dot_t) lm_ggml_vec_dot_bf16, - .vec_dot_type = LM_GGML_TYPE_BF16, - .nrows = 1, }, [LM_GGML_TYPE_Q4_0_4_4] = { .type_name = "q4_0_4x4", @@ -1113,14 +795,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_q4_0), .is_quantized = true, .to_float = NULL, - .from_float = NULL, .from_float_ref = NULL, - .vec_dot = NULL, - .vec_dot_type = LM_GGML_TYPE_Q8_0, - .nrows = 1, - .ncols = 4, - .gemv = lm_ggml_gemv_q4_0_4x4_q8_0, - .gemm = lm_ggml_gemm_q4_0_4x4_q8_0, }, [LM_GGML_TYPE_Q4_0_4_8] = { .type_name = "q4_0_4x8", @@ -1129,14 +804,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_q4_0), .is_quantized = true, .to_float = NULL, - .from_float = NULL, .from_float_ref = NULL, - .vec_dot = NULL, - .vec_dot_type = LM_GGML_TYPE_Q8_0, - .nrows = 1, - .ncols = 4, - .gemv = lm_ggml_gemv_q4_0_4x8_q8_0, - .gemm = lm_ggml_gemm_q4_0_4x8_q8_0, }, [LM_GGML_TYPE_Q4_0_8_8] = { .type_name = "q4_0_8x8", @@ -1145,14 +813,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_q4_0), .is_quantized = true, .to_float = NULL, - .from_float = NULL, .from_float_ref = NULL, - .vec_dot = NULL, - .vec_dot_type = LM_GGML_TYPE_Q8_0, - .nrows = 1, - .ncols = 8, - .gemv = lm_ggml_gemv_q4_0_8x8_q8_0, - .gemm = lm_ggml_gemm_q4_0_8x8_q8_0, }, [LM_GGML_TYPE_TQ1_0] = { .type_name = "tq1_0", @@ -1160,11 +821,7 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_tq1_0), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_tq1_0, - .from_float = quantize_row_tq1_0, .from_float_ref = (lm_ggml_from_float_t) quantize_row_tq1_0_ref, - .vec_dot = lm_ggml_vec_dot_tq1_0_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, [LM_GGML_TYPE_TQ2_0] = { .type_name = "tq2_0", @@ -1172,826 +829,15 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .type_size = sizeof(block_tq2_0), .is_quantized = true, .to_float = (lm_ggml_to_float_t) dequantize_row_tq2_0, - .from_float = quantize_row_tq2_0, .from_float_ref = (lm_ggml_from_float_t) quantize_row_tq2_0_ref, - .vec_dot = lm_ggml_vec_dot_tq2_0_q8_K, - .vec_dot_type = LM_GGML_TYPE_Q8_K, - .nrows = 1, }, }; -// For internal test use const struct lm_ggml_type_traits * lm_ggml_get_type_traits(enum lm_ggml_type type) { LM_GGML_ASSERT(type < LM_GGML_TYPE_COUNT); return &type_traits[type]; } -// -// simd mappings -// - -// we define a common set of C macros which map to specific intrinsics based on the current architecture -// we then implement the fundamental computation operations below using only these macros -// adding support for new architectures requires to define the corresponding SIMD macros -// -// LM_GGML_F32_STEP / LM_GGML_F16_STEP -// number of elements to process in a single step -// -// LM_GGML_F32_EPR / LM_GGML_F16_EPR -// number of elements to fit in a single register -// - -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) - -#define LM_GGML_SIMD - -// F32 NEON - -#define LM_GGML_F32_STEP 16 -#define LM_GGML_F32_EPR 4 - -#define LM_GGML_F32x4 float32x4_t -#define LM_GGML_F32x4_ZERO vdupq_n_f32(0.0f) -#define LM_GGML_F32x4_SET1(x) vdupq_n_f32(x) -#define LM_GGML_F32x4_LOAD vld1q_f32 -#define LM_GGML_F32x4_STORE vst1q_f32 -#define LM_GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) -#define LM_GGML_F32x4_ADD vaddq_f32 -#define LM_GGML_F32x4_MUL vmulq_f32 -#define LM_GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) -#define LM_GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = LM_GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ - } \ - (res) = LM_GGML_F32x4_REDUCE_ONE((x)[0]); \ -} - -#define LM_GGML_F32_VEC LM_GGML_F32x4 -#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x4_ZERO -#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x4_SET1 -#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x4_LOAD -#define LM_GGML_F32_VEC_STORE LM_GGML_F32x4_STORE -#define LM_GGML_F32_VEC_FMA LM_GGML_F32x4_FMA -#define LM_GGML_F32_VEC_ADD LM_GGML_F32x4_ADD -#define LM_GGML_F32_VEC_MUL LM_GGML_F32x4_MUL -#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x4_REDUCE - -// F16 NEON - -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - #define LM_GGML_F16_STEP 32 - #define LM_GGML_F16_EPR 8 - - #define LM_GGML_F16x8 float16x8_t - #define LM_GGML_F16x8_ZERO vdupq_n_f16(0.0f) - #define LM_GGML_F16x8_SET1(x) vdupq_n_f16(x) - #define LM_GGML_F16x8_LOAD(x) vld1q_f16((const lm_ggml_fp16_internal_t *)(x)) - #define LM_GGML_F16x8_STORE vst1q_f16 - #define LM_GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) - #define LM_GGML_F16x8_ADD vaddq_f16 - #define LM_GGML_F16x8_MUL vmulq_f16 - #define LM_GGML_F16x8_REDUCE(res, x) \ - do { \ - int offset = LM_GGML_F16_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ - } \ - const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \ - const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \ - (res) = (lm_ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ - } while (0) - - #define LM_GGML_F16_VEC LM_GGML_F16x8 - #define LM_GGML_F16_VEC_ZERO LM_GGML_F16x8_ZERO - #define LM_GGML_F16_VEC_SET1 LM_GGML_F16x8_SET1 - #define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F16x8_LOAD(p) - #define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F16x8_STORE((lm_ggml_fp16_internal_t *)(p), (r)[i]) - #define LM_GGML_F16_VEC_FMA LM_GGML_F16x8_FMA - #define LM_GGML_F16_VEC_ADD LM_GGML_F16x8_ADD - #define LM_GGML_F16_VEC_MUL LM_GGML_F16x8_MUL - #define LM_GGML_F16_VEC_REDUCE LM_GGML_F16x8_REDUCE -#else - // if FP16 vector arithmetic is not supported, we use FP32 instead - // and take advantage of the vcvt_ functions to convert to/from FP16 - - #define LM_GGML_F16_STEP 16 - #define LM_GGML_F16_EPR 4 - - #define LM_GGML_F32Cx4 float32x4_t - #define LM_GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) - #define LM_GGML_F32Cx4_SET1(x) vdupq_n_f32(x) - #define LM_GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const lm_ggml_fp16_internal_t *)(x))) - #define LM_GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) - #define LM_GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) - #define LM_GGML_F32Cx4_ADD vaddq_f32 - #define LM_GGML_F32Cx4_MUL vmulq_f32 - #define LM_GGML_F32Cx4_REDUCE LM_GGML_F32x4_REDUCE - - #define LM_GGML_F16_VEC LM_GGML_F32Cx4 - #define LM_GGML_F16_VEC_ZERO LM_GGML_F32Cx4_ZERO - #define LM_GGML_F16_VEC_SET1 LM_GGML_F32Cx4_SET1 - #define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F32Cx4_LOAD(p) - #define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F32Cx4_STORE((lm_ggml_fp16_internal_t *)(p), r[i]) - #define LM_GGML_F16_VEC_FMA LM_GGML_F32Cx4_FMA - #define LM_GGML_F16_VEC_ADD LM_GGML_F32Cx4_ADD - #define LM_GGML_F16_VEC_MUL LM_GGML_F32Cx4_MUL - #define LM_GGML_F16_VEC_REDUCE LM_GGML_F32Cx4_REDUCE -#endif - -#elif defined(__AVX512F__) - -#define LM_GGML_SIMD - -// F32 AVX512 - -#define LM_GGML_F32_STEP 64 -#define LM_GGML_F32_EPR 16 - -#define LM_GGML_F32x16 __m512 -#define LM_GGML_F32x16_ZERO _mm512_setzero_ps() -#define LM_GGML_F32x16_SET1(x) _mm512_set1_ps(x) -#define LM_GGML_F32x16_LOAD _mm512_loadu_ps -#define LM_GGML_F32x16_STORE _mm512_storeu_ps -// _mm512_fmadd_ps is defined in AVX512F so no guard is required -#define LM_GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) -#define LM_GGML_F32x16_ADD _mm512_add_ps -#define LM_GGML_F32x16_MUL _mm512_mul_ps -#define LM_GGML_F32x16_REDUCE(res, x) \ -do { \ - int offset = LM_GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm512_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm512_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm512_add_ps(x[i], x[offset+i]); \ - } \ - res = _mm512_reduce_add_ps(x[0]); \ -} while (0) - -// TODO: is this optimal ? - -#define LM_GGML_F32_VEC LM_GGML_F32x16 -#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x16_ZERO -#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x16_SET1 -#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x16_LOAD -#define LM_GGML_F32_VEC_STORE LM_GGML_F32x16_STORE -#define LM_GGML_F32_VEC_FMA LM_GGML_F32x16_FMA -#define LM_GGML_F32_VEC_ADD LM_GGML_F32x16_ADD -#define LM_GGML_F32_VEC_MUL LM_GGML_F32x16_MUL -#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x16_REDUCE - -// F16 AVX512 - -// F16 AVX - -#define LM_GGML_F16_STEP 64 -#define LM_GGML_F16_EPR 16 - -// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead - -#define LM_GGML_F32Cx16 __m512 -#define LM_GGML_F32Cx16_ZERO _mm512_setzero_ps() -#define LM_GGML_F32Cx16_SET1(x) _mm512_set1_ps(x) - -// unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F -// so F16C guard isn't required -#define LM_GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) -#define LM_GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0)) - -#define LM_GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) -#define LM_GGML_F32Cx16_ADD _mm512_add_ps -#define LM_GGML_F32Cx16_MUL _mm512_mul_ps -#define LM_GGML_F32Cx16_REDUCE(res, x) \ -do { \ - int offset = LM_GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm512_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm512_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm512_add_ps(x[i], x[offset+i]); \ - } \ - res = _mm512_reduce_add_ps(x[0]); \ -} while (0) - -#define LM_GGML_F16_VEC LM_GGML_F32Cx16 -#define LM_GGML_F16_VEC_ZERO LM_GGML_F32Cx16_ZERO -#define LM_GGML_F16_VEC_SET1 LM_GGML_F32Cx16_SET1 -#define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F32Cx16_LOAD(p) -#define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F32Cx16_STORE(p, r[i]) -#define LM_GGML_F16_VEC_FMA LM_GGML_F32Cx16_FMA -#define LM_GGML_F16_VEC_ADD LM_GGML_F32Cx16_ADD -#define LM_GGML_F16_VEC_MUL LM_GGML_F32Cx16_MUL -#define LM_GGML_F16_VEC_REDUCE LM_GGML_F32Cx16_REDUCE - -#elif defined(__AVX__) - -#define LM_GGML_SIMD - -// F32 AVX - -#define LM_GGML_F32_STEP 32 -#define LM_GGML_F32_EPR 8 - -#define LM_GGML_F32x8 __m256 -#define LM_GGML_F32x8_ZERO _mm256_setzero_ps() -#define LM_GGML_F32x8_SET1(x) _mm256_set1_ps(x) -#define LM_GGML_F32x8_LOAD _mm256_loadu_ps -#define LM_GGML_F32x8_STORE _mm256_storeu_ps -#if defined(__FMA__) - #define LM_GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) -#else - #define LM_GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) -#endif -#define LM_GGML_F32x8_ADD _mm256_add_ps -#define LM_GGML_F32x8_MUL _mm256_mul_ps -#define LM_GGML_F32x8_REDUCE(res, x) \ -do { \ - int offset = LM_GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm256_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm256_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm256_add_ps(x[i], x[offset+i]); \ - } \ - const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ - _mm256_extractf128_ps(x[0], 1)); \ - const __m128 t1 = _mm_hadd_ps(t0, t0); \ - res = (lm_ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ -} while (0) -// TODO: is this optimal ? - -#define LM_GGML_F32_VEC LM_GGML_F32x8 -#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x8_ZERO -#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x8_SET1 -#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x8_LOAD -#define LM_GGML_F32_VEC_STORE LM_GGML_F32x8_STORE -#define LM_GGML_F32_VEC_FMA LM_GGML_F32x8_FMA -#define LM_GGML_F32_VEC_ADD LM_GGML_F32x8_ADD -#define LM_GGML_F32_VEC_MUL LM_GGML_F32x8_MUL -#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x8_REDUCE - -// F16 AVX - -#define LM_GGML_F16_STEP 32 -#define LM_GGML_F16_EPR 8 - -// F16 arithmetic is not supported by AVX, so we use F32 instead - -#define LM_GGML_F32Cx8 __m256 -#define LM_GGML_F32Cx8_ZERO _mm256_setzero_ps() -#define LM_GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) - -#if defined(__F16C__) -// the _mm256_cvt intrinsics require F16C -#define LM_GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) -#define LM_GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) -#else -static inline __m256 __avx_f32cx8_load(lm_ggml_fp16_t *x) { - float tmp[8]; - - for (int i = 0; i < 8; i++) { - tmp[i] = LM_GGML_FP16_TO_FP32(x[i]); - } - - return _mm256_loadu_ps(tmp); -} -static inline void __avx_f32cx8_store(lm_ggml_fp16_t *x, __m256 y) { - float arr[8]; - - _mm256_storeu_ps(arr, y); - - for (int i = 0; i < 8; i++) - x[i] = LM_GGML_FP32_TO_FP16(arr[i]); -} -#define LM_GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) -#define LM_GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y) -#endif - -#define LM_GGML_F32Cx8_FMA LM_GGML_F32x8_FMA -#define LM_GGML_F32Cx8_ADD _mm256_add_ps -#define LM_GGML_F32Cx8_MUL _mm256_mul_ps -#define LM_GGML_F32Cx8_REDUCE LM_GGML_F32x8_REDUCE - -#define LM_GGML_F16_VEC LM_GGML_F32Cx8 -#define LM_GGML_F16_VEC_ZERO LM_GGML_F32Cx8_ZERO -#define LM_GGML_F16_VEC_SET1 LM_GGML_F32Cx8_SET1 -#define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F32Cx8_LOAD(p) -#define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F32Cx8_STORE(p, r[i]) -#define LM_GGML_F16_VEC_FMA LM_GGML_F32Cx8_FMA -#define LM_GGML_F16_VEC_ADD LM_GGML_F32Cx8_ADD -#define LM_GGML_F16_VEC_MUL LM_GGML_F32Cx8_MUL -#define LM_GGML_F16_VEC_REDUCE LM_GGML_F32Cx8_REDUCE - -#elif defined(__POWER9_VECTOR__) - -#define LM_GGML_SIMD - -// F32 POWER9 - -#define LM_GGML_F32_STEP 32 -#define LM_GGML_F32_EPR 4 - -#define LM_GGML_F32x4 vector float -#define LM_GGML_F32x4_ZERO 0.0f -#define LM_GGML_F32x4_SET1 vec_splats -#define LM_GGML_F32x4_LOAD(p) vec_xl(0, p) -#define LM_GGML_F32x4_STORE(p, r) vec_xst(r, 0, p) -#define LM_GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) -#define LM_GGML_F32x4_ADD vec_add -#define LM_GGML_F32x4_MUL vec_mul -#define LM_GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = LM_GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vec_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vec_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vec_add(x[i], x[offset+i]); \ - } \ - res = vec_extract(x[0], 0) + \ - vec_extract(x[0], 1) + \ - vec_extract(x[0], 2) + \ - vec_extract(x[0], 3); \ -} - -#define LM_GGML_F32_VEC LM_GGML_F32x4 -#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x4_ZERO -#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x4_SET1 -#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x4_LOAD -#define LM_GGML_F32_VEC_STORE LM_GGML_F32x4_STORE -#define LM_GGML_F32_VEC_FMA LM_GGML_F32x4_FMA -#define LM_GGML_F32_VEC_ADD LM_GGML_F32x4_ADD -#define LM_GGML_F32_VEC_MUL LM_GGML_F32x4_MUL -#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x4_REDUCE - -// F16 POWER9 -#define LM_GGML_F16_STEP LM_GGML_F32_STEP -#define LM_GGML_F16_EPR LM_GGML_F32_EPR -#define LM_GGML_F16_VEC LM_GGML_F32x4 -#define LM_GGML_F16_VEC_ZERO LM_GGML_F32x4_ZERO -#define LM_GGML_F16_VEC_SET1 LM_GGML_F32x4_SET1 -#define LM_GGML_F16_VEC_FMA LM_GGML_F32x4_FMA -#define LM_GGML_F16_VEC_ADD LM_GGML_F32x4_ADD -#define LM_GGML_F16_VEC_MUL LM_GGML_F32x4_MUL -#define LM_GGML_F16_VEC_REDUCE LM_GGML_F32x4_REDUCE -// Use vec_xl, not vec_ld, in case the load address is not aligned. -#define LM_GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ - vec_extract_fp32_from_shorth(vec_xl(0, p - LM_GGML_F16_EPR)) : \ - vec_extract_fp32_from_shortl(vec_xl(0, p)) -#define LM_GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] -#define LM_GGML_F16_VEC_STORE(p, r, i) \ - if (i & 0x1) \ - vec_xst(vec_pack_to_short_fp32(r[i - LM_GGML_ENDIAN_BYTE(1)], \ - r[i - LM_GGML_ENDIAN_BYTE(0)]), \ - 0, p - LM_GGML_F16_EPR) - -#elif defined(__wasm_simd128__) - -#define LM_GGML_SIMD - -// F32 WASM - -#define LM_GGML_F32_STEP 16 -#define LM_GGML_F32_EPR 4 - -#define LM_GGML_F32x4 v128_t -#define LM_GGML_F32x4_ZERO wasm_f32x4_splat(0.0f) -#define LM_GGML_F32x4_SET1(x) wasm_f32x4_splat(x) -#define LM_GGML_F32x4_LOAD wasm_v128_load -#define LM_GGML_F32x4_STORE wasm_v128_store -#define LM_GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) -#define LM_GGML_F32x4_ADD wasm_f32x4_add -#define LM_GGML_F32x4_MUL wasm_f32x4_mul -#define LM_GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = LM_GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - res = wasm_f32x4_extract_lane(x[0], 0) + \ - wasm_f32x4_extract_lane(x[0], 1) + \ - wasm_f32x4_extract_lane(x[0], 2) + \ - wasm_f32x4_extract_lane(x[0], 3); \ -} - -#define LM_GGML_F32_VEC LM_GGML_F32x4 -#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x4_ZERO -#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x4_SET1 -#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x4_LOAD -#define LM_GGML_F32_VEC_STORE LM_GGML_F32x4_STORE -#define LM_GGML_F32_VEC_FMA LM_GGML_F32x4_FMA -#define LM_GGML_F32_VEC_ADD LM_GGML_F32x4_ADD -#define LM_GGML_F32_VEC_MUL LM_GGML_F32x4_MUL -#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x4_REDUCE - -// F16 WASM - -#define LM_GGML_F16_STEP 16 -#define LM_GGML_F16_EPR 4 - -inline static v128_t __wasm_f16x4_load(const lm_ggml_fp16_t * p) { - float tmp[4]; - - tmp[0] = LM_GGML_FP16_TO_FP32(p[0]); - tmp[1] = LM_GGML_FP16_TO_FP32(p[1]); - tmp[2] = LM_GGML_FP16_TO_FP32(p[2]); - tmp[3] = LM_GGML_FP16_TO_FP32(p[3]); - - return wasm_v128_load(tmp); -} - -inline static void __wasm_f16x4_store(lm_ggml_fp16_t * p, v128_t x) { - float tmp[4]; - - wasm_v128_store(tmp, x); - - p[0] = LM_GGML_FP32_TO_FP16(tmp[0]); - p[1] = LM_GGML_FP32_TO_FP16(tmp[1]); - p[2] = LM_GGML_FP32_TO_FP16(tmp[2]); - p[3] = LM_GGML_FP32_TO_FP16(tmp[3]); -} - -#define LM_GGML_F16x4 v128_t -#define LM_GGML_F16x4_ZERO wasm_f32x4_splat(0.0f) -#define LM_GGML_F16x4_SET1(x) wasm_f32x4_splat(x) -#define LM_GGML_F16x4_LOAD(x) __wasm_f16x4_load(x) -#define LM_GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) -#define LM_GGML_F16x4_FMA LM_GGML_F32x4_FMA -#define LM_GGML_F16x4_ADD wasm_f32x4_add -#define LM_GGML_F16x4_MUL wasm_f32x4_mul -#define LM_GGML_F16x4_REDUCE(res, x) \ -{ \ - int offset = LM_GGML_F16_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - res = wasm_f32x4_extract_lane(x[0], 0) + \ - wasm_f32x4_extract_lane(x[0], 1) + \ - wasm_f32x4_extract_lane(x[0], 2) + \ - wasm_f32x4_extract_lane(x[0], 3); \ -} - -#define LM_GGML_F16_VEC LM_GGML_F16x4 -#define LM_GGML_F16_VEC_ZERO LM_GGML_F16x4_ZERO -#define LM_GGML_F16_VEC_SET1 LM_GGML_F16x4_SET1 -#define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F16x4_LOAD(p) -#define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F16x4_STORE(p, r[i]) -#define LM_GGML_F16_VEC_FMA LM_GGML_F16x4_FMA -#define LM_GGML_F16_VEC_ADD LM_GGML_F16x4_ADD -#define LM_GGML_F16_VEC_MUL LM_GGML_F16x4_MUL -#define LM_GGML_F16_VEC_REDUCE LM_GGML_F16x4_REDUCE - -#elif defined(__SSE3__) - -#define LM_GGML_SIMD - -// F32 SSE - -#define LM_GGML_F32_STEP 32 -#define LM_GGML_F32_EPR 4 - -#define LM_GGML_F32x4 __m128 -#define LM_GGML_F32x4_ZERO _mm_setzero_ps() -#define LM_GGML_F32x4_SET1(x) _mm_set1_ps(x) -#define LM_GGML_F32x4_LOAD _mm_loadu_ps -#define LM_GGML_F32x4_STORE _mm_storeu_ps -#if defined(__FMA__) - // TODO: Does this work? - #define LM_GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a) -#else - #define LM_GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a) -#endif -#define LM_GGML_F32x4_ADD _mm_add_ps -#define LM_GGML_F32x4_MUL _mm_mul_ps -#define LM_GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = LM_GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm_add_ps(x[i], x[offset+i]); \ - } \ - const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \ - res = (lm_ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \ -} -// TODO: is this optimal ? - -#define LM_GGML_F32_VEC LM_GGML_F32x4 -#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x4_ZERO -#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x4_SET1 -#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x4_LOAD -#define LM_GGML_F32_VEC_STORE LM_GGML_F32x4_STORE -#define LM_GGML_F32_VEC_FMA LM_GGML_F32x4_FMA -#define LM_GGML_F32_VEC_ADD LM_GGML_F32x4_ADD -#define LM_GGML_F32_VEC_MUL LM_GGML_F32x4_MUL -#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x4_REDUCE - -// F16 SSE - -#define LM_GGML_F16_STEP 32 -#define LM_GGML_F16_EPR 4 - -static inline __m128 __sse_f16x4_load(lm_ggml_fp16_t *x) { - float tmp[4]; - - tmp[0] = LM_GGML_FP16_TO_FP32(x[0]); - tmp[1] = LM_GGML_FP16_TO_FP32(x[1]); - tmp[2] = LM_GGML_FP16_TO_FP32(x[2]); - tmp[3] = LM_GGML_FP16_TO_FP32(x[3]); - - return _mm_loadu_ps(tmp); -} - -static inline void __sse_f16x4_store(lm_ggml_fp16_t *x, __m128 y) { - float arr[4]; - - _mm_storeu_ps(arr, y); - - x[0] = LM_GGML_FP32_TO_FP16(arr[0]); - x[1] = LM_GGML_FP32_TO_FP16(arr[1]); - x[2] = LM_GGML_FP32_TO_FP16(arr[2]); - x[3] = LM_GGML_FP32_TO_FP16(arr[3]); -} - -#define LM_GGML_F32Cx4 __m128 -#define LM_GGML_F32Cx4_ZERO _mm_setzero_ps() -#define LM_GGML_F32Cx4_SET1(x) _mm_set1_ps(x) -#define LM_GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x) -#define LM_GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y) -#define LM_GGML_F32Cx4_FMA LM_GGML_F32x4_FMA -#define LM_GGML_F32Cx4_ADD _mm_add_ps -#define LM_GGML_F32Cx4_MUL _mm_mul_ps -#define LM_GGML_F32Cx4_REDUCE LM_GGML_F32x4_REDUCE - -#define LM_GGML_F16_VEC LM_GGML_F32Cx4 -#define LM_GGML_F16_VEC_ZERO LM_GGML_F32Cx4_ZERO -#define LM_GGML_F16_VEC_SET1 LM_GGML_F32Cx4_SET1 -#define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F32Cx4_LOAD(p) -#define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F32Cx4_STORE(p, r[i]) -#define LM_GGML_F16_VEC_FMA LM_GGML_F32Cx4_FMA -#define LM_GGML_F16_VEC_ADD LM_GGML_F32Cx4_ADD -#define LM_GGML_F16_VEC_MUL LM_GGML_F32Cx4_MUL -#define LM_GGML_F16_VEC_REDUCE LM_GGML_F32Cx4_REDUCE - -#elif defined(__loongarch_asx) - -#define LM_GGML_SIMD - -// F32 LASX -#define LM_GGML_F32_STEP 32 -#define LM_GGML_F32_EPR 8 - -#define LM_GGML_F32x8 __m256 -#define LM_GGML_F32x8_ZERO (__m256)__lasx_xvldi(0) -#define LM_GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x)) -#define LM_GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0) -#define LM_GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0) -#define LM_GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a) -#define LM_GGML_F32x8_ADD __lasx_xvfadd_s -#define LM_GGML_F32x8_MUL __lasx_xvfmul_s -#define LM_GGML_F32x8_REDUCE(res, x) \ -do { \ - int offset = LM_GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ - } \ - float *tmp_p = (float *)&x[0]; \ - res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \ -} while (0) -// TODO: is this optimal ? - -#define LM_GGML_F32_VEC LM_GGML_F32x8 -#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x8_ZERO -#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x8_SET1 -#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x8_LOAD -#define LM_GGML_F32_VEC_STORE LM_GGML_F32x8_STORE -#define LM_GGML_F32_VEC_FMA LM_GGML_F32x8_FMA -#define LM_GGML_F32_VEC_ADD LM_GGML_F32x8_ADD -#define LM_GGML_F32_VEC_MUL LM_GGML_F32x8_MUL -#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x8_REDUCE - -// F16 LASX - -#define LM_GGML_F16_STEP 32 -#define LM_GGML_F16_EPR 8 - -// F16 arithmetic is not supported by AVX, so we use F32 instead - -#define LM_GGML_F32Cx8 __m256 -#define LM_GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0) -#define LM_GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x)) - -static inline __m256 __lasx_f32cx8_load(const lm_ggml_fp16_t * x) { - float tmp[8]; - - for (int i = 0; i < 8; i++) { - tmp[i] = LM_GGML_FP16_TO_FP32(x[i]); - } - - return (__m256)__lasx_xvld(tmp, 0); -} -static inline void __lasx_f32cx8_store(lm_ggml_fp16_t * x, __m256 y) { - float arr[8]; - - __lasx_xvst(y, arr, 0); - - for (int i = 0; i < 8; i++) { - x[i] = LM_GGML_FP32_TO_FP16(arr[i]); - } -} -#define LM_GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x) -#define LM_GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y) - -#define LM_GGML_F32Cx8_FMA LM_GGML_F32x8_FMA -#define LM_GGML_F32Cx8_ADD __lasx_xvfadd_s -#define LM_GGML_F32Cx8_MUL __lasx_xvfmul_s -#define LM_GGML_F32Cx8_REDUCE LM_GGML_F32x8_REDUCE - -#define LM_GGML_F16_VEC LM_GGML_F32Cx8 -#define LM_GGML_F16_VEC_ZERO LM_GGML_F32Cx8_ZERO -#define LM_GGML_F16_VEC_SET1 LM_GGML_F32Cx8_SET1 -#define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F32Cx8_LOAD(p) -#define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F32Cx8_STORE(p, r[i]) -#define LM_GGML_F16_VEC_FMA LM_GGML_F32Cx8_FMA -#define LM_GGML_F16_VEC_ADD LM_GGML_F32Cx8_ADD -#define LM_GGML_F16_VEC_MUL LM_GGML_F32Cx8_MUL -#define LM_GGML_F16_VEC_REDUCE LM_GGML_F32Cx8_REDUCE - -#elif defined(__loongarch_sx) - -#define LM_GGML_SIMD - -// F32 LSX - -#define LM_GGML_F32_STEP 32 -#define LM_GGML_F32_EPR 4 - -#define LM_GGML_F32x4 __m128 -#define LM_GGML_F32x4_ZERO __lsx_vldi(0) -#define LM_GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) -#define LM_GGML_F32x4_LOAD(x) __lsx_vld((x), 0) -#define LM_GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0) -#define LM_GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a) -#define LM_GGML_F32x4_ADD __lsx_vfadd_s -#define LM_GGML_F32x4_MUL __lsx_vfmul_s -#define LM_GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = LM_GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ - } \ - __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \ - tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \ - tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ - const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \ - tmp = __lsx_vsrli_d((__m128i)t0, 32); \ - tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \ - tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ - res = (lm_ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \ -} - -#define LM_GGML_F32_VEC LM_GGML_F32x4 -#define LM_GGML_F32_VEC_ZERO LM_GGML_F32x4_ZERO -#define LM_GGML_F32_VEC_SET1 LM_GGML_F32x4_SET1 -#define LM_GGML_F32_VEC_LOAD LM_GGML_F32x4_LOAD -#define LM_GGML_F32_VEC_STORE LM_GGML_F32x4_STORE -#define LM_GGML_F32_VEC_FMA LM_GGML_F32x4_FMA -#define LM_GGML_F32_VEC_ADD LM_GGML_F32x4_ADD -#define LM_GGML_F32_VEC_MUL LM_GGML_F32x4_MUL -#define LM_GGML_F32_VEC_REDUCE LM_GGML_F32x4_REDUCE - -// F16 LSX - -#define LM_GGML_F16_STEP 32 -#define LM_GGML_F16_EPR 4 - -static inline __m128 __lsx_f16x4_load(const lm_ggml_fp16_t * x) { - float tmp[4]; - - tmp[0] = LM_GGML_FP16_TO_FP32(x[0]); - tmp[1] = LM_GGML_FP16_TO_FP32(x[1]); - tmp[2] = LM_GGML_FP16_TO_FP32(x[2]); - tmp[3] = LM_GGML_FP16_TO_FP32(x[3]); - - return __lsx_vld(tmp, 0); -} - -static inline void __lsx_f16x4_store(lm_ggml_fp16_t * x, __m128 y) { - float arr[4]; - - __lsx_vst(y, arr, 0); - - x[0] = LM_GGML_FP32_TO_FP16(arr[0]); - x[1] = LM_GGML_FP32_TO_FP16(arr[1]); - x[2] = LM_GGML_FP32_TO_FP16(arr[2]); - x[3] = LM_GGML_FP32_TO_FP16(arr[3]); -} - -#define LM_GGML_F32Cx4 __m128 -#define LM_GGML_F32Cx4_ZERO __lsx_vldi(0) -#define LM_GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) -#define LM_GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x) -#define LM_GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y) -#define LM_GGML_F32Cx4_FMA LM_GGML_F32x4_FMA -#define LM_GGML_F32Cx4_ADD __lsx_vfadd_s -#define LM_GGML_F32Cx4_MUL __lsx_vfmul_s -#define LM_GGML_F32Cx4_REDUCE LM_GGML_F32x4_REDUCE - -#define LM_GGML_F16_VEC LM_GGML_F32Cx4 -#define LM_GGML_F16_VEC_ZERO LM_GGML_F32Cx4_ZERO -#define LM_GGML_F16_VEC_SET1 LM_GGML_F32Cx4_SET1 -#define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F32Cx4_LOAD(p) -#define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F32Cx4_STORE(p, r[i]) -#define LM_GGML_F16_VEC_FMA LM_GGML_F32Cx4_FMA -#define LM_GGML_F16_VEC_ADD LM_GGML_F32Cx4_ADD -#define LM_GGML_F16_VEC_MUL LM_GGML_F32Cx4_MUL -#define LM_GGML_F16_VEC_REDUCE LM_GGML_F32Cx4_REDUCE - -#endif - -// LM_GGML_F32_ARR / LM_GGML_F16_ARR -// number of registers to use per step -#ifdef LM_GGML_SIMD -#define LM_GGML_F32_ARR (LM_GGML_F32_STEP/LM_GGML_F32_EPR) -#define LM_GGML_F16_ARR (LM_GGML_F16_STEP/LM_GGML_F16_EPR) -#endif - // // ggml object // @@ -2032,19734 +878,5876 @@ struct lm_ggml_context_container { }; // -// Threading defs +// data types // -typedef pthread_t lm_ggml_thread_t; +static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = { + "NONE", -#if defined(_WIN32) + "DUP", + "ADD", + "ADD1", + "ACC", + "SUB", + "MUL", + "DIV", + "SQR", + "SQRT", + "LOG", + "SIN", + "COS", + "SUM", + "SUM_ROWS", + "MEAN", + "ARGMAX", + "COUNT_EQUAL", + "REPEAT", + "REPEAT_BACK", + "CONCAT", + "SILU_BACK", + "NORM", + "RMS_NORM", + "RMS_NORM_BACK", + "GROUP_NORM", -typedef CONDITION_VARIABLE lm_ggml_cond_t; -typedef SRWLOCK lm_ggml_mutex_t; + "MUL_MAT", + "MUL_MAT_ID", + "OUT_PROD", -#define lm_ggml_mutex_init(m) InitializeSRWLock(m) -#define lm_ggml_mutex_destroy(m) -#define lm_ggml_mutex_lock(m) AcquireSRWLockExclusive(m) -#define lm_ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m) -#define lm_ggml_mutex_lock_shared(m) AcquireSRWLockShared(m) -#define lm_ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m) + "SCALE", + "SET", + "CPY", + "CONT", + "RESHAPE", + "VIEW", + "PERMUTE", + "TRANSPOSE", + "GET_ROWS", + "GET_ROWS_BACK", + "DIAG", + "DIAG_MASK_INF", + "DIAG_MASK_ZERO", + "SOFT_MAX", + "SOFT_MAX_BACK", + "ROPE", + "ROPE_BACK", + "CLAMP", + "CONV_TRANSPOSE_1D", + "IM2COL", + "IM2COL_BACK", + "CONV_TRANSPOSE_2D", + "POOL_1D", + "POOL_2D", + "POOL_2D_BACK", + "UPSCALE", + "PAD", + "ARANGE", + "TIMESTEP_EMBEDDING", + "ARGSORT", + "LEAKY_RELU", -#define lm_ggml_cond_init(c) InitializeConditionVariable(c) -#define lm_ggml_cond_destroy(c) -#define lm_ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED) -#define lm_ggml_cond_broadcast(c) WakeAllConditionVariable(c) + "FLASH_ATTN_EXT", + "FLASH_ATTN_BACK", + "SSM_CONV", + "SSM_SCAN", + "WIN_PART", + "WIN_UNPART", + "GET_REL_POS", + "ADD_REL_POS", + "RWKV_WKV6", -#define lm_ggml_thread_create pthread_create -#define lm_ggml_thread_join pthread_join + "UNARY", -#else + "MAP_UNARY", + "MAP_BINARY", -typedef pthread_cond_t lm_ggml_cond_t; -typedef pthread_mutex_t lm_ggml_mutex_t; + "MAP_CUSTOM1_F32", + "MAP_CUSTOM2_F32", + "MAP_CUSTOM3_F32", -#define lm_ggml_mutex_init(m) pthread_mutex_init(m, NULL) -#define lm_ggml_mutex_destroy(m) pthread_mutex_destroy(m) -#define lm_ggml_mutex_lock(m) pthread_mutex_lock(m) -#define lm_ggml_mutex_unlock(m) pthread_mutex_unlock(m) -#define lm_ggml_mutex_lock_shared(m) pthread_mutex_lock(m) -#define lm_ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m) + "MAP_CUSTOM1", + "MAP_CUSTOM2", + "MAP_CUSTOM3", -#define lm_ggml_lock_init(x) UNUSED(x) -#define lm_ggml_lock_destroy(x) UNUSED(x) -#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) -#define lm_ggml_lock_lock(x) _mm_pause() -#else -#define lm_ggml_lock_lock(x) UNUSED(x) -#endif -#define lm_ggml_lock_unlock(x) UNUSED(x) - -#define LM_GGML_LOCK_INITIALIZER 0 -#define lm_ggml_cond_init(c) pthread_cond_init(c, NULL) -#define lm_ggml_cond_destroy(c) pthread_cond_destroy(c) -#define lm_ggml_cond_wait(c, m) pthread_cond_wait(c, m) -#define lm_ggml_cond_broadcast(c) pthread_cond_broadcast(c) - -#define lm_ggml_thread_create pthread_create -#define lm_ggml_thread_join pthread_join - -#endif + "CROSS_ENTROPY_LOSS", + "CROSS_ENTROPY_LOSS_BACK", + "OPT_STEP_ADAMW", +}; -// Threadpool def -struct lm_ggml_threadpool { - lm_ggml_mutex_t mutex; // mutex for cond.var - lm_ggml_cond_t cond; // cond.var for waiting for new work +static_assert(LM_GGML_OP_COUNT == 81, "LM_GGML_OP_COUNT != 81"); - struct lm_ggml_cgraph * cgraph; - struct lm_ggml_cplan * cplan; +static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { + "none", - // synchronization primitives - atomic_int n_graph; // incremented when there is work to be done (i.e each graph) - atomic_int LM_GGML_CACHE_ALIGN n_barrier; - atomic_int LM_GGML_CACHE_ALIGN n_barrier_passed; - atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. + "x", + "x+y", + "x+y", + "view(x,nb,offset)+=y->x", + "x-y", + "x*y", + "x/y", + "x^2", + "√x", + "log(x)", + "sin(x)", + "cos(x)", + "Σx", + "Σx_k", + "Σx/n", + "argmax(x)", + "count_equal(x)", + "repeat(x)", + "repeat_back(x)", + "concat(x, y)", + "silu_back(x)", + "norm(x)", + "rms_norm(x)", + "rms_norm_back(x)", + "group_norm(x)", - // these are atomic as an annotation for thread-sanitizer - atomic_bool stop; // Used for stopping the threadpool altogether - atomic_bool pause; // Used for pausing the threadpool or individual threads - atomic_bool abort; // Used for aborting processing of a graph + "X*Y", + "X[i]*Y", + "X*Y", - struct lm_ggml_compute_state * workers; // per thread state - int n_threads_max; // number of threads in the pool - atomic_int n_threads_cur; // number of threads used in the current graph + "x*v", + "y-\\>view(x)", + "x-\\>y", + "cont(x)", + "reshape(x)", + "view(x)", + "permute(x)", + "transpose(x)", + "get_rows(x)", + "get_rows_back(x)", + "diag(x)", + "diag_mask_inf(x)", + "diag_mask_zero(x)", + "soft_max(x)", + "soft_max_back(x)", + "rope(x)", + "rope_back(x)", + "clamp(x)", + "conv_transpose_1d(x)", + "im2col(x)", + "im2col_back(x)", + "conv_transpose_2d(x)", + "pool_1d(x)", + "pool_2d(x)", + "pool_2d_back(x)", + "upscale(x)", + "pad(x)", + "arange(start, stop, step)", + "timestep_embedding(timesteps, dim, max_period)", + "argsort(x)", + "leaky_relu(x)", - int32_t prio; // Scheduling priority - uint32_t poll; // Polling level (0 - no polling) + "flash_attn_ext(x)", + "flash_attn_back(x)", + "ssm_conv(x)", + "ssm_scan(x)", + "win_part(x)", + "win_unpart(x)", + "get_rel_pos(x)", + "add_rel_pos(x)", + "rwkv_wkv6(k, v, r, tf, td, s)", - enum lm_ggml_status ec; -}; + "unary(x)", -// Per-thread state -struct lm_ggml_compute_state { -#ifndef LM_GGML_USE_OPENMP - lm_ggml_thread_t thrd; - bool cpumask[LM_GGML_MAX_N_THREADS]; - int last_graph; - bool pending; -#endif - struct lm_ggml_threadpool * threadpool; - int ith; -}; + "f(x)", + "f(x,y)", -struct lm_ggml_compute_params { - // ith = thread index, nth = number of threads - int ith, nth; + "custom_f32(x)", + "custom_f32(x,y)", + "custom_f32(x,y,z)", - // work buffer for all threads - size_t wsize; - void * wdata; + "custom(x)", + "custom(x,y)", + "custom(x,y,z)", - struct lm_ggml_threadpool * threadpool; + "cross_entropy_loss(x,y)", + "cross_entropy_loss_back(x,y)", + "adamw(x)", }; -// -// fundamental operations -// - -inline static void lm_ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void lm_ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } +static_assert(LM_GGML_OP_COUNT == 81, "LM_GGML_OP_COUNT != 81"); -inline static void lm_ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } +static_assert(LM_GGML_OP_POOL_COUNT == 2, "LM_GGML_OP_POOL_COUNT != 2"); -inline static void lm_ggml_vec_set_f16(const int n, lm_ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void lm_ggml_vec_set_bf16(const int n, lm_ggml_bf16_t * x, const lm_ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } +static const char * LM_GGML_UNARY_OP_NAME[LM_GGML_UNARY_OP_COUNT] = { + "ABS", + "SGN", + "NEG", + "STEP", + "TANH", + "ELU", + "RELU", + "SIGMOID", + "GELU", + "GELU_QUICK", + "SILU", + "HARDSWISH", + "HARDSIGMOID", + "EXP", +}; -inline static void lm_ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } -inline static void lm_ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } -inline static void lm_ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } -inline static void lm_ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } -inline static void lm_ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } -inline static void lm_ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void lm_ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } -inline static void lm_ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } -inline static void lm_ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } -inline static void lm_ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } +static_assert(LM_GGML_UNARY_OP_COUNT == 14, "LM_GGML_UNARY_OP_COUNT != 14"); -static void lm_ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); -#if defined(LM_GGML_SIMD) - float sumf = 0.0f; - const int np = (n & ~(LM_GGML_F32_STEP - 1)); +static_assert(sizeof(struct lm_ggml_object)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_object size must be a multiple of LM_GGML_MEM_ALIGN"); +static_assert(sizeof(struct lm_ggml_tensor)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_tensor size must be a multiple of LM_GGML_MEM_ALIGN"); - LM_GGML_F32_VEC sum[LM_GGML_F32_ARR] = { LM_GGML_F32_VEC_ZERO }; - LM_GGML_F32_VEC ax[LM_GGML_F32_ARR]; - LM_GGML_F32_VEC ay[LM_GGML_F32_ARR]; +//////////////////////////////////////////////////////////////////////////////// - for (int i = 0; i < np; i += LM_GGML_F32_STEP) { - for (int j = 0; j < LM_GGML_F32_ARR; j++) { - ax[j] = LM_GGML_F32_VEC_LOAD(x + i + j*LM_GGML_F32_EPR); - ay[j] = LM_GGML_F32_VEC_LOAD(y + i + j*LM_GGML_F32_EPR); +void lm_ggml_print_object(const struct lm_ggml_object * obj) { + LM_GGML_LOG_INFO(" - lm_ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n", + obj->type, obj->offs, obj->size, (const void *) obj->next); +} - sum[j] = LM_GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); - } - } +void lm_ggml_print_objects(const struct lm_ggml_context * ctx) { + struct lm_ggml_object * obj = ctx->objects_begin; - // reduce sum0..sum3 to sum0 - LM_GGML_F32_VEC_REDUCE(sumf, sum); + LM_GGML_LOG_INFO("%s: objects in context %p:\n", __func__, (const void *) ctx); - // leftovers - for (int i = np; i < n; ++i) { - sumf += x[i]*y[i]; - } -#else - // scalar - lm_ggml_float sumf = 0.0; - for (int i = 0; i < n; ++i) { - sumf += (lm_ggml_float)(x[i]*y[i]); + while (obj != NULL) { + lm_ggml_print_object(obj); + obj = obj->next; } -#endif - *s = sumf; + LM_GGML_LOG_INFO("%s: --- end ---\n", __func__); } -static void lm_ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, lm_ggml_bf16_t * restrict x, size_t bx, lm_ggml_bf16_t * restrict y, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - int i = 0; - lm_ggml_float sumf = 0; - -#if defined(__AVX512BF16__) - __m512 c1 = _mm512_setzero_ps(); - __m512 c2 = _mm512_setzero_ps(); - for (; i + 64 <= n; i += 64) { - c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))), - m512bh(_mm512_loadu_si512((y + i)))); - c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))), - m512bh(_mm512_loadu_si512((y + i + 32)))); - } - sumf += (lm_ggml_float)_mm512_reduce_add_ps(c1); - sumf += (lm_ggml_float)_mm512_reduce_add_ps(c2); - -#elif defined(__AVX512F__) -#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16)) - __m512 c1 = _mm512_setzero_ps(); - __m512 c2 = _mm512_setzero_ps(); - for (; i + 32 <= n; i += 32) { - c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1); - c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2); - } - sumf += (lm_ggml_float)_mm512_reduce_add_ps(c1); - sumf += (lm_ggml_float)_mm512_reduce_add_ps(c2); - -#undef LOAD -#elif defined(__AVX2__) -#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)) - __m256 c1 = _mm256_setzero_ps(); - __m256 c2 = _mm256_setzero_ps(); - __m256 c3 = _mm256_setzero_ps(); - __m256 c4 = _mm256_setzero_ps(); - for (; i + 32 <= n; i += 32) { - c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1); - c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2); - c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3); - c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4); - } - __m128 g; - c1 = _mm256_add_ps(_mm256_add_ps(c1, c3), - _mm256_add_ps(c2, c4)); - g = _mm_add_ps(_mm256_extractf128_ps(c1, 1), - _mm256_castps256_ps128(c1)); - g = _mm_add_ps(g, _mm_movehl_ps(g, g)); - g = _mm_add_ss(g, _mm_movehdup_ps(g)); - sumf += (lm_ggml_float)_mm_cvtss_f32(g); - -#undef LOAD -#endif +int64_t lm_ggml_nelements(const struct lm_ggml_tensor * tensor) { + static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); - for (; i < n; ++i) { - sumf += (lm_ggml_float)(LM_GGML_BF16_TO_FP32(x[i]) * - LM_GGML_BF16_TO_FP32(y[i])); - } - *s = sumf; + return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; } -static void lm_ggml_vec_dot_f16(int n, float * restrict s, size_t bs, lm_ggml_fp16_t * restrict x, size_t bx, lm_ggml_fp16_t * restrict y, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - lm_ggml_float sumf = 0.0; - -#if defined(LM_GGML_SIMD) - const int np = (n & ~(LM_GGML_F16_STEP - 1)); - - LM_GGML_F16_VEC sum[LM_GGML_F16_ARR] = { LM_GGML_F16_VEC_ZERO }; - - LM_GGML_F16_VEC ax[LM_GGML_F16_ARR]; - LM_GGML_F16_VEC ay[LM_GGML_F16_ARR]; +int64_t lm_ggml_nrows(const struct lm_ggml_tensor * tensor) { + static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); - for (int i = 0; i < np; i += LM_GGML_F16_STEP) { - for (int j = 0; j < LM_GGML_F16_ARR; j++) { - ax[j] = LM_GGML_F16_VEC_LOAD(x + i + j*LM_GGML_F16_EPR, j); - ay[j] = LM_GGML_F16_VEC_LOAD(y + i + j*LM_GGML_F16_EPR, j); + return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} - sum[j] = LM_GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); +size_t lm_ggml_nbytes(const struct lm_ggml_tensor * tensor) { + size_t nbytes; + const size_t blck_size = lm_ggml_blck_size(tensor->type); + if (blck_size == 1) { + nbytes = lm_ggml_type_size(tensor->type); + for (int i = 0; i < LM_GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; } } - - // reduce sum0..sum3 to sum0 - LM_GGML_F16_VEC_REDUCE(sumf, sum); - - // leftovers - for (int i = np; i < n; ++i) { - sumf += (lm_ggml_float)(LM_GGML_FP16_TO_FP32(x[i])*LM_GGML_FP16_TO_FP32(y[i])); - } -#else - for (int i = 0; i < n; ++i) { - sumf += (lm_ggml_float)(LM_GGML_FP16_TO_FP32(x[i])*LM_GGML_FP16_TO_FP32(y[i])); + else { + nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; + for (int i = 1; i < LM_GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + } } -#endif - *s = sumf; + return nbytes; } -// compute LM_GGML_VEC_DOT_UNROLL dot products at once -// xs - x row stride in bytes -inline static void lm_ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, lm_ggml_fp16_t * restrict y) { - lm_ggml_float sumf[LM_GGML_VEC_DOT_UNROLL] = { 0.0 }; - - lm_ggml_fp16_t * restrict x[LM_GGML_VEC_DOT_UNROLL]; +size_t lm_ggml_nbytes_pad(const struct lm_ggml_tensor * tensor) { + return LM_GGML_PAD(lm_ggml_nbytes(tensor), LM_GGML_MEM_ALIGN); +} - for (int i = 0; i < LM_GGML_VEC_DOT_UNROLL; ++i) { - x[i] = (lm_ggml_fp16_t *) ((char *) xv + i*xs); - } +int64_t lm_ggml_blck_size(enum lm_ggml_type type) { + return type_traits[type].blck_size; +} -#if defined(LM_GGML_SIMD) - const int np = (n & ~(LM_GGML_F16_STEP - 1)); +size_t lm_ggml_type_size(enum lm_ggml_type type) { + return type_traits[type].type_size; +} - LM_GGML_F16_VEC sum[LM_GGML_VEC_DOT_UNROLL][LM_GGML_F16_ARR] = { { LM_GGML_F16_VEC_ZERO } }; +size_t lm_ggml_row_size(enum lm_ggml_type type, int64_t ne) { + assert(ne % lm_ggml_blck_size(type) == 0); + return lm_ggml_type_size(type)*ne/lm_ggml_blck_size(type); +} - LM_GGML_F16_VEC ax[LM_GGML_F16_ARR]; - LM_GGML_F16_VEC ay[LM_GGML_F16_ARR]; +double lm_ggml_type_sizef(enum lm_ggml_type type) { + return ((double)(type_traits[type].type_size))/type_traits[type].blck_size; +} - for (int i = 0; i < np; i += LM_GGML_F16_STEP) { - for (int j = 0; j < LM_GGML_F16_ARR; j++) { - ay[j] = LM_GGML_F16_VEC_LOAD(y + i + j*LM_GGML_F16_EPR, j); +const char * lm_ggml_type_name(enum lm_ggml_type type) { + return type < LM_GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE"; +} - for (int k = 0; k < LM_GGML_VEC_DOT_UNROLL; ++k) { - ax[j] = LM_GGML_F16_VEC_LOAD(x[k] + i + j*LM_GGML_F16_EPR, j); +bool lm_ggml_is_quantized(enum lm_ggml_type type) { + return type_traits[type].is_quantized; +} - sum[k][j] = LM_GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); - } - } - } +const char * lm_ggml_op_name(enum lm_ggml_op op) { + return LM_GGML_OP_NAME[op]; +} - // reduce sum0..sum3 to sum0 - for (int k = 0; k < LM_GGML_VEC_DOT_UNROLL; ++k) { - LM_GGML_F16_VEC_REDUCE(sumf[k], sum[k]); - } +const char * lm_ggml_op_symbol(enum lm_ggml_op op) { + return LM_GGML_OP_SYMBOL[op]; +} - // leftovers - for (int i = np; i < n; ++i) { - for (int j = 0; j < LM_GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += (lm_ggml_float)(LM_GGML_FP16_TO_FP32(x[j][i])*LM_GGML_FP16_TO_FP32(y[i])); - } - } -#else - for (int i = 0; i < n; ++i) { - for (int j = 0; j < LM_GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += (lm_ggml_float)(LM_GGML_FP16_TO_FP32(x[j][i])*LM_GGML_FP16_TO_FP32(y[i])); - } - } -#endif +const char * lm_ggml_unary_op_name(enum lm_ggml_unary_op op) { + return LM_GGML_UNARY_OP_NAME[op]; +} - for (int i = 0; i < LM_GGML_VEC_DOT_UNROLL; ++i) { - s[i] = sumf[i]; +const char * lm_ggml_op_desc(const struct lm_ggml_tensor * t) { + if (t->op == LM_GGML_OP_UNARY) { + enum lm_ggml_unary_op uop = lm_ggml_get_unary_op(t); + return lm_ggml_unary_op_name(uop); } + return lm_ggml_op_name(t->op); } -inline static void lm_ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { -#if defined(LM_GGML_SIMD) - const int np = (n & ~(LM_GGML_F32_STEP - 1)); - - LM_GGML_F32_VEC vx = LM_GGML_F32_VEC_SET1(v); +size_t lm_ggml_element_size(const struct lm_ggml_tensor * tensor) { + return lm_ggml_type_size(tensor->type); +} - LM_GGML_F32_VEC ax[LM_GGML_F32_ARR]; - LM_GGML_F32_VEC ay[LM_GGML_F32_ARR]; +bool lm_ggml_is_scalar(const struct lm_ggml_tensor * tensor) { + static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); - for (int i = 0; i < np; i += LM_GGML_F32_STEP) { - for (int j = 0; j < LM_GGML_F32_ARR; j++) { - ax[j] = LM_GGML_F32_VEC_LOAD(x + i + j*LM_GGML_F32_EPR); - ay[j] = LM_GGML_F32_VEC_LOAD(y + i + j*LM_GGML_F32_EPR); - ay[j] = LM_GGML_F32_VEC_FMA(ay[j], ax[j], vx); + return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} - LM_GGML_F32_VEC_STORE(y + i + j*LM_GGML_F32_EPR, ay[j]); - } - } +bool lm_ggml_is_vector(const struct lm_ggml_tensor * tensor) { + static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); - // leftovers - for (int i = np; i < n; ++i) { - y[i] += x[i]*v; - } -#else - // scalar - for (int i = 0; i < n; ++i) { - y[i] += x[i]*v; - } -#endif + return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; } -inline static void lm_ggml_vec_mad_f16(const int n, lm_ggml_fp16_t * restrict y, const lm_ggml_fp16_t * restrict x, const float v) { -#if defined(LM_GGML_SIMD) - const int np = (n & ~(LM_GGML_F16_STEP - 1)); - - LM_GGML_F16_VEC vx = LM_GGML_F16_VEC_SET1(v); +bool lm_ggml_is_matrix(const struct lm_ggml_tensor * tensor) { + static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); - LM_GGML_F16_VEC ax[LM_GGML_F16_ARR]; - LM_GGML_F16_VEC ay[LM_GGML_F16_ARR]; + return tensor->ne[2] == 1 && tensor->ne[3] == 1; +} - for (int i = 0; i < np; i += LM_GGML_F16_STEP) { - for (int j = 0; j < LM_GGML_F16_ARR; j++) { - ax[j] = LM_GGML_F16_VEC_LOAD(x + i + j*LM_GGML_F16_EPR, j); - ay[j] = LM_GGML_F16_VEC_LOAD(y + i + j*LM_GGML_F16_EPR, j); - ay[j] = LM_GGML_F16_VEC_FMA(ay[j], ax[j], vx); +bool lm_ggml_is_3d(const struct lm_ggml_tensor * tensor) { + return tensor->ne[3] == 1; +} - LM_GGML_F16_VEC_STORE(y + i + j*LM_GGML_F16_EPR, ay, j); +int lm_ggml_n_dims(const struct lm_ggml_tensor * tensor) { + for (int i = LM_GGML_MAX_DIMS - 1; i >= 1; --i) { + if (tensor->ne[i] > 1) { + return i + 1; } } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(y[i]) + LM_GGML_FP16_TO_FP32(x[i])*v); - } -#else - // scalar - for (int i = 0; i < n; ++i) { - y[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(y[i]) + LM_GGML_FP16_TO_FP32(x[i])*v); - } -#endif + return 1; } -// xs and vs are byte strides of x and v -inline static void lm_ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { - - const float * restrict x[LM_GGML_VEC_MAD_UNROLL]; - const float * restrict v[LM_GGML_VEC_MAD_UNROLL]; +enum lm_ggml_type lm_ggml_ftype_to_lm_ggml_type(enum lm_ggml_ftype ftype) { + enum lm_ggml_type wtype = LM_GGML_TYPE_COUNT; - for (int i = 0; i < LM_GGML_VEC_MAD_UNROLL; ++i) { - x[i] = (const float *) ((const char *) xv + i*xs); - v[i] = (const float *) ((const char *) vv + i*vs); + switch (ftype) { + case LM_GGML_FTYPE_ALL_F32: wtype = LM_GGML_TYPE_F32; break; + case LM_GGML_FTYPE_MOSTLY_F16: wtype = LM_GGML_TYPE_F16; break; + case LM_GGML_FTYPE_MOSTLY_BF16: wtype = LM_GGML_TYPE_BF16; break; + case LM_GGML_FTYPE_MOSTLY_Q4_0: wtype = LM_GGML_TYPE_Q4_0; break; + case LM_GGML_FTYPE_MOSTLY_Q4_1: wtype = LM_GGML_TYPE_Q4_1; break; + case LM_GGML_FTYPE_MOSTLY_Q5_0: wtype = LM_GGML_TYPE_Q5_0; break; + case LM_GGML_FTYPE_MOSTLY_Q5_1: wtype = LM_GGML_TYPE_Q5_1; break; + case LM_GGML_FTYPE_MOSTLY_Q8_0: wtype = LM_GGML_TYPE_Q8_0; break; + case LM_GGML_FTYPE_MOSTLY_Q2_K: wtype = LM_GGML_TYPE_Q2_K; break; + case LM_GGML_FTYPE_MOSTLY_Q3_K: wtype = LM_GGML_TYPE_Q3_K; break; + case LM_GGML_FTYPE_MOSTLY_Q4_K: wtype = LM_GGML_TYPE_Q4_K; break; + case LM_GGML_FTYPE_MOSTLY_Q5_K: wtype = LM_GGML_TYPE_Q5_K; break; + case LM_GGML_FTYPE_MOSTLY_Q6_K: wtype = LM_GGML_TYPE_Q6_K; break; + case LM_GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = LM_GGML_TYPE_IQ2_XXS; break; + case LM_GGML_FTYPE_MOSTLY_IQ2_XS: wtype = LM_GGML_TYPE_IQ2_XS; break; + case LM_GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = LM_GGML_TYPE_IQ3_XXS; break; + case LM_GGML_FTYPE_MOSTLY_IQ1_S: wtype = LM_GGML_TYPE_IQ1_S; break; + case LM_GGML_FTYPE_MOSTLY_IQ1_M: wtype = LM_GGML_TYPE_IQ1_M; break; + case LM_GGML_FTYPE_MOSTLY_IQ4_NL: wtype = LM_GGML_TYPE_IQ4_NL; break; + case LM_GGML_FTYPE_MOSTLY_IQ4_XS: wtype = LM_GGML_TYPE_IQ4_XS; break; + case LM_GGML_FTYPE_MOSTLY_IQ3_S: wtype = LM_GGML_TYPE_IQ3_S; break; + case LM_GGML_FTYPE_MOSTLY_IQ2_S: wtype = LM_GGML_TYPE_IQ2_S; break; + case LM_GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = LM_GGML_TYPE_Q4_0_4_4; break; + case LM_GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = LM_GGML_TYPE_Q4_0_4_8; break; + case LM_GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = LM_GGML_TYPE_Q4_0_8_8; break; + case LM_GGML_FTYPE_UNKNOWN: wtype = LM_GGML_TYPE_COUNT; break; + case LM_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = LM_GGML_TYPE_COUNT; break; } -#if defined(LM_GGML_SIMD) - const int np = (n & ~(LM_GGML_F32_STEP - 1)); - - LM_GGML_F32_VEC vx[LM_GGML_VEC_MAD_UNROLL]; + LM_GGML_ASSERT(wtype != LM_GGML_TYPE_COUNT); - for (int k = 0; k < LM_GGML_VEC_MAD_UNROLL; ++k) { - vx[k] = LM_GGML_F32_VEC_SET1(v[k][0]); - } + return wtype; +} - LM_GGML_F32_VEC ax[LM_GGML_VEC_MAD_UNROLL][LM_GGML_F32_ARR]; - LM_GGML_F32_VEC ay[LM_GGML_F32_ARR]; +size_t lm_ggml_tensor_overhead(void) { + return LM_GGML_OBJECT_SIZE + LM_GGML_TENSOR_SIZE; +} - for (int i = 0; i < np; i += LM_GGML_F32_STEP) { - for (int j = 0; j < LM_GGML_F32_ARR; j++) { - ay[j] = LM_GGML_F32_VEC_LOAD(y + i + j*LM_GGML_F32_EPR); +bool lm_ggml_is_transposed(const struct lm_ggml_tensor * tensor) { + return tensor->nb[0] > tensor->nb[1]; +} - for (int k = 0; k < LM_GGML_VEC_MAD_UNROLL; ++k) { - ax[k][j] = LM_GGML_F32_VEC_LOAD(x[k] + i + j*LM_GGML_F32_EPR); - ay[j] = LM_GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); +static bool lm_ggml_is_contiguous_n(const struct lm_ggml_tensor * tensor, int n) { + size_t next_nb = lm_ggml_type_size(tensor->type); + if (tensor->ne[0] != lm_ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) { + return false; + } + next_nb *= tensor->ne[0]/lm_ggml_blck_size(tensor->type); + for (int i = 1; i < LM_GGML_MAX_DIMS; i++) { + if (tensor->ne[i] != 1) { + if (i > n) { + if (tensor->nb[i] != next_nb) { + return false; + } + next_nb *= tensor->ne[i]; + } else { + // this dimension does not need to be contiguous + next_nb = tensor->ne[i]*tensor->nb[i]; } - - LM_GGML_F32_VEC_STORE(y + i + j*LM_GGML_F32_EPR, ay[j]); } } + return true; +} - // leftovers - for (int k = 0; k < LM_GGML_VEC_MAD_UNROLL; ++k) { - for (int i = np; i < n; ++i) { - y[i] += x[k][i]*v[k][0]; - } - } -#else - // scalar - for (int k = 0; k < LM_GGML_VEC_MAD_UNROLL; ++k) { - for (int i = 0; i < n; ++i) { - y[i] += x[k][i]*v[k][0]; - } - } -#endif +bool lm_ggml_is_contiguous(const struct lm_ggml_tensor * tensor) { + return lm_ggml_is_contiguous_0(tensor); +} + +bool lm_ggml_is_contiguous_0(const struct lm_ggml_tensor * tensor) { + return lm_ggml_is_contiguous_n(tensor, 0); +} + +bool lm_ggml_is_contiguous_1(const struct lm_ggml_tensor * tensor) { + return lm_ggml_is_contiguous_n(tensor, 1); } -//inline static void lm_ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } -inline static void lm_ggml_vec_scale_f32(const int n, float * y, const float v) { -#if defined(LM_GGML_USE_ACCELERATE) - vDSP_vsmul(y, 1, &v, y, 1, n); -#elif defined(LM_GGML_SIMD) - const int np = (n & ~(LM_GGML_F32_STEP - 1)); +bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor) { + return lm_ggml_is_contiguous_n(tensor, 2); +} - LM_GGML_F32_VEC vx = LM_GGML_F32_VEC_SET1(v); +bool lm_ggml_is_permuted(const struct lm_ggml_tensor * tensor) { + static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); - LM_GGML_F32_VEC ay[LM_GGML_F32_ARR]; + return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; +} - for (int i = 0; i < np; i += LM_GGML_F32_STEP) { - for (int j = 0; j < LM_GGML_F32_ARR; j++) { - ay[j] = LM_GGML_F32_VEC_LOAD(y + i + j*LM_GGML_F32_EPR); - ay[j] = LM_GGML_F32_VEC_MUL(ay[j], vx); +static inline bool lm_ggml_is_padded_1d(const struct lm_ggml_tensor * tensor) { + static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); - LM_GGML_F32_VEC_STORE(y + i + j*LM_GGML_F32_EPR, ay[j]); - } - } + return + tensor->nb[0] == lm_ggml_type_size(tensor->type) && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} - // leftovers - for (int i = np; i < n; ++i) { - y[i] *= v; - } -#else - // scalar - for (int i = 0; i < n; ++i) { - y[i] *= v; +bool lm_ggml_is_empty(const struct lm_ggml_tensor * tensor) { + for (int i = 0; i < LM_GGML_MAX_DIMS; ++i) { + if (tensor->ne[i] == 0) { + // empty if any dimension has no elements + return true; + } } -#endif + return false; } -inline static void lm_ggml_vec_scale_f16(const int n, lm_ggml_fp16_t * y, const float v) { -#if defined(LM_GGML_SIMD) - const int np = (n & ~(LM_GGML_F16_STEP - 1)); +bool lm_ggml_are_same_shape(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { + static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); - LM_GGML_F16_VEC vx = LM_GGML_F16_VEC_SET1(v); + return + (t0->ne[0] == t1->ne[0]) && + (t0->ne[1] == t1->ne[1]) && + (t0->ne[2] == t1->ne[2]) && + (t0->ne[3] == t1->ne[3]); +} - LM_GGML_F16_VEC ay[LM_GGML_F16_ARR]; +bool lm_ggml_are_same_stride(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { + static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); - for (int i = 0; i < np; i += LM_GGML_F16_STEP) { - for (int j = 0; j < LM_GGML_F16_ARR; j++) { - ay[j] = LM_GGML_F16_VEC_LOAD(y + i + j*LM_GGML_F16_EPR, j); - ay[j] = LM_GGML_F16_VEC_MUL(ay[j], vx); + return + (t0->nb[0] == t1->nb[0]) && + (t0->nb[1] == t1->nb[1]) && + (t0->nb[2] == t1->nb[2]) && + (t0->nb[3] == t1->nb[3]); +} - LM_GGML_F16_VEC_STORE(y + i + j*LM_GGML_F16_EPR, ay, j); - } - } +// check if t1 can be represented as a repeatition of t0 +bool lm_ggml_can_repeat(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { + static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); - // leftovers - for (int i = np; i < n; ++i) { - y[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(y[i])*v); - } -#else - // scalar - for (int i = 0; i < n; ++i) { - y[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(y[i])*v); - } -#endif + return lm_ggml_is_empty(t0) ? lm_ggml_is_empty(t1) : + (t1->ne[0]%t0->ne[0] == 0) && + (t1->ne[1]%t0->ne[1] == 0) && + (t1->ne[2]%t0->ne[2] == 0) && + (t1->ne[3]%t0->ne[3] == 0); } -inline static void lm_ggml_vec_norm_f32 (const int n, float * s, const float * x) { lm_ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } -inline static void lm_ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } -inline static void lm_ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } -inline static void lm_ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } -inline static void lm_ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); } -inline static void lm_ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); } -inline static void lm_ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } -inline static void lm_ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } -inline static void lm_ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } -inline static void lm_ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } -inline static void lm_ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); } -inline static void lm_ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } -inline static void lm_ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); } -inline static void lm_ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); } -// TODO: optimize performance -inline static void lm_ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } -inline static void lm_ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } -inline static void lm_ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); } - -static const float GELU_COEF_A = 0.044715f; -static const float GELU_QUICK_COEF = -1.702f; -static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - -inline static float lm_ggml_gelu_f32(float x) { - return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -inline static void lm_ggml_vec_gelu_f16(const int n, lm_ggml_fp16_t * y, const lm_ggml_fp16_t * x) { - const uint16_t * i16 = (const uint16_t *) x; - for (int i = 0; i < n; ++i) { - y[i] = lm_ggml_table_gelu_f16[i16[i]]; - } +static inline bool lm_ggml_can_repeat_rows(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { + static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[0] == t1->ne[0]) && lm_ggml_can_repeat(t0, t1); } -#ifdef LM_GGML_GELU_FP16 -inline static void lm_ggml_vec_gelu_f32(const int n, float * y, const float * x) { - uint16_t t; - for (int i = 0; i < n; ++i) { - if (x[i] <= -10.0f) { - y[i] = 0.0f; - } else if (x[i] >= 10.0f) { - y[i] = x[i]; - } else { - lm_ggml_fp16_t fp16 = LM_GGML_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = LM_GGML_FP16_TO_FP32(lm_ggml_table_gelu_f16[t]); +// assert that pointer is aligned to LM_GGML_MEM_ALIGN +#define LM_GGML_ASSERT_ALIGNED(ptr) \ + LM_GGML_ASSERT(((uintptr_t) (ptr))%LM_GGML_MEM_ALIGN == 0) + +//////////////////////////////////////////////////////////////////////////////// + +struct lm_ggml_context * lm_ggml_init(struct lm_ggml_init_params params) { + static bool is_first_call = true; + + lm_ggml_critical_section_start(); + + if (is_first_call) { + // initialize time system (required on Windows) + lm_ggml_time_init(); + + for (int i = 0; i < (1 << 16); ++i) { + union { + uint16_t u16; + lm_ggml_fp16_t fp16; + } u = {i}; + lm_ggml_table_f32_f16[i] = LM_GGML_COMPUTE_FP16_TO_FP32(u.fp16); } + + is_first_call = false; } -} -#else -inline static void lm_ggml_vec_gelu_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { - y[i] = lm_ggml_gelu_f32(x[i]); - } -} -#endif -inline static float lm_ggml_gelu_quick_f32(float x) { - return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); -} + lm_ggml_critical_section_end(); -//inline static void lm_ggml_vec_gelu_quick_f16(const int n, lm_ggml_fp16_t * y, const lm_ggml_fp16_t * x) { -// const uint16_t * i16 = (const uint16_t *) x; -// for (int i = 0; i < n; ++i) { -// y[i] = lm_ggml_table_gelu_quick_f16[i16[i]]; -// } -//} + struct lm_ggml_context * ctx = LM_GGML_MALLOC(sizeof(struct lm_ggml_context)); -#ifdef LM_GGML_GELU_QUICK_FP16 -inline static void lm_ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { - uint16_t t; - for (int i = 0; i < n; ++i) { - lm_ggml_fp16_t fp16 = LM_GGML_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = LM_GGML_FP16_TO_FP32(lm_ggml_table_gelu_quick_f16[t]); - } -} -#else -inline static void lm_ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { - y[i] = lm_ggml_gelu_quick_f32(x[i]); + // allow to call lm_ggml_init with 0 size + if (params.mem_size == 0) { + params.mem_size = LM_GGML_MEM_ALIGN; } -} -#endif -// Sigmoid Linear Unit (SiLU) function -inline static float lm_ggml_silu_f32(float x) { - return x/(1.0f + expf(-x)); -} + const size_t mem_size = params.mem_buffer ? params.mem_size : LM_GGML_PAD(params.mem_size, LM_GGML_MEM_ALIGN); -#if __FINITE_MATH_ONLY__ -#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix" -#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461" -#endif + *ctx = (struct lm_ggml_context) { + /*.mem_size =*/ mem_size, + /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : lm_ggml_aligned_malloc(mem_size), + /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, + /*.no_alloc =*/ params.no_alloc, + /*.n_objects =*/ 0, + /*.objects_begin =*/ NULL, + /*.objects_end =*/ NULL, + }; -#if defined(__ARM_NEON) && defined(__aarch64__) - -// adapted from arm limited optimized routine -// the maximum error is 1.45358 plus 0.5 ulps -// numbers above 88.38 will flush to infinity -// numbers beneath -103.97 will flush to zero -inline static float32x4_t lm_ggml_v_expf(float32x4_t x) { - const float32x4_t r = vdupq_n_f32(0x1.8p23f); - const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f)); - const float32x4_t n = vsubq_f32(z, r); - const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n, - vdupq_n_f32(0x1.7f7d1cp-20f)); - const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23); - const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1)))); - const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126)); - const float32x4_t u = vmulq_f32(b, b); - const float32x4_t j = vfmaq_f32( - vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b), - vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b), - vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u); - if (!vpaddd_u64(vreinterpretq_u64_u32(c))) - return vfmaq_f32(k, j, k); - const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000)); - const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000))); - const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d)); - return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1), - vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); -} - -// computes silu x/(1+exp(-x)) in single precision vector -inline static float32x4_t lm_ggml_v_silu(float32x4_t x) { - const float32x4_t one = vdupq_n_f32(1.0f); - const float32x4_t zero = vdupq_n_f32(0.0f); - const float32x4_t neg_x = vsubq_f32(zero, x); - const float32x4_t exp_neg_x = lm_ggml_v_expf(neg_x); - const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); - return vdivq_f32(x, one_plus_exp_neg_x); -} - -#elif defined(__AVX512F__) && defined(__AVX512DQ__) - -// adapted from arm limited optimized routine -// the maximum error is 1.45358 plus 0.5 ulps -// numbers above 88.38 will flush to infinity -// numbers beneath -103.97 will flush to zero -inline static __m512 lm_ggml_v_expf(__m512 x) { - const __m512 r = _mm512_set1_ps(0x1.8p23f); - const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); - const __m512 n = _mm512_sub_ps(z, r); - const __m512 b = - _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), - _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); - const __mmask16 d = - _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); - const __m512 u = _mm512_mul_ps(b, b); - const __m512 j = _mm512_fmadd_ps( - _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, - _mm512_set1_ps(0x1.573e2ep-5f)), - u, - _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, - _mm512_set1_ps(0x1.fffdb6p-2f))), - u, - _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); - const __m512 res = _mm512_scalef_ps(j, n); - if (_mm512_kortestz(d, d)) - return res; - const __m512 zero = _mm512_setzero_ps(); - const __m512 alt = _mm512_mask_blend_ps( - _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); - return _mm512_mask_blend_ps(d, res, alt); -} - -// computes silu x/(1+exp(-x)) in single precision vector -inline static __m512 lm_ggml_v_silu(__m512 x) { - const __m512 one = _mm512_set1_ps(1); - const __m512 zero = _mm512_setzero_ps(); - const __m512 neg_x = _mm512_sub_ps(zero, x); - const __m512 exp_neg_x = lm_ggml_v_expf(neg_x); - const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); - return _mm512_div_ps(x, one_plus_exp_neg_x); -} - -#elif defined(__AVX2__) && defined(__FMA__) - -// adapted from arm limited optimized routine -// the maximum error is 1.45358 plus 0.5 ulps -// numbers above 88.38 will flush to infinity -// numbers beneath -103.97 will flush to zero -inline static __m256 lm_ggml_v_expf(__m256 x) { - const __m256 r = _mm256_set1_ps(0x1.8p23f); - const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r); - const __m256 n = _mm256_sub_ps(z, r); - const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f), - _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x)); - const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23); - const __m256 k = _mm256_castsi256_ps( - _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1)))); - const __m256i c = _mm256_castps_si256( - _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), - _mm256_set1_ps(126), _CMP_GT_OQ)); - const __m256 u = _mm256_mul_ps(b, b); - const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b, - _mm256_set1_ps(0x1.573e2ep-5f)), u, - _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b, - _mm256_set1_ps(0x1.fffdb6p-2f))), - u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b)); - if (!_mm256_movemask_ps(_mm256_castsi256_ps(c))) - return _mm256_fmadd_ps(j, k, k); - const __m256i g = _mm256_and_si256( - _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)), - _mm256_set1_epi32(0x82000000u)); - const __m256 s1 = - _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u))); - const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g)); - const __m256i d = _mm256_castps_si256( - _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), - _mm256_set1_ps(192), _CMP_GT_OQ)); - return _mm256_or_ps( - _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)), - _mm256_andnot_ps( - _mm256_castsi256_ps(d), - _mm256_or_ps( - _mm256_and_ps(_mm256_castsi256_ps(c), - _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)), - _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k))))); -} - -// computes silu x/(1+exp(-x)) in single precision vector -inline static __m256 lm_ggml_v_silu(__m256 x) { - const __m256 one = _mm256_set1_ps(1); - const __m256 zero = _mm256_setzero_ps(); - const __m256 neg_x = _mm256_sub_ps(zero, x); - const __m256 exp_neg_x = lm_ggml_v_expf(neg_x); - const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); - return _mm256_div_ps(x, one_plus_exp_neg_x); -} - -#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON - -#if defined(__FMA__) -#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z) -#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z) -#else -#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z) -#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y)) -#endif + LM_GGML_ASSERT(ctx->mem_buffer != NULL); -// adapted from arm limited optimized routine -// the maximum error is 1.45358 plus 0.5 ulps -// numbers above 88.38 will flush to infinity -// numbers beneath -103.97 will flush to zero -inline static __m128 lm_ggml_v_expf(__m128 x) { - const __m128 r = _mm_set1_ps(0x1.8p23f); - const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r); - const __m128 n = _mm_sub_ps(z, r); - const __m128 b = - NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x)); - const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23); - const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1)))); - const __m128i c = - _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126))); - const __m128 u = _mm_mul_ps(b, b); - const __m128 j = - MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u, - MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))), - u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b)); - if (!_mm_movemask_epi8(c)) - return MADD128(j, k, k); - const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())), - _mm_set1_epi32(0x82000000u)); - const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u))); - const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g)); - const __m128i d = - _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192))); - return _mm_or_ps( - _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)), - _mm_andnot_ps(_mm_castsi128_ps(d), - _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)), - _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k))))); -} - -// computes silu x/(1+exp(-x)) in single precision vector -inline static __m128 lm_ggml_v_silu(__m128 x) { - const __m128 one = _mm_set1_ps(1); - const __m128 zero = _mm_setzero_ps(); - const __m128 neg_x = _mm_sub_ps(zero, x); - const __m128 exp_neg_x = lm_ggml_v_expf(neg_x); - const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x); - return _mm_div_ps(x, one_plus_exp_neg_x); -} - -#endif // __ARM_NEON / __AVX2__ / __SSE2__ - -static void lm_ggml_vec_silu_f32(const int n, float * y, const float * x) { - int i = 0; -#if defined(__AVX512F__) && defined(__AVX512DQ__) - for (; i + 15 < n; i += 16) { - _mm512_storeu_ps(y + i, lm_ggml_v_silu(_mm512_loadu_ps(x + i))); - } -#elif defined(__AVX2__) && defined(__FMA__) - for (; i + 7 < n; i += 8) { - _mm256_storeu_ps(y + i, lm_ggml_v_silu(_mm256_loadu_ps(x + i))); - } -#elif defined(__SSE2__) - for (; i + 3 < n; i += 4) { - _mm_storeu_ps(y + i, lm_ggml_v_silu(_mm_loadu_ps(x + i))); - } -#elif defined(__ARM_NEON) && defined(__aarch64__) - for (; i + 3 < n; i += 4) { - vst1q_f32(y + i, lm_ggml_v_silu(vld1q_f32(x + i))); - } -#endif - for (; i < n; ++i) { - y[i] = lm_ggml_silu_f32(x[i]); - } -} + LM_GGML_ASSERT_ALIGNED(ctx->mem_buffer); -static lm_ggml_float lm_ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { - int i = 0; - lm_ggml_float sum = 0; -#if defined(__AVX512F__) && defined(__AVX512DQ__) - for (; i + 15 < n; i += 16) { - __m512 val = lm_ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i), - _mm512_set1_ps(max))); - _mm512_storeu_ps(y + i, val); - sum += (lm_ggml_float)_mm512_reduce_add_ps(val); - } -#elif defined(__AVX2__) && defined(__FMA__) - for (; i + 7 < n; i += 8) { - __m256 val = lm_ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i), - _mm256_set1_ps(max))); - _mm256_storeu_ps(y + i, val); - __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), - _mm256_castps256_ps128(val)); - val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); - val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); - sum += (lm_ggml_float)_mm_cvtss_f32(val2); - } -#elif defined(__SSE2__) - for (; i + 3 < n; i += 4) { - __m128 val = lm_ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i), - _mm_set1_ps(max))); - _mm_storeu_ps(y + i, val); -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) - val = _mm_add_ps(val, _mm_movehl_ps(val, val)); - val = _mm_add_ss(val, _mm_movehdup_ps(val)); -#else - __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1)); - val = _mm_add_ps(val, tmp); - tmp = _mm_movehl_ps(tmp, val); - val = _mm_add_ss(val, tmp); -#endif - sum += (lm_ggml_float)_mm_cvtss_f32(val); - } -#elif defined(__ARM_NEON) && defined(__aarch64__) - for (; i + 3 < n; i += 4) { - float32x4_t val = lm_ggml_v_expf(vsubq_f32(vld1q_f32(x + i), - vdupq_n_f32(max))); - vst1q_f32(y + i, val); - sum += (lm_ggml_float)vaddvq_f32(val); - } -#endif - for (; i < n; ++i) { - float val = expf(x[i] - max); - sum += (lm_ggml_float)val; - y[i] = val; - } - return sum; -} + LM_GGML_PRINT_DEBUG("%s: context initialized\n", __func__); -static lm_ggml_float lm_ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) { - // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i) + return ctx; +} - int i = 0; - lm_ggml_float sum = 0; - for (; i < n; ++i) { - float val = x[i] - max; - y[i] = val; - sum += (lm_ggml_float)expf(val); +void lm_ggml_reset(struct lm_ggml_context * ctx) { + if (ctx == NULL) { + return; } - return sum = (lm_ggml_float)logf(sum); -} -inline static float lm_ggml_silu_backward_f32(float x, float dy) { - const float s = 1.0f/(1.0f + expf(-x)); - return dy*s*(1.0f + x*(1.0f - s)); + ctx->n_objects = 0; + ctx->objects_begin = NULL; + ctx->objects_end = NULL; } -inline static void lm_ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) { - for (int i = 0; i < n; ++i) { - dx[i] = lm_ggml_silu_backward_f32(x[i], dy[i]); +void lm_ggml_free(struct lm_ggml_context * ctx) { + if (ctx == NULL) { + return; } -} -inline static void lm_ggml_vec_sum_f32(const int n, float * s, const float * x) { -#ifndef LM_GGML_USE_ACCELERATE - lm_ggml_float sum = 0.0; - for (int i = 0; i < n; ++i) { - sum += (lm_ggml_float)x[i]; + if (ctx->mem_buffer_owned) { + lm_ggml_aligned_free(ctx->mem_buffer, ctx->mem_size); } - *s = sum; -#else - vDSP_sve(x, 1, s, n); -#endif + + LM_GGML_FREE(ctx); } -inline static void lm_ggml_vec_sum_f32_ggf(const int n, lm_ggml_float * s, const float * x) { - lm_ggml_float sum = 0.0; - for (int i = 0; i < n; ++i) { - sum += (lm_ggml_float)x[i]; - } - *s = sum; +size_t lm_ggml_used_mem(const struct lm_ggml_context * ctx) { + return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size; } -inline static void lm_ggml_vec_sum_f16_ggf(const int n, float * s, const lm_ggml_fp16_t * x) { - float sum = 0.0f; - for (int i = 0; i < n; ++i) { - sum += LM_GGML_FP16_TO_FP32(x[i]); - } - *s = sum; +bool lm_ggml_get_no_alloc(struct lm_ggml_context * ctx) { + return ctx->no_alloc; } -inline static void lm_ggml_vec_sum_bf16_ggf(const int n, float * s, const lm_ggml_bf16_t * x) { - float sum = 0.0f; - for (int i = 0; i < n; ++i) { - sum += LM_GGML_BF16_TO_FP32(x[i]); - } - *s = sum; +void lm_ggml_set_no_alloc(struct lm_ggml_context * ctx, bool no_alloc) { + ctx->no_alloc = no_alloc; } -inline static void lm_ggml_vec_max_f32(const int n, float * s, const float * x) { -#ifndef LM_GGML_USE_ACCELERATE - float max = -INFINITY; - for (int i = 0; i < n; ++i) { - max = MAX(max, x[i]); - } - *s = max; -#else - vDSP_maxv(x, 1, s, n); -#endif +void * lm_ggml_get_mem_buffer(const struct lm_ggml_context * ctx) { + return ctx->mem_buffer; } -inline static void lm_ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { - lm_ggml_vec_norm_f32(n, s, x); - *s = 1.f/(*s); +size_t lm_ggml_get_mem_size(const struct lm_ggml_context * ctx) { + return ctx->mem_size; } -inline static void lm_ggml_vec_argmax_f32(const int n, int * s, const float * x) { - float max = -INFINITY; - int idx = 0; - for (int i = 0; i < n; ++i) { - max = MAX(max, x[i]); - if (max == x[i]) { idx = i; } +size_t lm_ggml_get_max_tensor_size(const struct lm_ggml_context * ctx) { + size_t max_size = 0; + + for (struct lm_ggml_tensor * tensor = lm_ggml_get_first_tensor(ctx); tensor != NULL; tensor = lm_ggml_get_next_tensor(ctx, tensor)) { + size_t bytes = lm_ggml_nbytes(tensor); + max_size = MAX(max_size, bytes); } - *s = idx; -} -// -// data types -// + return max_size; +} -static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = { - "NONE", +//////////////////////////////////////////////////////////////////////////////// - "DUP", - "ADD", - "ADD1", - "ACC", - "SUB", - "MUL", - "DIV", - "SQR", - "SQRT", - "LOG", - "SIN", - "COS", - "SUM", - "SUM_ROWS", - "MEAN", - "ARGMAX", - "COUNT_EQUAL", - "REPEAT", - "REPEAT_BACK", - "CONCAT", - "SILU_BACK", - "NORM", - "RMS_NORM", - "RMS_NORM_BACK", - "GROUP_NORM", +static struct lm_ggml_object * lm_ggml_new_object(struct lm_ggml_context * ctx, enum lm_ggml_object_type type, size_t size) { + // always insert objects at the end of the context's memory pool + struct lm_ggml_object * obj_cur = ctx->objects_end; - "MUL_MAT", - "MUL_MAT_ID", - "OUT_PROD", + const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs; + const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; + const size_t cur_end = cur_offs + cur_size; - "SCALE", - "SET", - "CPY", - "CONT", - "RESHAPE", - "VIEW", - "PERMUTE", - "TRANSPOSE", - "GET_ROWS", - "GET_ROWS_BACK", - "DIAG", - "DIAG_MASK_INF", - "DIAG_MASK_ZERO", - "SOFT_MAX", - "SOFT_MAX_BACK", - "ROPE", - "ROPE_BACK", - "CLAMP", - "CONV_TRANSPOSE_1D", - "IM2COL", - "IM2COL_BACK", - "CONV_TRANSPOSE_2D", - "POOL_1D", - "POOL_2D", - "POOL_2D_BACK", - "UPSCALE", - "PAD", - "ARANGE", - "TIMESTEP_EMBEDDING", - "ARGSORT", - "LEAKY_RELU", + // align to LM_GGML_MEM_ALIGN + size_t size_needed = LM_GGML_PAD(size, LM_GGML_MEM_ALIGN); - "FLASH_ATTN_EXT", - "FLASH_ATTN_BACK", - "SSM_CONV", - "SSM_SCAN", - "WIN_PART", - "WIN_UNPART", - "GET_REL_POS", - "ADD_REL_POS", - "RWKV_WKV", + char * const mem_buffer = ctx->mem_buffer; + struct lm_ggml_object * const obj_new = (struct lm_ggml_object *)(mem_buffer + cur_end); - "UNARY", + if (cur_end + size_needed + LM_GGML_OBJECT_SIZE > ctx->mem_size) { + LM_GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + size_needed + LM_GGML_OBJECT_SIZE, ctx->mem_size); +#ifndef NDEBUG + LM_GGML_ABORT("not enough space in the context's memory pool"); +#endif + return NULL; + } - "MAP_UNARY", - "MAP_BINARY", + *obj_new = (struct lm_ggml_object) { + .offs = cur_end + LM_GGML_OBJECT_SIZE, + .size = size_needed, + .next = NULL, + .type = type, + }; - "MAP_CUSTOM1_F32", - "MAP_CUSTOM2_F32", - "MAP_CUSTOM3_F32", + LM_GGML_ASSERT_ALIGNED(mem_buffer + obj_new->offs); - "MAP_CUSTOM1", - "MAP_CUSTOM2", - "MAP_CUSTOM3", + if (obj_cur != NULL) { + obj_cur->next = obj_new; + } else { + // this is the first object in this context + ctx->objects_begin = obj_new; + } - "CROSS_ENTROPY_LOSS", - "CROSS_ENTROPY_LOSS_BACK", - "OPT_STEP_ADAMW", -}; + ctx->objects_end = obj_new; -static_assert(LM_GGML_OP_COUNT == 81, "LM_GGML_OP_COUNT != 81"); + //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size); -static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { - "none", + return obj_new; +} - "x", - "x+y", - "x+y", - "view(x,nb,offset)+=y->x", - "x-y", - "x*y", - "x/y", - "x^2", - "√x", - "log(x)", - "sin(x)", - "cos(x)", - "Σx", - "Σx_k", - "Σx/n", - "argmax(x)", - "count_equal(x)", - "repeat(x)", - "repeat_back(x)", - "concat(x, y)", - "silu_back(x)", - "norm(x)", - "rms_norm(x)", - "rms_norm_back(x)", - "group_norm(x)", +static struct lm_ggml_tensor * lm_ggml_new_tensor_impl( + struct lm_ggml_context * ctx, + enum lm_ggml_type type, + int n_dims, + const int64_t * ne, + struct lm_ggml_tensor * view_src, + size_t view_offs) { - "X*Y", - "X[i]*Y", - "X*Y", + LM_GGML_ASSERT(type >= 0 && type < LM_GGML_TYPE_COUNT); + LM_GGML_ASSERT(n_dims >= 1 && n_dims <= LM_GGML_MAX_DIMS); - "x*v", - "y-\\>view(x)", - "x-\\>y", - "cont(x)", - "reshape(x)", - "view(x)", - "permute(x)", - "transpose(x)", - "get_rows(x)", - "get_rows_back(x)", - "diag(x)", - "diag_mask_inf(x)", - "diag_mask_zero(x)", - "soft_max(x)", - "soft_max_back(x)", - "rope(x)", - "rope_back(x)", - "clamp(x)", - "conv_transpose_1d(x)", - "im2col(x)", - "im2col_back(x)", - "conv_transpose_2d(x)", - "pool_1d(x)", - "pool_2d(x)", - "pool_2d_back(x)", - "upscale(x)", - "pad(x)", - "arange(start, stop, step)", - "timestep_embedding(timesteps, dim, max_period)", - "argsort(x)", - "leaky_relu(x)", + // find the base tensor and absolute offset + if (view_src != NULL && view_src->view_src != NULL) { + view_offs += view_src->view_offs; + view_src = view_src->view_src; + } - "flash_attn_ext(x)", - "flash_attn_back(x)", - "ssm_conv(x)", - "ssm_scan(x)", - "win_part(x)", - "win_unpart(x)", - "get_rel_pos(x)", - "add_rel_pos(x)", - "rwkv_wkv(k, v, r, tf, td, s)", + size_t data_size = lm_ggml_row_size(type, ne[0]); + for (int i = 1; i < n_dims; i++) { + data_size *= ne[i]; + } - "unary(x)", + LM_GGML_ASSERT(view_src == NULL || data_size == 0 || data_size + view_offs <= lm_ggml_nbytes(view_src)); - "f(x)", - "f(x,y)", + void * data = view_src != NULL ? view_src->data : NULL; + if (data != NULL) { + data = (char *) data + view_offs; + } - "custom_f32(x)", - "custom_f32(x,y)", - "custom_f32(x,y,z)", + size_t obj_alloc_size = 0; - "custom(x)", - "custom(x,y)", - "custom(x,y,z)", + if (view_src == NULL && !ctx->no_alloc) { + // allocate tensor data in the context's memory pool + obj_alloc_size = data_size; + } - "cross_entropy_loss(x,y)", - "cross_entropy_loss_back(x,y)", - "adamw(x)", -}; + struct lm_ggml_object * const obj_new = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_TENSOR, LM_GGML_TENSOR_SIZE + obj_alloc_size); + LM_GGML_ASSERT(obj_new); -static_assert(LM_GGML_OP_COUNT == 81, "LM_GGML_OP_COUNT != 81"); + struct lm_ggml_tensor * const result = (struct lm_ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs); -static_assert(LM_GGML_OP_POOL_COUNT == 2, "LM_GGML_OP_POOL_COUNT != 2"); +#ifdef __clang__ + // temporary until lm_ggml_tensor::backend is removed + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wdeprecated-declarations" +#endif + *result = (struct lm_ggml_tensor) { + /*.type =*/ type, + /*.backend =*/ LM_GGML_BACKEND_TYPE_CPU, + /*.buffer =*/ NULL, + /*.ne =*/ { 1, 1, 1, 1 }, + /*.nb =*/ { 0, 0, 0, 0 }, + /*.op =*/ LM_GGML_OP_NONE, + /*.op_params =*/ { 0 }, + /*.flags =*/ 0, + /*.grad =*/ NULL, + /*.src =*/ { NULL }, + /*.view_src =*/ view_src, + /*.view_offs =*/ view_offs, + /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data, + /*.name =*/ { 0 }, + /*.extra =*/ NULL, + ///*.padding =*/ { 0 }, + }; -static const char * LM_GGML_UNARY_OP_NAME[LM_GGML_UNARY_OP_COUNT] = { - "ABS", - "SGN", - "NEG", - "STEP", - "TANH", - "ELU", - "RELU", - "SIGMOID", - "GELU", - "GELU_QUICK", - "SILU", - "HARDSWISH", - "HARDSIGMOID", - "EXP", -}; +#ifdef __clang__ + #pragma clang diagnostic pop +#endif -static_assert(LM_GGML_UNARY_OP_COUNT == 14, "LM_GGML_UNARY_OP_COUNT != 14"); + // TODO: this should not be needed as long as we don't rely on aligned SIMD loads + //LM_GGML_ASSERT_ALIGNED(result->data); + for (int i = 0; i < n_dims; i++) { + result->ne[i] = ne[i]; + } -static_assert(sizeof(struct lm_ggml_object)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_object size must be a multiple of LM_GGML_MEM_ALIGN"); -static_assert(sizeof(struct lm_ggml_tensor)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_tensor size must be a multiple of LM_GGML_MEM_ALIGN"); + result->nb[0] = lm_ggml_type_size(type); + result->nb[1] = result->nb[0]*(result->ne[0]/lm_ggml_blck_size(type)); + for (int i = 2; i < LM_GGML_MAX_DIMS; i++) { + result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; + } -// Helpers for polling loops -#if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) ) -static inline void lm_ggml_thread_cpu_relax(void) { - __asm__ volatile("yield" ::: "memory"); -} -#elif defined(__x86_64__) -static inline void lm_ggml_thread_cpu_relax(void) { - _mm_pause(); -} -#else -static inline void lm_ggml_thread_cpu_relax(void) {;} -#endif + ctx->n_objects++; -// -// NUMA support -// + return result; +} -#define LM_GGML_NUMA_MAX_NODES 8 -#define LM_GGML_NUMA_MAX_CPUS 512 +struct lm_ggml_tensor * lm_ggml_new_tensor( + struct lm_ggml_context * ctx, + enum lm_ggml_type type, + int n_dims, + const int64_t * ne) { + return lm_ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL, 0); +} -struct lm_ggml_numa_node { - uint32_t cpus[LM_GGML_NUMA_MAX_CPUS]; // hardware threads on this node - uint32_t n_cpus; -}; +struct lm_ggml_tensor * lm_ggml_new_tensor_1d( + struct lm_ggml_context * ctx, + enum lm_ggml_type type, + int64_t ne0) { + return lm_ggml_new_tensor(ctx, type, 1, &ne0); +} -struct lm_ggml_numa_nodes { - enum lm_ggml_numa_strategy numa_strategy; - struct lm_ggml_numa_node nodes[LM_GGML_NUMA_MAX_NODES]; - uint32_t n_nodes; - uint32_t total_cpus; // hardware threads on system - uint32_t current_node; // node on which main process is execting -#if defined(__gnu_linux__) - cpu_set_t cpuset; // cpuset from numactl -#else - uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype -#endif -}; +struct lm_ggml_tensor * lm_ggml_new_tensor_2d( + struct lm_ggml_context * ctx, + enum lm_ggml_type type, + int64_t ne0, + int64_t ne1) { + const int64_t ne[2] = { ne0, ne1 }; + return lm_ggml_new_tensor(ctx, type, 2, ne); +} -// -// ggml state -// +struct lm_ggml_tensor * lm_ggml_new_tensor_3d( + struct lm_ggml_context * ctx, + enum lm_ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + const int64_t ne[3] = { ne0, ne1, ne2 }; + return lm_ggml_new_tensor(ctx, type, 3, ne); +} -struct lm_ggml_state { - struct lm_ggml_numa_nodes numa; -}; +struct lm_ggml_tensor * lm_ggml_new_tensor_4d( + struct lm_ggml_context * ctx, + enum lm_ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + return lm_ggml_new_tensor(ctx, type, 4, ne); +} -// global state -static struct lm_ggml_state g_state; -static atomic_flag g_state_critical = ATOMIC_FLAG_INIT; +void * lm_ggml_new_buffer(struct lm_ggml_context * ctx, size_t nbytes) { + struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_WORK_BUFFER, nbytes); -// critical section via spin lock -inline static void lm_ggml_critical_section_start(void) { - while (atomic_flag_test_and_set(&g_state_critical)) { - // spin - sched_yield(); - } + return (uint8_t *)ctx->mem_buffer + obj->offs; } -static void lm_ggml_barrier(struct lm_ggml_threadpool * tp) { - int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed); - if (n_threads == 1) { - return; - } - -#ifdef LM_GGML_USE_OPENMP - #pragma omp barrier -#else - int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed); +struct lm_ggml_tensor * lm_ggml_dup_tensor(struct lm_ggml_context * ctx, const struct lm_ggml_tensor * src) { + return lm_ggml_new_tensor(ctx, src->type, LM_GGML_MAX_DIMS, src->ne); +} - // enter barrier (full seq-cst fence) - int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst); +void lm_ggml_unravel_index(const struct lm_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) { + const int64_t ne2 = tensor->ne[2]; + const int64_t ne1 = tensor->ne[1]; + const int64_t ne0 = tensor->ne[0]; - if (n_barrier == (n_threads - 1)) { - // last thread - atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed); + const int64_t i3_ = (i/(ne2*ne1*ne0)); + const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0); + const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0; + const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0); - // exit barrier (fill seq-cst fence) - atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst); - return; + if (i0) { + * i0 = i0_; } - - // wait for other threads - while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) { - lm_ggml_thread_cpu_relax(); + if (i1) { + * i1 = i1_; + } + if (i2) { + * i2 = i2_; + } + if (i3) { + * i3 = i3_; } - - // exit barrier (full seq-cst fence) - // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead - #ifdef LM_GGML_TSAN_ENABLED - atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst); - #else - atomic_thread_fence(memory_order_seq_cst); - #endif -#endif } -// TODO: make this somehow automatically executed -// some sort of "sentry" mechanism -inline static void lm_ggml_critical_section_end(void) { - atomic_flag_clear(&g_state_critical); +void * lm_ggml_get_data(const struct lm_ggml_tensor * tensor) { + return tensor->data; } -#if defined(__gnu_linux__) -static cpu_set_t lm_ggml_get_numa_affinity(void) { - cpu_set_t cpuset; - pthread_t thread; - thread = pthread_self(); - CPU_ZERO(&cpuset); - pthread_getaffinity_np(thread, sizeof(cpu_set_t), &cpuset); - return cpuset; +float * lm_ggml_get_data_f32(const struct lm_ggml_tensor * tensor) { + assert(tensor->type == LM_GGML_TYPE_F32); + return (float *)(tensor->data); } -#else -static uint32_t lm_ggml_get_numa_affinity(void) { - return 0; // no NUMA support + +enum lm_ggml_unary_op lm_ggml_get_unary_op(const struct lm_ggml_tensor * tensor) { + LM_GGML_ASSERT(tensor->op == LM_GGML_OP_UNARY); + return (enum lm_ggml_unary_op) lm_ggml_get_op_params_i32(tensor, 0); } -#endif -void lm_ggml_numa_init(enum lm_ggml_numa_strategy numa_flag) { - if (g_state.numa.n_nodes > 0) { - fprintf(stderr, "lm_ggml_numa_init: NUMA already initialized\n"); +const char * lm_ggml_get_name(const struct lm_ggml_tensor * tensor) { + return tensor->name; +} - return; +struct lm_ggml_tensor * lm_ggml_set_name(struct lm_ggml_tensor * tensor, const char * name) { + size_t i; + for (i = 0; i < sizeof(tensor->name) - 1 && name[i] != '\0'; i++) { + tensor->name[i] = name[i]; } + tensor->name[i] = '\0'; + return tensor; +} -#if defined(__gnu_linux__) - struct stat st; - char path[256]; - int rv; - - // set numa scheme - g_state.numa.numa_strategy = numa_flag; - - LM_GGML_PRINT_DEBUG("numa strategy %u\n",g_state.numa.numa_strategy); - - g_state.numa.cpuset = lm_ggml_get_numa_affinity(); +struct lm_ggml_tensor * lm_ggml_format_name(struct lm_ggml_tensor * tensor, const char * fmt, ...) { + va_list args; + va_start(args, fmt); + vsnprintf(tensor->name, sizeof(tensor->name), fmt, args); + va_end(args); + return tensor; +} - // enumerate nodes - while (g_state.numa.n_nodes < LM_GGML_NUMA_MAX_NODES) { - rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes); - LM_GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); - if (stat(path, &st) != 0) { break; } - ++g_state.numa.n_nodes; - } +struct lm_ggml_tensor * lm_ggml_view_tensor( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * src) { + struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, src->type, LM_GGML_MAX_DIMS, src->ne, src, 0); + lm_ggml_format_name(result, "%s (view)", src->name); - // enumerate CPUs - while (g_state.numa.total_cpus < LM_GGML_NUMA_MAX_CPUS) { - rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus); - LM_GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); - if (stat(path, &st) != 0) { break; } - ++g_state.numa.total_cpus; + for (int i = 0; i < LM_GGML_MAX_DIMS; i++) { + result->nb[i] = src->nb[i]; } - LM_GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus); + return result; +} - // figure out which node we're on - uint current_cpu; - int getcpu_ret = 0; -#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__) - getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node); -#else - // old glibc doesn't have a wrapper for this call. Fall back on direct syscall -# if !defined(SYS_getcpu) && defined(SYS_get_cpu) -# define SYS_getcpu SYS_get_cpu // some older glibc versions use this name -# endif - getcpu_ret = syscall(SYS_getcpu, ¤t_cpu, &g_state.numa.current_node); -#endif +struct lm_ggml_tensor * lm_ggml_get_first_tensor(const struct lm_ggml_context * ctx) { + struct lm_ggml_object * obj = ctx->objects_begin; - if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) { - g_state.numa.n_nodes = 0; - return; - } + char * const mem_buffer = ctx->mem_buffer; - LM_GGML_PRINT_DEBUG("found our process on numa node %u, CPU %u\n", g_state.numa.current_node, current_cpu); - - for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) { - struct lm_ggml_numa_node * node = &g_state.numa.nodes[n]; - LM_GGML_PRINT_DEBUG("CPUs on node %u:", n); - node->n_cpus = 0; - for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) { - rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c); - LM_GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); - if (stat(path, &st) == 0) { - node->cpus[node->n_cpus++] = c; - LM_GGML_PRINT_DEBUG(" %u", c); - } + while (obj != NULL) { + if (obj->type == LM_GGML_OBJECT_TYPE_TENSOR) { + return (struct lm_ggml_tensor *)(mem_buffer + obj->offs); } - LM_GGML_PRINT_DEBUG("\n"); - } - if (lm_ggml_is_numa()) { - FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r"); - if (fptr != NULL) { - char buf[42]; - if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) { - LM_GGML_LOG_WARN("/proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n"); - } - fclose(fptr); - } + obj = obj->next; } -#else - UNUSED(numa_flag); - // TODO -#endif -} -bool lm_ggml_is_numa(void) { - return g_state.numa.n_nodes > 1; + return NULL; } -//////////////////////////////////////////////////////////////////////////////// +struct lm_ggml_tensor * lm_ggml_get_next_tensor(const struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor) { + struct lm_ggml_object * obj = (struct lm_ggml_object *) ((char *)tensor - LM_GGML_OBJECT_SIZE); + obj = obj->next; -void lm_ggml_print_object(const struct lm_ggml_object * obj) { - LM_GGML_LOG_INFO(" - lm_ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n", - obj->type, obj->offs, obj->size, (const void *) obj->next); + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == LM_GGML_OBJECT_TYPE_TENSOR) { + return (struct lm_ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; } -void lm_ggml_print_objects(const struct lm_ggml_context * ctx) { +struct lm_ggml_tensor * lm_ggml_get_tensor(struct lm_ggml_context * ctx, const char * name) { struct lm_ggml_object * obj = ctx->objects_begin; - LM_GGML_LOG_INFO("%s: objects in context %p:\n", __func__, (const void *) ctx); + char * const mem_buffer = ctx->mem_buffer; while (obj != NULL) { - lm_ggml_print_object(obj); + if (obj->type == LM_GGML_OBJECT_TYPE_TENSOR) { + struct lm_ggml_tensor * cur = (struct lm_ggml_tensor *)(mem_buffer + obj->offs); + if (strcmp(cur->name, name) == 0) { + return cur; + } + } + obj = obj->next; } - LM_GGML_LOG_INFO("%s: --- end ---\n", __func__); + return NULL; } -int64_t lm_ggml_nelements(const struct lm_ggml_tensor * tensor) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; -} +//////////////////////////////////////////////////////////////////////////////// -int64_t lm_ggml_nrows(const struct lm_ggml_tensor * tensor) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); +// lm_ggml_dup - return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; -} +static struct lm_ggml_tensor * lm_ggml_dup_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + bool inplace) { + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); -size_t lm_ggml_nbytes(const struct lm_ggml_tensor * tensor) { - size_t nbytes; - const size_t blck_size = lm_ggml_blck_size(tensor->type); - if (blck_size == 1) { - nbytes = lm_ggml_type_size(tensor->type); - for (int i = 0; i < LM_GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; - } - } - else { - nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; - for (int i = 1; i < LM_GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; - } - } + result->op = LM_GGML_OP_DUP; + result->src[0] = a; - return nbytes; + return result; } -size_t lm_ggml_nbytes_pad(const struct lm_ggml_tensor * tensor) { - return LM_GGML_PAD(lm_ggml_nbytes(tensor), LM_GGML_MEM_ALIGN); +struct lm_ggml_tensor * lm_ggml_dup( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_dup_impl(ctx, a, false); } -int64_t lm_ggml_blck_size(enum lm_ggml_type type) { - return type_traits[type].blck_size; -} - -size_t lm_ggml_type_size(enum lm_ggml_type type) { - return type_traits[type].type_size; +struct lm_ggml_tensor * lm_ggml_dup_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_dup_impl(ctx, a, true); } -size_t lm_ggml_row_size(enum lm_ggml_type type, int64_t ne) { - assert(ne % lm_ggml_blck_size(type) == 0); - return lm_ggml_type_size(type)*ne/lm_ggml_blck_size(type); -} +// lm_ggml_add -double lm_ggml_type_sizef(enum lm_ggml_type type) { - return ((double)(type_traits[type].type_size))/type_traits[type].blck_size; -} +static struct lm_ggml_tensor * lm_ggml_add_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + bool inplace) { + LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); -const char * lm_ggml_type_name(enum lm_ggml_type type) { - return type < LM_GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE"; -} + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); -bool lm_ggml_is_quantized(enum lm_ggml_type type) { - return type_traits[type].is_quantized; -} + result->op = LM_GGML_OP_ADD; + result->src[0] = a; + result->src[1] = b; -const char * lm_ggml_op_name(enum lm_ggml_op op) { - return LM_GGML_OP_NAME[op]; + return result; } -const char * lm_ggml_op_symbol(enum lm_ggml_op op) { - return LM_GGML_OP_SYMBOL[op]; +struct lm_ggml_tensor * lm_ggml_add( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + return lm_ggml_add_impl(ctx, a, b, false); } -const char * lm_ggml_unary_op_name(enum lm_ggml_unary_op op) { - return LM_GGML_UNARY_OP_NAME[op]; +struct lm_ggml_tensor * lm_ggml_add_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + return lm_ggml_add_impl(ctx, a, b, true); } -const char * lm_ggml_op_desc(const struct lm_ggml_tensor * t) { - if (t->op == LM_GGML_OP_UNARY) { - enum lm_ggml_unary_op uop = lm_ggml_get_unary_op(t); - return lm_ggml_unary_op_name(uop); - } - return lm_ggml_op_name(t->op); -} +// lm_ggml_add_cast -size_t lm_ggml_element_size(const struct lm_ggml_tensor * tensor) { - return lm_ggml_type_size(tensor->type); -} +static struct lm_ggml_tensor * lm_ggml_add_cast_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + enum lm_ggml_type type) { + // TODO: support less-strict constraint + // LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); + LM_GGML_ASSERT(lm_ggml_can_repeat_rows(b, a)); -bool lm_ggml_is_scalar(const struct lm_ggml_tensor * tensor) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); + // currently only supported for quantized input and f16 + LM_GGML_ASSERT(lm_ggml_is_quantized(a->type) || + a->type == LM_GGML_TYPE_F16 || + a->type == LM_GGML_TYPE_BF16); - return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; -} + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, type, LM_GGML_MAX_DIMS, a->ne); -bool lm_ggml_is_vector(const struct lm_ggml_tensor * tensor) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); + result->op = LM_GGML_OP_ADD; + result->src[0] = a; + result->src[1] = b; - return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; + return result; } -bool lm_ggml_is_matrix(const struct lm_ggml_tensor * tensor) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[2] == 1 && tensor->ne[3] == 1; +struct lm_ggml_tensor * lm_ggml_add_cast( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + enum lm_ggml_type type) { + return lm_ggml_add_cast_impl(ctx, a, b, type); } -bool lm_ggml_is_3d(const struct lm_ggml_tensor * tensor) { - return tensor->ne[3] == 1; -} +// lm_ggml_add1 -int lm_ggml_n_dims(const struct lm_ggml_tensor * tensor) { - for (int i = LM_GGML_MAX_DIMS - 1; i >= 1; --i) { - if (tensor->ne[i] > 1) { - return i + 1; - } - } - return 1; -} +static struct lm_ggml_tensor * lm_ggml_add1_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + bool inplace) { + LM_GGML_ASSERT(lm_ggml_is_scalar(b)); + LM_GGML_ASSERT(lm_ggml_is_padded_1d(a)); -static inline bool lm_ggml_can_mul_mat(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - return (t0->ne[0] == t1->ne[0]) && - (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable - (t1->ne[3]%t0->ne[3] == 0); + result->op = LM_GGML_OP_ADD1; + result->src[0] = a; + result->src[1] = b; + + return result; } -static inline bool lm_ggml_can_out_prod(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); +struct lm_ggml_tensor * lm_ggml_add1( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + return lm_ggml_add1_impl(ctx, a, b, false); +} - return (t0->ne[1] == t1->ne[1]) && - (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable - (t1->ne[3]%t0->ne[3] == 0); +struct lm_ggml_tensor * lm_ggml_add1_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + return lm_ggml_add1_impl(ctx, a, b, true); } -enum lm_ggml_type lm_ggml_ftype_to_lm_ggml_type(enum lm_ggml_ftype ftype) { - enum lm_ggml_type wtype = LM_GGML_TYPE_COUNT; +// lm_ggml_acc - switch (ftype) { - case LM_GGML_FTYPE_ALL_F32: wtype = LM_GGML_TYPE_F32; break; - case LM_GGML_FTYPE_MOSTLY_F16: wtype = LM_GGML_TYPE_F16; break; - case LM_GGML_FTYPE_MOSTLY_BF16: wtype = LM_GGML_TYPE_BF16; break; - case LM_GGML_FTYPE_MOSTLY_Q4_0: wtype = LM_GGML_TYPE_Q4_0; break; - case LM_GGML_FTYPE_MOSTLY_Q4_1: wtype = LM_GGML_TYPE_Q4_1; break; - case LM_GGML_FTYPE_MOSTLY_Q5_0: wtype = LM_GGML_TYPE_Q5_0; break; - case LM_GGML_FTYPE_MOSTLY_Q5_1: wtype = LM_GGML_TYPE_Q5_1; break; - case LM_GGML_FTYPE_MOSTLY_Q8_0: wtype = LM_GGML_TYPE_Q8_0; break; - case LM_GGML_FTYPE_MOSTLY_Q2_K: wtype = LM_GGML_TYPE_Q2_K; break; - case LM_GGML_FTYPE_MOSTLY_Q3_K: wtype = LM_GGML_TYPE_Q3_K; break; - case LM_GGML_FTYPE_MOSTLY_Q4_K: wtype = LM_GGML_TYPE_Q4_K; break; - case LM_GGML_FTYPE_MOSTLY_Q5_K: wtype = LM_GGML_TYPE_Q5_K; break; - case LM_GGML_FTYPE_MOSTLY_Q6_K: wtype = LM_GGML_TYPE_Q6_K; break; - case LM_GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = LM_GGML_TYPE_IQ2_XXS; break; - case LM_GGML_FTYPE_MOSTLY_IQ2_XS: wtype = LM_GGML_TYPE_IQ2_XS; break; - case LM_GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = LM_GGML_TYPE_IQ3_XXS; break; - case LM_GGML_FTYPE_MOSTLY_IQ1_S: wtype = LM_GGML_TYPE_IQ1_S; break; - case LM_GGML_FTYPE_MOSTLY_IQ1_M: wtype = LM_GGML_TYPE_IQ1_M; break; - case LM_GGML_FTYPE_MOSTLY_IQ4_NL: wtype = LM_GGML_TYPE_IQ4_NL; break; - case LM_GGML_FTYPE_MOSTLY_IQ4_XS: wtype = LM_GGML_TYPE_IQ4_XS; break; - case LM_GGML_FTYPE_MOSTLY_IQ3_S: wtype = LM_GGML_TYPE_IQ3_S; break; - case LM_GGML_FTYPE_MOSTLY_IQ2_S: wtype = LM_GGML_TYPE_IQ2_S; break; - case LM_GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = LM_GGML_TYPE_Q4_0_4_4; break; - case LM_GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = LM_GGML_TYPE_Q4_0_4_8; break; - case LM_GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = LM_GGML_TYPE_Q4_0_8_8; break; - case LM_GGML_FTYPE_UNKNOWN: wtype = LM_GGML_TYPE_COUNT; break; - case LM_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = LM_GGML_TYPE_COUNT; break; - } +static struct lm_ggml_tensor * lm_ggml_acc_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset, + bool inplace) { + LM_GGML_ASSERT(lm_ggml_nelements(b) <= lm_ggml_nelements(a)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); + LM_GGML_ASSERT(a->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(b->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(wtype != LM_GGML_TYPE_COUNT); + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - return wtype; -} + int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; + lm_ggml_set_op_params(result, params, sizeof(params)); -size_t lm_ggml_tensor_overhead(void) { - return LM_GGML_OBJECT_SIZE + LM_GGML_TENSOR_SIZE; -} + result->op = LM_GGML_OP_ACC; + result->src[0] = a; + result->src[1] = b; -bool lm_ggml_is_transposed(const struct lm_ggml_tensor * tensor) { - return tensor->nb[0] > tensor->nb[1]; + return result; } -static bool lm_ggml_is_contiguous_n(const struct lm_ggml_tensor * tensor, int n) { - size_t next_nb = lm_ggml_type_size(tensor->type); - if (tensor->ne[0] != lm_ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) { - return false; - } - next_nb *= tensor->ne[0]/lm_ggml_blck_size(tensor->type); - for (int i = 1; i < LM_GGML_MAX_DIMS; i++) { - if (tensor->ne[i] != 1) { - if (i > n) { - if (tensor->nb[i] != next_nb) { - return false; - } - next_nb *= tensor->ne[i]; - } else { - // this dimension does not need to be contiguous - next_nb = tensor->ne[i]*tensor->nb[i]; - } - } - } - return true; +struct lm_ggml_tensor * lm_ggml_acc( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); } -bool lm_ggml_is_contiguous(const struct lm_ggml_tensor * tensor) { - return lm_ggml_is_contiguous_0(tensor); +struct lm_ggml_tensor * lm_ggml_acc_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); } -bool lm_ggml_is_contiguous_0(const struct lm_ggml_tensor * tensor) { - return lm_ggml_is_contiguous_n(tensor, 0); -} +// lm_ggml_sub -bool lm_ggml_is_contiguous_1(const struct lm_ggml_tensor * tensor) { - return lm_ggml_is_contiguous_n(tensor, 1); -} +static struct lm_ggml_tensor * lm_ggml_sub_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + bool inplace) { + LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); -bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor) { - return lm_ggml_is_contiguous_n(tensor, 2); -} + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); -bool lm_ggml_is_permuted(const struct lm_ggml_tensor * tensor) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); + result->op = LM_GGML_OP_SUB; + result->src[0] = a; + result->src[1] = b; - return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; + return result; } -static inline bool lm_ggml_is_padded_1d(const struct lm_ggml_tensor * tensor) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); - - return - tensor->nb[0] == lm_ggml_type_size(tensor->type) && - tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +struct lm_ggml_tensor * lm_ggml_sub( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + return lm_ggml_sub_impl(ctx, a, b, false); } -bool lm_ggml_is_empty(const struct lm_ggml_tensor * tensor) { - for (int i = 0; i < LM_GGML_MAX_DIMS; ++i) { - if (tensor->ne[i] == 0) { - // empty if any dimension has no elements - return true; - } - } - return false; +struct lm_ggml_tensor * lm_ggml_sub_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + return lm_ggml_sub_impl(ctx, a, b, true); } -bool lm_ggml_are_same_shape(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); +// lm_ggml_mul - return - (t0->ne[0] == t1->ne[0]) && - (t0->ne[1] == t1->ne[1]) && - (t0->ne[2] == t1->ne[2]) && - (t0->ne[3] == t1->ne[3]); -} +static struct lm_ggml_tensor * lm_ggml_mul_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + bool inplace) { + LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); -bool lm_ggml_are_same_stride(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - return - (t0->nb[0] == t1->nb[0]) && - (t0->nb[1] == t1->nb[1]) && - (t0->nb[2] == t1->nb[2]) && - (t0->nb[3] == t1->nb[3]); + result->op = LM_GGML_OP_MUL; + result->src[0] = a; + result->src[1] = b; + + return result; } -// check if t1 can be represented as a repeatition of t0 -bool lm_ggml_can_repeat(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); +struct lm_ggml_tensor * lm_ggml_mul( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + return lm_ggml_mul_impl(ctx, a, b, false); +} - return lm_ggml_is_empty(t0) ? lm_ggml_is_empty(t1) : - (t1->ne[0]%t0->ne[0] == 0) && - (t1->ne[1]%t0->ne[1] == 0) && - (t1->ne[2]%t0->ne[2] == 0) && - (t1->ne[3]%t0->ne[3] == 0); +struct lm_ggml_tensor * lm_ggml_mul_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + return lm_ggml_mul_impl(ctx, a, b, true); } -static inline bool lm_ggml_can_repeat_rows(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { - static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); +// lm_ggml_div - return (t0->ne[0] == t1->ne[0]) && lm_ggml_can_repeat(t0, t1); -} +static struct lm_ggml_tensor * lm_ggml_div_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + bool inplace) { + LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); -static inline int lm_ggml_up32(int n) { - return (n + 31) & ~31; -} + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); -//static inline int lm_ggml_up64(int n) { -// return (n + 63) & ~63; -//} + result->op = LM_GGML_OP_DIV; + result->src[0] = a; + result->src[1] = b; -static inline int lm_ggml_up(int n, int m) { - // assert m is a power of 2 - LM_GGML_ASSERT((m & (m - 1)) == 0); - return (n + m - 1) & ~(m - 1); + return result; } -// assert that pointer is aligned to LM_GGML_MEM_ALIGN -#define LM_GGML_ASSERT_ALIGNED(ptr) \ - LM_GGML_ASSERT(((uintptr_t) (ptr))%LM_GGML_MEM_ALIGN == 0) +struct lm_ggml_tensor * lm_ggml_div( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + return lm_ggml_div_impl(ctx, a, b, false); +} -//////////////////////////////////////////////////////////////////////////////// +struct lm_ggml_tensor * lm_ggml_div_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + return lm_ggml_div_impl(ctx, a, b, true); +} -#if defined(__ARM_ARCH) +// lm_ggml_sqr -#if defined(__linux__) && defined(__aarch64__) -#include -#elif defined(__APPLE__) -#include -#endif +static struct lm_ggml_tensor * lm_ggml_sqr_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + bool inplace) { + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); -#if !defined(HWCAP2_I8MM) -#define HWCAP2_I8MM 0 -#endif + result->op = LM_GGML_OP_SQR; + result->src[0] = a; -static void lm_ggml_init_arm_arch_features(void) { -#if defined(__linux__) && defined(__aarch64__) - uint32_t hwcap = getauxval(AT_HWCAP); - uint32_t hwcap2 = getauxval(AT_HWCAP2); + return result; +} - lm_ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD); - lm_ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM); - lm_ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE); +struct lm_ggml_tensor * lm_ggml_sqr( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_sqr_impl(ctx, a, false); +} -#if defined(__ARM_FEATURE_SVE) - lm_ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); -#endif -#elif defined(__APPLE__) - int oldp = 0; - size_t size = sizeof(oldp); - if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) { - oldp = 0; - } - lm_ggml_arm_arch_features.has_neon = oldp; +struct lm_ggml_tensor * lm_ggml_sqr_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_sqr_impl(ctx, a, true); +} - if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) { - oldp = 0; - } - lm_ggml_arm_arch_features.has_i8mm = oldp; +// lm_ggml_sqrt - lm_ggml_arm_arch_features.has_sve = 0; - lm_ggml_arm_arch_features.sve_cnt = 0; -#else -// Run-time CPU feature detection not implemented for this platform, fallback to compile time -#if defined(__ARM_NEON) - lm_ggml_arm_arch_features.has_neon = 1; -#else - lm_ggml_arm_arch_features.has_neon = 0; -#endif +static struct lm_ggml_tensor * lm_ggml_sqrt_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + bool inplace) { + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); -#if defined(__ARM_FEATURE_MATMUL_INT8) - lm_ggml_arm_arch_features.has_i8mm = 1; -#else - lm_ggml_arm_arch_features.has_i8mm = 0; -#endif + result->op = LM_GGML_OP_SQRT; + result->src[0] = a; -#if defined(__ARM_FEATURE_SVE) - lm_ggml_arm_arch_features.has_sve = 1; - lm_ggml_arm_arch_features.sve_cnt = 16; -#else - lm_ggml_arm_arch_features.has_sve = 0; - lm_ggml_arm_arch_features.sve_cnt = 0; -#endif -#endif + return result; } -#endif - -struct lm_ggml_context * lm_ggml_init(struct lm_ggml_init_params params) { - // make this function thread safe - lm_ggml_critical_section_start(); - - static bool is_first_call = true; - if (is_first_call) { - // initialize time system (required on Windows) - lm_ggml_time_init(); +struct lm_ggml_tensor * lm_ggml_sqrt( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_sqrt_impl(ctx, a, false); +} - // initialize GELU, Quick GELU, SILU and EXP F32 tables - { - const uint64_t t_start = lm_ggml_time_us(); UNUSED(t_start); - - for (int i = 0; i < (1 << 16); ++i) { - union { - uint16_t u16; - lm_ggml_fp16_t fp16; - } u = {i}; - float f = lm_ggml_table_f32_f16[i] = LM_GGML_COMPUTE_FP16_TO_FP32(u.fp16); - lm_ggml_table_gelu_f16[i] = LM_GGML_FP32_TO_FP16(lm_ggml_gelu_f32(f)); - lm_ggml_table_gelu_quick_f16[i] = LM_GGML_FP32_TO_FP16(lm_ggml_gelu_quick_f32(f)); - } +struct lm_ggml_tensor * lm_ggml_sqrt_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_sqrt_impl(ctx, a, true); +} - const uint64_t t_end = lm_ggml_time_us(); UNUSED(t_end); +// lm_ggml_log - LM_GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); - } +static struct lm_ggml_tensor * lm_ggml_log_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + bool inplace) { + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - // initialize g_state - { - const uint64_t t_start = lm_ggml_time_us(); UNUSED(t_start); + result->op = LM_GGML_OP_LOG; + result->src[0] = a; - g_state = (struct lm_ggml_state) { - /*.numa =*/ { - .n_nodes = 0, - .total_cpus = 0, - }, - }; + return result; +} - const uint64_t t_end = lm_ggml_time_us(); UNUSED(t_end); +struct lm_ggml_tensor * lm_ggml_log( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_log_impl(ctx, a, false); +} - LM_GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); - } +struct lm_ggml_tensor * lm_ggml_log_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_log_impl(ctx, a, true); +} -#if defined(__ARM_ARCH) - lm_ggml_init_arm_arch_features(); -#endif +// lm_ggml_sin - is_first_call = false; - } +static struct lm_ggml_tensor * lm_ggml_sin_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + bool inplace) { + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - lm_ggml_critical_section_end(); + result->op = LM_GGML_OP_SIN; + result->src[0] = a; - struct lm_ggml_context * ctx = LM_GGML_MALLOC(sizeof(struct lm_ggml_context)); + return result; +} - // allow to call lm_ggml_init with 0 size - if (params.mem_size == 0) { - params.mem_size = LM_GGML_MEM_ALIGN; - } +struct lm_ggml_tensor * lm_ggml_sin( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_sin_impl(ctx, a, false); +} - const size_t mem_size = params.mem_buffer ? params.mem_size : LM_GGML_PAD(params.mem_size, LM_GGML_MEM_ALIGN); +struct lm_ggml_tensor * lm_ggml_sin_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_sin_impl(ctx, a, true); +} - *ctx = (struct lm_ggml_context) { - /*.mem_size =*/ mem_size, - /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : lm_ggml_aligned_malloc(mem_size), - /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, - /*.no_alloc =*/ params.no_alloc, - /*.n_objects =*/ 0, - /*.objects_begin =*/ NULL, - /*.objects_end =*/ NULL, - }; +// lm_ggml_cos - LM_GGML_ASSERT(ctx->mem_buffer != NULL); +static struct lm_ggml_tensor * lm_ggml_cos_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + bool inplace) { + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - LM_GGML_ASSERT_ALIGNED(ctx->mem_buffer); + result->op = LM_GGML_OP_COS; + result->src[0] = a; - LM_GGML_PRINT_DEBUG("%s: context initialized\n", __func__); + return result; +} - return ctx; +struct lm_ggml_tensor * lm_ggml_cos( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_cos_impl(ctx, a, false); } -void lm_ggml_reset(struct lm_ggml_context * ctx) { - if (ctx == NULL) { - return; - } - - ctx->n_objects = 0; - ctx->objects_begin = NULL; - ctx->objects_end = NULL; +struct lm_ggml_tensor * lm_ggml_cos_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_cos_impl(ctx, a, true); } -void lm_ggml_free(struct lm_ggml_context * ctx) { - if (ctx == NULL) { - return; - } - - if (ctx->mem_buffer_owned) { - lm_ggml_aligned_free(ctx->mem_buffer, ctx->mem_size); - } - - LM_GGML_FREE(ctx); -} +// lm_ggml_sum -size_t lm_ggml_used_mem(const struct lm_ggml_context * ctx) { - return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size; -} +struct lm_ggml_tensor * lm_ggml_sum( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, a->type, 1); -bool lm_ggml_get_no_alloc(struct lm_ggml_context * ctx) { - return ctx->no_alloc; -} + result->op = LM_GGML_OP_SUM; + result->src[0] = a; -void lm_ggml_set_no_alloc(struct lm_ggml_context * ctx, bool no_alloc) { - ctx->no_alloc = no_alloc; + return result; } -void * lm_ggml_get_mem_buffer(const struct lm_ggml_context * ctx) { - return ctx->mem_buffer; -} +// lm_ggml_sum_rows -size_t lm_ggml_get_mem_size(const struct lm_ggml_context * ctx) { - return ctx->mem_size; -} +struct lm_ggml_tensor * lm_ggml_sum_rows( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + int64_t ne[LM_GGML_MAX_DIMS] = { 1 }; + for (int i = 1; i < LM_GGML_MAX_DIMS; ++i) { + ne[i] = a->ne[i]; + } -size_t lm_ggml_get_max_tensor_size(const struct lm_ggml_context * ctx) { - size_t max_size = 0; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, ne); - for (struct lm_ggml_tensor * tensor = lm_ggml_get_first_tensor(ctx); tensor != NULL; tensor = lm_ggml_get_next_tensor(ctx, tensor)) { - size_t bytes = lm_ggml_nbytes(tensor); - max_size = MAX(max_size, bytes); - } + result->op = LM_GGML_OP_SUM_ROWS; + result->src[0] = a; - return max_size; + return result; } -//////////////////////////////////////////////////////////////////////////////// +// lm_ggml_mean -static struct lm_ggml_object * lm_ggml_new_object(struct lm_ggml_context * ctx, enum lm_ggml_object_type type, size_t size) { - // always insert objects at the end of the context's memory pool - struct lm_ggml_object * obj_cur = ctx->objects_end; +struct lm_ggml_tensor * lm_ggml_mean( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs; - const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; - const size_t cur_end = cur_offs + cur_size; + result->op = LM_GGML_OP_MEAN; + result->src[0] = a; - // align to LM_GGML_MEM_ALIGN - size_t size_needed = LM_GGML_PAD(size, LM_GGML_MEM_ALIGN); + return result; +} - char * const mem_buffer = ctx->mem_buffer; - struct lm_ggml_object * const obj_new = (struct lm_ggml_object *)(mem_buffer + cur_end); +// lm_ggml_argmax - if (cur_end + size_needed + LM_GGML_OBJECT_SIZE > ctx->mem_size) { - LM_GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", - __func__, cur_end + size_needed + LM_GGML_OBJECT_SIZE, ctx->mem_size); -#ifndef NDEBUG - LM_GGML_ABORT("not enough space in the context's memory pool"); -#endif - return NULL; - } +struct lm_ggml_tensor * lm_ggml_argmax( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + LM_GGML_ASSERT(lm_ggml_is_matrix(a)); - *obj_new = (struct lm_ggml_object) { - .offs = cur_end + LM_GGML_OBJECT_SIZE, - .size = size_needed, - .next = NULL, - .type = type, - }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, a->ne[1]); - LM_GGML_ASSERT_ALIGNED(mem_buffer + obj_new->offs); + result->op = LM_GGML_OP_ARGMAX; + result->src[0] = a; - if (obj_cur != NULL) { - obj_cur->next = obj_new; - } else { - // this is the first object in this context - ctx->objects_begin = obj_new; - } + return result; +} - ctx->objects_end = obj_new; +// lm_ggml_count_equal - //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size); +struct lm_ggml_tensor * lm_ggml_count_equal( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); - return obj_new; -} + struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I64, 1); -static struct lm_ggml_tensor * lm_ggml_new_tensor_impl( - struct lm_ggml_context * ctx, - enum lm_ggml_type type, - int n_dims, - const int64_t * ne, - struct lm_ggml_tensor * view_src, - size_t view_offs) { + result->op = LM_GGML_OP_COUNT_EQUAL; + result->src[0] = a; + result->src[1] = b; - LM_GGML_ASSERT(type >= 0 && type < LM_GGML_TYPE_COUNT); - LM_GGML_ASSERT(n_dims >= 1 && n_dims <= LM_GGML_MAX_DIMS); + return result; +} - // find the base tensor and absolute offset - if (view_src != NULL && view_src->view_src != NULL) { - view_offs += view_src->view_offs; - view_src = view_src->view_src; - } +// lm_ggml_repeat - size_t data_size = lm_ggml_row_size(type, ne[0]); - for (int i = 1; i < n_dims; i++) { - data_size *= ne[i]; - } +struct lm_ggml_tensor * lm_ggml_repeat( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + LM_GGML_ASSERT(lm_ggml_can_repeat(a, b)); - LM_GGML_ASSERT(view_src == NULL || data_size == 0 || data_size + view_offs <= lm_ggml_nbytes(view_src)); + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, b->ne); - void * data = view_src != NULL ? view_src->data : NULL; - if (data != NULL) { - data = (char *) data + view_offs; - } + result->op = LM_GGML_OP_REPEAT; + result->src[0] = a; - size_t obj_alloc_size = 0; + return result; +} - if (view_src == NULL && !ctx->no_alloc) { - // allocate tensor data in the context's memory pool - obj_alloc_size = data_size; - } +// lm_ggml_repeat_back - struct lm_ggml_object * const obj_new = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_TENSOR, LM_GGML_TENSOR_SIZE + obj_alloc_size); - LM_GGML_ASSERT(obj_new); +struct lm_ggml_tensor * lm_ggml_repeat_back( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); - struct lm_ggml_tensor * const result = (struct lm_ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs); + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, b->ne); -#ifdef __clang__ - // temporary until lm_ggml_tensor::backend is removed - #pragma clang diagnostic push - #pragma clang diagnostic ignored "-Wdeprecated-declarations" -#endif + result->op = LM_GGML_OP_REPEAT_BACK; + result->src[0] = a; - *result = (struct lm_ggml_tensor) { - /*.type =*/ type, - /*.backend =*/ LM_GGML_BACKEND_TYPE_CPU, - /*.buffer =*/ NULL, - /*.ne =*/ { 1, 1, 1, 1 }, - /*.nb =*/ { 0, 0, 0, 0 }, - /*.op =*/ LM_GGML_OP_NONE, - /*.op_params =*/ { 0 }, - /*.flags =*/ 0, - /*.grad =*/ NULL, - /*.src =*/ { NULL }, - /*.view_src =*/ view_src, - /*.view_offs =*/ view_offs, - /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data, - /*.name =*/ { 0 }, - /*.extra =*/ NULL, - ///*.padding =*/ { 0 }, - }; + return result; +} -#ifdef __clang__ - #pragma clang diagnostic pop -#endif +// lm_ggml_concat - // TODO: this should not be needed as long as we don't rely on aligned SIMD loads - //LM_GGML_ASSERT_ALIGNED(result->data); +struct lm_ggml_tensor * lm_ggml_concat( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + int dim) { + LM_GGML_ASSERT(dim >= 0 && dim < LM_GGML_MAX_DIMS); - for (int i = 0; i < n_dims; i++) { - result->ne[i] = ne[i]; + int64_t ne[LM_GGML_MAX_DIMS]; + for (int d = 0; d < LM_GGML_MAX_DIMS; ++d) { + if (d == dim) { + ne[d] = a->ne[d] + b->ne[d]; + continue; + } + LM_GGML_ASSERT(a->ne[d] == b->ne[d]); + ne[d] = a->ne[d]; } - result->nb[0] = lm_ggml_type_size(type); - result->nb[1] = result->nb[0]*(result->ne[0]/lm_ggml_blck_size(type)); - for (int i = 2; i < LM_GGML_MAX_DIMS; i++) { - result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; - } + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, ne); - ctx->n_objects++; + lm_ggml_set_op_params_i32(result, 0, dim); + + result->op = LM_GGML_OP_CONCAT; + result->src[0] = a; + result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_new_tensor( - struct lm_ggml_context * ctx, - enum lm_ggml_type type, - int n_dims, - const int64_t * ne) { - return lm_ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL, 0); -} +// lm_ggml_abs -struct lm_ggml_tensor * lm_ggml_new_tensor_1d( +struct lm_ggml_tensor * lm_ggml_abs( struct lm_ggml_context * ctx, - enum lm_ggml_type type, - int64_t ne0) { - return lm_ggml_new_tensor(ctx, type, 1, &ne0); + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_ABS); } -struct lm_ggml_tensor * lm_ggml_new_tensor_2d( +struct lm_ggml_tensor * lm_ggml_abs_inplace( struct lm_ggml_context * ctx, - enum lm_ggml_type type, - int64_t ne0, - int64_t ne1) { - const int64_t ne[2] = { ne0, ne1 }; - return lm_ggml_new_tensor(ctx, type, 2, ne); + struct lm_ggml_tensor * a) { + return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_ABS); } -struct lm_ggml_tensor * lm_ggml_new_tensor_3d( +// lm_ggml_sgn + +struct lm_ggml_tensor * lm_ggml_sgn( struct lm_ggml_context * ctx, - enum lm_ggml_type type, - int64_t ne0, - int64_t ne1, - int64_t ne2) { - const int64_t ne[3] = { ne0, ne1, ne2 }; - return lm_ggml_new_tensor(ctx, type, 3, ne); + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_SGN); } -struct lm_ggml_tensor * lm_ggml_new_tensor_4d( +struct lm_ggml_tensor * lm_ggml_sgn_inplace( struct lm_ggml_context * ctx, - enum lm_ggml_type type, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int64_t ne3) { - const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; - return lm_ggml_new_tensor(ctx, type, 4, ne); + struct lm_ggml_tensor * a) { + return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_SGN); } -struct lm_ggml_tensor * lm_ggml_new_i32(struct lm_ggml_context * ctx, int32_t value) { - struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, 1); - - lm_ggml_set_i32(result, value); +// lm_ggml_neg - return result; +struct lm_ggml_tensor * lm_ggml_neg( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_NEG); } -struct lm_ggml_tensor * lm_ggml_new_f32(struct lm_ggml_context * ctx, float value) { - struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, 1); +struct lm_ggml_tensor * lm_ggml_neg_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_NEG); +} - lm_ggml_set_f32(result, value); +// lm_ggml_step - return result; +struct lm_ggml_tensor * lm_ggml_step( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_STEP); } -struct lm_ggml_tensor * lm_ggml_dup_tensor(struct lm_ggml_context * ctx, const struct lm_ggml_tensor * src) { - return lm_ggml_new_tensor(ctx, src->type, LM_GGML_MAX_DIMS, src->ne); +struct lm_ggml_tensor * lm_ggml_step_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_STEP); } -static void lm_ggml_set_op_params(struct lm_ggml_tensor * tensor, const void * params, size_t params_size) { - LM_GGML_ASSERT(tensor != NULL); // silence -Warray-bounds warnings - assert(params_size <= LM_GGML_MAX_OP_PARAMS); - memcpy(tensor->op_params, params, params_size); -} +// lm_ggml_tanh -static int32_t lm_ggml_get_op_params_i32(const struct lm_ggml_tensor * tensor, uint32_t i) { - assert(i < LM_GGML_MAX_OP_PARAMS / sizeof(int32_t)); - return ((const int32_t *)(tensor->op_params))[i]; +struct lm_ggml_tensor * lm_ggml_tanh( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_TANH); } -static float lm_ggml_get_op_params_f32(const struct lm_ggml_tensor * tensor, uint32_t i) { - assert(i < LM_GGML_MAX_OP_PARAMS / sizeof(float)); - return ((const float *)(tensor->op_params))[i]; +struct lm_ggml_tensor * lm_ggml_tanh_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_TANH); } -static void lm_ggml_set_op_params_i32(struct lm_ggml_tensor * tensor, uint32_t i, int32_t value) { - assert(i < LM_GGML_MAX_OP_PARAMS / sizeof(int32_t)); - ((int32_t *)(tensor->op_params))[i] = value; -} +// lm_ggml_elu -static void lm_ggml_set_op_params_f32(struct lm_ggml_tensor * tensor, uint32_t i, float value) { - assert(i < LM_GGML_MAX_OP_PARAMS / sizeof(float)); - ((float *)(tensor->op_params))[i] = value; +struct lm_ggml_tensor * lm_ggml_elu( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_ELU); } -struct lm_ggml_tensor * lm_ggml_set_zero(struct lm_ggml_tensor * tensor) { - if (lm_ggml_is_empty(tensor)) { - return tensor; - } - if (tensor->buffer) { - lm_ggml_backend_tensor_memset(tensor, 0, 0, lm_ggml_nbytes(tensor)); - } else { - LM_GGML_ASSERT(tensor->data); - memset(tensor->data, 0, lm_ggml_nbytes(tensor)); - } - return tensor; +struct lm_ggml_tensor * lm_ggml_elu_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_ELU); } -struct lm_ggml_tensor * lm_ggml_set_i32 (struct lm_ggml_tensor * tensor, int32_t value) { - const int n = lm_ggml_nrows(tensor); - const int nc = tensor->ne[0]; - const size_t n1 = tensor->nb[1]; - - char * const data = tensor->data; +// lm_ggml_relu - switch (tensor->type) { - case LM_GGML_TYPE_I8: - { - assert(tensor->nb[0] == sizeof(int8_t)); - for (int i = 0; i < n; i++) { - lm_ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); - } - } break; - case LM_GGML_TYPE_I16: - { - assert(tensor->nb[0] == sizeof(int16_t)); - for (int i = 0; i < n; i++) { - lm_ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); - } - } break; - case LM_GGML_TYPE_I32: - { - assert(tensor->nb[0] == sizeof(int32_t)); - for (int i = 0; i < n; i++) { - lm_ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); - } - } break; - case LM_GGML_TYPE_F16: - { - assert(tensor->nb[0] == sizeof(lm_ggml_fp16_t)); - for (int i = 0; i < n; i++) { - lm_ggml_vec_set_f16(nc, (lm_ggml_fp16_t *)(data + i*n1), LM_GGML_FP32_TO_FP16(value)); - } - } break; - case LM_GGML_TYPE_BF16: - { - assert(tensor->nb[0] == sizeof(lm_ggml_fp16_t)); - for (int i = 0; i < n; i++) { - lm_ggml_vec_set_bf16(nc, (lm_ggml_bf16_t *)(data + i*n1), LM_GGML_FP32_TO_BF16(value)); - } - } break; - case LM_GGML_TYPE_F32: - { - assert(tensor->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { - lm_ggml_vec_set_f32(nc, (float *)(data + i*n1), value); - } - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } +struct lm_ggml_tensor * lm_ggml_relu( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_RELU); +} - return tensor; +struct lm_ggml_tensor * lm_ggml_relu_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_RELU); } -struct lm_ggml_tensor * lm_ggml_set_f32(struct lm_ggml_tensor * tensor, float value) { - const int n = lm_ggml_nrows(tensor); - const int nc = tensor->ne[0]; - const size_t n1 = tensor->nb[1]; +// lm_ggml_leaky_relu + +struct lm_ggml_tensor * lm_ggml_leaky_relu( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + float negative_slope, + bool inplace) { + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - char * const data = tensor->data; + lm_ggml_set_op_params(result, &negative_slope, sizeof(negative_slope)); - switch (tensor->type) { - case LM_GGML_TYPE_I8: - { - assert(tensor->nb[0] == sizeof(int8_t)); - for (int i = 0; i < n; i++) { - lm_ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); - } - } break; - case LM_GGML_TYPE_I16: - { - assert(tensor->nb[0] == sizeof(int16_t)); - for (int i = 0; i < n; i++) { - lm_ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); - } - } break; - case LM_GGML_TYPE_I32: - { - assert(tensor->nb[0] == sizeof(int32_t)); - for (int i = 0; i < n; i++) { - lm_ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); - } - } break; - case LM_GGML_TYPE_F16: - { - assert(tensor->nb[0] == sizeof(lm_ggml_fp16_t)); - for (int i = 0; i < n; i++) { - lm_ggml_vec_set_f16(nc, (lm_ggml_fp16_t *)(data + i*n1), LM_GGML_FP32_TO_FP16(value)); - } - } break; - case LM_GGML_TYPE_BF16: - { - assert(tensor->nb[0] == sizeof(lm_ggml_bf16_t)); - for (int i = 0; i < n; i++) { - lm_ggml_vec_set_bf16(nc, (lm_ggml_bf16_t *)(data + i*n1), LM_GGML_FP32_TO_BF16(value)); - } - } break; - case LM_GGML_TYPE_F32: - { - assert(tensor->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { - lm_ggml_vec_set_f32(nc, (float *)(data + i*n1), value); - } - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } + result->op = LM_GGML_OP_LEAKY_RELU; + result->src[0] = a; - return tensor; + return result; } -void lm_ggml_unravel_index(const struct lm_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) { - const int64_t ne2 = tensor->ne[2]; - const int64_t ne1 = tensor->ne[1]; - const int64_t ne0 = tensor->ne[0]; - - const int64_t i3_ = (i/(ne2*ne1*ne0)); - const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0); - const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0; - const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0); +// lm_ggml_sigmoid - if (i0) { - * i0 = i0_; - } - if (i1) { - * i1 = i1_; - } - if (i2) { - * i2 = i2_; - } - if (i3) { - * i3 = i3_; - } +struct lm_ggml_tensor * lm_ggml_sigmoid( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_SIGMOID); } -int32_t lm_ggml_get_i32_1d(const struct lm_ggml_tensor * tensor, int i) { - if (!lm_ggml_is_contiguous(tensor)) { - int64_t id[4] = { 0, 0, 0, 0 }; - lm_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); - return lm_ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]); - } - switch (tensor->type) { - case LM_GGML_TYPE_I8: - { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); - return ((int8_t *)(tensor->data))[i]; - } - case LM_GGML_TYPE_I16: - { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); - return ((int16_t *)(tensor->data))[i]; - } - case LM_GGML_TYPE_I32: - { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); - return ((int32_t *)(tensor->data))[i]; - } - case LM_GGML_TYPE_F16: - { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(lm_ggml_fp16_t)); - return LM_GGML_FP16_TO_FP32(((lm_ggml_fp16_t *)(tensor->data))[i]); - } - case LM_GGML_TYPE_BF16: - { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(lm_ggml_bf16_t)); - return LM_GGML_BF16_TO_FP32(((lm_ggml_bf16_t *)(tensor->data))[i]); - } - case LM_GGML_TYPE_F32: - { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(float)); - return ((float *)(tensor->data))[i]; - } - default: - { - LM_GGML_ABORT("fatal error"); - } - } +struct lm_ggml_tensor * lm_ggml_sigmoid_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_SIGMOID); } -void lm_ggml_set_i32_1d(const struct lm_ggml_tensor * tensor, int i, int32_t value) { - if (!lm_ggml_is_contiguous(tensor)) { - int64_t id[4] = { 0, 0, 0, 0 }; - lm_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); - lm_ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value); - return; - } - switch (tensor->type) { - case LM_GGML_TYPE_I8: - { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); - ((int8_t *)(tensor->data))[i] = value; - } break; - case LM_GGML_TYPE_I16: - { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); - ((int16_t *)(tensor->data))[i] = value; - } break; - case LM_GGML_TYPE_I32: - { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); - ((int32_t *)(tensor->data))[i] = value; - } break; - case LM_GGML_TYPE_F16: - { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(lm_ggml_fp16_t)); - ((lm_ggml_fp16_t *)(tensor->data))[i] = LM_GGML_FP32_TO_FP16(value); - } break; - case LM_GGML_TYPE_BF16: - { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(lm_ggml_bf16_t)); - ((lm_ggml_bf16_t *)(tensor->data))[i] = LM_GGML_FP32_TO_BF16(value); - } break; - case LM_GGML_TYPE_F32: - { - LM_GGML_ASSERT(tensor->nb[0] == sizeof(float)); - ((float *)(tensor->data))[i] = value; - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } +// lm_ggml_gelu + +struct lm_ggml_tensor * lm_ggml_gelu( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_GELU); } -int32_t lm_ggml_get_i32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3) { - void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; - switch (tensor->type) { - case LM_GGML_TYPE_I8: - return ((int8_t *) data)[0]; - case LM_GGML_TYPE_I16: - return ((int16_t *) data)[0]; - case LM_GGML_TYPE_I32: - return ((int32_t *) data)[0]; - case LM_GGML_TYPE_F16: - return LM_GGML_FP16_TO_FP32(((lm_ggml_fp16_t *) data)[0]); - case LM_GGML_TYPE_BF16: - return LM_GGML_BF16_TO_FP32(((lm_ggml_bf16_t *) data)[0]); - case LM_GGML_TYPE_F32: - return ((float *) data)[0]; - default: - LM_GGML_ABORT("fatal error"); - } +struct lm_ggml_tensor * lm_ggml_gelu_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_GELU); } -void lm_ggml_set_i32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) { - void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; - switch (tensor->type) { - case LM_GGML_TYPE_I8: - { - ((int8_t *)(data))[0] = value; - } break; - case LM_GGML_TYPE_I16: - { - ((int16_t *)(data))[0] = value; - } break; - case LM_GGML_TYPE_I32: - { - ((int32_t *)(data))[0] = value; - } break; - case LM_GGML_TYPE_F16: - { - ((lm_ggml_fp16_t *)(data))[0] = LM_GGML_FP32_TO_FP16(value); - } break; - case LM_GGML_TYPE_BF16: - { - ((lm_ggml_bf16_t *)(data))[0] = LM_GGML_FP32_TO_BF16(value); - } break; - case LM_GGML_TYPE_F32: - { - ((float *)(data))[0] = value; - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -float lm_ggml_get_f32_1d(const struct lm_ggml_tensor * tensor, int i) { - if (!lm_ggml_is_contiguous(tensor)) { - int64_t id[4] = { 0, 0, 0, 0 }; - lm_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); - return lm_ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]); - } - switch (tensor->type) { - case LM_GGML_TYPE_I8: - { - return ((int8_t *)(tensor->data))[i]; - } - case LM_GGML_TYPE_I16: - { - return ((int16_t *)(tensor->data))[i]; - } - case LM_GGML_TYPE_I32: - { - return ((int32_t *)(tensor->data))[i]; - } - case LM_GGML_TYPE_F16: - { - return LM_GGML_FP16_TO_FP32(((lm_ggml_fp16_t *)(tensor->data))[i]); - } - case LM_GGML_TYPE_BF16: - { - return LM_GGML_BF16_TO_FP32(((lm_ggml_bf16_t *)(tensor->data))[i]); - } - case LM_GGML_TYPE_F32: - { - return ((float *)(tensor->data))[i]; - } - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -void lm_ggml_set_f32_1d(const struct lm_ggml_tensor * tensor, int i, float value) { - if (!lm_ggml_is_contiguous(tensor)) { - int64_t id[4] = { 0, 0, 0, 0 }; - lm_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); - lm_ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value); - return; - } - switch (tensor->type) { - case LM_GGML_TYPE_I8: - { - ((int8_t *)(tensor->data))[i] = value; - } break; - case LM_GGML_TYPE_I16: - { - ((int16_t *)(tensor->data))[i] = value; - } break; - case LM_GGML_TYPE_I32: - { - ((int32_t *)(tensor->data))[i] = value; - } break; - case LM_GGML_TYPE_F16: - { - ((lm_ggml_fp16_t *)(tensor->data))[i] = LM_GGML_FP32_TO_FP16(value); - } break; - case LM_GGML_TYPE_BF16: - { - ((lm_ggml_bf16_t *)(tensor->data))[i] = LM_GGML_FP32_TO_BF16(value); - } break; - case LM_GGML_TYPE_F32: - { - ((float *)(tensor->data))[i] = value; - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -float lm_ggml_get_f32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3) { - void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; - switch (tensor->type) { - case LM_GGML_TYPE_I8: - return ((int8_t *) data)[0]; - case LM_GGML_TYPE_I16: - return ((int16_t *) data)[0]; - case LM_GGML_TYPE_I32: - return ((int32_t *) data)[0]; - case LM_GGML_TYPE_F16: - return LM_GGML_FP16_TO_FP32(((lm_ggml_fp16_t *) data)[0]); - case LM_GGML_TYPE_BF16: - return LM_GGML_BF16_TO_FP32(((lm_ggml_bf16_t *) data)[0]); - case LM_GGML_TYPE_F32: - return ((float *) data)[0]; - default: - LM_GGML_ABORT("fatal error"); - } -} - -void lm_ggml_set_f32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) { - void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; - switch (tensor->type) { - case LM_GGML_TYPE_I8: - { - ((int8_t *)(data))[0] = value; - } break; - case LM_GGML_TYPE_I16: - { - ((int16_t *)(data))[0] = value; - } break; - case LM_GGML_TYPE_I32: - { - ((int32_t *)(data))[0] = value; - } break; - case LM_GGML_TYPE_F16: - { - ((lm_ggml_fp16_t *)(data))[0] = LM_GGML_FP32_TO_FP16(value); - } break; - case LM_GGML_TYPE_BF16: - { - ((lm_ggml_bf16_t *)(data))[0] = LM_GGML_FP32_TO_BF16(value); - } break; - case LM_GGML_TYPE_F32: - { - ((float *)(data))[0] = value; - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} +// lm_ggml_gelu_quick -void * lm_ggml_get_data(const struct lm_ggml_tensor * tensor) { - return tensor->data; +struct lm_ggml_tensor * lm_ggml_gelu_quick( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_GELU_QUICK); } -float * lm_ggml_get_data_f32(const struct lm_ggml_tensor * tensor) { - assert(tensor->type == LM_GGML_TYPE_F32); - return (float *)(tensor->data); +struct lm_ggml_tensor * lm_ggml_gelu_quick_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_GELU_QUICK); } -enum lm_ggml_unary_op lm_ggml_get_unary_op(const struct lm_ggml_tensor * tensor) { - LM_GGML_ASSERT(tensor->op == LM_GGML_OP_UNARY); - return (enum lm_ggml_unary_op) lm_ggml_get_op_params_i32(tensor, 0); -} +// lm_ggml_silu -const char * lm_ggml_get_name(const struct lm_ggml_tensor * tensor) { - return tensor->name; +struct lm_ggml_tensor * lm_ggml_silu( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_SILU); } -struct lm_ggml_tensor * lm_ggml_set_name(struct lm_ggml_tensor * tensor, const char * name) { - size_t i; - for (i = 0; i < sizeof(tensor->name) - 1 && name[i] != '\0'; i++) { - tensor->name[i] = name[i]; - } - tensor->name[i] = '\0'; - return tensor; +struct lm_ggml_tensor * lm_ggml_silu_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_SILU); } -struct lm_ggml_tensor * lm_ggml_format_name(struct lm_ggml_tensor * tensor, const char * fmt, ...) { - va_list args; - va_start(args, fmt); - vsnprintf(tensor->name, sizeof(tensor->name), fmt, args); - va_end(args); - return tensor; -} +// lm_ggml_silu_back -struct lm_ggml_tensor * lm_ggml_view_tensor( +struct lm_ggml_tensor * lm_ggml_silu_back( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * src) { - struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, src->type, LM_GGML_MAX_DIMS, src->ne, src, 0); - lm_ggml_format_name(result, "%s (view)", src->name); + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); - for (int i = 0; i < LM_GGML_MAX_DIMS; i++) { - result->nb[i] = src->nb[i]; - } + result->op = LM_GGML_OP_SILU_BACK; + result->src[0] = a; + result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_get_first_tensor(const struct lm_ggml_context * ctx) { - struct lm_ggml_object * obj = ctx->objects_begin; - - char * const mem_buffer = ctx->mem_buffer; - - while (obj != NULL) { - if (obj->type == LM_GGML_OBJECT_TYPE_TENSOR) { - return (struct lm_ggml_tensor *)(mem_buffer + obj->offs); - } - - obj = obj->next; - } +// ggml hardswish - return NULL; +struct lm_ggml_tensor * lm_ggml_hardswish( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_HARDSWISH); } -struct lm_ggml_tensor * lm_ggml_get_next_tensor(const struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor) { - struct lm_ggml_object * obj = (struct lm_ggml_object *) ((char *)tensor - LM_GGML_OBJECT_SIZE); - obj = obj->next; - - char * const mem_buffer = ctx->mem_buffer; - - while (obj != NULL) { - if (obj->type == LM_GGML_OBJECT_TYPE_TENSOR) { - return (struct lm_ggml_tensor *)(mem_buffer + obj->offs); - } - - obj = obj->next; - } +// ggml hardsigmoid - return NULL; +struct lm_ggml_tensor * lm_ggml_hardsigmoid( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_HARDSIGMOID); } -struct lm_ggml_tensor * lm_ggml_get_tensor(struct lm_ggml_context * ctx, const char * name) { - struct lm_ggml_object * obj = ctx->objects_begin; - - char * const mem_buffer = ctx->mem_buffer; - - while (obj != NULL) { - if (obj->type == LM_GGML_OBJECT_TYPE_TENSOR) { - struct lm_ggml_tensor * cur = (struct lm_ggml_tensor *)(mem_buffer + obj->offs); - if (strcmp(cur->name, name) == 0) { - return cur; - } - } - - obj = obj->next; - } +// ggml exp - return NULL; +struct lm_ggml_tensor * lm_ggml_exp( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_EXP); } -//////////////////////////////////////////////////////////////////////////////// +struct lm_ggml_tensor * lm_ggml_exp_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_EXP); +} -// lm_ggml_dup +// lm_ggml_norm -static struct lm_ggml_tensor * lm_ggml_dup_impl( +static struct lm_ggml_tensor * lm_ggml_norm_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, + float eps, bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_DUP; + lm_ggml_set_op_params(result, &eps, sizeof(eps)); + + result->op = LM_GGML_OP_NORM; result->src[0] = a; return result; } -struct lm_ggml_tensor * lm_ggml_dup( +struct lm_ggml_tensor * lm_ggml_norm( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_dup_impl(ctx, a, false); + struct lm_ggml_tensor * a, + float eps) { + return lm_ggml_norm_impl(ctx, a, eps, false); } -struct lm_ggml_tensor * lm_ggml_dup_inplace( +struct lm_ggml_tensor * lm_ggml_norm_inplace( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_dup_impl(ctx, a, true); + struct lm_ggml_tensor * a, + float eps) { + return lm_ggml_norm_impl(ctx, a, eps, true); } -// lm_ggml_add +// lm_ggml_rms_norm -static struct lm_ggml_tensor * lm_ggml_add_impl( +static struct lm_ggml_tensor * lm_ggml_rms_norm_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, + float eps, bool inplace) { - LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_ADD; + lm_ggml_set_op_params(result, &eps, sizeof(eps)); + + result->op = LM_GGML_OP_RMS_NORM; result->src[0] = a; - result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_add( +struct lm_ggml_tensor * lm_ggml_rms_norm( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - return lm_ggml_add_impl(ctx, a, b, false); -} + float eps) { + return lm_ggml_rms_norm_impl(ctx, a, eps, false); +} -struct lm_ggml_tensor * lm_ggml_add_inplace( +struct lm_ggml_tensor * lm_ggml_rms_norm_inplace( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - return lm_ggml_add_impl(ctx, a, b, true); + float eps) { + return lm_ggml_rms_norm_impl(ctx, a, eps, true); } -// lm_ggml_add_cast +// lm_ggml_rms_norm_back -static struct lm_ggml_tensor * lm_ggml_add_cast_impl( +struct lm_ggml_tensor * lm_ggml_rms_norm_back( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, - enum lm_ggml_type type) { - // TODO: support less-strict constraint - // LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); - LM_GGML_ASSERT(lm_ggml_can_repeat_rows(b, a)); - - // currently only supported for quantized input and f16 - LM_GGML_ASSERT(lm_ggml_is_quantized(a->type) || - a->type == LM_GGML_TYPE_F16 || - a->type == LM_GGML_TYPE_BF16); + float eps) { + struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, type, LM_GGML_MAX_DIMS, a->ne); + lm_ggml_set_op_params(result, &eps, sizeof(eps)); - result->op = LM_GGML_OP_ADD; + result->op = LM_GGML_OP_RMS_NORM_BACK; result->src[0] = a; result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_add_cast( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - enum lm_ggml_type type) { - return lm_ggml_add_cast_impl(ctx, a, b, type); -} - -// lm_ggml_add1 +// lm_ggml_group_norm -static struct lm_ggml_tensor * lm_ggml_add1_impl( +static struct lm_ggml_tensor * lm_ggml_group_norm_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, + int n_groups, + float eps, bool inplace) { - LM_GGML_ASSERT(lm_ggml_is_scalar(b)); - LM_GGML_ASSERT(lm_ggml_is_padded_1d(a)); - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_ADD1; + lm_ggml_set_op_params_i32(result, 0, n_groups); + lm_ggml_set_op_params_f32(result, 1, eps); + + result->op = LM_GGML_OP_GROUP_NORM; result->src[0] = a; - result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_add1( +struct lm_ggml_tensor * lm_ggml_group_norm( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - return lm_ggml_add1_impl(ctx, a, b, false); + int n_groups, + float eps) { + return lm_ggml_group_norm_impl(ctx, a, n_groups, eps, false); } -struct lm_ggml_tensor * lm_ggml_add1_inplace( +struct lm_ggml_tensor * lm_ggml_group_norm_inplace( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - return lm_ggml_add1_impl(ctx, a, b, true); + int n_groups, + float eps) { + return lm_ggml_group_norm_impl(ctx, a, n_groups, eps, true); } -// lm_ggml_acc +// lm_ggml_mul_mat -static struct lm_ggml_tensor * lm_ggml_acc_impl( +static inline bool lm_ggml_can_mul_mat(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { + static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[0] == t1->ne[0]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); +} + +struct lm_ggml_tensor * lm_ggml_mul_mat( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset, - bool inplace) { - LM_GGML_ASSERT(lm_ggml_nelements(b) <= lm_ggml_nelements(a)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); - LM_GGML_ASSERT(a->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(b->type == LM_GGML_TYPE_F32); - - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + struct lm_ggml_tensor * b) { + LM_GGML_ASSERT(lm_ggml_can_mul_mat(a, b)); + LM_GGML_ASSERT(!lm_ggml_is_transposed(a)); - int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; - lm_ggml_set_op_params(result, params, sizeof(params)); + const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - result->op = LM_GGML_OP_ACC; + result->op = LM_GGML_OP_MUL_MAT; result->src[0] = a; result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_acc( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { - return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); -} +void lm_ggml_mul_mat_set_prec( + struct lm_ggml_tensor * a, + enum lm_ggml_prec prec) { + LM_GGML_ASSERT(a->op == LM_GGML_OP_MUL_MAT); -struct lm_ggml_tensor * lm_ggml_acc_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { - return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); + const int32_t prec_i32 = (int32_t) prec; + + lm_ggml_set_op_params_i32(a, 0, prec_i32); } -// lm_ggml_sub +// lm_ggml_mul_mat_id -static struct lm_ggml_tensor * lm_ggml_sub_impl( +/* + c = lm_ggml_mul_mat_id(ctx, as, b, ids); + + as -> [cols, rows, n_expert] + ids -> [n_experts_used, n_tokens] (i32) + b -> [cols, n_expert_used, n_tokens] + c -> [rows, n_expert_used, n_tokens] + + in b, n_experts_used can be broadcasted to match the n_expert_used of ids + + c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids +*/ +struct lm_ggml_tensor * lm_ggml_mul_mat_id( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, + struct lm_ggml_tensor * as, struct lm_ggml_tensor * b, - bool inplace) { - LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); + struct lm_ggml_tensor * ids) { + LM_GGML_ASSERT(!lm_ggml_is_transposed(as)); + LM_GGML_ASSERT(ids->type == LM_GGML_TYPE_I32); - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + LM_GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert) + LM_GGML_ASSERT(b->ne[3] == 1); // b is 3d + LM_GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d + LM_GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row + LM_GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat + LM_GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast - result->op = LM_GGML_OP_SUB; - result->src[0] = a; + const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + + result->op = LM_GGML_OP_MUL_MAT_ID; + result->src[0] = as; result->src[1] = b; + result->src[2] = ids; return result; } -struct lm_ggml_tensor * lm_ggml_sub( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - return lm_ggml_sub_impl(ctx, a, b, false); +// lm_ggml_out_prod + +static inline bool lm_ggml_can_out_prod(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) { + static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[1] == t1->ne[1]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); } -struct lm_ggml_tensor * lm_ggml_sub_inplace( +struct lm_ggml_tensor * lm_ggml_out_prod( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b) { - return lm_ggml_sub_impl(ctx, a, b, true); + LM_GGML_ASSERT(lm_ggml_can_out_prod(a, b)); + LM_GGML_ASSERT(!lm_ggml_is_transposed(a)); + + // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3] + const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + + result->op = LM_GGML_OP_OUT_PROD; + result->src[0] = a; + result->src[1] = b; + + return result; } -// lm_ggml_mul +// lm_ggml_scale -static struct lm_ggml_tensor * lm_ggml_mul_impl( +static struct lm_ggml_tensor * lm_ggml_scale_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, + float s, bool inplace) { - LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); + LM_GGML_ASSERT(lm_ggml_is_padded_1d(a)); struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_MUL; + lm_ggml_set_op_params(result, &s, sizeof(s)); + + result->op = LM_GGML_OP_SCALE; result->src[0] = a; - result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_mul( +struct lm_ggml_tensor * lm_ggml_scale( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - return lm_ggml_mul_impl(ctx, a, b, false); + float s) { + return lm_ggml_scale_impl(ctx, a, s, false); } -struct lm_ggml_tensor * lm_ggml_mul_inplace( +struct lm_ggml_tensor * lm_ggml_scale_inplace( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - return lm_ggml_mul_impl(ctx, a, b, true); + float s) { + return lm_ggml_scale_impl(ctx, a, s, true); } -// lm_ggml_div +// lm_ggml_set -static struct lm_ggml_tensor * lm_ggml_div_impl( +static struct lm_ggml_tensor * lm_ggml_set_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset, bool inplace) { - LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); + LM_GGML_ASSERT(lm_ggml_nelements(a) >= lm_ggml_nelements(b)); + // make a view of the destination struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_DIV; + LM_GGML_ASSERT(offset < (size_t)(1 << 30)); + int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; + lm_ggml_set_op_params(result, params, sizeof(params)); + + result->op = LM_GGML_OP_SET; result->src[0] = a; result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_div( +struct lm_ggml_tensor * lm_ggml_set( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - return lm_ggml_div_impl(ctx, a, b, false); + struct lm_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return lm_ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false); } -struct lm_ggml_tensor * lm_ggml_div_inplace( +struct lm_ggml_tensor * lm_ggml_set_inplace( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - return lm_ggml_div_impl(ctx, a, b, true); + struct lm_ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return lm_ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true); } -// lm_ggml_sqr - -static struct lm_ggml_tensor * lm_ggml_sqr_impl( +struct lm_ggml_tensor * lm_ggml_set_1d( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - bool inplace) { - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - - result->op = LM_GGML_OP_SQR; - result->src[0] = a; + struct lm_ggml_tensor * b, + size_t offset) { + return lm_ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false); +} - return result; +struct lm_ggml_tensor * lm_ggml_set_1d_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + size_t offset) { + return lm_ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true); } -struct lm_ggml_tensor * lm_ggml_sqr( +struct lm_ggml_tensor * lm_ggml_set_2d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_sqr_impl(ctx, a, false); + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + size_t nb1, + size_t offset) { + return lm_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); } -struct lm_ggml_tensor * lm_ggml_sqr_inplace( +struct lm_ggml_tensor * lm_ggml_set_2d_inplace( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_sqr_impl(ctx, a, true); + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + size_t nb1, + size_t offset) { + return lm_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true); } -// lm_ggml_sqrt +// lm_ggml_cpy -static struct lm_ggml_tensor * lm_ggml_sqrt_impl( +static struct lm_ggml_tensor * lm_ggml_cpy_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - bool inplace) { - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + struct lm_ggml_tensor * b) { + LM_GGML_ASSERT(lm_ggml_nelements(a) == lm_ggml_nelements(b)); - result->op = LM_GGML_OP_SQRT; + // make a view of the destination + struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, b); + if (strlen(b->name) > 0) { + lm_ggml_format_name(result, "%s (copy of %s)", b->name, a->name); + } else { + lm_ggml_format_name(result, "%s (copy)", a->name); + } + + result->op = LM_GGML_OP_CPY; result->src[0] = a; + result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_sqrt( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_sqrt_impl(ctx, a, false); -} - -struct lm_ggml_tensor * lm_ggml_sqrt_inplace( +struct lm_ggml_tensor * lm_ggml_cpy( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_sqrt_impl(ctx, a, true); + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + return lm_ggml_cpy_impl(ctx, a, b); } -// lm_ggml_log - -static struct lm_ggml_tensor * lm_ggml_log_impl( +struct lm_ggml_tensor * lm_ggml_cast( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - bool inplace) { - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + enum lm_ggml_type type) { + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, type, LM_GGML_MAX_DIMS, a->ne); + lm_ggml_format_name(result, "%s (copy)", a->name); - result->op = LM_GGML_OP_LOG; + result->op = LM_GGML_OP_CPY; result->src[0] = a; + result->src[1] = result; return result; } -struct lm_ggml_tensor * lm_ggml_log( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_log_impl(ctx, a, false); -} +// lm_ggml_cont -struct lm_ggml_tensor * lm_ggml_log_inplace( +static struct lm_ggml_tensor * lm_ggml_cont_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a) { - return lm_ggml_log_impl(ctx, a, true); -} - -// lm_ggml_sin - -static struct lm_ggml_tensor * lm_ggml_sin_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - bool inplace) { - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); + lm_ggml_format_name(result, "%s (cont)", a->name); - result->op = LM_GGML_OP_SIN; + result->op = LM_GGML_OP_CONT; result->src[0] = a; return result; } -struct lm_ggml_tensor * lm_ggml_sin( +struct lm_ggml_tensor * lm_ggml_cont( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_sin_impl(ctx, a, false); + struct lm_ggml_tensor * a) { + return lm_ggml_cont_impl(ctx, a); } -struct lm_ggml_tensor * lm_ggml_sin_inplace( +// make contiguous, with new shape +LM_GGML_API struct lm_ggml_tensor * lm_ggml_cont_1d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_sin_impl(ctx, a, true); + struct lm_ggml_tensor * a, + int64_t ne0) { + return lm_ggml_cont_4d(ctx, a, ne0, 1, 1, 1); } -// lm_ggml_cos - -static struct lm_ggml_tensor * lm_ggml_cos_impl( +LM_GGML_API struct lm_ggml_tensor * lm_ggml_cont_2d( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - bool inplace) { - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - - result->op = LM_GGML_OP_COS; - result->src[0] = a; - - return result; + int64_t ne0, + int64_t ne1) { + return lm_ggml_cont_4d(ctx, a, ne0, ne1, 1, 1); } -struct lm_ggml_tensor * lm_ggml_cos( +LM_GGML_API struct lm_ggml_tensor * lm_ggml_cont_3d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_cos_impl(ctx, a, false); + struct lm_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + return lm_ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1); } -struct lm_ggml_tensor * lm_ggml_cos_inplace( +struct lm_ggml_tensor * lm_ggml_cont_4d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_cos_impl(ctx, a, true); -} - -// lm_ggml_sum + struct lm_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + LM_GGML_ASSERT(lm_ggml_nelements(a) == (ne0*ne1*ne2*ne3)); -struct lm_ggml_tensor * lm_ggml_sum( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, a->type, 1); + struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + lm_ggml_format_name(result, "%s (cont)", a->name); - result->op = LM_GGML_OP_SUM; + result->op = LM_GGML_OP_CONT; result->src[0] = a; return result; } -// lm_ggml_sum_rows +// lm_ggml_reshape -struct lm_ggml_tensor * lm_ggml_sum_rows( +struct lm_ggml_tensor * lm_ggml_reshape( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - int64_t ne[LM_GGML_MAX_DIMS] = { 1 }; - for (int i = 1; i < LM_GGML_MAX_DIMS; ++i) { - ne[i] = a->ne[i]; - } + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); + // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous. + LM_GGML_ASSERT(lm_ggml_nelements(a) == lm_ggml_nelements(b)); - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, ne); + struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, LM_GGML_MAX_DIMS, b->ne, a, 0); + lm_ggml_format_name(result, "%s (reshaped)", a->name); - result->op = LM_GGML_OP_SUM_ROWS; + result->op = LM_GGML_OP_RESHAPE; result->src[0] = a; return result; } -// lm_ggml_mean - -struct lm_ggml_tensor * lm_ggml_mean( +struct lm_ggml_tensor * lm_ggml_reshape_1d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + struct lm_ggml_tensor * a, + int64_t ne0) { + LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); + LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0); - result->op = LM_GGML_OP_MEAN; + const int64_t ne[1] = { ne0 }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0); + lm_ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = LM_GGML_OP_RESHAPE; result->src[0] = a; return result; } -// lm_ggml_argmax - -struct lm_ggml_tensor * lm_ggml_argmax( +struct lm_ggml_tensor * lm_ggml_reshape_2d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - LM_GGML_ASSERT(lm_ggml_is_matrix(a)); + struct lm_ggml_tensor * a, + int64_t ne0, + int64_t ne1) { + LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); + LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0*ne1); - struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, a->ne[1]); + const int64_t ne[2] = { ne0, ne1 }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0); + lm_ggml_format_name(result, "%s (reshaped)", a->name); - result->op = LM_GGML_OP_ARGMAX; + result->op = LM_GGML_OP_RESHAPE; result->src[0] = a; return result; } -// lm_ggml_count_equal - -struct lm_ggml_tensor * lm_ggml_count_equal( +struct lm_ggml_tensor * lm_ggml_reshape_3d( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); + int64_t ne0, + int64_t ne1, + int64_t ne2) { + LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); + LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0*ne1*ne2); - struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I64, 1); + const int64_t ne[3] = { ne0, ne1, ne2 }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0); + lm_ggml_format_name(result, "%s (reshaped)", a->name); - result->op = LM_GGML_OP_COUNT_EQUAL; + result->op = LM_GGML_OP_RESHAPE; result->src[0] = a; - result->src[1] = b; return result; } -// lm_ggml_repeat - -struct lm_ggml_tensor * lm_ggml_repeat( +struct lm_ggml_tensor * lm_ggml_reshape_4d( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - LM_GGML_ASSERT(lm_ggml_can_repeat(a, b)); + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); + LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0*ne1*ne2*ne3); - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, b->ne); + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0); + lm_ggml_format_name(result, "%s (reshaped)", a->name); - result->op = LM_GGML_OP_REPEAT; + result->op = LM_GGML_OP_RESHAPE; result->src[0] = a; return result; } -// lm_ggml_repeat_back - -struct lm_ggml_tensor * lm_ggml_repeat_back( +static struct lm_ggml_tensor * lm_ggml_view_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - LM_GGML_ASSERT(lm_ggml_can_repeat(b, a)); - - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, b->ne); + int n_dims, + const int64_t * ne, + size_t offset) { + struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset); + lm_ggml_format_name(result, "%s (view)", a->name); - result->op = LM_GGML_OP_REPEAT_BACK; + lm_ggml_set_op_params(result, &offset, sizeof(offset)); + + result->op = LM_GGML_OP_VIEW; result->src[0] = a; return result; } -// lm_ggml_concat +// lm_ggml_view_1d -struct lm_ggml_tensor * lm_ggml_concat( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int dim) { - LM_GGML_ASSERT(dim >= 0 && dim < LM_GGML_MAX_DIMS); +struct lm_ggml_tensor * lm_ggml_view_1d( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + int64_t ne0, + size_t offset) { + struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 1, &ne0, offset); - int64_t ne[LM_GGML_MAX_DIMS]; - for (int d = 0; d < LM_GGML_MAX_DIMS; ++d) { - if (d == dim) { - ne[d] = a->ne[d] + b->ne[d]; - continue; - } - LM_GGML_ASSERT(a->ne[d] == b->ne[d]); - ne[d] = a->ne[d]; - } + return result; +} - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, ne); +// lm_ggml_view_2d - lm_ggml_set_op_params_i32(result, 0, dim); +struct lm_ggml_tensor * lm_ggml_view_2d( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + size_t nb1, + size_t offset) { + const int64_t ne[2] = { ne0, ne1 }; - result->op = LM_GGML_OP_CONCAT; - result->src[0] = a; - result->src[1] = b; + struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 2, ne, offset); + + result->nb[1] = nb1; + result->nb[2] = result->nb[1]*ne1; + result->nb[3] = result->nb[2]; return result; } -// lm_ggml_abs +// lm_ggml_view_3d -struct lm_ggml_tensor * lm_ggml_abs( +struct lm_ggml_tensor * lm_ggml_view_3d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_ABS); -} + struct lm_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + size_t nb1, + size_t nb2, + size_t offset) { + const int64_t ne[3] = { ne0, ne1, ne2 }; -struct lm_ggml_tensor * lm_ggml_abs_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_ABS); -} + struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 3, ne, offset); -// lm_ggml_sgn + result->nb[1] = nb1; + result->nb[2] = nb2; + result->nb[3] = result->nb[2]*ne2; -struct lm_ggml_tensor * lm_ggml_sgn( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_SGN); + return result; } -struct lm_ggml_tensor * lm_ggml_sgn_inplace( +// lm_ggml_view_4d + +struct lm_ggml_tensor * lm_ggml_view_4d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_SGN); -} + struct lm_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; -// lm_ggml_neg + struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 4, ne, offset); -struct lm_ggml_tensor * lm_ggml_neg( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_NEG); -} + result->nb[1] = nb1; + result->nb[2] = nb2; + result->nb[3] = nb3; -struct lm_ggml_tensor * lm_ggml_neg_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_NEG); + return result; } -// lm_ggml_step +// lm_ggml_permute -struct lm_ggml_tensor * lm_ggml_step( +struct lm_ggml_tensor * lm_ggml_permute( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_STEP); -} + struct lm_ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3) { + LM_GGML_ASSERT(axis0 >= 0 && axis0 < LM_GGML_MAX_DIMS); + LM_GGML_ASSERT(axis1 >= 0 && axis1 < LM_GGML_MAX_DIMS); + LM_GGML_ASSERT(axis2 >= 0 && axis2 < LM_GGML_MAX_DIMS); + LM_GGML_ASSERT(axis3 >= 0 && axis3 < LM_GGML_MAX_DIMS); -struct lm_ggml_tensor * lm_ggml_step_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_STEP); -} + LM_GGML_ASSERT(axis0 != axis1); + LM_GGML_ASSERT(axis0 != axis2); + LM_GGML_ASSERT(axis0 != axis3); + LM_GGML_ASSERT(axis1 != axis2); + LM_GGML_ASSERT(axis1 != axis3); + LM_GGML_ASSERT(axis2 != axis3); -// lm_ggml_tanh + struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a); + lm_ggml_format_name(result, "%s (permuted)", a->name); -struct lm_ggml_tensor * lm_ggml_tanh( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_TANH); -} + int ne[LM_GGML_MAX_DIMS]; + int nb[LM_GGML_MAX_DIMS]; -struct lm_ggml_tensor * lm_ggml_tanh_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_TANH); -} + ne[axis0] = a->ne[0]; + ne[axis1] = a->ne[1]; + ne[axis2] = a->ne[2]; + ne[axis3] = a->ne[3]; -// lm_ggml_elu + nb[axis0] = a->nb[0]; + nb[axis1] = a->nb[1]; + nb[axis2] = a->nb[2]; + nb[axis3] = a->nb[3]; -struct lm_ggml_tensor * lm_ggml_elu( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_ELU); -} + result->ne[0] = ne[0]; + result->ne[1] = ne[1]; + result->ne[2] = ne[2]; + result->ne[3] = ne[3]; -struct lm_ggml_tensor * lm_ggml_elu_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_ELU); + result->nb[0] = nb[0]; + result->nb[1] = nb[1]; + result->nb[2] = nb[2]; + result->nb[3] = nb[3]; + + result->op = LM_GGML_OP_PERMUTE; + result->src[0] = a; + + int32_t params[] = { axis0, axis1, axis2, axis3 }; + lm_ggml_set_op_params(result, params, sizeof(params)); + + return result; } -// lm_ggml_relu +// lm_ggml_transpose -struct lm_ggml_tensor * lm_ggml_relu( +struct lm_ggml_tensor * lm_ggml_transpose( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_RELU); + struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a); + lm_ggml_format_name(result, "%s (transposed)", a->name); + + result->ne[0] = a->ne[1]; + result->ne[1] = a->ne[0]; + + result->nb[0] = a->nb[1]; + result->nb[1] = a->nb[0]; + + result->op = LM_GGML_OP_TRANSPOSE; + result->src[0] = a; + + return result; } -struct lm_ggml_tensor * lm_ggml_relu_inplace( +// lm_ggml_get_rows + +struct lm_ggml_tensor * lm_ggml_get_rows( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_RELU); + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b) { + LM_GGML_ASSERT(a->ne[2] == b->ne[1]); + LM_GGML_ASSERT(b->ne[3] == 1); + LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32); + + // TODO: implement non F32 return + enum lm_ggml_type type = LM_GGML_TYPE_F32; + if (a->type == LM_GGML_TYPE_I32) { + type = a->type; + } + struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); + + result->op = LM_GGML_OP_GET_ROWS; + result->src[0] = a; + result->src[1] = b; + + return result; } -// lm_ggml_leaky_relu +// lm_ggml_get_rows_back -struct lm_ggml_tensor * lm_ggml_leaky_relu( +struct lm_ggml_tensor * lm_ggml_get_rows_back( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - float negative_slope, - bool inplace) { - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c) { + LM_GGML_ASSERT(lm_ggml_is_matrix(a) && lm_ggml_is_vector(b) && b->type == LM_GGML_TYPE_I32); + LM_GGML_ASSERT(lm_ggml_is_matrix(c) && (a->ne[0] == c->ne[0])); - lm_ggml_set_op_params(result, &negative_slope, sizeof(negative_slope)); + // TODO: implement non F32 return + //struct lm_ggml_tensor * result = lm_ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); + struct lm_ggml_tensor * result = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, c->ne[0], c->ne[1]); - result->op = LM_GGML_OP_LEAKY_RELU; + result->op = LM_GGML_OP_GET_ROWS_BACK; result->src[0] = a; + result->src[1] = b; return result; } -// lm_ggml_sigmoid +// lm_ggml_diag -struct lm_ggml_tensor * lm_ggml_sigmoid( +struct lm_ggml_tensor * lm_ggml_diag( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_SIGMOID); -} + LM_GGML_ASSERT(a->ne[1] == 1); -struct lm_ggml_tensor * lm_ggml_sigmoid_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_SIGMOID); -} + const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, 4, ne); -// lm_ggml_gelu + result->op = LM_GGML_OP_DIAG; + result->src[0] = a; -struct lm_ggml_tensor * lm_ggml_gelu( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_GELU); + return result; } -struct lm_ggml_tensor * lm_ggml_gelu_inplace( +// lm_ggml_diag_mask_inf + +static struct lm_ggml_tensor * lm_ggml_diag_mask_inf_impl( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_GELU); -} + struct lm_ggml_tensor * a, + int n_past, + bool inplace) { + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); -// lm_ggml_gelu_quick + int32_t params[] = { n_past }; + lm_ggml_set_op_params(result, params, sizeof(params)); -struct lm_ggml_tensor * lm_ggml_gelu_quick( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_GELU_QUICK); -} + result->op = LM_GGML_OP_DIAG_MASK_INF; + result->src[0] = a; -struct lm_ggml_tensor * lm_ggml_gelu_quick_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_GELU_QUICK); + return result; } -// lm_ggml_silu - -struct lm_ggml_tensor * lm_ggml_silu( +struct lm_ggml_tensor * lm_ggml_diag_mask_inf( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_SILU); + struct lm_ggml_tensor * a, + int n_past) { + return lm_ggml_diag_mask_inf_impl(ctx, a, n_past, false); } -struct lm_ggml_tensor * lm_ggml_silu_inplace( +struct lm_ggml_tensor * lm_ggml_diag_mask_inf_inplace( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_SILU); + struct lm_ggml_tensor * a, + int n_past) { + return lm_ggml_diag_mask_inf_impl(ctx, a, n_past, true); } -// lm_ggml_silu_back +// lm_ggml_diag_mask_zero -struct lm_ggml_tensor * lm_ggml_silu_back( +static struct lm_ggml_tensor * lm_ggml_diag_mask_zero_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); + int n_past, + bool inplace) { + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_SILU_BACK; + int32_t params[] = { n_past }; + lm_ggml_set_op_params(result, params, sizeof(params)); + + result->op = LM_GGML_OP_DIAG_MASK_ZERO; result->src[0] = a; - result->src[1] = b; return result; } -// ggml hardswish - -struct lm_ggml_tensor * lm_ggml_hardswish( +struct lm_ggml_tensor * lm_ggml_diag_mask_zero( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_HARDSWISH); + struct lm_ggml_tensor * a, + int n_past) { + return lm_ggml_diag_mask_zero_impl(ctx, a, n_past, false); } -// ggml hardsigmoid - -struct lm_ggml_tensor * lm_ggml_hardsigmoid( +struct lm_ggml_tensor * lm_ggml_diag_mask_zero_inplace( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_HARDSIGMOID); + struct lm_ggml_tensor * a, + int n_past) { + return lm_ggml_diag_mask_zero_impl(ctx, a, n_past, true); } -// ggml exp +// lm_ggml_soft_max -struct lm_ggml_tensor * lm_ggml_exp( +static struct lm_ggml_tensor * lm_ggml_soft_max_impl( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_EXP); -} + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * mask, + float scale, + float max_bias, + bool inplace) { + LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); -struct lm_ggml_tensor * lm_ggml_exp_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_EXP); -} + if (mask) { + LM_GGML_ASSERT(mask->type == LM_GGML_TYPE_F16 || mask->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(lm_ggml_is_contiguous(mask)); + LM_GGML_ASSERT(lm_ggml_is_matrix(mask)); + LM_GGML_ASSERT(mask->ne[0] == a->ne[0]); + LM_GGML_ASSERT(mask->ne[1] >= a->ne[1]); + } -// lm_ggml_norm + if (max_bias > 0.0f) { + LM_GGML_ASSERT(mask); + } -static struct lm_ggml_tensor * lm_ggml_norm_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - float eps, - bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - lm_ggml_set_op_params(result, &eps, sizeof(eps)); + float params[] = { scale, max_bias }; + lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_NORM; + result->op = LM_GGML_OP_SOFT_MAX; result->src[0] = a; + result->src[1] = mask; return result; } -struct lm_ggml_tensor * lm_ggml_norm( +struct lm_ggml_tensor * lm_ggml_soft_max( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - float eps) { - return lm_ggml_norm_impl(ctx, a, eps, false); + struct lm_ggml_tensor * a) { + return lm_ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false); } -struct lm_ggml_tensor * lm_ggml_norm_inplace( +struct lm_ggml_tensor * lm_ggml_soft_max_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a) { + return lm_ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true); +} + +struct lm_ggml_tensor * lm_ggml_soft_max_ext( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - float eps) { - return lm_ggml_norm_impl(ctx, a, eps, true); + struct lm_ggml_tensor * mask, + float scale, + float max_bias) { + return lm_ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } -// lm_ggml_rms_norm +// lm_ggml_soft_max_back -static struct lm_ggml_tensor * lm_ggml_rms_norm_impl( +static struct lm_ggml_tensor * lm_ggml_soft_max_back_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - float eps, + struct lm_ggml_tensor * b, bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - lm_ggml_set_op_params(result, &eps, sizeof(eps)); - - result->op = LM_GGML_OP_RMS_NORM; + result->op = LM_GGML_OP_SOFT_MAX_BACK; result->src[0] = a; + result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_rms_norm( +struct lm_ggml_tensor * lm_ggml_soft_max_back( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - float eps) { - return lm_ggml_rms_norm_impl(ctx, a, eps, false); + struct lm_ggml_tensor * b) { + return lm_ggml_soft_max_back_impl(ctx, a, b, false); } -struct lm_ggml_tensor * lm_ggml_rms_norm_inplace( +struct lm_ggml_tensor * lm_ggml_soft_max_back_inplace( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - float eps) { - return lm_ggml_rms_norm_impl(ctx, a, eps, true); + struct lm_ggml_tensor * b) { + return lm_ggml_soft_max_back_impl(ctx, a, b, true); } -// lm_ggml_rms_norm_back +// lm_ggml_rope -struct lm_ggml_tensor * lm_ggml_rms_norm_back( +static struct lm_ggml_tensor * lm_ggml_rope_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, - float eps) { - struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); - - lm_ggml_set_op_params(result, &eps, sizeof(eps)); - - result->op = LM_GGML_OP_RMS_NORM_BACK; - result->src[0] = a; - result->src[1] = b; + struct lm_ggml_tensor * c, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + bool inplace) { + LM_GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); - return result; -} + LM_GGML_ASSERT(lm_ggml_is_vector(b)); + LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32); + LM_GGML_ASSERT(a->ne[2] == b->ne[0]); -// lm_ggml_group_norm + if (c) { + LM_GGML_ASSERT(c->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(c->ne[0] >= n_dims / 2); + } -static struct lm_ggml_tensor * lm_ggml_group_norm_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int n_groups, - float eps, - bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - lm_ggml_set_op_params_i32(result, 0, n_groups); - lm_ggml_set_op_params_f32(result, 1, eps); + int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; + memcpy(params + 5, &freq_base, sizeof(float)); + memcpy(params + 6, &freq_scale, sizeof(float)); + memcpy(params + 7, &ext_factor, sizeof(float)); + memcpy(params + 8, &attn_factor, sizeof(float)); + memcpy(params + 9, &beta_fast, sizeof(float)); + memcpy(params + 10, &beta_slow, sizeof(float)); + lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_GROUP_NORM; + result->op = LM_GGML_OP_ROPE; result->src[0] = a; + result->src[1] = b; + result->src[2] = c; return result; } -struct lm_ggml_tensor * lm_ggml_group_norm( +struct lm_ggml_tensor * lm_ggml_rope( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - int n_groups, - float eps) { - return lm_ggml_group_norm_impl(ctx, a, n_groups, eps, false); + struct lm_ggml_tensor * b, + int n_dims, + int mode) { + return lm_ggml_rope_impl( + ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false + ); } -struct lm_ggml_tensor * lm_ggml_group_norm_inplace( +struct lm_ggml_tensor * lm_ggml_rope_inplace( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - int n_groups, - float eps) { - return lm_ggml_group_norm_impl(ctx, a, n_groups, eps, true); + struct lm_ggml_tensor * b, + int n_dims, + int mode) { + return lm_ggml_rope_impl( + ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true + ); } -// lm_ggml_mul_mat - -struct lm_ggml_tensor * lm_ggml_mul_mat( +struct lm_ggml_tensor * lm_ggml_rope_ext( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - LM_GGML_ASSERT(lm_ggml_can_mul_mat(a, b)); - LM_GGML_ASSERT(!lm_ggml_is_transposed(a)); - - const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - - result->op = LM_GGML_OP_MUL_MAT; - result->src[0] = a; - result->src[1] = b; - - return result; + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return lm_ggml_rope_impl( + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, false + ); } -void lm_ggml_mul_mat_set_prec( - struct lm_ggml_tensor * a, - enum lm_ggml_prec prec) { - LM_GGML_ASSERT(a->op == LM_GGML_OP_MUL_MAT); - - const int32_t prec_i32 = (int32_t) prec; +struct lm_ggml_tensor * lm_ggml_rope_ext_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return lm_ggml_rope_impl( + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, true + ); +} - lm_ggml_set_op_params_i32(a, 0, prec_i32); +struct lm_ggml_tensor * lm_ggml_rope_custom( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return lm_ggml_rope_impl( + ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, false + ); } -// lm_ggml_mul_mat_id +struct lm_ggml_tensor * lm_ggml_rope_custom_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return lm_ggml_rope_impl( + ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, true + ); +} -/* - c = lm_ggml_mul_mat_id(ctx, as, b, ids); +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float lm_ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +} - as -> [cols, rows, n_expert] - ids -> [n_experts_used, n_tokens] (i32) - b -> [cols, n_expert_used, n_tokens] - c -> [rows, n_expert_used, n_tokens] +void lm_ggml_rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + float start = floorf(lm_ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); + float end = ceilf(lm_ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); + dims[0] = MAX(0, start); + dims[1] = MIN(n_dims - 1, end); +} - in b, n_experts_used can be broadcasted to match the n_expert_used of ids +// lm_ggml_rope_back - c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids -*/ -struct lm_ggml_tensor * lm_ggml_mul_mat_id( +struct lm_ggml_tensor * lm_ggml_rope_back( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * as, + struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, - struct lm_ggml_tensor * ids) { - LM_GGML_ASSERT(!lm_ggml_is_transposed(as)); - LM_GGML_ASSERT(ids->type == LM_GGML_TYPE_I32); + struct lm_ggml_tensor * c, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + LM_GGML_ASSERT(lm_ggml_is_vector(b)); + LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32); + LM_GGML_ASSERT(a->ne[2] == b->ne[0]); - LM_GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert) - LM_GGML_ASSERT(b->ne[3] == 1); // b is 3d - LM_GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d - LM_GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row - LM_GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat - LM_GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast + struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); - const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; + memcpy(params + 5, &freq_base, sizeof(float)); + memcpy(params + 6, &freq_scale, sizeof(float)); + memcpy(params + 7, &ext_factor, sizeof(float)); + memcpy(params + 8, &attn_factor, sizeof(float)); + memcpy(params + 9, &beta_fast, sizeof(float)); + memcpy(params + 10, &beta_slow, sizeof(float)); + lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_MUL_MAT_ID; - result->src[0] = as; + result->op = LM_GGML_OP_ROPE_BACK; + result->src[0] = a; result->src[1] = b; - result->src[2] = ids; + result->src[2] = c; return result; } -// lm_ggml_out_prod +// lm_ggml_clamp -struct lm_ggml_tensor * lm_ggml_out_prod( +struct lm_ggml_tensor * lm_ggml_clamp( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - LM_GGML_ASSERT(lm_ggml_can_out_prod(a, b)); - LM_GGML_ASSERT(!lm_ggml_is_transposed(a)); + float min, + float max) { + // TODO: when implement backward, fix this: + struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a); - // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3] - const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + float params[] = { min, max }; + lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_OUT_PROD; + result->op = LM_GGML_OP_CLAMP; result->src[0] = a; - result->src[1] = b; return result; } -// lm_ggml_scale +// lm_ggml_conv_1d -static struct lm_ggml_tensor * lm_ggml_scale_impl( +static int64_t lm_ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; +} + +LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_1d( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - float s, - bool inplace) { - LM_GGML_ASSERT(lm_ggml_is_padded_1d(a)); - - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + struct lm_ggml_tensor * b, + int s0, + int p0, + int d0) { + struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, LM_GGML_TYPE_F16); // [N, OL, IC * K] - lm_ggml_set_op_params(result, &s, sizeof(s)); + struct lm_ggml_tensor * result = + lm_ggml_mul_mat(ctx, + lm_ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K] + lm_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K] - result->op = LM_GGML_OP_SCALE; - result->src[0] = a; + result = lm_ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL] return result; } -struct lm_ggml_tensor * lm_ggml_scale( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - float s) { - return lm_ggml_scale_impl(ctx, a, s, false); -} +// lm_ggml_conv_1d_ph -struct lm_ggml_tensor * lm_ggml_scale_inplace( +struct lm_ggml_tensor* lm_ggml_conv_1d_ph( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - float s) { - return lm_ggml_scale_impl(ctx, a, s, true); + struct lm_ggml_tensor * b, + int s, + int d) { + return lm_ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); } -// lm_ggml_set +// lm_ggml_conv_transpose_1d -static struct lm_ggml_tensor * lm_ggml_set_impl( +static int64_t lm_ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins - 1) * s - 2 * p + d * (ks - 1) + 1; +} + +LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_1d( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset, - bool inplace) { - LM_GGML_ASSERT(lm_ggml_nelements(a) >= lm_ggml_nelements(b)); + int s0, + int p0, + int d0) { + LM_GGML_ASSERT(lm_ggml_is_matrix(b)); + LM_GGML_ASSERT(a->ne[2] == b->ne[1]); + LM_GGML_ASSERT(a->ne[3] == 1); - // make a view of the destination - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + LM_GGML_ASSERT(p0 == 0); + LM_GGML_ASSERT(d0 == 1); - LM_GGML_ASSERT(offset < (size_t)(1 << 30)); - int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; + const int64_t ne[4] = { + lm_ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), + a->ne[1], b->ne[2], 1, + }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + + int32_t params[] = { s0, p0, d0 }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_SET; + result->op = LM_GGML_OP_CONV_TRANSPOSE_1D; result->src[0] = a; result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_set( +// lm_ggml_conv_depthwise + +struct lm_ggml_tensor * lm_ggml_conv_depthwise_2d( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { - return lm_ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false); -} - -struct lm_ggml_tensor * lm_ggml_set_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { - return lm_ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true); -} + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + struct lm_ggml_tensor * new_a = lm_ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); + struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, new_a, + lm_ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]), + s0, s1, p0, p1, d0, d1, true, LM_GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW] + struct lm_ggml_tensor * new_b = lm_ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW] -struct lm_ggml_tensor * lm_ggml_set_1d( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - size_t offset) { - return lm_ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false); -} + new_a = lm_ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW] + struct lm_ggml_tensor * result = lm_ggml_mul_mat(ctx, new_a, new_b); + result = lm_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW] -struct lm_ggml_tensor * lm_ggml_set_1d_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - size_t offset) { - return lm_ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true); + return result; } +// lm_ggml_conv_2d -struct lm_ggml_tensor * lm_ggml_set_2d( +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] +// a: [OC,IC, KH, KW] +// b: [N, IC, IH, IW] +// result: [N, OH, OW, IC*KH*KW] +struct lm_ggml_tensor * lm_ggml_im2col( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, - size_t nb1, - size_t offset) { - return lm_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); -} + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D, + enum lm_ggml_type dst_type) { + if(is_2D) { + LM_GGML_ASSERT(a->ne[2] == b->ne[2]); + } else { + LM_GGML_ASSERT(a->ne[1] == b->ne[1]); + LM_GGML_ASSERT(b->ne[3] == 1); + } -struct lm_ggml_tensor * lm_ggml_set_2d_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - size_t nb1, - size_t offset) { - return lm_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true); -} + const int64_t OH = is_2D ? lm_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; + const int64_t OW = lm_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); -// lm_ggml_cpy + LM_GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a"); + LM_GGML_ASSERT((OW > 0) && "b too small compared to a"); -static struct lm_ggml_tensor * lm_ggml_cpy_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - LM_GGML_ASSERT(lm_ggml_nelements(a) == lm_ggml_nelements(b)); + const int64_t ne[4] = { + is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], + OW, + is_2D ? OH : b->ne[2], + is_2D ? b->ne[3] : 1, + }; - // make a view of the destination - struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, b); - if (strlen(b->name) > 0) { - lm_ggml_format_name(result, "%s (copy of %s)", b->name, a->name); - } else { - lm_ggml_format_name(result, "%s (copy)", a->name); - } + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, dst_type, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; + lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_CPY; + result->op = LM_GGML_OP_IM2COL; result->src[0] = a; result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_cpy( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - return lm_ggml_cpy_impl(ctx, a, b); -} - -struct lm_ggml_tensor * lm_ggml_cast( +struct lm_ggml_tensor * lm_ggml_im2col_back( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - enum lm_ggml_type type) { - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, type, LM_GGML_MAX_DIMS, a->ne); - lm_ggml_format_name(result, "%s (copy)", a->name); + struct lm_ggml_tensor * b, + int64_t * ne, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D) { + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; + lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_CPY; + result->op = LM_GGML_OP_IM2COL_BACK; result->src[0] = a; - result->src[1] = result; + result->src[1] = b; return result; } -// lm_ggml_cont - -static struct lm_ggml_tensor * lm_ggml_cont_impl( +// a: [OC,IC, KH, KW] +// b: [N, IC, IH, IW] +// result: [N, OC, OH, OW] +struct lm_ggml_tensor * lm_ggml_conv_2d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); - lm_ggml_format_name(result, "%s (cont)", a->name); + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW] + + struct lm_ggml_tensor * result = + lm_ggml_mul_mat(ctx, + lm_ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW] + lm_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW] + + result = lm_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW] + result = lm_ggml_cont(ctx, lm_ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW] - result->op = LM_GGML_OP_CONT; - result->src[0] = a; return result; } -struct lm_ggml_tensor * lm_ggml_cont( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_cont_impl(ctx, a); -} +// lm_ggml_conv_2d_sk_p0 -// make contiguous, with new shape -LM_GGML_API struct lm_ggml_tensor * lm_ggml_cont_1d( +struct lm_ggml_tensor * lm_ggml_conv_2d_sk_p0( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - int64_t ne0) { - return lm_ggml_cont_4d(ctx, a, ne0, 1, 1, 1); + struct lm_ggml_tensor * b) { + return lm_ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1); } -LM_GGML_API struct lm_ggml_tensor * lm_ggml_cont_2d( +// lm_ggml_conv_2d_s1_ph + +struct lm_ggml_tensor * lm_ggml_conv_2d_s1_ph( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - int64_t ne0, - int64_t ne1) { - return lm_ggml_cont_4d(ctx, a, ne0, ne1, 1, 1); + struct lm_ggml_tensor * b) { + return lm_ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1); } -LM_GGML_API struct lm_ggml_tensor * lm_ggml_cont_3d( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2) { - return lm_ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1); +// lm_ggml_conv_transpose_2d_p0 + +static int64_t lm_ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { + return (ins - 1) * s - 2 * p + ks; } -struct lm_ggml_tensor * lm_ggml_cont_4d( +struct lm_ggml_tensor * lm_ggml_conv_transpose_2d_p0( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int64_t ne3) { - LM_GGML_ASSERT(lm_ggml_nelements(a) == (ne0*ne1*ne2*ne3)); + struct lm_ggml_tensor * b, + int stride) { + LM_GGML_ASSERT(a->ne[3] == b->ne[2]); - struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); - lm_ggml_format_name(result, "%s (cont)", a->name); + const int64_t ne[4] = { + lm_ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/), + lm_ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/), + a->ne[2], b->ne[3], + }; - result->op = LM_GGML_OP_CONT; + struct lm_ggml_tensor* result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + + lm_ggml_set_op_params_i32(result, 0, stride); + + result->op = LM_GGML_OP_CONV_TRANSPOSE_2D; result->src[0] = a; + result->src[1] = b; return result; } -// lm_ggml_reshape +// lm_ggml_pool_* -struct lm_ggml_tensor * lm_ggml_reshape( +static int64_t lm_ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) { + return (ins + 2 * p - ks) / s + 1; +} + +// lm_ggml_pool_1d + +struct lm_ggml_tensor * lm_ggml_pool_1d( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); - // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous. - LM_GGML_ASSERT(lm_ggml_nelements(a) == lm_ggml_nelements(b)); + struct lm_ggml_tensor * a, + enum lm_ggml_op_pool op, + int k0, + int s0, + int p0) { + const int64_t ne[4] = { + lm_ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), + a->ne[1], + a->ne[2], + a->ne[3], + }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, LM_GGML_MAX_DIMS, b->ne, a, 0); - lm_ggml_format_name(result, "%s (reshaped)", a->name); + int32_t params[] = { op, k0, s0, p0 }; + lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_RESHAPE; + result->op = LM_GGML_OP_POOL_1D; result->src[0] = a; return result; } -struct lm_ggml_tensor * lm_ggml_reshape_1d( +// lm_ggml_pool_2d + +struct lm_ggml_tensor * lm_ggml_pool_2d( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - int64_t ne0) { - LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); - LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0); + enum lm_ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1) { + struct lm_ggml_tensor * result; + const int64_t ne[4] = { + lm_ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), + lm_ggml_calc_pool_output_size(a->ne[1], k1, s1, p1), + a->ne[2], + a->ne[3], + }; + result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - const int64_t ne[1] = { ne0 }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0); - lm_ggml_format_name(result, "%s (reshaped)", a->name); + int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; + lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_RESHAPE; + result->op = LM_GGML_OP_POOL_2D; result->src[0] = a; return result; } -struct lm_ggml_tensor * lm_ggml_reshape_2d( +struct lm_ggml_tensor * lm_ggml_pool_2d_back( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - int64_t ne0, - int64_t ne1) { - LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); - LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0*ne1); + struct lm_ggml_tensor * af, + enum lm_ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1) { + struct lm_ggml_tensor * result; + result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, af->ne); - const int64_t ne[2] = { ne0, ne1 }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0); - lm_ggml_format_name(result, "%s (reshaped)", a->name); + int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; + lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_RESHAPE; + result->op = LM_GGML_OP_POOL_2D_BACK; result->src[0] = a; + result->src[1] = af; return result; } -struct lm_ggml_tensor * lm_ggml_reshape_3d( +// lm_ggml_upscale + +static struct lm_ggml_tensor * lm_ggml_upscale_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2) { - LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); - LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0*ne1*ne2); + int ne0, + int ne1, + int ne2, + int ne3) { + LM_GGML_ASSERT(a->ne[0] <= ne0); + LM_GGML_ASSERT(a->ne[1] <= ne1); + LM_GGML_ASSERT(a->ne[2] <= ne2); + LM_GGML_ASSERT(a->ne[3] <= ne3); - const int64_t ne[3] = { ne0, ne1, ne2 }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0); - lm_ggml_format_name(result, "%s (reshaped)", a->name); + struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); - result->op = LM_GGML_OP_RESHAPE; + result->op = LM_GGML_OP_UPSCALE; result->src[0] = a; return result; } -struct lm_ggml_tensor * lm_ggml_reshape_4d( +struct lm_ggml_tensor * lm_ggml_upscale( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int64_t ne3) { - LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); - LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0*ne1*ne2*ne3); - - const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0); - lm_ggml_format_name(result, "%s (reshaped)", a->name); - - result->op = LM_GGML_OP_RESHAPE; - result->src[0] = a; - - return result; + int scale_factor) { + return lm_ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]); } -static struct lm_ggml_tensor * lm_ggml_view_impl( +struct lm_ggml_tensor * lm_ggml_upscale_ext( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - int n_dims, - const int64_t * ne, - size_t offset) { - struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset); - lm_ggml_format_name(result, "%s (view)", a->name); + int ne0, + int ne1, + int ne2, + int ne3) { + return lm_ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3); +} - lm_ggml_set_op_params(result, &offset, sizeof(offset)); +// lm_ggml_pad - result->op = LM_GGML_OP_VIEW; +struct lm_ggml_tensor * lm_ggml_pad( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + int p0, + int p1, + int p2, + int p3) { + struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, + a->ne[0] + p0, + a->ne[1] + p1, + a->ne[2] + p2, + a->ne[3] + p3); + + result->op = LM_GGML_OP_PAD; result->src[0] = a; return result; } -// lm_ggml_view_1d +// lm_ggml_arange -struct lm_ggml_tensor * lm_ggml_view_1d( +struct lm_ggml_tensor * lm_ggml_arange( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int64_t ne0, - size_t offset) { - struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 1, &ne0, offset); + float start, + float stop, + float step) { + LM_GGML_ASSERT(stop > start); + + const int64_t steps = (int64_t) ceilf((stop - start) / step); + + struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, steps); + + lm_ggml_set_op_params_f32(result, 0, start); + lm_ggml_set_op_params_f32(result, 1, stop); + lm_ggml_set_op_params_f32(result, 2, step); + + result->op = LM_GGML_OP_ARANGE; return result; } -// lm_ggml_view_2d +// lm_ggml_timestep_embedding -struct lm_ggml_tensor * lm_ggml_view_2d( +struct lm_ggml_tensor * lm_ggml_timestep_embedding( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int64_t ne0, - int64_t ne1, - size_t nb1, - size_t offset) { - const int64_t ne[2] = { ne0, ne1 }; + struct lm_ggml_tensor * timesteps, + int dim, + int max_period) { + int actual_dim = dim; + if (dim % 2 != 0) { + actual_dim = dim + 1; + } - struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 2, ne, offset); + struct lm_ggml_tensor * result = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, actual_dim, timesteps->ne[0]); - result->nb[1] = nb1; - result->nb[2] = result->nb[1]*ne1; - result->nb[3] = result->nb[2]; + lm_ggml_set_op_params_i32(result, 0, dim); + lm_ggml_set_op_params_i32(result, 1, max_period); + + result->op = LM_GGML_OP_TIMESTEP_EMBEDDING; + result->src[0] = timesteps; return result; } -// lm_ggml_view_3d +// lm_ggml_argsort -struct lm_ggml_tensor * lm_ggml_view_3d( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2, - size_t nb1, - size_t nb2, - size_t offset) { - const int64_t ne[3] = { ne0, ne1, ne2 }; +struct lm_ggml_tensor * lm_ggml_argsort( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + enum lm_ggml_sort_order order) { + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_I32, LM_GGML_MAX_DIMS, a->ne); - struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 3, ne, offset); + lm_ggml_set_op_params_i32(result, 0, (int32_t) order); - result->nb[1] = nb1; - result->nb[2] = nb2; - result->nb[3] = result->nb[2]*ne2; + result->op = LM_GGML_OP_ARGSORT; + result->src[0] = a; return result; } -// lm_ggml_view_4d +// lm_ggml_top_k -struct lm_ggml_tensor * lm_ggml_view_4d( +struct lm_ggml_tensor * lm_ggml_top_k( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int64_t ne3, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { - const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + int k) { + LM_GGML_ASSERT(a->ne[0] >= k); - struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 4, ne, offset); + struct lm_ggml_tensor * result = lm_ggml_argsort(ctx, a, LM_GGML_SORT_ORDER_DESC); - result->nb[1] = nb1; - result->nb[2] = nb2; - result->nb[3] = nb3; + result = lm_ggml_view_4d(ctx, result, + k, result->ne[1], result->ne[2], result->ne[3], + result->nb[1], result->nb[2], result->nb[3], + 0); return result; } -// lm_ggml_permute +// lm_ggml_flash_attn_ext -struct lm_ggml_tensor * lm_ggml_permute( +struct lm_ggml_tensor * lm_ggml_flash_attn_ext( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int axis0, - int axis1, - int axis2, - int axis3) { - LM_GGML_ASSERT(axis0 >= 0 && axis0 < LM_GGML_MAX_DIMS); - LM_GGML_ASSERT(axis1 >= 0 && axis1 < LM_GGML_MAX_DIMS); - LM_GGML_ASSERT(axis2 >= 0 && axis2 < LM_GGML_MAX_DIMS); - LM_GGML_ASSERT(axis3 >= 0 && axis3 < LM_GGML_MAX_DIMS); + struct lm_ggml_tensor * q, + struct lm_ggml_tensor * k, + struct lm_ggml_tensor * v, + struct lm_ggml_tensor * mask, + float scale, + float max_bias, + float logit_softcap) { + LM_GGML_ASSERT(lm_ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) - LM_GGML_ASSERT(axis0 != axis1); - LM_GGML_ASSERT(axis0 != axis2); - LM_GGML_ASSERT(axis0 != axis3); - LM_GGML_ASSERT(axis1 != axis2); - LM_GGML_ASSERT(axis1 != axis3); - LM_GGML_ASSERT(axis2 != axis3); + if (mask) { + LM_GGML_ASSERT(lm_ggml_is_contiguous(mask)); + LM_GGML_ASSERT(mask->ne[2] == 1); + LM_GGML_ASSERT(mask->ne[3] == 1); + LM_GGML_ASSERT(mask->ne[1] >= LM_GGML_PAD(q->ne[1], LM_GGML_KQ_MASK_PAD) && + "the Flash-Attention kernel requires the mask to be padded to LM_GGML_KQ_MASK_PAD and at least n_queries big"); + //LM_GGML_ASSERT(lm_ggml_can_repeat_rows(mask, qk)); + } - struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a); - lm_ggml_format_name(result, "%s (permuted)", a->name); + if (max_bias > 0.0f) { + LM_GGML_ASSERT(mask); + } - int ne[LM_GGML_MAX_DIMS]; - int nb[LM_GGML_MAX_DIMS]; + bool is_node = false; - ne[axis0] = a->ne[0]; - ne[axis1] = a->ne[1]; - ne[axis2] = a->ne[2]; - ne[axis3] = a->ne[3]; + // permute(0, 2, 1, 3) + int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - nb[axis0] = a->nb[0]; - nb[axis1] = a->nb[1]; - nb[axis2] = a->nb[2]; - nb[axis3] = a->nb[3]; + float params[] = { scale, max_bias, logit_softcap }; + lm_ggml_set_op_params(result, params, sizeof(params)); - result->ne[0] = ne[0]; - result->ne[1] = ne[1]; - result->ne[2] = ne[2]; - result->ne[3] = ne[3]; + result->op = LM_GGML_OP_FLASH_ATTN_EXT; + result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = mask; - result->nb[0] = nb[0]; - result->nb[1] = nb[1]; - result->nb[2] = nb[2]; - result->nb[3] = nb[3]; + return result; +} - result->op = LM_GGML_OP_PERMUTE; - result->src[0] = a; +void lm_ggml_flash_attn_ext_set_prec( + struct lm_ggml_tensor * a, + enum lm_ggml_prec prec) { + LM_GGML_ASSERT(a->op == LM_GGML_OP_FLASH_ATTN_EXT); - int32_t params[] = { axis0, axis1, axis2, axis3 }; - lm_ggml_set_op_params(result, params, sizeof(params)); + const int32_t prec_i32 = (int32_t) prec; - return result; + lm_ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second } -// lm_ggml_transpose +enum lm_ggml_prec lm_ggml_flash_attn_ext_get_prec( + const struct lm_ggml_tensor * a) { + LM_GGML_ASSERT(a->op == LM_GGML_OP_FLASH_ATTN_EXT); -struct lm_ggml_tensor * lm_ggml_transpose( + const int32_t prec_i32 = lm_ggml_get_op_params_i32(a, 3); + + return (enum lm_ggml_prec) prec_i32; +} + +// lm_ggml_flash_attn_back + +struct lm_ggml_tensor * lm_ggml_flash_attn_back( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a); - lm_ggml_format_name(result, "%s (transposed)", a->name); + struct lm_ggml_tensor * q, + struct lm_ggml_tensor * k, + struct lm_ggml_tensor * v, + struct lm_ggml_tensor * d, + bool masked) { + LM_GGML_ABORT("TODO: adapt to lm_ggml_flash_attn_ext() changes"); - result->ne[0] = a->ne[1]; - result->ne[1] = a->ne[0]; + LM_GGML_ASSERT(lm_ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) - result->nb[0] = a->nb[1]; - result->nb[1] = a->nb[0]; + // d shape [D,N,ne2,ne3] + // q shape [D,N,ne2,ne3] + // k shape [D,M,kvne2,ne3] + // v shape [M,D,kvne2,ne3] - result->op = LM_GGML_OP_TRANSPOSE; - result->src[0] = a; + const int64_t D = q->ne[0]; + const int64_t N = q->ne[1]; + const int64_t M = k->ne[1]; + const int64_t ne2 = q->ne[2]; + const int64_t ne3 = q->ne[3]; + const int64_t kvne2 = k->ne[2]; - return result; -} + LM_GGML_ASSERT(k->ne[0] == D); + LM_GGML_ASSERT(v->ne[0] == M); + LM_GGML_ASSERT(v->ne[1] == D); + LM_GGML_ASSERT(d->ne[0] == D); + LM_GGML_ASSERT(d->ne[1] == N); + LM_GGML_ASSERT(k->ne[2] == kvne2); + LM_GGML_ASSERT(k->ne[3] == ne3); + LM_GGML_ASSERT(v->ne[2] == kvne2); + LM_GGML_ASSERT(v->ne[3] == ne3); + LM_GGML_ASSERT(d->ne[2] == ne2); + LM_GGML_ASSERT(d->ne[3] == ne3); -// lm_ggml_get_rows + LM_GGML_ASSERT(ne2 % kvne2 == 0); -struct lm_ggml_tensor * lm_ggml_get_rows( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - LM_GGML_ASSERT(a->ne[2] == b->ne[1]); - LM_GGML_ASSERT(b->ne[3] == 1); - LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32); + bool is_node = false; - // TODO: implement non F32 return - enum lm_ggml_type type = LM_GGML_TYPE_F32; - if (a->type == LM_GGML_TYPE_I32) { - type = a->type; + if (q->grad || k->grad || v->grad) { + // when using this operation (in backwards pass) these grads are set. + // we don't want to create (big) grad of our result, so is_node is false. + is_node = false; } - struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); - result->op = LM_GGML_OP_GET_ROWS; - result->src[0] = a; - result->src[1] = b; + // store gradients of q, k and v as continuous tensors concatenated in result. + // note: v and gradv are actually transposed, i.e. v->ne[0] != D. + const int64_t elem_q = lm_ggml_nelements(q); + const int64_t elem_k = lm_ggml_nelements(k); + const int64_t elem_v = lm_ggml_nelements(v); + + enum lm_ggml_type result_type = LM_GGML_TYPE_F32; + LM_GGML_ASSERT(lm_ggml_blck_size(result_type) == 1); + const size_t tsize = lm_ggml_type_size(result_type); + + const size_t offs_q = 0; + const size_t offs_k = offs_q + LM_GGML_PAD(elem_q * tsize, LM_GGML_MEM_ALIGN); + const size_t offs_v = offs_k + LM_GGML_PAD(elem_k * tsize, LM_GGML_MEM_ALIGN); + const size_t end = offs_v + LM_GGML_PAD(elem_v * tsize, LM_GGML_MEM_ALIGN); + + const size_t nelements = (end + tsize - 1)/tsize; + + struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, nelements); + + int32_t masked_i = masked ? 1 : 0; + lm_ggml_set_op_params(result, &masked_i, sizeof(masked_i)); + + result->op = LM_GGML_OP_FLASH_ATTN_BACK; + result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = d; return result; } -// lm_ggml_get_rows_back +// lm_ggml_ssm_conv -struct lm_ggml_tensor * lm_ggml_get_rows_back( +struct lm_ggml_tensor * lm_ggml_ssm_conv( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, + struct lm_ggml_tensor * sx, struct lm_ggml_tensor * c) { - LM_GGML_ASSERT(lm_ggml_is_matrix(a) && lm_ggml_is_vector(b) && b->type == LM_GGML_TYPE_I32); - LM_GGML_ASSERT(lm_ggml_is_matrix(c) && (a->ne[0] == c->ne[0])); + LM_GGML_ASSERT(lm_ggml_is_3d(sx)); + LM_GGML_ASSERT(lm_ggml_is_matrix(c)); - // TODO: implement non F32 return - //struct lm_ggml_tensor * result = lm_ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); - struct lm_ggml_tensor * result = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, c->ne[0], c->ne[1]); + const int64_t d_conv = c->ne[0]; + const int64_t d_inner = c->ne[1]; + const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence + const int64_t n_s = sx->ne[2]; - result->op = LM_GGML_OP_GET_ROWS_BACK; - result->src[0] = a; - result->src[1] = b; + // TODO: maybe support other strides than 1? + // FIXME: this is always true? + LM_GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); + LM_GGML_ASSERT(sx->ne[1] == d_inner); + LM_GGML_ASSERT(n_t >= 0); + + struct lm_ggml_tensor * result = lm_ggml_new_tensor_3d(ctx, LM_GGML_TYPE_F32, d_inner, n_t, n_s); + + result->op = LM_GGML_OP_SSM_CONV; + result->src[0] = sx; + result->src[1] = c; return result; } -// lm_ggml_diag +// lm_ggml_ssm_scan -struct lm_ggml_tensor * lm_ggml_diag( +struct lm_ggml_tensor * lm_ggml_ssm_scan( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - LM_GGML_ASSERT(a->ne[1] == 1); + struct lm_ggml_tensor * s, + struct lm_ggml_tensor * x, + struct lm_ggml_tensor * dt, + struct lm_ggml_tensor * A, + struct lm_ggml_tensor * B, + struct lm_ggml_tensor * C) { + LM_GGML_ASSERT(lm_ggml_is_contiguous(s)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(x)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(dt)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(A)); + LM_GGML_ASSERT(lm_ggml_is_matrix(A)); + LM_GGML_ASSERT(lm_ggml_is_3d(B)); + LM_GGML_ASSERT(lm_ggml_is_3d(s)); + LM_GGML_ASSERT(B->nb[0] == lm_ggml_type_size(B->type)); + LM_GGML_ASSERT(C->nb[0] == lm_ggml_type_size(C->type)); + LM_GGML_ASSERT(lm_ggml_are_same_shape(x, dt)); + LM_GGML_ASSERT(lm_ggml_are_same_shape(B, C)); - const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, 4, ne); + { + const int64_t d_state = s->ne[0]; + const int64_t d_inner = s->ne[1]; + const int64_t n_seq_tokens = x->ne[1]; + const int64_t n_seqs = x->ne[2]; - result->op = LM_GGML_OP_DIAG; - result->src[0] = a; + LM_GGML_ASSERT(s->ne[2] == n_seqs); + LM_GGML_ASSERT(x->ne[0] == d_inner); + LM_GGML_ASSERT(A->ne[0] == d_state); + LM_GGML_ASSERT(A->ne[1] == d_inner); + LM_GGML_ASSERT(B->ne[0] == d_state); + LM_GGML_ASSERT(B->ne[1] == n_seq_tokens); + LM_GGML_ASSERT(B->ne[2] == n_seqs); + } + + // concatenated y + ssm_states + struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, lm_ggml_nelements(x) + lm_ggml_nelements(s)); + + result->op = LM_GGML_OP_SSM_SCAN; + result->src[0] = s; + result->src[1] = x; + result->src[2] = dt; + result->src[3] = A; + result->src[4] = B; + result->src[5] = C; return result; } -// lm_ggml_diag_mask_inf +// lm_ggml_win_part -static struct lm_ggml_tensor * lm_ggml_diag_mask_inf_impl( +struct lm_ggml_tensor * lm_ggml_win_part( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - int n_past, - bool inplace) { - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + int w) { + LM_GGML_ASSERT(a->ne[3] == 1); + LM_GGML_ASSERT(a->type == LM_GGML_TYPE_F32); - int32_t params[] = { n_past }; + // padding + const int px = (w - a->ne[1]%w)%w; + const int py = (w - a->ne[2]%w)%w; + + const int npx = (px + a->ne[1])/w; + const int npy = (py + a->ne[2])/w; + const int np = npx*npy; + + const int64_t ne[4] = { a->ne[0], w, w, np, }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + + int32_t params[] = { npx, npy, w }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_DIAG_MASK_INF; + result->op = LM_GGML_OP_WIN_PART; result->src[0] = a; return result; } -struct lm_ggml_tensor * lm_ggml_diag_mask_inf( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int n_past) { - return lm_ggml_diag_mask_inf_impl(ctx, a, n_past, false); -} +// lm_ggml_win_unpart -struct lm_ggml_tensor * lm_ggml_diag_mask_inf_inplace( +struct lm_ggml_tensor * lm_ggml_win_unpart( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - int n_past) { - return lm_ggml_diag_mask_inf_impl(ctx, a, n_past, true); -} - -// lm_ggml_diag_mask_zero + int w0, + int h0, + int w) { + LM_GGML_ASSERT(a->type == LM_GGML_TYPE_F32); -static struct lm_ggml_tensor * lm_ggml_diag_mask_zero_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int n_past, - bool inplace) { - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 3, ne); - int32_t params[] = { n_past }; + int32_t params[] = { w }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_DIAG_MASK_ZERO; + result->op = LM_GGML_OP_WIN_UNPART; result->src[0] = a; return result; } -struct lm_ggml_tensor * lm_ggml_diag_mask_zero( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int n_past) { - return lm_ggml_diag_mask_zero_impl(ctx, a, n_past, false); -} +// lm_ggml_get_rel_pos -struct lm_ggml_tensor * lm_ggml_diag_mask_zero_inplace( +struct lm_ggml_tensor * lm_ggml_get_rel_pos( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - int n_past) { - return lm_ggml_diag_mask_zero_impl(ctx, a, n_past, true); + int qh, + int kh) { + LM_GGML_ASSERT(qh == kh); + LM_GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]); + + const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F16, 3, ne); + + result->op = LM_GGML_OP_GET_REL_POS; + result->src[0] = a; + + return result; } -// lm_ggml_soft_max +// lm_ggml_add_rel_pos -static struct lm_ggml_tensor * lm_ggml_soft_max_impl( +static struct lm_ggml_tensor * lm_ggml_add_rel_pos_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * mask, - float scale, - float max_bias, + struct lm_ggml_tensor * pw, + struct lm_ggml_tensor * ph, bool inplace) { + LM_GGML_ASSERT(lm_ggml_are_same_shape(pw, ph)); LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); - - if (mask) { - LM_GGML_ASSERT(mask->type == LM_GGML_TYPE_F16 || mask->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(lm_ggml_is_contiguous(mask)); - LM_GGML_ASSERT(lm_ggml_is_matrix(mask)); - LM_GGML_ASSERT(mask->ne[0] == a->ne[0]); - LM_GGML_ASSERT(mask->ne[1] >= a->ne[1]); - } - - if (max_bias > 0.0f) { - LM_GGML_ASSERT(mask); - } + LM_GGML_ASSERT(lm_ggml_is_contiguous(pw)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(ph)); + LM_GGML_ASSERT(ph->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(pw->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(pw->ne[3] == a->ne[2]); + LM_GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]); + LM_GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]); struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + lm_ggml_set_op_params_i32(result, 0, inplace ? 1 : 0); - float params[] = { scale, max_bias }; - lm_ggml_set_op_params(result, params, sizeof(params)); - - result->op = LM_GGML_OP_SOFT_MAX; + result->op = LM_GGML_OP_ADD_REL_POS; result->src[0] = a; - result->src[1] = mask; + result->src[1] = pw; + result->src[2] = ph; return result; } -struct lm_ggml_tensor * lm_ggml_soft_max( +struct lm_ggml_tensor * lm_ggml_add_rel_pos( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false); + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * pw, + struct lm_ggml_tensor * ph) { + return lm_ggml_add_rel_pos_impl(ctx, a, pw, ph, false); } -struct lm_ggml_tensor * lm_ggml_soft_max_inplace( +struct lm_ggml_tensor * lm_ggml_add_rel_pos_inplace( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a) { - return lm_ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true); + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * pw, + struct lm_ggml_tensor * ph) { + return lm_ggml_add_rel_pos_impl(ctx, a, pw, ph, true); } -struct lm_ggml_tensor * lm_ggml_soft_max_ext( +// lm_ggml_rwkv_wkv6 + +struct lm_ggml_tensor * lm_ggml_rwkv_wkv6( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * mask, - float scale, - float max_bias) { - return lm_ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); + struct lm_ggml_tensor * k, + struct lm_ggml_tensor * v, + struct lm_ggml_tensor * r, + struct lm_ggml_tensor * tf, + struct lm_ggml_tensor * td, + struct lm_ggml_tensor * state) { + LM_GGML_ASSERT(lm_ggml_is_contiguous(k)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(v)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(r)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(tf)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(td)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(state)); + + const int64_t S = k->ne[0]; + const int64_t H = k->ne[2]; + const int64_t n_tokens = k->ne[3]; + const int64_t n_seqs = state->ne[1]; + { + LM_GGML_ASSERT(k->ne[1] == 1); + LM_GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens); + LM_GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens); + // TODO: RWKV v4 and v5 + LM_GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens); + LM_GGML_ASSERT(lm_ggml_nelements(state) == S * S * H * n_seqs); + } + + // concat output and new_state + const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + + result->op = LM_GGML_OP_RWKV_WKV6; + result->src[0] = k; + result->src[1] = v; + result->src[2] = r; + result->src[3] = tf; + result->src[4] = td; + result->src[5] = state; + + return result; } -// lm_ggml_soft_max_back +// lm_ggml_unary -static struct lm_ggml_tensor * lm_ggml_soft_max_back_impl( +static struct lm_ggml_tensor * lm_ggml_unary_impl( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, + enum lm_ggml_unary_op op, bool inplace) { + LM_GGML_ASSERT(lm_ggml_is_contiguous_1(a)); + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_SOFT_MAX_BACK; + lm_ggml_set_op_params_i32(result, 0, (int32_t) op); + + result->op = LM_GGML_OP_UNARY; result->src[0] = a; - result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_soft_max_back( +struct lm_ggml_tensor * lm_ggml_unary( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - return lm_ggml_soft_max_back_impl(ctx, a, b, false); + enum lm_ggml_unary_op op) { + return lm_ggml_unary_impl(ctx, a, op, false); } -struct lm_ggml_tensor * lm_ggml_soft_max_back_inplace( +struct lm_ggml_tensor * lm_ggml_unary_inplace( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - return lm_ggml_soft_max_back_impl(ctx, a, b, true); + enum lm_ggml_unary_op op) { + return lm_ggml_unary_impl(ctx, a, op, true); } -// lm_ggml_rope - -static struct lm_ggml_tensor * lm_ggml_rope_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c, - int n_dims, - int mode, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - bool inplace) { - LM_GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); - - LM_GGML_ASSERT(lm_ggml_is_vector(b)); - LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32); - LM_GGML_ASSERT(a->ne[2] == b->ne[0]); - - if (c) { - LM_GGML_ASSERT(c->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(c->ne[0] >= n_dims / 2); - } +// lm_ggml_map_unary +static struct lm_ggml_tensor * lm_ggml_map_unary_impl_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + const lm_ggml_unary_op_f32_t fun, + bool inplace) { struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; - memcpy(params + 5, &freq_base, sizeof(float)); - memcpy(params + 6, &freq_scale, sizeof(float)); - memcpy(params + 7, &ext_factor, sizeof(float)); - memcpy(params + 8, &attn_factor, sizeof(float)); - memcpy(params + 9, &beta_fast, sizeof(float)); - memcpy(params + 10, &beta_slow, sizeof(float)); - lm_ggml_set_op_params(result, params, sizeof(params)); + lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = LM_GGML_OP_ROPE; + result->op = LM_GGML_OP_MAP_UNARY; result->src[0] = a; - result->src[1] = b; - result->src[2] = c; return result; } -struct lm_ggml_tensor * lm_ggml_rope( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int n_dims, - int mode) { - return lm_ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false - ); +struct lm_ggml_tensor * lm_ggml_map_unary_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + const lm_ggml_unary_op_f32_t fun) { + return lm_ggml_map_unary_impl_f32(ctx, a, fun, false); } -struct lm_ggml_tensor * lm_ggml_rope_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int n_dims, - int mode) { - return lm_ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true - ); +struct lm_ggml_tensor * lm_ggml_map_unary_inplace_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + const lm_ggml_unary_op_f32_t fun) { + return lm_ggml_map_unary_impl_f32(ctx, a, fun, true); } -struct lm_ggml_tensor * lm_ggml_rope_ext( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c, - int n_dims, - int mode, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow) { - return lm_ggml_rope_impl( - ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, false - ); -} - -struct lm_ggml_tensor * lm_ggml_rope_ext_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c, - int n_dims, - int mode, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow) { - return lm_ggml_rope_impl( - ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, true - ); -} - -struct lm_ggml_tensor * lm_ggml_rope_custom( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int n_dims, - int mode, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow) { - return lm_ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, false - ); -} - -struct lm_ggml_tensor * lm_ggml_rope_custom_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int n_dims, - int mode, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow) { - return lm_ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, true - ); -} - -// lm_ggml_rope_back +// lm_ggml_map_binary -struct lm_ggml_tensor * lm_ggml_rope_back( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c, - int n_dims, - int mode, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow) { - LM_GGML_ASSERT(lm_ggml_is_vector(b)); - LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32); - LM_GGML_ASSERT(a->ne[2] == b->ne[0]); +static struct lm_ggml_tensor * lm_ggml_map_binary_impl_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + const lm_ggml_binary_op_f32_t fun, + bool inplace) { + LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); - struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; - memcpy(params + 5, &freq_base, sizeof(float)); - memcpy(params + 6, &freq_scale, sizeof(float)); - memcpy(params + 7, &ext_factor, sizeof(float)); - memcpy(params + 8, &attn_factor, sizeof(float)); - memcpy(params + 9, &beta_fast, sizeof(float)); - memcpy(params + 10, &beta_slow, sizeof(float)); - lm_ggml_set_op_params(result, params, sizeof(params)); + lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = LM_GGML_OP_ROPE_BACK; + result->op = LM_GGML_OP_MAP_BINARY; result->src[0] = a; result->src[1] = b; - result->src[2] = c; return result; } -// lm_ggml_clamp +struct lm_ggml_tensor * lm_ggml_map_binary_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + const lm_ggml_binary_op_f32_t fun) { + return lm_ggml_map_binary_impl_f32(ctx, a, b, fun, false); +} -struct lm_ggml_tensor * lm_ggml_clamp( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - float min, - float max) { - // TODO: when implement backward, fix this: - struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a); +struct lm_ggml_tensor * lm_ggml_map_binary_inplace_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + const lm_ggml_binary_op_f32_t fun) { + return lm_ggml_map_binary_impl_f32(ctx, a, b, fun, true); +} - float params[] = { min, max }; - lm_ggml_set_op_params(result, params, sizeof(params)); +// lm_ggml_map_custom1_f32 - result->op = LM_GGML_OP_CLAMP; +static struct lm_ggml_tensor * lm_ggml_map_custom1_impl_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + const lm_ggml_custom1_op_f32_t fun, + bool inplace) { + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + + lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); + + result->op = LM_GGML_OP_MAP_CUSTOM1_F32; result->src[0] = a; return result; } -// lm_ggml_conv_1d +struct lm_ggml_tensor * lm_ggml_map_custom1_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + const lm_ggml_custom1_op_f32_t fun) { + return lm_ggml_map_custom1_impl_f32(ctx, a, fun, false); +} -static int64_t lm_ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) { - return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; +struct lm_ggml_tensor * lm_ggml_map_custom1_inplace_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + const lm_ggml_custom1_op_f32_t fun) { + return lm_ggml_map_custom1_impl_f32(ctx, a, fun, true); } -LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_1d( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int s0, - int p0, - int d0) { - struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, LM_GGML_TYPE_F16); // [N, OL, IC * K] +// lm_ggml_map_custom2_f32 - struct lm_ggml_tensor * result = - lm_ggml_mul_mat(ctx, - lm_ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K] - lm_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K] +static struct lm_ggml_tensor * lm_ggml_map_custom2_impl_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + const lm_ggml_custom2_op_f32_t fun, + bool inplace) { + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - result = lm_ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL] + lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); + + result->op = LM_GGML_OP_MAP_CUSTOM2_F32; + result->src[0] = a; + result->src[1] = b; return result; } -// lm_ggml_conv_1d_ph - -struct lm_ggml_tensor* lm_ggml_conv_1d_ph( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int s, - int d) { - return lm_ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); +struct lm_ggml_tensor * lm_ggml_map_custom2_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + const lm_ggml_custom2_op_f32_t fun) { + return lm_ggml_map_custom2_impl_f32(ctx, a, b, fun, false); } -// lm_ggml_conv_transpose_1d - -static int64_t lm_ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { - return (ins - 1) * s - 2 * p + d * (ks - 1) + 1; +struct lm_ggml_tensor * lm_ggml_map_custom2_inplace_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + const lm_ggml_custom2_op_f32_t fun) { + return lm_ggml_map_custom2_impl_f32(ctx, a, b, fun, true); } -LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_1d( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int s0, - int p0, - int d0) { - LM_GGML_ASSERT(lm_ggml_is_matrix(b)); - LM_GGML_ASSERT(a->ne[2] == b->ne[1]); - LM_GGML_ASSERT(a->ne[3] == 1); - - LM_GGML_ASSERT(p0 == 0); - LM_GGML_ASSERT(d0 == 1); +// lm_ggml_map_custom3_f32 - const int64_t ne[4] = { - lm_ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), - a->ne[1], b->ne[2], 1, - }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); +static struct lm_ggml_tensor * lm_ggml_map_custom3_impl_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c, + const lm_ggml_custom3_op_f32_t fun, + bool inplace) { + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - int32_t params[] = { s0, p0, d0 }; - lm_ggml_set_op_params(result, params, sizeof(params)); + lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = LM_GGML_OP_CONV_TRANSPOSE_1D; + result->op = LM_GGML_OP_MAP_CUSTOM3_F32; result->src[0] = a; result->src[1] = b; + result->src[2] = c; return result; } -// lm_ggml_conv_depthwise +struct lm_ggml_tensor * lm_ggml_map_custom3_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c, + const lm_ggml_custom3_op_f32_t fun) { + return lm_ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false); +} -struct lm_ggml_tensor * lm_ggml_conv_depthwise_2d( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1) { - struct lm_ggml_tensor * new_a = lm_ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); - struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, new_a, - lm_ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]), - s0, s1, p0, p1, d0, d1, true, LM_GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW] - struct lm_ggml_tensor * new_b = lm_ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW] - - new_a = lm_ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW] - struct lm_ggml_tensor * result = lm_ggml_mul_mat(ctx, new_a, new_b); - result = lm_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW] - - return result; +struct lm_ggml_tensor * lm_ggml_map_custom3_inplace_f32( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c, + const lm_ggml_custom3_op_f32_t fun) { + return lm_ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true); } -// lm_ggml_conv_2d -// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] -// a: [OC,IC, KH, KW] -// b: [N, IC, IH, IW] -// result: [N, OH, OW, IC*KH*KW] -struct lm_ggml_tensor * lm_ggml_im2col( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1, - bool is_2D, - enum lm_ggml_type dst_type) { - if(is_2D) { - LM_GGML_ASSERT(a->ne[2] == b->ne[2]); - } else { - LM_GGML_ASSERT(a->ne[1] == b->ne[1]); - LM_GGML_ASSERT(b->ne[3] == 1); - } +// lm_ggml_map_custom1 - const int64_t OH = is_2D ? lm_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; - const int64_t OW = lm_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); +static struct lm_ggml_tensor * lm_ggml_map_custom1_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + const lm_ggml_custom1_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { + LM_GGML_ASSERT(n_tasks == LM_GGML_N_TASKS_MAX || n_tasks > 0); - LM_GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a"); - LM_GGML_ASSERT((OW > 0) && "b too small compared to a"); + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - const int64_t ne[4] = { - is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], - OW, - is_2D ? OH : b->ne[2], - is_2D ? b->ne[3] : 1, + struct lm_ggml_map_custom1_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata }; + lm_ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, dst_type, 4, ne); - int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; - lm_ggml_set_op_params(result, params, sizeof(params)); - - result->op = LM_GGML_OP_IM2COL; + result->op = LM_GGML_OP_MAP_CUSTOM1; result->src[0] = a; - result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_im2col_back( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int64_t * ne, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1, - bool is_2D) { - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; - lm_ggml_set_op_params(result, params, sizeof(params)); - - result->op = LM_GGML_OP_IM2COL_BACK; - result->src[0] = a; - result->src[1] = b; +struct lm_ggml_tensor * lm_ggml_map_custom1( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + const lm_ggml_custom1_op_t fun, + int n_tasks, + void * userdata) { + return lm_ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false); +} - return result; +struct lm_ggml_tensor * lm_ggml_map_custom1_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + const lm_ggml_custom1_op_t fun, + int n_tasks, + void * userdata) { + return lm_ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true); } -// a: [OC,IC, KH, KW] -// b: [N, IC, IH, IW] -// result: [N, OC, OH, OW] -struct lm_ggml_tensor * lm_ggml_conv_2d( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1) { - struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW] +// lm_ggml_map_custom2 - struct lm_ggml_tensor * result = - lm_ggml_mul_mat(ctx, - lm_ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW] - lm_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW] +static struct lm_ggml_tensor * lm_ggml_map_custom2_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + const lm_ggml_custom2_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { + LM_GGML_ASSERT(n_tasks == LM_GGML_N_TASKS_MAX || n_tasks > 0); - result = lm_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW] - result = lm_ggml_cont(ctx, lm_ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW] + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); + + struct lm_ggml_map_custom2_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + lm_ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + result->op = LM_GGML_OP_MAP_CUSTOM2; + result->src[0] = a; + result->src[1] = b; return result; } -// lm_ggml_conv_2d_sk_p0 - -struct lm_ggml_tensor * lm_ggml_conv_2d_sk_p0( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - return lm_ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1); +struct lm_ggml_tensor * lm_ggml_map_custom2( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + const lm_ggml_custom2_op_t fun, + int n_tasks, + void * userdata) { + return lm_ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false); } -// lm_ggml_conv_2d_s1_ph - -struct lm_ggml_tensor * lm_ggml_conv_2d_s1_ph( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - return lm_ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1); +struct lm_ggml_tensor * lm_ggml_map_custom2_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + const lm_ggml_custom2_op_t fun, + int n_tasks, + void * userdata) { + return lm_ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true); } -// lm_ggml_conv_transpose_2d_p0 +// lm_ggml_map_custom3 -static int64_t lm_ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { - return (ins - 1) * s - 2 * p + ks; -} +static struct lm_ggml_tensor * lm_ggml_map_custom3_impl( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c, + const lm_ggml_custom3_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { + LM_GGML_ASSERT(n_tasks == LM_GGML_N_TASKS_MAX || n_tasks > 0); -struct lm_ggml_tensor * lm_ggml_conv_transpose_2d_p0( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int stride) { - LM_GGML_ASSERT(a->ne[3] == b->ne[2]); + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - const int64_t ne[4] = { - lm_ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/), - lm_ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/), - a->ne[2], b->ne[3], + struct lm_ggml_map_custom3_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata }; + lm_ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - struct lm_ggml_tensor* result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - - lm_ggml_set_op_params_i32(result, 0, stride); - - result->op = LM_GGML_OP_CONV_TRANSPOSE_2D; + result->op = LM_GGML_OP_MAP_CUSTOM3; result->src[0] = a; result->src[1] = b; + result->src[2] = c; return result; } -// lm_ggml_pool_* +struct lm_ggml_tensor * lm_ggml_map_custom3( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c, + const lm_ggml_custom3_op_t fun, + int n_tasks, + void * userdata) { + return lm_ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false); +} -static int64_t lm_ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) { - return (ins + 2 * p - ks) / s + 1; +struct lm_ggml_tensor * lm_ggml_map_custom3_inplace( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c, + const lm_ggml_custom3_op_t fun, + int n_tasks, + void * userdata) { + return lm_ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true); } -// lm_ggml_pool_1d +// lm_ggml_cross_entropy_loss -struct lm_ggml_tensor * lm_ggml_pool_1d( +struct lm_ggml_tensor * lm_ggml_cross_entropy_loss( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - enum lm_ggml_op_pool op, - int k0, - int s0, - int p0) { - const int64_t ne[4] = { - lm_ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), - a->ne[1], - a->ne[2], - a->ne[3], - }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + struct lm_ggml_tensor * b) { + LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); - int32_t params[] = { op, k0, s0, p0 }; - lm_ggml_set_op_params(result, params, sizeof(params)); + struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, a->type, 1); - result->op = LM_GGML_OP_POOL_1D; + result->op = LM_GGML_OP_CROSS_ENTROPY_LOSS; result->src[0] = a; + result->src[1] = b; return result; } -// lm_ggml_pool_2d +// lm_ggml_cross_entropy_loss_back -struct lm_ggml_tensor * lm_ggml_pool_2d( +struct lm_ggml_tensor * lm_ggml_cross_entropy_loss_back( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - enum lm_ggml_op_pool op, - int k0, - int k1, - int s0, - int s1, - float p0, - float p1) { - struct lm_ggml_tensor * result; - const int64_t ne[4] = { - lm_ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), - lm_ggml_calc_pool_output_size(a->ne[1], k1, s1, p1), - a->ne[2], - a->ne[3], - }; - result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c) { + LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); + LM_GGML_ASSERT(lm_ggml_is_scalar(c)); - int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; - lm_ggml_set_op_params(result, params, sizeof(params)); + struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); - result->op = LM_GGML_OP_POOL_2D; + result->op = LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK; result->src[0] = a; + result->src[1] = b; + result->src[2] = c; return result; } -struct lm_ggml_tensor * lm_ggml_pool_2d_back( +// opt_step_adamw + +struct lm_ggml_tensor * lm_ggml_opt_step_adamw( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, - struct lm_ggml_tensor * af, - enum lm_ggml_op_pool op, - int k0, - int k1, - int s0, - int s1, - float p0, - float p1) { - struct lm_ggml_tensor * result; - result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, af->ne); + struct lm_ggml_tensor * grad, + float alpha, + float beta1, + float beta2, + float eps, + float wd) { + LM_GGML_ASSERT(a->flags & LM_GGML_TENSOR_FLAG_PARAM); + LM_GGML_ASSERT(lm_ggml_are_same_shape(a, grad)); + LM_GGML_ASSERT(alpha > 0.0f); + LM_GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f); + LM_GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f); + LM_GGML_ASSERT(eps >= 0.0f); + LM_GGML_ASSERT(wd >= 0.0f && wd <= 1.0f); - int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; - lm_ggml_set_op_params(result, params, sizeof(params)); + struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a); - result->op = LM_GGML_OP_POOL_2D_BACK; + const int64_t iter = 1; + memcpy(&result->op_params[0], &iter, sizeof(int64_t)); + lm_ggml_set_op_params_f32(result, 2, alpha); + lm_ggml_set_op_params_f32(result, 3, beta1); + lm_ggml_set_op_params_f32(result, 4, beta2); + lm_ggml_set_op_params_f32(result, 5, eps); + lm_ggml_set_op_params_f32(result, 6, wd); + + result->op = LM_GGML_OP_OPT_STEP_ADAMW; result->src[0] = a; - result->src[1] = af; + result->src[1] = grad; + result->src[2] = lm_ggml_dup_tensor(ctx, grad); + result->src[3] = lm_ggml_dup_tensor(ctx, grad); return result; } -// lm_ggml_upscale - -static struct lm_ggml_tensor * lm_ggml_upscale_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int ne0, - int ne1, - int ne2, - int ne3) { - LM_GGML_ASSERT(a->ne[0] <= ne0); - LM_GGML_ASSERT(a->ne[1] <= ne1); - LM_GGML_ASSERT(a->ne[2] <= ne2); - LM_GGML_ASSERT(a->ne[3] <= ne3); - - struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); - - result->op = LM_GGML_OP_UPSCALE; - result->src[0] = a; +//////////////////////////////////////////////////////////////////////////////// +struct lm_ggml_hash_set lm_ggml_hash_set_new(size_t size) { + size = lm_ggml_hash_size(size); + struct lm_ggml_hash_set result; + result.size = size; + result.keys = LM_GGML_MALLOC(sizeof(struct lm_ggml_tensor *) * size); + result.used = LM_GGML_CALLOC(lm_ggml_bitset_size(size), sizeof(lm_ggml_bitset_t)); return result; } -struct lm_ggml_tensor * lm_ggml_upscale( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int scale_factor) { - return lm_ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]); +void lm_ggml_hash_set_reset(struct lm_ggml_hash_set * hash_set) { + memset(hash_set->used, 0, sizeof(lm_ggml_bitset_t) * lm_ggml_bitset_size(hash_set->size)); } -struct lm_ggml_tensor * lm_ggml_upscale_ext( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int ne0, - int ne1, - int ne2, - int ne3) { - return lm_ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3); +void lm_ggml_hash_set_free(struct lm_ggml_hash_set * hash_set) { + LM_GGML_FREE(hash_set->used); + LM_GGML_FREE(hash_set->keys); } -// lm_ggml_pad +size_t lm_ggml_hash_size(size_t min_sz) { + // next primes after powers of two + static const size_t primes[] = { + 2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031, + 2053, 4099, 8209, 16411, 32771, 65537, 131101, + 262147, 524309, 1048583, 2097169, 4194319, 8388617, + 16777259, 33554467, 67108879, 134217757, 268435459, + 536870923, 1073741827, 2147483659 + }; + static const size_t n_primes = sizeof(primes)/sizeof(primes[0]); -struct lm_ggml_tensor * lm_ggml_pad( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int p0, - int p1, - int p2, - int p3) { - struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, - a->ne[0] + p0, - a->ne[1] + p1, - a->ne[2] + p2, - a->ne[3] + p3); + // find the smallest prime that is larger or equal than min_sz + size_t l = 0; + size_t r = n_primes; + while (l < r) { + size_t m = (l + r)/2; + if (primes[m] < min_sz) { + l = m + 1; + } else { + r = m; + } + } + size_t sz = l < n_primes ? primes[l] : min_sz | 1; + return sz; +} - result->op = LM_GGML_OP_PAD; - result->src[0] = a; +struct hash_map { + struct lm_ggml_hash_set set; + struct lm_ggml_tensor ** vals; +}; +static struct hash_map * lm_ggml_new_hash_map(size_t size) { + struct hash_map * result = LM_GGML_MALLOC(sizeof(struct hash_map)); + result->set = lm_ggml_hash_set_new(size); + result->vals = LM_GGML_CALLOC(result->set.size, sizeof(struct lm_ggml_tensor *)); return result; } -// lm_ggml_arange - -struct lm_ggml_tensor * lm_ggml_arange( - struct lm_ggml_context * ctx, - float start, - float stop, - float step) { - LM_GGML_ASSERT(stop > start); - - const int64_t steps = (int64_t) ceilf((stop - start) / step); - - struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, steps); - - lm_ggml_set_op_params_f32(result, 0, start); - lm_ggml_set_op_params_f32(result, 1, stop); - lm_ggml_set_op_params_f32(result, 2, step); - - result->op = LM_GGML_OP_ARANGE; - - return result; +static void lm_ggml_hash_map_free(struct hash_map * map) { + lm_ggml_hash_set_free(&map->set); + LM_GGML_FREE(map->vals); + LM_GGML_FREE(map); } -// lm_ggml_timestep_embedding +// gradient checkpointing -struct lm_ggml_tensor * lm_ggml_timestep_embedding( +static struct lm_ggml_tensor * lm_ggml_recompute_graph_node( struct lm_ggml_context * ctx, - struct lm_ggml_tensor * timesteps, - int dim, - int max_period) { - int actual_dim = dim; - if (dim % 2 != 0) { - actual_dim = dim + 1; + struct lm_ggml_cgraph * graph, + struct hash_map * replacements, + struct lm_ggml_tensor * node) { + + if (node == NULL) { + return NULL; } - struct lm_ggml_tensor * result = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, actual_dim, timesteps->ne[0]); + if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) { + return node; + } - lm_ggml_set_op_params_i32(result, 0, dim); - lm_ggml_set_op_params_i32(result, 1, max_period); + if (!lm_ggml_hash_contains(&graph->visited_hash_set, node)) { + return node; + } - result->op = LM_GGML_OP_TIMESTEP_EMBEDDING; - result->src[0] = timesteps; + int count_children = 0; + for (int k = 0; k < LM_GGML_MAX_SRC; ++k) { + if (node->src[k]) { + ++count_children; + } + } - return result; -} + if (count_children == 0) { + return node; + } -// lm_ggml_argsort + size_t i = lm_ggml_hash_find(&replacements->set, node); + LM_GGML_ASSERT(i != LM_GGML_HASHSET_FULL); // assert that not full + if (replacements->set.keys[i] == node) { + return replacements->vals[i]; + } -struct lm_ggml_tensor * lm_ggml_argsort( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - enum lm_ggml_sort_order order) { - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_I32, LM_GGML_MAX_DIMS, a->ne); + struct lm_ggml_tensor * clone = lm_ggml_new_tensor(ctx, node->type, LM_GGML_MAX_DIMS, node->ne); - lm_ggml_set_op_params_i32(result, 0, (int32_t) order); + // insert clone into replacements + LM_GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite + replacements->set.keys[i] = node; + replacements->vals[i] = clone; - result->op = LM_GGML_OP_ARGSORT; - result->src[0] = a; - - return result; -} - -// lm_ggml_top_k - -struct lm_ggml_tensor * lm_ggml_top_k( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int k) { - LM_GGML_ASSERT(a->ne[0] >= k); - - struct lm_ggml_tensor * result = lm_ggml_argsort(ctx, a, LM_GGML_SORT_ORDER_DESC); + clone->op = node->op; + clone->grad = node->grad; + clone->flags = node->flags; + clone->extra = node->extra; + for (int k = 0; k < LM_GGML_MAX_DIMS; ++k) { + clone->nb[k] = node->nb[k]; + } + for (int k = 0; k < LM_GGML_MAX_SRC; ++k) { + clone->src[k] = lm_ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); + } + if (node->view_src != NULL) { + clone->data = (node->view_src->data == NULL) + ? NULL // view_src not yet allocated + : (char *) node->view_src->data // view_src already allocated + + node->view_offs; + clone->view_src = node->view_src; + clone->view_offs = node->view_offs; + } - result = lm_ggml_view_4d(ctx, result, - k, result->ne[1], result->ne[2], result->ne[3], - result->nb[1], result->nb[2], result->nb[3], - 0); + LM_GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (LM_GGML_MAX_OP_PARAMS / sizeof(int32_t))); + LM_GGML_ASSERT(sizeof(node->name) == LM_GGML_MAX_NAME); + memcpy(clone->op_params, node->op_params, sizeof(node->op_params)); + lm_ggml_format_name(clone, "%s (clone)", lm_ggml_get_name(node)); - return result; + return clone; } -// lm_ggml_flash_attn_ext +void lm_ggml_build_backward_gradient_checkpointing( + struct lm_ggml_context * ctx, + struct lm_ggml_cgraph * gf, + struct lm_ggml_cgraph * gb, + struct lm_ggml_cgraph * gb_tmp, + struct lm_ggml_tensor * * checkpoints, + int n_checkpoints) { + lm_ggml_graph_cpy(gf, gb_tmp); + lm_ggml_build_backward_expand(ctx, gf, gb_tmp, false); -struct lm_ggml_tensor * lm_ggml_flash_attn_ext( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * q, - struct lm_ggml_tensor * k, - struct lm_ggml_tensor * v, - struct lm_ggml_tensor * mask, - float scale, - float max_bias, - float logit_softcap) { - LM_GGML_ASSERT(lm_ggml_can_mul_mat(k, q)); - // TODO: check if vT can be multiplied by (k*qT) + if (n_checkpoints <= 0) { + lm_ggml_graph_cpy(gb_tmp, gb); + return; + } - if (mask) { - LM_GGML_ASSERT(lm_ggml_is_contiguous(mask)); - LM_GGML_ASSERT(mask->ne[2] == 1); - LM_GGML_ASSERT(mask->ne[3] == 1); - LM_GGML_ASSERT(mask->ne[1] >= LM_GGML_PAD(q->ne[1], LM_GGML_KQ_MASK_PAD) && - "the Flash-Attention kernel requires the mask to be padded to LM_GGML_KQ_MASK_PAD and at least n_queries big"); - //LM_GGML_ASSERT(lm_ggml_can_repeat_rows(mask, qk)); + struct hash_map * replacements = lm_ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints); + + // insert checkpoints in replacements + for (int i = 0; i < n_checkpoints; ++i) { + size_t k = lm_ggml_hash_find(&replacements->set, checkpoints[i]); + LM_GGML_ASSERT(k != LM_GGML_HASHSET_FULL); // assert that not full + LM_GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite + replacements->set.keys[k] = checkpoints[i]; + replacements->vals[k] = checkpoints[i]; } - if (max_bias > 0.0f) { - LM_GGML_ASSERT(mask); + lm_ggml_graph_cpy(gf, gb); + // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes], + // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]), + // by recomputing them from checkpoints + for (int i = gf->n_nodes; in_nodes; ++i) { + struct lm_ggml_tensor * node = gb_tmp->nodes[i]; + for (int k = 0; k < LM_GGML_MAX_SRC; ++k) { + // insert new tensors recomputing src, reusing already made replacements, + // remember replacements: remember new tensors with mapping from corresponding gf nodes + // recurse for input tensors, + // unless (i.e. terminating when) input tensors are replacements (like checkpoints) + node->src[k] = lm_ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); + } + // insert rewritten backward node with replacements made into resulting backward graph gb + lm_ggml_build_forward_expand(gb, node); } - bool is_node = false; + lm_ggml_hash_map_free(replacements); +} - // permute(0, 2, 1, 3) - int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); +// utility functions to change gradients +// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator +// else if a is in zero_table, replace a +// else, just add/subtract/etc. the gradients - float params[] = { scale, max_bias, logit_softcap }; - lm_ggml_set_op_params(result, params, sizeof(params)); +static struct lm_ggml_tensor * lm_ggml_add_or_set( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_hash_set * zero_table, + struct lm_ggml_hash_set * acc_table) { + if (lm_ggml_hash_contains(acc_table, a)) { + struct lm_ggml_tensor * ret = lm_ggml_add_impl(ctx, a, b, true); + const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); + return ret; + } + if (lm_ggml_hash_contains(zero_table, a)) { + return b; + } + return lm_ggml_add_impl(ctx, a, b, false); +} - result->op = LM_GGML_OP_FLASH_ATTN_EXT; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = q; - result->src[1] = k; - result->src[2] = v; - result->src[3] = mask; +static struct lm_ggml_tensor * lm_ggml_acc_or_set( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + const size_t nb1, + const size_t nb2, + const size_t nb3, + const size_t offset, + struct lm_ggml_hash_set * zero_table, + struct lm_ggml_hash_set * acc_table) { + if (lm_ggml_hash_contains(acc_table, a)) { + struct lm_ggml_tensor * ret = lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); + const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); + return ret; + } + if (lm_ggml_hash_contains(zero_table, a)) { + struct lm_ggml_tensor * a_zero = lm_ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN + return lm_ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); + } + return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); +} - return result; +static struct lm_ggml_tensor * lm_ggml_add1_or_set( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_hash_set * zero_table, + struct lm_ggml_hash_set * acc_table) { + if (lm_ggml_hash_contains(acc_table, a)) { + struct lm_ggml_tensor * ret = lm_ggml_add1_impl(ctx, a, b, true); + const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); + return ret; + } + if (lm_ggml_hash_contains(zero_table, a)) { + return lm_ggml_repeat(ctx, b, a); + } + return lm_ggml_add1_impl(ctx, a, b, false); } -void lm_ggml_flash_attn_ext_set_prec( - struct lm_ggml_tensor * a, - enum lm_ggml_prec prec) { - LM_GGML_ASSERT(a->op == LM_GGML_OP_FLASH_ATTN_EXT); +static struct lm_ggml_tensor * lm_ggml_sub_or_set( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_hash_set * zero_table, + struct lm_ggml_hash_set * acc_table) { + if (lm_ggml_hash_contains(acc_table, a)) { + struct lm_ggml_tensor * ret = lm_ggml_sub_impl(ctx, a, b, true); + const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); + return ret; + } + if (lm_ggml_hash_contains(zero_table, a)) { + return lm_ggml_neg(ctx, b); + } + return lm_ggml_sub_impl(ctx, a, b, false); +} - const int32_t prec_i32 = (int32_t) prec; +static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor, struct lm_ggml_hash_set * zero_table, struct lm_ggml_hash_set * acc_table) { + struct lm_ggml_tensor * src0 = tensor->src[0]; + struct lm_ggml_tensor * src1 = tensor->src[1]; + struct lm_ggml_tensor * src2 = tensor->src[2]; - lm_ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second -} + switch (tensor->op) { + case LM_GGML_OP_DUP: + { + if (src0->grad) { + src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); + } + } break; + case LM_GGML_OP_ADD: + { + if (src0->grad) { + src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); + } + if (src1->grad) { + if (lm_ggml_are_same_shape(src0, src1)) { + src1->grad = lm_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table); + } else { + src1->grad = lm_ggml_add_or_set(ctx, src1->grad, lm_ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table); + } + } + } break; + case LM_GGML_OP_ADD1: + { + if (src0->grad) { + src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); + } + if (src1->grad) { + src1->grad = lm_ggml_add_or_set(ctx, + src1->grad, + lm_ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean + zero_table, acc_table); + } + } break; + case LM_GGML_OP_ACC: + { + if (src0->grad) { + src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); + } + if (src1->grad) { + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; -// lm_ggml_flash_attn_back + struct lm_ggml_tensor * tensor_grad_view = lm_ggml_view_4d(ctx, + tensor->grad, + src1->grad->ne[0], + src1->grad->ne[1], + src1->grad->ne[2], + src1->grad->ne[3], + nb1, nb2, nb3, offset); -struct lm_ggml_tensor * lm_ggml_flash_attn_back( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * q, - struct lm_ggml_tensor * k, - struct lm_ggml_tensor * v, - struct lm_ggml_tensor * d, - bool masked) { - LM_GGML_ABORT("TODO: adapt to lm_ggml_flash_attn_ext() changes"); - - LM_GGML_ASSERT(lm_ggml_can_mul_mat(k, q)); - // TODO: check if vT can be multiplied by (k*qT) - - // d shape [D,N,ne2,ne3] - // q shape [D,N,ne2,ne3] - // k shape [D,M,kvne2,ne3] - // v shape [M,D,kvne2,ne3] - - const int64_t D = q->ne[0]; - const int64_t N = q->ne[1]; - const int64_t M = k->ne[1]; - const int64_t ne2 = q->ne[2]; - const int64_t ne3 = q->ne[3]; - const int64_t kvne2 = k->ne[2]; - - LM_GGML_ASSERT(k->ne[0] == D); - LM_GGML_ASSERT(v->ne[0] == M); - LM_GGML_ASSERT(v->ne[1] == D); - LM_GGML_ASSERT(d->ne[0] == D); - LM_GGML_ASSERT(d->ne[1] == N); - LM_GGML_ASSERT(k->ne[2] == kvne2); - LM_GGML_ASSERT(k->ne[3] == ne3); - LM_GGML_ASSERT(v->ne[2] == kvne2); - LM_GGML_ASSERT(v->ne[3] == ne3); - LM_GGML_ASSERT(d->ne[2] == ne2); - LM_GGML_ASSERT(d->ne[3] == ne3); - - LM_GGML_ASSERT(ne2 % kvne2 == 0); - - bool is_node = false; - - if (q->grad || k->grad || v->grad) { - // when using this operation (in backwards pass) these grads are set. - // we don't want to create (big) grad of our result, so is_node is false. - is_node = false; - } - - // store gradients of q, k and v as continuous tensors concatenated in result. - // note: v and gradv are actually transposed, i.e. v->ne[0] != D. - const int64_t elem_q = lm_ggml_nelements(q); - const int64_t elem_k = lm_ggml_nelements(k); - const int64_t elem_v = lm_ggml_nelements(v); - - enum lm_ggml_type result_type = LM_GGML_TYPE_F32; - LM_GGML_ASSERT(lm_ggml_blck_size(result_type) == 1); - const size_t tsize = lm_ggml_type_size(result_type); - - const size_t offs_q = 0; - const size_t offs_k = offs_q + LM_GGML_PAD(elem_q * tsize, LM_GGML_MEM_ALIGN); - const size_t offs_v = offs_k + LM_GGML_PAD(elem_k * tsize, LM_GGML_MEM_ALIGN); - const size_t end = offs_v + LM_GGML_PAD(elem_v * tsize, LM_GGML_MEM_ALIGN); - - const size_t nelements = (end + tsize - 1)/tsize; - - struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, nelements); - - int32_t masked_i = masked ? 1 : 0; - lm_ggml_set_op_params(result, &masked_i, sizeof(masked_i)); - - result->op = LM_GGML_OP_FLASH_ATTN_BACK; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = q; - result->src[1] = k; - result->src[2] = v; - result->src[3] = d; - - return result; -} - -// lm_ggml_ssm_conv - -struct lm_ggml_tensor * lm_ggml_ssm_conv( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * sx, - struct lm_ggml_tensor * c) { - LM_GGML_ASSERT(lm_ggml_is_3d(sx)); - LM_GGML_ASSERT(lm_ggml_is_matrix(c)); - - const int64_t d_conv = c->ne[0]; - const int64_t d_inner = c->ne[1]; - const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence - const int64_t n_s = sx->ne[2]; - - // TODO: maybe support other strides than 1? - // FIXME: this is always true? - LM_GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); - LM_GGML_ASSERT(sx->ne[1] == d_inner); - LM_GGML_ASSERT(n_t >= 0); - - struct lm_ggml_tensor * result = lm_ggml_new_tensor_3d(ctx, LM_GGML_TYPE_F32, d_inner, n_t, n_s); - - result->op = LM_GGML_OP_SSM_CONV; - result->src[0] = sx; - result->src[1] = c; - - return result; -} - -// lm_ggml_ssm_scan - -struct lm_ggml_tensor * lm_ggml_ssm_scan( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * s, - struct lm_ggml_tensor * x, - struct lm_ggml_tensor * dt, - struct lm_ggml_tensor * A, - struct lm_ggml_tensor * B, - struct lm_ggml_tensor * C) { - LM_GGML_ASSERT(lm_ggml_is_contiguous(s)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(x)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(dt)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(A)); - LM_GGML_ASSERT(lm_ggml_is_matrix(A)); - LM_GGML_ASSERT(lm_ggml_is_3d(B)); - LM_GGML_ASSERT(lm_ggml_is_3d(s)); - LM_GGML_ASSERT(B->nb[0] == lm_ggml_type_size(B->type)); - LM_GGML_ASSERT(C->nb[0] == lm_ggml_type_size(C->type)); - LM_GGML_ASSERT(lm_ggml_are_same_shape(x, dt)); - LM_GGML_ASSERT(lm_ggml_are_same_shape(B, C)); - - { - const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_seq_tokens = x->ne[1]; - const int64_t n_seqs = x->ne[2]; - - LM_GGML_ASSERT(s->ne[2] == n_seqs); - LM_GGML_ASSERT(x->ne[0] == d_inner); - LM_GGML_ASSERT(A->ne[0] == d_state); - LM_GGML_ASSERT(A->ne[1] == d_inner); - LM_GGML_ASSERT(B->ne[0] == d_state); - LM_GGML_ASSERT(B->ne[1] == n_seq_tokens); - LM_GGML_ASSERT(B->ne[2] == n_seqs); - } - - // concatenated y + ssm_states - struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, lm_ggml_nelements(x) + lm_ggml_nelements(s)); - - result->op = LM_GGML_OP_SSM_SCAN; - result->src[0] = s; - result->src[1] = x; - result->src[2] = dt; - result->src[3] = A; - result->src[4] = B; - result->src[5] = C; - - return result; -} - -// lm_ggml_win_part - -struct lm_ggml_tensor * lm_ggml_win_part( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int w) { - LM_GGML_ASSERT(a->ne[3] == 1); - LM_GGML_ASSERT(a->type == LM_GGML_TYPE_F32); - - // padding - const int px = (w - a->ne[1]%w)%w; - const int py = (w - a->ne[2]%w)%w; - - const int npx = (px + a->ne[1])/w; - const int npy = (py + a->ne[2])/w; - const int np = npx*npy; - - const int64_t ne[4] = { a->ne[0], w, w, np, }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - - int32_t params[] = { npx, npy, w }; - lm_ggml_set_op_params(result, params, sizeof(params)); - - result->op = LM_GGML_OP_WIN_PART; - result->src[0] = a; - - return result; -} - -// lm_ggml_win_unpart - -struct lm_ggml_tensor * lm_ggml_win_unpart( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int w0, - int h0, - int w) { - LM_GGML_ASSERT(a->type == LM_GGML_TYPE_F32); - - const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 3, ne); - - int32_t params[] = { w }; - lm_ggml_set_op_params(result, params, sizeof(params)); - - result->op = LM_GGML_OP_WIN_UNPART; - result->src[0] = a; - - return result; -} - -// lm_ggml_get_rel_pos - -struct lm_ggml_tensor * lm_ggml_get_rel_pos( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - int qh, - int kh) { - LM_GGML_ASSERT(qh == kh); - LM_GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]); - - const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F16, 3, ne); - - result->op = LM_GGML_OP_GET_REL_POS; - result->src[0] = a; - - return result; -} - -// lm_ggml_add_rel_pos - -static struct lm_ggml_tensor * lm_ggml_add_rel_pos_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * pw, - struct lm_ggml_tensor * ph, - bool inplace) { - LM_GGML_ASSERT(lm_ggml_are_same_shape(pw, ph)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(pw)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(ph)); - LM_GGML_ASSERT(ph->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(pw->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(pw->ne[3] == a->ne[2]); - LM_GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]); - LM_GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]); - - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - lm_ggml_set_op_params_i32(result, 0, inplace ? 1 : 0); - - result->op = LM_GGML_OP_ADD_REL_POS; - result->src[0] = a; - result->src[1] = pw; - result->src[2] = ph; - - return result; -} - -struct lm_ggml_tensor * lm_ggml_add_rel_pos( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * pw, - struct lm_ggml_tensor * ph) { - return lm_ggml_add_rel_pos_impl(ctx, a, pw, ph, false); -} - -struct lm_ggml_tensor * lm_ggml_add_rel_pos_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * pw, - struct lm_ggml_tensor * ph) { - return lm_ggml_add_rel_pos_impl(ctx, a, pw, ph, true); -} - -// lm_ggml_rwkv_wkv - -struct lm_ggml_tensor * lm_ggml_rwkv_wkv( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * k, - struct lm_ggml_tensor * v, - struct lm_ggml_tensor * r, - struct lm_ggml_tensor * tf, - struct lm_ggml_tensor * td, - struct lm_ggml_tensor * state) { - LM_GGML_ASSERT(lm_ggml_is_contiguous(k)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(v)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(r)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(tf)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(td)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(state)); - - const int64_t S = k->ne[0]; - const int64_t H = k->ne[2]; - const int64_t n_tokens = k->ne[3]; - const int64_t n_seqs = state->ne[1]; - { - LM_GGML_ASSERT(k->ne[1] == 1); - LM_GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens); - LM_GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens); - // TODO: RWKV v4 and v5 - LM_GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens); - LM_GGML_ASSERT(lm_ggml_nelements(state) == S * S * H * n_seqs); - } - - // concat output and new_state - const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - - result->op = LM_GGML_OP_RWKV_WKV; - result->src[0] = k; - result->src[1] = v; - result->src[2] = r; - result->src[3] = tf; - result->src[4] = td; - result->src[5] = state; - - return result; -} - -// lm_ggml_unary - -static struct lm_ggml_tensor * lm_ggml_unary_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - enum lm_ggml_unary_op op, - bool inplace) { - LM_GGML_ASSERT(lm_ggml_is_contiguous_1(a)); - - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - - lm_ggml_set_op_params_i32(result, 0, (int32_t) op); - - result->op = LM_GGML_OP_UNARY; - result->src[0] = a; - - return result; -} - -struct lm_ggml_tensor * lm_ggml_unary( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - enum lm_ggml_unary_op op) { - return lm_ggml_unary_impl(ctx, a, op, false); -} - -struct lm_ggml_tensor * lm_ggml_unary_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - enum lm_ggml_unary_op op) { - return lm_ggml_unary_impl(ctx, a, op, true); -} - -// lm_ggml_map_unary - -static struct lm_ggml_tensor * lm_ggml_map_unary_impl_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - const lm_ggml_unary_op_f32_t fun, - bool inplace) { - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - - lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = LM_GGML_OP_MAP_UNARY; - result->src[0] = a; - - return result; -} - -struct lm_ggml_tensor * lm_ggml_map_unary_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - const lm_ggml_unary_op_f32_t fun) { - return lm_ggml_map_unary_impl_f32(ctx, a, fun, false); -} - -struct lm_ggml_tensor * lm_ggml_map_unary_inplace_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - const lm_ggml_unary_op_f32_t fun) { - return lm_ggml_map_unary_impl_f32(ctx, a, fun, true); -} - -// lm_ggml_map_binary - -static struct lm_ggml_tensor * lm_ggml_map_binary_impl_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - const lm_ggml_binary_op_f32_t fun, - bool inplace) { - LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); - - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - - lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = LM_GGML_OP_MAP_BINARY; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct lm_ggml_tensor * lm_ggml_map_binary_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - const lm_ggml_binary_op_f32_t fun) { - return lm_ggml_map_binary_impl_f32(ctx, a, b, fun, false); -} - -struct lm_ggml_tensor * lm_ggml_map_binary_inplace_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - const lm_ggml_binary_op_f32_t fun) { - return lm_ggml_map_binary_impl_f32(ctx, a, b, fun, true); -} - -// lm_ggml_map_custom1_f32 - -static struct lm_ggml_tensor * lm_ggml_map_custom1_impl_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - const lm_ggml_custom1_op_f32_t fun, - bool inplace) { - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - - lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = LM_GGML_OP_MAP_CUSTOM1_F32; - result->src[0] = a; - - return result; -} - -struct lm_ggml_tensor * lm_ggml_map_custom1_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - const lm_ggml_custom1_op_f32_t fun) { - return lm_ggml_map_custom1_impl_f32(ctx, a, fun, false); -} - -struct lm_ggml_tensor * lm_ggml_map_custom1_inplace_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - const lm_ggml_custom1_op_f32_t fun) { - return lm_ggml_map_custom1_impl_f32(ctx, a, fun, true); -} - -// lm_ggml_map_custom2_f32 - -static struct lm_ggml_tensor * lm_ggml_map_custom2_impl_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - const lm_ggml_custom2_op_f32_t fun, - bool inplace) { - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - - lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = LM_GGML_OP_MAP_CUSTOM2_F32; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct lm_ggml_tensor * lm_ggml_map_custom2_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - const lm_ggml_custom2_op_f32_t fun) { - return lm_ggml_map_custom2_impl_f32(ctx, a, b, fun, false); -} - -struct lm_ggml_tensor * lm_ggml_map_custom2_inplace_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - const lm_ggml_custom2_op_f32_t fun) { - return lm_ggml_map_custom2_impl_f32(ctx, a, b, fun, true); -} - -// lm_ggml_map_custom3_f32 - -static struct lm_ggml_tensor * lm_ggml_map_custom3_impl_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c, - const lm_ggml_custom3_op_f32_t fun, - bool inplace) { - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - - lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = LM_GGML_OP_MAP_CUSTOM3_F32; - result->src[0] = a; - result->src[1] = b; - result->src[2] = c; - - return result; -} - -struct lm_ggml_tensor * lm_ggml_map_custom3_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c, - const lm_ggml_custom3_op_f32_t fun) { - return lm_ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false); -} - -struct lm_ggml_tensor * lm_ggml_map_custom3_inplace_f32( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c, - const lm_ggml_custom3_op_f32_t fun) { - return lm_ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true); -} - -// lm_ggml_map_custom1 -struct lm_ggml_map_custom1_op_params { - lm_ggml_custom1_op_t fun; - int n_tasks; - void * userdata; -}; - -static struct lm_ggml_tensor * lm_ggml_map_custom1_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - const lm_ggml_custom1_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { - LM_GGML_ASSERT(n_tasks == LM_GGML_N_TASKS_MAX || n_tasks > 0); - - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - - struct lm_ggml_map_custom1_op_params params = { - /*.fun =*/ fun, - /*.n_tasks =*/ n_tasks, - /*.userdata =*/ userdata - }; - lm_ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - - result->op = LM_GGML_OP_MAP_CUSTOM1; - result->src[0] = a; - - return result; -} - -struct lm_ggml_tensor * lm_ggml_map_custom1( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - const lm_ggml_custom1_op_t fun, - int n_tasks, - void * userdata) { - return lm_ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false); -} - -struct lm_ggml_tensor * lm_ggml_map_custom1_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - const lm_ggml_custom1_op_t fun, - int n_tasks, - void * userdata) { - return lm_ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true); -} - -// lm_ggml_map_custom2 - -struct lm_ggml_map_custom2_op_params { - lm_ggml_custom2_op_t fun; - int n_tasks; - void * userdata; -}; - -static struct lm_ggml_tensor * lm_ggml_map_custom2_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - const lm_ggml_custom2_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { - LM_GGML_ASSERT(n_tasks == LM_GGML_N_TASKS_MAX || n_tasks > 0); - - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - - struct lm_ggml_map_custom2_op_params params = { - /*.fun =*/ fun, - /*.n_tasks =*/ n_tasks, - /*.userdata =*/ userdata - }; - lm_ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - - result->op = LM_GGML_OP_MAP_CUSTOM2; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct lm_ggml_tensor * lm_ggml_map_custom2( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - const lm_ggml_custom2_op_t fun, - int n_tasks, - void * userdata) { - return lm_ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false); -} - -struct lm_ggml_tensor * lm_ggml_map_custom2_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - const lm_ggml_custom2_op_t fun, - int n_tasks, - void * userdata) { - return lm_ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true); -} - -// lm_ggml_map_custom3 - -struct lm_ggml_map_custom3_op_params { - lm_ggml_custom3_op_t fun; - int n_tasks; - void * userdata; -}; - -static struct lm_ggml_tensor * lm_ggml_map_custom3_impl( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c, - const lm_ggml_custom3_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { - LM_GGML_ASSERT(n_tasks == LM_GGML_N_TASKS_MAX || n_tasks > 0); - - struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - - struct lm_ggml_map_custom3_op_params params = { - /*.fun =*/ fun, - /*.n_tasks =*/ n_tasks, - /*.userdata =*/ userdata - }; - lm_ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - - result->op = LM_GGML_OP_MAP_CUSTOM3; - result->src[0] = a; - result->src[1] = b; - result->src[2] = c; - - return result; -} - -struct lm_ggml_tensor * lm_ggml_map_custom3( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c, - const lm_ggml_custom3_op_t fun, - int n_tasks, - void * userdata) { - return lm_ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false); -} - -struct lm_ggml_tensor * lm_ggml_map_custom3_inplace( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c, - const lm_ggml_custom3_op_t fun, - int n_tasks, - void * userdata) { - return lm_ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true); -} - -// lm_ggml_cross_entropy_loss - -struct lm_ggml_tensor * lm_ggml_cross_entropy_loss( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b) { - LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); - - struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, a->type, 1); - - result->op = LM_GGML_OP_CROSS_ENTROPY_LOSS; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// lm_ggml_cross_entropy_loss_back - -struct lm_ggml_tensor * lm_ggml_cross_entropy_loss_back( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_tensor * c) { - LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b)); - LM_GGML_ASSERT(lm_ggml_is_scalar(c)); - - struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); - - result->op = LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK; - result->src[0] = a; - result->src[1] = b; - result->src[2] = c; - - return result; -} - -// opt_step_adamw - -struct lm_ggml_tensor * lm_ggml_opt_step_adamw( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * grad, - float alpha, - float beta1, - float beta2, - float eps, - float wd) { - LM_GGML_ASSERT(a->flags & LM_GGML_TENSOR_FLAG_PARAM); - LM_GGML_ASSERT(lm_ggml_are_same_shape(a, grad)); - LM_GGML_ASSERT(alpha > 0.0f); - LM_GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f); - LM_GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f); - LM_GGML_ASSERT(eps >= 0.0f); - LM_GGML_ASSERT(wd >= 0.0f && wd <= 1.0f); - - struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a); - - const int64_t iter = 1; - memcpy(&result->op_params[0], &iter, sizeof(int64_t)); - lm_ggml_set_op_params_f32(result, 2, alpha); - lm_ggml_set_op_params_f32(result, 3, beta1); - lm_ggml_set_op_params_f32(result, 4, beta2); - lm_ggml_set_op_params_f32(result, 5, eps); - lm_ggml_set_op_params_f32(result, 6, wd); - - result->op = LM_GGML_OP_OPT_STEP_ADAMW; - result->src[0] = a; - result->src[1] = grad; - result->src[2] = lm_ggml_dup_tensor(ctx, grad); - result->src[3] = lm_ggml_dup_tensor(ctx, grad); - - return result; -} - -//////////////////////////////////////////////////////////////////////////////// - -// lm_ggml_compute_forward_dup - -static void lm_ggml_compute_forward_dup_same_cont( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); - LM_GGML_ASSERT(src0->type == dst->type); - - const size_t nb0 = lm_ggml_type_size(src0->type); - - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - // parallelize by elements - const int ne = lm_ggml_nelements(dst); - const int dr = (ne + nth - 1) / nth; - const int ie0 = dr * ith; - const int ie1 = MIN(ie0 + dr, ne); - - if (ie0 < ie1) { - memcpy( - ((char *) dst->data + ie0*nb0), - ((char *) src0->data + ie0*nb0), - (ie1 - ie0) * nb0); - } -} - -static void lm_ggml_compute_forward_dup_f16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - // parallelize by rows - const int nr = ne01; - // number of rows per thread - const int dr = (nr + nth - 1) / nth; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (src0->type == dst->type && - ne00 == ne0 && - nb00 == lm_ggml_type_size(src0->type) && nb0 == lm_ggml_type_size(dst->type)) { - // copy by rows - const size_t rs = ne00*nb00; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ir0; i01 < ir1; i01++) { - memcpy( - ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - rs); - } - } - } - return; - } - - // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy - - if (lm_ggml_is_contiguous(dst)) { - if (nb00 == sizeof(lm_ggml_fp16_t)) { - if (dst->type == LM_GGML_TYPE_F16) { - size_t id = 0; - const size_t rs = ne00 * nb00; - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, rs); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else if (dst->type == LM_GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - for (int i00 = 0; i00 < ne00; i00++) { - dst_ptr[id] = LM_GGML_FP16_TO_FP32(src0_ptr[i00]); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (type_traits[dst->type].from_float) { - lm_ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; - float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; - - size_t id = 0; - size_t rs = nb0 * (ne00 / lm_ggml_blck_size(dst->type)); - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - for (int i00 = 0; i00 < ne00; i00++) { - src0_f32[i00] = LM_GGML_FP16_TO_FP32(src0_ptr[i00]); - } - - quantize_row_q(src0_f32, dst_ptr + id, ne00); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - LM_GGML_ABORT("fatal error"); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == LM_GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = LM_GGML_FP16_TO_FP32(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == LM_GGML_TYPE_F16) { - size_t id = 0; - lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else { - LM_GGML_ABORT("fatal error"); // TODO: implement - } - } - return; - } - - // dst counters - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - if (dst->type == LM_GGML_TYPE_F16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, sizeof(lm_ggml_fp16_t)); - - if (++i10 == ne00) { - i10 = 0; - if (++i11 == ne01) { - i11 = 0; - if (++i12 == ne02) { - i12 = 0; - if (++i13 == ne03) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == LM_GGML_TYPE_F32) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(float *) dst_ptr = LM_GGML_FP16_TO_FP32(*(const lm_ggml_fp16_t *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else { - LM_GGML_ABORT("fatal error"); // TODO: implement - } -} - -static void lm_ggml_compute_forward_dup_bf16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - // parallelize by rows - const int nr = ne01; - // number of rows per thread - const int dr = (nr + nth - 1) / nth; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (src0->type == dst->type && - ne00 == ne0 && - nb00 == lm_ggml_type_size(src0->type) && nb0 == lm_ggml_type_size(dst->type)) { - // copy by rows - const size_t rs = ne00*nb00; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ir0; i01 < ir1; i01++) { - memcpy( - ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - rs); - } - } - } - return; - } - - // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy - - if (lm_ggml_is_contiguous(dst)) { - if (nb00 == sizeof(lm_ggml_bf16_t)) { - if (dst->type == LM_GGML_TYPE_BF16) { - size_t id = 0; - const size_t rs = ne00 * nb00; - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, rs); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else if (dst->type == LM_GGML_TYPE_F16) { - size_t id = 0; - lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - for (int i00 = 0; i00 < ne00; i00++) { - dst_ptr[id] = LM_GGML_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(src0_ptr[i00])); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == LM_GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - for (int i00 = 0; i00 < ne00; i00++) { - dst_ptr[id] = LM_GGML_BF16_TO_FP32(src0_ptr[i00]); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (type_traits[dst->type].from_float) { - lm_ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; - float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; - - size_t id = 0; - size_t rs = nb0 * (ne00 / lm_ggml_blck_size(dst->type)); - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - for (int i00 = 0; i00 < ne00; i00++) { - src0_f32[i00] = LM_GGML_BF16_TO_FP32(src0_ptr[i00]); - } - - quantize_row_q(src0_f32, dst_ptr + id, ne00); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - LM_GGML_ABORT("fatal error"); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == LM_GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = LM_GGML_BF16_TO_FP32(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == LM_GGML_TYPE_BF16) { - size_t id = 0; - lm_ggml_bf16_t * dst_ptr = (lm_ggml_bf16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == LM_GGML_TYPE_F16) { - size_t id = 0; - lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = LM_GGML_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(*src0_ptr)); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else { - LM_GGML_ABORT("fatal error"); // TODO: implement - } - } - return; - } - - // dst counters - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - if (dst->type == LM_GGML_TYPE_BF16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, sizeof(lm_ggml_bf16_t)); - - if (++i10 == ne00) { - i10 = 0; - if (++i11 == ne01) { - i11 = 0; - if (++i12 == ne02) { - i12 = 0; - if (++i13 == ne03) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == LM_GGML_TYPE_F16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(lm_ggml_fp16_t *) dst_ptr = LM_GGML_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(*(const lm_ggml_bf16_t *) src0_ptr)); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == LM_GGML_TYPE_F32) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(float *) dst_ptr = LM_GGML_BF16_TO_FP32(*(const lm_ggml_bf16_t *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else { - LM_GGML_ABORT("fatal error"); // TODO: implement - } -} - -static void lm_ggml_compute_forward_dup_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - // parallelize by rows - const int nr = ne01; - // number of rows per thread - const int dr = (nr + nth - 1) / nth; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (src0->type == dst->type && - ne00 == ne0 && - nb00 == lm_ggml_type_size(src0->type) && nb0 == lm_ggml_type_size(dst->type)) { - // copy by rows - const size_t rs = ne00*nb00; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ir0; i01 < ir1; i01++) { - memcpy( - ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - rs); - } - } - } - return; - } - - if (lm_ggml_is_contiguous(dst)) { - // TODO: simplify - if (nb00 == sizeof(float)) { - if (dst->type == LM_GGML_TYPE_F32) { - size_t id = 0; - const size_t rs = ne00 * nb00; - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, rs); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else if (type_traits[dst->type].from_float) { - lm_ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; - - size_t id = 0; - size_t rs = nb0 * (ne00 / lm_ggml_blck_size(dst->type)); - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - quantize_row_q(src0_ptr, dst_ptr + id, ne00); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - LM_GGML_ABORT("fatal error"); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == LM_GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == LM_GGML_TYPE_F16) { - size_t id = 0; - lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = LM_GGML_FP32_TO_FP16(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == LM_GGML_TYPE_BF16) { - size_t id = 0; - lm_ggml_bf16_t * dst_ptr = (lm_ggml_bf16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = LM_GGML_FP32_TO_BF16(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else { - LM_GGML_ABORT("fatal error"); // TODO: implement - } - } - - return; - } - - // dst counters - - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - if (dst->type == LM_GGML_TYPE_F32) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, sizeof(float)); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == LM_GGML_TYPE_F16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(lm_ggml_fp16_t *) dst_ptr = LM_GGML_FP32_TO_FP16(*(const float *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == LM_GGML_TYPE_BF16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(lm_ggml_bf16_t *) dst_ptr = LM_GGML_FP32_TO_BF16(*(const float *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else { - LM_GGML_ABORT("fatal error"); // TODO: implement - } -} - -// A simplified version of lm_ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy. -static void lm_ggml_compute_forward_dup_bytes( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); - LM_GGML_ASSERT(src0->type == dst->type); - - LM_GGML_TENSOR_UNARY_OP_LOCALS; - - if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst)) { - lm_ggml_compute_forward_dup_same_cont(params, dst); - return; - } - - const size_t type_size = lm_ggml_type_size(src0->type); - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - - // parallelize by rows - const int nr = ne01; - // number of rows per thread - const int dr = (nr + nth - 1) / nth; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (src0->type == dst->type && - ne00 == ne0 && - nb00 == type_size && nb0 == type_size) { - // copy by rows - const size_t rs = ne00 * type_size; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ir0; i01 < ir1; i01++) { - memcpy( - ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - rs); - } - } - } - return; - } - - if (lm_ggml_is_contiguous(dst)) { - size_t id = 0; - char * dst_ptr = (char *) dst->data; - const size_t rs = ne00 * type_size; - - if (nb00 == type_size) { - // src0 is contigous on first dimension, copy by rows - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int64_t i01 = ir0; i01 < ir1; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, rs); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, type_size); - - id += type_size; - } - } - id += rs * (ne01 - ir1); - } - } - } - - return; - } - - // dst counters - - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, type_size); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } -} - -static void lm_ggml_compute_forward_dup( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (src0->type == dst->type) { - lm_ggml_compute_forward_dup_bytes(params, dst); - return; - } - - switch (src0->type) { - case LM_GGML_TYPE_F16: - { - lm_ggml_compute_forward_dup_f16(params, dst); - } break; - case LM_GGML_TYPE_BF16: - { - lm_ggml_compute_forward_dup_bf16(params, dst); - } break; - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_dup_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_add - -static void lm_ggml_compute_forward_add_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - LM_GGML_ASSERT( nb0 == sizeof(float)); - LM_GGML_ASSERT(nb00 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(float)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - for (int64_t r = 0; r < nr0; ++r) { -#ifdef LM_GGML_USE_ACCELERATE - vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10); -#else - lm_ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); -#endif - } - } - } else { - // src1 is not contiguous - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int64_t i10 = i0 % ne10; - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); - - dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; - } - } - } -} - -static void lm_ggml_compute_forward_add_f16_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - - if (dst->type == LM_GGML_TYPE_F32) { - LM_GGML_ASSERT( nb0 == sizeof(float)); - } - else { - LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F16); - LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_fp16_t)); - } - - LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(float)) { - if (dst->type == LM_GGML_TYPE_F16) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); - } - } - } else { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = LM_GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; - } - } - } - } - else { - // src1 is not contiguous - LM_GGML_ABORT("fatal error"); - } -} - -static void lm_ggml_compute_forward_add_bf16_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_BF16); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - - if (dst->type == LM_GGML_TYPE_F32) { - LM_GGML_ASSERT( nb0 == sizeof(float)); - } - else { - LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_BF16); - LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_bf16_t)); - } - - LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_bf16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(float)) { - if (dst->type == LM_GGML_TYPE_BF16) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - lm_ggml_bf16_t * dst_ptr = (lm_ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = LM_GGML_FP32_TO_BF16(LM_GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); - } - } - } else { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = LM_GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; - } - } - } - } - else { - // src1 is not contiguous - LM_GGML_ABORT("fatal error"); - } -} - -static void lm_ggml_compute_forward_add_f16_f16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F16); - LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F16); - - LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_fp16_t)); - LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(lm_ggml_fp16_t)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - lm_ggml_fp16_t * src1_ptr = (lm_ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(src0_ptr[i]) + LM_GGML_FP16_TO_FP32(src1_ptr[i])); - } - } - } - else { - // src1 is not contiguous - LM_GGML_ABORT("fatal error"); - } -} - -static void lm_ggml_compute_forward_add_bf16_bf16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_BF16); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_BF16); - LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_BF16); - - LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_bf16_t)); - LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_bf16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(lm_ggml_bf16_t)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - lm_ggml_bf16_t * dst_ptr = (lm_ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - lm_ggml_bf16_t * src1_ptr = (lm_ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = LM_GGML_FP32_TO_BF16(LM_GGML_BF16_TO_FP32(src0_ptr[i]) + LM_GGML_BF16_TO_FP32(src1_ptr[i])); - } - } - } - else { - // src1 is not contiguous - LM_GGML_ABORT("fatal error"); - } -} - -static void lm_ggml_compute_forward_add_q_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); - - const int nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const enum lm_ggml_type type = src0->type; - const enum lm_ggml_type dtype = dst->type; - lm_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - lm_ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float; - - // we don't support permuted src0 or src1 - LM_GGML_ASSERT(nb00 == lm_ggml_type_size(type)); - LM_GGML_ASSERT(nb10 == sizeof(float)); - - // dst cannot be transposed or permuted - LM_GGML_ASSERT(nb0 <= nb1); - LM_GGML_ASSERT(nb1 <= nb2); - LM_GGML_ASSERT(nb2 <= nb3); - - LM_GGML_ASSERT(lm_ggml_is_quantized(src0->type)); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - - // src1 and dst are same shape as src0 => same indices - const int i13 = i03; - const int i12 = i02; - const int i11 = i01; - - const int i3 = i03; - const int i2 = i02; - const int i1 = i01; - - void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)); - void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - assert(ne00 % 32 == 0); - - // unquantize row from src0 to temp buffer - dequantize_row_q(src0_row, wdata, ne00); - // add src1 - lm_ggml_vec_acc_f32(ne00, wdata, src1_row); - // quantize row to dst - if (quantize_row_q != NULL) { - quantize_row_q(wdata, dst_row, ne00); - } else { - memcpy(dst_row, wdata, ne0*nb0); - } - } -} - -static void lm_ggml_compute_forward_add( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - if (src1->type == LM_GGML_TYPE_F32) { - lm_ggml_compute_forward_add_f32(params, dst); - } - else { - LM_GGML_ABORT("fatal error"); - } - } break; - case LM_GGML_TYPE_F16: - { - if (src1->type == LM_GGML_TYPE_F16) { - lm_ggml_compute_forward_add_f16_f16(params, dst); - } - else if (src1->type == LM_GGML_TYPE_F32) { - lm_ggml_compute_forward_add_f16_f32(params, dst); - } - else { - LM_GGML_ABORT("fatal error"); - } - } break; - case LM_GGML_TYPE_BF16: - { - if (src1->type == LM_GGML_TYPE_BF16) { - lm_ggml_compute_forward_add_bf16_bf16(params, dst); - } - else if (src1->type == LM_GGML_TYPE_F32) { - lm_ggml_compute_forward_add_bf16_f32(params, dst); - } - else { - LM_GGML_ABORT("fatal error"); - } - } break; - case LM_GGML_TYPE_Q4_0: - case LM_GGML_TYPE_Q4_1: - case LM_GGML_TYPE_Q5_0: - case LM_GGML_TYPE_Q5_1: - case LM_GGML_TYPE_Q8_0: - case LM_GGML_TYPE_Q2_K: - case LM_GGML_TYPE_Q3_K: - case LM_GGML_TYPE_Q4_K: - case LM_GGML_TYPE_Q5_K: - case LM_GGML_TYPE_Q6_K: - case LM_GGML_TYPE_TQ1_0: - case LM_GGML_TYPE_TQ2_0: - case LM_GGML_TYPE_IQ2_XXS: - case LM_GGML_TYPE_IQ2_XS: - case LM_GGML_TYPE_IQ3_XXS: - case LM_GGML_TYPE_IQ1_S: - case LM_GGML_TYPE_IQ1_M: - case LM_GGML_TYPE_IQ4_NL: - case LM_GGML_TYPE_IQ4_XS: - case LM_GGML_TYPE_IQ3_S: - case LM_GGML_TYPE_IQ2_S: - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - case LM_GGML_TYPE_Q4_0_8_8: - { - lm_ggml_compute_forward_add_q_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_add1 - -static void lm_ggml_compute_forward_add1_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - LM_GGML_ASSERT( nb0 == sizeof(float)); - LM_GGML_ASSERT(nb00 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - -#ifdef LM_GGML_USE_ACCELERATE - UNUSED(lm_ggml_vec_add1_f32); - - vDSP_vadd( - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, - (float *) ((char *) src1->data), 0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne0); -#else - lm_ggml_vec_add1_f32(ne0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), - *(float *) src1->data); -#endif - } -} - -static void lm_ggml_compute_forward_add1_f16_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); - - // scalar to add - const float v = *(float *) src1->data; - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F16); - - LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_fp16_t)); - LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(src0_ptr[i]) + v); - } - } -} - -static void lm_ggml_compute_forward_add1_f16_f16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); - - // scalar to add - const float v = LM_GGML_FP16_TO_FP32(*(lm_ggml_fp16_t *) src1->data); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F16); - LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F16); - - LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_fp16_t)); - LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(src0_ptr[i]) + v); - } - } -} - -static void lm_ggml_compute_forward_add1_q_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); - - // scalar to add - const float v = *(float *) src1->data; - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - const enum lm_ggml_type type = src0->type; - lm_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - lm_ggml_from_float_t const quantize_row_q = type_traits[type].from_float; - - // we don't support permuted src0 - LM_GGML_ASSERT(nb00 == lm_ggml_type_size(type)); - - // dst cannot be transposed or permuted - LM_GGML_ASSERT(nb0 <= nb1); - LM_GGML_ASSERT(nb1 <= nb2); - LM_GGML_ASSERT(nb2 <= nb3); - - LM_GGML_ASSERT(lm_ggml_is_quantized(src0->type)); - LM_GGML_ASSERT(dst->type == src0->type); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03)); - void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 )); - - assert(ne0 % 32 == 0); - - // unquantize row from src0 to temp buffer - dequantize_row_q(src0_row, wdata, ne0); - // add src1 - lm_ggml_vec_acc1_f32(ne0, wdata, v); - // quantize row to dst - quantize_row_q(wdata, dst_row, ne0); - } -} - -static void lm_ggml_compute_forward_add1_bf16_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); - - // scalar to add - const float v = *(float *) src1->data; - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_BF16); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_BF16); - - LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_bf16_t)); - LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_bf16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - lm_ggml_bf16_t * dst_ptr = (lm_ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = LM_GGML_FP32_TO_BF16(LM_GGML_BF16_TO_FP32(src0_ptr[i]) + v); - } - } -} - -static void lm_ggml_compute_forward_add1_bf16_bf16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - LM_GGML_ASSERT(lm_ggml_is_scalar(src1)); - - // scalar to add - const float v = LM_GGML_BF16_TO_FP32(*(lm_ggml_bf16_t *) src1->data); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_BF16); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_BF16); - LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_BF16); - - LM_GGML_ASSERT( nb0 == sizeof(lm_ggml_bf16_t)); - LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_bf16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - lm_ggml_bf16_t * dst_ptr = (lm_ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = LM_GGML_FP32_TO_BF16(LM_GGML_BF16_TO_FP32(src0_ptr[i]) + v); - } - } -} - -static void lm_ggml_compute_forward_add1( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_add1_f32(params, dst); - } break; - case LM_GGML_TYPE_F16: - { - if (src1->type == LM_GGML_TYPE_F16) { - lm_ggml_compute_forward_add1_f16_f16(params, dst); - } - else if (src1->type == LM_GGML_TYPE_F32) { - lm_ggml_compute_forward_add1_f16_f32(params, dst); - } - else { - LM_GGML_ABORT("fatal error"); - } - } break; - case LM_GGML_TYPE_BF16: - { - if (src1->type == LM_GGML_TYPE_BF16) { - lm_ggml_compute_forward_add1_bf16_bf16(params, dst); - } - else if (src1->type == LM_GGML_TYPE_F32) { - lm_ggml_compute_forward_add1_bf16_f32(params, dst); - } - else { - LM_GGML_ABORT("fatal error"); - } - } break; - case LM_GGML_TYPE_Q4_0: - case LM_GGML_TYPE_Q4_1: - case LM_GGML_TYPE_Q5_0: - case LM_GGML_TYPE_Q5_1: - case LM_GGML_TYPE_Q8_0: - case LM_GGML_TYPE_Q8_1: - case LM_GGML_TYPE_Q2_K: - case LM_GGML_TYPE_Q3_K: - case LM_GGML_TYPE_Q4_K: - case LM_GGML_TYPE_Q5_K: - case LM_GGML_TYPE_Q6_K: - case LM_GGML_TYPE_TQ1_0: - case LM_GGML_TYPE_TQ2_0: - case LM_GGML_TYPE_IQ2_XXS: - case LM_GGML_TYPE_IQ2_XS: - case LM_GGML_TYPE_IQ3_XXS: - case LM_GGML_TYPE_IQ1_S: - case LM_GGML_TYPE_IQ1_M: - case LM_GGML_TYPE_IQ4_NL: - case LM_GGML_TYPE_IQ4_XS: - case LM_GGML_TYPE_IQ3_S: - case LM_GGML_TYPE_IQ2_S: - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - case LM_GGML_TYPE_Q4_0_8_8: - { - lm_ggml_compute_forward_add1_q_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_acc - -static void lm_ggml_compute_forward_acc_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); - - // view src0 and dst with these strides and data offset inbytes during acc - // nb0 is implicitly element_size because src0 and dst are contiguous - size_t nb1 = ((int32_t *) dst->op_params)[0]; - size_t nb2 = ((int32_t *) dst->op_params)[1]; - size_t nb3 = ((int32_t *) dst->op_params)[2]; - size_t offset = ((int32_t *) dst->op_params)[3]; - bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace) { - if (params->ith == 0) { - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - memcpy( - ((char *) dst->data), - ((char *) src0->data), - lm_ggml_nbytes(dst)); - } - lm_ggml_barrier(params->threadpool); - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src1); - const int nc = src1->ne[0]; - - LM_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) - LM_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) - - // src0 and dst as viewed during acc - const size_t nb0 = lm_ggml_element_size(src0); - - const size_t nb00 = nb0; - const size_t nb01 = nb1; - const size_t nb02 = nb2; - const size_t nb03 = nb3; - - LM_GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < lm_ggml_nbytes(dst)); - LM_GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < lm_ggml_nbytes(src0)); - - LM_GGML_ASSERT(nb10 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are viewed with shape of src1 and offset - // => same indices - const int i3 = ir/(ne12*ne11); - const int i2 = (ir - i3*ne12*ne11)/ne11; - const int i1 = (ir - i3*ne12*ne11 - i2*ne11); - -#ifdef LM_GGML_USE_ACCELERATE - vDSP_vadd( - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1, - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc); -#else - lm_ggml_vec_add_f32(nc, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); -#endif - } -} - -static void lm_ggml_compute_forward_acc( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_acc_f32(params, dst); - } break; - case LM_GGML_TYPE_F16: - case LM_GGML_TYPE_BF16: - case LM_GGML_TYPE_Q4_0: - case LM_GGML_TYPE_Q4_1: - case LM_GGML_TYPE_Q5_0: - case LM_GGML_TYPE_Q5_1: - case LM_GGML_TYPE_Q8_0: - case LM_GGML_TYPE_Q8_1: - case LM_GGML_TYPE_Q2_K: - case LM_GGML_TYPE_Q3_K: - case LM_GGML_TYPE_Q4_K: - case LM_GGML_TYPE_Q5_K: - case LM_GGML_TYPE_Q6_K: - case LM_GGML_TYPE_TQ1_0: - case LM_GGML_TYPE_TQ2_0: - case LM_GGML_TYPE_IQ2_XXS: - case LM_GGML_TYPE_IQ2_XS: - case LM_GGML_TYPE_IQ3_XXS: - case LM_GGML_TYPE_IQ1_S: - case LM_GGML_TYPE_IQ1_M: - case LM_GGML_TYPE_IQ4_NL: - case LM_GGML_TYPE_IQ4_XS: - case LM_GGML_TYPE_IQ3_S: - case LM_GGML_TYPE_IQ2_S: - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - case LM_GGML_TYPE_Q4_0_8_8: - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_sub - -static void lm_ggml_compute_forward_sub_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - assert(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - LM_GGML_ASSERT( nb0 == sizeof(float)); - LM_GGML_ASSERT(nb00 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(float)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - for (int64_t r = 0; r < nr0; ++r) { -#ifdef LM_GGML_USE_ACCELERATE - vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); -#else - lm_ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); -#endif - } - } - } else { - // src1 is not contiguous - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int64_t i10 = i0 % ne10; - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); - - dst_ptr[i0] = src0_ptr[i0] - *src1_ptr; - } - } - } -} - -static void lm_ggml_compute_forward_sub( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_sub_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_mul - -static void lm_ggml_compute_forward_mul_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - LM_GGML_ASSERT( nb0 == sizeof(float)); - LM_GGML_ASSERT(nb00 == sizeof(float)); - - if (nb10 == sizeof(float)) { - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - for (int64_t r = 0 ; r < nr0; ++r) { -#ifdef LM_GGML_USE_ACCELERATE - UNUSED(lm_ggml_vec_mul_f32); - - vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10); -#else - lm_ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); -#endif - } - } - } else { - // src1 is not contiguous - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - - for (int64_t i0 = 0; i0 < ne00; ++i0) { - const int64_t i10 = i0 % ne10; - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); - - dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); - } - } - } -} - -static void lm_ggml_compute_forward_mul( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32 && "only f32 src1 supported for now"); - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_mul_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_div - -static void lm_ggml_compute_forward_div_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - LM_GGML_ASSERT( nb0 == sizeof(float)); - LM_GGML_ASSERT(nb00 == sizeof(float)); - - if (nb10 == sizeof(float)) { - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - for (int64_t r = 0; r < nr0; ++r) { -#ifdef LM_GGML_USE_ACCELERATE - UNUSED(lm_ggml_vec_div_f32); - - vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); -#else - lm_ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); -#endif - } - } - } else { - // src1 is not contiguous - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - - for (int64_t i0 = 0; i0 < ne00; ++i0) { - const int64_t i10 = i0 % ne10; - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); - - dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr); - } - } - } -} - -static void lm_ggml_compute_forward_div( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_div_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_sqr - -static void lm_ggml_compute_forward_sqr_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - lm_ggml_vec_sqr_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_sqr( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_sqr_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_sqrt - -static void lm_ggml_compute_forward_sqrt_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - lm_ggml_vec_sqrt_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_sqrt( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_sqrt_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_log - -static void lm_ggml_compute_forward_log_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - LM_GGML_ASSERT( dst->nb[0] == sizeof(float)); - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - lm_ggml_vec_log_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_log( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_log_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_sin - -static void lm_ggml_compute_forward_sin_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - LM_GGML_ASSERT( dst->nb[0] == sizeof(float)); - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - lm_ggml_vec_sin_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_sin( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_sin_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_cos - -static void lm_ggml_compute_forward_cos_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - LM_GGML_ASSERT( dst->nb[0] == sizeof(float)); - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - lm_ggml_vec_cos_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_cos( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_cos_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_sum - -static void lm_ggml_compute_forward_sum_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_scalar(dst)); - assert(src0->nb[0] == sizeof(float)); - - LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - LM_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) - - lm_ggml_float sum = 0; - lm_ggml_float row_sum = 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - lm_ggml_vec_sum_f32_ggf(ne00, - &row_sum, - (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); - sum += row_sum; - } - } - } - ((float *) dst->data)[0] = sum; -} - -static void lm_ggml_compute_forward_sum_f16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_scalar(dst)); - - assert(src0->nb[0] == sizeof(lm_ggml_fp16_t)); - - LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - LM_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) - - float sum = 0; - float row_sum = 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - lm_ggml_vec_sum_f16_ggf(ne00, - &row_sum, - (lm_ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); - sum += row_sum; - } - } - } - ((lm_ggml_fp16_t *) dst->data)[0] = LM_GGML_FP32_TO_FP16(sum); -} - -static void lm_ggml_compute_forward_sum_bf16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_scalar(dst)); - - assert(src0->nb[0] == sizeof(lm_ggml_bf16_t)); - - LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - LM_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) - - float sum = 0; - float row_sum = 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - lm_ggml_vec_sum_bf16_ggf(ne00, - &row_sum, - (lm_ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); - sum += row_sum; - } - } - } - ((lm_ggml_bf16_t *) dst->data)[0] = LM_GGML_FP32_TO_BF16(sum); -} - -static void lm_ggml_compute_forward_sum( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_sum_f32(params, dst); - } break; - case LM_GGML_TYPE_F16: - { - lm_ggml_compute_forward_sum_f16(params, dst); - } break; - case LM_GGML_TYPE_BF16: - { - lm_ggml_compute_forward_sum_bf16(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_sum_rows - -static void lm_ggml_compute_forward_sum_rows_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - LM_GGML_ASSERT(dst->nb[0] == sizeof(float)); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - LM_GGML_ASSERT(ne0 == 1); - LM_GGML_ASSERT(ne1 == ne01); - LM_GGML_ASSERT(ne2 == ne02); - LM_GGML_ASSERT(ne3 == ne03); - - for (int64_t i3 = 0; i3 < ne03; i3++) { - for (int64_t i2 = 0; i2 < ne02; i2++) { - for (int64_t i1 = 0; i1 < ne01; i1++) { - float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); - float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); - float row_sum = 0; - lm_ggml_vec_sum_f32(ne00, &row_sum, src_row); - dst_row[0] = row_sum; - } - } - } -} - -static void lm_ggml_compute_forward_sum_rows( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_sum_rows_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_mean - -static void lm_ggml_compute_forward_mean_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(src0->nb[0] == sizeof(float)); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - assert(ne0 == 1); - assert(ne1 == ne01); - assert(ne2 == ne02); - assert(ne3 == ne03); - - UNUSED(ne0); - UNUSED(ne1); - UNUSED(ne2); - UNUSED(ne3); - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - lm_ggml_vec_sum_f32(ne00, - (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); - - *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; - } - } - } -} - -static void lm_ggml_compute_forward_mean( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_mean_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_argmax - -static void lm_ggml_compute_forward_argmax_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(src0->nb[0] == sizeof(float)); - assert(dst->nb[0] == sizeof(float)); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - - const size_t nb01 = src0->nb[1]; - const size_t nb0 = dst->nb[0]; - - for (int64_t i1 = 0; i1 < ne01; i1++) { - float * src = (float *) ((char *) src0->data + i1*nb01); - int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0); - int v = 0; - lm_ggml_vec_argmax_f32(ne00, &v, src); - dst_[0] = v; - } -} - -static void lm_ggml_compute_forward_argmax( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_argmax_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_count_equal - -static void lm_ggml_compute_forward_count_equal_i32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_TENSOR_BINARY_OP_LOCALS; - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_I32); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_I32); - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1)); - LM_GGML_ASSERT(lm_ggml_is_scalar(dst)); - LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_I64); - - const int64_t nr = lm_ggml_nrows(src0); - - const int ith = params->ith; - const int nth = params->nth; - - int64_t * sums = (int64_t *) params->wdata; - int64_t sum_thread = 0; - - // rows per thread - const int64_t dr = (nr + nth - 1)/nth; - - // row range for this thread - const int64_t ir0 = dr*ith; - const int64_t ir1 = MIN(ir0 + dr, nr); - - for (int64_t ir = ir0; ir < ir1; ++ir) { - const int64_t i03 = ir / (ne02*ne01); - const int64_t i02 = (ir - i03*ne03) / ne01; - const int64_t i01 = ir - i03*ne03 - i02*ne02; - - const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01; - const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11; - - for (int64_t i00 = 0; i00 < ne00; ++i00) { - const int32_t val0 = *((const int32_t *) (data0 + i00*nb00)); - const int32_t val1 = *((const int32_t *) (data1 + i00*nb10)); - - sum_thread += val0 == val1; - } - } - if (ith != 0) { - sums[ith] = sum_thread; - } - lm_ggml_barrier(params->threadpool); - - if (ith != 0) { - return; - } - - for (int ith_other = 1; ith_other < nth; ++ith_other) { - sum_thread += sums[ith_other]; - } - *((int64_t *) dst->data) = sum_thread; -} - -static void lm_ggml_compute_forward_count_equal( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_I32: - { - lm_ggml_compute_forward_count_equal_i32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_repeat - -static void lm_ggml_compute_forward_repeat_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - LM_GGML_ASSERT(lm_ggml_can_repeat(src0, dst)); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - // guaranteed to be an integer due to the check in lm_ggml_can_repeat - const int nr0 = (int)(ne0/ne00); - const int nr1 = (int)(ne1/ne01); - const int nr2 = (int)(ne2/ne02); - const int nr3 = (int)(ne3/ne03); - - // TODO: support for transposed / permuted tensors - LM_GGML_ASSERT(nb0 == sizeof(float)); - LM_GGML_ASSERT(nb00 == sizeof(float)); - - // TODO: maybe this is not optimal? - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne03; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne02; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne01; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - lm_ggml_vec_cpy_f32(ne00, - (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0), - (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01)); - } - } - } - } - } - } - } -} - -static void lm_ggml_compute_forward_repeat_f16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - LM_GGML_ASSERT(lm_ggml_can_repeat(src0, dst)); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - // guaranteed to be an integer due to the check in lm_ggml_can_repeat - const int nr0 = (int)(ne0/ne00); - const int nr1 = (int)(ne1/ne01); - const int nr2 = (int)(ne2/ne02); - const int nr3 = (int)(ne3/ne03); - - // TODO: support for transposed / permuted tensors - LM_GGML_ASSERT(nb0 == sizeof(lm_ggml_fp16_t)); - LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); - - // TODO: maybe this is not optimal? - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne03; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne02; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne01; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - lm_ggml_fp16_t * y = (lm_ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0); - lm_ggml_fp16_t * x = (lm_ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01); - // lm_ggml_vec_cpy_f16(ne00, y, x) - for (int i = 0; i < ne00; ++i) { - y[i] = x[i]; - } - } - } - } - } - } - } - } -} - -static void lm_ggml_compute_forward_repeat( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F16: - case LM_GGML_TYPE_BF16: - case LM_GGML_TYPE_I16: - { - lm_ggml_compute_forward_repeat_f16(params, dst); - } break; - case LM_GGML_TYPE_F32: - case LM_GGML_TYPE_I32: - { - lm_ggml_compute_forward_repeat_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_repeat_back - -static void lm_ggml_compute_forward_repeat_back_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - LM_GGML_ASSERT(lm_ggml_can_repeat(dst, src0)); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - // guaranteed to be an integer due to the check in lm_ggml_can_repeat - const int nr0 = (int)(ne00/ne0); - const int nr1 = (int)(ne01/ne1); - const int nr2 = (int)(ne02/ne2); - const int nr3 = (int)(ne03/ne3); - - // TODO: support for transposed / permuted tensors - LM_GGML_ASSERT(nb0 == sizeof(float)); - LM_GGML_ASSERT(nb00 == sizeof(float)); - - if (lm_ggml_is_contiguous(dst)) { - lm_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); - } else { - for (int k3 = 0; k3 < ne3; k3++) { - for (int k2 = 0; k2 < ne2; k2++) { - for (int k1 = 0; k1 < ne1; k1++) { - lm_ggml_vec_set_f32(ne0, - (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3), - 0); - } - } - } - } - - // TODO: maybe this is not optimal? - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne3; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne2; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne1; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - lm_ggml_vec_acc_f32(ne0, - (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1), - (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00)); - } - } - } - } - } - } - } -} - -static void lm_ggml_compute_forward_repeat_back( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_repeat_back_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_concat - -static void lm_ggml_compute_forward_concat_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - const int32_t dim = lm_ggml_get_op_params_i32(dst, 0); - - LM_GGML_ASSERT(dim >= 0 && dim < 4); - - int64_t o[4] = {0, 0, 0, 0}; - o[dim] = src0->ne[dim]; - - const float * x; - - // TODO: smarter multi-theading - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = ith; i2 < ne2; i2 += nth) { - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); - } else { - x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); - } - - float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); - - *y = *x; - } - } - } - } -} - -static void lm_ggml_compute_forward_concat( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - case LM_GGML_TYPE_I32: - { - lm_ggml_compute_forward_concat_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_abs - -static void lm_ggml_compute_forward_abs_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - lm_ggml_vec_abs_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_abs( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_abs_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_sgn - -static void lm_ggml_compute_forward_sgn_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - lm_ggml_vec_sgn_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_sgn( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_sgn_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_neg - -static void lm_ggml_compute_forward_neg_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - lm_ggml_vec_neg_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_neg( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_neg_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_step - -static void lm_ggml_compute_forward_step_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - lm_ggml_vec_step_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_step( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_step_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_tanh - -static void lm_ggml_compute_forward_tanh_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - lm_ggml_vec_tanh_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_tanh( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_tanh_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_elu - -static void lm_ggml_compute_forward_elu_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - lm_ggml_vec_elu_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_elu( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_elu_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_relu - -static void lm_ggml_compute_forward_relu_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - lm_ggml_vec_relu_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_relu( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_relu_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_sigmoid - -static void lm_ggml_compute_forward_sigmoid_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - lm_ggml_vec_sigmoid_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_sigmoid( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_sigmoid_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_gelu - -static void lm_ggml_compute_forward_gelu_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = lm_ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - lm_ggml_vec_gelu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void lm_ggml_compute_forward_gelu( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_gelu_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_gelu_quick - -static void lm_ggml_compute_forward_gelu_quick_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = lm_ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - lm_ggml_vec_gelu_quick_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void lm_ggml_compute_forward_gelu_quick( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_gelu_quick_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_silu - -static void lm_ggml_compute_forward_silu_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = lm_ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - lm_ggml_vec_silu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void lm_ggml_compute_forward_silu( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_silu_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} -// lm_ggml_compute_forward_leaky_relu - -static void lm_ggml_compute_forward_leaky_relu_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - float negative_slope; - memcpy(&negative_slope, dst->op_params, sizeof(float)); - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - lm_ggml_vec_leaky_relu_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope); - } -} - -static void lm_ggml_compute_forward_leaky_relu( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_leaky_relu_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_silu_back - -static void lm_ggml_compute_forward_silu_back_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * grad = dst->src[1]; - - assert(lm_ggml_is_contiguous_1(grad)); - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - assert(lm_ggml_are_same_shape(src0, grad)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = lm_ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - lm_ggml_vec_silu_backward_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1])), - (float *) ((char *) grad->data + i1*(grad->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void lm_ggml_compute_forward_silu_back( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_silu_back_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - - -static void lm_ggml_compute_forward_hardswish_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - lm_ggml_vec_hardswish_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} -static void lm_ggml_compute_forward_hardswish( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_hardswish_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -static void lm_ggml_compute_forward_hardsigmoid_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - lm_ggml_vec_hardsigmoid_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_hardsigmoid( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_hardsigmoid_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -static void lm_ggml_compute_forward_exp_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - lm_ggml_vec_exp_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_exp( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_exp_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - - -// lm_ggml_compute_forward_norm - -static void lm_ggml_compute_forward_norm_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - LM_GGML_ASSERT(eps > 0.0f); - - // TODO: optimize - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - lm_ggml_float sum = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (lm_ggml_float)x[i00]; - } - - float mean = sum/ne00; - - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - lm_ggml_float sum2 = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - float v = x[i00] - mean; - y[i00] = v; - sum2 += (lm_ggml_float)(v*v); - } - - float variance = sum2/ne00; - const float scale = 1.0f/sqrtf(variance + eps); - - lm_ggml_vec_scale_f32(ne00, y, scale); - } - } - } -} - -static void lm_ggml_compute_forward_norm( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_norm_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_group_rms_norm - -static void lm_ggml_compute_forward_rms_norm_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - LM_GGML_ASSERT(eps > 0.0f); - - // TODO: optimize - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - lm_ggml_float sum = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (lm_ggml_float)(x[i00] * x[i00]); - } - - const float mean = sum/ne00; - - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - memcpy(y, x, ne00 * sizeof(float)); - // for (int i00 = 0; i00 < ne00; i00++) { - // y[i00] = x[i00]; - // } - - const float scale = 1.0f/sqrtf(mean + eps); - - lm_ggml_vec_scale_f32(ne00, y, scale); - } - } - } -} - -static void lm_ggml_compute_forward_rms_norm( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_rms_norm_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -static void lm_ggml_compute_forward_rms_norm_back_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst) && lm_ggml_are_same_shape(src0, src1)); - - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - // TODO: optimize - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - // src1 is same shape as src0 => same indices - const int64_t i11 = i01; - const int64_t i12 = i02; - const int64_t i13 = i03; - - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); - - lm_ggml_float sum_xx = 0.0; - lm_ggml_float sum_xdz = 0.0; - - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum_xx += (lm_ggml_float)(x[i00] * x[i00]); - sum_xdz += (lm_ggml_float)(x[i00] * dz[i00]); - } - - //const float mean = (float)(sum_xx)/ne00; - const float mean_eps = (float)(sum_xx)/ne00 + eps; - const float sum_eps = (float)(sum_xx) + eps*ne00; - //const float mean_xdz = (float)(sum_xdz)/ne00; - // we could cache rms from forward pass to improve performance. - // to do this implement lm_ggml_rms and compose lm_ggml_rms_norm using lm_ggml_rms. - //const float rms = sqrtf(mean_eps); - const float rrms = 1.0f / sqrtf(mean_eps); - //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3) - - { - // z = rms_norm(x) - // - // rms_norm(src0) = - // scale( - // src0, - // div( - // 1, - // sqrt( - // add( - // scale( - // sum( - // sqr( - // src0)), - // (1.0/N)), - // eps)))); - - // postorder: - // ## op args grad - // 00 param src0 grad[#00] - // 01 const 1 - // 02 sqr (#00) grad[#02] - // 03 sum (#02) grad[#03] - // 04 const 1/N - // 05 scale (#03, #04) grad[#05] - // 06 const eps - // 07 add (#05, #06) grad[#07] - // 08 sqrt (#07) grad[#08] - // 09 div (#01,#08) grad[#09] - // 10 scale (#00,#09) grad[#10] - // - // backward pass, given grad[#10] - // #10: scale - // grad[#00] += scale(grad[#10],#09) - // grad[#09] += sum(mul(grad[#10],#00)) - // #09: div - // grad[#08] += neg(mul(grad[#09], div(#09,#08))) - // #08: sqrt - // grad[#07] += mul(grad[#08], div(0.5, #08)) - // #07: add - // grad[#05] += grad[#07] - // #05: scale - // grad[#03] += scale(grad[#05],#04) - // #03: sum - // grad[#02] += repeat(grad[#03], #02) - // #02: - // grad[#00] += scale(mul(#00, grad[#02]), 2.0) - // - // substitute and simplify: - // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) - // grad[#02] = repeat(grad[#03], #02) - // grad[#02] = repeat(scale(grad[#05],#04), #02) - // grad[#02] = repeat(scale(grad[#07],#04), #02) - // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02) - // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02) - // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02) - // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02) - // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02) - // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02) - // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02) - // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) - // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0) - // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0) - // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N))) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N)) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps)) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps))) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps)) - // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps)) - // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps)) - // a = b*c + d*e - // a = b*c*f/f + d*e*f/f - // a = (b*c*f + d*e*f)*(1/f) - // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c)) - // a = (b + d*e/c)*c - // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps) - // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms - // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms - // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms - // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms - // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms - // a = (dz + x*div(-mean_xdz,mean_eps))*rrms - // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms) - // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) - // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) - } - // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) - // post-order: - // dx := x - // dx := scale(dx,-mean_xdz/mean_eps) - // dx := add(dx, dz) - // dx := scale(dx, rrms) - float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - lm_ggml_vec_cpy_f32 (ne00, dx, x); - // lm_ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); - lm_ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); - lm_ggml_vec_acc_f32 (ne00, dx, dz); - lm_ggml_vec_scale_f32(ne00, dx, rrms); - } - } - } -} - -static void lm_ggml_compute_forward_rms_norm_back( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_rms_norm_back_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_group_norm - -static void lm_ggml_compute_forward_group_norm_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - // TODO: optimize - - float eps; - memcpy(&eps, dst->op_params + 1, sizeof(float)); - - int n_channels = src0->ne[2]; - int n_groups = dst->op_params[0]; - int n_channels_per_group = (n_channels + n_groups - 1) / n_groups; - for (int i = ith; i < n_groups; i += nth) { - int start = i * n_channels_per_group; - int end = start + n_channels_per_group; - if (end > n_channels) { - end = n_channels; - } - int step = end - start; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - lm_ggml_float sum = 0.0; - for (int64_t i02 = start; i02 < end; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); - - lm_ggml_float sumr = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sumr += (lm_ggml_float)x[i00]; - } - sum += sumr; - } - } - const float mean = sum / (ne00 * ne01 * step); - - lm_ggml_float sum2 = 0.0; - for (int64_t i02 = start; i02 < end; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); - - float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); - - lm_ggml_float sumr = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - float v = x[i00] - mean; - y[i00] = v; - sumr += (lm_ggml_float)(v * v); - } - sum2 += sumr; - } - } - const float variance = sum2 / (ne00 * ne01 * step); - const float scale = 1.0f / sqrtf(variance + eps); - - for (int64_t i02 = start; i02 < end; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); - lm_ggml_vec_scale_f32(ne00, y, scale); - } - } - } - } -} - -static void lm_ggml_compute_forward_group_norm( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_group_norm_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_mul_mat - -static void lm_ggml_compute_forward_mul_mat_one_chunk( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst, - const int64_t num_rows_per_vec_dot, - const int64_t ir0_start, - const int64_t ir0_end, - const int64_t ir1_start, - const int64_t ir1_end) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - const enum lm_ggml_type type = src0->type; - - const bool src1_cont = lm_ggml_is_contiguous(src1); - - lm_ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; - enum lm_ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - - // broadcast factors - const int64_t r2 = ne12 / ne02; - const int64_t r3 = ne13 / ne03; - - //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end); - - // threads with no work simply yield (not sure if it helps) - if (ir0_start >= ir0_end || ir1_start >= ir1_end) { - return; - } - - const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = lm_ggml_row_size(vec_dot_type, ne10); - - assert(ne12 % ne02 == 0); - assert(ne13 % ne03 == 0); - - // block-tiling attempt - const int64_t blck_0 = 16; - const int64_t blck_1 = 16; - - const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11; - - // attempt to reduce false-sharing (does not seem to make a difference) - // 16 * 2, accounting for mmla kernels - float tmp[32]; - - for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { - for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { - for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) { - const int64_t i13 = (ir1 / (ne12 * ne1)); - const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; - const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); - - // broadcast src0 into src1 - const int64_t i03 = i13 / r3; - const int64_t i02 = i12 / r2; - - const int64_t i1 = i11; - const int64_t i2 = i12; - const int64_t i3 = i13; - - const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03); - - // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides - // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using - // the original src1 data pointer, so we should index using the indices directly - // TODO: this is a bit of a hack, we should probably have a better way to handle this - const char * src1_col = (const char*)wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size - : (i11 * nb11 + i12 * nb12 + i13 * nb13)); - float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); - - //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { - // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); - //} - - for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { - vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); - } - - for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { - memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); - } - } - } - } -} - -static void lm_ggml_compute_forward_mul_mat( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const enum lm_ggml_type type = src0->type; - - enum lm_ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - lm_ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; - lm_ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat; - int64_t const vec_dot_num_rows = type_traits[type].nrows; - int64_t const matmul_num_cols = type_traits[type].ncols; - int64_t const blck_size_interleave = type_traits[type].blck_size_interleave; - lm_ggml_gemv_t const gemv = type_traits[type].gemv; - lm_ggml_gemm_t const gemm = type_traits[type].gemm; - - LM_GGML_ASSERT(ne0 == ne01); - LM_GGML_ASSERT(ne1 == ne11); - LM_GGML_ASSERT(ne2 == ne12); - LM_GGML_ASSERT(ne3 == ne13); - - // we don't support permuted src0 or src1 - LM_GGML_ASSERT(nb00 == lm_ggml_type_size(type)); - LM_GGML_ASSERT(nb10 == lm_ggml_type_size(src1->type)); - - // dst cannot be transposed or permuted - LM_GGML_ASSERT(nb0 == sizeof(float)); - LM_GGML_ASSERT(nb0 <= nb1); - LM_GGML_ASSERT(nb1 <= nb2); - LM_GGML_ASSERT(nb2 <= nb3); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - -#if LM_GGML_USE_LLAMAFILE - // broadcast factors - const int64_t r2 = ne12 / ne02; - const int64_t r3 = ne13 / ne03; - - const bool src1_cont = lm_ggml_is_contiguous(src1); - - if (src1_cont) { - for (int64_t i13 = 0; i13 < ne13; i13++) - for (int64_t i12 = 0; i12 < ne12; i12++) - if (!llamafile_sgemm(ne01, ne11, ne00/lm_ggml_blck_size(src0->type), - (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, - nb01/lm_ggml_type_size(src0->type), - (const char *)src1->data + i12*nb12 + i13*nb13, - nb11/lm_ggml_type_size(src1->type), - (char *)dst->data + i12*nb2 + i13*nb3, - nb1/lm_ggml_type_size(dst->type), - ith, nth, - src0->type, - src1->type, - dst->type)) - goto UseGgmlGemm1; - return; - } -UseGgmlGemm1:; -#endif - - if (src1->type != vec_dot_type) { - char * wdata = params->wdata; - - const size_t nbw1 = lm_ggml_row_size(vec_dot_type, ne10); - const size_t nbw2 = nbw1*ne11; - const size_t nbw3 = nbw2*ne12; - - assert(params->wsize >= ne13*nbw3); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - int64_t i11_processed = 0; - if ((lm_ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) { - for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { - from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - 4, ne10, blck_size_interleave); - } - i11_processed = ne11 - ne11 % 4; - } - for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { - from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - ne10); - } - } - } - } - - if (ith == 0) { - // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - atomic_store_explicit(¶ms->threadpool->current_chunk, nth, memory_order_relaxed); - } - - lm_ggml_barrier(params->threadpool); - -#if LM_GGML_USE_LLAMAFILE - if (src1->type != vec_dot_type) { - const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = lm_ggml_row_size(vec_dot_type, ne10); - - for (int64_t i13 = 0; i13 < ne13; i13++) - for (int64_t i12 = 0; i12 < ne12; i12++) - if (!llamafile_sgemm(ne01, ne11, ne00/lm_ggml_blck_size(src0->type), - (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, - nb01/lm_ggml_type_size(src0->type), - (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, - row_size/lm_ggml_type_size(vec_dot_type), - (char *)dst->data + i12*nb2 + i13*nb3, - nb1/lm_ggml_type_size(dst->type), - ith, nth, - src0->type, - vec_dot_type, - dst->type)) - goto UseGgmlGemm2; - return; - } -UseGgmlGemm2:; -#endif - - // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) - const int64_t nr0 = ne0; - - // This is the size of the rest of the dimensions of the result - const int64_t nr1 = ne1 * ne2 * ne3; - - // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols - int64_t num_rows_per_vec_dot = vec_dot_num_rows; - // TODO: currently the mmla kernels support only even numbered rows/cols. - // this check can be removed once they are extended to support odd numbered rows/cols too - if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { - num_rows_per_vec_dot = 1; - } - - // Now select a reasonable chunk size. - int chunk_size = 16; - - // We need to step up the size if it's small - if (nr0 == 1 || nr1 == 1) { - chunk_size = 64; - } - - // distribute the work across the inner or outer loop based on which one is larger - // The number of chunks in the 0/1 dim. - // CEIL(nr0/chunk_size) - int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; - int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; - - // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread. - // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915 - // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that. - if (nchunk0 * nchunk1 < nth * 4 || lm_ggml_is_numa()) { - // distribute the thread work across the inner or outer loop based on which one is larger - nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - } - - // The number of elements in each chunk - const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; - const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - - if ((lm_ggml_n_dims(src0) == 2) && gemv) { - const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t src1_col_stride = lm_ggml_is_contiguous(src1) || src1->type != vec_dot_type ? lm_ggml_row_size(vec_dot_type, ne10) : nb11; - int64_t src0_start = (ith * ne01) / nth; - int64_t src0_end = ((ith + 1) * ne01) / nth; - src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start; - src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end; - if (src0_start >= src0_end) return; - - // If there are more than three rows in src1, use gemm; otherwise, use gemv. - if (gemm && (ne11 > 3)) { - gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01, - (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); - } - for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) { - gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01, - (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1, - src0_end - src0_start); - } - return; - } - - // The first chunk comes from our thread_id, the rest will get auto-assigned. - int current_chunk = ith; - - while (current_chunk < nchunk0 * nchunk1) { - const int64_t ith0 = current_chunk % nchunk0; - const int64_t ith1 = current_chunk / nchunk0; - - const int64_t ir0_start = dr0 * ith0; - const int64_t ir0_end = MIN(ir0_start + dr0, nr0); - - const int64_t ir1_start = dr1 * ith1; - const int64_t ir1_end = MIN(ir1_start + dr1, nr1); - - lm_ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); - - if (nth >= nchunk0 * nchunk1) { - break; - } - - current_chunk = atomic_fetch_add_explicit(¶ms->threadpool->current_chunk, 1, memory_order_relaxed); - } -} - -// lm_ggml_compute_forward_mul_mat_id - -static void lm_ggml_compute_forward_mul_mat_id( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - const struct lm_ggml_tensor * ids = dst->src[2]; - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const enum lm_ggml_type type = src0->type; - - const bool src1_cont = lm_ggml_is_contiguous(src1); - - lm_ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; - enum lm_ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - lm_ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; - int64_t const matmul_num_cols = type_traits[type].ncols; - lm_ggml_gemv_t const gemv = type_traits[type].gemv; - - // we don't support permuted src0 or src1 - LM_GGML_ASSERT(nb00 == lm_ggml_type_size(type)); - LM_GGML_ASSERT(nb10 == lm_ggml_type_size(src1->type)); - - // dst cannot be transposed or permuted - LM_GGML_ASSERT(nb0 == sizeof(float)); - LM_GGML_ASSERT(nb0 <= nb1); - LM_GGML_ASSERT(nb1 <= nb2); - LM_GGML_ASSERT(nb2 <= nb3); - - // row groups - const int n_ids = ids->ne[0]; // n_expert_used - const int n_as = ne02; // n_expert - - char * wdata_src1_end = (src1->type == vec_dot_type) ? - (char *) params->wdata : - (char *) params->wdata + LM_GGML_PAD(lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(src1)), sizeof(int64_t)); - - struct mmid_row_mapping { - int32_t i1; - int32_t i2; - }; - - int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] - struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] - - if (src1->type != vec_dot_type) { - char * wdata = params->wdata; - - const size_t nbw1 = lm_ggml_row_size(vec_dot_type, ne10); - const size_t nbw2 = nbw1*ne11; - const size_t nbw3 = nbw2*ne12; - - assert(params->wsize >= ne13*nbw3); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = ith; i11 < ne11; i11 += nth) { - from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - ne10); - } - } - } - } - -#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] - - if (ith == 0) { - // initialize matrix_row_counts - memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); - - // group rows by src0 matrix - for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { - for (int id = 0; id < n_ids; ++id) { - const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); - - assert(i02 >= 0 && i02 < n_as); - - MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; - matrix_row_counts[i02] += 1; - } - } - } - - lm_ggml_barrier(params->threadpool); - - // compute each matrix multiplication in sequence - for (int cur_a = 0; cur_a < n_as; ++cur_a) { - const int64_t cne1 = matrix_row_counts[cur_a]; - - if (cne1 == 0) { - continue; - } - - const char * src0_cur = (const char *) src0->data + cur_a*nb02; - - const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = lm_ggml_row_size(vec_dot_type, ne10); - - const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = cne1; // src1 rows - - if (((lm_ggml_n_dims(src0) - 1) == 2) && gemv) { - int64_t src0_cur_start = (ith * ne01) / nth; - int64_t src0_cur_end = ((ith + 1) * ne01) / nth; - src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start; - src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end; - if (src0_cur_start >= src0_cur_end) return; - - for (int ir1 = 0; ir1 < nr1; ir1++) { - struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); - const int id = row_mapping.i1; // selected expert index - - const int64_t i11 = id % ne11; - const int64_t i12 = row_mapping.i2; // row index in src1 - - const int64_t i1 = id; // selected expert index - const int64_t i2 = i12; // row - - const char * src1_col = (const char *) wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12 * ne11) * row_size - : (i11 * nb11 + i12 * nb12)); - - gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, - (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); - } - continue; - } - - // distribute the thread work across the inner or outer loop based on which one is larger - - const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - - const int64_t ith0 = ith % nth0; - const int64_t ith1 = ith / nth0; - - const int64_t dr0 = (nr0 + nth0 - 1)/nth0; - const int64_t dr1 = (nr1 + nth1 - 1)/nth1; - - const int64_t ir010 = dr0*ith0; - const int64_t ir011 = MIN(ir010 + dr0, nr0); - - const int64_t ir110 = dr1*ith1; - const int64_t ir111 = MIN(ir110 + dr1, nr1); - - // threads with no work simply yield (not sure if it helps) - //if (ir010 >= ir011 || ir110 >= ir111) { - // sched_yield(); - // continue; - //} - - // block-tiling attempt - const int64_t blck_0 = 16; - const int64_t blck_1 = 16; - - // attempt to reduce false-sharing (does not seem to make a difference) - float tmp[16]; - - for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { - for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { - for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { - const int64_t _i12 = ir1; // logical row index for this expert - - struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12); - const int id = row_mapping.i1; // selected expert index - - const int64_t i11 = id % ne11; - const int64_t i12 = row_mapping.i2; // row index in src1 - - const int64_t i1 = id; // selected expert index - const int64_t i2 = i12; // row - - // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides - // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using - // the original src1 data pointer, so we should index using the indices directly - // TODO: this is a bit of a hack, we should probably have a better way to handle this - const char * src1_col = (const char *) wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12*ne11)*row_size - : (i11*nb11 + i12*nb12)); - - float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2)); - - //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); - //} - - for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1); - } - - memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); - } - } - } - } - -#undef MMID_MATRIX_ROW -} - -// lm_ggml_compute_forward_out_prod - -static void lm_ggml_compute_forward_out_prod_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - - const int ith = params->ith; - const int nth = params->nth; - - LM_GGML_ASSERT(ne0 == ne00); - LM_GGML_ASSERT(ne1 == ne10); - LM_GGML_ASSERT(ne2 == ne02); - LM_GGML_ASSERT(ne02 == ne12); - LM_GGML_ASSERT(ne3 == ne13); - LM_GGML_ASSERT(ne03 == ne13); - - // we don't support permuted src0 or src1 - LM_GGML_ASSERT(nb00 == sizeof(float)); - - // dst cannot be transposed or permuted - LM_GGML_ASSERT(nb0 == sizeof(float)); - // LM_GGML_ASSERT(nb0 <= nb1); - // LM_GGML_ASSERT(nb1 <= nb2); - // LM_GGML_ASSERT(nb2 <= nb3); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - - if (ith == 0) { - lm_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); - } - lm_ggml_barrier(params->threadpool); - - // dst[:,:,:,:] = 0 - // for i2,i3: - // for i1: - // for i01: - // for i0: - // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] - - // parallelize by last three dimensions - - // total rows in dst - const int64_t nr = ne1*ne2*ne3; - - // rows per thread - const int64_t dr = (nr + nth - 1)/nth; - - // row range for this thread - const int64_t ir0 = dr*ith; - const int64_t ir1 = MIN(ir0 + dr, nr); - - // block-tiling attempt - const int64_t blck_0 = MAX(LM_GGML_VEC_MAD_UNROLL, 32); - const int64_t blck_1 = 16; - - for (int64_t bir = ir0; bir < ir1; bir += blck_1) { - const int64_t bir1 = MIN(bir + blck_1, ir1); - for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { - const int64_t bne01 = MIN(bi01 + blck_0, ne01); - for (int64_t ir = bir; ir < bir1; ++ir) { - // dst indices - const int64_t i3 = ir/(ne2*ne1); - const int64_t i2 = (ir - i3*ne2*ne1)/ne1; - const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); - - const int64_t i02 = i2; - const int64_t i03 = i3; - - //const int64_t i10 = i1; - const int64_t i12 = i2; - const int64_t i13 = i3; - -#if LM_GGML_VEC_MAD_UNROLL > 2 - const int64_t bne01_unroll = bne01 - (bne01 % LM_GGML_VEC_MAD_UNROLL); - for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += LM_GGML_VEC_MAD_UNROLL) { - const int64_t i11 = i01; - - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - lm_ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); - } - for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) { - const int64_t i11 = i01; - - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - lm_ggml_vec_mad_f32(ne0, d, s0, *s1); - } -#else - for (int64_t i01 = bi01; i01 < bne01; ++i01) { - const int64_t i11 = i01; - - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - lm_ggml_vec_mad_f32(ne0, d, s0, *s1); - } -#endif - } - } - } -} - -static void lm_ggml_compute_forward_out_prod_q_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_TENSOR_BINARY_OP_LOCALS; - - const int ith = params->ith; - const int nth = params->nth; - - const enum lm_ggml_type type = src0->type; - lm_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - - LM_GGML_ASSERT(ne02 == ne12); - LM_GGML_ASSERT(ne03 == ne13); - LM_GGML_ASSERT(ne2 == ne12); - LM_GGML_ASSERT(ne3 == ne13); - - // we don't support permuted src0 dim0 - LM_GGML_ASSERT(nb00 == lm_ggml_type_size(type)); - - // dst dim0 cannot be transposed or permuted - LM_GGML_ASSERT(nb0 == sizeof(float)); - // LM_GGML_ASSERT(nb0 <= nb1); - // LM_GGML_ASSERT(nb1 <= nb2); - // LM_GGML_ASSERT(nb2 <= nb3); - - LM_GGML_ASSERT(ne0 == ne00); - LM_GGML_ASSERT(ne1 == ne10); - LM_GGML_ASSERT(ne2 == ne02); - LM_GGML_ASSERT(ne3 == ne03); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - - if (ith == 0) { - lm_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); - } - lm_ggml_barrier(params->threadpool); - - // parallelize by last three dimensions - - // total rows in dst - const int64_t nr = ne1*ne2*ne3; - - // rows per thread - const int64_t dr = (nr + nth - 1)/nth; - - // row range for this thread - const int64_t ir0 = dr*ith; - const int64_t ir1 = MIN(ir0 + dr, nr); - - // dst[:,:,:,:] = 0 - // for i2,i3: - // for i1: - // for i01: - // for i0: - // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] - - float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; - - for (int64_t ir = ir0; ir < ir1; ++ir) { - // dst indices - const int64_t i3 = ir/(ne2*ne1); - const int64_t i2 = (ir - i3*ne2*ne1)/ne1; - const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); - - const int64_t i02 = i2; - const int64_t i03 = i3; - - //const int64_t i10 = i1; - const int64_t i12 = i2; - const int64_t i13 = i3; - - for (int64_t i01 = 0; i01 < ne01; ++i01) { - const int64_t i11 = i01; - - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - dequantize_row_q(s0, wdata, ne0); - lm_ggml_vec_mad_f32(ne0, d, wdata, *s1); - } - } -} - -static void lm_ggml_compute_forward_out_prod( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_Q4_0: - case LM_GGML_TYPE_Q4_1: - case LM_GGML_TYPE_Q5_0: - case LM_GGML_TYPE_Q5_1: - case LM_GGML_TYPE_Q8_0: - case LM_GGML_TYPE_Q2_K: - case LM_GGML_TYPE_Q3_K: - case LM_GGML_TYPE_Q4_K: - case LM_GGML_TYPE_Q5_K: - case LM_GGML_TYPE_Q6_K: - case LM_GGML_TYPE_TQ1_0: - case LM_GGML_TYPE_TQ2_0: - case LM_GGML_TYPE_IQ2_XXS: - case LM_GGML_TYPE_IQ2_XS: - case LM_GGML_TYPE_IQ3_XXS: - case LM_GGML_TYPE_IQ1_S: - case LM_GGML_TYPE_IQ1_M: - case LM_GGML_TYPE_IQ4_NL: - case LM_GGML_TYPE_IQ4_XS: - case LM_GGML_TYPE_IQ3_S: - case LM_GGML_TYPE_IQ2_S: - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - case LM_GGML_TYPE_Q4_0_8_8: - { - lm_ggml_compute_forward_out_prod_q_f32(params, dst); - } break; - case LM_GGML_TYPE_F16: - { - LM_GGML_ABORT("fatal error"); // todo - // lm_ggml_compute_forward_out_prod_f16_f32(params, dst); - } - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_out_prod_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_scale - -static void lm_ggml_compute_forward_scale_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(dst)); - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - - // scale factor - float v; - memcpy(&v, dst->op_params, sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = lm_ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const size_t nb01 = src0->nb[1]; - - const size_t nb1 = dst->nb[1]; - - for (int i1 = ir0; i1 < ir1; i1++) { - if (dst->data != src0->data) { - // src0 is same shape as dst => same indices - memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); - } - lm_ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); - } -} - -static void lm_ggml_compute_forward_scale( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_scale_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_set - -static void lm_ggml_compute_forward_set_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); - - // view src0 and dst with these strides and data offset inbytes during set - // nb0 is implicitly element_size because src0 and dst are contiguous - size_t nb1 = ((int32_t *) dst->op_params)[0]; - size_t nb2 = ((int32_t *) dst->op_params)[1]; - size_t nb3 = ((int32_t *) dst->op_params)[2]; - size_t offset = ((int32_t *) dst->op_params)[3]; - bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace) { - if (params->ith == 0) { - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - memcpy( - ((char *) dst->data), - ((char *) src0->data), - lm_ggml_nbytes(dst)); - } - lm_ggml_barrier(params->threadpool); - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src1); - const int nc = src1->ne[0]; - - LM_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) - LM_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) - - // src0 and dst as viewed during set - const size_t nb0 = lm_ggml_element_size(src0); - - const int im0 = (ne10 == 0 ? 0 : ne10-1); - const int im1 = (ne11 == 0 ? 0 : ne11-1); - const int im2 = (ne12 == 0 ? 0 : ne12-1); - const int im3 = (ne13 == 0 ? 0 : ne13-1); - - LM_GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= lm_ggml_nbytes(dst)); - - LM_GGML_ASSERT(nb10 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are viewed with shape of src1 and offset - // => same indices - const int i3 = ir/(ne12*ne11); - const int i2 = (ir - i3*ne12*ne11)/ne11; - const int i1 = (ir - i3*ne12*ne11 - i2*ne11); - - lm_ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); - } -} - -static void lm_ggml_compute_forward_set( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_set_f32(params, dst); - } break; - case LM_GGML_TYPE_F16: - case LM_GGML_TYPE_BF16: - case LM_GGML_TYPE_Q4_0: - case LM_GGML_TYPE_Q4_1: - case LM_GGML_TYPE_Q5_0: - case LM_GGML_TYPE_Q5_1: - case LM_GGML_TYPE_Q8_0: - case LM_GGML_TYPE_Q8_1: - case LM_GGML_TYPE_Q2_K: - case LM_GGML_TYPE_Q3_K: - case LM_GGML_TYPE_Q4_K: - case LM_GGML_TYPE_Q5_K: - case LM_GGML_TYPE_Q6_K: - case LM_GGML_TYPE_TQ1_0: - case LM_GGML_TYPE_TQ2_0: - case LM_GGML_TYPE_IQ2_XXS: - case LM_GGML_TYPE_IQ2_XS: - case LM_GGML_TYPE_IQ3_XXS: - case LM_GGML_TYPE_IQ1_S: - case LM_GGML_TYPE_IQ1_M: - case LM_GGML_TYPE_IQ4_NL: - case LM_GGML_TYPE_IQ4_XS: - case LM_GGML_TYPE_IQ3_S: - case LM_GGML_TYPE_IQ2_S: - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - case LM_GGML_TYPE_Q4_0_8_8: - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_cpy - -static void lm_ggml_compute_forward_cpy( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - lm_ggml_compute_forward_dup(params, dst); -} - -// lm_ggml_compute_forward_cont - -static void lm_ggml_compute_forward_cont( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - lm_ggml_compute_forward_dup(params, dst); -} - -// lm_ggml_compute_forward_reshape - -static void lm_ggml_compute_forward_reshape( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - // NOP - UNUSED(params); - UNUSED(dst); -} - -// lm_ggml_compute_forward_view - -static void lm_ggml_compute_forward_view( - const struct lm_ggml_compute_params * params, - const struct lm_ggml_tensor * dst) { - // NOP - UNUSED(params); - UNUSED(dst); -} - -// lm_ggml_compute_forward_permute - -static void lm_ggml_compute_forward_permute( - const struct lm_ggml_compute_params * params, - const struct lm_ggml_tensor * dst) { - // NOP - UNUSED(params); - UNUSED(dst); -} - -// lm_ggml_compute_forward_transpose - -static void lm_ggml_compute_forward_transpose( - const struct lm_ggml_compute_params * params, - const struct lm_ggml_tensor * dst) { - // NOP - UNUSED(params); - UNUSED(dst); -} - -// lm_ggml_compute_forward_get_rows - -static void lm_ggml_compute_forward_get_rows_q( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - const int64_t nc = ne00; - const int64_t nr = lm_ggml_nelements(src1); - - const enum lm_ggml_type type = src0->type; - lm_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - - assert(ne0 == nc); - assert(ne02 == ne11); - assert(nb00 == lm_ggml_type_size(type)); - assert(lm_ggml_nrows(dst) == nr); - - const int ith = params->ith; - const int nth = params->nth; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int64_t i = ir0; i < ir1; ++i) { - const int64_t i12 = i/(ne11*ne10); - const int64_t i11 = (i - i12*ne11*ne10)/ne10; - const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); - const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - - LM_GGML_ASSERT(i01 >= 0 && i01 < ne01); - - dequantize_row_q( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); - } -} - -static void lm_ggml_compute_forward_get_rows_f16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - const int64_t nc = ne00; - const int64_t nr = lm_ggml_nelements(src1); - - assert(ne0 == nc); - assert(ne02 == ne11); - assert(nb00 == sizeof(lm_ggml_fp16_t)); - assert(lm_ggml_nrows(dst) == nr); - - const int ith = params->ith; - const int nth = params->nth; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int64_t i = ir0; i < ir1; ++i) { - const int64_t i12 = i/(ne11*ne10); - const int64_t i11 = (i - i12*ne11*ne10)/ne10; - const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); - const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - - LM_GGML_ASSERT(i01 >= 0 && i01 < ne01); - - lm_ggml_fp16_to_fp32_row( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); - } -} - -static void lm_ggml_compute_forward_get_rows_bf16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - const int64_t nc = ne00; - const int64_t nr = lm_ggml_nelements(src1); - - assert(ne0 == nc); - assert(ne02 == ne11); - assert(nb00 == sizeof(lm_ggml_bf16_t)); - assert(lm_ggml_nrows(dst) == nr); - - const int ith = params->ith; - const int nth = params->nth; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int64_t i = ir0; i < ir1; ++i) { - const int64_t i12 = i/(ne11*ne10); - const int64_t i11 = (i - i12*ne11*ne10)/ne10; - const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); - const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - - LM_GGML_ASSERT(i01 >= 0 && i01 < ne01); - - lm_ggml_bf16_to_fp32_row( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); - } -} - -static void lm_ggml_compute_forward_get_rows_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - const int64_t nc = ne00; - const int64_t nr = lm_ggml_nelements(src1); - - assert(ne0 == nc); - assert(ne02 == ne11); - assert(nb00 == sizeof(float)); - assert(lm_ggml_nrows(dst) == nr); - - const int ith = params->ith; - const int nth = params->nth; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int64_t i = ir0; i < ir1; ++i) { - const int64_t i12 = i/(ne11*ne10); - const int64_t i11 = (i - i12*ne11*ne10)/ne10; - const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); - const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - - LM_GGML_ASSERT(i01 >= 0 && i01 < ne01); - - lm_ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), - (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); - } -} - -static void lm_ggml_compute_forward_get_rows( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_Q4_0: - case LM_GGML_TYPE_Q4_1: - case LM_GGML_TYPE_Q5_0: - case LM_GGML_TYPE_Q5_1: - case LM_GGML_TYPE_Q8_0: - case LM_GGML_TYPE_Q8_1: - case LM_GGML_TYPE_Q2_K: - case LM_GGML_TYPE_Q3_K: - case LM_GGML_TYPE_Q4_K: - case LM_GGML_TYPE_Q5_K: - case LM_GGML_TYPE_Q6_K: - case LM_GGML_TYPE_TQ1_0: - case LM_GGML_TYPE_TQ2_0: - case LM_GGML_TYPE_IQ2_XXS: - case LM_GGML_TYPE_IQ2_XS: - case LM_GGML_TYPE_IQ3_XXS: - case LM_GGML_TYPE_IQ1_S: - case LM_GGML_TYPE_IQ1_M: - case LM_GGML_TYPE_IQ4_NL: - case LM_GGML_TYPE_IQ4_XS: - case LM_GGML_TYPE_IQ3_S: - case LM_GGML_TYPE_IQ2_S: - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - case LM_GGML_TYPE_Q4_0_8_8: - { - lm_ggml_compute_forward_get_rows_q(params, dst); - } break; - case LM_GGML_TYPE_F16: - { - lm_ggml_compute_forward_get_rows_f16(params, dst); - } break; - case LM_GGML_TYPE_BF16: - { - lm_ggml_compute_forward_get_rows_bf16(params, dst); - } break; - case LM_GGML_TYPE_F32: - case LM_GGML_TYPE_I32: - { - lm_ggml_compute_forward_get_rows_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } - - //static bool first = true; - //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); - //if (first) { - // first = false; - //} else { - // for (int k = 0; k < dst->ne[1]; ++k) { - // for (int j = 0; j < dst->ne[0]/16; ++j) { - // for (int i = 0; i < 16; ++i) { - // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); - // } - // printf("\n"); - // } - // printf("\n"); - // } - // printf("\n"); - // exit(0); - //} -} - -// lm_ggml_compute_forward_get_rows_back - -static void lm_ggml_compute_forward_get_rows_back_f32_f16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - if (params->ith != 0) { - return; - } - - LM_GGML_ASSERT(lm_ggml_is_contiguous(dst)); - - // lm_ggml_compute_forward_dup_same_cont(params, opt0, dst); - - memset(dst->data, 0, lm_ggml_nbytes(dst)); - - const int nc = src0->ne[0]; - const int nr = lm_ggml_nelements(src1); - - LM_GGML_ASSERT( dst->ne[0] == nc); - LM_GGML_ASSERT(src0->nb[0] == sizeof(lm_ggml_fp16_t)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; - - for (int j = 0; j < nc; ++j) { - lm_ggml_fp16_t v = ((lm_ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j]; - ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += LM_GGML_FP16_TO_FP32(v); - } - } -} - -static void lm_ggml_compute_forward_get_rows_back_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - if (params->ith != 0) { - return; - } - - LM_GGML_ASSERT(lm_ggml_is_contiguous(dst)); - - // lm_ggml_compute_forward_dup_same_cont(params, opt0, dst); - - memset(dst->data, 0, lm_ggml_nbytes(dst)); - - const int nc = src0->ne[0]; - const int nr = lm_ggml_nelements(src1); - - LM_GGML_ASSERT( dst->ne[0] == nc); - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; - - lm_ggml_vec_add_f32(nc, - (float *) ((char *) dst->data + r*dst->nb[1]), - (float *) ((char *) dst->data + r*dst->nb[1]), - (float *) ((char *) src0->data + i*src0->nb[1])); - } -} - -static void lm_ggml_compute_forward_get_rows_back( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F16: - { - lm_ggml_compute_forward_get_rows_back_f32_f16(params, dst); - } break; - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_get_rows_back_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } - - //static bool first = true; - //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); - //if (first) { - // first = false; - //} else { - // for (int k = 0; k < dst->ne[1]; ++k) { - // for (int j = 0; j < dst->ne[0]/16; ++j) { - // for (int i = 0; i < 16; ++i) { - // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); - // } - // printf("\n"); - // } - // printf("\n"); - // } - // printf("\n"); - // exit(0); - //} -} - -// lm_ggml_compute_forward_diag - -static void lm_ggml_compute_forward_diag_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - // TODO: handle transposed/permuted matrices - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - LM_GGML_ASSERT(ne00 == ne0); - LM_GGML_ASSERT(ne00 == ne1); - LM_GGML_ASSERT(ne01 == 1); - LM_GGML_ASSERT(ne02 == ne2); - LM_GGML_ASSERT(ne03 == ne3); - - LM_GGML_ASSERT(nb00 == sizeof(float)); - LM_GGML_ASSERT(nb0 == sizeof(float)); - - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = 0; i2 < ne2; i2++) { - for (int i1 = 0; i1 < ne1; i1++) { - float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02); - for (int i0 = 0; i0 < i1; i0++) { - d[i0] = 0; - } - d[i1] = s[i1]; - for (int i0 = i1+1; i0 < ne0; i0++) { - d[i0] = 0; - } - } - } - } -} - -static void lm_ggml_compute_forward_diag( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_diag_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_diag_mask_inf - -static void lm_ggml_compute_forward_diag_mask_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst, - const float value) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - const int ith = params->ith; - const int nth = params->nth; - - const int n_past = ((int32_t *) dst->op_params)[0]; - const bool inplace = src0->data == dst->data; - - LM_GGML_ASSERT(n_past >= 0); - - if (!inplace) { - if (ith == 0) { - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); - memcpy( - ((char *) dst->data), - ((char *) src0->data), - lm_ggml_nbytes(dst)); - } - lm_ggml_barrier(params->threadpool); - } - - // TODO: handle transposed/permuted matrices - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - const int nr = src0->ne[1]; - const int nz = n/nr; - - LM_GGML_ASSERT( dst->nb[0] == sizeof(float)); - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - - for (int k = 0; k < nz; k++) { - for (int j = ith; j < nr; j += nth) { - for (int i = n_past; i < nc; i++) { - if (i > n_past + j) { - *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value; - } - } - } - } -} - -static void lm_ggml_compute_forward_diag_mask_inf( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -static void lm_ggml_compute_forward_diag_mask_zero( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_diag_mask_f32(params, dst, 0); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_soft_max - -static void lm_ggml_compute_forward_soft_max_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - assert(lm_ggml_is_contiguous(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - float scale = 1.0f; - float max_bias = 0.0f; - - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - - // TODO: handle transposed/permuted matrices - - const int ith = params->ith; - const int nth = params->nth; - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - //const int64_t ne11 = src1 ? src1->ne[1] : 1; - - // TODO: is this supposed to be ceil instead of floor? - // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 - const uint32_t n_head = ne02; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - const int nc = src0->ne[0]; - const int nr = lm_ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; - - const bool use_f16 = (src1 && src1->type == LM_GGML_TYPE_F16); - - for (int i1 = ir0; i1 < ir1; i1++) { - // ALiBi - const uint32_t h = (i1/ne01)%ne02; // head - const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; - - float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); - float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); - - // broadcast the mask across rows - lm_ggml_fp16_t * mp_f16 = src1 ? (lm_ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - - lm_ggml_vec_cpy_f32 (nc, wp, sp); - lm_ggml_vec_scale_f32(nc, wp, scale); - if (mp_f32) { - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*LM_GGML_FP16_TO_FP32(mp_f16[i]); - } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*mp_f32[i]; - } - } - } - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(wp[i])); - } -#endif - - float max = -INFINITY; - lm_ggml_vec_max_f32(nc, &max, wp); - - lm_ggml_float sum = lm_ggml_vec_soft_max_f32(nc, dp, wp, max); - assert(sum > 0.0); - - sum = 1.0/sum; - lm_ggml_vec_scale_f32(nc, dp, sum); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(dp[i])); - assert(!isinf(dp[i])); - } -#endif - } -} - -static void lm_ggml_compute_forward_soft_max( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_soft_max_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - - -// lm_ggml_compute_forward_soft_max_back - -static void lm_ggml_compute_forward_soft_max_back_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(src1)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(dst)); - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); - LM_GGML_ASSERT(lm_ggml_are_same_shape(src1, dst)); - - // TODO: handle transposed/permuted matrices - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = lm_ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float *dy = (float *)((char *) src0->data + i1*src0->nb[1]); - float *y = (float *)((char *) src1->data + i1*src1->nb[1]); - float *dx = (float *)((char *) dst->data + i1*dst->nb[1]); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(dy[i])); - assert(!isnan(y[i])); - } -#endif - // Jii = yi - yi*yi - // Jij = -yi*yj - // J = diag(y)-y.T*y - // dx = J * dy - // dxk = sum_i(Jki * dyi) - // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk - // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk - // dxk = sum_i(-yk*yi * dyi) + yk*dyk - // dxk = -yk * sum_i(yi * dyi) + yk*dyk - // dxk = -yk * dot(y, dy) + yk*dyk - // dxk = yk * (- dot(y, dy) + dyk) - // dxk = yk * (dyk - dot(y, dy)) - // - // post-order: - // dot_y_dy := dot(y, dy) - // dx := dy - // dx := dx - dot_y_dy - // dx := dx * y - - // linear runtime, no additional memory - float dot_y_dy = 0; - lm_ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1); - lm_ggml_vec_cpy_f32 (nc, dx, dy); - lm_ggml_vec_acc1_f32(nc, dx, -dot_y_dy); - lm_ggml_vec_mul_f32 (nc, dx, dx, y); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(dx[i])); - assert(!isinf(dx[i])); - } -#endif - } -} - -static void lm_ggml_compute_forward_soft_max_back( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_soft_max_back_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_clamp - -static void lm_ggml_compute_forward_clamp_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - float min; - float max; - memcpy(&min, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - - LM_GGML_ASSERT( nb0 == sizeof(float)); - LM_GGML_ASSERT(nb00 == sizeof(float)); - - for (int j = ith; j < n; j += nth) { - float * dst_ptr = (float *) ((char *) dst->data + j*nb1); - float * src0_ptr = (float *) ((char *) src0->data + j*nb01); - - for (int i = 0; i < nc; i++) { - dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min); - } - } -} - -static void lm_ggml_compute_forward_clamp( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_clamp_f32(params, dst); - } break; - case LM_GGML_TYPE_F16: - case LM_GGML_TYPE_BF16: - case LM_GGML_TYPE_Q4_0: - case LM_GGML_TYPE_Q4_1: - case LM_GGML_TYPE_Q5_0: - case LM_GGML_TYPE_Q5_1: - case LM_GGML_TYPE_Q8_0: - case LM_GGML_TYPE_Q8_1: - case LM_GGML_TYPE_Q2_K: - case LM_GGML_TYPE_Q3_K: - case LM_GGML_TYPE_Q4_K: - case LM_GGML_TYPE_Q5_K: - case LM_GGML_TYPE_Q6_K: - case LM_GGML_TYPE_TQ1_0: - case LM_GGML_TYPE_TQ2_0: - case LM_GGML_TYPE_IQ2_XXS: - case LM_GGML_TYPE_IQ2_XS: - case LM_GGML_TYPE_IQ3_XXS: - case LM_GGML_TYPE_IQ1_S: - case LM_GGML_TYPE_IQ1_M: - case LM_GGML_TYPE_IQ4_NL: - case LM_GGML_TYPE_IQ4_XS: - case LM_GGML_TYPE_IQ3_S: - case LM_GGML_TYPE_IQ2_S: - case LM_GGML_TYPE_Q8_K: - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - case LM_GGML_TYPE_Q4_0_8_8: - case LM_GGML_TYPE_I8: - case LM_GGML_TYPE_I16: - case LM_GGML_TYPE_I32: - case LM_GGML_TYPE_I64: - case LM_GGML_TYPE_F64: - case LM_GGML_TYPE_COUNT: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_rope - -static float rope_yarn_ramp(const float low, const float high, const int i0) { - const float y = (i0 / 2 - low) / MAX(0.001f, high - low); - return 1 - MIN(1, MAX(0, y)); -} - -// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn -// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static void rope_yarn( - float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta) { - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta = theta_interp; - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); - } - *cos_theta = cosf(theta) * mscale; - *sin_theta = sinf(theta) * mscale; -} - -// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float lm_ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) { - return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); -} - -static void lm_ggml_rope_cache_init( - float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, - float * cache, float sin_sign, float theta_scale) { - // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py - float theta = theta_base; - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float ff = freq_factors ? freq_factors[i0/2] : 1.0f; - rope_yarn( - theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] - ); - cache[i0 + 1] *= sin_sign; - - theta *= theta_scale; - } -} - -void lm_ggml_rope_yarn_corr_dims( - int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] -) { - // start and end correction dims - float start = floorf(lm_ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); - float end = ceilf(lm_ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); - dims[0] = MAX(0, start); - dims[1] = MIN(n_dims - 1, end); -} - -static void lm_ggml_compute_forward_rope_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst, - const bool forward) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - const struct lm_ggml_tensor * src2 = dst->src[2]; - - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - //const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); - //printf("n_past = %d, ne2 = %d\n", n_past, ne2); - - LM_GGML_ASSERT(nb00 == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(dst); - - LM_GGML_ASSERT(n_dims <= ne0); - LM_GGML_ASSERT(n_dims % 2 == 0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - // row index used to determine which thread to use - int ir = 0; - - const float theta_scale = powf(freq_base, -2.0f/n_dims); - - float corr_dims[2]; - lm_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - - const bool is_neox = mode & LM_GGML_ROPE_TYPE_NEOX; - - const float * freq_factors = NULL; - if (src2 != NULL) { - LM_GGML_ASSERT(src2->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(src2->ne[0] >= n_dims / 2); - freq_factors = (const float *) src2->data; - } - - // backward process uses inverse rotation by cos and sin. - // cos and sin build a rotation matrix, where the inverse is the transpose. - // this essentially just switches the sign of sin. - const float sin_sign = forward ? 1.0f : -1.0f; - - const int32_t * pos = (const int32_t *) src1->data; - - for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = 0; i2 < ne2; i2++) { - const int64_t p = pos[i2]; - - float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - lm_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - - for (int64_t i1 = 0; i1 < ne1; i1++) { - if (ir++ < ir0) continue; - if (ir > ir1) break; - - if (!is_neox) { - for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { - const float cos_theta = cache[i0 + 0]; - const float sin_theta = cache[i0 + 1]; - - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[1]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; - } - } else { - for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { - const int64_t ic = i0/2; - - const float cos_theta = cache[i0 + 0]; - const float sin_theta = cache[i0 + 1]; - - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } - } - - for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } - } - } -} - -// TODO: deduplicate f16/f32 code -static void lm_ggml_compute_forward_rope_f16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst, - const bool forward) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - const struct lm_ggml_tensor * src2 = dst->src[2]; - - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - //const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); - //printf("n_past = %d, ne2 = %d\n", n_past, ne2); - - LM_GGML_ASSERT(nb0 == sizeof(lm_ggml_fp16_t)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(dst); - - LM_GGML_ASSERT(n_dims <= ne0); - LM_GGML_ASSERT(n_dims % 2 == 0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - // row index used to determine which thread to use - int ir = 0; - - const float theta_scale = powf(freq_base, -2.0f/n_dims); - - float corr_dims[2]; - lm_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - - const bool is_neox = mode & LM_GGML_ROPE_TYPE_NEOX; - - const float * freq_factors = NULL; - if (src2 != NULL) { - LM_GGML_ASSERT(src2->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(src2->ne[0] >= n_dims / 2); - freq_factors = (const float *) src2->data; - } - - // backward process uses inverse rotation by cos and sin. - // cos and sin build a rotation matrix, where the inverse is the transpose. - // this essentially just switches the sign of sin. - const float sin_sign = forward ? 1.0f : -1.0f; - - const int32_t * pos = (const int32_t *) src1->data; - - for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = 0; i2 < ne2; i2++) { - const int64_t p = pos[i2]; - - float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - lm_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - - for (int64_t i1 = 0; i1 < ne1; i1++) { - if (ir++ < ir0) continue; - if (ir > ir1) break; - - if (!is_neox) { - for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { - const float cos_theta = cache[i0 + 0]; - const float sin_theta = cache[i0 + 1]; - - const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = LM_GGML_FP16_TO_FP32(src[0]); - const float x1 = LM_GGML_FP16_TO_FP32(src[1]); - - dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[1] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - } - } else { - for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { - const int64_t ic = i0/2; - - const float cos_theta = cache[i0 + 0]; - const float sin_theta = cache[i0 + 1]; - - const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = LM_GGML_FP16_TO_FP32(src[0]); - const float x1 = LM_GGML_FP16_TO_FP32(src[n_dims/2]); - - dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[n_dims/2] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - } - } - - for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { - const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } - } - } -} - -static void lm_ggml_compute_forward_rope( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F16: - { - lm_ggml_compute_forward_rope_f16(params, dst, true); - } break; - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_rope_f32(params, dst, true); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_rope_back - -static void lm_ggml_compute_forward_rope_back( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F16: - { - lm_ggml_compute_forward_rope_f16(params, dst, false); - } break; - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_rope_f32(params, dst, false); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_conv_transpose_1d - -static void lm_ggml_compute_forward_conv_transpose_1d_f16_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00*ne01*ne02; - - LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); - LM_GGML_ASSERT(nb10 == sizeof(float)); - - if (ith == 0) { - memset(params->wdata, 0, params->wsize); - - // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) - { - lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) params->wdata + 0; - - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - lm_ggml_fp16_t * dst_data = wdata + i01*ne00*ne02; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ne02 + i02] = src[i00]; - } - } - } - } - - // permute source data (src1) from (L x Cin) to (Cin x L) - { - lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) params->wdata + nk; - lm_ggml_fp16_t * dst_data = wdata; - - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne11 + i11] = LM_GGML_FP32_TO_FP16(src[i10]); - } - } - } - - // need to zero dst since we are accumulating into it - memset(dst->data, 0, lm_ggml_nbytes(dst)); - } - lm_ggml_barrier(params->threadpool); - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - - // total rows in dst - const int nr = ne1; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) params->wdata + 0; - lm_ggml_fp16_t * const wdata_src = wdata + nk; - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - lm_ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00; - for (int i10 = 0; i10 < ne10; i10++) { - const int i1n = i10*ne11; - for (int i00 = 0; i00 < ne00; i00++) { - float v = 0; - lm_ggml_vec_dot_f16(ne02, &v, 0, - (lm_ggml_fp16_t *) wdata_src + i1n, 0, - (lm_ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1); - dst_data[i10*s0 + i00] += v; - } - } - } -} - -static void lm_ggml_compute_forward_conv_transpose_1d_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00*ne01*ne02; - - LM_GGML_ASSERT(nb00 == sizeof(float)); - LM_GGML_ASSERT(nb10 == sizeof(float)); - - if (ith == 0) { - memset(params->wdata, 0, params->wsize); - - // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) - { - float * const wdata = (float *) params->wdata + 0; - - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i01*ne00*ne02; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ne02 + i02] = src[i00]; - } - } - } - } - - // prepare source data (src1) - { - float * const wdata = (float *) params->wdata + nk; - float * dst_data = wdata; - - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne11 + i11] = src[i10]; - } - } - } - - // need to zero dst since we are accumulating into it - memset(dst->data, 0, lm_ggml_nbytes(dst)); - } - lm_ggml_barrier(params->threadpool); - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - - // total rows in dst - const int nr = ne1; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * const wdata = (float *) params->wdata + 0; - float * const wdata_src = wdata + nk; - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - float * wdata_kernel = wdata + i1*ne02*ne00; - for (int i10 = 0; i10 < ne10; i10++) { - const int i1n = i10*ne11; - for (int i00 = 0; i00 < ne00; i00++) { - float v = 0; - lm_ggml_vec_dot_f32(ne02, &v, 0, - wdata_src + i1n, 0, - wdata_kernel + i00*ne02, 0, 1); - dst_data[i10*s0 + i00] += v; - } - } - } -} - -static void lm_ggml_compute_forward_conv_transpose_1d( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F16: - { - lm_ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst); - } break; - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_conv_transpose_1d_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_im2col_f32 -// src0: kernel [OC, IC, KH, KW] -// src1: image [N, IC, IH, IW] -// dst: result [N, OH, OW, IC*KH*KW] -static void lm_ggml_compute_forward_im2col_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); - - LM_GGML_TENSOR_BINARY_OP_LOCALS; - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t N = is_2D ? ne13 : ne12; - const int64_t IC = is_2D ? ne12 : ne11; - const int64_t IH = is_2D ? ne11 : 1; - const int64_t IW = ne10; - - const int64_t KH = is_2D ? ne01 : 1; - const int64_t KW = ne00; - - const int64_t OH = is_2D ? ne2 : 1; - const int64_t OW = ne1; - - int ofs0 = is_2D ? nb13 : nb12; - int ofs1 = is_2D ? nb12 : nb11; - - LM_GGML_ASSERT(nb10 == sizeof(float)); - - // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] - { - float * const wdata = (float *) dst->data; - - for (int64_t in = 0; in < N; in++) { - for (int64_t ioh = 0; ioh < OH; ioh++) { // 1 - for (int64_t iow = 0; iow < OW; iow++) { - for (int64_t iic = ith; iic < IC; iic += nth) { - - // micro kernel - float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] - - for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 - for (int64_t ikw = 0; ikw < KW; ikw++) { - const int64_t iiw = iow*s0 + ikw*d0 - p0; - const int64_t iih = ioh*s1 + ikh*d1 - p1; - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; - } else { - dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]); - } - } - } - } - } - } - } - } -} - - -// lm_ggml_compute_forward_im2col_f16 -// src0: kernel [OC, IC, KH, KW] -// src1: image [N, IC, IH, IW] -// dst: result [N, OH, OW, IC*KH*KW] -static void lm_ggml_compute_forward_im2col_f16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F16); - - LM_GGML_TENSOR_BINARY_OP_LOCALS; - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t N = is_2D ? ne13 : ne12; - const int64_t IC = is_2D ? ne12 : ne11; - const int64_t IH = is_2D ? ne11 : 1; - const int64_t IW = ne10; - - const int64_t KH = is_2D ? ne01 : 1; - const int64_t KW = ne00; - - const int64_t OH = is_2D ? ne2 : 1; - const int64_t OW = ne1; - - int ofs0 = is_2D ? nb13 : nb12; - int ofs1 = is_2D ? nb12 : nb11; - - LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); - LM_GGML_ASSERT(nb10 == sizeof(float)); - - // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] - { - lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) dst->data; - - for (int64_t in = 0; in < N; in++) { - for (int64_t ioh = 0; ioh < OH; ioh++) { // 1 - for (int64_t iow = 0; iow < OW; iow++) { - for (int64_t iic = ith; iic < IC; iic += nth) { - - // micro kernel - lm_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] - - for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 - for (int64_t ikw = 0; ikw < KW; ikw++) { - const int64_t iiw = iow*s0 + ikw*d0 - p0; - const int64_t iih = ioh*s1 + ikh*d1 - p1; - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; - } else { - dst_data[iic*(KH*KW) + ikh*KW + ikw] = LM_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); - } - } - } - } - } - } - } - } -} - -static void lm_ggml_compute_forward_im2col( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - switch (dst->type) { - case LM_GGML_TYPE_F16: - { - lm_ggml_compute_forward_im2col_f16(params, dst); - } break; - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_im2col_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_im2col_back_f32 - -static void lm_ggml_compute_forward_im2col_back_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); - - LM_GGML_TENSOR_BINARY_OP_LOCALS; - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t N = is_2D ? ne3 : ne2; - const int64_t IC = is_2D ? ne2 : ne1; - const int64_t IH = is_2D ? ne1 : 1; - const int64_t IW = ne0; - - const int64_t KH = is_2D ? ne01 : 1; - const int64_t KW = ne00; - - const int64_t OH = is_2D ? ne12 : 1; - const int64_t OW = ne11; - - int ofs0 = is_2D ? nb3 : nb2; - int ofs1 = is_2D ? nb2 : nb1; - - LM_GGML_ASSERT(nb0 == sizeof(float)); - - // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] - { - float * const wdata = (float *) dst->data; - - for (int64_t in = 0; in < N; in++) { - for (int64_t iic = ith; iic < IC; iic += nth) { - for (int64_t iih = 0; iih < IH; iih++) { - for (int64_t iiw = 0; iiw < IW; iiw++) { - - // micro kernel - float grad = 0.0f; - for (int64_t ikh = 0; ikh < KH; ikh++) { - for (int64_t ikw = 0; ikw < KW; ikw++) { - // For s0 > 1 some values were skipped over in the forward pass. - // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well. - const int64_t tmpw = (iiw + p0 - ikw*d0); - if (tmpw % s0 != 0) { - continue; - } - const int64_t iow = tmpw / s0; - - // Equivalent logic as above except for s1. - int64_t ioh; - if (is_2D) { - const int64_t tmph = iih + p1 - ikh*d1; - - if (tmph % s1 != 0) { - continue; - } - - ioh = tmph / s1; - } else { - ioh = 0; - } - - if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) { - continue; - } - - const float * const src_data = (const float *) src1->data - + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - grad += src_data[iic*(KH*KW) + ikh*KW + ikw]; - } - } - float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW] - dst_data[iih*IW + iiw] = grad; - } - } - } - } - } -} - -// lm_ggml_compute_forward_conv_transpose_2d - -static void lm_ggml_compute_forward_conv_transpose_2d( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); - - LM_GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00*ne01*ne02*ne03; - - LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t)); - LM_GGML_ASSERT(nb10 == sizeof(float)); - - if (ith == 0) { - memset(params->wdata, 0, params->wsize); - - // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) - { - lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) params->wdata + 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02); - lm_ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03; - for (int64_t i01 = 0; i01 < ne01; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00]; - } - } - } - } - } - - // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh) - { - lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) params->wdata + nk; - for (int i12 = 0; i12 < ne12; i12++) { - for (int i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11); - lm_ggml_fp16_t * dst_data = wdata + i11*ne10*ne12; - for (int i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne12 + i12] = LM_GGML_FP32_TO_FP16(src[i10]); - } - } - } - } - - memset(dst->data, 0, lm_ggml_nbytes(dst)); - } - lm_ggml_barrier(params->threadpool); - - const int32_t stride = lm_ggml_get_op_params_i32(dst, 0); - - // total patches in dst - const int np = ne2; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - const int ip0 = dp*ith; - const int ip1 = MIN(ip0 + dp, np); - - lm_ggml_fp16_t * const wdata = (lm_ggml_fp16_t *) params->wdata + 0; - lm_ggml_fp16_t * const wdata_src = wdata + nk; - - for (int i2 = ip0; i2 < ip1; i2++) { // Cout - float * dst_data = (float *)((char *) dst->data + i2*nb2); - lm_ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; - for (int i11 = 0; i11 < ne11; i11++) { - for (int i10 = 0; i10 < ne10; i10++) { - const int i1n = i11*ne10*ne12 + i10*ne12; - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - float v = 0; - lm_ggml_vec_dot_f16(ne03, &v, 0, - wdata_src + i1n, 0, - wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); - dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; - } - } - } - } - } -} - -// lm_ggml_compute_forward_pool_1d_sk_p0 - -static void lm_ggml_compute_forward_pool_1d_sk_p0( - const struct lm_ggml_compute_params * params, - const enum lm_ggml_op_pool op, - const int k, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src = dst->src[0]; - - assert(src->type == LM_GGML_TYPE_F32 || src->type == LM_GGML_TYPE_F16); - - if (params->ith != 0) { - return; - } - - const char * cdata = (const char *)src->data; - const char * const data_end = cdata + lm_ggml_nbytes(src); - float * drow = (float *)dst->data; - - const int64_t rs = dst->ne[0]; - - while (cdata < data_end) { - const void * srow = (const void *)cdata; - int j = 0; - for (int64_t i = 0; i < rs; ++i) { - switch (op) { - case LM_GGML_OP_POOL_AVG: drow[i] = 0; break; - case LM_GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break; - case LM_GGML_OP_POOL_COUNT: LM_GGML_ABORT("fatal error"); - } - for (int ki = 0; ki < k; ++ki) { - const float srow_j = (src->type == LM_GGML_TYPE_F32) ? ((const float*)srow)[j] : LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t*)srow)[j]); - switch (op) { - case LM_GGML_OP_POOL_AVG: drow[i] += srow_j; break; - case LM_GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break; - case LM_GGML_OP_POOL_COUNT: LM_GGML_ABORT("fatal error"); - } - ++j; - } - switch (op) { - case LM_GGML_OP_POOL_AVG: drow[i] /= k; break; - case LM_GGML_OP_POOL_MAX: break; - case LM_GGML_OP_POOL_COUNT: LM_GGML_ABORT("fatal error"); - } - } - - cdata += src->nb[1]; - drow += rs; - } -} - -// lm_ggml_compute_forward_pool_1d - -static void lm_ggml_compute_forward_pool_1d( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const int32_t * opts = (const int32_t *)dst->op_params; - enum lm_ggml_op_pool op = opts[0]; - const int k0 = opts[1]; - const int s0 = opts[2]; - const int p0 = opts[3]; - LM_GGML_ASSERT(p0 == 0); // padding not supported - LM_GGML_ASSERT(k0 == s0); // only s = k supported - - lm_ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst); -} - -// lm_ggml_compute_forward_pool_2d - -static void lm_ggml_compute_forward_pool_2d( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src = dst->src[0]; - - assert(src->type == LM_GGML_TYPE_F32 || src->type == LM_GGML_TYPE_F16); - - if (params->ith != 0) { - return; - } - - const int32_t * opts = (const int32_t *)dst->op_params; - enum lm_ggml_op_pool op = opts[0]; - const int k0 = opts[1]; - const int k1 = opts[2]; - const int s0 = opts[3]; - const int s1 = opts[4]; - const int p0 = opts[5]; - const int p1 = opts[6]; - const char * cdata = (const char*)src->data; - const char * const data_end = cdata + lm_ggml_nbytes(src); - - const int64_t px = dst->ne[0]; - const int64_t py = dst->ne[1]; - const int64_t pa = px * py; - - float * dplane = (float *)dst->data; - - const int ka = k0 * k1; - const int offset0 = -p0; - const int offset1 = -p1; - - while (cdata < data_end) { - for (int oy = 0; oy < py; ++oy) { - float * const drow = dplane + oy * px; - for (int ox = 0; ox < px; ++ox) { - float * const out = drow + ox; - switch (op) { - case LM_GGML_OP_POOL_AVG: *out = 0; break; - case LM_GGML_OP_POOL_MAX: *out = -FLT_MAX; break; - case LM_GGML_OP_POOL_COUNT: LM_GGML_ABORT("fatal error"); - } - - const int ix = offset0 + ox * s0; - const int iy = offset1 + oy * s1; - - for (int ky = 0; ky < k1; ++ky) { - if (iy + ky < 0 || iy + ky >= src->ne[1]) continue; - const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky)); - for (int kx = 0; kx < k0; ++kx) { - int j = ix + kx; - if (j < 0 || j >= src->ne[0]) continue; - const float srow_j = (src->type == LM_GGML_TYPE_F32) ? ((const float*)srow)[j] : LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t*)srow)[j]); - switch (op) { - case LM_GGML_OP_POOL_AVG: *out += srow_j; break; - case LM_GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break; - case LM_GGML_OP_POOL_COUNT: LM_GGML_ABORT("fatal error"); - } - } - } - switch (op) { - case LM_GGML_OP_POOL_AVG: *out /= ka; break; - case LM_GGML_OP_POOL_MAX: break; - case LM_GGML_OP_POOL_COUNT: LM_GGML_ABORT("fatal error"); - } - } - } - - cdata += src->nb[2]; - dplane += pa; - } -} - -// lm_ggml_compute_forward_pool_2d_back - -static void lm_ggml_compute_forward_pool_2d_back( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src = dst->src[0]; - const struct lm_ggml_tensor * dstf = dst->src[1]; // forward tensor of dst - - assert(dst->type == LM_GGML_TYPE_F32 || dst->type == LM_GGML_TYPE_F16); - - if (params->ith != 0) { - return; - } - - const int32_t * opts = (const int32_t *)dst->op_params; - enum lm_ggml_op_pool op = opts[0]; - const int k0 = opts[1]; - const int k1 = opts[2]; - const int s0 = opts[3]; - const int s1 = opts[4]; - const int p0 = opts[5]; - const int p1 = opts[6]; - - char * cdata = (char *) dst->data; - const char * cdataf = (const char *) dstf->data; - const char * const data_end = cdata + lm_ggml_nbytes(dst); - - LM_GGML_ASSERT(params->ith == 0); - memset(cdata, 0, lm_ggml_nbytes(dst)); - - const int64_t px = src->ne[0]; - const int64_t py = src->ne[1]; - const int64_t pa = px * py; - - const float * splane = (const float *) src->data; - - const int ka = k0 * k1; - const int offset0 = -p0; - const int offset1 = -p1; - - while (cdata < data_end) { - for (int oy = 0; oy < py; ++oy) { - const float * const srow = splane + oy * px; - for (int ox = 0; ox < px; ++ox) { - const float grad0 = srow[ox]; - - const int ix = offset0 + ox * s0; - const int iy = offset1 + oy * s1; - - if (op == LM_GGML_OP_POOL_MAX) { - float maxval = -FLT_MAX; - int kxmax = -1; - int kymax = -1; - - for (int ky = 0; ky < k1; ++ky) { - if (iy + ky < 0 || iy + ky >= dst->ne[1]) { - continue; - } - const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky)); - for (int kx = 0; kx < k0; ++kx) { - int j = ix + kx; - if (j < 0 || j >= dst->ne[0]) { - continue; - } - - const float val = dst->type == LM_GGML_TYPE_F32 ? - ((const float *) drowf)[j] : LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t *) drowf)[j]); - if (val <= maxval) { - continue; - } - - maxval = val; - kxmax = kx; - kymax = ky; - } - } - - if (kxmax == -1 || kymax == -1) { - continue; - } - - void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax)); - const int j = ix + kxmax; - if (dst->type == LM_GGML_TYPE_F32) { - ((float *) drow)[j] += grad0; - } else { - ((lm_ggml_fp16_t *) drow)[j] = LM_GGML_FP32_TO_FP16(grad0 + LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t *) drow)[j])); - } - } else if (op == LM_GGML_OP_POOL_AVG) { - const float grad = grad0 / ka; - - for (int ky = 0; ky < k1; ++ky) { - if (iy + ky < 0 || iy + ky >= dst->ne[1]) { - continue; - } - void * drow = (void *)(cdata + dst->nb[1] * (iy + ky)); - for (int kx = 0; kx < k0; ++kx) { - int j = ix + kx; - if (j < 0 || j >= dst->ne[0]) { - continue; - } - - if (dst->type == LM_GGML_TYPE_F32) { - ((float *) drow)[j] += grad; - } else { - ((lm_ggml_fp16_t *) drow)[j] += LM_GGML_FP32_TO_FP16(grad); - } - } - } - } else { - LM_GGML_ASSERT(false); - } - } - } - - cdata += dst->nb[2]; - cdataf += dst->nb[2]; - splane += pa; - } -} - -// lm_ggml_compute_forward_upscale - -static void lm_ggml_compute_forward_upscale_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); - - const int ith = params->ith; - const int nth = params->nth; - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - const float sf0 = (float)ne0/src0->ne[0]; - const float sf1 = (float)ne1/src0->ne[1]; - const float sf2 = (float)ne2/src0->ne[2]; - const float sf3 = (float)ne3/src0->ne[3]; - - // TODO: optimize - - for (int64_t i3 = 0; i3 < ne3; i3++) { - const int64_t i03 = i3 / sf3; - for (int64_t i2 = ith; i2 < ne2; i2 += nth) { - const int64_t i02 = i2 / sf2; - for (int64_t i1 = 0; i1 < ne1; i1++) { - const int64_t i01 = i1 / sf1; - for (int64_t i0 = 0; i0 < ne0; i0++) { - const int64_t i00 = i0 / sf0; - - const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); - - *y = *x; - } - } - } - } -} - -static void lm_ggml_compute_forward_upscale( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_upscale_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - - -// lm_ggml_compute_forward_pad - -static void lm_ggml_compute_forward_pad_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - LM_GGML_ASSERT( dst->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - float * dst_ptr = (float *) dst->data; - - // TODO: optimize - - for (int64_t i2 = 0; i2 < ne2; ++i2) { - for (int64_t i1 = ith; i1 < ne1; i1 += nth) { - for (int64_t i0 = 0; i0 < ne0; ++i0) { - for (int64_t i3 = 0; i3 < ne3; ++i3) { - const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; - - const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - dst_ptr[dst_idx] = *src_ptr; - } else { - dst_ptr[dst_idx] = 0; - } - } - } - } - } -} - -static void lm_ggml_compute_forward_pad( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_pad_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - - -// lm_ggml_compute_forward_arange - -static void lm_ggml_compute_forward_arange_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - LM_GGML_ASSERT(dst->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const float start = lm_ggml_get_op_params_f32(dst, 0); - const float stop = lm_ggml_get_op_params_f32(dst, 1); - const float step = lm_ggml_get_op_params_f32(dst, 2); - - const int64_t steps = (int64_t) ceilf((stop - start) / step); - - LM_GGML_ASSERT(lm_ggml_nelements(dst) == steps); - - for (int64_t i = ith; i < steps; i+= nth) { - float value = start + step * i; - ((float *)dst->data)[i] = value; - } -} - -static void lm_ggml_compute_forward_arange( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - switch (dst->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_arange_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -static void lm_ggml_compute_forward_timestep_embedding_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - const int dim = lm_ggml_get_op_params_i32(dst, 0); - const int max_period = lm_ggml_get_op_params_i32(dst, 1); - - int half = dim / 2; - - for (int64_t i = 0; i < ne00; i++) { - float * embed_data = (float *)((char *) dst->data + i*nb1); - for (int64_t j = ith; j < half; j += nth) { - float timestep = ((float *)src0->data)[i]; - float freq = (float)expf(-logf(max_period) * j / half); - float arg = timestep * freq; - embed_data[j] = cosf(arg); - embed_data[j + half] = sinf(arg); - } - if (dim % 2 != 0 && ith == 0) { - embed_data[dim] = 0.f; - } - } -} - -static void lm_ggml_compute_forward_timestep_embedding( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_timestep_embedding_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_argsort - -static void lm_ggml_compute_forward_argsort_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - LM_GGML_ASSERT(nb0 == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t nr = lm_ggml_nrows(src0); - - enum lm_ggml_sort_order order = (enum lm_ggml_sort_order) lm_ggml_get_op_params_i32(dst, 0); - - for (int64_t i = ith; i < nr; i += nth) { - int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); - const float * src_data = (float *)((char *) src0->data + i*nb01); - - for (int64_t j = 0; j < ne0; j++) { - dst_data[j] = j; - } - - // C doesn't have a functional sort, so we do a bubble sort instead - for (int64_t j = 0; j < ne0; j++) { - for (int64_t k = j + 1; k < ne0; k++) { - if ((order == LM_GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || - (order == LM_GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { - int32_t tmp = dst_data[j]; - dst_data[j] = dst_data[k]; - dst_data[k] = tmp; - } - } - } - } -} - -static void lm_ggml_compute_forward_argsort( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_argsort_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_flash_attn_ext - -static void lm_ggml_compute_forward_flash_attn_ext_f16( - const struct lm_ggml_compute_params * params, - const struct lm_ggml_tensor * q, - const struct lm_ggml_tensor * k, - const struct lm_ggml_tensor * v, - const struct lm_ggml_tensor * mask, - struct lm_ggml_tensor * dst) { - - LM_GGML_TENSOR_LOCALS(int64_t, neq, q, ne) - LM_GGML_TENSOR_LOCALS(size_t, nbq, q, nb) - LM_GGML_TENSOR_LOCALS(int64_t, nek, k, ne) - LM_GGML_TENSOR_LOCALS(size_t, nbk, k, nb) - LM_GGML_TENSOR_LOCALS(int64_t, nev, v, ne) - LM_GGML_TENSOR_LOCALS(size_t, nbv, v, nb) - LM_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - LM_GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t D = neq0; - const int64_t N = neq1; - - LM_GGML_ASSERT(ne0 == D); - LM_GGML_ASSERT(ne2 == N); - - // input tensor rows must be contiguous - LM_GGML_ASSERT(nbq0 == lm_ggml_type_size(q->type)); - LM_GGML_ASSERT(nbk0 == lm_ggml_type_size(k->type)); - LM_GGML_ASSERT(nbv0 == lm_ggml_type_size(v->type)); - - LM_GGML_ASSERT(neq0 == D); - LM_GGML_ASSERT(nek0 == D); - LM_GGML_ASSERT(nev0 == D); - - LM_GGML_ASSERT(neq1 == N); - LM_GGML_ASSERT(nev0 == D); - - // dst cannot be transposed or permuted - LM_GGML_ASSERT(nb0 == sizeof(float)); - LM_GGML_ASSERT(nb0 <= nb1); - LM_GGML_ASSERT(nb1 <= nb2); - LM_GGML_ASSERT(nb2 <= nb3); - - // broadcast factors - const int64_t rk2 = neq2/nek2; - const int64_t rk3 = neq3/nek3; - - const int64_t rv2 = neq2/nev2; - const int64_t rv3 = neq3/nev3; - - // parallelize by q rows using lm_ggml_vec_dot_f32 - - // total rows in q - const int nr = neq1*neq2*neq3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float scale = 1.0f; - float max_bias = 0.0f; - float logit_softcap = 0.0f; - - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); - - if (logit_softcap != 0) { - scale /= logit_softcap; - } - - const uint32_t n_head = neq2; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - enum lm_ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type; - lm_ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float; - lm_ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; - lm_ggml_to_float_t const v_to_float = type_traits[v->type].to_float; - - LM_GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type"); - LM_GGML_ASSERT(v_to_float && "fattn: unsupported V-type"); - - // loop over n_batch and n_head - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - - const uint32_t h = iq2; // head index - const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; - - float S = 0.0f; // sum - float M = -INFINITY; // maximum KQ value - - float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator - float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer - lm_ggml_fp16_t * VKQ16 = (lm_ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator - lm_ggml_fp16_t * Q_q = (lm_ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 - - if (v->type == LM_GGML_TYPE_F16) { - memset(VKQ16, 0, D*sizeof(lm_ggml_fp16_t)); - } else { - memset(VKQ32, 0, D*sizeof(float)); - } - - const lm_ggml_fp16_t * mp = mask ? (lm_ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; - - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - - // v indices - const int iv3 = iq3 / rv3; - const int iv2 = iq2 / rv2; - - const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - q_to_vec_dot(pq, Q_q, D); - - // online softmax / attention - // loop over n_kv and n_head_kv - // ref: https://arxiv.org/pdf/2112.05682.pdf - for (int64_t ic = 0; ic < nek1; ++ic) { - const float mv = mp ? slope*LM_GGML_FP16_TO_FP32(mp[ic]) : 0.0f; - if (mv == -INFINITY) { - continue; - } - - float s; // KQ value - - const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); - kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); - - s = s*scale; // scale KQ value - - if (logit_softcap != 0.0f) { - s = logit_softcap*tanhf(s); - } - - s += mv; // apply mask - - const float Mold = M; - - float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value - float vs = 1.0f; // post-softmax KQ value, expf(s - M) - - const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - - if (v->type == LM_GGML_TYPE_F16) { - if (s > M) { - // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f - M = s; - ms = expf(Mold - M); - - // V = V*expf(Mold - M) - lm_ggml_vec_scale_f16(D, VKQ16, ms); - } else { - // no new maximum, ms == 1.0f, vs != 1.0f - vs = expf(s - M); - } - - // V += v*expf(s - M) - lm_ggml_vec_mad_f16(D, VKQ16, (const lm_ggml_fp16_t *) v_data, vs); - } else { - if (s > M) { - // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f - M = s; - ms = expf(Mold - M); - - // V = V*expf(Mold - M) - lm_ggml_vec_scale_f32(D, VKQ32, ms); - } else { - // no new maximum, ms == 1.0f, vs != 1.0f - vs = expf(s - M); - } - - v_to_float(v_data, V32, D); - - // V += v*expf(s - M) - lm_ggml_vec_mad_f32(D, VKQ32, V32, vs); - } - - S = S*ms + vs; // scale and increment sum with partial sum - } - - if (v->type == LM_GGML_TYPE_F16) { - for (int64_t d = 0; d < D; ++d) { - VKQ32[d] = LM_GGML_FP16_TO_FP32(VKQ16[d]); - } - } - - // V /= S - const float S_inv = 1.0f/S; - lm_ggml_vec_scale_f32(D, VKQ32, S_inv); - - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // original - //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); - - // permute(0, 2, 1, 3) - memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); - } -} - -static void lm_ggml_compute_forward_flash_attn_ext( - const struct lm_ggml_compute_params * params, - const struct lm_ggml_tensor * q, - const struct lm_ggml_tensor * k, - const struct lm_ggml_tensor * v, - const struct lm_ggml_tensor * mask, - struct lm_ggml_tensor * dst) { - switch (dst->op_params[3]) { - case LM_GGML_PREC_DEFAULT: - case LM_GGML_PREC_F32: - { - // uses F32 accumulators - lm_ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_flash_attn_back - -static void lm_ggml_compute_forward_flash_attn_back_f32( - const struct lm_ggml_compute_params * params, - const bool masked, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * q = dst->src[0]; - const struct lm_ggml_tensor * k = dst->src[1]; - const struct lm_ggml_tensor * v = dst->src[2]; - const struct lm_ggml_tensor * d = dst->src[3]; - - LM_GGML_TENSOR_LOCALS(int64_t, neq, q, ne) - LM_GGML_TENSOR_LOCALS(size_t, nbq, q, nb) - LM_GGML_TENSOR_LOCALS(int64_t, nek, k, ne) - LM_GGML_TENSOR_LOCALS(size_t, nbk, k, nb) - LM_GGML_TENSOR_LOCALS(int64_t, nev, v, ne) - LM_GGML_TENSOR_LOCALS(size_t, nbv, v, nb) - LM_GGML_TENSOR_LOCALS(int64_t, ned, d, ne) - LM_GGML_TENSOR_LOCALS(size_t, nbd, d, nb) - LM_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - LM_GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t D = neq0; - const int64_t N = neq1; - const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = lm_ggml_up(M, LM_GGML_SOFT_MAX_UNROLL); - const int mxDM = MAX(D, Mup); - - // LM_GGML_ASSERT(ne0 == D); - // LM_GGML_ASSERT(ne1 == N); - LM_GGML_ASSERT(P >= 0); - - LM_GGML_ASSERT(nbq0 == sizeof(float)); - LM_GGML_ASSERT(nbk0 == sizeof(float)); - LM_GGML_ASSERT(nbv0 == sizeof(float)); - - LM_GGML_ASSERT(neq0 == D); - LM_GGML_ASSERT(nek0 == D); - LM_GGML_ASSERT(nev1 == D); - LM_GGML_ASSERT(ned0 == D); - - LM_GGML_ASSERT(neq1 == N); - LM_GGML_ASSERT(nek1 == N + P); - LM_GGML_ASSERT(nev1 == D); - LM_GGML_ASSERT(ned1 == N); - - // dst cannot be transposed or permuted - LM_GGML_ASSERT(nb0 == sizeof(float)); - LM_GGML_ASSERT(nb0 <= nb1); - LM_GGML_ASSERT(nb1 <= nb2); - LM_GGML_ASSERT(nb2 <= nb3); - - if (ith == 0) { - memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); - } - lm_ggml_barrier(params->threadpool); - - const int64_t elem_q = lm_ggml_nelements(q); - const int64_t elem_k = lm_ggml_nelements(k); - - enum lm_ggml_type result_type = dst->type; - LM_GGML_ASSERT(lm_ggml_blck_size(result_type) == 1); - const size_t tsize = lm_ggml_type_size(result_type); - - const size_t offs_q = 0; - const size_t offs_k = offs_q + LM_GGML_PAD(elem_q * tsize, LM_GGML_MEM_ALIGN); - const size_t offs_v = offs_k + LM_GGML_PAD(elem_k * tsize, LM_GGML_MEM_ALIGN); - - void * grad_q = (char *) dst->data; - void * grad_k = (char *) dst->data + offs_k; - void * grad_v = (char *) dst->data + offs_v; - - const size_t nbgq1 = nb0*neq0; - const size_t nbgq2 = nb0*neq0*neq1; - const size_t nbgq3 = nb0*neq0*neq1*neq2; - - const size_t nbgk1 = nb0*nek0; - const size_t nbgk2 = nb0*nek0*nek1; - const size_t nbgk3 = nb0*nek0*nek1*neq2; - - const size_t nbgv1 = nb0*nev0; - const size_t nbgv2 = nb0*nev0*nev1; - const size_t nbgv3 = nb0*nev0*nev1*neq2; - - // parallelize by k rows using lm_ggml_vec_dot_f32 - - // total rows in k - const int nr = nek2*nek3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const float scale = 1.0f/sqrtf(D); - - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - - // how often k2 (and v2) is repeated in q2 - int nrep = neq2/nek2; - - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int ik3 = ir/(nek2); - const int ik2 = ir - ik3*nek2; - - const int iq3 = ik3; - const int id3 = ik3; - const int iv3 = ik3; - const int iv2 = ik2; - - for (int irep = 0; irep < nrep; ++irep) { - const int iq2 = ik2 + irep*nek2; - const int id2 = iq2; - - // (ik2 + irep*nek2) % nek2 == ik2 - for (int iq1 = 0; iq1 < neq1; ++iq1) { - const int id1 = iq1; - - // not sure about CACHE_LINE_SIZE_F32.. - // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? - float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); - float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); - - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } - - const int64_t masked_begin = masked ? (P + iq1 + 1) : M; - for (int64_t ic = 0; ic < masked_begin; ++ic) { - // k indices - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - lm_ggml_vec_dot_f32(neq0, - S + i1, 0, - (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, - (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1); - } - - // scale - lm_ggml_vec_scale_f32(masked_begin, S, scale); - - for (int64_t i = masked_begin; i < M; i++) { - S[i] = -INFINITY; - } - - // softmax - // exclude known -INF S[..] values from max and loop - // dont forget to set their SM values to zero - { - float max = -INFINITY; - lm_ggml_vec_max_f32(masked_begin, &max, S); - - lm_ggml_float sum = 0.0; - { -#ifdef LM_GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(SM, 1, &max, SM, 1, Mup); - vvexpf(SM, SM, &Mup); - lm_ggml_vec_sum_f32(Mup, &sum, SM); -#else - sum = lm_ggml_vec_soft_max_f32(Mup, SM, S, max); -#endif - } - - assert(sum > 0.0); - - sum = 1.0/sum; - lm_ggml_vec_scale_f32(masked_begin, SM, sum); - - } - - // step-by-step explanation - { - // forward-process shape grads from backward process - // parallel_for ik2,ik3: - // for irep: - // iq2 = ik2 + irep*nek2 - // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur] - // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] - // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur] - // for iq1: - // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur - // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur - // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 - // S0 = -Inf [D,1,1,1] - // ~S1[i] = dot(kcur[:D,i], qcur) - // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale - // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) - // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur - // ~S5[i] = dot(vcur[:,i], S4) - // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3] - // ~dst[i,iq1,iq2,iq3] = S5[i] ^ - // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3] - // dst backward-/ grad[dst] = d - // - // output gradients with their dependencies: - // - // grad[kcur] = grad[S1].T @ qcur - // grad[S1] = diag_mask_zero(grad[S3], P) * scale - // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // grad[S4] = grad[S5] @ vcur - // grad[S4] = d[:D,id1,id2,id3] @ vcur - // grad[qcur] = grad[S1] @ kcur - // grad[vcur] = grad[S5].T @ S4 - // grad[vcur] = d[:D,id1,id2,id3].T @ S4 - // - // in post-order: - // - // S1 = qcur @ kcur.T - // S2 = S1 * scale - // S3 = diag_mask_inf(S2, P) - // S4 = softmax(S3) - // grad[S4] = d[:D,id1,id2,id3] @ vcur - // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // grad[S1] = diag_mask_zero(grad[S3], P) * scale - // grad[qcur] = grad[S1] @ kcur - // grad[kcur] = grad[S1].T @ qcur - // grad[vcur] = d[:D,id1,id2,id3].T @ S4 - // - // using less variables (SM=S4): - // - // S = diag_mask_inf(qcur @ kcur.T * scale, P) - // SM = softmax(S) - // S = d[:D,iq1,iq2,iq3] @ vcur - // dot_SM_gradSM = dot(SM, S) - // S = SM * (S - dot(SM, S)) - // S = diag_mask_zero(S, P) * scale - // - // grad[q][:D,iq1,iq2,iq3] += S @ kcur - // grad[k][:D,:M,ik2,ik3] += S.T @ qcur - // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM - } - - // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] - // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] - // for ic: - // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3] - // exclude known future zero S[..] values from operation - lm_ggml_vec_set_f32(masked_begin, S, 0); - for (int64_t ic = 0; ic < D; ++ic) { - lm_ggml_vec_mad_f32(masked_begin, - S, - (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); - } - - // S = SM * (S - dot(SM, S)) - float dot_SM_gradSM = 0; - lm_ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1); - lm_ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); - lm_ggml_vec_mul_f32 (masked_begin, S, S, SM); - - // S = diag_mask_zero(S, P) * scale - // already done by above lm_ggml_vec_set_f32 - - // exclude known zero S[..] values from operation - lm_ggml_vec_scale_f32(masked_begin, S, scale); - - // S shape [M,1] - // SM shape [M,1] - // kcur shape [D,M] - // qcur shape [D,1] - // vcur shape [M,D] - - // grad[q][:D,iq1,iq2,iq3] += S @ kcur - // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] - // for ic: - // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3] - // exclude known zero S[..] values from loop - for (int64_t ic = 0; ic < masked_begin; ++ic) { - lm_ggml_vec_mad_f32(D, - (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)), - (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)), - S[ic]); - } - - // grad[k][:D,:M,iq2,iq3] += S.T @ qcur - // for ic: - // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] - // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] - // exclude known zero S[..] values from loop - for (int64_t ic = 0; ic < masked_begin; ++ic) { - lm_ggml_vec_mad_f32(D, - (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)), - (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), - S[ic]); - } - - // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM - // for ic: - // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M] - // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M] - // exclude known zero SM[..] values from mad - for (int64_t ic = 0; ic < D; ++ic) { - lm_ggml_vec_mad_f32(masked_begin, - (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)), - SM, - *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); - } - } - } - } -} - -static void lm_ggml_compute_forward_flash_attn_back( - const struct lm_ggml_compute_params * params, - const bool masked, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * q = dst->src[0]; - - switch (q->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_flash_attn_back_f32(params, masked, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_ssm_conv - -static void lm_ggml_compute_forward_ssm_conv_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - const struct lm_ggml_tensor * src0 = dst->src[0]; // conv_x - const struct lm_ggml_tensor * src1 = dst->src[1]; // conv1d.weight - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src1->ne[0]; // d_conv - const int ncs = src0->ne[0]; // d_conv - 1 + n_t - const int nr = src0->ne[1]; // d_inner - const int n_t = dst->ne[1]; // tokens per sequence - const int n_s = dst->ne[2]; // number of sequences in the batch - - LM_GGML_ASSERT( dst->ne[0] == nr); - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - LM_GGML_ASSERT(src1->nb[0] == sizeof(float)); - LM_GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - const int ir = ir1 - ir0; - - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - // {d_conv - 1 + n_t, d_inner, n_seqs} - // sliding window - const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} - const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} - float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} - - // TODO: transpose the output for smaller strides for big batches? - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // rowwise dot product - // NOTE: not using lm_ggml_vec_dot_f32, because its sum is in double precision - float sumf = 0.0f; - - // d_conv - for (int i0 = 0; i0 < nc; ++i0) { - sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; - } - x[i1] = sumf; - } - } - } -} - -static void lm_ggml_compute_forward_ssm_conv( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - switch (dst->src[0]->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_ssm_conv_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_ssm_scan - -static void lm_ggml_compute_forward_ssm_scan_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - const struct lm_ggml_tensor * src0 = dst->src[0]; // s - const struct lm_ggml_tensor * src1 = dst->src[1]; // x - const struct lm_ggml_tensor * src2 = dst->src[2]; // dt - const struct lm_ggml_tensor * src3 = dst->src[3]; // A - const struct lm_ggml_tensor * src4 = dst->src[4]; // B - const struct lm_ggml_tensor * src5 = dst->src[5]; // C - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens per sequence - const int64_t n_s = src0->ne[2]; // number of sequences in the batch - - LM_GGML_ASSERT(lm_ggml_nelements(src1) + lm_ggml_nelements(src0) == lm_ggml_nelements(dst)); - LM_GGML_ASSERT(src0->nb[0] == sizeof(float)); - LM_GGML_ASSERT(src1->nb[0] == sizeof(float)); - LM_GGML_ASSERT(src2->nb[0] == sizeof(float)); - LM_GGML_ASSERT(src3->nb[0] == sizeof(float)); - LM_GGML_ASSERT(src4->nb[0] == sizeof(float)); - LM_GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C - LM_GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required for per-sequence offsets for states - LM_GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[3]) - LM_GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - const int ir = ir1 - ir0; - - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} - - // use the output as the source for the next token-wise iterations - if (i2 > 0) { s0 = s; } - - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; - } - y[i1] = sumf; - } - } - } -} - -static void lm_ggml_compute_forward_ssm_scan( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - switch (dst->src[0]->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_ssm_scan_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_win_part - -static void lm_ggml_compute_forward_win_part_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - UNUSED(params); - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - LM_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - - const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t w = ((const int32_t *)(dst->op_params))[2]; - - assert(ne00 == ne0); - assert(ne3 == nep0*nep1); - - // TODO: optimize / multi-thread - for (int py = 0; py < nep1; ++py) { - for (int px = 0; px < nep0; ++px) { - const int64_t i3 = py*nep0 + px; - for (int64_t i2 = 0; i2 < ne2; ++i2) { - for (int64_t i1 = 0; i1 < ne1; ++i1) { - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int64_t i02 = py*w + i2; - const int64_t i01 = px*w + i1; - const int64_t i00 = i0; - - const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0; - const int64_t j = i02*ne01*ne00 + i01*ne00 + i00; - - if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { - ((float *) dst->data)[i] = 0.0f; - } else { - ((float *) dst->data)[i] = ((float *) src0->data)[j]; - } - } - } - } - } - } -} - -static void lm_ggml_compute_forward_win_part( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_win_part_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_win_unpart - -static void lm_ggml_compute_forward_win_unpart_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - UNUSED(params); - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - LM_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - LM_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - - const int32_t w = ((const int32_t *)(dst->op_params))[0]; - - // padding - const int px = (w - ne1%w)%w; - //const int py = (w - ne2%w)%w; - - const int npx = (px + ne1)/w; - //const int npy = (py + ne2)/w; - - assert(ne0 == ne00); - - // TODO: optimize / multi-thread - for (int64_t i2 = 0; i2 < ne2; ++i2) { - for (int64_t i1 = 0; i1 < ne1; ++i1) { - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int ip2 = i2/w; - const int ip1 = i1/w; - - const int64_t i02 = i2%w; - const int64_t i01 = i1%w; - const int64_t i00 = i0; - - const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00; - const int64_t j = i2*ne1*ne0 + i1*ne0 + i0; - - ((float *) dst->data)[j] = ((float *) src0->data)[i]; - } - } - } -} - -static void lm_ggml_compute_forward_win_unpart( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_win_unpart_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -//gmml_compute_forward_unary - -static void lm_ggml_compute_forward_unary( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const enum lm_ggml_unary_op op = lm_ggml_get_unary_op(dst); - - switch (op) { - case LM_GGML_UNARY_OP_ABS: - { - lm_ggml_compute_forward_abs(params, dst); - } break; - case LM_GGML_UNARY_OP_SGN: - { - lm_ggml_compute_forward_sgn(params, dst); - } break; - case LM_GGML_UNARY_OP_NEG: - { - lm_ggml_compute_forward_neg(params, dst); - } break; - case LM_GGML_UNARY_OP_STEP: - { - lm_ggml_compute_forward_step(params, dst); - } break; - case LM_GGML_UNARY_OP_TANH: - { - lm_ggml_compute_forward_tanh(params, dst); - } break; - case LM_GGML_UNARY_OP_ELU: - { - lm_ggml_compute_forward_elu(params, dst); - } break; - case LM_GGML_UNARY_OP_RELU: - { - lm_ggml_compute_forward_relu(params, dst); - } break; - case LM_GGML_UNARY_OP_SIGMOID: - { - lm_ggml_compute_forward_sigmoid(params, dst); - } break; - case LM_GGML_UNARY_OP_GELU: - { - lm_ggml_compute_forward_gelu(params, dst); - } break; - case LM_GGML_UNARY_OP_GELU_QUICK: - { - lm_ggml_compute_forward_gelu_quick(params, dst); - } break; - case LM_GGML_UNARY_OP_SILU: - { - lm_ggml_compute_forward_silu(params, dst); - } break; - case LM_GGML_UNARY_OP_HARDSWISH: - { - lm_ggml_compute_forward_hardswish(params, dst); - } break; - case LM_GGML_UNARY_OP_HARDSIGMOID: - { - lm_ggml_compute_forward_hardsigmoid(params, dst); - } break; - case LM_GGML_UNARY_OP_EXP: - { - lm_ggml_compute_forward_exp(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_get_rel_pos - -static void lm_ggml_compute_forward_get_rel_pos_f16( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - UNUSED(params); - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 - - LM_GGML_TENSOR_UNARY_OP_LOCALS - - const int64_t w = ne1; - - lm_ggml_fp16_t * src0_data = (lm_ggml_fp16_t *) src0->data; - lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *) dst->data; - - for (int64_t i2 = 0; i2 < ne2; ++i2) { - for (int64_t i1 = 0; i1 < ne1; ++i1) { - const int64_t pos = (w - i1 - 1) + i2; - for (int64_t i0 = 0; i0 < ne0; ++i0) { - dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; - } - } - } -} - -static void lm_ggml_compute_forward_get_rel_pos( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F16: - case LM_GGML_TYPE_BF16: - { - lm_ggml_compute_forward_get_rel_pos_f16(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_add_rel_pos - -static void lm_ggml_compute_forward_add_rel_pos_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - const struct lm_ggml_tensor * src2 = dst->src[2]; - - const bool inplace = (bool) ((int32_t *) dst->op_params)[0]; - if (!inplace) { - if (params->ith == 0) { - memcpy((char *) dst->data, (char *) src0->data, lm_ggml_nbytes(dst)); - } - lm_ggml_barrier(params->threadpool); - } - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359 - - float * src1_data = (float *) src1->data; - float * src2_data = (float *) src2->data; - float * dst_data = (float *) dst->data; - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - - const int ith = params->ith; - const int nth = params->nth; - - // total patches in dst - const int np = ne13; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - const int ip0 = dp*ith; - const int ip1 = MIN(ip0 + dp, np); - - for (int64_t i13 = ip0; i13 < ip1; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { - const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10; - for (int64_t i10 = 0; i10 < ne10; ++i10) { - const int64_t jp0 = jp1 + i10; - const float src1_e = src1_data[jp0]; - const float src2_e = src2_data[jp0]; - - const int64_t jdh = jp0 * ne10; - const int64_t jdw = jdh - (ne10 - 1) * i10; - - for (int64_t j = 0; j < ne10; ++j) { - dst_data[jdh + j ] += src2_e; - dst_data[jdw + j*ne10] += src1_e; - } - } - } - } - } -} - -static void lm_ggml_compute_forward_add_rel_pos( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_add_rel_pos_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_rwkv_wkv - -static void lm_ggml_compute_forward_rwkv_wkv_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - const size_t T = dst->src[1]->ne[3]; - const size_t C = dst->ne[0]; - const size_t H = dst->src[1]->ne[2]; - const size_t n_seqs = dst->src[5]->ne[1]; - - float * dst_data = (float *) dst->data; - float * state = ((float *) dst->data) + C * T; - - if (params->ith != 0) { - return; - } - - memset(dst_data, 0, T * C * sizeof(float)); - - float * k = (float *) dst->src[0]->data; - float * v = (float *) dst->src[1]->data; - float * r = (float *) dst->src[2]->data; - float * time_faaaa = (float *) dst->src[3]->data; - float * time_decay = (float *) dst->src[4]->data; - - size_t t_stride = H * (C / H); - - size_t h_stride = C / H; - size_t h_stride_2d = (C / H) * (C / H); - - // basically fused operations: - // dst = r @ (time_faaaa * (k @ v) + state), - // state = time_decay * state + (k @ v), - // recursive through each token - for (size_t t = 0; t < T; t++) { - size_t t_offset = t * t_stride; - size_t state_offset = (C / H) * C * (t / (T / n_seqs)); - float * state_cur = state + state_offset; - float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - - for (size_t h = 0; h < H; h++) { - size_t h_offset = h * h_stride; - size_t t_h_offset = t_offset + h_offset; - size_t h_2d_offset = h * h_stride_2d; - - for (size_t i = 0; i < C / H; i++) { - size_t t_h_i_offset = t_h_offset + i; - size_t h_i_offset = h_offset + i; - size_t h_2d_i_offset = h_2d_offset + i * h_stride; - - float k_val = k[t_h_i_offset]; - float r_val = r[t_h_i_offset]; - float time_faaaa_val = time_faaaa[h_i_offset]; - // RWKV v6: different time_decay for each token. - float time_decay_val = time_decay[t_h_i_offset]; - - for (size_t j = 0; j < C / H; j ++) { - size_t t_h_j_offset = t_h_offset + j; - size_t h_2d_i_j_offset = h_2d_i_offset + j; - - float v_val = v[t_h_j_offset]; - float kv_val = v_val * k_val; - float prev_state_val = state_prev[h_2d_i_j_offset]; - float temp_val = kv_val * time_faaaa_val + prev_state_val; - dst_data[t_h_j_offset] += temp_val * r_val; - state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; - } - } - } - } -} - -static void lm_ggml_compute_forward_rwkv_wkv( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_rwkv_wkv_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_map_unary - -static void lm_ggml_compute_forward_map_unary_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst, - const lm_ggml_unary_op_f32_t fun) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - fun(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void lm_ggml_compute_forward_map_unary( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst, - const lm_ggml_unary_op_f32_t fun) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_map_unary_f32(params, dst, fun); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_map_binary - -static void lm_ggml_compute_forward_map_binary_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst, - const lm_ggml_binary_op_f32_t fun) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - if (params->ith != 0) { - return; - } - - assert(lm_ggml_is_contiguous_1(src0)); - assert(lm_ggml_is_contiguous_1(src1)); - assert(lm_ggml_is_contiguous_1(dst)); - assert(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); - - const int n = lm_ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - fun(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), - (float *) ((char *) src1->data + i*(src1->nb[1]))); - } -} - -static void lm_ggml_compute_forward_map_binary( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst, - const lm_ggml_binary_op_f32_t fun) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_map_binary_f32(params, dst, fun); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_map_custom1 - -static void lm_ggml_compute_forward_map_custom1_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst, - const lm_ggml_custom1_op_f32_t fun) { - - const struct lm_ggml_tensor * a = dst->src[0]; - - if (params->ith != 0) { - return; - } - - fun(dst, a); -} - -// lm_ggml_compute_forward_map_custom2 - -static void lm_ggml_compute_forward_map_custom2_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst, - const lm_ggml_custom2_op_f32_t fun) { - - const struct lm_ggml_tensor * a = dst->src[0]; - const struct lm_ggml_tensor * b = dst->src[1]; - - if (params->ith != 0) { - return; - } - - fun(dst, a, b); -} - -// lm_ggml_compute_forward_map_custom3 - -static void lm_ggml_compute_forward_map_custom3_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst, - const lm_ggml_custom3_op_f32_t fun) { - - const struct lm_ggml_tensor * a = dst->src[0]; - const struct lm_ggml_tensor * b = dst->src[1]; - const struct lm_ggml_tensor * c = dst->src[1]; - - if (params->ith != 0) { - return; - } - - fun(dst, a, b, c); -} - -// lm_ggml_compute_forward_map_custom1 - -static void lm_ggml_compute_forward_map_custom1( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * a = dst->src[0]; - - struct lm_ggml_map_custom1_op_params p; - memcpy(&p, dst->op_params, sizeof(p)); - - p.fun(dst, a, params->ith, params->nth, p.userdata); -} - -// lm_ggml_compute_forward_map_custom2 - -static void lm_ggml_compute_forward_map_custom2( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * a = dst->src[0]; - const struct lm_ggml_tensor * b = dst->src[1]; - - struct lm_ggml_map_custom2_op_params p; - memcpy(&p, dst->op_params, sizeof(p)); - - p.fun(dst, a, b, params->ith, params->nth, p.userdata); -} - -// lm_ggml_compute_forward_map_custom3 - -static void lm_ggml_compute_forward_map_custom3( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * a = dst->src[0]; - const struct lm_ggml_tensor * b = dst->src[1]; - const struct lm_ggml_tensor * c = dst->src[2]; - - struct lm_ggml_map_custom3_op_params p; - memcpy(&p, dst->op_params, sizeof(p)); - - p.fun(dst, a, b, c, params->ith, params->nth, p.userdata); -} - -// lm_ggml_compute_forward_cross_entropy_loss - -static void lm_ggml_compute_forward_cross_entropy_loss_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - - LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(src0->nb[0] == lm_ggml_type_size(src0->type)); - LM_GGML_ASSERT(src1->nb[0] == lm_ggml_type_size(src1->type)); - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1)); - LM_GGML_ASSERT(lm_ggml_is_scalar(dst)); - LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F32); - - // TODO: handle transposed/permuted matrices - const int64_t nc = src0->ne[0]; - const int64_t nr = lm_ggml_nrows(src0); - - const int ith = params->ith; - const int nth = params->nth; - - float * sums = (float *) params->wdata; - float * st = ((float *) params->wdata) + nth + ith*nc; - float sum_thread = 0.0f; - - LM_GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc)); - - // rows per thread - const int64_t dr = (nr + nth - 1)/nth; - - // row range for this thread - const int64_t ir0 = dr*ith; - const int64_t ir1 = MIN(ir0 + dr, nr); - - for (int64_t i1 = ir0; i1 < ir1; ++i1) { - const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]); - const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]); - -#ifndef NDEBUG - for (int64_t i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(s0[i])); - assert(!isnan(s1[i])); - } -#endif - - float max = -INFINITY; - lm_ggml_vec_max_f32(nc, &max, s0); - const lm_ggml_float sum_softmax = lm_ggml_vec_log_soft_max_f32(nc, st, s0, max); - assert(sum_softmax >= 0.0); - - lm_ggml_vec_add1_f32(nc, st, st, -sum_softmax); - lm_ggml_vec_mul_f32(nc, st, st, s1); - - float sum_st = 0.0f; - lm_ggml_vec_sum_f32(nc, &sum_st, st); - sum_thread += sum_st; - -#ifndef NDEBUG - for (int64_t i = 0; i < nc; ++i) { - assert(!isnan(st[i])); - assert(!isinf(st[i])); - } -#endif - } - sums[ith] = sum_thread; - lm_ggml_barrier(params->threadpool); - - if (ith == 0) { - float * dp = (float *) dst->data; - lm_ggml_vec_sum_f32(nth, dp, sums); - dp[0] *= -1.0f / (float) nr; - } -} - -static void lm_ggml_compute_forward_cross_entropy_loss( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_cross_entropy_loss_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -// lm_ggml_compute_forward_cross_entropy_loss_back - -static void lm_ggml_compute_forward_cross_entropy_loss_back_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src1 = dst->src[1]; - const struct lm_ggml_tensor * opt0 = dst->src[2]; - - LM_GGML_ASSERT(lm_ggml_is_contiguous(dst)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(src1)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(opt0)); - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst)); - - const int64_t ith = params->ith; - const int64_t nth = params->nth; - - // TODO: handle transposed/permuted matrices - const int64_t nc = src0->ne[0]; - const int64_t nr = lm_ggml_nrows(src0); - - // rows per thread - const int64_t dr = (nr + nth - 1)/nth; - - // row range for this thread - const int64_t ir0 = dr*ith; - const int64_t ir1 = MIN(ir0 + dr, nr); - - const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr; - - for (int64_t i1 = ir0; i1 < ir1; i1++) { - float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); - float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); - float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); - -#ifndef NDEBUG - for (int64_t i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(s0[i])); - assert(!isnan(s1[i])); - } -#endif - - // soft_max - float max = -INFINITY; - lm_ggml_vec_max_f32(nc, &max, s0); - lm_ggml_float sum = lm_ggml_vec_soft_max_f32(nc, ds0, s0, max); - assert(sum > 0.0); - lm_ggml_vec_scale_f32(nc, ds0, 1.0/sum); - - // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr - lm_ggml_vec_sub_f32(nc, ds0, ds0, s1); - lm_ggml_vec_scale_f32(nc, ds0, d_by_nr); - -#ifndef NDEBUG - for (int64_t i = 0; i < nc; ++i) { - assert(!isnan(ds0[i])); - assert(!isinf(ds0[i])); - } -#endif - } -} - -static void lm_ggml_compute_forward_cross_entropy_loss_back( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_cross_entropy_loss_back_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -static void lm_ggml_compute_forward_opt_step_adamw_f32( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src0_grad = dst->src[1]; - const struct lm_ggml_tensor * src0_grad_m = dst->src[2]; - const struct lm_ggml_tensor * src0_grad_v = dst->src[3]; - LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src0_grad)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = lm_ggml_nrows(src0); - - LM_GGML_TENSOR_UNARY_OP_LOCALS - LM_GGML_ASSERT(nb00 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - /* const float gnorm = 1.0f; */ - int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t)); - const float alpha = lm_ggml_get_op_params_f32(dst, 2); - const float beta1 = lm_ggml_get_op_params_f32(dst, 3); - const float beta2 = lm_ggml_get_op_params_f32(dst, 4); - const float eps = lm_ggml_get_op_params_f32(dst, 5); - const float wd = lm_ggml_get_op_params_f32(dst, 6); - - const float beta1h = alpha/(1.0f - powf(beta1, iter)); - const float beta2h = 1.0f/(1.0f - powf(beta2, iter)); - - for (int ir = ir0; ir < ir1; ++ir) { - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const size_t offset = i03*nb03 + i02*nb02 + i01*nb01; - - float * w = (float *) ((char *) src0->data + offset); // weight - const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad - float * m = (float *) ((char *) src0_grad_m->data + offset); - float * v = (float *) ((char *) src0_grad_v->data + offset); - - for (int i00 = 0; i00 < ne00; ++i00) { - m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1); - v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2); - - const float mh = m[i00]*beta1h; - const float vh = sqrtf(v[i00]*beta2h) + eps; - - // The weight decay is applied independently of the Adam momenta m and v. - // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss. - // See: https://arxiv.org/pdf/1711.05101v3.pdf - w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh; - } - } - - lm_ggml_barrier(params->threadpool); - if (ith != 0) { - return; - } - - iter++; - memcpy(&dst->op_params[0], &iter, sizeof(int64_t)); -} - -static void lm_ggml_compute_forward_opt_step_adamw( - const struct lm_ggml_compute_params * params, - struct lm_ggml_tensor * dst) { - - const struct lm_ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case LM_GGML_TYPE_F32: - { - lm_ggml_compute_forward_opt_step_adamw_f32(params, dst); - } break; - default: - { - LM_GGML_ABORT("fatal error"); - } - } -} -///////////////////////////////// - -static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * tensor) { - LM_GGML_ASSERT(params); - - if (tensor->op == LM_GGML_OP_NONE || lm_ggml_is_empty(tensor)) { - return; - } - - switch (tensor->op) { - case LM_GGML_OP_DUP: - { - lm_ggml_compute_forward_dup(params, tensor); - } break; - case LM_GGML_OP_ADD: - { - lm_ggml_compute_forward_add(params, tensor); - } break; - case LM_GGML_OP_ADD1: - { - lm_ggml_compute_forward_add1(params, tensor); - } break; - case LM_GGML_OP_ACC: - { - lm_ggml_compute_forward_acc(params, tensor); - } break; - case LM_GGML_OP_SUB: - { - lm_ggml_compute_forward_sub(params, tensor); - } break; - case LM_GGML_OP_MUL: - { - lm_ggml_compute_forward_mul(params, tensor); - } break; - case LM_GGML_OP_DIV: - { - lm_ggml_compute_forward_div(params, tensor); - } break; - case LM_GGML_OP_SQR: - { - lm_ggml_compute_forward_sqr(params, tensor); - } break; - case LM_GGML_OP_SQRT: - { - lm_ggml_compute_forward_sqrt(params, tensor); - } break; - case LM_GGML_OP_LOG: - { - lm_ggml_compute_forward_log(params, tensor); - } break; - case LM_GGML_OP_SIN: - { - lm_ggml_compute_forward_sin(params, tensor); - } break; - case LM_GGML_OP_COS: - { - lm_ggml_compute_forward_cos(params, tensor); - } break; - case LM_GGML_OP_SUM: - { - lm_ggml_compute_forward_sum(params, tensor); - } break; - case LM_GGML_OP_SUM_ROWS: - { - lm_ggml_compute_forward_sum_rows(params, tensor); - } break; - case LM_GGML_OP_MEAN: - { - lm_ggml_compute_forward_mean(params, tensor); - } break; - case LM_GGML_OP_ARGMAX: - { - lm_ggml_compute_forward_argmax(params, tensor); - } break; - case LM_GGML_OP_COUNT_EQUAL: - { - lm_ggml_compute_forward_count_equal(params, tensor); - } break; - case LM_GGML_OP_REPEAT: - { - lm_ggml_compute_forward_repeat(params, tensor); - } break; - case LM_GGML_OP_REPEAT_BACK: - { - lm_ggml_compute_forward_repeat_back(params, tensor); - } break; - case LM_GGML_OP_CONCAT: - { - lm_ggml_compute_forward_concat(params, tensor); - } break; - case LM_GGML_OP_SILU_BACK: - { - lm_ggml_compute_forward_silu_back(params, tensor); - } break; - case LM_GGML_OP_NORM: - { - lm_ggml_compute_forward_norm(params, tensor); - } break; - case LM_GGML_OP_RMS_NORM: - { - lm_ggml_compute_forward_rms_norm(params, tensor); - } break; - case LM_GGML_OP_RMS_NORM_BACK: - { - lm_ggml_compute_forward_rms_norm_back(params, tensor); - } break; - case LM_GGML_OP_GROUP_NORM: - { - lm_ggml_compute_forward_group_norm(params, tensor); - } break; - case LM_GGML_OP_MUL_MAT: - { - lm_ggml_compute_forward_mul_mat(params, tensor); - } break; - case LM_GGML_OP_MUL_MAT_ID: - { - lm_ggml_compute_forward_mul_mat_id(params, tensor); - } break; - case LM_GGML_OP_OUT_PROD: - { - lm_ggml_compute_forward_out_prod(params, tensor); - } break; - case LM_GGML_OP_SCALE: - { - lm_ggml_compute_forward_scale(params, tensor); - } break; - case LM_GGML_OP_SET: - { - lm_ggml_compute_forward_set(params, tensor); - } break; - case LM_GGML_OP_CPY: - { - lm_ggml_compute_forward_cpy(params, tensor); - } break; - case LM_GGML_OP_CONT: - { - lm_ggml_compute_forward_cont(params, tensor); - } break; - case LM_GGML_OP_RESHAPE: - { - lm_ggml_compute_forward_reshape(params, tensor); - } break; - case LM_GGML_OP_VIEW: - { - lm_ggml_compute_forward_view(params, tensor); - } break; - case LM_GGML_OP_PERMUTE: - { - lm_ggml_compute_forward_permute(params, tensor); - } break; - case LM_GGML_OP_TRANSPOSE: - { - lm_ggml_compute_forward_transpose(params, tensor); - } break; - case LM_GGML_OP_GET_ROWS: - { - lm_ggml_compute_forward_get_rows(params, tensor); - } break; - case LM_GGML_OP_GET_ROWS_BACK: - { - lm_ggml_compute_forward_get_rows_back(params, tensor); - } break; - case LM_GGML_OP_DIAG: - { - lm_ggml_compute_forward_diag(params, tensor); - } break; - case LM_GGML_OP_DIAG_MASK_INF: - { - lm_ggml_compute_forward_diag_mask_inf(params, tensor); - } break; - case LM_GGML_OP_DIAG_MASK_ZERO: - { - lm_ggml_compute_forward_diag_mask_zero(params, tensor); - } break; - case LM_GGML_OP_SOFT_MAX: - { - lm_ggml_compute_forward_soft_max(params, tensor); - } break; - case LM_GGML_OP_SOFT_MAX_BACK: - { - lm_ggml_compute_forward_soft_max_back(params, tensor); - } break; - case LM_GGML_OP_ROPE: - { - lm_ggml_compute_forward_rope(params, tensor); - } break; - case LM_GGML_OP_ROPE_BACK: - { - lm_ggml_compute_forward_rope_back(params, tensor); - } break; - case LM_GGML_OP_CLAMP: - { - lm_ggml_compute_forward_clamp(params, tensor); - } break; - case LM_GGML_OP_CONV_TRANSPOSE_1D: - { - lm_ggml_compute_forward_conv_transpose_1d(params, tensor); - } break; - case LM_GGML_OP_IM2COL: - { - lm_ggml_compute_forward_im2col(params, tensor); - } break; - case LM_GGML_OP_IM2COL_BACK: - { - lm_ggml_compute_forward_im2col_back_f32(params, tensor); - } break; - case LM_GGML_OP_CONV_TRANSPOSE_2D: - { - lm_ggml_compute_forward_conv_transpose_2d(params, tensor); - } break; - case LM_GGML_OP_POOL_1D: - { - lm_ggml_compute_forward_pool_1d(params, tensor); - } break; - case LM_GGML_OP_POOL_2D: - { - lm_ggml_compute_forward_pool_2d(params, tensor); - } break; - case LM_GGML_OP_POOL_2D_BACK: - { - lm_ggml_compute_forward_pool_2d_back(params, tensor); - } break; - case LM_GGML_OP_UPSCALE: - { - lm_ggml_compute_forward_upscale(params, tensor); - } break; - case LM_GGML_OP_PAD: - { - lm_ggml_compute_forward_pad(params, tensor); - } break; - case LM_GGML_OP_ARANGE: - { - lm_ggml_compute_forward_arange(params, tensor); - } break; - case LM_GGML_OP_TIMESTEP_EMBEDDING: - { - lm_ggml_compute_forward_timestep_embedding(params, tensor); - } break; - case LM_GGML_OP_ARGSORT: - { - lm_ggml_compute_forward_argsort(params, tensor); - } break; - case LM_GGML_OP_LEAKY_RELU: - { - lm_ggml_compute_forward_leaky_relu(params, tensor); - } break; - case LM_GGML_OP_FLASH_ATTN_EXT: - { - lm_ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); - } break; - case LM_GGML_OP_FLASH_ATTN_BACK: - { - int32_t t = lm_ggml_get_op_params_i32(tensor, 0); - LM_GGML_ASSERT(t == 0 || t == 1); - bool masked = t != 0; - lm_ggml_compute_forward_flash_attn_back(params, masked, tensor); - } break; - case LM_GGML_OP_SSM_CONV: - { - lm_ggml_compute_forward_ssm_conv(params, tensor); - } break; - case LM_GGML_OP_SSM_SCAN: - { - lm_ggml_compute_forward_ssm_scan(params, tensor); - } break; - case LM_GGML_OP_WIN_PART: - { - lm_ggml_compute_forward_win_part(params, tensor); - } break; - case LM_GGML_OP_WIN_UNPART: - { - lm_ggml_compute_forward_win_unpart(params, tensor); - } break; - case LM_GGML_OP_UNARY: - { - lm_ggml_compute_forward_unary(params, tensor); - } break; - case LM_GGML_OP_GET_REL_POS: - { - lm_ggml_compute_forward_get_rel_pos(params, tensor); - } break; - case LM_GGML_OP_ADD_REL_POS: - { - lm_ggml_compute_forward_add_rel_pos(params, tensor); - } break; - case LM_GGML_OP_RWKV_WKV: - { - lm_ggml_compute_forward_rwkv_wkv(params, tensor); - } break; - case LM_GGML_OP_MAP_UNARY: - { - lm_ggml_unary_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - lm_ggml_compute_forward_map_unary(params, tensor, fun); - } - break; - case LM_GGML_OP_MAP_BINARY: - { - lm_ggml_binary_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - lm_ggml_compute_forward_map_binary(params, tensor, fun); - } - break; - case LM_GGML_OP_MAP_CUSTOM1_F32: - { - lm_ggml_custom1_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - lm_ggml_compute_forward_map_custom1_f32(params, tensor, fun); - } - break; - case LM_GGML_OP_MAP_CUSTOM2_F32: - { - lm_ggml_custom2_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - lm_ggml_compute_forward_map_custom2_f32(params, tensor, fun); - } - break; - case LM_GGML_OP_MAP_CUSTOM3_F32: - { - lm_ggml_custom3_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - lm_ggml_compute_forward_map_custom3_f32(params, tensor, fun); - } - break; - case LM_GGML_OP_MAP_CUSTOM1: - { - lm_ggml_compute_forward_map_custom1(params, tensor); - } - break; - case LM_GGML_OP_MAP_CUSTOM2: - { - lm_ggml_compute_forward_map_custom2(params, tensor); - } - break; - case LM_GGML_OP_MAP_CUSTOM3: - { - lm_ggml_compute_forward_map_custom3(params, tensor); - } - break; - case LM_GGML_OP_CROSS_ENTROPY_LOSS: - { - lm_ggml_compute_forward_cross_entropy_loss(params, tensor); - } - break; - case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - lm_ggml_compute_forward_cross_entropy_loss_back(params, tensor); - } - break; - case LM_GGML_OP_OPT_STEP_ADAMW: - { - lm_ggml_compute_forward_opt_step_adamw(params, tensor); - } - break; - case LM_GGML_OP_NONE: - { - // nop - } break; - case LM_GGML_OP_COUNT: - { - LM_GGML_ABORT("fatal error"); - } - } -} - -//////////////////////////////////////////////////////////////////////////////// - -struct lm_ggml_hash_set lm_ggml_hash_set_new(size_t size) { - size = lm_ggml_hash_size(size); - struct lm_ggml_hash_set result; - result.size = size; - result.keys = LM_GGML_MALLOC(sizeof(struct lm_ggml_tensor *) * size); - result.used = LM_GGML_CALLOC(lm_ggml_bitset_size(size), sizeof(lm_ggml_bitset_t)); - return result; -} - -void lm_ggml_hash_set_reset(struct lm_ggml_hash_set * hash_set) { - memset(hash_set->used, 0, sizeof(lm_ggml_bitset_t) * lm_ggml_bitset_size(hash_set->size)); -} - -void lm_ggml_hash_set_free(struct lm_ggml_hash_set * hash_set) { - LM_GGML_FREE(hash_set->used); - LM_GGML_FREE(hash_set->keys); -} - -size_t lm_ggml_hash_size(size_t min_sz) { - // next primes after powers of two - static const size_t primes[] = { - 2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031, - 2053, 4099, 8209, 16411, 32771, 65537, 131101, - 262147, 524309, 1048583, 2097169, 4194319, 8388617, - 16777259, 33554467, 67108879, 134217757, 268435459, - 536870923, 1073741827, 2147483659 - }; - static const size_t n_primes = sizeof(primes)/sizeof(primes[0]); - - // find the smallest prime that is larger or equal than min_sz - size_t l = 0; - size_t r = n_primes; - while (l < r) { - size_t m = (l + r)/2; - if (primes[m] < min_sz) { - l = m + 1; - } else { - r = m; - } - } - size_t sz = l < n_primes ? primes[l] : min_sz | 1; - return sz; -} - -struct hash_map { - struct lm_ggml_hash_set set; - struct lm_ggml_tensor ** vals; -}; - -static struct hash_map * lm_ggml_new_hash_map(size_t size) { - struct hash_map * result = LM_GGML_MALLOC(sizeof(struct hash_map)); - result->set = lm_ggml_hash_set_new(size); - result->vals = LM_GGML_CALLOC(result->set.size, sizeof(struct lm_ggml_tensor *)); - return result; -} - -static void lm_ggml_hash_map_free(struct hash_map * map) { - lm_ggml_hash_set_free(&map->set); - LM_GGML_FREE(map->vals); - LM_GGML_FREE(map); -} - -// gradient checkpointing - -static struct lm_ggml_tensor * lm_ggml_recompute_graph_node( - struct lm_ggml_context * ctx, - struct lm_ggml_cgraph * graph, - struct hash_map * replacements, - struct lm_ggml_tensor * node) { - - if (node == NULL) { - return NULL; - } - - if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) { - return node; - } - - if (!lm_ggml_hash_contains(&graph->visited_hash_set, node)) { - return node; - } - - int count_children = 0; - for (int k = 0; k < LM_GGML_MAX_SRC; ++k) { - if (node->src[k]) { - ++count_children; - } - } - - if (count_children == 0) { - return node; - } - - size_t i = lm_ggml_hash_find(&replacements->set, node); - LM_GGML_ASSERT(i != LM_GGML_HASHSET_FULL); // assert that not full - if (replacements->set.keys[i] == node) { - return replacements->vals[i]; - } - - struct lm_ggml_tensor * clone = lm_ggml_new_tensor(ctx, node->type, LM_GGML_MAX_DIMS, node->ne); - - // insert clone into replacements - LM_GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite - replacements->set.keys[i] = node; - replacements->vals[i] = clone; - - clone->op = node->op; - clone->grad = node->grad; - clone->flags = node->flags; - clone->extra = node->extra; - for (int k = 0; k < LM_GGML_MAX_DIMS; ++k) { - clone->nb[k] = node->nb[k]; - } - for (int k = 0; k < LM_GGML_MAX_SRC; ++k) { - clone->src[k] = lm_ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); - } - if (node->view_src != NULL) { - clone->data = (node->view_src->data == NULL) - ? NULL // view_src not yet allocated - : (char *) node->view_src->data // view_src already allocated - + node->view_offs; - clone->view_src = node->view_src; - clone->view_offs = node->view_offs; - } - - LM_GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (LM_GGML_MAX_OP_PARAMS / sizeof(int32_t))); - LM_GGML_ASSERT(sizeof(node->name) == LM_GGML_MAX_NAME); - memcpy(clone->op_params, node->op_params, sizeof(node->op_params)); - lm_ggml_format_name(clone, "%s (clone)", lm_ggml_get_name(node)); - - return clone; -} - -void lm_ggml_build_backward_gradient_checkpointing( - struct lm_ggml_context * ctx, - struct lm_ggml_cgraph * gf, - struct lm_ggml_cgraph * gb, - struct lm_ggml_cgraph * gb_tmp, - struct lm_ggml_tensor * * checkpoints, - int n_checkpoints) { - lm_ggml_graph_cpy(gf, gb_tmp); - lm_ggml_build_backward_expand(ctx, gf, gb_tmp, false); - - if (n_checkpoints <= 0) { - lm_ggml_graph_cpy(gb_tmp, gb); - return; - } - - struct hash_map * replacements = lm_ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints); - - // insert checkpoints in replacements - for (int i = 0; i < n_checkpoints; ++i) { - size_t k = lm_ggml_hash_find(&replacements->set, checkpoints[i]); - LM_GGML_ASSERT(k != LM_GGML_HASHSET_FULL); // assert that not full - LM_GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite - replacements->set.keys[k] = checkpoints[i]; - replacements->vals[k] = checkpoints[i]; - } - - lm_ggml_graph_cpy(gf, gb); - // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes], - // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]), - // by recomputing them from checkpoints - for (int i = gf->n_nodes; in_nodes; ++i) { - struct lm_ggml_tensor * node = gb_tmp->nodes[i]; - for (int k = 0; k < LM_GGML_MAX_SRC; ++k) { - // insert new tensors recomputing src, reusing already made replacements, - // remember replacements: remember new tensors with mapping from corresponding gf nodes - // recurse for input tensors, - // unless (i.e. terminating when) input tensors are replacements (like checkpoints) - node->src[k] = lm_ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); - } - // insert rewritten backward node with replacements made into resulting backward graph gb - lm_ggml_build_forward_expand(gb, node); - } - - lm_ggml_hash_map_free(replacements); -} - -// utility functions to change gradients -// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator -// else if a is in zero_table, replace a -// else, just add/subtract/etc. the gradients - -static struct lm_ggml_tensor * lm_ggml_add_or_set( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_hash_set * zero_table, - struct lm_ggml_hash_set * acc_table) { - if (lm_ggml_hash_contains(acc_table, a)) { - struct lm_ggml_tensor * ret = lm_ggml_add_impl(ctx, a, b, true); - const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); - return ret; - } - if (lm_ggml_hash_contains(zero_table, a)) { - return b; - } - return lm_ggml_add_impl(ctx, a, b, false); -} - -static struct lm_ggml_tensor * lm_ggml_acc_or_set( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - const size_t nb1, - const size_t nb2, - const size_t nb3, - const size_t offset, - struct lm_ggml_hash_set * zero_table, - struct lm_ggml_hash_set * acc_table) { - if (lm_ggml_hash_contains(acc_table, a)) { - struct lm_ggml_tensor * ret = lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); - const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); - return ret; - } - if (lm_ggml_hash_contains(zero_table, a)) { - struct lm_ggml_tensor * a_zero = lm_ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN - return lm_ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); - } - return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); -} - -static struct lm_ggml_tensor * lm_ggml_add1_or_set( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_hash_set * zero_table, - struct lm_ggml_hash_set * acc_table) { - if (lm_ggml_hash_contains(acc_table, a)) { - struct lm_ggml_tensor * ret = lm_ggml_add1_impl(ctx, a, b, true); - const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); - return ret; - } - if (lm_ggml_hash_contains(zero_table, a)) { - return lm_ggml_repeat(ctx, b, a); - } - return lm_ggml_add1_impl(ctx, a, b, false); -} - -static struct lm_ggml_tensor * lm_ggml_sub_or_set( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_hash_set * zero_table, - struct lm_ggml_hash_set * acc_table) { - if (lm_ggml_hash_contains(acc_table, a)) { - struct lm_ggml_tensor * ret = lm_ggml_sub_impl(ctx, a, b, true); - const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); - return ret; - } - if (lm_ggml_hash_contains(zero_table, a)) { - return lm_ggml_neg(ctx, b); - } - return lm_ggml_sub_impl(ctx, a, b, false); -} - -static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor, struct lm_ggml_hash_set * zero_table, struct lm_ggml_hash_set * acc_table) { - struct lm_ggml_tensor * src0 = tensor->src[0]; - struct lm_ggml_tensor * src1 = tensor->src[1]; - struct lm_ggml_tensor * src2 = tensor->src[2]; - - switch (tensor->op) { - case LM_GGML_OP_DUP: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - } break; - case LM_GGML_OP_ADD: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - if (lm_ggml_are_same_shape(src0, src1)) { - src1->grad = lm_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table); - } else { - src1->grad = lm_ggml_add_or_set(ctx, src1->grad, lm_ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table); - } - } - } break; - case LM_GGML_OP_ADD1: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - src1->grad = lm_ggml_add_or_set(ctx, - src1->grad, - lm_ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean - zero_table, acc_table); - } - } break; - case LM_GGML_OP_ACC: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - const size_t nb1 = ((int32_t *) tensor->op_params)[0]; - const size_t nb2 = ((int32_t *) tensor->op_params)[1]; - const size_t nb3 = ((int32_t *) tensor->op_params)[2]; - const size_t offset = ((int32_t *) tensor->op_params)[3]; - - struct lm_ggml_tensor * tensor_grad_view = lm_ggml_view_4d(ctx, - tensor->grad, - src1->grad->ne[0], - src1->grad->ne[1], - src1->grad->ne[2], - src1->grad->ne[3], - nb1, nb2, nb3, offset); - - src1->grad = - lm_ggml_add_or_set(ctx, - src1->grad, - lm_ggml_reshape(ctx, - lm_ggml_cont(ctx, tensor_grad_view), - src1->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SUB: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - src1->grad = lm_ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table); - } - } break; - case LM_GGML_OP_MUL: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_mul(ctx, src1, tensor->grad), - zero_table, acc_table); - } - if (src1->grad) { - src1->grad = - lm_ggml_add_or_set(ctx, - src1->grad, - lm_ggml_mul(ctx, src0, tensor->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_DIV: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_div(ctx, tensor->grad, src1), - zero_table, acc_table); - } - if (src1->grad) { - src1->grad = - lm_ggml_sub_or_set(ctx, - src1->grad, - lm_ggml_mul(ctx, - tensor->grad, - lm_ggml_div(ctx, tensor, src1)), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SQR: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_scale(ctx, - lm_ggml_mul(ctx, src0, tensor->grad), - 2.0f), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SQRT: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_scale(ctx, - lm_ggml_div(ctx, - tensor->grad, - tensor), - 0.5f), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_LOG: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_div(ctx, - tensor->grad, - src0), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SIN: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_mul(ctx, - tensor->grad, - lm_ggml_cos(ctx, src0)), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_COS: - { - if (src0->grad) { - src0->grad = - lm_ggml_sub_or_set(ctx, - src0->grad, - lm_ggml_mul(ctx, - tensor->grad, - lm_ggml_sin(ctx, src0)), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SUM: - { - if (src0->grad) { - src0->grad = - lm_ggml_add1_or_set(ctx, - src0->grad, - tensor->grad, - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SUM_ROWS: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_repeat(ctx, - tensor->grad, - src0->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_MEAN: - case LM_GGML_OP_ARGMAX: - case LM_GGML_OP_COUNT_EQUAL: - { - LM_GGML_ABORT("fatal error"); // TODO: implement - } - case LM_GGML_OP_REPEAT: - { - // necessary for llama - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_repeat_back(ctx, tensor->grad, src0->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_REPEAT_BACK: - { - if (src0->grad) { - // TODO: test this - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_repeat(ctx, tensor->grad, src0->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_CONCAT: - { - LM_GGML_ABORT("fatal error"); // TODO: implement - } - case LM_GGML_OP_SILU_BACK: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_NORM: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_RMS_NORM: - { - // necessary for llama - if (src0->grad) { - float eps; - memcpy(&eps, tensor->op_params, sizeof(float)); - - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_rms_norm_back(ctx, src0, tensor->grad, eps), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_RMS_NORM_BACK: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_GROUP_NORM: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_MUL_MAT: - { - // https://cs231n.github.io/optimization-2/#staged - // # forward pass - // s0 = np.random.randn(5, 10) - // s1 = np.random.randn(10, 3) - // t = s0.dot(s1) - - // # now suppose we had the gradient on t from above in the circuit - // dt = np.random.randn(*t.shape) # same shape as t - // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix - // ds1 = t.T.dot(dt) - - // tensor.shape [m,p,qq,rr] - // src0.shape [n,m,q1,r1] - // src1.shape [n,p,qq,rr] - - // necessary for llama - if (src0->grad) { - struct lm_ggml_tensor * s1_tg = - lm_ggml_out_prod(ctx, // [n,m,qq,rr] - src1, // [n,p,qq,rr] - tensor->grad); // [m,p,qq,rr] - const int64_t qq = s1_tg->ne[2]; - const int64_t rr = s1_tg->ne[3]; - const int64_t q1 = src0->ne[2]; - const int64_t r1 = src0->ne[3]; - const bool ne2_broadcasted = qq > q1; - const bool ne3_broadcasted = rr > r1; - if (ne2_broadcasted || ne3_broadcasted) { - // sum broadcast repetitions of s1_tg into shape of src0 - s1_tg = lm_ggml_repeat_back(ctx, s1_tg, src0); - } - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, // [n,m,q1,r1] - s1_tg, // [n,m,q1,r1] - zero_table, acc_table); - } - if (src1->grad) { - src1->grad = - lm_ggml_add_or_set(ctx, - src1->grad, // [n,p,qq,rr] - // lm_ggml_mul_mat(ctx, // [n,p,qq,rr] - // lm_ggml_cont(ctx, // [m,n,q1,r1] - // lm_ggml_transpose(ctx, src0)), // [m,n,q1,r1] - // tensor->grad), // [m,p,qq,rr] - - // // when src0 is bigger than tensor->grad (this is mostly the case in llama), - // // avoid transpose of src0, rather transpose smaller tensor->grad - // // and then use lm_ggml_out_prod - lm_ggml_out_prod(ctx, // [n,p,qq,rr] - src0, // [n,m,q1,r1] - lm_ggml_transpose(ctx, // [p,m,qq,rr] - tensor->grad)), // [m,p,qq,rr] - zero_table, acc_table); - } - } break; - case LM_GGML_OP_MUL_MAT_ID: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_OUT_PROD: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_SCALE: - { - // necessary for llama - if (src0->grad) { - float s; - memcpy(&s, tensor->op_params, sizeof(float)); - - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_scale_impl(ctx, tensor->grad, s, false), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SET: - { - const size_t nb1 = ((int32_t *) tensor->op_params)[0]; - const size_t nb2 = ((int32_t *) tensor->op_params)[1]; - const size_t nb3 = ((int32_t *) tensor->op_params)[2]; - const size_t offset = ((int32_t *) tensor->op_params)[3]; - - struct lm_ggml_tensor * tensor_grad_view = NULL; - - if (src0->grad || src1->grad) { - LM_GGML_ASSERT(src0->type == tensor->type); - LM_GGML_ASSERT(tensor->grad->type == tensor->type); - LM_GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type); - - tensor_grad_view = lm_ggml_view_4d(ctx, - tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - nb1, nb2, nb3, offset); - } - - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_acc_impl(ctx, - tensor->grad, - lm_ggml_neg(ctx, tensor_grad_view), - nb1, nb2, nb3, offset, false), - zero_table, acc_table); - } - - if (src1->grad) { - src1->grad = - lm_ggml_add_or_set(ctx, - src1->grad, - lm_ggml_reshape(ctx, - lm_ggml_cont(ctx, tensor_grad_view), - src1->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_CPY: - { - // necessary for llama - // cpy overwrites value of src1 by src0 and returns view(src1) - // the overwriting is mathematically equivalent to: - // tensor = src0 * 1 + src1 * 0 - if (src0->grad) { - // dsrc0 = dtensor * 1 - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - // dsrc1 = dtensor * 0 -> noop - } - } break; - case LM_GGML_OP_CONT: - { - // same as cpy - if (src0->grad) { - LM_GGML_ASSERT(lm_ggml_is_contiguous(src0->grad)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(tensor->grad)); - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - } break; - case LM_GGML_OP_RESHAPE: - { - // necessary for llama - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, src0->grad, - lm_ggml_reshape(ctx, - lm_ggml_is_contiguous(tensor->grad) - ? tensor->grad - : lm_ggml_cont(ctx, tensor->grad), - src0->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_VIEW: - { - // necessary for llama - if (src0->grad) { - size_t offset; - - memcpy(&offset, tensor->op_params, sizeof(offset)); - - size_t nb1 = tensor->nb[1]; - size_t nb2 = tensor->nb[2]; - size_t nb3 = tensor->nb[3]; - - if (src0->type != src0->grad->type) { - // gradient is typically F32, but src0 could be other type - size_t ng = lm_ggml_element_size(src0->grad); - size_t n0 = lm_ggml_element_size(src0); - LM_GGML_ASSERT(offset % n0 == 0); - LM_GGML_ASSERT(nb1 % n0 == 0); - LM_GGML_ASSERT(nb2 % n0 == 0); - LM_GGML_ASSERT(nb3 % n0 == 0); - offset = (offset / n0) * ng; - nb1 = (nb1 / n0) * ng; - nb2 = (nb2 / n0) * ng; - nb3 = (nb3 / n0) * ng; - } - - src0->grad = lm_ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table); - } - } break; - case LM_GGML_OP_PERMUTE: - { - // necessary for llama - if (src0->grad) { - int32_t * axes = (int32_t *) tensor->op_params; - int axis0 = axes[0] & 0x3; - int axis1 = axes[1] & 0x3; - int axis2 = axes[2] & 0x3; - int axis3 = axes[3] & 0x3; - int axes_backward[4] = {0,0,0,0}; - axes_backward[axis0] = 0; - axes_backward[axis1] = 1; - axes_backward[axis2] = 2; - axes_backward[axis3] = 3; - src0->grad = - lm_ggml_add_or_set(ctx, src0->grad, - lm_ggml_permute(ctx, - tensor->grad, - axes_backward[0], - axes_backward[1], - axes_backward[2], - axes_backward[3]), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_TRANSPOSE: - { - // necessary for llama - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, src0->grad, - lm_ggml_transpose(ctx, tensor->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_GET_ROWS: - { - // necessary for llama (only for tokenizer) - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, src0->grad, - // last lm_ggml_get_rows_back argument src0->grad is only - // necessary to setup correct output shape - lm_ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), - zero_table, acc_table); - } - if (src1->grad) { - // noop - } - } break; - case LM_GGML_OP_GET_ROWS_BACK: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_DIAG: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_DIAG_MASK_INF: - { - // necessary for llama - if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; - src0->grad = - lm_ggml_add_or_set(ctx, src0->grad, - /* lm_ggml_diag_mask_inf_impl() shouldn't be here */ - /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ - lm_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_DIAG_MASK_ZERO: - { - // necessary for llama - if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; - src0->grad = - lm_ggml_add_or_set(ctx, src0->grad, - lm_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SOFT_MAX: - { - // necessary for llama - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, src0->grad, - lm_ggml_soft_max_back(ctx, tensor->grad, tensor), - zero_table, acc_table); - } - LM_GGML_ASSERT((!src1 || !src1->grad) && "backward pass for softmax mask not implemented"); - } break; - case LM_GGML_OP_SOFT_MAX_BACK: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_ROPE: - { - // necessary for llama - if (src0->grad) { - //const int n_past = ((int32_t *) tensor->op_params)[0]; - const int n_dims = ((int32_t *) tensor->op_params)[1]; - const int mode = ((int32_t *) tensor->op_params)[2]; - //const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - - memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_rope_back(ctx, - tensor->grad, - src1, - src2, - n_dims, - mode, - n_ctx_orig, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow), - zero_table, acc_table); - } - LM_GGML_ASSERT((!src2 || !src2->grad) && "gradients for freq factors not implemented"); - } break; - case LM_GGML_OP_ROPE_BACK: - { - if (src0->grad) { - //const int n_past = ((int32_t *) tensor->op_params)[0]; - const int n_dims = ((int32_t *) tensor->op_params)[1]; - const int mode = ((int32_t *) tensor->op_params)[2]; - //const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - - memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_rope_impl(ctx, - tensor->grad, - src1, - src2, - n_dims, - mode, - n_ctx_orig, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow, - false), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_CLAMP: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_CONV_TRANSPOSE_1D: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_IM2COL: - { - if (src1->grad) { - const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 0); - const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 1); - const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 2); - const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 3); - const int32_t d0 = lm_ggml_get_op_params_i32(tensor, 4); - const int32_t d1 = lm_ggml_get_op_params_i32(tensor, 5); - const bool is_2D = lm_ggml_get_op_params_i32(tensor, 6) == 1; - - src1->grad = lm_ggml_add_or_set(ctx, - src1->grad, - lm_ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_IM2COL_BACK: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_CONV_TRANSPOSE_2D: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_POOL_1D: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_POOL_2D: - { - if (src0->grad) { - const enum lm_ggml_op_pool op = lm_ggml_get_op_params_i32(tensor, 0); - const int32_t k0 = lm_ggml_get_op_params_i32(tensor, 1); - const int32_t k1 = lm_ggml_get_op_params_i32(tensor, 2); - const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 3); - const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 4); - const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 5); - const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 6); - - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_POOL_2D_BACK: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_UPSCALE: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_PAD: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_ARANGE: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_TIMESTEP_EMBEDDING: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_ARGSORT: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_LEAKY_RELU: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_FLASH_ATTN_EXT: - { - LM_GGML_ABORT("FA backward pass not adapted after rework"); - struct lm_ggml_tensor * flash_grad = NULL; - if (src0->grad || src1->grad || tensor->src[2]->grad) { - int32_t t = lm_ggml_get_op_params_i32(tensor, 0); - LM_GGML_ASSERT(t == 0 || t == 1); - bool masked = t != 0; - flash_grad = - lm_ggml_flash_attn_back(ctx, - src0, - src1, - tensor->src[2], - tensor->grad, - masked); - } - - const int64_t elem_q = lm_ggml_nelements(src0); - const int64_t elem_k = lm_ggml_nelements(src1); - const int64_t elem_v = lm_ggml_nelements(src2); - - enum lm_ggml_type result_type = flash_grad->type; - LM_GGML_ASSERT(lm_ggml_blck_size(result_type) == 1); - const size_t tsize = lm_ggml_type_size(result_type); - - const size_t offs_q = 0; - const size_t offs_k = offs_q + LM_GGML_PAD(elem_q * tsize, LM_GGML_MEM_ALIGN); - const size_t offs_v = offs_k + LM_GGML_PAD(elem_k * tsize, LM_GGML_MEM_ALIGN); - - if (src0->grad) { - struct lm_ggml_tensor * view_q = lm_ggml_view_1d(ctx, flash_grad, elem_q, offs_q); - struct lm_ggml_tensor * grad_q = lm_ggml_reshape(ctx, view_q, src0); - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - grad_q, - zero_table, acc_table); - } - if (src1->grad) { - struct lm_ggml_tensor * view_k = lm_ggml_view_1d(ctx, flash_grad, elem_k, offs_k); - struct lm_ggml_tensor * grad_k = lm_ggml_reshape(ctx, view_k, src1); - src1->grad = lm_ggml_add_or_set(ctx, - src1->grad, - grad_k, - zero_table, acc_table); - } - if (src2->grad) { - struct lm_ggml_tensor * view_v = lm_ggml_view_1d(ctx, flash_grad, elem_v, offs_v); - struct lm_ggml_tensor * grad_v = lm_ggml_reshape(ctx, view_v, src2); - src2->grad = lm_ggml_add_or_set(ctx, - src2->grad, - grad_v, - zero_table, acc_table); - } - } break; - case LM_GGML_OP_FLASH_ATTN_BACK: - { - LM_GGML_ABORT("fatal error"); // not supported - } - case LM_GGML_OP_SSM_CONV: - case LM_GGML_OP_SSM_SCAN: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_OP_WIN_PART: - case LM_GGML_OP_WIN_UNPART: - case LM_GGML_OP_UNARY: - { - switch (lm_ggml_get_unary_op(tensor)) { - case LM_GGML_UNARY_OP_ABS: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_mul(ctx, - lm_ggml_sgn(ctx, src0), - tensor->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_UNARY_OP_SGN: - { - if (src0->grad) { - // noop - } - } break; - case LM_GGML_UNARY_OP_NEG: - { - if (src0->grad) { - src0->grad = lm_ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - } break; - case LM_GGML_UNARY_OP_STEP: - { - if (src0->grad) { - // noop - } - } break; - case LM_GGML_UNARY_OP_TANH: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_UNARY_OP_ELU: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_UNARY_OP_RELU: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_mul(ctx, - lm_ggml_step(ctx, src0), - tensor->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_UNARY_OP_SIGMOID: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_UNARY_OP_GELU: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_UNARY_OP_GELU_QUICK: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_UNARY_OP_SILU: - { - // necessary for llama - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_silu_back(ctx, src0, tensor->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_UNARY_OP_EXP: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_mul(ctx, tensor, tensor->grad), - zero_table, acc_table); - } - } break; - default: - LM_GGML_ABORT("fatal error"); - } - } break; - case LM_GGML_OP_GET_REL_POS: - case LM_GGML_OP_ADD_REL_POS: - case LM_GGML_OP_RWKV_WKV: - case LM_GGML_OP_MAP_UNARY: - case LM_GGML_OP_MAP_BINARY: - case LM_GGML_OP_MAP_CUSTOM1_F32: - case LM_GGML_OP_MAP_CUSTOM2_F32: - case LM_GGML_OP_MAP_CUSTOM3_F32: - case LM_GGML_OP_MAP_CUSTOM1: - case LM_GGML_OP_MAP_CUSTOM2: - case LM_GGML_OP_MAP_CUSTOM3: - { - LM_GGML_ABORT("fatal error"); // not supported - } - case LM_GGML_OP_CROSS_ENTROPY_LOSS: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_cross_entropy_loss_back(ctx, - src0, - src1, - tensor->grad), - zero_table, acc_table); - } - LM_GGML_ASSERT(!src1->grad && "backward pass for labels not implemented"); - } break; - case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - LM_GGML_ABORT("fatal error"); // not supported - } - case LM_GGML_OP_OPT_STEP_ADAMW: - { - LM_GGML_ABORT("fatal error"); // not supported - } - case LM_GGML_OP_NONE: - { - // nop - } break; - case LM_GGML_OP_COUNT: - { - LM_GGML_ABORT("fatal error"); - } - } - - for (int i = 0; i < LM_GGML_MAX_SRC; ++i) { - if (tensor->src[i] && tensor->src[i]->grad) { - LM_GGML_ASSERT(lm_ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad)); - } - } -} - -static void lm_ggml_visit_parents(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * node) { - if (node->grad == NULL) { - // this usually happens when we generate intermediate nodes from constants in the backward pass - // it can also happen during forward pass, if the user performs computations with constants - if (node->op != LM_GGML_OP_NONE) { - //LM_GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); - } - } - - // check if already visited - if (lm_ggml_hash_insert(&cgraph->visited_hash_set, node) == LM_GGML_HASHSET_ALREADY_EXISTS) { - return; - } - - for (int i = 0; i < LM_GGML_MAX_SRC; ++i) { - const int k = - (cgraph->order == LM_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : - (cgraph->order == LM_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (LM_GGML_MAX_SRC-1-i) : - /* unknown order, just fall back to using i*/ i; - if (node->src[k]) { - lm_ggml_visit_parents(cgraph, node->src[k]); - } - } - - if (node->op == LM_GGML_OP_NONE && !(node->flags & LM_GGML_TENSOR_FLAG_PARAM)) { - // reached a leaf node, not part of the gradient graph (e.g. a constant) - LM_GGML_ASSERT(cgraph->n_leafs < cgraph->size); - - if (strlen(node->name) == 0) { - lm_ggml_format_name(node, "leaf_%d", cgraph->n_leafs); - } - - cgraph->leafs[cgraph->n_leafs] = node; - cgraph->n_leafs++; - } else { - LM_GGML_ASSERT(cgraph->n_nodes < cgraph->size); - - if (strlen(node->name) == 0) { - lm_ggml_format_name(node, "node_%d", cgraph->n_nodes); - } - - cgraph->nodes[cgraph->n_nodes] = node; - cgraph->n_nodes++; - } -} - -static void lm_ggml_build_forward_impl(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor, bool expand) { - if (!expand) { - // TODO: this branch isn't accessible anymore, maybe move this to lm_ggml_build_forward_expand - lm_ggml_graph_clear(cgraph); - } - - const int n0 = cgraph->n_nodes; - - lm_ggml_visit_parents(cgraph, tensor); - - const int n_new = cgraph->n_nodes - n0; - LM_GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); - - if (n_new > 0) { - // the last added node should always be starting point - LM_GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor); - } -} - -void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor) { - lm_ggml_build_forward_impl(cgraph, tensor, true); -} - -void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool accumulate) { - LM_GGML_ASSERT(gf->n_nodes > 0); - LM_GGML_ASSERT(gf->grads); - - for (int i = 0; i < gf->n_nodes; ++i) { - struct lm_ggml_tensor * node = gf->nodes[i]; - - if (node->type == LM_GGML_TYPE_I32) { - continue; - } - - bool needs_grad = node->flags & LM_GGML_TENSOR_FLAG_PARAM; - bool ignore_src[LM_GGML_MAX_SRC] = {false}; - switch (node->op) { - // gradients in node->src[0] for one reason or another have no effect on output gradients - case LM_GGML_OP_IM2COL: // only used for its shape - case LM_GGML_OP_IM2COL_BACK: // same as IM2COL - ignore_src[0] = true; - break; - case LM_GGML_OP_UNARY: { - const enum lm_ggml_unary_op uop = lm_ggml_get_unary_op(node); - // SGN and STEP unary ops are piecewise constant - if (uop == LM_GGML_UNARY_OP_SGN || uop == LM_GGML_UNARY_OP_STEP) { - ignore_src[0] = true; - } - } break; - - // gradients in node->src[1] for one reason or another have no effect on output gradients - case LM_GGML_OP_CPY: // gradients in CPY target are irrelevant - case LM_GGML_OP_GET_ROWS: // row indices not differentiable - case LM_GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS - case LM_GGML_OP_ROPE: // positions not differentiable - ignore_src[1] = true; - break; - - default: - break; - } - for (int j = 0; j < LM_GGML_MAX_SRC; ++j) { - if (!node->src[j] || !node->src[j]->grad || ignore_src[j]) { - continue; - } - LM_GGML_ASSERT(node->src[j]->type == LM_GGML_TYPE_F32 || node->src[j]->type == LM_GGML_TYPE_F16); - needs_grad = true; - break; - } - if (!needs_grad) { - continue; - } - - // inplace operations are currently not supported - LM_GGML_ASSERT(!node->view_src || node->op == LM_GGML_OP_CPY || node->op == LM_GGML_OP_VIEW || - node->op == LM_GGML_OP_RESHAPE || node->op == LM_GGML_OP_PERMUTE || node->op == LM_GGML_OP_TRANSPOSE); - - // create a new tensor with the same type and shape as the node and set it as grad - node->grad = lm_ggml_dup_tensor(ctx, node); - } - - // keep tables of original gradients for replacement/accumulation logic - struct lm_ggml_hash_set zero_table = lm_ggml_hash_set_new(gf->size); - struct lm_ggml_hash_set acc_table = lm_ggml_hash_set_new(gf->size); - for (int i = 0; i < gf->n_nodes; i++) { - struct lm_ggml_tensor * node = gf->nodes[i]; - - if (node->grad) { - { - const size_t insert_result = lm_ggml_hash_insert(&zero_table, node->grad); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); - } - - // only gradients of trainable parameters should be accumulated - if (accumulate && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) { - const size_t insert_result = lm_ggml_hash_insert(&acc_table, node->grad); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); - } - } - } - - for (int i = gf->n_nodes - 1; i >= 0; i--) { - struct lm_ggml_tensor * node = gf->nodes[i]; - - // inplace operations to add gradients are not created by lm_ggml_compute_backward except for gradient accumulation - // use allocator to automatically make inplace operations - if (node->grad) { - lm_ggml_compute_backward(ctx, node, &zero_table, &acc_table); - } - } - - for (int i = 0; i < gf->n_nodes; i++) { - struct lm_ggml_tensor * node = gf->nodes[i]; - - if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) { - LM_GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - lm_ggml_build_forward_expand(gb, node->grad); - } - } - - lm_ggml_hash_set_free(&zero_table); - lm_ggml_hash_set_free(&acc_table); -} - -void lm_ggml_build_opt_adamw( - struct lm_ggml_context * ctx, - struct lm_ggml_cgraph * gf, - struct lm_ggml_cgraph * gb, - float alpha, - float beta1, - float beta2, - float eps, - float wd) { - for (int i = 0; i < gf->n_nodes; i++) { - struct lm_ggml_tensor * node = gf->nodes[i]; - - if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) { - LM_GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - struct lm_ggml_tensor * opt_step = lm_ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd); - lm_ggml_build_forward_expand(gb, opt_step); - } - } -} - - -static void * incr_ptr_aligned(void ** p, size_t size, size_t align) { - void * ptr = *p; - ptr = (void *) LM_GGML_PAD((uintptr_t) ptr, align); - *p = (void *) ((char *) ptr + size); - return ptr; -} - -static size_t lm_ggml_graph_nbytes(size_t size, bool grads) { - size_t hash_size = lm_ggml_hash_size(size * 2); - void * p = 0; - incr_ptr_aligned(&p, sizeof(struct lm_ggml_cgraph), 1); - incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // nodes - incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // leafs - incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // hash keys - if (grads) { - incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // grads - } - incr_ptr_aligned(&p, lm_ggml_bitset_size(hash_size) * sizeof(lm_ggml_bitset_t), sizeof(lm_ggml_bitset_t)); - - size_t nbytes = (size_t) p; - return nbytes; -} - -size_t lm_ggml_graph_overhead_custom(size_t size, bool grads) { - return LM_GGML_OBJECT_SIZE + LM_GGML_PAD(lm_ggml_graph_nbytes(size, grads), LM_GGML_MEM_ALIGN); -} - -size_t lm_ggml_graph_overhead(void) { - return lm_ggml_graph_overhead_custom(LM_GGML_DEFAULT_GRAPH_SIZE, false); -} - -struct lm_ggml_cgraph * lm_ggml_new_graph_custom(struct lm_ggml_context * ctx, size_t size, bool grads) { - const size_t obj_size = lm_ggml_graph_nbytes(size, grads); - struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_GRAPH, obj_size); - struct lm_ggml_cgraph * cgraph = (struct lm_ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); - - // the size of the hash table is doubled since it needs to hold both nodes and leafs - size_t hash_size = lm_ggml_hash_size(size * 2); - - void * p = cgraph + 1; - - struct lm_ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); - struct lm_ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); - struct lm_ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); - struct lm_ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)) : NULL; - lm_ggml_bitset_t * hash_used = incr_ptr_aligned(&p, lm_ggml_bitset_size(hash_size) * sizeof(lm_ggml_bitset_t), sizeof(lm_ggml_bitset_t)); - - // check that we allocated the correct amount of memory - assert(obj_size == (size_t)((char *)p - (char *)cgraph)); - - *cgraph = (struct lm_ggml_cgraph) { - /*.size =*/ size, - /*.n_nodes =*/ 0, - /*.n_leafs =*/ 0, - /*.nodes =*/ nodes_ptr, - /*.grads =*/ grads_ptr, - /*.leafs =*/ leafs_ptr, - /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, - /*.order =*/ LM_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, - }; - - lm_ggml_hash_set_reset(&cgraph->visited_hash_set); - - return cgraph; -} - -struct lm_ggml_cgraph * lm_ggml_new_graph(struct lm_ggml_context * ctx) { - return lm_ggml_new_graph_custom(ctx, LM_GGML_DEFAULT_GRAPH_SIZE, false); -} - -struct lm_ggml_cgraph lm_ggml_graph_view(struct lm_ggml_cgraph * cgraph0, int i0, int i1) { - struct lm_ggml_cgraph cgraph = { - /*.size =*/ 0, - /*.n_nodes =*/ i1 - i0, - /*.n_leafs =*/ 0, - /*.nodes =*/ cgraph0->nodes + i0, - /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL, - /*.leafs =*/ NULL, - /*.hash_table =*/ { 0, NULL, NULL }, - /*.order =*/ cgraph0->order, - }; - - return cgraph; -} - -void lm_ggml_graph_cpy(struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst) { - LM_GGML_ASSERT(dst->size >= src->n_leafs); - LM_GGML_ASSERT(dst->size >= src->n_nodes); - LM_GGML_ASSERT(dst->visited_hash_set.size >= src->visited_hash_set.size); - - dst->n_leafs = src->n_leafs; - dst->n_nodes = src->n_nodes; - dst->order = src->order; - - for (int i = 0; i < src->n_leafs; ++i) { - dst->leafs[i] = src->leafs[i]; - } - - for (int i = 0; i < src->n_nodes; ++i) { - dst->nodes[i] = src->nodes[i]; - } - - if (src->grads) { - LM_GGML_ASSERT(dst->grads != NULL); - for (int i = 0; i < src->n_nodes; ++i) { - dst->grads[i] = src->grads[i]; - } - } - - for (size_t i = 0; i < src->visited_hash_set.size; ++i) { - // copy all hashset keys (tensors) that are in use - if (lm_ggml_bitset_get(src->visited_hash_set.used, i)) { - lm_ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]); - } - } -} - -struct lm_ggml_cgraph * lm_ggml_graph_dup(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph) { - struct lm_ggml_cgraph * result = lm_ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL); - lm_ggml_graph_cpy(cgraph, result); - return result; -} - -void lm_ggml_graph_reset(struct lm_ggml_cgraph * cgraph) { - LM_GGML_ASSERT(cgraph->grads != NULL); - - for (int i = 0; i < cgraph->n_nodes; i++) { - struct lm_ggml_tensor * node = cgraph->nodes[i]; - - // initial gradients of loss should be 1, 0 otherwise - if (node->grad) { - if (node->flags & LM_GGML_TENSOR_FLAG_LOSS) { - LM_GGML_ASSERT(node->grad->buffer); - LM_GGML_ASSERT(node->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(lm_ggml_is_scalar(node)); - - const float onef = 1.0f; - lm_ggml_backend_tensor_set(node->grad, &onef, 0, lm_ggml_nbytes(node->grad)); - } else { - lm_ggml_set_zero(node->grad); - } - } - - LM_GGML_ASSERT(node); - if (node->op == LM_GGML_OP_OPT_STEP_ADAMW) { - // set iteration to 1 and clear momenta - lm_ggml_set_op_params_i32(node, 0, 1); - lm_ggml_set_zero(node->src[2]); - lm_ggml_set_zero(node->src[3]); - } - } -} - -void lm_ggml_graph_clear(struct lm_ggml_cgraph * cgraph) { - cgraph->n_leafs = 0; - cgraph->n_nodes = 0; - lm_ggml_hash_set_reset(&cgraph->visited_hash_set); -} - -int lm_ggml_graph_size(struct lm_ggml_cgraph * cgraph) { - return cgraph->size; -} - -struct lm_ggml_tensor * lm_ggml_graph_node(struct lm_ggml_cgraph * cgraph, int i) { - if (i < 0) { - LM_GGML_ASSERT(cgraph->n_nodes + i >= 0); - return cgraph->nodes[cgraph->n_nodes + i]; - } - - LM_GGML_ASSERT(i < cgraph->n_nodes); - return cgraph->nodes[i]; -} - -struct lm_ggml_tensor ** lm_ggml_graph_nodes(struct lm_ggml_cgraph * cgraph) { - return cgraph->nodes; -} - -int lm_ggml_graph_n_nodes(struct lm_ggml_cgraph * cgraph) { - return cgraph->n_nodes; -} - -void lm_ggml_graph_add_node(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor) { - LM_GGML_ASSERT(cgraph->size > cgraph->n_nodes); - cgraph->nodes[cgraph->n_nodes] = tensor; - cgraph->n_nodes++; -} - -// Android's libc implementation "bionic" does not support setting affinity -#if defined(__gnu_linux__) -static void set_numa_thread_affinity(int thread_n) { - if (!lm_ggml_is_numa()) { - return; - } - - int node_num; - int rv; - size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); - - switch(g_state.numa.numa_strategy) { - case LM_GGML_NUMA_STRATEGY_DISTRIBUTE: - // run thread on node_num thread_n / (threads per node) - node_num = thread_n % g_state.numa.n_nodes; - break; - case LM_GGML_NUMA_STRATEGY_ISOLATE: - // run thread on current_node - node_num = g_state.numa.current_node; - break; - case LM_GGML_NUMA_STRATEGY_NUMACTL: - // use the cpuset that numactl gave us - rv = pthread_setaffinity_np(pthread_self(), setsize, &g_state.numa.cpuset); - if (rv) { - fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",strerror(rv)); - } - return; - default: - return; - } - - struct lm_ggml_numa_node * node = &g_state.numa.nodes[node_num]; - - cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); - CPU_ZERO_S(setsize, cpus); - for (size_t i = 0; i < node->n_cpus; ++i) { - CPU_SET_S(node->cpus[i], setsize, cpus); - } - - rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); - if (rv) { - fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv)); - } - - CPU_FREE(cpus); -} - -static void clear_numa_thread_affinity(void) { - if (!lm_ggml_is_numa()) { - return; - } - - size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); - - cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); - CPU_ZERO_S(setsize, cpus); - for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) { - CPU_SET_S(i, setsize, cpus); - } - - int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); - if (rv) { - fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv)); - } - - CPU_FREE(cpus); -} -#else -// TODO: Windows etc. -// (the linux implementation may also work on BSD, someone should test) -static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); } -static void clear_numa_thread_affinity(void) {} -#endif - -static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { - int n_tasks = 0; - - if (lm_ggml_is_empty(node)) { - // no need to multi-thread a no-op - n_tasks = 1; - return n_tasks; - } - - switch (node->op) { - case LM_GGML_OP_CPY: - case LM_GGML_OP_DUP: - case LM_GGML_OP_CONT: - case LM_GGML_OP_ADD: - case LM_GGML_OP_ADD1: - case LM_GGML_OP_ACC: - { - n_tasks = n_threads; + src1->grad = + lm_ggml_add_or_set(ctx, + src1->grad, + lm_ggml_reshape(ctx, + lm_ggml_cont(ctx, tensor_grad_view), + src1->grad), + zero_table, acc_table); + } } break; case LM_GGML_OP_SUB: - case LM_GGML_OP_SQR: - case LM_GGML_OP_SQRT: - case LM_GGML_OP_LOG: - case LM_GGML_OP_SIN: - case LM_GGML_OP_COS: - case LM_GGML_OP_SUM: - case LM_GGML_OP_SUM_ROWS: - case LM_GGML_OP_MEAN: - case LM_GGML_OP_ARGMAX: - { - n_tasks = 1; - } break; - case LM_GGML_OP_COUNT_EQUAL: - { - n_tasks = n_threads; - } break; - case LM_GGML_OP_REPEAT: - case LM_GGML_OP_REPEAT_BACK: - case LM_GGML_OP_LEAKY_RELU: { - n_tasks = 1; + if (src0->grad) { + src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); + } + if (src1->grad) { + src1->grad = lm_ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table); + } } break; - case LM_GGML_OP_UNARY: - switch (lm_ggml_get_unary_op(node)) { - case LM_GGML_UNARY_OP_ABS: - case LM_GGML_UNARY_OP_SGN: - case LM_GGML_UNARY_OP_NEG: - case LM_GGML_UNARY_OP_STEP: - case LM_GGML_UNARY_OP_TANH: - case LM_GGML_UNARY_OP_ELU: - case LM_GGML_UNARY_OP_RELU: - case LM_GGML_UNARY_OP_SIGMOID: - case LM_GGML_UNARY_OP_HARDSWISH: - case LM_GGML_UNARY_OP_HARDSIGMOID: - case LM_GGML_UNARY_OP_EXP: - { - n_tasks = 1; - } break; - - case LM_GGML_UNARY_OP_GELU: - case LM_GGML_UNARY_OP_GELU_QUICK: - case LM_GGML_UNARY_OP_SILU: - { - n_tasks = n_threads; - } break; - default: - LM_GGML_ABORT("fatal error"); - } - break; - case LM_GGML_OP_SILU_BACK: case LM_GGML_OP_MUL: - case LM_GGML_OP_DIV: - case LM_GGML_OP_NORM: - case LM_GGML_OP_RMS_NORM: - case LM_GGML_OP_RMS_NORM_BACK: - case LM_GGML_OP_GROUP_NORM: - case LM_GGML_OP_CONCAT: - case LM_GGML_OP_MUL_MAT: - case LM_GGML_OP_MUL_MAT_ID: - case LM_GGML_OP_OUT_PROD: - { - n_tasks = n_threads; - } break; - case LM_GGML_OP_GET_ROWS: - { - // FIXME: get_rows can use additional threads, but the cost of launching additional threads - // decreases performance with GPU offloading - //n_tasks = n_threads; - n_tasks = 1; - } break; - case LM_GGML_OP_SCALE: - case LM_GGML_OP_SET: - case LM_GGML_OP_RESHAPE: - case LM_GGML_OP_VIEW: - case LM_GGML_OP_PERMUTE: - case LM_GGML_OP_TRANSPOSE: - case LM_GGML_OP_GET_ROWS_BACK: - case LM_GGML_OP_DIAG: - { - n_tasks = 1; - } break; - case LM_GGML_OP_DIAG_MASK_ZERO: - case LM_GGML_OP_DIAG_MASK_INF: - case LM_GGML_OP_SOFT_MAX_BACK: - case LM_GGML_OP_ROPE: - case LM_GGML_OP_ROPE_BACK: - case LM_GGML_OP_ADD_REL_POS: - { - n_tasks = n_threads; - } break; - case LM_GGML_OP_CLAMP: - { - n_tasks = 1; //TODO - } break; - case LM_GGML_OP_SOFT_MAX: - { - n_tasks = MIN(n_threads, lm_ggml_nrows(node->src[0])); - } break; - case LM_GGML_OP_IM2COL: - case LM_GGML_OP_IM2COL_BACK: - case LM_GGML_OP_CONV_TRANSPOSE_1D: - case LM_GGML_OP_CONV_TRANSPOSE_2D: { - n_tasks = n_threads; + if (src0->grad) { + src0->grad = + lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_mul(ctx, src1, tensor->grad), + zero_table, acc_table); + } + if (src1->grad) { + src1->grad = + lm_ggml_add_or_set(ctx, + src1->grad, + lm_ggml_mul(ctx, src0, tensor->grad), + zero_table, acc_table); + } } break; - case LM_GGML_OP_POOL_1D: - case LM_GGML_OP_POOL_2D: - case LM_GGML_OP_POOL_2D_BACK: + case LM_GGML_OP_DIV: { - n_tasks = 1; + if (src0->grad) { + src0->grad = + lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_div(ctx, tensor->grad, src1), + zero_table, acc_table); + } + if (src1->grad) { + src1->grad = + lm_ggml_sub_or_set(ctx, + src1->grad, + lm_ggml_mul(ctx, + tensor->grad, + lm_ggml_div(ctx, tensor, src1)), + zero_table, acc_table); + } } break; - case LM_GGML_OP_UPSCALE: - case LM_GGML_OP_PAD: - case LM_GGML_OP_ARANGE: - case LM_GGML_OP_TIMESTEP_EMBEDDING: - case LM_GGML_OP_ARGSORT: - case LM_GGML_OP_FLASH_ATTN_EXT: - case LM_GGML_OP_FLASH_ATTN_BACK: - case LM_GGML_OP_SSM_CONV: - case LM_GGML_OP_SSM_SCAN: + case LM_GGML_OP_SQR: { - n_tasks = n_threads; + if (src0->grad) { + src0->grad = + lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_scale(ctx, + lm_ggml_mul(ctx, src0, tensor->grad), + 2.0f), + zero_table, acc_table); + } } break; - case LM_GGML_OP_WIN_PART: - case LM_GGML_OP_WIN_UNPART: - case LM_GGML_OP_GET_REL_POS: - case LM_GGML_OP_RWKV_WKV: - case LM_GGML_OP_MAP_UNARY: - case LM_GGML_OP_MAP_BINARY: - case LM_GGML_OP_MAP_CUSTOM1_F32: - case LM_GGML_OP_MAP_CUSTOM2_F32: - case LM_GGML_OP_MAP_CUSTOM3_F32: + case LM_GGML_OP_SQRT: { - n_tasks = 1; + if (src0->grad) { + src0->grad = + lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_scale(ctx, + lm_ggml_div(ctx, + tensor->grad, + tensor), + 0.5f), + zero_table, acc_table); + } } break; - case LM_GGML_OP_MAP_CUSTOM1: + case LM_GGML_OP_LOG: { - struct lm_ggml_map_custom1_op_params p; - memcpy(&p, node->op_params, sizeof(p)); - if (p.n_tasks == LM_GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p.n_tasks, n_threads); + if (src0->grad) { + src0->grad = + lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_div(ctx, + tensor->grad, + src0), + zero_table, acc_table); } } break; - case LM_GGML_OP_MAP_CUSTOM2: + case LM_GGML_OP_SIN: { - struct lm_ggml_map_custom2_op_params p; - memcpy(&p, node->op_params, sizeof(p)); - if (p.n_tasks == LM_GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p.n_tasks, n_threads); + if (src0->grad) { + src0->grad = + lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_mul(ctx, + tensor->grad, + lm_ggml_cos(ctx, src0)), + zero_table, acc_table); } } break; - case LM_GGML_OP_MAP_CUSTOM3: + case LM_GGML_OP_COS: { - struct lm_ggml_map_custom3_op_params p; - memcpy(&p, node->op_params, sizeof(p)); - if (p.n_tasks == LM_GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p.n_tasks, n_threads); + if (src0->grad) { + src0->grad = + lm_ggml_sub_or_set(ctx, + src0->grad, + lm_ggml_mul(ctx, + tensor->grad, + lm_ggml_sin(ctx, src0)), + zero_table, acc_table); } } break; - case LM_GGML_OP_CROSS_ENTROPY_LOSS: - case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK: - case LM_GGML_OP_OPT_STEP_ADAMW: + case LM_GGML_OP_SUM: { - n_tasks = n_threads; + if (src0->grad) { + src0->grad = + lm_ggml_add1_or_set(ctx, + src0->grad, + tensor->grad, + zero_table, acc_table); + } } break; - case LM_GGML_OP_NONE: + case LM_GGML_OP_SUM_ROWS: { - n_tasks = 1; + if (src0->grad) { + src0->grad = + lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_repeat(ctx, + tensor->grad, + src0->grad), + zero_table, acc_table); + } } break; - case LM_GGML_OP_COUNT: + case LM_GGML_OP_MEAN: + case LM_GGML_OP_ARGMAX: + case LM_GGML_OP_COUNT_EQUAL: { - LM_GGML_ABORT("fatal error"); + LM_GGML_ABORT("fatal error"); // TODO: implement } - default: + case LM_GGML_OP_REPEAT: { - fprintf(stderr, "%s: op not implemented: ", __func__); - if (node->op < LM_GGML_OP_COUNT) { - fprintf(stderr, "%s\n", lm_ggml_op_name(node->op)); - } else { - fprintf(stderr, "%d\n", node->op); + // necessary for llama + if (src0->grad) { + src0->grad = lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_repeat_back(ctx, tensor->grad, src0->grad), + zero_table, acc_table); } - LM_GGML_ABORT("fatal error"); - } - } - - assert(n_tasks > 0); - - return n_tasks; -} - -static thread_ret_t lm_ggml_graph_compute_secondary_thread(void* data); - -#if defined(_WIN32) -#include "windows.h" - -// TODO: support > 64 CPUs -bool lm_ggml_thread_apply_affinity(bool * mask) { - HANDLE h = GetCurrentThread(); - uint64_t bitmask = 0ULL; - - assert(LM_GGML_MAX_N_THREADS >= 64); - - for (int32_t i = 0; i < 8; i++) { - int32_t idx = i * 8; - uint8_t val = 0; - val |= mask[idx + 0] << 0; - val |= mask[idx + 1] << 1; - val |= mask[idx + 2] << 2; - val |= mask[idx + 3] << 3; - val |= mask[idx + 4] << 4; - val |= mask[idx + 5] << 5; - val |= mask[idx + 6] << 6; - val |= mask[idx + 7] << 7; - bitmask |= (uint64_t)val << idx; - } - - for (int32_t i = 64; i < LM_GGML_MAX_N_THREADS; i++) { - if (mask[i]) { - fprintf(stderr, "warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\n"); - break; - } - } - - DWORD_PTR m = (DWORD_PTR)bitmask; - - m = SetThreadAffinityMask(h, m); - - return m != 0; -} - -static bool lm_ggml_thread_apply_priority(int32_t prio) { - // Note that on Windows the Process Priority Class must be updated in order to set Thread priority. - // This is up to the applications. - DWORD p = THREAD_PRIORITY_NORMAL; - switch (prio) { - case LM_GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break; - case LM_GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break; - case LM_GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break; - case LM_GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break; - } - - if (prio == LM_GGML_SCHED_PRIO_NORMAL) { - // Keep inherited policy/priority - return true; - } - - if (!SetThreadPriority(GetCurrentThread(), p)) { - fprintf(stderr, "warn: failed to set thread priority %d : (%d)\n", prio, (int) GetLastError()); - return false; - } - - return true; -} - -#elif defined(__APPLE__) -#include -#include - -static bool lm_ggml_thread_apply_affinity(const bool * mask) { - // Not supported on Apple platforms - UNUSED(mask); - return true; -} - -static bool lm_ggml_thread_apply_priority(int32_t prio) { - struct sched_param p; - int32_t policy = SCHED_OTHER; - switch (prio) { - case LM_GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; - case LM_GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; - case LM_GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; - case LM_GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break; - } - - if (prio == LM_GGML_SCHED_PRIO_NORMAL) { - // Keep inherited policy/priority - return true; - } - - int32_t err = pthread_setschedparam(pthread_self(), policy, &p); - if (err != 0) { - fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err); - return false; - } - - return true; -} - -#elif defined(__gnu_linux__) -// TODO: this may not work on BSD, to be verified - -static bool lm_ggml_thread_apply_affinity(const bool * mask) { - cpu_set_t cpuset; - int err; - - CPU_ZERO(&cpuset); - - for (uint32_t i = 0; i < LM_GGML_MAX_N_THREADS; i++) { - if (mask[i]) { - LM_GGML_PRINT_DEBUG("Thread %lx: adding %d to cpuset\n", pthread_self(), i); - CPU_SET(i, &cpuset); - } - } - -#ifdef __ANDROID__ - err = sched_setaffinity(0, sizeof(cpuset), &cpuset); - if (err < 0) { - err = errno; - } -#else - err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset); -#endif - if (err != 0) { - fprintf(stderr, "warn: failed to set affinity mask 0x%llx : %s (%d)\n", (unsigned long long)mask, strerror(err), err); - return false; - } - - return true; -} - -static bool lm_ggml_thread_apply_priority(int32_t prio) { - struct sched_param p; - int32_t policy = SCHED_OTHER; - switch (prio) { - case LM_GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; - case LM_GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; - case LM_GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; - case LM_GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break; - } - - if (prio == LM_GGML_SCHED_PRIO_NORMAL) { - // Keep inherited policy/priority - return true; - } - - int32_t err = pthread_setschedparam(pthread_self(), policy, &p); - if (err != 0) { - fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err); - return false; - } - - return true; -} - -#else // unsupported platforms - -static bool lm_ggml_thread_apply_affinity(const bool * mask) { - UNUSED(mask); - return true; -} - -static bool lm_ggml_thread_apply_priority(int32_t prio) { - UNUSED(prio); - return true; -} - -#endif - -static bool lm_ggml_thread_cpumask_is_valid(const bool * mask) { - for (int i = 0; i < LM_GGML_MAX_N_THREADS; i++) { - if (mask[i]) { return true; } - } - return false; -} - -static void lm_ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) { - if (!strict) { - memcpy(local_mask, global_mask, LM_GGML_MAX_N_THREADS); - return; - } else { - memset(local_mask, 0, LM_GGML_MAX_N_THREADS); - int32_t base_idx = *iter; - for (int32_t i = 0; i < LM_GGML_MAX_N_THREADS; i++) { - int32_t idx = base_idx + i; - if (idx >= LM_GGML_MAX_N_THREADS) { - // Just a cheaper modulo - idx -= LM_GGML_MAX_N_THREADS; - } - if (global_mask[idx]) { - local_mask[idx] = 1; - *iter = idx + 1; - return; - } - } - } -} - -void lm_ggml_threadpool_free(struct lm_ggml_threadpool* threadpool) { - if (!threadpool) return; - - const int n_threads = threadpool->n_threads_max; - -#ifndef LM_GGML_USE_OPENMP - struct lm_ggml_compute_state* workers = threadpool->workers; - - lm_ggml_mutex_lock(&threadpool->mutex); - - threadpool->stop = true; - threadpool->pause = false; - - lm_ggml_cond_broadcast(&threadpool->cond); - lm_ggml_mutex_unlock(&threadpool->mutex); - - for (int j = 1; j < n_threads; j++) { - int32_t rc = lm_ggml_thread_join(workers[j].thrd, NULL); - LM_GGML_ASSERT(rc == LM_GGML_EXIT_SUCCESS || rc == LM_GGML_EXIT_ABORTED); - UNUSED(rc); - } - - lm_ggml_mutex_destroy(&threadpool->mutex); - lm_ggml_cond_destroy(&threadpool->cond); -#endif // LM_GGML_USE_OPENMP - - const size_t workers_size = sizeof(struct lm_ggml_compute_state) * n_threads; - lm_ggml_aligned_free(threadpool->workers, workers_size); - lm_ggml_aligned_free(threadpool, sizeof(struct lm_ggml_threadpool)); -} - -#ifndef LM_GGML_USE_OPENMP -// pause/resume must be called under mutex -static void lm_ggml_threadpool_pause_locked(struct lm_ggml_threadpool * threadpool) { - LM_GGML_PRINT_DEBUG("Pausing threadpool\n"); - threadpool->pause = true; - lm_ggml_cond_broadcast(&threadpool->cond); -} - -static void lm_ggml_threadpool_resume_locked(struct lm_ggml_threadpool * threadpool) { - LM_GGML_PRINT_DEBUG("Resuming threadpool\n"); - threadpool->pause = false; - lm_ggml_cond_broadcast(&threadpool->cond); -} -#endif - -void lm_ggml_threadpool_pause(struct lm_ggml_threadpool * threadpool) { -#ifndef LM_GGML_USE_OPENMP - lm_ggml_mutex_lock(&threadpool->mutex); - if (!threadpool->pause) { - lm_ggml_threadpool_pause_locked(threadpool); - } - lm_ggml_mutex_unlock(&threadpool->mutex); -#else - UNUSED(threadpool); -#endif -} - -void lm_ggml_threadpool_resume(struct lm_ggml_threadpool * threadpool) { -#ifndef LM_GGML_USE_OPENMP - lm_ggml_mutex_lock(&threadpool->mutex); - if (threadpool->pause) { - lm_ggml_threadpool_resume_locked(threadpool); - } - lm_ggml_mutex_unlock(&threadpool->mutex); -#else - UNUSED(threadpool); -#endif -} - -struct lm_ggml_cplan lm_ggml_graph_plan( - const struct lm_ggml_cgraph * cgraph, - int n_threads, - struct lm_ggml_threadpool * threadpool) { - - if (threadpool == NULL) { - LM_GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); - } - if (n_threads <= 0) { - n_threads = threadpool ? threadpool->n_threads_max : LM_GGML_DEFAULT_N_THREADS; - } - - size_t work_size = 0; - - struct lm_ggml_cplan cplan; - memset(&cplan, 0, sizeof(struct lm_ggml_cplan)); - - int max_tasks = 1; - - // thread scheduling for the different operations + work buffer size estimation - for (int i = 0; i < cgraph->n_nodes; i++) { - struct lm_ggml_tensor * node = cgraph->nodes[i]; - - const int n_tasks = lm_ggml_get_n_tasks(node, n_threads); - - max_tasks = MAX(max_tasks, n_tasks); - - size_t cur = 0; - - switch (node->op) { - case LM_GGML_OP_CPY: - case LM_GGML_OP_DUP: - { - if (lm_ggml_is_quantized(node->type) || - // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32 - (node->src[0]->type == LM_GGML_TYPE_F16 && node->src[1] && node->src[1]->type == LM_GGML_TYPE_BF16) || - (node->src[0]->type == LM_GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == LM_GGML_TYPE_F16)) { - cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->ne[0] * n_tasks; - } - } break; - case LM_GGML_OP_ADD: - case LM_GGML_OP_ADD1: - { - if (lm_ggml_is_quantized(node->src[0]->type)) { - cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; - } - } break; - case LM_GGML_OP_ACC: - { - if (lm_ggml_is_quantized(node->src[0]->type)) { - cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; - } - } break; - case LM_GGML_OP_COUNT_EQUAL: - { - cur = lm_ggml_type_size(node->type)*n_tasks; - } break; - case LM_GGML_OP_MUL_MAT: - { - const enum lm_ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; - - if (node->src[1]->type != vec_dot_type) { - cur = lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(node->src[1])); - } - } break; - case LM_GGML_OP_MUL_MAT_ID: - { - cur = 0; - const struct lm_ggml_tensor * src0 = node->src[0]; - const struct lm_ggml_tensor * src1 = node->src[1]; - const enum lm_ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; - if (src1->type != vec_dot_type) { - cur += lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(src1)); - } - const int n_as = src0->ne[2]; - cur += LM_GGML_PAD(cur, sizeof(int64_t)); // align - cur += n_as * sizeof(int64_t); // matrix_row_counts - cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows - } break; - case LM_GGML_OP_OUT_PROD: - { - if (lm_ggml_is_quantized(node->src[0]->type)) { - cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; - } - } break; - case LM_GGML_OP_SOFT_MAX: - case LM_GGML_OP_ROPE: - { - cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->ne[0] * n_tasks; - } break; - case LM_GGML_OP_CONV_TRANSPOSE_1D: - { - LM_GGML_ASSERT(node->src[0]->ne[3] == 1); - LM_GGML_ASSERT(node->src[1]->ne[2] == 1); - LM_GGML_ASSERT(node->src[1]->ne[3] == 1); - - const int64_t ne00 = node->src[0]->ne[0]; // K - const int64_t ne01 = node->src[0]->ne[1]; // Cout - const int64_t ne02 = node->src[0]->ne[2]; // Cin - - const int64_t ne10 = node->src[1]->ne[0]; // L - const int64_t ne11 = node->src[1]->ne[1]; // Cin - - if ((node->src[0]->type == LM_GGML_TYPE_F16 || - node->src[0]->type == LM_GGML_TYPE_BF16) && - node->src[1]->type == LM_GGML_TYPE_F32) { - cur += sizeof(lm_ggml_fp16_t)*ne00*ne01*ne02; - cur += sizeof(lm_ggml_fp16_t)*ne10*ne11; - } else if (node->src[0]->type == LM_GGML_TYPE_F32 && - node->src[1]->type == LM_GGML_TYPE_F32) { - cur += sizeof(float)*ne00*ne01*ne02; - cur += sizeof(float)*ne10*ne11; - } else { - LM_GGML_ABORT("fatal error"); - } - } break; - case LM_GGML_OP_CONV_TRANSPOSE_2D: - { - const int64_t ne00 = node->src[0]->ne[0]; // W - const int64_t ne01 = node->src[0]->ne[1]; // H - const int64_t ne02 = node->src[0]->ne[2]; // Channels Out - const int64_t ne03 = node->src[0]->ne[3]; // Channels In - - const int64_t ne10 = node->src[1]->ne[0]; // W - const int64_t ne11 = node->src[1]->ne[1]; // H - const int64_t ne12 = node->src[1]->ne[2]; // Channels In - - cur += sizeof(lm_ggml_fp16_t)*ne00*ne01*ne02*ne03; - cur += sizeof(lm_ggml_fp16_t)*ne10*ne11*ne12; - } break; - case LM_GGML_OP_FLASH_ATTN_EXT: - { - const int64_t ne00 = node->src[0]->ne[0]; // D - - cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread - } break; - case LM_GGML_OP_FLASH_ATTN_BACK: - { - const int64_t D = node->src[0]->ne[0]; - const int64_t ne11 = lm_ggml_up(node->src[1]->ne[1], LM_GGML_SOFT_MAX_UNROLL); - const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in lm_ggml_compute_forward_flash_attn_back - if (node->src[1]->type == LM_GGML_TYPE_F32) { - cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 - } else if (node->src[1]->type == LM_GGML_TYPE_F16) { - cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 - } else if (node->src[1]->type == LM_GGML_TYPE_BF16) { - cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 - } - } break; - - case LM_GGML_OP_CROSS_ENTROPY_LOSS: - { - cur = lm_ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); - } break; - case LM_GGML_OP_COUNT: - { - LM_GGML_ABORT("fatal error"); + } break; + case LM_GGML_OP_REPEAT_BACK: + { + if (src0->grad) { + // TODO: test this + src0->grad = lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_repeat(ctx, tensor->grad, src0->grad), + zero_table, acc_table); } - default: - break; - } - - work_size = MAX(work_size, cur); - } - - if (work_size > 0) { - work_size += CACHE_LINE_SIZE*(n_threads); - } - - cplan.threadpool = threadpool; - cplan.n_threads = MIN(max_tasks, n_threads); - cplan.work_size = work_size; - cplan.work_data = NULL; - - return cplan; -} - -static thread_ret_t lm_ggml_graph_compute_thread(void * data) { - struct lm_ggml_compute_state * state = (struct lm_ggml_compute_state *) data; - struct lm_ggml_threadpool * tp = state->threadpool; - - const struct lm_ggml_cgraph * cgraph = tp->cgraph; - const struct lm_ggml_cplan * cplan = tp->cplan; - - set_numa_thread_affinity(state->ith); - - struct lm_ggml_compute_params params = { - /*.ith =*/ state->ith, - /*.nth =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed), - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, - /*.threadpool=*/ tp, - }; - - for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) { - struct lm_ggml_tensor * node = cgraph->nodes[node_n]; - - lm_ggml_compute_forward(¶ms, node); - - if (state->ith == 0 && cplan->abort_callback && - cplan->abort_callback(cplan->abort_callback_data)) { - tp->abort = true; - tp->ec = LM_GGML_STATUS_ABORTED; - } - - lm_ggml_barrier(state->threadpool); - } - - return 0; -} - -#ifndef LM_GGML_USE_OPENMP - -// check if thread is active -static inline bool lm_ggml_graph_compute_thread_active(struct lm_ggml_compute_state * state) { - struct lm_ggml_threadpool * threadpool = state->threadpool; - int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed); - return (state->ith < n_threads); -} - -// check if thread is ready to proceed (exit from polling or sleeping) -static inline bool lm_ggml_graph_compute_thread_ready(struct lm_ggml_compute_state * state) { - struct lm_ggml_threadpool * threadpool = state->threadpool; - - if (state->pending || threadpool->stop || threadpool->pause) { return true; } - - // check for new graph/work - int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed); - if (new_graph != state->last_graph) { - state->pending = lm_ggml_graph_compute_thread_active(state); - state->last_graph = new_graph; - } - - return state->pending; -} - -// sync thread state after polling -static inline void lm_ggml_graph_compute_thread_sync(struct lm_ggml_compute_state * state) { - // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead - #ifdef LM_GGML_TSAN_ENABLED - atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst); - #else - atomic_thread_fence(memory_order_seq_cst); - #endif - UNUSED(state); -} - -static inline bool lm_ggml_graph_compute_poll_for_work(struct lm_ggml_compute_state * state) { - struct lm_ggml_threadpool * threadpool = state->threadpool; - - // Skip polling for unused threads - if (!lm_ggml_graph_compute_thread_active(state)) { - return state->pending; - } - - // This seems to make 0 ... 100 a decent range for polling level across modern processors. - // Perhaps, we can adjust it dynamically based on load and things. - const uint64_t n_rounds = 1024UL * 128 * threadpool->poll; - - for (uint64_t i=0; !lm_ggml_graph_compute_thread_ready(state) && i < n_rounds; i++) { - // No new work. Keep polling. - lm_ggml_thread_cpu_relax(); - } - - return state->pending; -} - -static inline bool lm_ggml_graph_compute_check_for_work(struct lm_ggml_compute_state * state) { - struct lm_ggml_threadpool * threadpool = state->threadpool; - - if (lm_ggml_graph_compute_poll_for_work(state)) { - lm_ggml_graph_compute_thread_sync(state); - return state->pending; - } - - lm_ggml_mutex_lock_shared(&threadpool->mutex); - while (!lm_ggml_graph_compute_thread_ready(state)) { - // No new work. Wait for the signal. - LM_GGML_PRINT_DEBUG("thread #%d waiting for work (sleeping)\n", state->ith); - lm_ggml_cond_wait(&threadpool->cond, &threadpool->mutex); - } - lm_ggml_mutex_unlock_shared(&threadpool->mutex); - - return state->pending; -} - -static thread_ret_t lm_ggml_graph_compute_secondary_thread(void* data) { - struct lm_ggml_compute_state * state = (struct lm_ggml_compute_state *) data; - struct lm_ggml_threadpool * threadpool = state->threadpool; - - lm_ggml_thread_apply_priority(threadpool->prio); - if (lm_ggml_thread_cpumask_is_valid(state->cpumask)) { - lm_ggml_thread_apply_affinity(state->cpumask); - } - - while (true) { - // Check if we need to sleep - while (threadpool->pause) { - LM_GGML_PRINT_DEBUG("thread #%d inside pause loop\n", state->ith); - lm_ggml_mutex_lock_shared(&threadpool->mutex); - if (threadpool->pause) { - lm_ggml_cond_wait(&threadpool->cond, &threadpool->mutex); + } break; + case LM_GGML_OP_CONCAT: + { + LM_GGML_ABORT("fatal error"); // TODO: implement } - LM_GGML_PRINT_DEBUG("thread #%d resuming after wait\n", state->ith); - lm_ggml_mutex_unlock_shared(&threadpool->mutex); - } - - // This needs to be checked for after the cond_wait - if (threadpool->stop) break; - - // Check if there is new work - // The main thread is the only one that can dispatch new work - - lm_ggml_graph_compute_check_for_work(state); - if (state->pending) { - state->pending = false; - - lm_ggml_graph_compute_thread(state); - } - } - - return (thread_ret_t) 0; -} - -// Start processing new graph -static void lm_ggml_graph_compute_kickoff(struct lm_ggml_threadpool * threadpool, int n_threads) -{ - // Always take the mutex here because the worker threads are doing hybrid poll/wait - - lm_ggml_mutex_lock(&threadpool->mutex); - - LM_GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads); - - // Update the number of active threads - atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed); - - // Indicate the graph is ready to be processed - // We need the full seq-cst fence here because of the polling threads (used in thread_sync) - atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst); - - if (threadpool->pause) { - // Update main thread prio and affinity to match the threadpool settings - lm_ggml_thread_apply_priority(threadpool->prio); - if (lm_ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) { - lm_ggml_thread_apply_affinity(threadpool->workers[0].cpumask); - } - - // resume does cond broadcast - lm_ggml_threadpool_resume_locked(threadpool); - } else { - lm_ggml_cond_broadcast(&threadpool->cond); - } - - lm_ggml_mutex_unlock(&threadpool->mutex); -} - -#endif // LM_GGML_USE_OPENMP - -void lm_ggml_threadpool_params_init(struct lm_ggml_threadpool_params * p, int n_threads) { - p->n_threads = n_threads; - p->prio = 0; // default priority (usually means normal or inherited) - p->poll = 50; // hybrid-polling enabled - p->strict_cpu = false; // no strict placement (all threads share same cpumask) - p->paused = false; // threads are ready to go - memset(p->cpumask, 0, LM_GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited) -} - -struct lm_ggml_threadpool_params lm_ggml_threadpool_params_default(int n_threads) { - struct lm_ggml_threadpool_params p; - lm_ggml_threadpool_params_init(&p, n_threads); - return p; -} - -bool lm_ggml_threadpool_params_match(const struct lm_ggml_threadpool_params * p0, const struct lm_ggml_threadpool_params * p1) { - if (p0->n_threads != p1->n_threads ) return false; - if (p0->prio != p1->prio ) return false; - if (p0->poll != p1->poll ) return false; - if (p0->strict_cpu != p1->strict_cpu ) return false; - return memcmp(p0->cpumask, p1->cpumask, LM_GGML_MAX_N_THREADS) == 0; -} - -static struct lm_ggml_threadpool * lm_ggml_threadpool_new_impl( - struct lm_ggml_threadpool_params * tpp, - struct lm_ggml_cgraph * cgraph, - struct lm_ggml_cplan * cplan) { - - struct lm_ggml_threadpool * threadpool = - lm_ggml_aligned_malloc(sizeof(struct lm_ggml_threadpool)); - { - threadpool->cgraph = cgraph; - threadpool->cplan = cplan; - threadpool->n_graph = 0; - threadpool->n_barrier = 0; - threadpool->n_barrier_passed = 0; - threadpool->current_chunk = 0; - threadpool->stop = false; - threadpool->pause = tpp->paused; - threadpool->abort = false; - threadpool->workers = NULL; - threadpool->n_threads_max = tpp->n_threads; - threadpool->n_threads_cur = tpp->n_threads; - threadpool->poll = tpp->poll; - threadpool->prio = tpp->prio; - threadpool->ec = LM_GGML_STATUS_SUCCESS; - } - - // Allocate and init workers state - const size_t workers_size = sizeof(struct lm_ggml_compute_state) * tpp->n_threads; - struct lm_ggml_compute_state * workers = lm_ggml_aligned_malloc(workers_size); - - memset(workers, 0, workers_size); - for (int j = 0; j < tpp->n_threads; j++) { - workers[j].threadpool = threadpool; - workers[j].ith = j; - } - - threadpool->workers = workers; - -#ifndef LM_GGML_USE_OPENMP - lm_ggml_mutex_init(&threadpool->mutex); - lm_ggml_cond_init(&threadpool->cond); - - // Spin the threads for all workers, and update CPU placements. - // Place the main thread last (towards the higher numbered CPU cores). - - int32_t cpumask_iter = 0; - - for (int j = 1; j < tpp->n_threads; j++) { - lm_ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter); - - int32_t rc = lm_ggml_thread_create(&workers[j].thrd, NULL, lm_ggml_graph_compute_secondary_thread, &workers[j]); - LM_GGML_ASSERT(rc == 0); - } - - lm_ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter); - - if (!threadpool->pause) { - // Update main thread prio and affinity at the start, otherwise we'll do it in resume - lm_ggml_thread_apply_priority(threadpool->prio); - if (lm_ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) { - lm_ggml_thread_apply_affinity(threadpool->workers[0].cpumask); - } - } -#endif // LM_GGML_USE_OPENMP - - return threadpool; -} - -struct lm_ggml_threadpool * lm_ggml_threadpool_new(struct lm_ggml_threadpool_params * tpp) { - return lm_ggml_threadpool_new_impl(tpp, NULL, NULL); -} - -enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct lm_ggml_cplan * cplan) { - LM_GGML_ASSERT(cplan); - LM_GGML_ASSERT(cplan->n_threads > 0); - LM_GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL); - - int n_threads = cplan->n_threads; - struct lm_ggml_threadpool * threadpool = cplan->threadpool; - - bool disposable_threadpool = false; - - if (threadpool == NULL) { - LM_GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); - disposable_threadpool = true; - - struct lm_ggml_threadpool_params ttp = lm_ggml_threadpool_params_default(n_threads); - threadpool = lm_ggml_threadpool_new_impl(&ttp, cgraph, cplan); - } else { - // Reset some of the parameters that need resetting - // No worker threads should be accessing the parameters below at this stage - threadpool->cgraph = cgraph; - threadpool->cplan = cplan; - threadpool->current_chunk = 0; - threadpool->abort = false; - threadpool->ec = LM_GGML_STATUS_SUCCESS; - } - -#ifdef LM_GGML_USE_OPENMP - if (n_threads > 1) { - #pragma omp parallel num_threads(n_threads) - { - #pragma omp single + case LM_GGML_OP_SILU_BACK: { - // update the number of threads from the actual number of threads that we got from OpenMP - n_threads = omp_get_num_threads(); - atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed); + LM_GGML_ABORT("fatal error"); // TODO: not implemented } + case LM_GGML_OP_NORM: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_RMS_NORM: + { + // necessary for llama + if (src0->grad) { + float eps; + memcpy(&eps, tensor->op_params, sizeof(float)); - lm_ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]); - } - } else { - atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed); - lm_ggml_graph_compute_thread(&threadpool->workers[0]); - } -#else - if (n_threads > threadpool->n_threads_max) { - LM_GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max); - n_threads = threadpool->n_threads_max; - } - - // Kick all threads to start the new graph - lm_ggml_graph_compute_kickoff(threadpool, n_threads); - - // This is a work thread too - lm_ggml_graph_compute_thread(&threadpool->workers[0]); -#endif - - // don't leave affinity set on the main thread - clear_numa_thread_affinity(); - - enum lm_ggml_status ret = threadpool->ec; - - if (disposable_threadpool) { - lm_ggml_threadpool_free(threadpool); - } - - return ret; -} - -enum lm_ggml_status lm_ggml_graph_compute_with_ctx(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, int n_threads) { - struct lm_ggml_cplan cplan = lm_ggml_graph_plan(cgraph, n_threads, NULL); - - struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); - - cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; - - return lm_ggml_graph_compute(cgraph, &cplan); -} - -struct lm_ggml_tensor * lm_ggml_graph_get_tensor(struct lm_ggml_cgraph * cgraph, const char * name) { - for (int i = 0; i < cgraph->n_leafs; i++) { - struct lm_ggml_tensor * leaf = cgraph->leafs[i]; - - if (strcmp(leaf->name, name) == 0) { - return leaf; - } - } - - for (int i = 0; i < cgraph->n_nodes; i++) { - struct lm_ggml_tensor * node = cgraph->nodes[i]; - - if (strcmp(node->name, name) == 0) { - return node; - } - } - - return NULL; -} - -static void lm_ggml_graph_export_leaf(const struct lm_ggml_tensor * tensor, FILE * fout) { - const int64_t * ne = tensor->ne; - const size_t * nb = tensor->nb; - - fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", - lm_ggml_type_name(tensor->type), - lm_ggml_op_name (tensor->op), - lm_ggml_n_dims(tensor), - ne[0], ne[1], ne[2], ne[3], - nb[0], nb[1], nb[2], nb[3], - tensor->data, - tensor->name); -} - -static void lm_ggml_graph_export_node(const struct lm_ggml_tensor * tensor, const char * arg, FILE * fout) { - const int64_t * ne = tensor->ne; - const size_t * nb = tensor->nb; - - fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", - arg, - lm_ggml_type_name(tensor->type), - lm_ggml_op_name (tensor->op), - lm_ggml_n_dims(tensor), - ne[0], ne[1], ne[2], ne[3], - nb[0], nb[1], nb[2], nb[3], - tensor->data, - tensor->name); -} - -void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fname) { - uint64_t size_eval = 0; + src0->grad = lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_rms_norm_back(ctx, src0, tensor->grad, eps), + zero_table, acc_table); + } + } break; + case LM_GGML_OP_RMS_NORM_BACK: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_GROUP_NORM: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_MUL_MAT: + { + // https://cs231n.github.io/optimization-2/#staged + // # forward pass + // s0 = np.random.randn(5, 10) + // s1 = np.random.randn(10, 3) + // t = s0.dot(s1) - // compute size of intermediate results - for (int i = 0; i < cgraph->n_nodes; ++i) { - size_eval += lm_ggml_nbytes_pad(cgraph->nodes[i]); - } + // # now suppose we had the gradient on t from above in the circuit + // dt = np.random.randn(*t.shape) # same shape as t + // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix + // ds1 = t.T.dot(dt) - // print - { - FILE * fout = stdout; - - fprintf(fout, "\n"); - fprintf(fout, "%-16s %8x\n", "magic", LM_GGML_FILE_MAGIC); - fprintf(fout, "%-16s %8d\n", "version", LM_GGML_FILE_VERSION); - fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs); - fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes); - fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval); - - // header - fprintf(fout, "\n"); - fprintf(fout, "%-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %16s %16s\n", - "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "DATA", "NAME"); - - for (int i = 0; i < cgraph->n_leafs; ++i) { - lm_ggml_graph_export_leaf(cgraph->leafs[i], fout); - - LM_GGML_ASSERT(cgraph->leafs[i]->op == LM_GGML_OP_NONE); - LM_GGML_ASSERT(cgraph->leafs[i]->src[0] == NULL); - LM_GGML_ASSERT(cgraph->leafs[i]->src[1] == NULL); - } + // tensor.shape [m,p,qq,rr] + // src0.shape [n,m,q1,r1] + // src1.shape [n,p,qq,rr] - // header - fprintf(fout, "\n"); - fprintf(fout, "%-6s %-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %8s %16s %16s\n", - "ARG", "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "NTASKS", "DATA", "NAME"); + // necessary for llama + if (src0->grad) { + struct lm_ggml_tensor * s1_tg = + lm_ggml_out_prod(ctx, // [n,m,qq,rr] + src1, // [n,p,qq,rr] + tensor->grad); // [m,p,qq,rr] + const int64_t qq = s1_tg->ne[2]; + const int64_t rr = s1_tg->ne[3]; + const int64_t q1 = src0->ne[2]; + const int64_t r1 = src0->ne[3]; + const bool ne2_broadcasted = qq > q1; + const bool ne3_broadcasted = rr > r1; + if (ne2_broadcasted || ne3_broadcasted) { + // sum broadcast repetitions of s1_tg into shape of src0 + s1_tg = lm_ggml_repeat_back(ctx, s1_tg, src0); + } + src0->grad = + lm_ggml_add_or_set(ctx, + src0->grad, // [n,m,q1,r1] + s1_tg, // [n,m,q1,r1] + zero_table, acc_table); + } + if (src1->grad) { + src1->grad = + lm_ggml_add_or_set(ctx, + src1->grad, // [n,p,qq,rr] + // lm_ggml_mul_mat(ctx, // [n,p,qq,rr] + // lm_ggml_cont(ctx, // [m,n,q1,r1] + // lm_ggml_transpose(ctx, src0)), // [m,n,q1,r1] + // tensor->grad), // [m,p,qq,rr] - for (int i = 0; i < cgraph->n_nodes; ++i) { - lm_ggml_graph_export_node(cgraph->nodes[i], "DST", fout); + // // when src0 is bigger than tensor->grad (this is mostly the case in llama), + // // avoid transpose of src0, rather transpose smaller tensor->grad + // // and then use lm_ggml_out_prod + lm_ggml_out_prod(ctx, // [n,p,qq,rr] + src0, // [n,m,q1,r1] + lm_ggml_transpose(ctx, // [p,m,qq,rr] + tensor->grad)), // [m,p,qq,rr] + zero_table, acc_table); + } + } break; + case LM_GGML_OP_MUL_MAT_ID: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_OUT_PROD: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_SCALE: + { + // necessary for llama + if (src0->grad) { + float s; + memcpy(&s, tensor->op_params, sizeof(float)); - for (int j = 0; j < LM_GGML_MAX_SRC; ++j) { - if (cgraph->nodes[i]->src[j]) { - lm_ggml_graph_export_node(cgraph->nodes[i]->src[j], "SRC", fout); + src0->grad = + lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_scale_impl(ctx, tensor->grad, s, false), + zero_table, acc_table); } - } - - fprintf(fout, "\n"); - } + } break; + case LM_GGML_OP_SET: + { + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; - fprintf(fout, "\n"); - } + struct lm_ggml_tensor * tensor_grad_view = NULL; - // write binary data - { - FILE * fout = lm_ggml_fopen(fname, "wb"); + if (src0->grad || src1->grad) { + LM_GGML_ASSERT(src0->type == tensor->type); + LM_GGML_ASSERT(tensor->grad->type == tensor->type); + LM_GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type); - if (!fout) { - fprintf(stderr, "%s: failed to open %s: %s\n", __func__, fname, strerror(errno)); - return; - } + tensor_grad_view = lm_ggml_view_4d(ctx, + tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + nb1, nb2, nb3, offset); + } - // header - { - const uint32_t magic = LM_GGML_FILE_MAGIC; - const uint32_t version = LM_GGML_FILE_VERSION; - const uint32_t n_leafs = cgraph->n_leafs; - const uint32_t n_nodes = cgraph->n_nodes; - - fwrite(&magic, sizeof(uint32_t), 1, fout); - fwrite(&version, sizeof(uint32_t), 1, fout); - fwrite(&n_leafs, sizeof(uint32_t), 1, fout); - fwrite(&n_nodes, sizeof(uint32_t), 1, fout); - fwrite(&size_eval, sizeof(uint64_t), 1, fout); - } + if (src0->grad) { + src0->grad = lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_acc_impl(ctx, + tensor->grad, + lm_ggml_neg(ctx, tensor_grad_view), + nb1, nb2, nb3, offset, false), + zero_table, acc_table); + } - // leafs - { - for (int i = 0; i < cgraph->n_leafs; ++i) { - const struct lm_ggml_tensor * tensor = cgraph->leafs[i]; + if (src1->grad) { + src1->grad = + lm_ggml_add_or_set(ctx, + src1->grad, + lm_ggml_reshape(ctx, + lm_ggml_cont(ctx, tensor_grad_view), + src1->grad), + zero_table, acc_table); + } + } break; + case LM_GGML_OP_CPY: + { + // necessary for llama + // cpy overwrites value of src1 by src0 and returns view(src1) + // the overwriting is mathematically equivalent to: + // tensor = src0 * 1 + src1 * 0 + if (src0->grad) { + // dsrc0 = dtensor * 1 + src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); + } + if (src1->grad) { + // dsrc1 = dtensor * 0 -> noop + } + } break; + case LM_GGML_OP_CONT: + { + // same as cpy + if (src0->grad) { + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0->grad)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(tensor->grad)); + src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); + } + } break; + case LM_GGML_OP_RESHAPE: + { + // necessary for llama + if (src0->grad) { + src0->grad = + lm_ggml_add_or_set(ctx, src0->grad, + lm_ggml_reshape(ctx, + lm_ggml_is_contiguous(tensor->grad) + ? tensor->grad + : lm_ggml_cont(ctx, tensor->grad), + src0->grad), + zero_table, acc_table); + } + } break; + case LM_GGML_OP_VIEW: + { + // necessary for llama + if (src0->grad) { + size_t offset; - const uint32_t type = tensor->type; - const uint32_t op = tensor->op; - const int32_t flags = tensor->flags; + memcpy(&offset, tensor->op_params, sizeof(offset)); - fwrite(&type, sizeof(uint32_t), 1, fout); - fwrite(&op, sizeof(uint32_t), 1, fout); - fwrite(&flags, sizeof(int32_t), 1, fout); + size_t nb1 = tensor->nb[1]; + size_t nb2 = tensor->nb[2]; + size_t nb3 = tensor->nb[3]; - for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) { - const uint64_t ne = tensor->ne[j]; - const uint64_t nb = tensor->nb[j]; + if (src0->type != src0->grad->type) { + // gradient is typically F32, but src0 could be other type + size_t ng = lm_ggml_element_size(src0->grad); + size_t n0 = lm_ggml_element_size(src0); + LM_GGML_ASSERT(offset % n0 == 0); + LM_GGML_ASSERT(nb1 % n0 == 0); + LM_GGML_ASSERT(nb2 % n0 == 0); + LM_GGML_ASSERT(nb3 % n0 == 0); + offset = (offset / n0) * ng; + nb1 = (nb1 / n0) * ng; + nb2 = (nb2 / n0) * ng; + nb3 = (nb3 / n0) * ng; + } - fwrite(&ne, sizeof(uint64_t), 1, fout); - fwrite(&nb, sizeof(uint64_t), 1, fout); + src0->grad = lm_ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table); } - - fwrite(tensor->name, sizeof(char), LM_GGML_MAX_NAME, fout); - fwrite(tensor->op_params, sizeof(char), LM_GGML_MAX_OP_PARAMS, fout); - - // dump the data - // TODO: pad this to 32 byte boundary - { - const size_t size = lm_ggml_nbytes(tensor); - - fwrite(tensor->data, sizeof(char), size, fout); + } break; + case LM_GGML_OP_PERMUTE: + { + // necessary for llama + if (src0->grad) { + int32_t * axes = (int32_t *) tensor->op_params; + int axis0 = axes[0] & 0x3; + int axis1 = axes[1] & 0x3; + int axis2 = axes[2] & 0x3; + int axis3 = axes[3] & 0x3; + int axes_backward[4] = {0,0,0,0}; + axes_backward[axis0] = 0; + axes_backward[axis1] = 1; + axes_backward[axis2] = 2; + axes_backward[axis3] = 3; + src0->grad = + lm_ggml_add_or_set(ctx, src0->grad, + lm_ggml_permute(ctx, + tensor->grad, + axes_backward[0], + axes_backward[1], + axes_backward[2], + axes_backward[3]), + zero_table, acc_table); + } + } break; + case LM_GGML_OP_TRANSPOSE: + { + // necessary for llama + if (src0->grad) { + src0->grad = + lm_ggml_add_or_set(ctx, src0->grad, + lm_ggml_transpose(ctx, tensor->grad), + zero_table, acc_table); + } + } break; + case LM_GGML_OP_GET_ROWS: + { + // necessary for llama (only for tokenizer) + if (src0->grad) { + src0->grad = + lm_ggml_add_or_set(ctx, src0->grad, + // last lm_ggml_get_rows_back argument src0->grad is only + // necessary to setup correct output shape + lm_ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), + zero_table, acc_table); + } + if (src1->grad) { + // noop } + } break; + case LM_GGML_OP_GET_ROWS_BACK: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented } - } - - // nodes - { - for (int i = 0; i < cgraph->n_nodes; ++i) { - const struct lm_ggml_tensor * tensor = cgraph->nodes[i]; + case LM_GGML_OP_DIAG: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_DIAG_MASK_INF: + { + // necessary for llama + if (src0->grad) { + const int n_past = ((int32_t *) tensor->op_params)[0]; + src0->grad = + lm_ggml_add_or_set(ctx, src0->grad, + /* lm_ggml_diag_mask_inf_impl() shouldn't be here */ + /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ + lm_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), + zero_table, acc_table); + } + } break; + case LM_GGML_OP_DIAG_MASK_ZERO: + { + // necessary for llama + if (src0->grad) { + const int n_past = ((int32_t *) tensor->op_params)[0]; + src0->grad = + lm_ggml_add_or_set(ctx, src0->grad, + lm_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), + zero_table, acc_table); + } + } break; + case LM_GGML_OP_SOFT_MAX: + { + // necessary for llama + if (src0->grad) { + src0->grad = + lm_ggml_add_or_set(ctx, src0->grad, + lm_ggml_soft_max_back(ctx, tensor->grad, tensor), + zero_table, acc_table); + } + LM_GGML_ASSERT((!src1 || !src1->grad) && "backward pass for softmax mask not implemented"); + } break; + case LM_GGML_OP_SOFT_MAX_BACK: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_ROPE: + { + // necessary for llama + if (src0->grad) { + //const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - const uint32_t type = tensor->type; - const uint32_t op = tensor->op; - const int32_t flags = tensor->flags; + memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - fwrite(&type, sizeof(uint32_t), 1, fout); - fwrite(&op, sizeof(uint32_t), 1, fout); - fwrite(&flags, sizeof(int32_t), 1, fout); + src0->grad = lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_rope_back(ctx, + tensor->grad, + src1, + src2, + n_dims, + mode, + n_ctx_orig, + freq_base, + freq_scale, + ext_factor, + attn_factor, + beta_fast, + beta_slow), + zero_table, acc_table); + } + LM_GGML_ASSERT((!src2 || !src2->grad) && "gradients for freq factors not implemented"); + } break; + case LM_GGML_OP_ROPE_BACK: + { + if (src0->grad) { + //const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) { - const uint64_t ne = tensor->ne[j]; - const uint64_t nb = tensor->nb[j]; + memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - fwrite(&ne, sizeof(uint64_t), 1, fout); - fwrite(&nb, sizeof(uint64_t), 1, fout); + src0->grad = lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_rope_impl(ctx, + tensor->grad, + src1, + src2, + n_dims, + mode, + n_ctx_orig, + freq_base, + freq_scale, + ext_factor, + attn_factor, + beta_fast, + beta_slow, + false), + zero_table, acc_table); } + } break; + case LM_GGML_OP_CLAMP: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_CONV_TRANSPOSE_1D: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_IM2COL: + { + if (src1->grad) { + const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 0); + const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 1); + const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 2); + const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 3); + const int32_t d0 = lm_ggml_get_op_params_i32(tensor, 4); + const int32_t d1 = lm_ggml_get_op_params_i32(tensor, 5); + const bool is_2D = lm_ggml_get_op_params_i32(tensor, 6) == 1; - fwrite(tensor->name, sizeof(char), LM_GGML_MAX_NAME, fout); - fwrite(tensor->op_params, sizeof(char), LM_GGML_MAX_OP_PARAMS, fout); - - // output the op arguments - { - struct lm_ggml_tensor * args[LM_GGML_MAX_SRC] = { NULL }; - - for (int j = 0; j < LM_GGML_MAX_SRC; ++j) { - args[j] = tensor->src[j]; - } - - for (int j = 0; j < LM_GGML_MAX_SRC; ++j) { - if (args[j]) { - int32_t idx = -1; - - // check if leaf - { - for (int k = 0; k < cgraph->n_leafs; ++k) { - if (args[j] == cgraph->leafs[k]) { - idx = k; - break; - } - } - } - - // check if node - if (idx == -1) { - for (int k = 0; k < cgraph->n_nodes; ++k) { - if (args[j] == cgraph->nodes[k]) { - idx = cgraph->n_leafs + k; - break; - } - } - } - - if (idx == -1) { - fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i); - fclose(fout); - return; - } - - fwrite(&idx, sizeof(int32_t), 1, fout); - } else { - const int32_t nul = -1; - - fwrite(&nul, sizeof(int32_t), 1, fout); - } - } + src1->grad = lm_ggml_add_or_set(ctx, + src1->grad, + lm_ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D), + zero_table, acc_table); } + } break; + case LM_GGML_OP_IM2COL_BACK: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_CONV_TRANSPOSE_2D: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_POOL_1D: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_POOL_2D: + { + if (src0->grad) { + const enum lm_ggml_op_pool op = lm_ggml_get_op_params_i32(tensor, 0); + const int32_t k0 = lm_ggml_get_op_params_i32(tensor, 1); + const int32_t k1 = lm_ggml_get_op_params_i32(tensor, 2); + const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 3); + const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 4); + const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 5); + const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 6); - // dump the data - // TODO: pad this to 32 byte boundary - if ((flags & LM_GGML_TENSOR_FLAG_PARAM)) { - const size_t size = lm_ggml_nbytes(tensor); - - fwrite(tensor->data, sizeof(char), size, fout); + src0->grad = lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1), + zero_table, acc_table); } + } break; + case LM_GGML_OP_POOL_2D_BACK: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented } - } - - fclose(fout); - } -} - -struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_context ** ctx_data, struct lm_ggml_context ** ctx_eval) { - assert(*ctx_data == NULL); - assert(*ctx_eval == NULL); - - struct lm_ggml_cgraph * result = NULL; - - struct lm_ggml_tensor * data = NULL; - - // read file into data - { - FILE * fin = lm_ggml_fopen(fname, "rb"); - if (!fin) { - fprintf(stderr, "%s: failed to open %s: %s\n", __func__, fname, strerror(errno)); - return result; - } - - size_t fsize = 0; - - fseek(fin, 0, SEEK_END); - fsize = ftell(fin); - fseek(fin, 0, SEEK_SET); - - // create the data context - { - const size_t overhead = 1*lm_ggml_tensor_overhead(); - - struct lm_ggml_init_params params = { - .mem_size = fsize + overhead, - .mem_buffer = NULL, - .no_alloc = false, - }; - - *ctx_data = lm_ggml_init(params); - - if (!*ctx_data) { - fprintf(stderr, "%s: failed to create ggml context\n", __func__); - fclose(fin); - return result; + case LM_GGML_OP_UPSCALE: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented } - } - - data = lm_ggml_new_tensor_1d(*ctx_data, LM_GGML_TYPE_I8, fsize); - - { - const size_t ret = fread(data->data, sizeof(char), fsize, fin); - if (ret != fsize) { - fprintf(stderr, "%s: failed to read %s\n", __func__, fname); - fclose(fin); - return result; + case LM_GGML_OP_PAD: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented } - } - - fclose(fin); - } - - // populate result - { - char * ptr = (char *) data->data; - - const uint32_t magic = *(const uint32_t *) ptr; ptr += sizeof(magic); - - if (magic != LM_GGML_FILE_MAGIC) { - fprintf(stderr, "%s: invalid magic number, got %08x\n", __func__, magic); - return result; - } - - const uint32_t version = *(const uint32_t *) ptr; ptr += sizeof(version); - - if (version != LM_GGML_FILE_VERSION) { - fprintf(stderr, "%s: invalid version number\n", __func__); - return result; - } - - const uint32_t n_leafs = *(const uint32_t *) ptr; ptr += sizeof(n_leafs); - const uint32_t n_nodes = *(const uint32_t *) ptr; ptr += sizeof(n_nodes); - const uint64_t size_eval = *(const uint64_t *) ptr; ptr += sizeof(size_eval); - const int graph_size = MAX(n_leafs, n_nodes); - - // create the data context - { - const size_t overhead = (n_leafs + n_nodes)*lm_ggml_tensor_overhead() + lm_ggml_graph_overhead_custom(graph_size, false); - - struct lm_ggml_init_params params = { - .mem_size = size_eval + overhead, - .mem_buffer = NULL, - .no_alloc = true, - }; - - *ctx_eval = lm_ggml_init(params); - - if (!*ctx_eval) { - fprintf(stderr, "%s: failed to create ggml context\n", __func__); - return result; + case LM_GGML_OP_ARANGE: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented } - } - - result = lm_ggml_new_graph_custom(*ctx_eval, graph_size, false); - - result->n_leafs = n_leafs; - result->n_nodes = n_nodes; - - - // leafs - { - uint32_t type; - uint32_t op; - int32_t flags; - - for (uint32_t i = 0; i < n_leafs; ++i) { - type = *(const uint32_t *) ptr; ptr += sizeof(type); - op = *(const uint32_t *) ptr; ptr += sizeof(op); - flags = *(const int32_t *) ptr; ptr += sizeof(flags); - - int64_t ne[LM_GGML_MAX_DIMS]; - size_t nb[LM_GGML_MAX_DIMS]; - - for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) { - uint64_t ne_cur; - uint64_t nb_cur; - - ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur); - nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur); - - ne[j] = ne_cur; - nb[j] = nb_cur; + case LM_GGML_OP_TIMESTEP_EMBEDDING: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_ARGSORT: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_LEAKY_RELU: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_FLASH_ATTN_EXT: + { + LM_GGML_ABORT("FA backward pass not adapted after rework"); + struct lm_ggml_tensor * flash_grad = NULL; + if (src0->grad || src1->grad || tensor->src[2]->grad) { + int32_t t = lm_ggml_get_op_params_i32(tensor, 0); + LM_GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + flash_grad = + lm_ggml_flash_attn_back(ctx, + src0, + src1, + tensor->src[2], + tensor->grad, + masked); } - struct lm_ggml_tensor * tensor = lm_ggml_new_tensor(*ctx_eval, (enum lm_ggml_type) type, LM_GGML_MAX_DIMS, ne); + const int64_t elem_q = lm_ggml_nelements(src0); + const int64_t elem_k = lm_ggml_nelements(src1); + const int64_t elem_v = lm_ggml_nelements(src2); - tensor->op = (enum lm_ggml_op) op; - tensor->flags = flags; + enum lm_ggml_type result_type = flash_grad->type; + LM_GGML_ASSERT(lm_ggml_blck_size(result_type) == 1); + const size_t tsize = lm_ggml_type_size(result_type); - memcpy(tensor->name, ptr, LM_GGML_MAX_NAME); ptr += LM_GGML_MAX_NAME; - memcpy(tensor->op_params, ptr, LM_GGML_MAX_OP_PARAMS); ptr += LM_GGML_MAX_OP_PARAMS; + const size_t offs_q = 0; + const size_t offs_k = offs_q + LM_GGML_PAD(elem_q * tsize, LM_GGML_MEM_ALIGN); + const size_t offs_v = offs_k + LM_GGML_PAD(elem_k * tsize, LM_GGML_MEM_ALIGN); - for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) { - tensor->nb[j] = nb[j]; + if (src0->grad) { + struct lm_ggml_tensor * view_q = lm_ggml_view_1d(ctx, flash_grad, elem_q, offs_q); + struct lm_ggml_tensor * grad_q = lm_ggml_reshape(ctx, view_q, src0); + src0->grad = lm_ggml_add_or_set(ctx, + src0->grad, + grad_q, + zero_table, acc_table); } - - tensor->data = (void *) ptr; ptr += lm_ggml_nbytes(tensor); - - result->leafs[i] = tensor; - - fprintf(stderr, "%s: loaded leaf %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, lm_ggml_nbytes(tensor)); - } - } - - lm_ggml_set_no_alloc(*ctx_eval, false); - - // nodes - { - uint32_t type; - uint32_t op; - int32_t flags; - - for (uint32_t i = 0; i < n_nodes; ++i) { - type = *(const uint32_t *) ptr; ptr += sizeof(type); - op = *(const uint32_t *) ptr; ptr += sizeof(op); - flags = *(const int32_t *) ptr; ptr += sizeof(flags); - - enum lm_ggml_op eop = (enum lm_ggml_op) op; - - int64_t ne[LM_GGML_MAX_DIMS]; - size_t nb[LM_GGML_MAX_DIMS]; - - for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) { - uint64_t ne_cur; - uint64_t nb_cur; - - ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur); - nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur); - - ne[j] = ne_cur; - nb[j] = nb_cur; + if (src1->grad) { + struct lm_ggml_tensor * view_k = lm_ggml_view_1d(ctx, flash_grad, elem_k, offs_k); + struct lm_ggml_tensor * grad_k = lm_ggml_reshape(ctx, view_k, src1); + src1->grad = lm_ggml_add_or_set(ctx, + src1->grad, + grad_k, + zero_table, acc_table); } - - const char * ptr_name = ptr; ptr += LM_GGML_MAX_NAME; - const char * ptr_op_params = ptr; ptr += LM_GGML_MAX_OP_PARAMS; - - const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += LM_GGML_MAX_SRC*sizeof(int32_t); - - struct lm_ggml_tensor * args[LM_GGML_MAX_SRC] = { NULL }; - - // parse args - for (int j = 0; j < LM_GGML_MAX_SRC; ++j) { - const int32_t arg_idx = ptr_arg_idx[j]; - - if (arg_idx == -1) { - continue; - } - - if (arg_idx < result->n_leafs) { - args[j] = result->leafs[arg_idx]; - } else { - args[j] = result->nodes[arg_idx - result->n_leafs]; - } + if (src2->grad) { + struct lm_ggml_tensor * view_v = lm_ggml_view_1d(ctx, flash_grad, elem_v, offs_v); + struct lm_ggml_tensor * grad_v = lm_ggml_reshape(ctx, view_v, src2); + src2->grad = lm_ggml_add_or_set(ctx, + src2->grad, + grad_v, + zero_table, acc_table); } - - // create the tensor - // "view" operations are handled differently - // TODO: handle inplace ops - currently a copy is always made - - struct lm_ggml_tensor * tensor = NULL; - - switch (eop) { - // TODO: implement other view ops - case LM_GGML_OP_RESHAPE: + } break; + case LM_GGML_OP_FLASH_ATTN_BACK: + { + LM_GGML_ABORT("fatal error"); // not supported + } + case LM_GGML_OP_SSM_CONV: + case LM_GGML_OP_SSM_SCAN: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_OP_WIN_PART: + case LM_GGML_OP_WIN_UNPART: + case LM_GGML_OP_UNARY: + { + switch (lm_ggml_get_unary_op(tensor)) { + case LM_GGML_UNARY_OP_ABS: + { + if (src0->grad) { + src0->grad = + lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_mul(ctx, + lm_ggml_sgn(ctx, src0), + tensor->grad), + zero_table, acc_table); + } + } break; + case LM_GGML_UNARY_OP_SGN: { - tensor = lm_ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]); + if (src0->grad) { + // noop + } } break; - case LM_GGML_OP_VIEW: + case LM_GGML_UNARY_OP_NEG: { - tensor = lm_ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); - - size_t offs; - memcpy(&offs, ptr_op_params, sizeof(offs)); - - tensor->data = ((char *) tensor->data) + offs; + if (src0->grad) { + src0->grad = lm_ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); + } } break; - case LM_GGML_OP_TRANSPOSE: + case LM_GGML_UNARY_OP_STEP: { - tensor = lm_ggml_transpose(*ctx_eval, args[0]); + if (src0->grad) { + // noop + } } break; - case LM_GGML_OP_PERMUTE: + case LM_GGML_UNARY_OP_TANH: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_UNARY_OP_ELU: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_UNARY_OP_RELU: { - tensor = lm_ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); + if (src0->grad) { + src0->grad = lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_mul(ctx, + lm_ggml_step(ctx, src0), + tensor->grad), + zero_table, acc_table); + } } break; - default: + case LM_GGML_UNARY_OP_SIGMOID: { - tensor = lm_ggml_new_tensor(*ctx_eval, (enum lm_ggml_type) type, LM_GGML_MAX_DIMS, ne); - - tensor->op = eop; + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_UNARY_OP_GELU: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_UNARY_OP_GELU_QUICK: + { + LM_GGML_ABORT("fatal error"); // TODO: not implemented + } + case LM_GGML_UNARY_OP_SILU: + { + // necessary for llama + if (src0->grad) { + src0->grad = lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_silu_back(ctx, src0, tensor->grad), + zero_table, acc_table); + } } break; + case LM_GGML_UNARY_OP_EXP: + { + if (src0->grad) { + src0->grad = lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_mul(ctx, tensor, tensor->grad), + zero_table, acc_table); + } + } break; + default: + LM_GGML_ABORT("fatal error"); } - - memcpy(tensor->name, ptr_name, LM_GGML_MAX_NAME); - memcpy(tensor->op_params, ptr_op_params, LM_GGML_MAX_OP_PARAMS); - - for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) { - tensor->nb[j] = nb[j]; - } - - for (int j = 0; j < LM_GGML_MAX_SRC; ++j) { - tensor->src[j] = args[j]; - } - - result->nodes[i] = tensor; - - // TODO tensor data is be duplicated due to lm_ggml_new_tensor call above - if (flags & LM_GGML_TENSOR_FLAG_PARAM) { - tensor->data = (void *) ptr; ptr += lm_ggml_nbytes(tensor); + } break; + case LM_GGML_OP_GET_REL_POS: + case LM_GGML_OP_ADD_REL_POS: + case LM_GGML_OP_RWKV_WKV6: + case LM_GGML_OP_MAP_UNARY: + case LM_GGML_OP_MAP_BINARY: + case LM_GGML_OP_MAP_CUSTOM1_F32: + case LM_GGML_OP_MAP_CUSTOM2_F32: + case LM_GGML_OP_MAP_CUSTOM3_F32: + case LM_GGML_OP_MAP_CUSTOM1: + case LM_GGML_OP_MAP_CUSTOM2: + case LM_GGML_OP_MAP_CUSTOM3: + { + LM_GGML_ABORT("fatal error"); // not supported + } + case LM_GGML_OP_CROSS_ENTROPY_LOSS: + { + if (src0->grad) { + src0->grad = lm_ggml_add_or_set(ctx, + src0->grad, + lm_ggml_cross_entropy_loss_back(ctx, + src0, + src1, + tensor->grad), + zero_table, acc_table); } - - fprintf(stderr, "%s: loaded node %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, lm_ggml_nbytes(tensor)); + LM_GGML_ASSERT(!src1->grad && "backward pass for labels not implemented"); + } break; + case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + LM_GGML_ABORT("fatal error"); // not supported + } + case LM_GGML_OP_OPT_STEP_ADAMW: + { + LM_GGML_ABORT("fatal error"); // not supported + } + case LM_GGML_OP_NONE: + { + // nop + } break; + case LM_GGML_OP_COUNT: + { + LM_GGML_ABORT("fatal error"); } - } - } - - return result; -} - -void lm_ggml_graph_print(const struct lm_ggml_cgraph * cgraph) { - LM_GGML_LOG_INFO("=== GRAPH ===\n"); - - LM_GGML_LOG_INFO("n_nodes = %d\n", cgraph->n_nodes); - for (int i = 0; i < cgraph->n_nodes; i++) { - struct lm_ggml_tensor * node = cgraph->nodes[i]; - - LM_GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n", - i, - node->ne[0], node->ne[1], node->ne[2], - lm_ggml_op_name(node->op), (node->flags & LM_GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " "); - } - - LM_GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs); - for (int i = 0; i < cgraph->n_leafs; i++) { - struct lm_ggml_tensor * node = cgraph->leafs[i]; - - LM_GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n", - i, - node->ne[0], node->ne[1], - lm_ggml_op_name(node->op), - lm_ggml_get_name(node)); - } - - LM_GGML_LOG_INFO("========================================\n"); -} - -// check if node is part of the graph -static bool lm_ggml_graph_find(const struct lm_ggml_cgraph * cgraph, const struct lm_ggml_tensor * node) { - if (cgraph == NULL) { - return true; } - for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i] == node) { - return true; + for (int i = 0; i < LM_GGML_MAX_SRC; ++i) { + if (tensor->src[i] && tensor->src[i]->grad) { + LM_GGML_ASSERT(lm_ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad)); } } - - return false; } -static struct lm_ggml_tensor * lm_ggml_graph_get_parent(const struct lm_ggml_cgraph * cgraph, const struct lm_ggml_tensor * node) { - for (int i = 0; i < cgraph->n_nodes; i++) { - struct lm_ggml_tensor * parent = cgraph->nodes[i]; - - if (parent->grad == node) { - return parent; +static void lm_ggml_visit_parents(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * node) { + if (node->grad == NULL) { + // this usually happens when we generate intermediate nodes from constants in the backward pass + // it can also happen during forward pass, if the user performs computations with constants + if (node->op != LM_GGML_OP_NONE) { + //LM_GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); } } - return NULL; -} - -static void lm_ggml_graph_dump_dot_node_edge(FILE * fp, const struct lm_ggml_cgraph * gb, struct lm_ggml_tensor * node, struct lm_ggml_tensor * parent, const char * label) { - struct lm_ggml_tensor * gparent = lm_ggml_graph_get_parent(gb, node); - struct lm_ggml_tensor * gparent0 = lm_ggml_graph_get_parent(gb, parent); - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n", - gparent0 ? (void *) gparent0 : (void *) parent, - gparent0 ? "g" : "x", - gparent ? (void *) gparent : (void *) node, - gparent ? "g" : "x", - gparent ? "empty" : "vee", - gparent ? "dashed" : "solid", - label); -} - -static void lm_ggml_graph_dump_dot_leaf_edge(FILE * fp, struct lm_ggml_tensor * node, struct lm_ggml_tensor * parent, const char * label) { - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n", - (void *) parent, "x", - (void *) node, "x", - label); -} - -void lm_ggml_graph_dump_dot(const struct lm_ggml_cgraph * gb, const struct lm_ggml_cgraph * gf, const char * filename) { - char color[16]; - - FILE * fp = lm_ggml_fopen(filename, "w"); - LM_GGML_ASSERT(fp); - - fprintf(fp, "digraph G {\n"); - fprintf(fp, " newrank = true;\n"); - fprintf(fp, " rankdir = TB;\n"); - - for (int i = 0; i < gb->n_nodes; i++) { - struct lm_ggml_tensor * node = gb->nodes[i]; - - if (lm_ggml_graph_get_parent(gb, node) != NULL) { - continue; - } + // check if already visited + if (lm_ggml_hash_insert(&cgraph->visited_hash_set, node) == LM_GGML_HASHSET_ALREADY_EXISTS) { + return; + } - if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) { - snprintf(color, sizeof(color), "yellow"); - } else if (node->grad) { - if (lm_ggml_graph_find(gf, node)) { - snprintf(color, sizeof(color), "green"); - } else { - snprintf(color, sizeof(color), "lightblue"); - } - } else { - snprintf(color, sizeof(color), "white"); + for (int i = 0; i < LM_GGML_MAX_SRC; ++i) { + const int k = + (cgraph->order == LM_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : + (cgraph->order == LM_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (LM_GGML_MAX_SRC-1-i) : + /* unknown order, just fall back to using i*/ i; + if (node->src[k]) { + lm_ggml_visit_parents(cgraph, node->src[k]); } + } - fprintf(fp, " \"%p\" [ " - "style = filled; fillcolor = %s; shape = record; " - "label=\"", - (void *) node, color); - - if (strlen(node->name) > 0) { - fprintf(fp, "%s (%s)|", node->name, lm_ggml_type_name(node->type)); - } else { - fprintf(fp, "(%s)|", lm_ggml_type_name(node->type)); - } + if (node->op == LM_GGML_OP_NONE && !(node->flags & LM_GGML_TENSOR_FLAG_PARAM)) { + // reached a leaf node, not part of the gradient graph (e.g. a constant) + LM_GGML_ASSERT(cgraph->n_leafs < cgraph->size); - if (lm_ggml_is_matrix(node)) { - fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], lm_ggml_op_symbol(node->op)); - } else { - fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], lm_ggml_op_symbol(node->op)); + if (strlen(node->name) == 0) { + lm_ggml_format_name(node, "leaf_%d", cgraph->n_leafs); } - if (node->grad) { - fprintf(fp, " | %s\"; ]\n", lm_ggml_op_symbol(node->grad->op)); - } else { - fprintf(fp, "\"; ]\n"); + cgraph->leafs[cgraph->n_leafs] = node; + cgraph->n_leafs++; + } else { + LM_GGML_ASSERT(cgraph->n_nodes < cgraph->size); + + if (strlen(node->name) == 0) { + lm_ggml_format_name(node, "node_%d", cgraph->n_nodes); } + + cgraph->nodes[cgraph->n_nodes] = node; + cgraph->n_nodes++; } +} - for (int i = 0; i < gb->n_leafs; i++) { - struct lm_ggml_tensor * node = gb->leafs[i]; +static void lm_ggml_build_forward_impl(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor, bool expand) { + if (!expand) { + // TODO: this branch isn't accessible anymore, maybe move this to lm_ggml_build_forward_expand + lm_ggml_graph_clear(cgraph); + } - snprintf(color, sizeof(color), "pink"); + const int n0 = cgraph->n_nodes; - fprintf(fp, " \"%p\" [ " - "style = filled; fillcolor = %s; shape = record; " - "label=\"", - (void *) node, color); + lm_ggml_visit_parents(cgraph, tensor); - if (strlen(node->name) > 0) { - fprintf(fp, "%s (%s)|", node->name, lm_ggml_type_name(node->type)); - } else { - fprintf(fp, "(%s)|", lm_ggml_type_name(node->type)); - } + const int n_new = cgraph->n_nodes - n0; + LM_GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); - fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]); - if (lm_ggml_nelements(node) < 5 && node->data != NULL) { - fprintf(fp, " | ("); - for (int j = 0; j < lm_ggml_nelements(node); j++) { - if (node->type == LM_GGML_TYPE_I8 || node->type == LM_GGML_TYPE_I16 || node->type == LM_GGML_TYPE_I32) { - fprintf(fp, "%d", lm_ggml_get_i32_1d(node, j)); - } - else if (node->type == LM_GGML_TYPE_F32 || - node->type == LM_GGML_TYPE_F16 || - node->type == LM_GGML_TYPE_BF16) { - fprintf(fp, "%.1e", (double)lm_ggml_get_f32_1d(node, j)); - } - else { - fprintf(fp, "#"); - } - if (j < lm_ggml_nelements(node) - 1) { - fprintf(fp, ", "); - } - } - fprintf(fp, ")"); - } - fprintf(fp, "\"; ]\n"); + if (n_new > 0) { + // the last added node should always be starting point + LM_GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor); } +} - for (int i = 0; i < gb->n_nodes; i++) { - struct lm_ggml_tensor * node = gb->nodes[i]; +void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor) { + lm_ggml_build_forward_impl(cgraph, tensor, true); +} - for (int j = 0; j < LM_GGML_MAX_SRC; j++) { - if (node->src[j]) { - char label[16]; - snprintf(label, sizeof(label), "src %d", j); - lm_ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label); - } +void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool accumulate) { + LM_GGML_ASSERT(gf->n_nodes > 0); + LM_GGML_ASSERT(gf->grads); + + for (int i = 0; i < gf->n_nodes; ++i) { + struct lm_ggml_tensor * node = gf->nodes[i]; + + if (node->type == LM_GGML_TYPE_I32) { + continue; } - } - for (int i = 0; i < gb->n_leafs; i++) { - struct lm_ggml_tensor * node = gb->leafs[i]; + bool needs_grad = node->flags & LM_GGML_TENSOR_FLAG_PARAM; + bool ignore_src[LM_GGML_MAX_SRC] = {false}; + switch (node->op) { + // gradients in node->src[0] for one reason or another have no effect on output gradients + case LM_GGML_OP_IM2COL: // only used for its shape + case LM_GGML_OP_IM2COL_BACK: // same as IM2COL + ignore_src[0] = true; + break; + case LM_GGML_OP_UNARY: { + const enum lm_ggml_unary_op uop = lm_ggml_get_unary_op(node); + // SGN and STEP unary ops are piecewise constant + if (uop == LM_GGML_UNARY_OP_SGN || uop == LM_GGML_UNARY_OP_STEP) { + ignore_src[0] = true; + } + } break; - for (int j = 0; j < LM_GGML_MAX_SRC; j++) { - if (node->src[j]) { - char label[16]; - snprintf(label, sizeof(label), "src %d", j); - lm_ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label); + // gradients in node->src[1] for one reason or another have no effect on output gradients + case LM_GGML_OP_CPY: // gradients in CPY target are irrelevant + case LM_GGML_OP_GET_ROWS: // row indices not differentiable + case LM_GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS + case LM_GGML_OP_ROPE: // positions not differentiable + ignore_src[1] = true; + break; + + default: + break; + } + for (int j = 0; j < LM_GGML_MAX_SRC; ++j) { + if (!node->src[j] || !node->src[j]->grad || ignore_src[j]) { + continue; } + LM_GGML_ASSERT(node->src[j]->type == LM_GGML_TYPE_F32 || node->src[j]->type == LM_GGML_TYPE_F16); + needs_grad = true; + break; + } + if (!needs_grad) { + continue; } - } - fprintf(fp, "}\n"); + // inplace operations are currently not supported + LM_GGML_ASSERT(!node->view_src || node->op == LM_GGML_OP_CPY || node->op == LM_GGML_OP_VIEW || + node->op == LM_GGML_OP_RESHAPE || node->op == LM_GGML_OP_PERMUTE || node->op == LM_GGML_OP_TRANSPOSE); - fclose(fp); + // create a new tensor with the same type and shape as the node and set it as grad + node->grad = lm_ggml_dup_tensor(ctx, node); + } - LM_GGML_LOG_INFO("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); -} + // keep tables of original gradients for replacement/accumulation logic + struct lm_ggml_hash_set zero_table = lm_ggml_hash_set_new(gf->size); + struct lm_ggml_hash_set acc_table = lm_ggml_hash_set_new(gf->size); + for (int i = 0; i < gf->n_nodes; i++) { + struct lm_ggml_tensor * node = gf->nodes[i]; -//////////////////////////////////////////////////////////////////////////////// + if (node->grad) { + { + const size_t insert_result = lm_ggml_hash_insert(&zero_table, node->grad); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); + } -static void lm_ggml_opt_set_params(int np, struct lm_ggml_tensor * const ps[], const float * x) { - int i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = lm_ggml_nelements(ps[p]) ; - // TODO: add function to set tensor from array - for (int64_t j = 0; j < ne; ++j) { - lm_ggml_set_f32_1d(ps[p], j, x[i++]); + // only gradients of trainable parameters should be accumulated + if (accumulate && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) { + const size_t insert_result = lm_ggml_hash_insert(&acc_table, node->grad); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); + } } } -} -static void lm_ggml_opt_get_params(int np, struct lm_ggml_tensor * const ps[], float * x) { - int i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = lm_ggml_nelements(ps[p]) ; - // TODO: add function to get all elements at once - for (int64_t j = 0; j < ne; ++j) { - x[i++] = lm_ggml_get_f32_1d(ps[p], j); - } - } -} + for (int i = gf->n_nodes - 1; i >= 0; i--) { + struct lm_ggml_tensor * node = gf->nodes[i]; -static void lm_ggml_opt_get_grad(int np, struct lm_ggml_tensor * const ps[], float * g) { - int64_t i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = lm_ggml_nelements(ps[p]) ; - // TODO: add function to get all elements at once - for (int64_t j = 0; j < ne; ++j) { - g[i++] = lm_ggml_get_f32_1d(ps[p]->grad, j); + // inplace operations to add gradients are not created by lm_ggml_compute_backward except for gradient accumulation + // use allocator to automatically make inplace operations + if (node->grad) { + lm_ggml_compute_backward(ctx, node, &zero_table, &acc_table); } } -} -static void lm_ggml_opt_acc_grad(int np, struct lm_ggml_tensor * const ps[], float * g, float scale) { - int64_t i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = lm_ggml_nelements(ps[p]) ; - // TODO: add function to get all elements at once - for (int64_t j = 0; j < ne; ++j) { - g[i++] += lm_ggml_get_f32_1d(ps[p]->grad, j) * scale; + for (int i = 0; i < gf->n_nodes; i++) { + struct lm_ggml_tensor * node = gf->nodes[i]; + + if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) { + LM_GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); + lm_ggml_build_forward_expand(gb, node->grad); } } -} -// -// Using AdamW - ref: https://arxiv.org/pdf/1711.05101v3.pdf -// -// (Original Adam - ref: https://arxiv.org/pdf/1412.6980.pdf) -// + lm_ggml_hash_set_free(&zero_table); + lm_ggml_hash_set_free(&acc_table); +} -static enum lm_ggml_opt_result lm_ggml_opt_adam( +void lm_ggml_build_opt_adamw( struct lm_ggml_context * ctx, - struct lm_ggml_opt_context * opt, - struct lm_ggml_opt_params params, - struct lm_ggml_tensor * f, - struct lm_ggml_cgraph * gf, - struct lm_ggml_cgraph * gb, - lm_ggml_opt_callback callback, - void * callback_data) { - LM_GGML_ASSERT(lm_ggml_is_scalar(f)); - LM_GGML_ASSERT(f->type == LM_GGML_TYPE_F32); - - // these will store the parameters we want to optimize - struct lm_ggml_tensor * ps[LM_GGML_MAX_PARAMS]; - - int np = 0; - int64_t nx = 0; - for (int i = 0; i < gf->n_nodes; ++i) { - if (gf->nodes[i]->flags & LM_GGML_TENSOR_FLAG_PARAM) { - LM_GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - - LM_GGML_ASSERT(np < LM_GGML_MAX_PARAMS); + struct lm_ggml_cgraph * gf, + struct lm_ggml_cgraph * gb, + float alpha, + float beta1, + float beta2, + float eps, + float wd) { + for (int i = 0; i < gf->n_nodes; i++) { + struct lm_ggml_tensor * node = gf->nodes[i]; - ps[np++] = gf->nodes[i]; - nx += lm_ggml_nelements(gf->nodes[i]); + if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) { + LM_GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); + struct lm_ggml_tensor * opt_step = lm_ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd); + lm_ggml_build_forward_expand(gb, opt_step); } } +} - if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) { - int iter = opt->iter; - lm_ggml_opt_init(opt->ctx, opt, params, nx); - opt->iter = iter; - } - - // constants - float sched = params.adam.sched; - const float alpha = params.adam.alpha; - const float decay = params.adam.decay * alpha; - const float beta1 = params.adam.beta1; - const float beta2 = params.adam.beta2; - const float eps = params.adam.eps; - const float gclip = params.adam.gclip; - const int decay_min_ndim = params.adam.decay_min_ndim; - const int n_accum = MAX(1, params.n_gradient_accumulation); - const float accum_norm = 1.0f / (float) n_accum; - - float * g = opt->adam.g->data; // gradients - float * m = opt->adam.m->data; // first moment - float * v = opt->adam.v->data; // second moment - - float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values - - struct lm_ggml_cplan cplan = lm_ggml_graph_plan(gb, params.n_threads, NULL); - struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); - cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; - - bool cancel = false; - - // compute the function value - float fx = 0; - lm_ggml_set_zero(opt->adam.g); - for (int accum_step = 0; accum_step < n_accum; ++accum_step) { - if (callback) { - callback(callback_data, accum_step, &sched, &cancel); - if (cancel) { - return LM_GGML_OPT_RESULT_CANCEL; - } - } - // lm_ggml_graph_reset (gf); - lm_ggml_set_f32 (f->grad, 1.0f); - lm_ggml_graph_compute(gb, &cplan); - lm_ggml_opt_acc_grad(np, ps, g, accum_norm); - fx += lm_ggml_get_f32_1d(f, 0); - } - fx *= accum_norm; +static void * incr_ptr_aligned(void ** p, size_t size, size_t align) { + void * ptr = *p; + ptr = (void *) LM_GGML_PAD((uintptr_t) ptr, align); + *p = (void *) ((char *) ptr + size); + return ptr; +} - opt->adam.fx_prev = fx; - opt->adam.fx_best = opt->adam.fx_prev; - if (pf) { - pf[opt->iter % params.past] = opt->adam.fx_prev; +static size_t lm_ggml_graph_nbytes(size_t size, bool grads) { + size_t hash_size = lm_ggml_hash_size(size * 2); + void * p = 0; + incr_ptr_aligned(&p, sizeof(struct lm_ggml_cgraph), 1); + incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // nodes + incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // leafs + incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // hash keys + if (grads) { + incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // grads } + incr_ptr_aligned(&p, lm_ggml_bitset_size(hash_size) * sizeof(lm_ggml_bitset_t), sizeof(lm_ggml_bitset_t)); - opt->loss_before = opt->adam.fx_prev; - opt->loss_after = opt->adam.fx_prev; - - // initialize - if (opt->just_initialized) { - opt->adam.n_no_improvement = 0; - opt->just_initialized = false; - } + size_t nbytes = (size_t) p; + return nbytes; +} - float * fx_best = &opt->adam.fx_best; - float * fx_prev = &opt->adam.fx_prev; - int * n_no_improvement = &opt->adam.n_no_improvement; +size_t lm_ggml_graph_overhead_custom(size_t size, bool grads) { + return LM_GGML_OBJECT_SIZE + LM_GGML_PAD(lm_ggml_graph_nbytes(size, grads), LM_GGML_MEM_ALIGN); +} - int iter0 = opt->iter; +size_t lm_ggml_graph_overhead(void) { + return lm_ggml_graph_overhead_custom(LM_GGML_DEFAULT_GRAPH_SIZE, false); +} - // run the optimizer - for (int t = 0; t < params.adam.n_iter; ++t) { - opt->iter = iter0 + t + 1; - LM_GGML_PRINT_DEBUG ("=== iter %d ===\n", t); +struct lm_ggml_cgraph * lm_ggml_new_graph_custom(struct lm_ggml_context * ctx, size_t size, bool grads) { + const size_t obj_size = lm_ggml_graph_nbytes(size, grads); + struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_GRAPH, obj_size); + struct lm_ggml_cgraph * cgraph = (struct lm_ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); - LM_GGML_PRINT_DEBUG ("f = %10.6f\n", lm_ggml_get_f32_1d(f, 0)); - LM_GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", lm_ggml_get_f32_1d(ps[0]->grad, 0)); - LM_GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", lm_ggml_get_f32_1d(ps[1]->grad, 0)); + // the size of the hash table is doubled since it needs to hold both nodes and leafs + size_t hash_size = lm_ggml_hash_size(size * 2); - for (int i = 0; i < np; ++i) { - LM_GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i, - lm_ggml_get_f32_1d(ps[i], 0), lm_ggml_get_f32_1d(ps[i]->grad, 0)); - } + void * p = cgraph + 1; - const int64_t t_start_wall = lm_ggml_time_us(); - const int64_t t_start_cpu = lm_ggml_cycles(); - UNUSED(t_start_wall); - UNUSED(t_start_cpu); - - { - float gnorm = 1.0f; - if (gclip > 0.0f) { - // gradient clipping - lm_ggml_float sum = 0.0; - for (int64_t i = 0; i < nx; ++i) { - sum += (lm_ggml_float)(g[i]*g[i]); - } - lm_ggml_float norm = sqrt(sum); - if (norm > (lm_ggml_float) gclip) { - gnorm = (float) ((lm_ggml_float) gclip / norm); - } - } - const float beta1h = alpha*sched/(1.0f - powf(beta1, opt->iter)); - const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter)); - int64_t i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = lm_ggml_nelements(ps[p]); - const float p_decay = ((lm_ggml_n_dims(ps[p]) >= decay_min_ndim) ? decay : 0.0f) * sched; - for (int64_t j = 0; j < ne; ++j) { - float x = lm_ggml_get_f32_1d(ps[p], j); - float g_ = g[i]*gnorm; - m[i] = m[i]*beta1 + g_*(1.0f - beta1); - v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2); - float mh = m[i]*beta1h; - float vh = v[i]*beta2h; - vh = sqrtf(vh) + eps; - x = x*(1.0f - p_decay) - mh/vh; - lm_ggml_set_f32_1d(ps[p], j, x); - ++i; - } - } - } + struct lm_ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); + struct lm_ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); + struct lm_ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); + struct lm_ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)) : NULL; + lm_ggml_bitset_t * hash_used = incr_ptr_aligned(&p, lm_ggml_bitset_size(hash_size) * sizeof(lm_ggml_bitset_t), sizeof(lm_ggml_bitset_t)); - fx = 0; - lm_ggml_set_zero(opt->adam.g); - for (int accum_step = 0; accum_step < n_accum; ++accum_step) { - if (callback) { - callback(callback_data, accum_step, &sched, &cancel); - if (cancel) { - return LM_GGML_OPT_RESULT_CANCEL;; - } - } - // lm_ggml_graph_reset (gf); - lm_ggml_set_f32 (f->grad, 1.0f); - lm_ggml_graph_compute(gb, &cplan); - lm_ggml_opt_acc_grad(np, ps, g, accum_norm); - fx += lm_ggml_get_f32_1d(f, 0); - } - fx *= accum_norm; + // check that we allocated the correct amount of memory + assert(obj_size == (size_t)((char *)p - (char *)cgraph)); - opt->loss_after = fx; + *cgraph = (struct lm_ggml_cgraph) { + /*.size =*/ size, + /*.n_nodes =*/ 0, + /*.n_leafs =*/ 0, + /*.nodes =*/ nodes_ptr, + /*.grads =*/ grads_ptr, + /*.leafs =*/ leafs_ptr, + /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, + /*.order =*/ LM_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, + }; - // check convergence - if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) { - LM_GGML_PRINT_DEBUG("converged\n"); + lm_ggml_hash_set_reset(&cgraph->visited_hash_set); - return LM_GGML_OPT_RESULT_OK; - } + return cgraph; +} - // delta-based convergence test - if (pf != NULL) { - // need at least params.past iterations to start checking for convergence - if (params.past <= iter0 + t) { - const float rate = (pf[(iter0 + t)%params.past] - fx)/fx; +struct lm_ggml_cgraph * lm_ggml_new_graph(struct lm_ggml_context * ctx) { + return lm_ggml_new_graph_custom(ctx, LM_GGML_DEFAULT_GRAPH_SIZE, false); +} - if (fabsf(rate) < params.delta) { - return LM_GGML_OPT_RESULT_OK; - } - } +struct lm_ggml_cgraph lm_ggml_graph_view(struct lm_ggml_cgraph * cgraph0, int i0, int i1) { + struct lm_ggml_cgraph cgraph = { + /*.size =*/ 0, + /*.n_nodes =*/ i1 - i0, + /*.n_leafs =*/ 0, + /*.nodes =*/ cgraph0->nodes + i0, + /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL, + /*.leafs =*/ NULL, + /*.hash_table =*/ { 0, NULL, NULL }, + /*.order =*/ cgraph0->order, + }; - pf[(iter0 + t)%params.past] = fx; - } + return cgraph; +} - // check for improvement - if (params.max_no_improvement > 0) { - if (fx_best[0] > fx) { - fx_best[0] = fx; - n_no_improvement[0] = 0; - } else { - ++n_no_improvement[0]; +void lm_ggml_graph_cpy(struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst) { + LM_GGML_ASSERT(dst->size >= src->n_leafs); + LM_GGML_ASSERT(dst->size >= src->n_nodes); + LM_GGML_ASSERT(dst->visited_hash_set.size >= src->visited_hash_set.size); - if (n_no_improvement[0] >= params.max_no_improvement) { - return LM_GGML_OPT_RESULT_OK; - } - } - } + dst->n_leafs = src->n_leafs; + dst->n_nodes = src->n_nodes; + dst->order = src->order; - fx_prev[0] = fx; + for (int i = 0; i < src->n_leafs; ++i) { + dst->leafs[i] = src->leafs[i]; + } - { - const int64_t t_end_cpu = lm_ggml_cycles(); - LM_GGML_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC); - UNUSED(t_end_cpu); + for (int i = 0; i < src->n_nodes; ++i) { + dst->nodes[i] = src->nodes[i]; + } - const int64_t t_end_wall = lm_ggml_time_us(); - LM_GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6); - UNUSED(t_end_wall); + if (src->grads) { + LM_GGML_ASSERT(dst->grads != NULL); + for (int i = 0; i < src->n_nodes; ++i) { + dst->grads[i] = src->grads[i]; } } - return LM_GGML_OPT_RESULT_DID_NOT_CONVERGE; + for (size_t i = 0; i < src->visited_hash_set.size; ++i) { + // copy all hashset keys (tensors) that are in use + if (lm_ggml_bitset_get(src->visited_hash_set.used, i)) { + lm_ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]); + } + } } -// -// L-BFGS -// -// the L-BFGS implementation below is based on the following implementation: -// -// https://github.com/chokkan/liblbfgs -// - -struct lm_ggml_lbfgs_iteration_data { - float alpha; - float ys; - float * s; - float * y; -}; +struct lm_ggml_cgraph * lm_ggml_graph_dup(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph) { + struct lm_ggml_cgraph * result = lm_ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL); + lm_ggml_graph_cpy(cgraph, result); + return result; +} -static enum lm_ggml_opt_result linesearch_backtracking( - const struct lm_ggml_opt_params * params, - int nx, - float * x, - float * fx, - float * g, - float * d, - float * step, - const float * xp, - struct lm_ggml_tensor * f, - struct lm_ggml_cgraph * gb, - struct lm_ggml_cplan * cplan, - const int np, - struct lm_ggml_tensor * ps[], - bool * cancel, - lm_ggml_opt_callback callback, - void * callback_data) { - int count = 0; - - float width = 0.0f; - float dg = 0.0f; - float finit = 0.0f; - float dginit = 0.0f; - float dgtest = 0.0f; - - const float dec = 0.5f; - const float inc = 2.1f; - - const int n_accum = MAX(1, params->n_gradient_accumulation); - const float accum_norm = 1.0f / (float) n_accum; - - if (*step <= 0.f) { - return LM_GGML_LINESEARCH_INVALID_PARAMETERS; +struct lm_ggml_tensor * lm_ggml_set_zero(struct lm_ggml_tensor * tensor) { + if (lm_ggml_is_empty(tensor)) { + return tensor; } - - // compute the initial gradient in the search direction - lm_ggml_vec_dot_f32(nx, &dginit, 0, g, 0, d, 0, 1); - - // make sure that d points to a descent direction - if (0 < dginit) { - return LM_GGML_LINESEARCH_FAIL; + if (tensor->buffer) { + lm_ggml_backend_tensor_memset(tensor, 0, 0, lm_ggml_nbytes(tensor)); + } else { + LM_GGML_ASSERT(tensor->data); + memset(tensor->data, 0, lm_ggml_nbytes(tensor)); } + return tensor; +} - // initialize local variables - finit = *fx; - dgtest = params->lbfgs.ftol*dginit; - - while (true) { - lm_ggml_vec_cpy_f32(nx, x, xp); - lm_ggml_vec_mad_f32(nx, x, d, *step); - - // evaluate the function and gradient values - { - lm_ggml_opt_set_params(np, ps, x); - - *fx = 0; - memset(g, 0, sizeof(float)*nx); - for (int accum_step = 0; accum_step < n_accum; ++accum_step) { - if (callback) { - // LBFG-S does not support learning rate -> ignore learning schedule - float sched = 0; - callback(callback_data, accum_step, &sched, cancel); - if (*cancel) { - return LM_GGML_OPT_RESULT_CANCEL; - } - } - // lm_ggml_graph_reset (gf); - lm_ggml_set_f32 (f->grad, 1.0f); - lm_ggml_graph_compute(gb, cplan); - lm_ggml_opt_acc_grad(np, ps, g, accum_norm); - *fx += lm_ggml_get_f32_1d(f, 0); - } - *fx *= accum_norm; - - } - - ++count; +void lm_ggml_graph_reset(struct lm_ggml_cgraph * cgraph) { + LM_GGML_ASSERT(cgraph->grads != NULL); - if (*fx > finit + (*step)*dgtest) { - width = dec; - } else { - // Armijo condition is satisfied - if (params->lbfgs.linesearch == LM_GGML_LINESEARCH_BACKTRACKING_ARMIJO) { - return count; - } + for (int i = 0; i < cgraph->n_nodes; i++) { + struct lm_ggml_tensor * node = cgraph->nodes[i]; - lm_ggml_vec_dot_f32(nx, &dg, 0, g, 0, d, 0, 1); + // initial gradients of loss should be 1, 0 otherwise + if (node->grad) { + if (node->flags & LM_GGML_TENSOR_FLAG_LOSS) { + LM_GGML_ASSERT(node->grad->buffer); + LM_GGML_ASSERT(node->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(lm_ggml_is_scalar(node)); - // check the Wolfe condition - if (dg < params->lbfgs.wolfe * dginit) { - width = inc; + const float onef = 1.0f; + lm_ggml_backend_tensor_set(node->grad, &onef, 0, lm_ggml_nbytes(node->grad)); } else { - if(params->lbfgs.linesearch == LM_GGML_LINESEARCH_BACKTRACKING_WOLFE) { - // regular Wolfe conditions - return count; - } - - if(dg > -params->lbfgs.wolfe*dginit) { - width = dec; - } else { - // strong Wolfe condition (LM_GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) - return count; - } + lm_ggml_set_zero(node->grad); } } - if (*step < params->lbfgs.min_step) { - return LM_GGML_LINESEARCH_MINIMUM_STEP; - } - if (*step > params->lbfgs.max_step) { - return LM_GGML_LINESEARCH_MAXIMUM_STEP; - } - if (params->lbfgs.max_linesearch <= count) { - return LM_GGML_LINESEARCH_MAXIMUM_ITERATIONS; + LM_GGML_ASSERT(node); + if (node->op == LM_GGML_OP_OPT_STEP_ADAMW) { + // set iteration to 1 and clear momenta + lm_ggml_set_op_params_i32(node, 0, 1); + lm_ggml_set_zero(node->src[2]); + lm_ggml_set_zero(node->src[3]); } - - (*step) *= width; } +} - LM_GGML_ABORT("line search failed"); +void lm_ggml_graph_clear(struct lm_ggml_cgraph * cgraph) { + cgraph->n_leafs = 0; + cgraph->n_nodes = 0; + lm_ggml_hash_set_reset(&cgraph->visited_hash_set); +} - //return LM_GGML_LINESEARCH_FAIL; +int lm_ggml_graph_size(struct lm_ggml_cgraph * cgraph) { + return cgraph->size; } -static enum lm_ggml_opt_result lm_ggml_opt_lbfgs( - struct lm_ggml_context * ctx, - struct lm_ggml_opt_context * opt, - struct lm_ggml_opt_params params, - struct lm_ggml_tensor * f, - struct lm_ggml_cgraph * gf, - struct lm_ggml_cgraph * gb, - lm_ggml_opt_callback callback, - void * callback_data) { - if (params.lbfgs.linesearch == LM_GGML_LINESEARCH_BACKTRACKING_WOLFE || - params.lbfgs.linesearch == LM_GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { - if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) { - return LM_GGML_OPT_RESULT_INVALID_WOLFE; - } +struct lm_ggml_tensor * lm_ggml_graph_node(struct lm_ggml_cgraph * cgraph, int i) { + if (i < 0) { + LM_GGML_ASSERT(cgraph->n_nodes + i >= 0); + return cgraph->nodes[cgraph->n_nodes + i]; } - const int m = params.lbfgs.m; + LM_GGML_ASSERT(i < cgraph->n_nodes); + return cgraph->nodes[i]; +} + +struct lm_ggml_tensor ** lm_ggml_graph_nodes(struct lm_ggml_cgraph * cgraph) { + return cgraph->nodes; +} - // these will store the parameters we want to optimize - struct lm_ggml_tensor * ps[LM_GGML_MAX_PARAMS]; +int lm_ggml_graph_n_nodes(struct lm_ggml_cgraph * cgraph) { + return cgraph->n_nodes; +} - int np = 0; - int nx = 0; - for (int i = 0; i < gf->n_nodes; ++i) { - if (gf->nodes[i]->flags & LM_GGML_TENSOR_FLAG_PARAM) { - LM_GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); +void lm_ggml_graph_add_node(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor) { + LM_GGML_ASSERT(cgraph->size > cgraph->n_nodes); + cgraph->nodes[cgraph->n_nodes] = tensor; + cgraph->n_nodes++; +} - LM_GGML_ASSERT(np < LM_GGML_MAX_PARAMS); +struct lm_ggml_tensor * lm_ggml_graph_get_tensor(struct lm_ggml_cgraph * cgraph, const char * name) { + for (int i = 0; i < cgraph->n_leafs; i++) { + struct lm_ggml_tensor * leaf = cgraph->leafs[i]; - ps[np++] = gf->nodes[i]; - nx += lm_ggml_nelements(gf->nodes[i]); + if (strcmp(leaf->name, name) == 0) { + return leaf; } } - if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) { - int iter = opt->iter; - lm_ggml_opt_init(ctx, opt, params, nx); - opt->iter = iter; - } - - struct lm_ggml_cplan cplan = lm_ggml_graph_plan(gb, params.n_threads, NULL); - struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); - cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; - - float * x = opt->lbfgs.x->data; // current parameters - float * xp = opt->lbfgs.xp->data; // previous parameters - float * g = opt->lbfgs.g->data; // current gradient - float * gp = opt->lbfgs.gp->data; // previous gradient - float * d = opt->lbfgs.d->data; // search direction - - float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values - - const int n_accum = MAX(1, params.n_gradient_accumulation); - const float accum_norm = 1.0f / (float) n_accum; - - float fx = 0.0f; // cost function value - float xnorm = 0.0f; // ||x|| - float gnorm = 0.0f; // ||g|| + for (int i = 0; i < cgraph->n_nodes; i++) { + struct lm_ggml_tensor * node = cgraph->nodes[i]; - // initialize x from the graph nodes - lm_ggml_opt_get_params(np, ps, x); + if (strcmp(node->name, name) == 0) { + return node; + } + } - // the L-BFGS memory - float * lm_alpha = opt->lbfgs.lmal->data; - float * lm_ys = opt->lbfgs.lmys->data; - float * lm_s = opt->lbfgs.lms->data; - float * lm_y = opt->lbfgs.lmy->data; + return NULL; +} - bool cancel = false; +void lm_ggml_graph_print(const struct lm_ggml_cgraph * cgraph) { + LM_GGML_LOG_INFO("=== GRAPH ===\n"); - // evaluate the function value and its gradient - { - lm_ggml_opt_set_params(np, ps, x); - - fx = 0; - memset(g, 0, sizeof(float)*nx); - for (int accum_step = 0; accum_step < n_accum; ++accum_step) { - if (callback) { - // LBFG-S does not support learning rate -> ignore learning schedule - float sched = 0; - callback(callback_data, accum_step, &sched, &cancel); - if (cancel) { - return LM_GGML_OPT_RESULT_CANCEL; - } - } - // lm_ggml_graph_reset (gf); - lm_ggml_set_f32 (f->grad, 1.0f); - lm_ggml_graph_compute(gb, &cplan); - lm_ggml_opt_acc_grad(np, ps, g, accum_norm); - fx += lm_ggml_get_f32_1d(f, 0); - } - fx *= accum_norm; + LM_GGML_LOG_INFO("n_nodes = %d\n", cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { + struct lm_ggml_tensor * node = cgraph->nodes[i]; - opt->loss_before = fx; - opt->loss_after = fx; + LM_GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n", + i, + node->ne[0], node->ne[1], node->ne[2], + lm_ggml_op_name(node->op), (node->flags & LM_GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " "); } - // search direction = -gradient - lm_ggml_vec_neg_f32(nx, d, g); - - // ||x||, ||g|| - lm_ggml_vec_norm_f32(nx, &xnorm, x); - lm_ggml_vec_norm_f32(nx, &gnorm, g); + LM_GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs); + for (int i = 0; i < cgraph->n_leafs; i++) { + struct lm_ggml_tensor * node = cgraph->leafs[i]; - if (xnorm < 1.0f) { - xnorm = 1.0f; + LM_GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n", + i, + node->ne[0], node->ne[1], + lm_ggml_op_name(node->op), + lm_ggml_get_name(node)); } - // already optimized - if (gnorm/xnorm <= params.lbfgs.eps) { - return LM_GGML_OPT_RESULT_OK; + LM_GGML_LOG_INFO("========================================\n"); +} + +// check if node is part of the graph +static bool lm_ggml_graph_find(const struct lm_ggml_cgraph * cgraph, const struct lm_ggml_tensor * node) { + if (cgraph == NULL) { + return true; } - if (opt->just_initialized) { - if (pf) { - pf[0] = fx; + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i] == node) { + return true; } - opt->lbfgs.fx_best = fx; - - // initial step - lm_ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d); - opt->lbfgs.j = 0; - opt->lbfgs.k = 1; - opt->lbfgs.end = 0; - opt->lbfgs.n_no_improvement = 0; - opt->just_initialized = false; } - float * fx_best = &opt->lbfgs.fx_best; - float * step = &opt->lbfgs.step; - int * j = &opt->lbfgs.j; - int * k = &opt->lbfgs.k; - int * end = &opt->lbfgs.end; - int * n_no_improvement = &opt->lbfgs.n_no_improvement; - - int ls = 0; - int bound = 0; - - float ys = 0.0f; - float yy = 0.0f; - float beta = 0.0f; - - int it = 0; - - while (true) { - // store the current position and gradient vectors - lm_ggml_vec_cpy_f32(nx, xp, x); - lm_ggml_vec_cpy_f32(nx, gp, g); - - // TODO: instead of passing &cancel here, use the return code of the linesearch - // to determine if the optimization should be cancelled - // this is a simple change, but not doing this atm, since I don't have a nice - // way to test and don't want to break something with so many changes lined up - ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data); - if (cancel) { - return LM_GGML_OPT_RESULT_CANCEL; - } + return false; +} - if (ls < 0) { - // linesearch failed - go back to the previous point and return - lm_ggml_vec_cpy_f32(nx, x, xp); - lm_ggml_vec_cpy_f32(nx, g, gp); +static struct lm_ggml_tensor * lm_ggml_graph_get_parent(const struct lm_ggml_cgraph * cgraph, const struct lm_ggml_tensor * node) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct lm_ggml_tensor * parent = cgraph->nodes[i]; - return ls; + if (parent->grad == node) { + return parent; } + } + + return NULL; +} - opt->loss_after = fx; +static void lm_ggml_graph_dump_dot_node_edge(FILE * fp, const struct lm_ggml_cgraph * gb, struct lm_ggml_tensor * node, struct lm_ggml_tensor * parent, const char * label) { + struct lm_ggml_tensor * gparent = lm_ggml_graph_get_parent(gb, node); + struct lm_ggml_tensor * gparent0 = lm_ggml_graph_get_parent(gb, parent); + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n", + gparent0 ? (void *) gparent0 : (void *) parent, + gparent0 ? "g" : "x", + gparent ? (void *) gparent : (void *) node, + gparent ? "g" : "x", + gparent ? "empty" : "vee", + gparent ? "dashed" : "solid", + label); +} - lm_ggml_vec_norm_f32(nx, &xnorm, x); - lm_ggml_vec_norm_f32(nx, &gnorm, g); +static void lm_ggml_graph_dump_dot_leaf_edge(FILE * fp, struct lm_ggml_tensor * node, struct lm_ggml_tensor * parent, const char * label) { + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n", + (void *) parent, "x", + (void *) node, "x", + label); +} - LM_GGML_PRINT_DEBUG("f = %10.6f\n", lm_ggml_get_f32_1d(f, 0)); +void lm_ggml_graph_dump_dot(const struct lm_ggml_cgraph * gb, const struct lm_ggml_cgraph * gf, const char * filename) { + char color[16]; - if (xnorm < 1.0f) { - xnorm = 1.0f; - } - if (gnorm/xnorm <= params.lbfgs.eps) { - // converged - return LM_GGML_OPT_RESULT_OK; - } + FILE * fp = lm_ggml_fopen(filename, "w"); + LM_GGML_ASSERT(fp); - // delta-based convergence test - if (pf != NULL) { - // need at least params.past iterations to start checking for convergence - if (params.past <= k[0]) { - const float rate = (pf[k[0]%params.past] - fx)/fx; + fprintf(fp, "digraph G {\n"); + fprintf(fp, " newrank = true;\n"); + fprintf(fp, " rankdir = TB;\n"); - if (fabsf(rate) < params.delta) { - return LM_GGML_OPT_RESULT_OK; - } - } + for (int i = 0; i < gb->n_nodes; i++) { + struct lm_ggml_tensor * node = gb->nodes[i]; - pf[k[0]%params.past] = fx; + if (lm_ggml_graph_get_parent(gb, node) != NULL) { + continue; } - // check for improvement - if (params.max_no_improvement > 0) { - if (fx < fx_best[0]) { - fx_best[0] = fx; - n_no_improvement[0] = 0; + if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) { + snprintf(color, sizeof(color), "yellow"); + } else if (node->grad) { + if (lm_ggml_graph_find(gf, node)) { + snprintf(color, sizeof(color), "green"); } else { - n_no_improvement[0]++; - - if (n_no_improvement[0] >= params.max_no_improvement) { - return LM_GGML_OPT_RESULT_OK; - } + snprintf(color, sizeof(color), "lightblue"); } + } else { + snprintf(color, sizeof(color), "white"); } - if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) { - // reached the maximum number of iterations - return LM_GGML_OPT_RESULT_DID_NOT_CONVERGE; + fprintf(fp, " \"%p\" [ " + "style = filled; fillcolor = %s; shape = record; " + "label=\"", + (void *) node, color); + + if (strlen(node->name) > 0) { + fprintf(fp, "%s (%s)|", node->name, lm_ggml_type_name(node->type)); + } else { + fprintf(fp, "(%s)|", lm_ggml_type_name(node->type)); } - // update vectors s and y: - // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}. - // y_{k+1} = g_{k+1} - g_{k}. - // - lm_ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp); - lm_ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp); - - // compute scalars ys and yy: - // ys = y^t \cdot s -> 1 / \rho. - // yy = y^t \cdot y. - // - lm_ggml_vec_dot_f32(nx, &ys, 0, &lm_y[end[0]*nx], 0, &lm_s[end[0]*nx], 0, 1); - lm_ggml_vec_dot_f32(nx, &yy, 0, &lm_y[end[0]*nx], 0, &lm_y[end[0]*nx], 0, 1); - - lm_ys[end[0]] = ys; - - // find new search direction - // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS - - bound = (m <= k[0]) ? m : k[0]; - k[0]++; - it++; - end[0] = (end[0] + 1)%m; - - // initialize search direction with -g - lm_ggml_vec_neg_f32(nx, d, g); - - j[0] = end[0]; - for (int i = 0; i < bound; ++i) { - j[0] = (j[0] + m - 1) % m; - // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1} - lm_ggml_vec_dot_f32(nx, &lm_alpha[j[0]], 0, &lm_s[j[0]*nx], 0, d, 0, 1); - lm_alpha[j[0]] /= lm_ys[j[0]]; - // q_{i} = q_{i+1} - \alpha_{i} y_{i} - lm_ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]); + if (lm_ggml_is_matrix(node)) { + fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], lm_ggml_op_symbol(node->op)); + } else { + fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], lm_ggml_op_symbol(node->op)); } - lm_ggml_vec_scale_f32(nx, d, ys/yy); - - for (int i = 0; i < bound; ++i) { - // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i} - lm_ggml_vec_dot_f32(nx, &beta, 0, &lm_y[j[0]*nx], 0, d, 0, 1); - beta /= lm_ys[j[0]]; - // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j} - lm_ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta); - j[0] = (j[0] + 1)%m; + if (node->grad) { + fprintf(fp, " | %s\"; ]\n", lm_ggml_op_symbol(node->grad->op)); + } else { + fprintf(fp, "\"; ]\n"); } - - step[0] = 1.0; } - LM_GGML_ABORT("lbfgs failed"); - - //return LM_GGML_OPT_RESULT_DID_NOT_CONVERGE; -} - -struct lm_ggml_opt_params lm_ggml_opt_default_params(enum lm_ggml_opt_type type) { - struct lm_ggml_opt_params result; + for (int i = 0; i < gb->n_leafs; i++) { + struct lm_ggml_tensor * node = gb->leafs[i]; - switch (type) { - case LM_GGML_OPT_TYPE_ADAM: - { - result = (struct lm_ggml_opt_params) { - .type = LM_GGML_OPT_TYPE_ADAM, - .graph_size = LM_GGML_DEFAULT_GRAPH_SIZE, - .n_threads = 1, // FIXME: LM_GGML_DEFAULT_N_THREADS ? - .past = 0, - .delta = 1e-5f, - - .max_no_improvement = 100, - - .print_forward_graph = true, - .print_backward_graph = true, - - .n_gradient_accumulation = 1, - - .adam = { - .n_iter = 10000, - .sched = 1.000f, - .decay = 0.0f, - .decay_min_ndim = 2, - .alpha = 0.001f, - .beta1 = 0.9f, - .beta2 = 0.999f, - .eps = 1e-8f, - .eps_f = 1e-5f, - .eps_g = 1e-3f, - .gclip = 0.0f, - }, - }; - } break; - case LM_GGML_OPT_TYPE_LBFGS: - { - result = (struct lm_ggml_opt_params) { - .type = LM_GGML_OPT_TYPE_LBFGS, - .graph_size = LM_GGML_DEFAULT_GRAPH_SIZE, - .n_threads = 1, - .past = 0, - .delta = 1e-5f, - - .max_no_improvement = 0, - - .print_forward_graph = true, - .print_backward_graph = true, - - .n_gradient_accumulation = 1, - - .lbfgs = { - .m = 6, - .n_iter = 100, - .max_linesearch = 20, - - .eps = 1e-5f, - .ftol = 1e-4f, - .wolfe = 0.9f, - .min_step = 1e-20f, - .max_step = 1e+20f, - - .linesearch = LM_GGML_LINESEARCH_DEFAULT, - }, - }; - } break; - } + snprintf(color, sizeof(color), "pink"); - return result; -} + fprintf(fp, " \"%p\" [ " + "style = filled; fillcolor = %s; shape = record; " + "label=\"", + (void *) node, color); -LM_GGML_API void lm_ggml_opt_init( - struct lm_ggml_context * ctx, - struct lm_ggml_opt_context * opt, - struct lm_ggml_opt_params params, - int64_t nx) { - opt->ctx = ctx; - opt->params = params; - opt->iter = 0; - opt->nx = nx; - opt->just_initialized = true; - if (opt->ctx == NULL) { - struct lm_ggml_init_params ctx_opt_params; - if (opt->params.type == LM_GGML_OPT_TYPE_ADAM) { - ctx_opt_params.mem_size = LM_GGML_MEM_ALIGN*3 + lm_ggml_tensor_overhead()*3 + lm_ggml_type_size(LM_GGML_TYPE_F32)*nx*3; - if (opt->params.past > 0) { - ctx_opt_params.mem_size += LM_GGML_MEM_ALIGN + lm_ggml_tensor_overhead() + lm_ggml_type_size(LM_GGML_TYPE_F32)*opt->params.past; - } - } else if (opt->params.type == LM_GGML_OPT_TYPE_LBFGS) { - ctx_opt_params.mem_size = LM_GGML_MEM_ALIGN*9 + lm_ggml_tensor_overhead()*9 + lm_ggml_type_size(LM_GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2); - if (opt->params.past > 0) { - ctx_opt_params.mem_size += LM_GGML_MEM_ALIGN + lm_ggml_tensor_overhead() + lm_ggml_type_size(LM_GGML_TYPE_F32)*opt->params.past; - } + if (strlen(node->name) > 0) { + fprintf(fp, "%s (%s)|", node->name, lm_ggml_type_name(node->type)); + } else { + fprintf(fp, "(%s)|", lm_ggml_type_name(node->type)); } - ctx_opt_params.mem_buffer = NULL; - ctx_opt_params.no_alloc = false; - opt->ctx = lm_ggml_init(ctx_opt_params); - } - switch (opt->params.type) { - case LM_GGML_OPT_TYPE_ADAM: - { - opt->adam.g = lm_ggml_new_tensor_1d(opt->ctx, LM_GGML_TYPE_F32, nx); - opt->adam.m = lm_ggml_new_tensor_1d(opt->ctx, LM_GGML_TYPE_F32, nx); - opt->adam.v = lm_ggml_new_tensor_1d(opt->ctx, LM_GGML_TYPE_F32, nx); - opt->adam.pf = params.past > 0 - ? lm_ggml_new_tensor_1d(opt->ctx, LM_GGML_TYPE_F32, params.past) - : NULL; - lm_ggml_set_zero(opt->adam.m); - lm_ggml_set_zero(opt->adam.v); - if (opt->adam.pf) { - lm_ggml_set_zero(opt->adam.pf); + fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]); + if (lm_ggml_nelements(node) < 5 && node->data != NULL) { + fprintf(fp, " | ("); + for (int j = 0; j < lm_ggml_nelements(node); j++) { + // FIXME: use ggml-backend to obtain the tensor data + //if (node->type == LM_GGML_TYPE_I8 || node->type == LM_GGML_TYPE_I16 || node->type == LM_GGML_TYPE_I32) { + // fprintf(fp, "%d", lm_ggml_get_i32_1d(node, j)); + //} + //else if (node->type == LM_GGML_TYPE_F32 || + // node->type == LM_GGML_TYPE_F16 || + // node->type == LM_GGML_TYPE_BF16) { + // fprintf(fp, "%.1e", (double)lm_ggml_get_f32_1d(node, j)); + //} + //else + { + fprintf(fp, "#"); } - } break; - case LM_GGML_OPT_TYPE_LBFGS: - { - opt->lbfgs.x = lm_ggml_new_tensor_1d(opt->ctx, LM_GGML_TYPE_F32, nx); - opt->lbfgs.xp = lm_ggml_new_tensor_1d(opt->ctx, LM_GGML_TYPE_F32, nx); - opt->lbfgs.g = lm_ggml_new_tensor_1d(opt->ctx, LM_GGML_TYPE_F32, nx); - opt->lbfgs.gp = lm_ggml_new_tensor_1d(opt->ctx, LM_GGML_TYPE_F32, nx); - opt->lbfgs.d = lm_ggml_new_tensor_1d(opt->ctx, LM_GGML_TYPE_F32, nx); - opt->lbfgs.pf = params.past > 0 - ? lm_ggml_new_tensor_1d(opt->ctx, LM_GGML_TYPE_F32, params.past) - : NULL; - opt->lbfgs.lmal = lm_ggml_new_tensor_1d(opt->ctx, LM_GGML_TYPE_F32, params.lbfgs.m); - opt->lbfgs.lmys = lm_ggml_new_tensor_1d(opt->ctx, LM_GGML_TYPE_F32, params.lbfgs.m); - opt->lbfgs.lms = lm_ggml_new_tensor_2d(opt->ctx, LM_GGML_TYPE_F32, nx, params.lbfgs.m); - opt->lbfgs.lmy = lm_ggml_new_tensor_2d(opt->ctx, LM_GGML_TYPE_F32, nx, params.lbfgs.m); - lm_ggml_set_zero(opt->lbfgs.x); - lm_ggml_set_zero(opt->lbfgs.xp); - lm_ggml_set_zero(opt->lbfgs.g); - lm_ggml_set_zero(opt->lbfgs.gp); - lm_ggml_set_zero(opt->lbfgs.d); - if (opt->lbfgs.pf) { - lm_ggml_set_zero(opt->lbfgs.pf); + if (j < lm_ggml_nelements(node) - 1) { + fprintf(fp, ", "); } - lm_ggml_set_zero(opt->lbfgs.lmal); - lm_ggml_set_zero(opt->lbfgs.lmys); - lm_ggml_set_zero(opt->lbfgs.lms); - lm_ggml_set_zero(opt->lbfgs.lmy); - } break; - } -} - -enum lm_ggml_opt_result lm_ggml_opt( - struct lm_ggml_context * ctx, - struct lm_ggml_opt_params params, - struct lm_ggml_tensor * f) { - bool free_ctx = false; - if (ctx == NULL) { - struct lm_ggml_init_params params_ctx = { - .mem_size = 16*1024*1024, - .mem_buffer = NULL, - .no_alloc = false, - }; - - ctx = lm_ggml_init(params_ctx); - if (ctx == NULL) { - return LM_GGML_OPT_RESULT_NO_CONTEXT; + } + fprintf(fp, ")"); } - - free_ctx = true; + fprintf(fp, "\"; ]\n"); } - enum lm_ggml_opt_result result = LM_GGML_OPT_RESULT_OK; - - struct lm_ggml_opt_context * opt = (struct lm_ggml_opt_context *) alloca(sizeof(struct lm_ggml_opt_context)); - - lm_ggml_opt_init(ctx, opt, params, 0); - result = lm_ggml_opt_resume(ctx, opt, f); + for (int i = 0; i < gb->n_nodes; i++) { + struct lm_ggml_tensor * node = gb->nodes[i]; - if (free_ctx) { - lm_ggml_free(ctx); + for (int j = 0; j < LM_GGML_MAX_SRC; j++) { + if (node->src[j]) { + char label[16]; + snprintf(label, sizeof(label), "src %d", j); + lm_ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label); + } + } } - return result; -} - -enum lm_ggml_opt_result lm_ggml_opt_resume( - struct lm_ggml_context * ctx, - struct lm_ggml_opt_context * opt, - struct lm_ggml_tensor * f) { - - // build forward + backward compute graphs - struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx, opt->params.graph_size, true); - lm_ggml_build_forward_expand(gf, f); - - struct lm_ggml_cgraph * gb = lm_ggml_graph_dup(ctx, gf); - lm_ggml_build_backward_expand(ctx, gf, gb, false); - - return lm_ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL); -} - -enum lm_ggml_opt_result lm_ggml_opt_resume_g( - struct lm_ggml_context * ctx, - struct lm_ggml_opt_context * opt, - struct lm_ggml_tensor * f, - struct lm_ggml_cgraph * gf, - struct lm_ggml_cgraph * gb, - lm_ggml_opt_callback callback, - void * callback_data) { - - LM_GGML_ASSERT(f->grad && "lm_ggml_set_param must be called for at least one ancestor"); - - // build forward + backward compute graphs - enum lm_ggml_opt_result result = LM_GGML_OPT_RESULT_OK; + for (int i = 0; i < gb->n_leafs; i++) { + struct lm_ggml_tensor * node = gb->leafs[i]; - switch (opt->params.type) { - case LM_GGML_OPT_TYPE_ADAM: - { - result = lm_ggml_opt_adam(ctx, opt, opt->params, f, gf, gb, callback, callback_data); - } break; - case LM_GGML_OPT_TYPE_LBFGS: - { - result = lm_ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb, callback, callback_data); - } break; + for (int j = 0; j < LM_GGML_MAX_SRC; j++) { + if (node->src[j]) { + char label[16]; + snprintf(label, sizeof(label), "src %d", j); + lm_ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label); + } + } } - if (opt->params.print_forward_graph) { - lm_ggml_graph_print (gf); - lm_ggml_graph_dump_dot(gf, NULL, "opt-forward.dot"); - } + fprintf(fp, "}\n"); - if (opt->params.print_backward_graph) { - lm_ggml_graph_print (gb); - lm_ggml_graph_dump_dot(gb, gf, "opt-backward.dot"); - } + fclose(fp); - return result; + LM_GGML_LOG_INFO("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); } //////////////////////////////////////////////////////////////////////////////// @@ -23173,254 +8161,7 @@ void lm_gguf_get_meta_data(const struct lm_gguf_context * ctx, void * data) { lm_gguf_buf_free(buf); } -//////////////////////////////////////////////////////////////////////////////// - -int lm_ggml_cpu_has_avx(void) { -#if defined(__AVX__) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_avx_vnni(void) { -#if defined(__AVXVNNI__) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_avx2(void) { -#if defined(__AVX2__) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_avx512(void) { -#if defined(__AVX512F__) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_avx512_vbmi(void) { -#if defined(__AVX512VBMI__) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_avx512_vnni(void) { -#if defined(__AVX512VNNI__) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_avx512_bf16(void) { -#if defined(__AVX512BF16__) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_amx_int8(void) { -#if defined(__AMX_INT8__) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_fma(void) { -#if defined(__FMA__) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_neon(void) { -#if defined(__ARM_ARCH) - return lm_ggml_arm_arch_features.has_neon; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_sve(void) { -#if defined(__ARM_ARCH) - return lm_ggml_arm_arch_features.has_sve; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_arm_fma(void) { -#if defined(__ARM_FEATURE_FMA) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_riscv_v(void) { -#if defined(__riscv_v_intrinsic) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_metal(void) { -#if defined(LM_GGML_USE_METAL) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_f16c(void) { -#if defined(__F16C__) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_fp16_va(void) { -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_wasm_simd(void) { -#if defined(__wasm_simd128__) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_blas(void) { -#if defined(LM_GGML_USE_BLAS) || defined(LM_GGML_USE_CUDA) || defined(LM_GGML_USE_VULKAN) || defined(LM_GGML_USE_SYCL) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_cuda(void) { -#if defined(LM_GGML_USE_CUDA) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_vulkan(void) { -#if defined(LM_GGML_USE_VULKAN) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_kompute(void) { -#if defined(LM_GGML_USE_KOMPUTE) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_sycl(void) { -#if defined(LM_GGML_USE_SYCL) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_rpc(void) { -#if defined(LM_GGML_USE_RPC) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_cann(void) { -#if defined(LM_GGML_USE_CANN) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_llamafile(void) { -#if defined(LM_GGML_USE_LLAMAFILE) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_gpublas(void) { - return lm_ggml_cpu_has_cuda() || lm_ggml_cpu_has_vulkan() || lm_ggml_cpu_has_kompute() || lm_ggml_cpu_has_sycl(); -} - -int lm_ggml_cpu_has_sse3(void) { -#if defined(__SSE3__) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_ssse3(void) { -#if defined(__SSSE3__) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_vsx(void) { -#if defined(__POWER9_VECTOR__) - return 1; -#else - return 0; -#endif -} - -int lm_ggml_cpu_has_matmul_int8(void) { -#if defined(__ARM_ARCH) - return lm_ggml_arm_arch_features.has_i8mm; -#else - return 0; -#endif -} - -int lm_ggml_cpu_get_sve_cnt(void) { -#if defined(__ARM_ARCH) - return lm_ggml_arm_arch_features.sve_cnt; -#else - return 0; -#endif -} - void lm_ggml_log_set(lm_ggml_log_callback log_callback, void * user_data) { g_logger_state.log_callback = log_callback ? log_callback : lm_ggml_log_callback_default; g_logger_state.log_callback_user_data = user_data; } -//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/ggml.h b/cpp/ggml.h index e89b832e..677b6d36 100644 --- a/cpp/ggml.h +++ b/cpp/ggml.h @@ -176,15 +176,15 @@ #ifdef LM_GGML_SHARED # if defined(_WIN32) && !defined(__MINGW32__) # ifdef LM_GGML_BUILD -# define LM_GGML_API __declspec(dllexport) +# define LM_GGML_API __declspec(dllexport) extern # else -# define LM_GGML_API __declspec(dllimport) +# define LM_GGML_API __declspec(dllimport) extern # endif # else -# define LM_GGML_API __attribute__ ((visibility ("default"))) +# define LM_GGML_API __attribute__ ((visibility ("default"))) extern # endif #else -# define LM_GGML_API +# define LM_GGML_API extern #endif // TODO: support for clang @@ -509,7 +509,7 @@ extern "C" { LM_GGML_OP_WIN_UNPART, LM_GGML_OP_GET_REL_POS, LM_GGML_OP_ADD_REL_POS, - LM_GGML_OP_RWKV_WKV, + LM_GGML_OP_RWKV_WKV6, LM_GGML_OP_UNARY, @@ -573,6 +573,13 @@ extern "C" { LM_GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) }; + struct lm_ggml_init_params { + // memory pool + size_t mem_size; // bytes + void * mem_buffer; // if NULL, memory will be allocated internally + bool no_alloc; // don't allocate memory for the tensor data + }; + // n-dimensional tensor struct lm_ggml_tensor { enum lm_ggml_type type; @@ -618,59 +625,6 @@ extern "C" { // If it returns true, the computation is aborted typedef bool (*lm_ggml_abort_callback)(void * data); - // Scheduling priorities - enum lm_ggml_sched_priority { - LM_GGML_SCHED_PRIO_NORMAL, - LM_GGML_SCHED_PRIO_MEDIUM, - LM_GGML_SCHED_PRIO_HIGH, - LM_GGML_SCHED_PRIO_REALTIME - }; - - // Threadpool params - // Use lm_ggml_threadpool_params_default() or lm_ggml_threadpool_params_init() to populate the defaults - struct lm_ggml_threadpool_params { - bool cpumask[LM_GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings) - int n_threads; // number of threads - enum lm_ggml_sched_priority prio; // thread priority - uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling) - bool strict_cpu; // strict cpu placement - bool paused; // start in paused state - }; - - struct lm_ggml_threadpool; // forward declaration, see ggml.c - - typedef struct lm_ggml_threadpool * lm_ggml_threadpool_t; - - // the compute plan that needs to be prepared for lm_ggml_graph_compute() - // since https://github.com/ggerganov/ggml/issues/287 - struct lm_ggml_cplan { - size_t work_size; // size of work buffer, calculated by `lm_ggml_graph_plan()` - uint8_t * work_data; // work buffer, to be allocated by caller before calling to `lm_ggml_graph_compute()` - - int n_threads; - struct lm_ggml_threadpool * threadpool; - - // abort lm_ggml_graph_compute when true - lm_ggml_abort_callback abort_callback; - void * abort_callback_data; - }; - - struct lm_ggml_init_params { - // memory pool - size_t mem_size; // bytes - void * mem_buffer; // if NULL, memory will be allocated internally - bool no_alloc; // don't allocate memory for the tensor data - }; - - // numa strategies - enum lm_ggml_numa_strategy { - LM_GGML_NUMA_STRATEGY_DISABLED = 0, - LM_GGML_NUMA_STRATEGY_DISTRIBUTE = 1, - LM_GGML_NUMA_STRATEGY_ISOLATE = 2, - LM_GGML_NUMA_STRATEGY_NUMACTL = 3, - LM_GGML_NUMA_STRATEGY_MIRROR = 4, - LM_GGML_NUMA_STRATEGY_COUNT - }; // // GUID @@ -693,9 +647,6 @@ extern "C" { // accepts a UTF-8 path, even on Windows LM_GGML_API FILE * lm_ggml_fopen(const char * fname, const char * mode); - LM_GGML_API void lm_ggml_numa_init(enum lm_ggml_numa_strategy numa); // call once for better performance on NUMA systems - LM_GGML_API bool lm_ggml_is_numa(void); // true if init detected that system has >1 NUMA node - LM_GGML_API void lm_ggml_print_object (const struct lm_ggml_object * obj); LM_GGML_API void lm_ggml_print_objects(const struct lm_ggml_context * ctx); @@ -797,8 +748,7 @@ extern "C" { int64_t ne2, int64_t ne3); - LM_GGML_API struct lm_ggml_tensor * lm_ggml_new_i32(struct lm_ggml_context * ctx, int32_t value); - LM_GGML_API struct lm_ggml_tensor * lm_ggml_new_f32(struct lm_ggml_context * ctx, float value); + LM_GGML_API void * lm_ggml_new_buffer(struct lm_ggml_context * ctx, size_t nbytes); LM_GGML_API struct lm_ggml_tensor * lm_ggml_dup_tensor (struct lm_ggml_context * ctx, const struct lm_ggml_tensor * src); LM_GGML_API struct lm_ggml_tensor * lm_ggml_view_tensor(struct lm_ggml_context * ctx, struct lm_ggml_tensor * src); @@ -808,35 +758,25 @@ extern "C" { LM_GGML_API struct lm_ggml_tensor * lm_ggml_get_next_tensor (const struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor); LM_GGML_API struct lm_ggml_tensor * lm_ggml_get_tensor(struct lm_ggml_context * ctx, const char * name); - LM_GGML_API struct lm_ggml_tensor * lm_ggml_set_zero(struct lm_ggml_tensor * tensor); - LM_GGML_API struct lm_ggml_tensor * lm_ggml_set_i32 (struct lm_ggml_tensor * tensor, int32_t value); - LM_GGML_API struct lm_ggml_tensor * lm_ggml_set_f32 (struct lm_ggml_tensor * tensor, float value); - // Converts a flat index into coordinates - LM_GGML_API void lm_ggml_unravel_index(const struct lm_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); - - LM_GGML_API int32_t lm_ggml_get_i32_1d(const struct lm_ggml_tensor * tensor, int i); - LM_GGML_API void lm_ggml_set_i32_1d(const struct lm_ggml_tensor * tensor, int i, int32_t value); + LM_GGML_API void lm_ggml_unravel_index(const struct lm_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); - LM_GGML_API int32_t lm_ggml_get_i32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3); - LM_GGML_API void lm_ggml_set_i32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value); - - LM_GGML_API float lm_ggml_get_f32_1d(const struct lm_ggml_tensor * tensor, int i); - LM_GGML_API void lm_ggml_set_f32_1d(const struct lm_ggml_tensor * tensor, int i, float value); - - LM_GGML_API float lm_ggml_get_f32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3); - LM_GGML_API void lm_ggml_set_f32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value); + LM_GGML_API enum lm_ggml_unary_op lm_ggml_get_unary_op(const struct lm_ggml_tensor * tensor); LM_GGML_API void * lm_ggml_get_data (const struct lm_ggml_tensor * tensor); LM_GGML_API float * lm_ggml_get_data_f32(const struct lm_ggml_tensor * tensor); - LM_GGML_API enum lm_ggml_unary_op lm_ggml_get_unary_op(const struct lm_ggml_tensor * tensor); - LM_GGML_API const char * lm_ggml_get_name (const struct lm_ggml_tensor * tensor); LM_GGML_API struct lm_ggml_tensor * lm_ggml_set_name ( struct lm_ggml_tensor * tensor, const char * name); LM_GGML_ATTRIBUTE_FORMAT(2, 3) LM_GGML_API struct lm_ggml_tensor * lm_ggml_format_name( struct lm_ggml_tensor * tensor, const char * fmt, ...); + // Tensor flags + LM_GGML_API void lm_ggml_set_input(struct lm_ggml_tensor * tensor); + LM_GGML_API void lm_ggml_set_output(struct lm_ggml_tensor * tensor); + LM_GGML_API void lm_ggml_set_param(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor); + LM_GGML_API void lm_ggml_set_loss(struct lm_ggml_tensor * tensor); + // // operations on tensors with backpropagation // @@ -1550,7 +1490,7 @@ extern "C" { "use lm_ggml_rope_ext_inplace instead"); // compute correction dims for YaRN RoPE scaling - void lm_ggml_rope_yarn_corr_dims( + LM_GGML_API void lm_ggml_rope_yarn_corr_dims( int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]); // rotary position embedding backward, i.e compute dx from dy @@ -1806,6 +1746,9 @@ extern "C" { struct lm_ggml_tensor * a, enum lm_ggml_prec prec); + LM_GGML_API enum lm_ggml_prec lm_ggml_flash_attn_ext_get_prec( + const struct lm_ggml_tensor * a); + // TODO: needs to be adapted to lm_ggml_flash_attn_ext LM_GGML_API struct lm_ggml_tensor * lm_ggml_flash_attn_back( struct lm_ggml_context * ctx, @@ -1879,7 +1822,7 @@ extern "C" { struct lm_ggml_tensor * pw, struct lm_ggml_tensor * ph); - LM_GGML_API struct lm_ggml_tensor * lm_ggml_rwkv_wkv( + LM_GGML_API struct lm_ggml_tensor * lm_ggml_rwkv_wkv6( struct lm_ggml_context * ctx, struct lm_ggml_tensor * k, struct lm_ggml_tensor * v, @@ -2052,9 +1995,6 @@ extern "C" { // automatic differentiation // - LM_GGML_API void lm_ggml_set_param(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor); - LM_GGML_API void lm_ggml_set_loss(struct lm_ggml_tensor * tensor); - LM_GGML_API void lm_ggml_build_forward_expand (struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor); LM_GGML_API void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool accumulate); @@ -2086,27 +2026,6 @@ extern "C" { LM_GGML_API size_t lm_ggml_graph_overhead(void); LM_GGML_API size_t lm_ggml_graph_overhead_custom(size_t size, bool grads); - LM_GGML_API struct lm_ggml_threadpool_params lm_ggml_threadpool_params_default(int n_threads); - LM_GGML_API void lm_ggml_threadpool_params_init (struct lm_ggml_threadpool_params * p, int n_threads); - LM_GGML_API bool lm_ggml_threadpool_params_match (const struct lm_ggml_threadpool_params * p0, const struct lm_ggml_threadpool_params * p1); - LM_GGML_API struct lm_ggml_threadpool * lm_ggml_threadpool_new (struct lm_ggml_threadpool_params * params); - LM_GGML_API void lm_ggml_threadpool_free (struct lm_ggml_threadpool * threadpool); - LM_GGML_API int lm_ggml_threadpool_get_n_threads(struct lm_ggml_threadpool * threadpool); - LM_GGML_API void lm_ggml_threadpool_pause (struct lm_ggml_threadpool * threadpool); - LM_GGML_API void lm_ggml_threadpool_resume (struct lm_ggml_threadpool * threadpool); - - // lm_ggml_graph_plan() has to be called before lm_ggml_graph_compute() - // when plan.work_size > 0, caller must allocate memory for plan.work_data - LM_GGML_API struct lm_ggml_cplan lm_ggml_graph_plan( - const struct lm_ggml_cgraph * cgraph, - int n_threads, /* = LM_GGML_DEFAULT_N_THREADS */ - struct lm_ggml_threadpool * threadpool /* = NULL */ ); - LM_GGML_API enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct lm_ggml_cplan * cplan); - - // same as lm_ggml_graph_compute() but the work data is allocated as a part of the context - // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data - LM_GGML_API enum lm_ggml_status lm_ggml_graph_compute_with_ctx(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, int n_threads); - LM_GGML_API struct lm_ggml_tensor * lm_ggml_graph_get_tensor(struct lm_ggml_cgraph * cgraph, const char * name); LM_GGML_API void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fname); @@ -2277,6 +2196,8 @@ extern "C" { } lbfgs; }; + LM_GGML_API struct lm_ggml_tensor * lm_ggml_set_zero(struct lm_ggml_tensor * tensor); + LM_GGML_API struct lm_ggml_opt_params lm_ggml_opt_default_params(enum lm_ggml_opt_type type); // optimize the function defined by the tensor f @@ -2308,12 +2229,6 @@ extern "C" { lm_ggml_opt_callback callback, void * callback_data); - // - // tensor flags - // - LM_GGML_API void lm_ggml_set_input(struct lm_ggml_tensor * tensor); - LM_GGML_API void lm_ggml_set_output(struct lm_ggml_tensor * tensor); - // // quantization // @@ -2469,48 +2384,6 @@ extern "C" { LM_GGML_API size_t lm_gguf_get_meta_size(const struct lm_gguf_context * ctx); LM_GGML_API void lm_gguf_get_meta_data(const struct lm_gguf_context * ctx, void * data); - // - // system info - // - - LM_GGML_API int lm_ggml_cpu_has_avx (void); - LM_GGML_API int lm_ggml_cpu_has_avx_vnni (void); - LM_GGML_API int lm_ggml_cpu_has_avx2 (void); - LM_GGML_API int lm_ggml_cpu_has_avx512 (void); - LM_GGML_API int lm_ggml_cpu_has_avx512_vbmi(void); - LM_GGML_API int lm_ggml_cpu_has_avx512_vnni(void); - LM_GGML_API int lm_ggml_cpu_has_avx512_bf16(void); - LM_GGML_API int lm_ggml_cpu_has_amx_int8 (void); - LM_GGML_API int lm_ggml_cpu_has_fma (void); - LM_GGML_API int lm_ggml_cpu_has_neon (void); - LM_GGML_API int lm_ggml_cpu_has_sve (void); - LM_GGML_API int lm_ggml_cpu_has_arm_fma (void); - LM_GGML_API int lm_ggml_cpu_has_metal (void); - LM_GGML_API int lm_ggml_cpu_has_f16c (void); - LM_GGML_API int lm_ggml_cpu_has_fp16_va (void); - LM_GGML_API int lm_ggml_cpu_has_wasm_simd (void); - LM_GGML_API int lm_ggml_cpu_has_blas (void); - LM_GGML_API int lm_ggml_cpu_has_cuda (void); - LM_GGML_API int lm_ggml_cpu_has_vulkan (void); - LM_GGML_API int lm_ggml_cpu_has_kompute (void); - LM_GGML_API int lm_ggml_cpu_has_gpublas (void); - LM_GGML_API int lm_ggml_cpu_has_sse3 (void); - LM_GGML_API int lm_ggml_cpu_has_ssse3 (void); - LM_GGML_API int lm_ggml_cpu_has_riscv_v (void); - LM_GGML_API int lm_ggml_cpu_has_sycl (void); - LM_GGML_API int lm_ggml_cpu_has_rpc (void); - LM_GGML_API int lm_ggml_cpu_has_vsx (void); - LM_GGML_API int lm_ggml_cpu_has_matmul_int8(void); - LM_GGML_API int lm_ggml_cpu_has_cann (void); - LM_GGML_API int lm_ggml_cpu_has_llamafile (void); - - // get the sve vector length in bytes - LM_GGML_API int lm_ggml_cpu_get_sve_cnt(void); - - // - // Internal types and functions exposed for tests and benchmarks - // - #ifdef __cplusplus // restrict not standard in C++ #define LM_GGML_RESTRICT @@ -2519,14 +2392,6 @@ extern "C" { #endif typedef void (*lm_ggml_to_float_t) (const void * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); typedef void (*lm_ggml_from_float_t)(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); - typedef void (*lm_ggml_from_float_to_mat_t) - (const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs); - typedef void (*lm_ggml_vec_dot_t) (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT x, size_t bx, - const void * LM_GGML_RESTRICT y, size_t by, int nrc); - typedef void (*lm_ggml_gemv_t) (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT x, - const void * LM_GGML_RESTRICT y, int nr, int nc); - typedef void (*lm_ggml_gemm_t) (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT x, - const void * LM_GGML_RESTRICT y, int nr, int nc); struct lm_ggml_type_traits { const char * type_name; @@ -2535,15 +2400,7 @@ extern "C" { size_t type_size; bool is_quantized; lm_ggml_to_float_t to_float; - lm_ggml_from_float_t from_float; lm_ggml_from_float_t from_float_ref; - lm_ggml_from_float_to_mat_t from_float_to_mat; - lm_ggml_vec_dot_t vec_dot; - enum lm_ggml_type vec_dot_type; - int64_t nrows; // number of rows to process simultaneously - int64_t ncols; // number of columns to process simultaneously - lm_ggml_gemv_t gemv; - lm_ggml_gemm_t gemm; }; LM_GGML_API const struct lm_ggml_type_traits * lm_ggml_get_type_traits(enum lm_ggml_type type); diff --git a/cpp/llama-sampling.cpp b/cpp/llama-sampling.cpp index 6d298a46..755c72b0 100644 --- a/cpp/llama-sampling.cpp +++ b/cpp/llama-sampling.cpp @@ -1876,8 +1876,11 @@ static void llama_sampler_dry_reset(struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) { const auto * ctx = (llama_sampler_dry *) smpl->ctx; - // nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying - auto * result = llama_sampler_init_dry(nullptr, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0); + llama_vocab dummy_vocab; + + // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying + auto * result = llama_sampler_init_dry_impl(dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0); + // Copy the state, including the processed breakers { auto * result_ctx = (llama_sampler_dry *) result->ctx; diff --git a/cpp/llama.cpp b/cpp/llama.cpp index 6296abe7..d5c006d3 100644 --- a/cpp/llama.cpp +++ b/cpp/llama.cpp @@ -2312,6 +2312,7 @@ enum e_model { MODEL_1B, MODEL_1_3B, MODEL_1_4B, + MODEL_1_5B, MODEL_1_6B, MODEL_2B, MODEL_2_8B, @@ -2917,9 +2918,15 @@ struct llama_model { // for quantize-stats only std::vector> tensors_by_name; - int64_t t_load_us = 0; + int64_t t_load_us = 0; int64_t t_start_us = 0; + // total number of parameters in the model + uint64_t n_elements = 0; + + // total size of all the tensors in the model in bytes + size_t n_bytes = 0; + // keep track of loaded lora adapters std::set lora_adapters; @@ -3512,11 +3519,24 @@ static bool llama_kv_cache_init( return true; } +// a structure holds information about the slot found in llama_kv_cache_find_slot +struct llama_kv_cache_slot_info { + std::pair boundaries; // slot boundaries [begin, end) + bool found = false; // the slot was found + + explicit llama_kv_cache_slot_info(bool found_) : found{found_} {} + llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {} + + operator bool() const { return found; } +}; +static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false}; + // find an empty slot of size "n_tokens" in the cache // updates the cache head +// returns a structure holding information about the slot found // Note: On success, it's important that cache.head points // to the first cell of the slot. -static bool llama_kv_cache_find_slot( +static struct llama_kv_cache_slot_info llama_kv_cache_find_slot( struct llama_kv_cache & cache, const struct llama_ubatch & batch) { const uint32_t n_tokens = batch.n_tokens; @@ -3544,7 +3564,7 @@ static bool llama_kv_cache_find_slot( // too big seq_id // TODO: would it be possible to resize the cache instead? LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); - return false; + return llama_kv_cache_slot_info_failed; } if (j > 0) { llama_kv_cell & seq = cache.cells[seq_id]; @@ -3679,15 +3699,17 @@ static bool llama_kv_cache_find_slot( // allow getting the range of used cells, from head to head + n cache.head = min; cache.n = max - min + 1; + cache.used = std::count_if(cache.cells.begin(), cache.cells.end(), + [](const llama_kv_cell& cell){ return !cell.is_empty(); }); // sanity check - return cache.n >= n_seqs; + return llama_kv_cache_slot_info(cache.n >= n_seqs); } // otherwise, one cell per token. if (n_tokens > cache.size) { LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size); - return false; + return llama_kv_cache_slot_info_failed; } uint32_t n_tested = 0; @@ -3715,7 +3737,7 @@ static bool llama_kv_cache_find_slot( if (n_tested >= cache.size) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; + return llama_kv_cache_slot_info_failed; } } @@ -3732,7 +3754,7 @@ static bool llama_kv_cache_find_slot( cache.used += n_tokens; - return true; + return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens); } // find how many cells are currently in use @@ -4008,6 +4030,53 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) return cparams.flash_attn ? 256u : 32u; } +// saves the kv_cache state for future recovery. +// used to rollback llama_kv_cache_find_slot changes. +struct llama_kv_slot_restorer { + struct llama_kv_cache_state { + uint32_t head = 0; + uint32_t n = 0; + } old_state; + + // for non-recurrent models only + // list of slots to restore + std::vector> slot_boundaries; + + bool do_restore = false; + + explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) { + old_state.head = cache.head; + old_state.n = cache.n; + } + + // saves a slot information for future restoration + void save(const struct llama_kv_cache_slot_info & slot) { + if (slot) { + do_restore = true; + if (slot.boundaries.first != slot.boundaries.second) { + slot_boundaries.push_back(slot.boundaries); + } + } + } + + // must be explicitly called to restore the kv_cache state + // and rollback changes from all llama_kv_cache_find_slot calls + void restore(struct llama_kv_cache & cache) { + if (do_restore) { + cache.head = old_state.head; + cache.n = old_state.n; + + if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased + llama_kv_cache_seq_rm(cache, -1, -1, -1); + } else { + for (auto & slot : slot_boundaries) { + llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second); + } + } + } + } +}; + // // model loading and saving // @@ -4223,8 +4292,8 @@ struct llama_model_loader { int n_tensors = 0; int n_created = 0; - int64_t n_elements = 0; - size_t n_bytes = 0; + uint64_t n_elements = 0; + size_t n_bytes = 0; bool use_mmap = false; bool check_tensors; @@ -5238,6 +5307,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_1B: return "1B"; case MODEL_1_3B: return "1.3B"; case MODEL_1_4B: return "1.4B"; + case MODEL_1_5B: return "1.5B"; case MODEL_1_6B: return "1.6B"; case MODEL_2B: return "2B"; case MODEL_2_8B: return "2.8B"; @@ -5291,6 +5361,11 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ } } +static void llm_load_stats(llama_model_loader & ml, llama_model & model) { + model.n_elements = ml.n_elements; + model.n_bytes = ml.n_bytes; +} + static void llm_load_arch(llama_model_loader & ml, llama_model & model) { model.arch = ml.get_arch(); if (model.arch == LLM_ARCH_UNKNOWN) { @@ -5609,6 +5684,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 24: model.type = hparams.n_embd == 1024 ? e_model::MODEL_0_5B : e_model::MODEL_1B; break; + case 28: model.type = hparams.n_embd == 1536 ? e_model::MODEL_1_5B : e_model::MODEL_7B; break; case 32: model.type = e_model::MODEL_7B; break; case 40: model.type = hparams.n_head() == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break; case 80: model.type = e_model::MODEL_70B; break; @@ -7022,7 +7098,7 @@ static const std::map llm_tensor_info_mapping = { {LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_ADD}}, - {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_RWKV_WKV}}, + {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_RWKV_WKV6}}, {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}}, {LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}}, {LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}}, @@ -7138,7 +7214,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, lm_ggml_tensor lm_ggml_tensor * C = lm_ggml_new_tensor_3d(ctx, LM_GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); op_tensor = lm_ggml_ssm_scan(ctx, s, x, dt, w, B, C); } break; - case LM_GGML_OP_RWKV_WKV: + case LM_GGML_OP_RWKV_WKV6: { // FIXME const int64_t S = 123; @@ -7151,7 +7227,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, lm_ggml_tensor lm_ggml_tensor * tf = w; lm_ggml_tensor * td = lm_ggml_new_tensor_4d(ctx, LM_GGML_TYPE_F32, 1, S, H, n_tokens); lm_ggml_tensor * state = lm_ggml_new_tensor_4d(ctx, LM_GGML_TYPE_F32, S, n_seqs, S, H); - op_tensor = lm_ggml_rwkv_wkv(ctx, k, v, r, tf, td, state); + op_tensor = lm_ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); } break; default: LM_GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, lm_ggml_op_name(op), w->name); @@ -7200,7 +7276,7 @@ static llama_model::buft_list_t make_cpu_buft_list(llama_model & model) { auto * cpu_dev = lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU); auto * cpu_reg = lm_ggml_backend_dev_backend_reg(cpu_dev); auto lm_ggml_backend_dev_get_extra_bufts_fn = (lm_ggml_backend_dev_get_extra_bufts_t) - lm_ggml_backend_reg_get_proc_address(cpu_reg, "lm_ggml_backend_cpu_get_extra_bufts"); + lm_ggml_backend_reg_get_proc_address(cpu_reg, "lm_ggml_backend_dev_get_extra_bufts"); if (lm_ggml_backend_dev_get_extra_bufts_fn) { lm_ggml_backend_buffer_type_t * extra_bufts = lm_ggml_backend_dev_get_extra_bufts_fn(cpu_dev); while (extra_bufts && *extra_bufts) { @@ -7467,7 +7543,7 @@ static bool llm_load_tensors( // avoid using a host buffer when using mmap auto * buft_dev = lm_ggml_backend_buft_get_device(buft); - if (ml.use_mmap && buft == lm_ggml_backend_dev_host_buffer_type(buft_dev)) { + if (ml.use_mmap && buft_dev && buft == lm_ggml_backend_dev_host_buffer_type(buft_dev)) { auto * cpu_dev = lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU); buft = lm_ggml_backend_dev_buffer_type(cpu_dev); } @@ -9074,6 +9150,10 @@ static bool llm_load_tensors( // check if it is possible to use buffer_from_host_ptr with this buffer type lm_ggml_backend_dev_t dev = lm_ggml_backend_buft_get_device(buft); + if (!dev) { + // FIXME: workaround for CPU backend buft having a NULL device + dev = lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0); + } lm_ggml_backend_dev_props props; lm_ggml_backend_dev_get_props(dev, &props); bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr; @@ -9145,7 +9225,7 @@ static bool llm_load_tensors( // print memory requirements per buffer type for (auto & buf : model.bufs) { - LLAMA_LOG_INFO("%s: %10s model buffer size = %8.2f MiB\n", __func__, lm_ggml_backend_buffer_name(buf.get()), lm_ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); + LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, lm_ggml_backend_buffer_name(buf.get()), lm_ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); } // populate tensors_by_name @@ -9198,6 +9278,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam throw std::runtime_error("error loading model vocabulary: " + std::string(e.what())); } + llm_load_stats(ml, model); llm_load_print_meta(ml, model); if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE && @@ -10094,7 +10175,7 @@ static struct lm_ggml_tensor * llm_build_rwkv6_time_mix( v = lm_ggml_transpose(ctx, v); r = lm_ggml_transpose(ctx, r); - struct lm_ggml_tensor * wkv_output = lm_ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state); + struct lm_ggml_tensor * wkv_output = lm_ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state); cur = lm_ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0); *wkv_state = lm_ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); @@ -17189,7 +17270,8 @@ static void llama_output_reorder(struct llama_context * ctx) { } } -static void llama_graph_compute( +// returns the result of lm_ggml_backend_sched_graph_compute_async execution +static enum lm_ggml_status llama_graph_compute( llama_context & lctx, lm_ggml_cgraph * gf, int n_threads, @@ -17204,15 +17286,20 @@ static void llama_graph_compute( set_n_threads_fn.second(set_n_threads_fn.first, n_threads); } - auto err = lm_ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf); - if (err != LM_GGML_STATUS_SUCCESS) { - LLAMA_LOG_ERROR("%s: lm_ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err); + auto status = lm_ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf); + if (status != LM_GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: lm_ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); } // fprintf(stderr, "splits: %d\n", lm_ggml_backend_sched_get_n_splits(lctx.sched)); + + return status; } // decode a batch of tokens by evaluating the transformer +// in case of unsuccessful decoding (error or warning), +// the kv_cache state will be returned to its original state +// (for non-recurrent models) or cleaned (for recurrent models) // // - lctx: llama context // - batch: batch to evaluate @@ -17262,6 +17349,7 @@ static int llama_decode_internal( lctx.n_queued_tokens += n_tokens_all; auto & kv_self = lctx.kv_self; + llama_kv_slot_restorer kv_slot_restorer(kv_self); const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -17346,9 +17434,11 @@ static int llama_decode_internal( kv_self.head = 0; } - if (!llama_kv_cache_find_slot(kv_self, ubatch)) { + const auto slot = llama_kv_cache_find_slot(kv_self, ubatch); + if (!slot) { return 1; } + kv_slot_restorer.save(slot); if (!kv_self.recurrent) { // a heuristic, to avoid attending the full cache if it is not yet utilized @@ -17395,7 +17485,19 @@ static int llama_decode_internal( llama_set_inputs(lctx, ubatch); - llama_graph_compute(lctx, gf, n_threads, threadpool); + const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); + if (compute_status != LM_GGML_STATUS_SUCCESS) { + kv_slot_restorer.restore(kv_self); + switch (compute_status) { + case LM_GGML_STATUS_ABORTED: + return 2; + case LM_GGML_STATUS_ALLOC_FAILED: + return -2; + case LM_GGML_STATUS_FAILED: + default: + return -3; + } + } // update the kv ring buffer { @@ -17632,7 +17734,18 @@ static int llama_encode_internal( llama_set_inputs(lctx, ubatch); - llama_graph_compute(lctx, gf, n_threads, threadpool); + const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); + switch (compute_status) { + case LM_GGML_STATUS_SUCCESS: + break; + case LM_GGML_STATUS_ABORTED: + return 2; + case LM_GGML_STATUS_ALLOC_FAILED: + return -2; + case LM_GGML_STATUS_FAILED: + default: + return -3; + } // extract embeddings if (embd) { @@ -18511,6 +18624,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s llama_model model; llm_load_arch(ml, model); llm_load_hparams(ml, model); + llm_load_stats(ml, model); struct quantize_state_internal qs(model, params); @@ -19451,12 +19565,26 @@ struct llama_context * llama_new_context_with_model( cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; } - LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); - LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); - LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); - LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); - LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + + LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq); + LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); + LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); + LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); + LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + + if (n_ctx_per_seq < hparams.n_ctx_train) { + LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", + __func__, n_ctx_per_seq, hparams.n_ctx_train); + } + + if (n_ctx_per_seq > hparams.n_ctx_train) { + LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", + __func__, n_ctx_per_seq, hparams.n_ctx_train); + } ctx->abort_callback = params.abort_callback; ctx->abort_callback_data = params.abort_callback_data; @@ -19849,19 +19977,11 @@ int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t bu } uint64_t llama_model_size(const struct llama_model * model) { - uint64_t size = 0; - for (const auto & it : model->tensors_by_name) { - size += lm_ggml_nbytes(it.second); - } - return size; + return model->n_bytes; } uint64_t llama_model_n_params(const struct llama_model * model) { - uint64_t nparams = 0; - for (const auto & it : model->tensors_by_name) { - nparams += lm_ggml_nelements(it.second); - } - return nparams; + return model->n_elements; } struct lm_ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) { @@ -21796,8 +21916,11 @@ static int32_t llama_chat_apply_template_internal( // IBM Granite template for (const auto & message : chat) { std::string role(message->role); - ss << "<|start_of_role|>" << role << "<|end_of_role|>" - << message->content << "<|end_of_text|>\n"; + ss << "<|start_of_role|>" << role << "<|end_of_role|>"; + if (role == "assistant_tool_call") { + ss << "<|tool_call|>"; + } + ss << message->content << "<|end_of_text|>\n"; } if (add_ass) { ss << "<|start_of_role|>assistant<|end_of_role|>\n"; @@ -21897,6 +22020,8 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int } const char * llama_print_system_info(void) { + lm_ggml_cpu_init(); // some ARM features are detected at runtime + static std::string s; s = ""; @@ -21916,7 +22041,6 @@ const char * llama_print_system_info(void) { s += "FP16_VA = " + std::to_string(lm_ggml_cpu_has_fp16_va()) + " | "; s += "RISCV_VECT = " + std::to_string(lm_ggml_cpu_has_riscv_v()) + " | "; s += "WASM_SIMD = " + std::to_string(lm_ggml_cpu_has_wasm_simd()) + " | "; - s += "BLAS = " + std::to_string(lm_ggml_cpu_has_blas()) + " | "; s += "SSE3 = " + std::to_string(lm_ggml_cpu_has_sse3()) + " | "; s += "SSSE3 = " + std::to_string(lm_ggml_cpu_has_ssse3()) + " | "; s += "VSX = " + std::to_string(lm_ggml_cpu_has_vsx()) + " | "; @@ -21962,28 +22086,6 @@ void llama_perf_context_reset(struct llama_context * ctx) { ctx->t_p_eval_us = ctx->n_p_eval = 0; } -void llama_perf_dump_yaml(FILE * stream, const llama_context * ctx) { - fprintf(stream, "\n"); - fprintf(stream, "###########\n"); - fprintf(stream, "# Timings #\n"); - fprintf(stream, "###########\n"); - fprintf(stream, "\n"); - - fprintf(stream, "mst_eval: %.2f # ms / token during generation\n", - 1.0e-3 * ctx->t_eval_us / ctx->n_eval); - fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n", - 1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval); - fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); - fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); - fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); - fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); - fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); - fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", - 1.0e6 * ctx->n_eval / ctx->t_eval_us); - fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", - 1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us); -} - // For internal test use const std::vector> & llama_internal_get_tensor_map( struct llama_context * ctx diff --git a/cpp/llama.h b/cpp/llama.h index 3a6ae914..9d627af1 100644 --- a/cpp/llama.h +++ b/cpp/llama.h @@ -2,6 +2,7 @@ #define LLAMA_H #include "ggml.h" +#include "ggml-cpu.h" #include "ggml-backend.h" #include @@ -796,7 +797,7 @@ extern "C" { // Processes a batch of tokens with the ecoder part of the encoder-decoder model. // Stores the encoder output internally for later use by the decoder cross-attention layers. // 0 - success - // < 0 - error + // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch); @@ -804,7 +805,7 @@ extern "C" { // Positive return values does not mean a fatal error, but rather a warning. // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) - // < 0 - error + // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch); @@ -1243,8 +1244,6 @@ extern "C" { LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); - LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx); - #ifdef __cplusplus } #endif diff --git a/cpp/sgemm.cpp b/cpp/sgemm.cpp index 41dbb452..5f8cc2d5 100644 --- a/cpp/sgemm.cpp +++ b/cpp/sgemm.cpp @@ -106,6 +106,10 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); } inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if defined(__MMA__) +typedef vector unsigned char vec_t; +typedef __vector_quad acc_t; +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED FUSED MULTIPLY ADD @@ -1026,6 +1030,600 @@ class tinyBLAS_Q0_AVX { }; #endif // __AVX__ +//PPC Implementation +#if defined(__MMA__) + +#define SAVE_ACC(ACC, ii, jj) \ + __builtin_mma_disassemble_acc(vec_C, ACC); \ + for (int I = 0; I < 4; I++) { \ + for (int J = 0; J < 4; J++) { \ + *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \ + } \ + } \ + +template +class tinyBLAS_PPC { + public: + tinyBLAS_PPC(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int64_t m, int64_t n) { + mnpack(0, m, 0, n); + } + + private: + + void (tinyBLAS_PPC::*kernel)(int64_t, int64_t); + + void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) { + int64_t i, j; + float *aoffset = NULL, *boffset = NULL; + float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; + float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + + aoffset = const_cast(a); + boffset = vec; + j = (rows >> 3); + if (j > 0) { + do { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset5 = aoffset4 + lda; + aoffset6 = aoffset5 + lda; + aoffset7 = aoffset6 + lda; + aoffset8 = aoffset7 + lda; + aoffset += 8 * lda; + i = (cols >> 3); + if (i > 0) { + __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; + vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2]; + vector float t1, t2, t3, t4, t5, t6, t7, t8; + do { + C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); + C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); + C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); + C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); + C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5); + C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6); + C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7); + C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8); + __builtin_vsx_disassemble_pair(c1, &C1); + __builtin_vsx_disassemble_pair(c2, &C2); + __builtin_vsx_disassemble_pair(c3, &C3); + __builtin_vsx_disassemble_pair(c4, &C4); + __builtin_vsx_disassemble_pair(c5, &C5); + __builtin_vsx_disassemble_pair(c6, &C6); + __builtin_vsx_disassemble_pair(c7, &C7); + __builtin_vsx_disassemble_pair(c8, &C8); + + t1 = vec_mergeh(c1[0], c2[0]); + t2 = vec_mergeh(c3[0], c4[0]); + t3 = vec_mergeh(c5[0], c6[0]); + t4 = vec_mergeh(c7[0], c8[0]); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset); + vec_xst(t6, 0, boffset+4); + vec_xst(t7, 0, boffset+8); + vec_xst(t8, 0, boffset+12); + + t1 = vec_mergel(c1[0], c2[0]); + t2 = vec_mergel(c3[0], c4[0]); + t3 = vec_mergel(c5[0], c6[0]); + t4 = vec_mergel(c7[0], c8[0]); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset+16); + vec_xst(t6, 0, boffset+20); + vec_xst(t7, 0, boffset+24); + vec_xst(t8, 0, boffset+28); + + t1 = vec_mergeh(c1[1], c2[1]); + t2 = vec_mergeh(c3[1], c4[1]); + t3 = vec_mergeh(c5[1], c6[1]); + t4 = vec_mergeh(c7[1], c8[1]); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset+32); + vec_xst(t6, 0, boffset+36); + vec_xst(t7, 0, boffset+40); + vec_xst(t8, 0, boffset+44); + + t1 = vec_mergel(c1[1], c2[1]); + t2 = vec_mergel(c3[1], c4[1]); + t3 = vec_mergel(c5[1], c6[1]); + t4 = vec_mergel(c7[1], c8[1]); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset+48); + vec_xst(t6, 0, boffset+52); + vec_xst(t7, 0, boffset+56); + vec_xst(t8, 0, boffset+60); + + aoffset1 += 8*lda; + aoffset2 += 8*lda; + aoffset3 += 8*lda; + aoffset4 += 8*lda; + boffset += 64; + i--; + } while(i > 0); + } + if (cols & 4) { + vector float c1, c2, c3, c4, c5, c6, c7, c8; + vector float t1, t2, t3, t4, t5, t6, t7, t8; + c1 = vec_xl(0, aoffset1); + c2 = vec_xl(0, aoffset2); + c3 = vec_xl(0, aoffset3); + c4 = vec_xl(0, aoffset4); + c5 = vec_xl(0, aoffset5); + c6 = vec_xl(0, aoffset6); + c7 = vec_xl(0, aoffset7); + c8 = vec_xl(0, aoffset8); + + t1 = vec_mergeh(c1, c2); + t2 = vec_mergeh(c3, c4); + t3 = vec_mergeh(c5, c6); + t4 = vec_mergeh(c7, c8); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset); + vec_xst(t6, 0, boffset+4); + vec_xst(t7, 0, boffset+8); + vec_xst(t8, 0, boffset+12); + + t1 = vec_mergel(c1, c2); + t2 = vec_mergel(c3, c4); + t3 = vec_mergel(c5, c6); + t4 = vec_mergel(c7, c8); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset+16); + vec_xst(t6, 0, boffset+20); + vec_xst(t7, 0, boffset+24); + vec_xst(t8, 0, boffset+28); + } + j--; + } while(j > 0); + } + + if (rows & 4) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset += 4 * lda; + i = (cols >> 3); + if (i > 0) { + __vector_pair C1, C2, C3, C4; + vector float c1[2], c2[2], c3[2], c4[2]; + vector float t1, t2, t3, t4, t5, t6, t7, t8; + do { + C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); + C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); + C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); + C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); + __builtin_vsx_disassemble_pair(c1, &C1); + __builtin_vsx_disassemble_pair(c2, &C2); + __builtin_vsx_disassemble_pair(c3, &C3); + __builtin_vsx_disassemble_pair(c4, &C4); + + t1 = vec_mergeh(c1[0], c2[0]); + t2 = vec_mergeh(c3[0], c4[0]); + t3 = vec_mergel(c1[0], c2[0]); + t4 = vec_mergel(c3[0], c4[0]); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t1, t2, 3); + t7 = vec_xxpermdi(t3, t4, 0); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset); + vec_xst(t6, 0, boffset+4); + vec_xst(t7, 0, boffset+8); + vec_xst(t8, 0, boffset+12); + + t1 = vec_mergeh(c1[1], c2[1]); + t2 = vec_mergeh(c3[1], c4[1]); + t3 = vec_mergel(c1[1], c2[1]); + t4 = vec_mergel(c3[1], c4[1]); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t1, t2, 3); + t7 = vec_xxpermdi(t3, t4, 0); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset+16); + vec_xst(t6, 0, boffset+20); + vec_xst(t7, 0, boffset+24); + vec_xst(t8, 0, boffset+28); + + aoffset1 += 8*lda; + aoffset2 += 8*lda; + aoffset3 += 8*lda; + aoffset4 += 8*lda; + boffset += 32; + i--; + } while(i > 0); + } + + if (cols & 4) { + vector float c1, c2, c3, c4; + vector float t1, t2, t3, t4; + c1 = vec_xl(0, aoffset1); + c2 = vec_xl(0, aoffset2); + c3 = vec_xl(0, aoffset3); + c4 = vec_xl(0, aoffset4); + + t1 = vec_mergeh(c1, c2); + t2 = vec_mergeh(c3, c4); + t3 = vec_xxpermdi(t1, t2, 0); + t4 = vec_xxpermdi(t1, t2, 3); + vec_xst(t3, 0, boffset); + vec_xst(t4, 0, boffset+4); + + t1 = vec_mergel(c1, c2); + t2 = vec_mergel(c3, c4); + t3 = vec_xxpermdi(t1, t2, 0); + t4 = vec_xxpermdi(t1, t2, 3); + vec_xst(t3, 0, boffset+8); + vec_xst(t4, 0, boffset+12); + } + } + if (rows & 3) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + if (cols & 4) { + vector float c1, c2, c3, c4 = {0}; + vector float t1, t2, t3, t4; + c1 = vec_xl(0, aoffset1); + c2 = vec_xl(0, aoffset2); + c3 = vec_xl(0, aoffset3); + + t1 = vec_mergeh(c1, c2); + t2 = vec_mergeh(c3, c4); + t3 = vec_xxpermdi(t1, t2, 0); + t4 = vec_xxpermdi(t1, t2, 3); + vec_xst(t3, 0, boffset); + vec_xst(t4, 0, boffset+4); + + t1 = vec_mergel(c1, c2); + t2 = vec_mergel(c3, c4); + t3 = vec_xxpermdi(t1, t2, 0); + t4 = vec_xxpermdi(t1, t2, 3); + vec_xst(t3, 0, boffset+8); + vec_xst(t4, 0, boffset+12); + } + } + } + + void KERNEL_4x4(int64_t ii, int64_t jj) { + vec_t vec_A[4], vec_B[4], vec_C[4]; + acc_t acc_0; + __builtin_mma_xxsetaccz(&acc_0); + for (int l = 0; l < k; l+=4) { + READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); + READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]); + } + SAVE_ACC(&acc_0, ii, jj); + } + + void KERNEL_4x8(int64_t ii, int64_t jj) { + vec_t vec_A[4], vec_B[8], vec_C[4]; + acc_t acc_0, acc_1; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + for (int64_t l = 0; l < k; l+=4) { + READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); + READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]); + __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]); + __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]); + __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]); + __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]); + } + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii, jj+4); + } + + void KERNEL_8x4(int64_t ii, int64_t jj) { + vec_t vec_A[8], vec_B[4], vec_C[4]; + acc_t acc_0, acc_1; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + for (int64_t l = 0; l < k; l+=4) { + READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A); + READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); + __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]); + __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]); + __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]); + __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]); + __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]); + __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]); + __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]); + __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]); + } + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii+4, jj); + } + + void KERNEL_8x8(int64_t ii, int64_t jj) { + vec_t vec_A[16], vec_B[16], vec_C[4]; + acc_t acc_0, acc_1, acc_2, acc_3; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(&acc_2); + __builtin_mma_xxsetaccz(&acc_3); + for (int l = 0; l < k; l+=8) { + READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A); + READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B); + for(int x = 0; x < 16; x+=2) { + __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]); + __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]); + } + } + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii, jj+4); + SAVE_ACC(&acc_2, ii+4, jj); + SAVE_ACC(&acc_3, ii+4, jj+4); + } + + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + int m_rem = MIN(m - m0, 16); + int n_rem = MIN(n - n0, 16); + if (m_rem >= 16 && n_rem >= 8) { + mc = 8; + nc = 8; + gemm<8,8>(m0, m, n0, n); + } else if(m_rem >= 8 && n_rem >= 16) { + mc = 8; + nc = 8; + gemm<8,8>(m0, m, n0, n); + } else if (m_rem >= 8 && n_rem >= 8) { + mc = 8; + nc = 8; + gemm<8,8>(m0, m, n0, n); + } else if (m_rem >= 4 && n_rem >= 8) { + mc = 4; + nc = 8; + gemm<4,8>(m0, m, n0, n); + } else if (m_rem >= 8 && n_rem >= 4) { + mc = 8; + nc = 4; + gemm<8,4>(m0, m, n0, n); + } else if (m_rem >= 4 && n_rem >= 4) { + mc = 4; + nc = 4; + gemm<4,4>(m0, m, n0, n); + } else if ((m_rem < 4) && (n_rem > 4)) { + nc = 4; + switch(m_rem) { + case 1: + mc = 1; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 2: + mc = 2; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 3: + mc = 3; + gemm_small(m0, m, n0, n, mc, nc); + break; + default: + return; + } + } else if ((m_rem > 4) && (n_rem < 4)) { + mc = 4; + switch(n_rem) { + case 1: + nc = 1; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 2: + nc = 2; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 3: + nc = 3; + gemm_small(m0, m, n0, n, mc, nc); + break; + default: + return; + } + } else { + switch((m_rem << 4) | n_rem) { + case 0x43: + mc = 4; + nc = 3; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x42: + mc = 4; + nc = 2; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x41: + mc = 4; + nc = 1; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x34: + mc = 3; + nc = 4; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x33: + mc = 3; + nc = 3; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x32: + mc = 3; + nc = 2; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x31: + mc = 3; + nc = 1; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x24: + mc = 2; + nc = 4; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x23: + mc = 2; + nc = 3; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x22: + mc = 2; + nc = 2; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x21: + mc = 2; + nc = 1; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x14: + mc = 1; + nc = 4; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x13: + mc = 1; + nc = 3; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x12: + mc = 1; + nc = 2; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x11: + mc = 1; + nc = 1; + gemm_small(m0, m, n0, n, mc, nc); + break; + default: + return; + } + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + vec_t vec_C[4]; + acc_t acc_0; + __builtin_mma_xxsetaccz(&acc_0); + vec_t vec_A[4], vec_B[4]; + for (int l=0; l= 4 && RM == 1) { + float* a = const_cast(A+(ii)*lda+l); + READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); + vec_A[0] = (vec_t)vec_xl(0,a); + vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1)); + vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2)); + vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3)); + } else { + READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); + READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); + } + __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]); + } + __builtin_mma_disassemble_acc(vec_C, &acc_0); + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); + } + } + } + } + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (RM == 4 && RN == 4) { + kernel = &tinyBLAS_PPC::KERNEL_4x4; + } else if (RM == 4 && RN == 8) { + kernel = &tinyBLAS_PPC::KERNEL_4x8; + } else if (RM == 8 && RN == 4) { + kernel = &tinyBLAS_PPC::KERNEL_8x4; + } else if (RM == 8 && RN == 8) { + kernel = &tinyBLAS_PPC::KERNEL_8x8; + } + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + (this->*kernel)(ii, jj); + } + } + + const TA *const A; + const TB *const B; + TC *C; + TA *At; + TB *Bt; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; +#endif } // namespace /** @@ -1114,6 +1712,16 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda ith, nth}; tb.matmul(m, n); return true; +#elif defined(__MMA__) + if (k % 8) + return false; + tinyBLAS_PPC tb{ + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; #else return false; #endif diff --git a/example/ios/.xcode.env.local b/example/ios/.xcode.env.local index 12b13b28..20e57253 100644 --- a/example/ios/.xcode.env.local +++ b/example/ios/.xcode.env.local @@ -1 +1 @@ -export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1730859907282-0.2156595720676977/node +export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1731817237858-0.44085333624502776/node diff --git a/llama.cpp b/llama.cpp index a6744e43..0fff7fd7 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit a6744e43e80f4be6398fc7733a01642c846dce1d +Subproject commit 0fff7fd79818980763a601660f25b01a0cf4b87a diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index 6afbb02b..db30c27b 100755 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -6,26 +6,38 @@ git submodule update --recursive cp ./llama.cpp/include/llama.h ./cpp/llama.h cp ./llama.cpp/ggml/include/ggml.h ./cpp/ggml.h -cp ./llama.cpp/ggml/include/ggml-cpp.h ./cpp/ggml-cpp.h cp ./llama.cpp/ggml/include/ggml-alloc.h ./cpp/ggml-alloc.h cp ./llama.cpp/ggml/include/ggml-backend.h ./cpp/ggml-backend.h +cp ./llama.cpp/ggml/include/ggml-cpu.h ./cpp/ggml-cpu.h +cp ./llama.cpp/ggml/include/ggml-cpp.h ./cpp/ggml-cpp.h cp ./llama.cpp/ggml/include/ggml-metal.h ./cpp/ggml-metal.h cp ./llama.cpp/ggml/src/ggml.c ./cpp/ggml.c -cp ./llama.cpp/ggml/src/ggml-metal.m ./cpp/ggml-metal.m +cp ./llama.cpp/ggml/src/ggml-impl.h ./cpp/ggml-impl.h cp ./llama.cpp/ggml/src/ggml-alloc.c ./cpp/ggml-alloc.c cp ./llama.cpp/ggml/src/ggml-backend.cpp ./cpp/ggml-backend.cpp cp ./llama.cpp/ggml/src/ggml-backend-impl.h ./cpp/ggml-backend-impl.h -cp ./llama.cpp/ggml/src/ggml-impl.h ./cpp/ggml-impl.h -cp ./llama.cpp/ggml/src/ggml-cpu-impl.h ./cpp/ggml-cpu-impl.h +cp ./llama.cpp/ggml/src/ggml-backend-reg.cpp ./cpp/ggml-backend-reg.cpp cp ./llama.cpp/ggml/src/ggml-common.h ./cpp/ggml-common.h cp ./llama.cpp/ggml/src/ggml-quants.h ./cpp/ggml-quants.h cp ./llama.cpp/ggml/src/ggml-quants.c ./cpp/ggml-quants.c cp ./llama.cpp/ggml/src/ggml-aarch64.c ./cpp/ggml-aarch64.c cp ./llama.cpp/ggml/src/ggml-aarch64.h ./cpp/ggml-aarch64.h +cp ./llama.cpp/ggml/src/ggml-threading.cpp ./cpp/ggml-threading.cpp +cp ./llama.cpp/ggml/src/ggml-threading.h ./cpp/ggml-threading.h -cp ./llama.cpp/ggml/src/llamafile/sgemm.h ./cpp/sgemm.h -cp ./llama.cpp/ggml/src/llamafile/sgemm.cpp ./cpp/sgemm.cpp +cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c ./cpp/ggml-cpu.c +cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp ./cpp/ggml-cpu.cpp +cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h ./cpp/ggml-cpu-impl.h +cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h ./cpp/ggml-cpu-aarch64.h +cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c ./cpp/ggml-cpu-aarch64.c +cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h ./cpp/ggml-cpu-quants.h +cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c ./cpp/ggml-cpu-quants.c + +cp ./llama.cpp/ggml/src/ggml-metal/ggml-metal.m ./cpp/ggml-metal.m + +cp ./llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h ./cpp/sgemm.h +cp ./llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp ./cpp/sgemm.cpp cp ./llama.cpp/src/llama.cpp ./cpp/llama.cpp cp ./llama.cpp/src/llama-impl.h ./cpp/llama-impl.h @@ -65,6 +77,7 @@ files_add_lm_prefix=( "./cpp/log.cpp" "./cpp/ggml.h" "./cpp/ggml.c" + "./cpp/ggml-impl.h" "./cpp/common.h" "./cpp/common.cpp" "./cpp/ggml-cpp.h" @@ -80,8 +93,17 @@ files_add_lm_prefix=( "./cpp/ggml-backend.h" "./cpp/ggml-backend.cpp" "./cpp/ggml-backend-impl.h" - "./cpp/ggml-impl.h" + "./cpp/ggml-backend-reg.cpp" "./cpp/ggml-cpu-impl.h" + "./cpp/ggml-cpu.h" + "./cpp/ggml-cpu.c" + "./cpp/ggml-cpu.cpp" + "./cpp/ggml-cpu-aarch64.h" + "./cpp/ggml-cpu-aarch64.c" + "./cpp/ggml-cpu-quants.h" + "./cpp/ggml-cpu-quants.c" + "./cpp/ggml-threading.h" + "./cpp/ggml-threading.cpp" "./cpp/ggml-common.h" "./cpp/sgemm.h" "./cpp/sgemm.cpp" @@ -140,17 +162,20 @@ patch -p0 -d ./cpp < ./scripts/common.cpp.patch patch -p0 -d ./cpp < ./scripts/log.cpp.patch patch -p0 -d ./cpp < ./scripts/llama.cpp.patch patch -p0 -d ./cpp < ./scripts/ggml-metal.m.patch -patch -p0 -d ./cpp < ./scripts/ggml-backend.cpp.patch +patch -p0 -d ./cpp < ./scripts/ggml-backend-reg.cpp.patch patch -p0 -d ./cpp < ./scripts/ggml.c.patch +patch -p0 -d ./cpp < ./scripts/ggml-quants.c.patch +patch -p0 -d ./cpp < ./scripts/ggml-cpu-aarch64.c.patch +patch -p0 -d ./cpp < ./scripts/sgemm.cpp.patch if [ "$OS" = "Darwin" ]; then # Build metallib (~1.4MB) - cd llama.cpp/ggml/src/ + cd llama.cpp/ggml/src/ggml-metal xcrun --sdk iphoneos metal -c ggml-metal.metal -o ggml-metal.air xcrun --sdk iphoneos metallib ggml-metal.air -o ggml-llama.metallib rm ggml-metal.air - cp ./ggml-llama.metallib ../../../cpp/ggml-llama.metallib + cp ./ggml-llama.metallib ../../../../cpp/ggml-llama.metallib cd - diff --git a/scripts/common.h.patch b/scripts/common.h.patch index 87b54b8e..9a2b365c 100644 --- a/scripts/common.h.patch +++ b/scripts/common.h.patch @@ -1,5 +1,5 @@ ---- common.h.orig 2024-11-04 12:59:08 -+++ common.h 2024-11-04 12:58:24 +--- common.h.orig 2024-11-17 11:56:40 ++++ common.h 2024-11-17 11:56:41 @@ -41,6 +41,17 @@ struct common_control_vector_load_info; @@ -24,9 +24,9 @@ struct common_params { + bool vocab_only = false; int32_t n_predict = -1; // new tokens to predict - int32_t n_ctx = 0; // context size + int32_t n_ctx = 4096; // context size int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) -@@ -271,6 +283,9 @@ +@@ -270,6 +282,9 @@ bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data diff --git a/scripts/ggml-backend.cpp.patch b/scripts/ggml-backend-reg.cpp.patch similarity index 50% rename from scripts/ggml-backend.cpp.patch rename to scripts/ggml-backend-reg.cpp.patch index 48757f45..8b80d697 100644 --- a/scripts/ggml-backend.cpp.patch +++ b/scripts/ggml-backend-reg.cpp.patch @@ -1,13 +1,29 @@ ---- ggml-backend.cpp.orig 2024-11-02 18:37:57 -+++ ggml-backend.cpp 2024-11-02 18:39:36 -@@ -575,8 +575,11 @@ - register_backend(lm_ggml_backend_cuda_reg()); +--- ggml-backend-reg.cpp.orig 2024-11-17 11:53:44 ++++ ggml-backend-reg.cpp 2024-11-17 11:53:17 +@@ -12,9 +12,14 @@ #endif + #ifdef LM_GGML_USE_METAL +#include ++ ++#if !TARGET_OS_SIMULATOR + #include "ggml-metal.h" + #endif + ++#endif ++ + #ifdef LM_GGML_USE_SYCL + #include "ggml-sycl.h" + #endif +@@ -52,8 +57,12 @@ + register_backend(lm_ggml_backend_cuda_reg()); + #endif + #ifdef LM_GGML_USE_METAL ++ +#if !TARGET_OS_SIMULATOR register_backend(lm_ggml_backend_metal_reg()); #endif ++ +#endif #ifdef LM_GGML_USE_SYCL register_backend(lm_ggml_backend_sycl_reg()); diff --git a/scripts/ggml-cpu-aarch64.c.patch b/scripts/ggml-cpu-aarch64.c.patch new file mode 100644 index 00000000..5cadaa07 --- /dev/null +++ b/scripts/ggml-cpu-aarch64.c.patch @@ -0,0 +1,11 @@ +--- ggml-cpu-aarch64.c.orig 2024-11-17 12:15:45 ++++ ggml-cpu-aarch64.c 2024-11-17 12:15:56 +@@ -8,7 +8,7 @@ + #include "ggml-quants.h" + #include "ggml-impl.h" + #include "ggml-cpu.h" +-#include "ggml-cpu/ggml-cpu-impl.h" ++#include "ggml-cpu-impl.h" + + #include + #include diff --git a/scripts/ggml-metal.m.patch b/scripts/ggml-metal.m.patch index 4f3821a4..b2c8d86c 100644 --- a/scripts/ggml-metal.m.patch +++ b/scripts/ggml-metal.m.patch @@ -1,9 +1,9 @@ ---- ggml-metal.m.orig -+++ ggml-metal.m -@@ -389,7 +389,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend +--- ggml-metal.m.orig 2024-11-17 11:52:03 ++++ ggml-metal.m 2024-11-17 11:52:05 +@@ -461,7 +461,7 @@ const bool try_metallib = true; #endif - + - NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"]; + NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"]; if (try_metallib && path_lib != nil) { diff --git a/scripts/ggml-quants.c.patch b/scripts/ggml-quants.c.patch new file mode 100644 index 00000000..5563f841 --- /dev/null +++ b/scripts/ggml-quants.c.patch @@ -0,0 +1,11 @@ +--- ggml-quants.c.orig 2024-11-17 12:15:22 ++++ ggml-quants.c 2024-11-17 12:14:57 +@@ -3,7 +3,7 @@ + + #include "ggml-quants.h" + #include "ggml-impl.h" +-#include "ggml-cpu/ggml-cpu-impl.h" ++#include "ggml-cpu-impl.h" + #include "ggml-cpu.h" + + #include diff --git a/scripts/ggml.c.patch b/scripts/ggml.c.patch index 25744ff2..f104777b 100644 --- a/scripts/ggml.c.patch +++ b/scripts/ggml.c.patch @@ -1,6 +1,6 @@ ---- ggml.c.orig -+++ ggml.c -@@ -242,9 +242,9 @@ static void lm_ggml_print_backtrace_symbols(void) { +--- ggml.c.orig 2024-11-17 12:20:04 ++++ ggml.c 2024-11-17 12:20:05 +@@ -114,9 +114,9 @@ #elif defined(__linux__) && defined(__GLIBC__) #include static void lm_ggml_print_backtrace_symbols(void) { diff --git a/scripts/sgemm.cpp.patch b/scripts/sgemm.cpp.patch new file mode 100644 index 00000000..590be27a --- /dev/null +++ b/scripts/sgemm.cpp.patch @@ -0,0 +1,12 @@ +--- sgemm.cpp.orig 2024-11-17 12:18:43 ++++ sgemm.cpp 2024-11-17 12:19:12 +@@ -50,8 +50,7 @@ + + #include "sgemm.h" + #include "ggml-impl.h" +-// hack until moved into the CPU backend +-#include "../ggml-cpu-impl.h" ++#include "ggml-cpu-impl.h" + #include "ggml-quants.h" + + #ifdef _MSC_VER