From 9ccaad95a6569d9ec27b43add0db3bebad1a8fe0 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Thu, 21 Nov 2024 11:03:14 +0800 Subject: [PATCH 1/2] feat: sync llama.cpp --- android/src/main/CMakeLists.txt | 1 + cpp/common.cpp | 6 + cpp/ggml-alloc.c | 14 +- cpp/ggml-backend.cpp | 14 +- cpp/ggml-backend.h | 26 +- cpp/ggml-cpu.c | 43 +- cpp/ggml-impl.h | 19 +- cpp/ggml-metal.m | 704 +++++++------ cpp/ggml-opt.cpp | 867 +++++++++++++++ cpp/ggml-opt.h | 216 ++++ cpp/ggml.c | 1748 ++++++++++--------------------- cpp/ggml.h | 218 +--- cpp/llama.cpp | 208 +++- cpp/llama.h | 3 + example/ios/.xcode.env.local | 2 +- llama.cpp | 2 +- scripts/bootstrap.sh | 4 + scripts/common.cpp.patch | 10 +- scripts/ggml-metal.m.patch | 6 +- scripts/llama.cpp.patch | 6 +- 20 files changed, 2349 insertions(+), 1768 deletions(-) create mode 100644 cpp/ggml-opt.cpp create mode 100644 cpp/ggml-opt.h diff --git a/android/src/main/CMakeLists.txt b/android/src/main/CMakeLists.txt index 92b5189..e9c3bb4 100644 --- a/android/src/main/CMakeLists.txt +++ b/android/src/main/CMakeLists.txt @@ -18,6 +18,7 @@ set( ${RNLLAMA_LIB_DIR}/ggml-cpu.cpp ${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.c ${RNLLAMA_LIB_DIR}/ggml-cpu-quants.c + ${RNLLAMA_LIB_DIR}/ggml-opt.cpp ${RNLLAMA_LIB_DIR}/ggml-threading.cpp ${RNLLAMA_LIB_DIR}/ggml-quants.c ${RNLLAMA_LIB_DIR}/log.cpp diff --git a/cpp/common.cpp b/cpp/common.cpp index 3b527a1..1b4f9d4 100644 --- a/cpp/common.cpp +++ b/cpp/common.cpp @@ -875,6 +875,12 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } + if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) { + LOG_ERR("%s: KV cache shifting is not supported for this model (--no-context-shift to disable)'\n", __func__); + llama_free_model(model); + return iparams; + } + if (!params.control_vectors.empty()) { if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); diff --git a/cpp/ggml-alloc.c b/cpp/ggml-alloc.c index 6843f5c..e4c00b1 100644 --- a/cpp/ggml-alloc.c +++ b/cpp/ggml-alloc.c @@ -466,18 +466,12 @@ static bool lm_ggml_gallocr_is_own(lm_ggml_gallocr_t galloc, struct lm_ggml_tens return lm_ggml_gallocr_hash_get(galloc, t)->allocated; } -static void lm_ggml_gallocr_set_node_offset(lm_ggml_gallocr_t galloc, struct lm_ggml_tensor * node, int buffer_id, size_t offset) { - struct hash_node * hn = lm_ggml_gallocr_hash_get(galloc, node); - hn->buffer_id = buffer_id; - hn->offset = offset; - hn->allocated = true; -} - static bool lm_ggml_gallocr_is_allocated(lm_ggml_gallocr_t galloc, struct lm_ggml_tensor * t) { return t->data != NULL || lm_ggml_gallocr_hash_get(galloc, t)->allocated; } static void lm_ggml_gallocr_allocate_node(lm_ggml_gallocr_t galloc, struct lm_ggml_tensor * node, int buffer_id) { + LM_GGML_ASSERT(buffer_id >= 0); struct hash_node * hn = lm_ggml_gallocr_hash_get(galloc, node); if (!lm_ggml_gallocr_is_allocated(galloc, node) && !lm_ggml_is_view(node)) { @@ -816,7 +810,11 @@ static void lm_ggml_gallocr_init_tensor(lm_ggml_gallocr_t galloc, struct lm_ggml } static bool lm_ggml_gallocr_node_needs_realloc(lm_ggml_gallocr_t galloc, struct lm_ggml_tensor * node, struct tensor_alloc * talloc) { - size_t node_size = (node->data || node->view_src) ? 0 : lm_ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node); + size_t node_size = 0; + if (!node->data && !node->view_src) { + LM_GGML_ASSERT(talloc->buffer_id >= 0); // prevent segfault when misusing the API + node_size = lm_ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node); + } return talloc->size_max >= node_size; } diff --git a/cpp/ggml-backend.cpp b/cpp/ggml-backend.cpp index 3f45df7..27f92ad 100644 --- a/cpp/ggml-backend.cpp +++ b/cpp/ggml-backend.cpp @@ -279,7 +279,7 @@ void lm_ggml_backend_tensor_get(const struct lm_ggml_tensor * tensor, void * dat buf->iface.get_tensor(buf, tensor, data, offset, size); } -LM_GGML_API void lm_ggml_backend_tensor_memset(struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { +void lm_ggml_backend_tensor_memset(struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { lm_ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; if (size == 0) { @@ -689,7 +689,7 @@ static int lm_ggml_backend_sched_backend_id(lm_ggml_backend_sched_t sched, lm_gg } static int lm_ggml_backend_sched_backend_from_buffer(lm_ggml_backend_sched_t sched, const struct lm_ggml_tensor * tensor, const struct lm_ggml_tensor * op) { - lm_ggml_backend_buffer_t buffer = tensor->buffer; + lm_ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; if (buffer == NULL) { return -1; } @@ -722,8 +722,6 @@ static char causes[LM_GGML_DEFAULT_GRAPH_SIZE*16 + LM_GGML_SCHED_MAX_SPLITS_DEBU // returns the backend that should be used for the node based on the current locations static int lm_ggml_backend_sched_backend_id_from_cur(lm_ggml_backend_sched_t sched, struct lm_ggml_tensor * tensor) { - // TODO: use supports_op to check if the backend supports the op - // assign pre-allocated nodes to their backend int cur_backend_id = lm_ggml_backend_sched_backend_from_buffer(sched, tensor, tensor); if (cur_backend_id != -1) { @@ -742,7 +740,7 @@ static int lm_ggml_backend_sched_backend_id_from_cur(lm_ggml_backend_sched_t sch if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) { // since the tensor is pre-allocated, it cannot be moved to another backend - LM_GGML_ABORT("pre-allocated tensor in a backend that cannot run the operation"); + LM_GGML_ABORT("pre-allocated tensor (%s) in a backend that cannot run the operation", tensor->name); } // graph input @@ -886,6 +884,9 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str for (int i = 0; i < graph->n_nodes; i++) { struct lm_ggml_tensor * node = graph->nodes[i]; int * node_backend_id = &tensor_backend_id(node); + if (lm_ggml_is_view_op(node->op)) { + continue; + } // do not overwrite user assignments if (*node_backend_id == -1) { *node_backend_id = lm_ggml_backend_sched_backend_id_from_cur(sched, node); @@ -1538,12 +1539,13 @@ bool lm_ggml_backend_sched_reserve(lm_ggml_backend_sched_t sched, struct lm_ggml lm_ggml_backend_sched_split_graph(sched, measure_graph); + lm_ggml_backend_sched_synchronize(sched); + if (!lm_ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) { return false; } lm_ggml_backend_sched_reset(sched); - lm_ggml_backend_sched_synchronize(sched); return true; } diff --git a/cpp/ggml-backend.h b/cpp/ggml-backend.h index 6a5851c..30f0a49 100644 --- a/cpp/ggml-backend.h +++ b/cpp/ggml-backend.h @@ -86,7 +86,7 @@ extern "C" { LM_GGML_API void lm_ggml_backend_tensor_set_async(lm_ggml_backend_t backend, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size); LM_GGML_API void lm_ggml_backend_tensor_get_async(lm_ggml_backend_t backend, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size); - // "offset" refers to the offset of the tensor data for setting/getting data + // "offset" refers to the offset in tensor->data for setting/getting data LM_GGML_API void lm_ggml_backend_tensor_set( struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size); LM_GGML_API void lm_ggml_backend_tensor_get(const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size); LM_GGML_API void lm_ggml_backend_tensor_memset( struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); @@ -242,14 +242,20 @@ extern "C" { lm_ggml_backend_sched_reserve(sched, reserve_graph); // compute - graph = build_graph(sched); - lm_ggml_backend_sched_graph_compute(sched, graph); + graph = build_graph(sched); // the graph and its tensors are single-use in terms of allocation, multi-use in terms of computation + for (int i = 0; i < 10; ++i) { + lm_ggml_backend_sched_graph_compute(sched, graph); // on the first iteration the graph is allocated automatically + } // if there are graph inputs: - lm_ggml_backend_sched_reset(sched); - lm_ggml_backend_sched_alloc_graph(sched, graph); - lm_ggml_backend_tensor_set(input_tensor, ...); - lm_ggml_backend_sched_graph_compute(sched, graph); + graph = build_graph(sched); // get a new graph that is not allocated (the metadata for the old graph is freed once lm_ggml_free is called) + lm_ggml_backend_sched_reset(sched); // clear the allocation of the previous graph + lm_ggml_backend_sched_alloc_graph(sched, graph); // explicitly allocate the new graph but do not execute it + lm_ggml_backend_tensor_set(input_tensor, ...); // copy data to the newly allocated graph tensors + lm_ggml_backend_sched_graph_compute(sched, graph); // execute the graph + + // as an alternative to the above it is also possible to assign the inputs to a dedicated context and + // allocate them statically via lm_ggml_backend_alloc_ctx_tensors } */ @@ -264,7 +270,7 @@ extern "C" { // typedef bool (*lm_ggml_backend_sched_eval_callback)(struct lm_ggml_tensor * t, bool ask, void * user_data); - // Initialize a backend scheduler + // Initialize a backend scheduler, backends with low index are given priority over backends with high index LM_GGML_API lm_ggml_backend_sched_t lm_ggml_backend_sched_new(lm_ggml_backend_t * backends, lm_ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel); LM_GGML_API void lm_ggml_backend_sched_free(lm_ggml_backend_sched_t sched); @@ -289,7 +295,9 @@ extern "C" { LM_GGML_API enum lm_ggml_status lm_ggml_backend_sched_graph_compute_async(lm_ggml_backend_sched_t sched, struct lm_ggml_cgraph * graph); LM_GGML_API void lm_ggml_backend_sched_synchronize(lm_ggml_backend_sched_t sched); - // Reset all assignments and allocators - must be called before changing the node backends + // Reset all assignments and allocators - must be called before changing the node backends or allocating a new graph. + // This in effect deallocates all tensors that were previously allocated and leaves them with dangling pointers. + // The correct way to use this API is to discard the deallocated tensors and create new ones. LM_GGML_API void lm_ggml_backend_sched_reset(lm_ggml_backend_sched_t sched); // Set a callback to be called for each resulting node during graph compute diff --git a/cpp/ggml-cpu.c b/cpp/ggml-cpu.c index 406fb85..51fe133 100644 --- a/cpp/ggml-cpu.c +++ b/cpp/ggml-cpu.c @@ -2369,7 +2369,7 @@ void lm_ggml_numa_init(enum lm_ggml_numa_strategy numa_flag) { // figure out which node we're on uint current_cpu; int getcpu_ret = 0; -#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__) +#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 33) || defined(__COSMOPOLITAN__) getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node); #else // old glibc doesn't have a wrapper for this call. Fall back on direct syscall @@ -12216,11 +12216,16 @@ static void lm_ggml_compute_forward_opt_step_adamw_f32( const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst) { - const struct lm_ggml_tensor * src0 = dst->src[0]; - const struct lm_ggml_tensor * src0_grad = dst->src[1]; - const struct lm_ggml_tensor * src0_grad_m = dst->src[2]; - const struct lm_ggml_tensor * src0_grad_v = dst->src[3]; + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src0_grad = dst->src[1]; + const struct lm_ggml_tensor * src0_grad_m = dst->src[2]; + const struct lm_ggml_tensor * src0_grad_v = dst->src[3]; + const struct lm_ggml_tensor * adamw_params = dst->src[4]; + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src0_grad)); + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src0_grad_m)); + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src0_grad_v)); + LM_GGML_ASSERT(lm_ggml_nelements(adamw_params) == 7); const int ith = params->ith; const int nth = params->nth; @@ -12237,16 +12242,14 @@ static void lm_ggml_compute_forward_opt_step_adamw_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - /* const float gnorm = 1.0f; */ - int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t)); - const float alpha = lm_ggml_get_op_params_f32(dst, 2); - const float beta1 = lm_ggml_get_op_params_f32(dst, 3); - const float beta2 = lm_ggml_get_op_params_f32(dst, 4); - const float eps = lm_ggml_get_op_params_f32(dst, 5); - const float wd = lm_ggml_get_op_params_f32(dst, 6); - - const float beta1h = alpha/(1.0f - powf(beta1, iter)); - const float beta2h = 1.0f/(1.0f - powf(beta2, iter)); + const float * adamw_params_ptr = lm_ggml_get_data_f32(adamw_params); + const float alpha = adamw_params_ptr[0]; + const float beta1 = adamw_params_ptr[1]; + const float beta2 = adamw_params_ptr[2]; + const float eps = adamw_params_ptr[3]; + const float wd = adamw_params_ptr[4]; + const float beta1h = adamw_params_ptr[5]; + const float beta2h = adamw_params_ptr[6]; for (int ir = ir0; ir < ir1; ++ir) { const int64_t i03 = ir/(ne02*ne01); @@ -12270,17 +12273,9 @@ static void lm_ggml_compute_forward_opt_step_adamw_f32( // The weight decay is applied independently of the Adam momenta m and v. // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss. // See: https://arxiv.org/pdf/1711.05101v3.pdf - w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh; + w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh; } } - - lm_ggml_barrier(params->threadpool); - if (ith != 0) { - return; - } - - iter++; - memcpy(&dst->op_params[0], &iter, sizeof(int64_t)); } static void lm_ggml_compute_forward_opt_step_adamw( diff --git a/cpp/ggml-impl.h b/cpp/ggml-impl.h index e01aedf..17cfe65 100644 --- a/cpp/ggml-impl.h +++ b/cpp/ggml-impl.h @@ -196,7 +196,7 @@ void lm_ggml_hash_set_reset(struct lm_ggml_hash_set * hash_set); static bool lm_ggml_hash_contains(const struct lm_ggml_hash_set * hash_set, struct lm_ggml_tensor * key); // returns LM_GGML_HASHSET_FULL if table is full, otherwise the current index of the key or where it should be inserted -static size_t lm_ggml_hash_find(const struct lm_ggml_hash_set * hash_set, struct lm_ggml_tensor * key); +static size_t lm_ggml_hash_find(const struct lm_ggml_hash_set * hash_set, const struct lm_ggml_tensor * key); // returns LM_GGML_HASHSET_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full static size_t lm_ggml_hash_insert(struct lm_ggml_hash_set * hash_set, struct lm_ggml_tensor * key); @@ -210,7 +210,7 @@ static inline size_t lm_ggml_hash(const struct lm_ggml_tensor * p) { return (size_t)(uintptr_t)p >> 4; } -static size_t lm_ggml_hash_find(const struct lm_ggml_hash_set * hash_set, struct lm_ggml_tensor * key) { +static size_t lm_ggml_hash_find(const struct lm_ggml_hash_set * hash_set, const struct lm_ggml_tensor * key) { size_t h = lm_ggml_hash(key) % hash_set->size; // linear probing @@ -281,13 +281,14 @@ enum lm_ggml_cgraph_eval_order { }; struct lm_ggml_cgraph { - int size; - int n_nodes; - int n_leafs; - - struct lm_ggml_tensor ** nodes; - struct lm_ggml_tensor ** grads; - struct lm_ggml_tensor ** leafs; + int size; // maximum number of nodes/leafs/grads/grad_accs + int n_nodes; // number of nodes currently in use + int n_leafs; // number of leafs currently in use + + struct lm_ggml_tensor ** nodes; // tensors with data that can change if the graph is evaluated + struct lm_ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes + struct lm_ggml_tensor ** grad_accs; // accumulators for node gradients + struct lm_ggml_tensor ** leafs; // tensors with constant data struct lm_ggml_hash_set visited_hash_set; diff --git a/cpp/ggml-metal.m b/cpp/ggml-metal.m index 8385ad9..468c783 100644 --- a/cpp/ggml-metal.m +++ b/cpp/ggml-metal.m @@ -2,6 +2,7 @@ #import "ggml-impl.h" #import "ggml-backend-impl.h" +#import "ggml-metal-impl.h" #import @@ -125,6 +126,7 @@ static void lm_ggml_backend_metal_device_rel(struct lm_ggml_backend_metal_device LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, LM_GGML_METAL_KERNEL_TYPE_SILU, LM_GGML_METAL_KERNEL_TYPE_SILU_4, + LM_GGML_METAL_KERNEL_TYPE_ELU, LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, @@ -648,6 +650,7 @@ @implementation LMGGMLMetalClass LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ELU, elu, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction); @@ -967,6 +970,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_ case LM_GGML_UNARY_OP_GELU: case LM_GGML_UNARY_OP_GELU_QUICK: case LM_GGML_UNARY_OP_SILU: + case LM_GGML_UNARY_OP_ELU: return lm_ggml_is_contiguous(op->src[0]); default: return false; @@ -1193,35 +1197,39 @@ static void lm_ggml_metal_encode_node( const int32_t dim = ((const int32_t *) dst->op_params)[0]; + lm_ggml_metal_kargs_concat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.dim =*/ dim, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; const int nth = MIN(1024, ne0); @@ -1239,8 +1247,6 @@ static void lm_ggml_metal_encode_node( bool bcast_row = false; - int64_t nb = ne00; // used by the "row" kernels - id pipeline = nil; if (lm_ggml_nelements(src1) == ne10 && lm_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { @@ -1249,7 +1255,6 @@ static void lm_ggml_metal_encode_node( // src1 is a row LM_GGML_ASSERT(ne11 == 1); - nb = ne00 / 4; switch (dst->op) { case LM_GGML_OP_ADD: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; case LM_GGML_OP_SUB: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; @@ -1269,36 +1274,39 @@ static void lm_ggml_metal_encode_node( } } + lm_ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.offs =*/ offs, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; if (bcast_row) { const int64_t n = lm_ggml_nelements(dst)/4; @@ -1322,25 +1330,29 @@ static void lm_ggml_metal_encode_node( default: LM_GGML_ABORT("fatal error"); } + lm_ggml_metal_kargs_repeat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); @@ -1369,25 +1381,29 @@ static void lm_ggml_metal_encode_node( const id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; + lm_ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); @@ -1396,35 +1412,39 @@ static void lm_ggml_metal_encode_node( const id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD].pipeline; + lm_ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ pnb1, + /*.nb02 =*/ pnb2, + /*.nb03 =*/ pnb3, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + /*.offs =*/ offs, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); @@ -1465,10 +1485,10 @@ static void lm_ggml_metal_encode_node( memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float)); [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&min length:sizeof(min) atIndex:2]; - [encoder setBytes:&max length:sizeof(max) atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; const int64_t n = lm_ggml_nelements(dst); @@ -1572,6 +1592,18 @@ static void lm_ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case LM_GGML_UNARY_OP_ELU: + { + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ELU].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = lm_ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: { LM_GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op)); @@ -1640,6 +1672,7 @@ static void lm_ggml_metal_encode_node( id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + // TODO: add lm_ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1715,6 +1748,8 @@ static void lm_ggml_metal_encode_node( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + // TODO: add lm_ggml_metal_kargs struct + // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6) [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; if (id_src1) { @@ -1731,6 +1766,7 @@ static void lm_ggml_metal_encode_node( [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; @@ -1747,6 +1783,7 @@ static void lm_ggml_metal_encode_node( pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; } + // TODO: add lm_ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1771,6 +1808,7 @@ static void lm_ggml_metal_encode_node( id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; + // TODO: add lm_ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1841,6 +1879,7 @@ static void lm_ggml_metal_encode_node( id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + // TODO: add lm_ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1959,24 +1998,29 @@ static void lm_ggml_metal_encode_node( default: LM_GGML_ABORT("MUL MAT-MAT not implemented"); } + lm_ggml_metal_kargs_mul_mm args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:15]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:16]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { @@ -2154,28 +2198,32 @@ static void lm_ggml_metal_encode_node( } }; + lm_ggml_metal_kargs_mul_mv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:19]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:20]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 || src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K || @@ -2288,27 +2336,30 @@ static void lm_ggml_metal_encode_node( default: LM_GGML_ABORT("MUL_MAT_ID not implemented"); } + lm_ggml_metal_kargs_mul_mm_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; [encoder setThreadgroupMemoryLength:LM_GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; @@ -2467,30 +2518,34 @@ static void lm_ggml_metal_encode_node( LM_GGML_ASSERT(ne00 >= nth0*nth1); } + lm_ggml_metal_kargs_mul_mv_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nb1 =*/ nb1, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; const int64_t _ne1 = 1; const int tgz = dst_rows; @@ -2563,6 +2618,7 @@ static void lm_ggml_metal_encode_node( default: LM_GGML_ABORT("not implemented"); } + // TODO: add lm_ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -2586,20 +2642,28 @@ static void lm_ggml_metal_encode_node( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + int nth = 32; // SIMD width - while (nth < ne00/4 && nth < 1024) { + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { nth *= 2; } - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + nth = MIN(nth, ne00/4); + + lm_ggml_metal_kargs_rms_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; const int64_t nrows = lm_ggml_nrows(src0); @@ -2624,6 +2688,7 @@ static void lm_ggml_metal_encode_node( id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; + // TODO: add lm_ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2641,22 +2706,35 @@ static void lm_ggml_metal_encode_node( } break; case LM_GGML_OP_NORM: { + LM_GGML_ASSERT(ne00 % 4 == 0); LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0)); float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int nth = MIN(256, ne00); - id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_NORM].pipeline; + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + lm_ggml_metal_kargs_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:LM_GGML_PAD(nth*sizeof(float), 16) atIndex:0]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; const int64_t nrows = lm_ggml_nrows(src0); @@ -2706,40 +2784,44 @@ static void lm_ggml_metal_encode_node( }; } + lm_ggml_metal_kargs_rope args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.n_past =*/ n_past, + /*.n_dims =*/ n_dims, + /*.n_ctx_orig =*/ n_ctx_orig, + /*.freq_base =*/ freq_base, + /*.freq_scale =*/ freq_scale, + /*.ext_factor =*/ ext_factor, + /*.attn_factor =*/ attn_factor, + /*.beta_fast =*/ beta_fast, + /*.beta_slow =*/ beta_slow, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -2796,6 +2878,7 @@ static void lm_ggml_metal_encode_node( default: LM_GGML_ABORT("fatal error"); }; + // TODO: add lm_ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2836,6 +2919,7 @@ static void lm_ggml_metal_encode_node( const id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; + // TODO: add lm_ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2870,6 +2954,7 @@ static void lm_ggml_metal_encode_node( id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + // TODO: add lm_ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2906,6 +2991,7 @@ static void lm_ggml_metal_encode_node( id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; + // TODO: add lm_ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; @@ -2927,6 +3013,7 @@ static void lm_ggml_metal_encode_node( id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; + // TODO: add lm_ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2965,6 +3052,7 @@ static void lm_ggml_metal_encode_node( default: LM_GGML_ABORT("fatal error"); }; + // TODO: add lm_ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2983,6 +3071,7 @@ static void lm_ggml_metal_encode_node( id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + // TODO: add lm_ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -3224,37 +3313,41 @@ static void lm_ggml_metal_encode_node( } } + lm_ggml_metal_kargs_flash_attn_ext args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.nb_12_1 =*/ nb11, + /*.nb_12_2 =*/ nb12, + /*.nb_12_3 =*/ nb13, + /*.nb31 =*/ nb31, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19]; - [encoder setBytes:&scale length:sizeof( float) atIndex:20]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:21]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:22]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:23]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24]; - [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:5]; if (!use_vec_kernel) { // half8x8 kernel @@ -3389,25 +3482,29 @@ static void lm_ggml_metal_encode_node( default: LM_GGML_ABORT("not implemented"); } + lm_ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -3452,6 +3549,7 @@ static void lm_ggml_metal_encode_node( const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; + // TODO: add lm_ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -3639,6 +3737,12 @@ static void lm_ggml_backend_metal_buffer_free_buffer(lm_ggml_backend_buffer_t bu return ctx->all_data; } +static void lm_ggml_backend_metal_buffer_memset_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + memset((char *)tensor->data + offset, value, size); + + UNUSED(buffer); +} + static void lm_ggml_backend_metal_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { memcpy((char *)tensor->data + offset, data, size); @@ -3671,7 +3775,7 @@ static void lm_ggml_backend_metal_buffer_clear(lm_ggml_backend_buffer_t buffer, /* .free_buffer = */ lm_ggml_backend_metal_buffer_free_buffer, /* .get_base = */ lm_ggml_backend_metal_buffer_get_base, /* .init_tensor = */ NULL, - /* .memset_tensor = */ NULL, + /* .memset_tensor = */ lm_ggml_backend_metal_buffer_memset_tensor, /* .set_tensor = */ lm_ggml_backend_metal_buffer_set_tensor, /* .get_tensor = */ lm_ggml_backend_metal_buffer_get_tensor, /* .cpy_tensor = */ lm_ggml_backend_metal_buffer_cpy_tensor, diff --git a/cpp/ggml-opt.cpp b/cpp/ggml-opt.cpp new file mode 100644 index 0000000..5204eb3 --- /dev/null +++ b/cpp/ggml-opt.cpp @@ -0,0 +1,867 @@ +#include "ggml-opt.h" + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include +#include +#include +#include +#include +#include +#include + +struct lm_ggml_opt_dataset { + struct lm_ggml_context * ctx; + lm_ggml_backend_buffer_t buf; + struct lm_ggml_tensor * data; + struct lm_ggml_tensor * labels; + + int64_t ndata; + int64_t ndata_shard; + size_t nbs_data; + size_t nbs_labels; + + std::vector permutation; +}; + +struct lm_ggml_opt_context { + lm_ggml_backend_sched_t backend_sched; + lm_ggml_cgraph * allocated_graph; + lm_ggml_cgraph * allocated_graph_copy; + struct lm_ggml_context * ctx_static; + struct lm_ggml_context * ctx_static_cpu; + struct lm_ggml_context * ctx_compute; + struct lm_ggml_context * ctx_copy; + lm_ggml_backend_buffer_t buf_static; + lm_ggml_backend_buffer_t buf_static_cpu; + std::mt19937 rng; + + struct lm_ggml_tensor * inputs; + struct lm_ggml_tensor * outputs; + struct lm_ggml_tensor * labels; + + struct lm_ggml_tensor * loss; + struct lm_ggml_tensor * pred; + struct lm_ggml_tensor * ncorrect; + + struct lm_ggml_cgraph * gf; + struct lm_ggml_cgraph * gb_grad; + struct lm_ggml_cgraph * gb_opt; + + int64_t iter; + int32_t opt_period; + int32_t opt_i; + bool loss_per_datapoint; + + lm_ggml_opt_get_optimizer_params get_opt_pars; + void * get_opt_pars_ud; + struct lm_ggml_tensor * adamw_params; +}; + +struct lm_ggml_opt_result { + int64_t ndata = 0; + std::vector loss; + std::vector pred; + int64_t ncorrect = 0; + + bool loss_per_datapoint = false; + int64_t opt_period = -1; +}; + +// ====== Dataset ====== + +lm_ggml_opt_dataset_t lm_ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) { + LM_GGML_ASSERT(ne_datapoint > 0); + LM_GGML_ASSERT(ne_label >= 0); + LM_GGML_ASSERT(ndata > 0); + LM_GGML_ASSERT(ndata_shard > 0); + + lm_ggml_opt_dataset_t result = new lm_ggml_opt_dataset; + result->ndata = ndata; + result->ndata_shard = ndata_shard; + + { + struct lm_ggml_init_params params = { + /*.mem_size =*/ 2*lm_ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + result->ctx = lm_ggml_init(params); + } + + result->data = lm_ggml_new_tensor_2d(result->ctx, LM_GGML_TYPE_F32, ne_datapoint, ndata); + result->nbs_data = lm_ggml_nbytes(result->data) * ndata_shard/ndata; + + if (ne_label > 0) { + result->labels = lm_ggml_new_tensor_2d(result->ctx, LM_GGML_TYPE_F32, ne_label, ndata); + result->nbs_labels = lm_ggml_nbytes(result->labels) * ndata_shard/ndata; + } else { + result->labels = nullptr; + result->nbs_labels = 0; + } + + result->buf = lm_ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, lm_ggml_backend_cpu_buffer_type()); + + const int64_t nshards = ndata/ndata_shard; + result->permutation.resize(nshards); + for (int64_t i = 0; i < nshards; ++i) { + result->permutation[i] = i; + } + return result; +} + +void lm_ggml_opt_dataset_free(lm_ggml_opt_dataset_t dataset) { + lm_ggml_backend_buffer_free(dataset->buf); + lm_ggml_free(dataset->ctx); + delete dataset; +} + +struct lm_ggml_tensor * lm_ggml_opt_dataset_data(lm_ggml_opt_dataset_t dataset) { + return dataset->data; +} + +struct lm_ggml_tensor * lm_ggml_opt_dataset_labels(lm_ggml_opt_dataset_t dataset) { + return dataset->labels; +} + +void lm_ggml_opt_dataset_shuffle(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_dataset_t dataset, int64_t idata) { + LM_GGML_ASSERT(idata <= dataset->ndata); + + if (idata < 0) { + std::shuffle(dataset->permutation.begin(), dataset->permutation.end(), opt_ctx->rng); + return; + } + + LM_GGML_ASSERT(idata % dataset->ndata_shard == 0); + const int64_t ishard_max = idata / dataset->ndata_shard; + std::shuffle(dataset->permutation.begin(), dataset->permutation.begin() + ishard_max, opt_ctx->rng); +} + +void lm_ggml_opt_dataset_get_batch(lm_ggml_opt_dataset_t dataset, struct lm_ggml_tensor * data_batch, struct lm_ggml_tensor * labels_batch, int64_t ibatch) { + LM_GGML_ASSERT( data_batch && lm_ggml_is_contiguous(data_batch)); + LM_GGML_ASSERT(!labels_batch || lm_ggml_is_contiguous(labels_batch)); + LM_GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + + const size_t nb_data_batch = lm_ggml_nbytes(data_batch); + LM_GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); + const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data; + + if (labels_batch) { + const size_t nb_labels_batch = lm_ggml_nbytes(labels_batch); + LM_GGML_ASSERT(nb_labels_batch == shards_per_batch*dataset->nbs_labels); + } + + LM_GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size())); + + for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) { + const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch]; + + const char * ptr_data = (const char *) dataset->data->data + ishard*dataset->nbs_data; + lm_ggml_backend_tensor_set(data_batch, ptr_data, ishard_batch*dataset->nbs_data, dataset->nbs_data); + + if (!labels_batch) { + continue; + } + + const char * ptr_labels = (const char *) dataset->labels->data + ishard*dataset->nbs_labels; + lm_ggml_backend_tensor_set(labels_batch, ptr_labels, ishard_batch*dataset->nbs_labels, dataset->nbs_labels); + } +} + +// ====== Model / Context ====== + +struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_default_optimizer_params(void * userdata) { + LM_GGML_UNUSED(userdata); + + lm_ggml_opt_optimizer_params result; + + result.adamw.alpha = 0.001f; + result.adamw.beta1 = 0.9f; + result.adamw.beta2 = 0.999f; + result.adamw.eps = 1e-8f; + result.adamw.wd = 0.0f; + + return result; +} + +struct lm_ggml_opt_params lm_ggml_opt_default_params( + lm_ggml_backend_sched_t backend_sched, + struct lm_ggml_context * ctx_compute, + struct lm_ggml_tensor * inputs, + struct lm_ggml_tensor * outputs, + enum lm_ggml_opt_loss_type loss_type) { + return { + /*backend_sched =*/ backend_sched, + /*ctx_compute =*/ ctx_compute, + /*inputs =*/ inputs, + /*logits =*/ outputs, + /*loss_type =*/ loss_type, + /*build_type =*/ LM_GGML_OPT_BUILD_TYPE_OPT, + /*opt_period =*/ 1, + /*get_opt_pars =*/ lm_ggml_opt_get_default_optimizer_params, + /*get_opt_pars_ud =*/ nullptr, + }; +} + +static lm_ggml_tensor * map_tensor(std::map & tensor_map, lm_ggml_context * ctx, lm_ggml_tensor * tensor) { + if (!tensor) { + return nullptr; + } + + if (tensor_map.find(tensor) != tensor_map.end()) { + return tensor_map[tensor]; + } + + lm_ggml_tensor * new_tensor = lm_ggml_dup_tensor(ctx, tensor); + tensor_map[tensor] = new_tensor; + + new_tensor->op = tensor->op; + for (int i = 0; i < LM_GGML_MAX_DIMS; i++) { + new_tensor->nb[i] = tensor->nb[i]; + } + new_tensor->flags = tensor->flags; + memcpy(new_tensor->op_params, tensor->op_params, sizeof(tensor->op_params)); + strcpy(new_tensor->name, tensor->name); + new_tensor->data = tensor->data; + new_tensor->buffer = tensor->buffer; + new_tensor->extra = tensor->extra; + new_tensor->view_offs = tensor->view_offs; + new_tensor->view_src = map_tensor(tensor_map, ctx, tensor->view_src); + for (int i = 0; i < LM_GGML_MAX_SRC; i++) { + new_tensor->src[i] = map_tensor(tensor_map, ctx, tensor->src[i]); + } + + return new_tensor; +} + +static lm_ggml_cgraph * dup_graph(lm_ggml_context * ctx, lm_ggml_cgraph * graph) { + std::map tensor_map; + + lm_ggml_cgraph * new_graph = lm_ggml_new_graph_custom(ctx, LM_GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); + + for (int i = 0; i < graph->n_leafs; i++) { + lm_ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->leafs[i])); + } + for (int i = 0; i < graph->n_nodes; i++) { + lm_ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->nodes[i])); + } + for (int i = 0; i < graph->n_nodes; ++i) { + const size_t igrad_src = lm_ggml_hash_find(&graph->visited_hash_set, graph->nodes[i]); + const size_t igrad_dst = lm_ggml_hash_find(&new_graph->visited_hash_set, new_graph->nodes[i]); + graph->grads[igrad_dst] = new_graph->grads[igrad_src]; + graph->grad_accs[igrad_dst] = new_graph->grad_accs[igrad_src]; + } + + return new_graph; +} + +static void lm_ggml_opt_alloc_graph(lm_ggml_opt_context_t opt_ctx, lm_ggml_cgraph * graph) { + LM_GGML_ASSERT(graph); + if (opt_ctx->allocated_graph == graph) { + return; + } + + lm_ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph + + { + lm_ggml_init_params params = { + /*.mem_size =*/ lm_ggml_tensor_overhead() * LM_GGML_DEFAULT_GRAPH_SIZE, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + lm_ggml_free(opt_ctx->ctx_copy); + opt_ctx->ctx_copy = lm_ggml_init(params); + } + + opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph); + + lm_ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); + opt_ctx->allocated_graph = graph; +} + +lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params) { + lm_ggml_opt_context_t result = new struct lm_ggml_opt_context; + result->backend_sched = params.backend_sched; + result->allocated_graph = nullptr; + result->allocated_graph_copy = nullptr; + result->ctx_compute = params.ctx_compute; + result->ctx_copy = nullptr; + result->inputs = params.inputs; + result->outputs = params.outputs; + result->iter = 1; + result->opt_period = params.opt_period; + result->opt_i = 0; + result->get_opt_pars = params.get_opt_pars; + result->get_opt_pars_ud = params.get_opt_pars_ud; + + LM_GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically"); + LM_GGML_ASSERT(result->opt_period >= 1); + + const bool accumulate = params.build_type == LM_GGML_OPT_BUILD_TYPE_GRAD || + (params.build_type == LM_GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1); + + lm_ggml_set_input(result->inputs); + lm_ggml_set_output(result->outputs); + + result->gf = lm_ggml_new_graph_custom(result->ctx_compute, LM_GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. + lm_ggml_build_forward_expand(result->gf, result->outputs); + + int n_param = 0; + for (int i = 0; i < result->gf->n_nodes; ++i) { + if (result->gf->nodes[i]->flags & LM_GGML_TENSOR_FLAG_PARAM) { + n_param++; + } + } + + { + // The static context is used for: + // - gradients (1 tensor per param if using gradient accumulation) + // - optimizer momenta (2 tensors per param) + // - labels + // - loss + its gradient (up to 5 tensors) + // - pred + // - ncorrect (2 tensors). + const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == LM_GGML_OPT_BUILD_TYPE_OPT ? 2 : 0); + const size_t size_meta = (tensors_per_param*n_param + 9) * lm_ggml_tensor_overhead(); + struct lm_ggml_init_params params = { + /*.mem_size =*/ size_meta, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + result->ctx_static = lm_ggml_init(params); + } + { + // The static cpu context is used for: + // - optimizer parameters (1 for the entire context) + const size_t size_meta = 1 * lm_ggml_tensor_overhead(); + struct lm_ggml_init_params params = { + /*.mem_size =*/ size_meta, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + result->ctx_static_cpu = lm_ggml_init(params); + } + + + switch (params.loss_type) { + case LM_GGML_OPT_LOSS_TYPE_MEAN: { + result->labels = nullptr; + result->loss = lm_ggml_sum(result->ctx_static, result->outputs); + lm_ggml_set_name(result->loss, "loss_sum"); + const float scale = 1.0f / (result->opt_period * lm_ggml_nelements(result->outputs)); + result->loss = lm_ggml_scale(result->ctx_static, result->loss, scale); + lm_ggml_set_name(result->loss, "loss_mean"); + result->loss_per_datapoint = true; + break; + } + case LM_GGML_OPT_LOSS_TYPE_SUM: { + result->labels = nullptr; + result->loss = lm_ggml_sum(result->ctx_static, result->outputs); + lm_ggml_set_name(result->loss, "loss_sum"); + result->loss_per_datapoint = false; + break; + } + case LM_GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: { + result->labels = lm_ggml_dup_tensor(result->ctx_static, result->outputs); + lm_ggml_set_input(result->labels); + lm_ggml_set_name(result->labels, "labels"); + result->loss = lm_ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels); + lm_ggml_set_name(result->loss, "loss_cross_entropy"); + if (result->opt_period > 1) { + result->loss = lm_ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period); + lm_ggml_set_name(result->loss, "loss_cross_entropy_scaled"); + } + result->loss_per_datapoint = true; + break; + } + case LM_GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: { + result->labels = lm_ggml_dup_tensor(result->ctx_static, result->outputs); + lm_ggml_set_input(result->labels); + lm_ggml_set_name(result->labels, "labels"); + result->loss = lm_ggml_sub(result->ctx_static, result->outputs, result->labels); + lm_ggml_set_name(result->loss, "loss_error"); + result->loss = lm_ggml_sqr(result->ctx_static, result->loss); + lm_ggml_set_name(result->loss, "loss_squared_error"); + result->loss = lm_ggml_sum(result->ctx_static, result->loss); + lm_ggml_set_name(result->loss, "loss_sum_squared_error"); + const float scale = 1.0f / (result->opt_period * lm_ggml_nelements(result->outputs)); + result->loss = lm_ggml_scale(result->ctx_static, result->loss, scale); + lm_ggml_set_name(result->loss, "loss_mean_squared_error"); + result->loss_per_datapoint = true; + break; + } + } + lm_ggml_set_output(result->loss); + lm_ggml_set_loss(result->loss); + lm_ggml_build_forward_expand(result->gf, result->loss); + + result->pred = lm_ggml_argmax(result->ctx_static, result->outputs); + lm_ggml_set_name(result->pred, "pred"); + lm_ggml_set_output(result->pred); + lm_ggml_build_forward_expand(result->gf, result->pred); + + if (result->labels) { + result->ncorrect = lm_ggml_count_equal(result->ctx_static, result->pred, lm_ggml_argmax(result->ctx_static, result->labels)); + lm_ggml_set_name(result->ncorrect, "ncorrect"); + lm_ggml_set_output(result->ncorrect); + lm_ggml_build_forward_expand(result->gf, result->ncorrect); + } else { + result->ncorrect = nullptr; + } + + if (params.build_type == LM_GGML_OPT_BUILD_TYPE_FORWARD) { + result->gb_grad = nullptr; + result->gb_opt = nullptr; + + result->buf_static = lm_ggml_backend_alloc_ctx_tensors(result->ctx_static, lm_ggml_backend_sched_get_backend(result->backend_sched, 0)); + result->buf_static_cpu = nullptr; + + lm_ggml_opt_alloc_graph(result, result->gf); + + return result; + } + + // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. + result->gb_grad = lm_ggml_graph_dup(result->ctx_compute, result->gf); + lm_ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate); + + if (params.build_type == LM_GGML_OPT_BUILD_TYPE_GRAD) { + result->gb_opt = nullptr; + + result->buf_static = lm_ggml_backend_alloc_ctx_tensors(result->ctx_static, lm_ggml_backend_sched_get_backend(result->backend_sched, 0)); + result->buf_static_cpu = nullptr; + + lm_ggml_opt_alloc_graph(result, result->gb_grad); + lm_ggml_graph_reset(result->gb_grad); + + return result; + } + + LM_GGML_ASSERT(params.build_type == LM_GGML_OPT_BUILD_TYPE_OPT); + + // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. + result->gb_opt = lm_ggml_graph_dup(result->ctx_compute, result->gb_grad); + + result->adamw_params = lm_ggml_new_tensor_1d(result->ctx_static_cpu, LM_GGML_TYPE_F32, 7); + lm_ggml_set_input(result->adamw_params); + lm_ggml_set_name(result->adamw_params, "adamw_params"); + + for (int i = result->gf->n_nodes-1; i >= 0; --i) { + struct lm_ggml_tensor * node = result->gb_opt->nodes[i]; + struct lm_ggml_tensor * grad = lm_ggml_graph_get_grad(result->gb_opt, node); + + if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) { + struct lm_ggml_tensor * m = lm_ggml_dup_tensor(result->ctx_static, node); + struct lm_ggml_tensor * v = lm_ggml_dup_tensor(result->ctx_static, node); + struct lm_ggml_tensor * opt_step = lm_ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params); + lm_ggml_build_forward_expand(result->gb_opt, opt_step); + } + } + + result->buf_static = lm_ggml_backend_alloc_ctx_tensors( + result->ctx_static, lm_ggml_backend_sched_get_backend(result->backend_sched, 0)); + + result->buf_static_cpu = lm_ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, lm_ggml_backend_cpu_buffer_type()); + + lm_ggml_opt_alloc_graph(result, result->gb_opt); + lm_ggml_graph_reset(result->gb_opt); + + return result; +} + +void lm_ggml_opt_free(lm_ggml_opt_context_t opt_ctx) { + if (opt_ctx == nullptr) { + return; + } + lm_ggml_backend_buffer_free(opt_ctx->buf_static); + lm_ggml_backend_buffer_free(opt_ctx->buf_static_cpu); + lm_ggml_free(opt_ctx->ctx_static); + lm_ggml_free(opt_ctx->ctx_static_cpu); + delete opt_ctx; +} + +void lm_ggml_opt_reset(lm_ggml_opt_context_t opt_ctx, bool optimizer) { + if (optimizer) { + lm_ggml_graph_reset(opt_ctx->gb_opt); + opt_ctx->iter = 1; + } else { + lm_ggml_graph_reset(opt_ctx->gb_grad); + } +} + +struct lm_ggml_tensor * lm_ggml_opt_inputs(lm_ggml_opt_context_t opt_ctx) { + return opt_ctx->inputs; +} + +struct lm_ggml_tensor * lm_ggml_opt_outputs(lm_ggml_opt_context_t opt_ctx) { + return opt_ctx->outputs; +} + +struct lm_ggml_tensor * lm_ggml_opt_labels(lm_ggml_opt_context_t opt_ctx) { + return opt_ctx->labels; +} + +struct lm_ggml_tensor * lm_ggml_opt_loss(lm_ggml_opt_context_t opt_ctx) { + return opt_ctx->loss; +} + +struct lm_ggml_tensor * lm_ggml_opt_pred(lm_ggml_opt_context_t opt_ctx) { + return opt_ctx->pred; +} + +struct lm_ggml_tensor * lm_ggml_opt_ncorrect(lm_ggml_opt_context_t opt_ctx) { + return opt_ctx->ncorrect; +} + +struct lm_ggml_tensor * lm_ggml_opt_grad_acc(lm_ggml_opt_context_t opt_ctx, struct lm_ggml_tensor * node) { + return lm_ggml_graph_get_grad_acc(opt_ctx->gb_opt, node); +} + +// ====== Optimization Result ====== + +lm_ggml_opt_result_t lm_ggml_opt_result_init() { + return new lm_ggml_opt_result; +} + +void lm_ggml_opt_result_free(lm_ggml_opt_result_t result) { + delete result; +} + +void lm_ggml_opt_result_reset(lm_ggml_opt_result_t result) { + result->ndata = 0; + result->loss.clear(); + result->pred.clear(); + result->ncorrect = 0; +} + +void lm_ggml_opt_result_ndata(lm_ggml_opt_result_t result, int64_t * ndata) { + *ndata = result->ndata; +} + +void lm_ggml_opt_result_loss(lm_ggml_opt_result_t result, double * loss, double * unc) { + const int64_t nbatches = result->loss.size(); // Number of physical batches. + + if (nbatches == 0) { + *loss = 0.0; + *unc = NAN; + return; + } + + double sum = 0.0; + double sum_squared = 0.0; + + for (const float & loss : result->loss) { + // If the loss is per datapoint it was scaled by 1.0f/opt_period for each physical batch. + const float loss_scaled = result->loss_per_datapoint ? loss*result->opt_period : loss; + sum += loss_scaled; + sum_squared += loss_scaled*loss_scaled; + } + + const double mean = sum/nbatches; + *loss = result->loss_per_datapoint ? mean : sum; + + if (!unc) { + return; + } + + if (nbatches < 2) { + *unc = NAN; + return; + } + + const double var_sum = sum_squared/nbatches - mean*mean; // variance without Bessel's correction, i.e. nbatches/(nbatches-1) + *unc = result->loss_per_datapoint ? sqrt(var_sum / (nbatches - 1)) : sqrt(var_sum * nbatches/(nbatches - 1)); +} + +void lm_ggml_opt_result_pred(lm_ggml_opt_result_t result, int32_t * pred) { + for (size_t i = 0; i < result->pred.size(); ++i) { + pred[i] = result->pred[i]; + } +} + +void lm_ggml_opt_result_accuracy(lm_ggml_opt_result_t result, double * accuracy, double * unc) { + *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN; + + if (!unc) { + return; + } + + *unc = result->ncorrect >= 0 && result->ndata >= 2 ? + sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN; +} + +// ====== Computation ====== + +static void lm_ggml_opt_eval_graph(lm_ggml_opt_context_t opt_ctx, lm_ggml_cgraph * graph, lm_ggml_opt_result * result) { + if (graph != opt_ctx->gf) { + struct lm_ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud); + + LM_GGML_ASSERT(opt_pars.adamw.alpha > 0.0f); + LM_GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f); + LM_GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f); + LM_GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f); + LM_GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f); + LM_GGML_ASSERT(opt_pars.adamw.eps >= 0.0f); + LM_GGML_ASSERT(opt_pars.adamw.wd >= 0.0f); + LM_GGML_ASSERT(opt_pars.adamw.wd <= 1.0f); + + // beta1, beta2 after applying warmup + const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter)); + const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter)); + + float * adamw_par_data = lm_ggml_get_data_f32(opt_ctx->adamw_params); + adamw_par_data[0] = opt_pars.adamw.alpha; + adamw_par_data[1] = opt_pars.adamw.beta1; + adamw_par_data[2] = opt_pars.adamw.beta2; + adamw_par_data[3] = opt_pars.adamw.eps; + adamw_par_data[4] = opt_pars.adamw.wd; + adamw_par_data[5] = beta1h; + adamw_par_data[6] = beta2h; + } + + lm_ggml_opt_alloc_graph(opt_ctx, graph); + lm_ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); + opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt; + + if (!result) { + return; + } + + if (result->ndata == 0) { + result->loss_per_datapoint = opt_ctx->loss_per_datapoint; + result->opt_period = opt_ctx->opt_period; + } else { + LM_GGML_ASSERT(result->loss_per_datapoint == opt_ctx->loss_per_datapoint); + LM_GGML_ASSERT(result->opt_period == opt_ctx->opt_period); + } + + const int64_t ndata = opt_ctx->outputs->ne[1]; + LM_GGML_ASSERT(result->ndata == ndata*int64_t(result->loss.size()) && "varying batch size not supported"); + result->ndata += ndata; + + LM_GGML_ASSERT(lm_ggml_is_scalar(opt_ctx->loss)); + LM_GGML_ASSERT(opt_ctx->loss->type == LM_GGML_TYPE_F32); + float loss; + lm_ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, lm_ggml_nbytes(opt_ctx->loss)); + result->loss.push_back(loss); + + LM_GGML_ASSERT(opt_ctx->pred->type == LM_GGML_TYPE_I32); + std::vector pred(ndata); + lm_ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, lm_ggml_nbytes(opt_ctx->pred)); + result->pred.insert(result->pred.end(), pred.begin(), pred.end()); + + if (!opt_ctx->labels || result->ncorrect < 0) { + result->ncorrect = -1; + return; + } + + LM_GGML_ASSERT(lm_ggml_is_scalar(opt_ctx->ncorrect)); + LM_GGML_ASSERT(opt_ctx->ncorrect->type == LM_GGML_TYPE_I64); + int64_t ncorrect; + lm_ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, lm_ggml_nbytes(opt_ctx->ncorrect)); + result->ncorrect += ncorrect; +} + +void lm_ggml_opt_forward(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result * result) { + lm_ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result); +} + +void lm_ggml_opt_forward_backward(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result * result) { + if (opt_ctx->opt_period == 1) { + lm_ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); + return; + } + + const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; + if (opt_i_next == 0) { + lm_ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); + lm_ggml_opt_reset(opt_ctx, /*optimizer =*/ false); + } else { + lm_ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result); + } + opt_ctx->opt_i = opt_i_next; +} + +// ====== High-Level Functions ====== + +void lm_ggml_opt_epoch( + lm_ggml_opt_context_t opt_ctx, + lm_ggml_opt_dataset_t dataset, + lm_ggml_opt_result_t result_train, + lm_ggml_opt_result_t result_eval, + int64_t idata_split, + lm_ggml_opt_epoch_callback callback_train, + lm_ggml_opt_epoch_callback callback_eval) { + struct lm_ggml_tensor * inputs = lm_ggml_opt_inputs(opt_ctx); + struct lm_ggml_tensor * labels = lm_ggml_opt_labels(opt_ctx); + struct lm_ggml_tensor * data = lm_ggml_opt_dataset_data(dataset); + LM_GGML_ASSERT(data->ne[0] == inputs->ne[0]); + + const int64_t ndata = data->ne[1]; + const int64_t ndata_batch = inputs->ne[1]; + + LM_GGML_ASSERT(data->ne[1] % inputs->ne[1] == 0); + const int64_t nbatches = ndata/ndata_batch; + + idata_split = idata_split < 0 ? ndata : idata_split; + LM_GGML_ASSERT(idata_split % ndata_batch == 0); + const int64_t ibatch_split = idata_split / ndata_batch; + + int64_t ibatch = 0; + int64_t t_loop_start = lm_ggml_time_us(); + for (; ibatch < ibatch_split; ++ibatch) { + lm_ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); + lm_ggml_opt_forward_backward(opt_ctx, result_train); + if (callback_train) { + callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start); + } + } + t_loop_start = lm_ggml_time_us(); + for (; ibatch < nbatches; ++ibatch) { + lm_ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); + lm_ggml_opt_forward(opt_ctx, result_eval); + if (callback_eval) { + callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start); + } + } +} + +void lm_ggml_opt_epoch_callback_progress_bar( + bool train, + lm_ggml_opt_context_t opt_ctx, + lm_ggml_opt_dataset_t dataset, + lm_ggml_opt_result_t result, + int64_t ibatch, + int64_t ibatch_max, + int64_t t_start_us) { + fprintf(stderr, "%s[", train ? "train: " : "val: "); + + constexpr int64_t bar_length = 25; + for (int64_t j = 0; j < bar_length; ++j) { + const int64_t ibatch_j = ibatch_max * j/bar_length; + if (ibatch_j < ibatch) { + fprintf(stderr, "="); + } else if (ibatch_max * (j - 1)/bar_length < ibatch) { + fprintf(stderr, ">"); + } else { + fprintf(stderr, " "); + } + } + + const int64_t batch_size = lm_ggml_opt_inputs(opt_ctx)->ne[1]; + const int64_t idata = ibatch*batch_size; + const int64_t idata_max = ibatch_max*batch_size; + + double loss; + double loss_unc; + lm_ggml_opt_result_loss(result, &loss, &loss_unc); + + double accuracy; + double accuracy_unc; + lm_ggml_opt_result_accuracy(result, &accuracy, &accuracy_unc); + + const int64_t t_ibatch_us = lm_ggml_time_us() - t_start_us; + int64_t t_ibatch_s = t_ibatch_us / 1000000; + const int64_t t_ibatch_h = t_ibatch_s / 3600; + t_ibatch_s -= t_ibatch_h * 3600; + const int64_t t_ibatch_m = t_ibatch_s / 60; + t_ibatch_s -= t_ibatch_m * 60; + + const int64_t t_eta_us = t_ibatch_us * (ibatch_max - ibatch)/ibatch; + int64_t t_eta_s = t_eta_us / 1000000; + const int64_t t_eta_h = t_eta_s / 3600; + t_eta_s -= t_eta_h * 3600; + const int64_t t_eta_m = t_eta_s / 60; + t_eta_s -= t_eta_m * 60; + + fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, " + "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r", + idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc, + t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s); + if (ibatch == ibatch_max) { + fprintf(stderr, "\n"); + } + fflush(stderr); + + LM_GGML_UNUSED(dataset); +} + +void lm_ggml_opt_fit( + lm_ggml_backend_sched_t backend_sched, + lm_ggml_context * ctx_compute, + lm_ggml_tensor * inputs, + lm_ggml_tensor * outputs, + lm_ggml_opt_dataset_t dataset, + enum lm_ggml_opt_loss_type loss_type, + lm_ggml_opt_get_optimizer_params get_opt_pars, + int64_t nepoch, + int64_t nbatch_logical, + float val_split, + bool silent) { + lm_ggml_time_init(); + const int64_t t_start_us = lm_ggml_time_us(); + + const int64_t ndata = lm_ggml_opt_dataset_data(dataset)->ne[1]; + const int64_t nbatch_physical = inputs->ne[1]; + LM_GGML_ASSERT(ndata % nbatch_logical == 0); + LM_GGML_ASSERT(nbatch_logical % nbatch_physical == 0); + + const int64_t opt_period = nbatch_logical / nbatch_physical; + const int64_t nbatches_logical = ndata / nbatch_logical; + + LM_GGML_ASSERT(val_split >= 0.0f); + LM_GGML_ASSERT(val_split < 1.0f); + const int64_t ibatch_split = int64_t(((1.0f - val_split) * nbatches_logical)) * opt_period; // train <-> val split index (physical) + const int64_t idata_split = ibatch_split * nbatch_physical; + + int64_t epoch = 1; + + lm_ggml_opt_params params = lm_ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type); + params.opt_period = opt_period; + params.get_opt_pars = get_opt_pars; + params.get_opt_pars_ud = &epoch; + lm_ggml_opt_context_t opt_ctx = lm_ggml_opt_init(params); + + // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch. + if (nbatch_logical < ndata) { + lm_ggml_opt_dataset_shuffle(opt_ctx, dataset, -1); // Shuffle all data (train + validation). + } + + lm_ggml_opt_result_t result_train = lm_ggml_opt_result_init(); + lm_ggml_opt_result_t result_val = lm_ggml_opt_result_init(); + + lm_ggml_opt_epoch_callback epoch_callback = silent ? nullptr : lm_ggml_opt_epoch_callback_progress_bar; + + for (; epoch <= nepoch; ++epoch) { + if (nbatch_logical < idata_split) { + lm_ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split); + } + + lm_ggml_opt_result_reset(result_train); + lm_ggml_opt_result_reset(result_val); + + if (!silent) { + fprintf(stderr, "%s: epoch %04" PRId64 "/%04" PRId64 ":\n", __func__, epoch, nepoch); + } + lm_ggml_opt_epoch(opt_ctx, dataset, result_train, result_val, idata_split, epoch_callback, epoch_callback); + if (!silent) { + fprintf(stderr, "\n"); + } + } + + if (!silent) { + int64_t t_total_s = (lm_ggml_time_us() - t_start_us) / 1000000; + const int64_t t_total_h = t_total_s / 3600; + t_total_s -= t_total_h * 3600; + const int64_t t_total_m = t_total_s / 60; + t_total_s -= t_total_m * 60; + fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s); + } + + lm_ggml_opt_free(opt_ctx); + lm_ggml_opt_result_free(result_train); + lm_ggml_opt_result_free(result_val); +} diff --git a/cpp/ggml-opt.h b/cpp/ggml-opt.h new file mode 100644 index 0000000..fc982da --- /dev/null +++ b/cpp/ggml-opt.h @@ -0,0 +1,216 @@ +// This file contains functionality for training models using GGML. +// It is not strictly needed vs. just vanilla GGML but it provides a more high-level interface for common needs such as datasets. +// At the bottom of this file especially there are relatively high-level functions that are suitable use or adaptation in user code. +// +// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de) + +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#include + +#ifdef __cplusplus +extern "C" { +#endif + + struct lm_ggml_opt_dataset; + struct lm_ggml_opt_context; + struct lm_ggml_opt_result; + + typedef struct lm_ggml_opt_dataset * lm_ggml_opt_dataset_t; + typedef struct lm_ggml_opt_context * lm_ggml_opt_context_t; + typedef struct lm_ggml_opt_result * lm_ggml_opt_result_t; + + // ====== Loss ====== + + // built-in loss types, i.e. the built-in quantities minimized by the optimizer + // custom loss types can be defined via mean or sum which simply reduce the outputs for all datapoints to a single value + enum lm_ggml_opt_loss_type { + LM_GGML_OPT_LOSS_TYPE_MEAN, + LM_GGML_OPT_LOSS_TYPE_SUM, + LM_GGML_OPT_LOSS_TYPE_CROSS_ENTROPY, + LM_GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR, + }; + + // ====== Dataset ====== + + LM_GGML_API lm_ggml_opt_dataset_t lm_ggml_opt_dataset_init( + int64_t ne_datapoint, // number of elements per datapoint + int64_t ne_label, // number of elements per label + int64_t ndata, // total number of datapoints/labels + int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) + LM_GGML_API void lm_ggml_opt_dataset_free(lm_ggml_opt_dataset_t dataset); + + // get underlying tensors that store the data + LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_dataset_data (lm_ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata] + LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_dataset_labels(lm_ggml_opt_dataset_t dataset); // shape = [nd_label, ndata] + + // shuffle idata first datapoints from dataset with RNG from opt_ctx, shuffle all datapoints if idata is negative + LM_GGML_API void lm_ggml_opt_dataset_shuffle(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_dataset_t dataset, int64_t idata); + + // get batch at position ibatch from dataset and copy the data to data_batch and labels_batch + LM_GGML_API void lm_ggml_opt_dataset_get_batch( + lm_ggml_opt_dataset_t dataset, + struct lm_ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch] + struct lm_ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch] + int64_t ibatch); + + // ====== Model / Context ====== + + enum lm_ggml_opt_build_type { + LM_GGML_OPT_BUILD_TYPE_FORWARD, + LM_GGML_OPT_BUILD_TYPE_GRAD, + LM_GGML_OPT_BUILD_TYPE_OPT, + }; + + // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss + struct lm_ggml_opt_optimizer_params { + // AdamW optimizer parameters + struct { + float alpha; // learning rate + float beta1; + float beta2; + float eps; // epsilon for numerical stability + float wd; // weight decay for AdamW, use 0.0f to disable + } adamw; + }; + + // callback to calculate optimizer parameters prior to a backward pass + // userdata can be used to pass arbitrary data + typedef struct lm_ggml_opt_optimizer_params (*lm_ggml_opt_get_optimizer_params)(void * userdata); + + // returns the default optimizer params (constant) + // userdata is not used + LM_GGML_API struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_default_optimizer_params(void * userdata); + + // parameters for initializing a new optimization context + struct lm_ggml_opt_params { + lm_ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs + + struct lm_ggml_context * ctx_compute; // created in user code, holds non-static tensors + + // the forward graph is defined by inputs and outputs + // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts + struct lm_ggml_tensor * inputs; + struct lm_ggml_tensor * outputs; + + enum lm_ggml_opt_loss_type loss_type; + enum lm_ggml_opt_build_type build_type; + + int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done + + lm_ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters + void * get_opt_pars_ud; // userdata for calculating optimizer parameters + }; + + // get parameters for an optimization context with defaults set where possible + // parameters for which no sensible defaults exist are supplied as arguments to this function + LM_GGML_API lm_ggml_opt_params lm_ggml_opt_default_params( + lm_ggml_backend_sched_t backend_sched, + struct lm_ggml_context * ctx_compute, + struct lm_ggml_tensor * inputs, + struct lm_ggml_tensor * outputs, + enum lm_ggml_opt_loss_type loss_type); + + LM_GGML_API lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params); + LM_GGML_API void lm_ggml_opt_free(lm_ggml_opt_context_t opt_ctx); + + // set gradients to zero, initilize loss, and optionally reset the optimizer + LM_GGML_API void lm_ggml_opt_reset(lm_ggml_opt_context_t opt_ctx, bool optimizer); + + // get underlying tensors that store data + LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_inputs( lm_ggml_opt_context_t opt_ctx); // forward graph input tensor + LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_outputs( lm_ggml_opt_context_t opt_ctx); // forward graph output tensor + LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_labels( lm_ggml_opt_context_t opt_ctx); // labels to compare outputs against + LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_loss( lm_ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss + LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_pred( lm_ggml_opt_context_t opt_ctx); // predictions made by outputs + LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_ncorrect(lm_ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels + + LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_grad_acc(lm_ggml_opt_context_t opt_ctx, struct lm_ggml_tensor * node); + + // ====== Optimization Result ====== + + LM_GGML_API lm_ggml_opt_result_t lm_ggml_opt_result_init(); + LM_GGML_API void lm_ggml_opt_result_free(lm_ggml_opt_result_t result); + LM_GGML_API void lm_ggml_opt_result_reset(lm_ggml_opt_result_t result); + + // get data from result, uncertainties are optional and can be ignored by passing NULL + LM_GGML_API void lm_ggml_opt_result_ndata( lm_ggml_opt_result_t result, int64_t * ndata); // writes 1 value, number of datapoints + LM_GGML_API void lm_ggml_opt_result_loss( lm_ggml_opt_result_t result, double * loss, double * unc); // writes 1 value + LM_GGML_API void lm_ggml_opt_result_pred( lm_ggml_opt_result_t result, int32_t * pred); // writes ndata values + LM_GGML_API void lm_ggml_opt_result_accuracy(lm_ggml_opt_result_t result, double * accuracy, double * unc); // writes 1 value + + // ====== Computation ====== + + // do forward pass, increment result if not NULL + LM_GGML_API void lm_ggml_opt_forward(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result_t result); + + // do forward pass, increment result if not NULL, do backward pass + LM_GGML_API void lm_ggml_opt_forward_backward(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result_t result); + + // ############################################################################ + // ## The high-level functions start here. They do not depend on any private ## + // ## functions or structs and can be copied to and adapted for user code. ## + // ############################################################################ + + // ====== Intended Usage ====== + // + // 1. Select the appropriate loss for your problem. + // 2. Create a dataset and set the data for the "data" tensor. Also set the "labels" tensor if your loss needs them. + // Setting the shard size to 1 will be fine, it's the granularity with which data is shuffled/loaded (bigger values are faster). + // 3. Create a GGML graph for your model with no_alloc == true. Use two separate contexts for the tensors. + // The first context should contain the model parameters and inputs and be allocated statically in user code. + // The second context should contain all other tensors and will be (re)allocated automatically. + // Due to this automated allocation the data of the second context is not defined when accessed in user code. + // Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors. + // 4. Call lm_ggml_opt_fit. If you need more control you can use lm_ggml_opt_epoch instead. + + // signature for a callback while evaluating opt_ctx on dataset, called after an evaluation + typedef void (*lm_ggml_opt_epoch_callback)( + bool train, // true after training evaluation, false after validation evaluation + lm_ggml_opt_context_t opt_ctx, + lm_ggml_opt_dataset_t dataset, + lm_ggml_opt_result_t result, // result associated with the dataset subsection + int64_t ibatch, // number of batches that have been evaluated so far + int64_t ibatch_max, // total number of batches in this dataset subsection + int64_t t_start_us); // time at which the evaluation on the dataset subsection was started + + // do training on front of dataset, do evaluation only on back of dataset + LM_GGML_API void lm_ggml_opt_epoch( + lm_ggml_opt_context_t opt_ctx, + lm_ggml_opt_dataset_t dataset, + lm_ggml_opt_result_t result_train, // result to increment during training, ignored if NULL + lm_ggml_opt_result_t result_eval, // result to increment during evaluation, ignored if NULL + int64_t idata_split, // data index at which to split training and evaluation + lm_ggml_opt_epoch_callback callback_train, + lm_ggml_opt_epoch_callback callback_eval); + + // callback that prints a progress bar on stderr + LM_GGML_API void lm_ggml_opt_epoch_callback_progress_bar( + bool train, + lm_ggml_opt_context_t opt_ctx, + lm_ggml_opt_dataset_t dataset, + lm_ggml_opt_result_t result, + int64_t ibatch, + int64_t ibatch_max, + int64_t t_start_us); + + // fit model defined by inputs and outputs to dataset + LM_GGML_API void lm_ggml_opt_fit( + lm_ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs + lm_ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs + lm_ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch] + lm_ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used + lm_ggml_opt_dataset_t dataset, // dataset with data and optionally also labels + enum lm_ggml_opt_loss_type loss_type, // loss to minimize + lm_ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t) + int64_t nepoch, // how many times the dataset should be iterated over + int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs + float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f) + bool silent); // whether or not info prints to stderr should be suppressed + +#ifdef __cplusplus +} +#endif diff --git a/cpp/ggml.c b/cpp/ggml.c index 876dd87..4cbe385 100644 --- a/cpp/ggml.c +++ b/cpp/ggml.c @@ -1592,14 +1592,13 @@ static struct lm_ggml_tensor * lm_ggml_new_tensor_impl( /*.op =*/ LM_GGML_OP_NONE, /*.op_params =*/ { 0 }, /*.flags =*/ 0, - /*.grad =*/ NULL, /*.src =*/ { NULL }, /*.view_src =*/ view_src, /*.view_offs =*/ view_offs, /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data, /*.name =*/ { 0 }, /*.extra =*/ NULL, - ///*.padding =*/ { 0 }, + /*.padding =*/ { 0 }, }; #ifdef __clang__ @@ -4194,8 +4193,6 @@ struct lm_ggml_tensor * lm_ggml_flash_attn_ext( LM_GGML_ASSERT(mask); } - bool is_node = false; - // permute(0, 2, 1, 3) int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); @@ -4203,8 +4200,7 @@ struct lm_ggml_tensor * lm_ggml_flash_attn_ext( float params[] = { scale, max_bias, logit_softcap }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_FLASH_ATTN_EXT; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_FLASH_ATTN_EXT; result->src[0] = q; result->src[1] = k; result->src[2] = v; @@ -4272,14 +4268,6 @@ struct lm_ggml_tensor * lm_ggml_flash_attn_back( LM_GGML_ASSERT(ne2 % kvne2 == 0); - bool is_node = false; - - if (q->grad || k->grad || v->grad) { - // when using this operation (in backwards pass) these grads are set. - // we don't want to create (big) grad of our result, so is_node is false. - is_node = false; - } - // store gradients of q, k and v as continuous tensors concatenated in result. // note: v and gradv are actually transposed, i.e. v->ne[0] != D. const int64_t elem_q = lm_ggml_nelements(q); @@ -4302,8 +4290,7 @@ struct lm_ggml_tensor * lm_ggml_flash_attn_back( int32_t masked_i = masked ? 1 : 0; lm_ggml_set_op_params(result, &masked_i, sizeof(masked_i)); - result->op = LM_GGML_OP_FLASH_ATTN_BACK; - result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL; + result->op = LM_GGML_OP_FLASH_ATTN_BACK; result->src[0] = q; result->src[1] = k; result->src[2] = v; @@ -4945,34 +4932,24 @@ struct lm_ggml_tensor * lm_ggml_opt_step_adamw( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * grad, - float alpha, - float beta1, - float beta2, - float eps, - float wd) { + struct lm_ggml_tensor * m, + struct lm_ggml_tensor * v, + struct lm_ggml_tensor * adamw_params) { LM_GGML_ASSERT(a->flags & LM_GGML_TENSOR_FLAG_PARAM); LM_GGML_ASSERT(lm_ggml_are_same_shape(a, grad)); - LM_GGML_ASSERT(alpha > 0.0f); - LM_GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f); - LM_GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f); - LM_GGML_ASSERT(eps >= 0.0f); - LM_GGML_ASSERT(wd >= 0.0f && wd <= 1.0f); + LM_GGML_ASSERT(lm_ggml_are_same_shape(a, m)); + LM_GGML_ASSERT(lm_ggml_are_same_shape(a, v)); + LM_GGML_ASSERT(adamw_params->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(lm_ggml_nelements(adamw_params) == 7); struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a); - const int64_t iter = 1; - memcpy(&result->op_params[0], &iter, sizeof(int64_t)); - lm_ggml_set_op_params_f32(result, 2, alpha); - lm_ggml_set_op_params_f32(result, 3, beta1); - lm_ggml_set_op_params_f32(result, 4, beta2); - lm_ggml_set_op_params_f32(result, 5, eps); - lm_ggml_set_op_params_f32(result, 6, wd); - result->op = LM_GGML_OP_OPT_STEP_ADAMW; result->src[0] = a; result->src[1] = grad; - result->src[2] = lm_ggml_dup_tensor(ctx, grad); - result->src[3] = lm_ggml_dup_tensor(ctx, grad); + result->src[2] = m; + result->src[3] = v; + result->src[4] = adamw_params; return result; } @@ -5041,1112 +5018,514 @@ static void lm_ggml_hash_map_free(struct hash_map * map) { LM_GGML_FREE(map); } -// gradient checkpointing - -static struct lm_ggml_tensor * lm_ggml_recompute_graph_node( - struct lm_ggml_context * ctx, - struct lm_ggml_cgraph * graph, - struct hash_map * replacements, - struct lm_ggml_tensor * node) { - - if (node == NULL) { - return NULL; - } - - if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) { - return node; - } - - if (!lm_ggml_hash_contains(&graph->visited_hash_set, node)) { - return node; - } - - int count_children = 0; - for (int k = 0; k < LM_GGML_MAX_SRC; ++k) { - if (node->src[k]) { - ++count_children; - } - } - - if (count_children == 0) { - return node; - } - - size_t i = lm_ggml_hash_find(&replacements->set, node); - LM_GGML_ASSERT(i != LM_GGML_HASHSET_FULL); // assert that not full - if (replacements->set.keys[i] == node) { - return replacements->vals[i]; - } - - struct lm_ggml_tensor * clone = lm_ggml_new_tensor(ctx, node->type, LM_GGML_MAX_DIMS, node->ne); - - // insert clone into replacements - LM_GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite - replacements->set.keys[i] = node; - replacements->vals[i] = clone; - - clone->op = node->op; - clone->grad = node->grad; - clone->flags = node->flags; - clone->extra = node->extra; - for (int k = 0; k < LM_GGML_MAX_DIMS; ++k) { - clone->nb[k] = node->nb[k]; - } - for (int k = 0; k < LM_GGML_MAX_SRC; ++k) { - clone->src[k] = lm_ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); - } - if (node->view_src != NULL) { - clone->data = (node->view_src->data == NULL) - ? NULL // view_src not yet allocated - : (char *) node->view_src->data // view_src already allocated - + node->view_offs; - clone->view_src = node->view_src; - clone->view_offs = node->view_offs; - } - - LM_GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (LM_GGML_MAX_OP_PARAMS / sizeof(int32_t))); - LM_GGML_ASSERT(sizeof(node->name) == LM_GGML_MAX_NAME); - memcpy(clone->op_params, node->op_params, sizeof(node->op_params)); - lm_ggml_format_name(clone, "%s (clone)", lm_ggml_get_name(node)); - - return clone; -} - -void lm_ggml_build_backward_gradient_checkpointing( - struct lm_ggml_context * ctx, - struct lm_ggml_cgraph * gf, - struct lm_ggml_cgraph * gb, - struct lm_ggml_cgraph * gb_tmp, - struct lm_ggml_tensor * * checkpoints, - int n_checkpoints) { - lm_ggml_graph_cpy(gf, gb_tmp); - lm_ggml_build_backward_expand(ctx, gf, gb_tmp, false); - - if (n_checkpoints <= 0) { - lm_ggml_graph_cpy(gb_tmp, gb); - return; - } - - struct hash_map * replacements = lm_ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints); - - // insert checkpoints in replacements - for (int i = 0; i < n_checkpoints; ++i) { - size_t k = lm_ggml_hash_find(&replacements->set, checkpoints[i]); - LM_GGML_ASSERT(k != LM_GGML_HASHSET_FULL); // assert that not full - LM_GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite - replacements->set.keys[k] = checkpoints[i]; - replacements->vals[k] = checkpoints[i]; - } - - lm_ggml_graph_cpy(gf, gb); - // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes], - // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]), - // by recomputing them from checkpoints - for (int i = gf->n_nodes; in_nodes; ++i) { - struct lm_ggml_tensor * node = gb_tmp->nodes[i]; - for (int k = 0; k < LM_GGML_MAX_SRC; ++k) { - // insert new tensors recomputing src, reusing already made replacements, - // remember replacements: remember new tensors with mapping from corresponding gf nodes - // recurse for input tensors, - // unless (i.e. terminating when) input tensors are replacements (like checkpoints) - node->src[k] = lm_ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); - } - // insert rewritten backward node with replacements made into resulting backward graph gb - lm_ggml_build_forward_expand(gb, node); - } - - lm_ggml_hash_map_free(replacements); -} - // utility functions to change gradients // if a is in acc_table, modify gradients in-place and mark result as gradient accumulator // else if a is in zero_table, replace a // else, just add/subtract/etc. the gradients -static struct lm_ggml_tensor * lm_ggml_add_or_set( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_hash_set * zero_table, - struct lm_ggml_hash_set * acc_table) { - if (lm_ggml_hash_contains(acc_table, a)) { - struct lm_ggml_tensor * ret = lm_ggml_add_impl(ctx, a, b, true); - const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); - return ret; - } - if (lm_ggml_hash_contains(zero_table, a)) { - return b; +static void lm_ggml_add_or_set( + struct lm_ggml_context * ctx, + struct lm_ggml_cgraph * cgraph, + size_t isrc, + struct lm_ggml_tensor * tensor) { + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = lm_ggml_add_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = tensor; } - return lm_ggml_add_impl(ctx, a, b, false); + lm_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } -static struct lm_ggml_tensor * lm_ggml_acc_or_set( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - const size_t nb1, - const size_t nb2, - const size_t nb3, - const size_t offset, - struct lm_ggml_hash_set * zero_table, - struct lm_ggml_hash_set * acc_table) { - if (lm_ggml_hash_contains(acc_table, a)) { - struct lm_ggml_tensor * ret = lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); - const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); - return ret; - } - if (lm_ggml_hash_contains(zero_table, a)) { - struct lm_ggml_tensor * a_zero = lm_ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN - return lm_ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); +static void lm_ggml_acc_or_set( + struct lm_ggml_context * ctx, + struct lm_ggml_cgraph * cgraph, + size_t isrc, + struct lm_ggml_tensor * src, + struct lm_ggml_tensor * tensor, + const size_t nb1, + const size_t nb2, + const size_t nb3, + const size_t offset) { + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = lm_ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]); + } else { + struct lm_ggml_tensor * a_zero = lm_ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN + cgraph->grads[isrc] = lm_ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false); } - return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); + lm_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } -static struct lm_ggml_tensor * lm_ggml_add1_or_set( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_hash_set * zero_table, - struct lm_ggml_hash_set * acc_table) { - if (lm_ggml_hash_contains(acc_table, a)) { - struct lm_ggml_tensor * ret = lm_ggml_add1_impl(ctx, a, b, true); - const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); - return ret; - } - if (lm_ggml_hash_contains(zero_table, a)) { - return lm_ggml_repeat(ctx, b, a); +static void lm_ggml_add1_or_set( + struct lm_ggml_context * ctx, + struct lm_ggml_cgraph * cgraph, + size_t isrc, + struct lm_ggml_tensor * src, + struct lm_ggml_tensor * tensor) { + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = lm_ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = lm_ggml_repeat(ctx, tensor, src); } - return lm_ggml_add1_impl(ctx, a, b, false); + lm_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } -static struct lm_ggml_tensor * lm_ggml_sub_or_set( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - struct lm_ggml_hash_set * zero_table, - struct lm_ggml_hash_set * acc_table) { - if (lm_ggml_hash_contains(acc_table, a)) { - struct lm_ggml_tensor * ret = lm_ggml_sub_impl(ctx, a, b, true); - const size_t insert_result = lm_ggml_hash_insert(acc_table, ret); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); - return ret; - } - if (lm_ggml_hash_contains(zero_table, a)) { - return lm_ggml_neg(ctx, b); +static void lm_ggml_sub_or_set( + struct lm_ggml_context * ctx, + struct lm_ggml_cgraph * cgraph, + size_t isrc, + struct lm_ggml_tensor * tensor) { + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = lm_ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = lm_ggml_neg(ctx, tensor); } - return lm_ggml_sub_impl(ctx, a, b, false); + lm_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } -static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor, struct lm_ggml_hash_set * zero_table, struct lm_ggml_hash_set * acc_table) { +static void lm_ggml_compute_backward( + struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, int i, bool * grads_needed) { + struct lm_ggml_tensor * tensor = cgraph->nodes[i]; + struct lm_ggml_tensor * grad = lm_ggml_graph_get_grad(cgraph, tensor); + + if (!grad) { + return; + } + struct lm_ggml_tensor * src0 = tensor->src[0]; struct lm_ggml_tensor * src1 = tensor->src[1]; struct lm_ggml_tensor * src2 = tensor->src[2]; + struct lm_ggml_hash_set * hash_set = &cgraph->visited_hash_set; + const size_t isrc0 = lm_ggml_hash_find(hash_set, src0); + const size_t isrc1 = lm_ggml_hash_find(hash_set, src1); + const size_t isrc2 = lm_ggml_hash_find(hash_set, src2); + const bool src0_needs_grads = isrc0 != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0]; + const bool src1_needs_grads = isrc1 != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1]; + const bool src2_needs_grads = isrc2 != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2]; switch (tensor->op) { - case LM_GGML_OP_DUP: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - } break; - case LM_GGML_OP_ADD: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - if (lm_ggml_are_same_shape(src0, src1)) { - src1->grad = lm_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table); - } else { - src1->grad = lm_ggml_add_or_set(ctx, src1->grad, lm_ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table); - } - } - } break; - case LM_GGML_OP_ADD1: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - src1->grad = lm_ggml_add_or_set(ctx, - src1->grad, - lm_ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean - zero_table, acc_table); - } - } break; - case LM_GGML_OP_ACC: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - const size_t nb1 = ((int32_t *) tensor->op_params)[0]; - const size_t nb2 = ((int32_t *) tensor->op_params)[1]; - const size_t nb3 = ((int32_t *) tensor->op_params)[2]; - const size_t offset = ((int32_t *) tensor->op_params)[3]; - - struct lm_ggml_tensor * tensor_grad_view = lm_ggml_view_4d(ctx, - tensor->grad, - src1->grad->ne[0], - src1->grad->ne[1], - src1->grad->ne[2], - src1->grad->ne[3], - nb1, nb2, nb3, offset); - - src1->grad = - lm_ggml_add_or_set(ctx, - src1->grad, - lm_ggml_reshape(ctx, - lm_ggml_cont(ctx, tensor_grad_view), - src1->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SUB: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - src1->grad = lm_ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table); - } - } break; - case LM_GGML_OP_MUL: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_mul(ctx, src1, tensor->grad), - zero_table, acc_table); - } - if (src1->grad) { - src1->grad = - lm_ggml_add_or_set(ctx, - src1->grad, - lm_ggml_mul(ctx, src0, tensor->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_DIV: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_div(ctx, tensor->grad, src1), - zero_table, acc_table); - } - if (src1->grad) { - src1->grad = - lm_ggml_sub_or_set(ctx, - src1->grad, - lm_ggml_mul(ctx, - tensor->grad, - lm_ggml_div(ctx, tensor, src1)), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SQR: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_scale(ctx, - lm_ggml_mul(ctx, src0, tensor->grad), - 2.0f), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SQRT: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_scale(ctx, - lm_ggml_div(ctx, - tensor->grad, - tensor), - 0.5f), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_LOG: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_div(ctx, - tensor->grad, - src0), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SIN: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_mul(ctx, - tensor->grad, - lm_ggml_cos(ctx, src0)), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_COS: - { - if (src0->grad) { - src0->grad = - lm_ggml_sub_or_set(ctx, - src0->grad, - lm_ggml_mul(ctx, - tensor->grad, - lm_ggml_sin(ctx, src0)), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SUM: - { - if (src0->grad) { - src0->grad = - lm_ggml_add1_or_set(ctx, - src0->grad, - tensor->grad, - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SUM_ROWS: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_repeat(ctx, - tensor->grad, - src0->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_MEAN: - case LM_GGML_OP_ARGMAX: - case LM_GGML_OP_COUNT_EQUAL: - { - LM_GGML_ABORT("fatal error"); // TODO: implement - } - case LM_GGML_OP_REPEAT: - { - // necessary for llama - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_repeat_back(ctx, tensor->grad, src0->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_REPEAT_BACK: - { - if (src0->grad) { - // TODO: test this - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_repeat(ctx, tensor->grad, src0->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_CONCAT: - { - LM_GGML_ABORT("fatal error"); // TODO: implement - } - case LM_GGML_OP_SILU_BACK: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + case LM_GGML_OP_DUP: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case LM_GGML_OP_NORM: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_ADD: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case LM_GGML_OP_RMS_NORM: - { - // necessary for llama - if (src0->grad) { - float eps; - memcpy(&eps, tensor->op_params, sizeof(float)); - - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_rms_norm_back(ctx, src0, tensor->grad, eps), - zero_table, acc_table); + if (src1_needs_grads) { + struct lm_ggml_tensor * tmp = grad; + if (!lm_ggml_are_same_shape(src0, src1)) { + tmp = lm_ggml_repeat_back(ctx, tmp, src1); } - } break; - case LM_GGML_OP_RMS_NORM_BACK: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + lm_ggml_add_or_set(ctx, cgraph, isrc1, tmp); } - case LM_GGML_OP_GROUP_NORM: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_ADD1: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case LM_GGML_OP_MUL_MAT: - { - // https://cs231n.github.io/optimization-2/#staged - // # forward pass - // s0 = np.random.randn(5, 10) - // s1 = np.random.randn(10, 3) - // t = s0.dot(s1) - - // # now suppose we had the gradient on t from above in the circuit - // dt = np.random.randn(*t.shape) # same shape as t - // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix - // ds1 = t.T.dot(dt) - - // tensor.shape [m,p,qq,rr] - // src0.shape [n,m,q1,r1] - // src1.shape [n,p,qq,rr] - - // necessary for llama - if (src0->grad) { - struct lm_ggml_tensor * s1_tg = - lm_ggml_out_prod(ctx, // [n,m,qq,rr] - src1, // [n,p,qq,rr] - tensor->grad); // [m,p,qq,rr] - const int64_t qq = s1_tg->ne[2]; - const int64_t rr = s1_tg->ne[3]; - const int64_t q1 = src0->ne[2]; - const int64_t r1 = src0->ne[3]; - const bool ne2_broadcasted = qq > q1; - const bool ne3_broadcasted = rr > r1; - if (ne2_broadcasted || ne3_broadcasted) { - // sum broadcast repetitions of s1_tg into shape of src0 - s1_tg = lm_ggml_repeat_back(ctx, s1_tg, src0); - } - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, // [n,m,q1,r1] - s1_tg, // [n,m,q1,r1] - zero_table, acc_table); - } - if (src1->grad) { - src1->grad = - lm_ggml_add_or_set(ctx, - src1->grad, // [n,p,qq,rr] - // lm_ggml_mul_mat(ctx, // [n,p,qq,rr] - // lm_ggml_cont(ctx, // [m,n,q1,r1] - // lm_ggml_transpose(ctx, src0)), // [m,n,q1,r1] - // tensor->grad), // [m,p,qq,rr] - - // // when src0 is bigger than tensor->grad (this is mostly the case in llama), - // // avoid transpose of src0, rather transpose smaller tensor->grad - // // and then use lm_ggml_out_prod - lm_ggml_out_prod(ctx, // [n,p,qq,rr] - src0, // [n,m,q1,r1] - lm_ggml_transpose(ctx, // [p,m,qq,rr] - tensor->grad)), // [m,p,qq,rr] - zero_table, acc_table); - } - } break; - case LM_GGML_OP_MUL_MAT_ID: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + if (src1_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc1, lm_ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean } - case LM_GGML_OP_OUT_PROD: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_ACC: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case LM_GGML_OP_SCALE: - { - // necessary for llama - if (src0->grad) { - float s; - memcpy(&s, tensor->op_params, sizeof(float)); - - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_scale_impl(ctx, tensor->grad, s, false), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SET: - { - const size_t nb1 = ((int32_t *) tensor->op_params)[0]; - const size_t nb2 = ((int32_t *) tensor->op_params)[1]; - const size_t nb3 = ((int32_t *) tensor->op_params)[2]; - const size_t offset = ((int32_t *) tensor->op_params)[3]; - - struct lm_ggml_tensor * tensor_grad_view = NULL; - - if (src0->grad || src1->grad) { - LM_GGML_ASSERT(src0->type == tensor->type); - LM_GGML_ASSERT(tensor->grad->type == tensor->type); - LM_GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type); - - tensor_grad_view = lm_ggml_view_4d(ctx, - tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - nb1, nb2, nb3, offset); - } + if (src1_needs_grads) { + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_acc_impl(ctx, - tensor->grad, - lm_ggml_neg(ctx, tensor_grad_view), - nb1, nb2, nb3, offset, false), - zero_table, acc_table); - } + struct lm_ggml_tensor * tensor_grad_view = lm_ggml_view_4d(ctx, + grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + nb1, nb2, nb3, offset); - if (src1->grad) { - src1->grad = - lm_ggml_add_or_set(ctx, - src1->grad, - lm_ggml_reshape(ctx, - lm_ggml_cont(ctx, tensor_grad_view), - src1->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_CPY: - { - // necessary for llama - // cpy overwrites value of src1 by src0 and returns view(src1) - // the overwriting is mathematically equivalent to: - // tensor = src0 * 1 + src1 * 0 - if (src0->grad) { - // dsrc0 = dtensor * 1 - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - // dsrc1 = dtensor * 0 -> noop - } - } break; - case LM_GGML_OP_CONT: - { - // same as cpy - if (src0->grad) { - LM_GGML_ASSERT(lm_ggml_is_contiguous(src0->grad)); - LM_GGML_ASSERT(lm_ggml_is_contiguous(tensor->grad)); - src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - } break; - case LM_GGML_OP_RESHAPE: - { - // necessary for llama - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, src0->grad, - lm_ggml_reshape(ctx, - lm_ggml_is_contiguous(tensor->grad) - ? tensor->grad - : lm_ggml_cont(ctx, tensor->grad), - src0->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_VIEW: - { - // necessary for llama - if (src0->grad) { - size_t offset; - - memcpy(&offset, tensor->op_params, sizeof(offset)); - - size_t nb1 = tensor->nb[1]; - size_t nb2 = tensor->nb[2]; - size_t nb3 = tensor->nb[3]; - - if (src0->type != src0->grad->type) { - // gradient is typically F32, but src0 could be other type - size_t ng = lm_ggml_element_size(src0->grad); - size_t n0 = lm_ggml_element_size(src0); - LM_GGML_ASSERT(offset % n0 == 0); - LM_GGML_ASSERT(nb1 % n0 == 0); - LM_GGML_ASSERT(nb2 % n0 == 0); - LM_GGML_ASSERT(nb3 % n0 == 0); - offset = (offset / n0) * ng; - nb1 = (nb1 / n0) * ng; - nb2 = (nb2 / n0) * ng; - nb3 = (nb3 / n0) * ng; - } - - src0->grad = lm_ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table); - } - } break; - case LM_GGML_OP_PERMUTE: - { - // necessary for llama - if (src0->grad) { - int32_t * axes = (int32_t *) tensor->op_params; - int axis0 = axes[0] & 0x3; - int axis1 = axes[1] & 0x3; - int axis2 = axes[2] & 0x3; - int axis3 = axes[3] & 0x3; - int axes_backward[4] = {0,0,0,0}; - axes_backward[axis0] = 0; - axes_backward[axis1] = 1; - axes_backward[axis2] = 2; - axes_backward[axis3] = 3; - src0->grad = - lm_ggml_add_or_set(ctx, src0->grad, - lm_ggml_permute(ctx, - tensor->grad, - axes_backward[0], - axes_backward[1], - axes_backward[2], - axes_backward[3]), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_TRANSPOSE: - { - // necessary for llama - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, src0->grad, - lm_ggml_transpose(ctx, tensor->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_GET_ROWS: - { - // necessary for llama (only for tokenizer) - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, src0->grad, - // last lm_ggml_get_rows_back argument src0->grad is only - // necessary to setup correct output shape - lm_ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), - zero_table, acc_table); - } - if (src1->grad) { - // noop - } - } break; - case LM_GGML_OP_GET_ROWS_BACK: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + lm_ggml_add_or_set(ctx, cgraph, isrc1, lm_ggml_reshape(ctx, lm_ggml_cont(ctx, tensor_grad_view), src1)); } - case LM_GGML_OP_DIAG: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_SUB: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case LM_GGML_OP_DIAG_MASK_INF: - { - // necessary for llama - if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; - src0->grad = - lm_ggml_add_or_set(ctx, src0->grad, - /* lm_ggml_diag_mask_inf_impl() shouldn't be here */ - /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ - lm_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_DIAG_MASK_ZERO: - { - // necessary for llama - if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; - src0->grad = - lm_ggml_add_or_set(ctx, src0->grad, - lm_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_SOFT_MAX: - { - // necessary for llama - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, src0->grad, - lm_ggml_soft_max_back(ctx, tensor->grad, tensor), - zero_table, acc_table); - } - LM_GGML_ASSERT((!src1 || !src1->grad) && "backward pass for softmax mask not implemented"); - } break; - case LM_GGML_OP_SOFT_MAX_BACK: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + if (src1_needs_grads) { + lm_ggml_sub_or_set(ctx, cgraph, isrc1, grad); } - case LM_GGML_OP_ROPE: - { - // necessary for llama - if (src0->grad) { - //const int n_past = ((int32_t *) tensor->op_params)[0]; - const int n_dims = ((int32_t *) tensor->op_params)[1]; - const int mode = ((int32_t *) tensor->op_params)[2]; - //const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - - memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_rope_back(ctx, - tensor->grad, - src1, - src2, - n_dims, - mode, - n_ctx_orig, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow), - zero_table, acc_table); - } - LM_GGML_ASSERT((!src2 || !src2->grad) && "gradients for freq factors not implemented"); - } break; - case LM_GGML_OP_ROPE_BACK: - { - if (src0->grad) { - //const int n_past = ((int32_t *) tensor->op_params)[0]; - const int n_dims = ((int32_t *) tensor->op_params)[1]; - const int mode = ((int32_t *) tensor->op_params)[2]; - //const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - - memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_rope_impl(ctx, - tensor->grad, - src1, - src2, - n_dims, - mode, - n_ctx_orig, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow, - false), - zero_table, acc_table); + } break; + case LM_GGML_OP_MUL: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_mul(ctx, src1, grad)); + } + if (src1_needs_grads) { + struct lm_ggml_tensor * tmp = lm_ggml_mul(ctx, src0, grad); + if (!lm_ggml_are_same_shape(src0, src1)) { + tmp = lm_ggml_repeat_back(ctx, tmp, src1); } - } break; - case LM_GGML_OP_CLAMP: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + lm_ggml_add_or_set(ctx, cgraph, isrc1, tmp); } - case LM_GGML_OP_CONV_TRANSPOSE_1D: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_DIV: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_div(ctx, grad, src1)); } - case LM_GGML_OP_IM2COL: - { - if (src1->grad) { - const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 0); - const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 1); - const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 2); - const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 3); - const int32_t d0 = lm_ggml_get_op_params_i32(tensor, 4); - const int32_t d1 = lm_ggml_get_op_params_i32(tensor, 5); - const bool is_2D = lm_ggml_get_op_params_i32(tensor, 6) == 1; - - src1->grad = lm_ggml_add_or_set(ctx, - src1->grad, - lm_ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_IM2COL_BACK: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + if (src1_needs_grads) { + lm_ggml_sub_or_set(ctx, cgraph, isrc1, lm_ggml_mul(ctx, grad, lm_ggml_div(ctx, tensor, src1))); } - case LM_GGML_OP_CONV_TRANSPOSE_2D: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_SQR: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_scale(ctx, lm_ggml_mul(ctx, src0, grad), 2.0f)); } - case LM_GGML_OP_POOL_1D: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_SQRT: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_scale(ctx, lm_ggml_div(ctx, grad, tensor), 0.5f)); } - case LM_GGML_OP_POOL_2D: - { - if (src0->grad) { - const enum lm_ggml_op_pool op = lm_ggml_get_op_params_i32(tensor, 0); - const int32_t k0 = lm_ggml_get_op_params_i32(tensor, 1); - const int32_t k1 = lm_ggml_get_op_params_i32(tensor, 2); - const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 3); - const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 4); - const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 5); - const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 6); - - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1), - zero_table, acc_table); - } - } break; - case LM_GGML_OP_POOL_2D_BACK: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_LOG: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_div(ctx, grad, src0)); } - case LM_GGML_OP_UPSCALE: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_SIN: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_mul(ctx, grad, lm_ggml_cos(ctx, src0))); } - case LM_GGML_OP_PAD: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_COS: { + if (src0_needs_grads) { + lm_ggml_sub_or_set(ctx, cgraph, isrc0, lm_ggml_mul(ctx, grad, lm_ggml_sin(ctx, src0))); } - case LM_GGML_OP_ARANGE: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_SUM: { + if (src0_needs_grads) { + lm_ggml_add1_or_set(ctx, cgraph, isrc0, src0, grad); } - case LM_GGML_OP_TIMESTEP_EMBEDDING: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_SUM_ROWS: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_repeat(ctx, grad, src0)); } - case LM_GGML_OP_ARGSORT: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_MEAN: { + if (src0_needs_grads) { + lm_ggml_add1_or_set(ctx, cgraph, isrc0, src0, lm_ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false)); } - case LM_GGML_OP_LEAKY_RELU: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_REPEAT: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_repeat_back(ctx, grad, src0)); } - case LM_GGML_OP_FLASH_ATTN_EXT: - { - LM_GGML_ABORT("FA backward pass not adapted after rework"); - struct lm_ggml_tensor * flash_grad = NULL; - if (src0->grad || src1->grad || tensor->src[2]->grad) { - int32_t t = lm_ggml_get_op_params_i32(tensor, 0); - LM_GGML_ASSERT(t == 0 || t == 1); - bool masked = t != 0; - flash_grad = - lm_ggml_flash_attn_back(ctx, - src0, - src1, - tensor->src[2], - tensor->grad, - masked); + } break; + case LM_GGML_OP_REPEAT_BACK: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_repeat(ctx, grad, src0)); + } + } break; + case LM_GGML_OP_RMS_NORM: { + if (src0_needs_grads) { + float eps; + memcpy(&eps, tensor->op_params, sizeof(float)); + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_rms_norm_back(ctx, src0, grad, eps)); + } + } break; + case LM_GGML_OP_MUL_MAT: { + // https://cs231n.github.io/optimization-2/#staged + // # forward pass + // s0 = np.random.randn(5, 10) + // s1 = np.random.randn(10, 3) + // t = s0.dot(s1) + + // # now suppose we had the gradient on t from above in the circuit + // dt = np.random.randn(*t.shape) # same shape as t + // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix + // ds1 = t.T.dot(dt) + + // tensor.shape [m,p,qq,rr] + // src0.shape [n,m,q1,r1] + // src1.shape [n,p,qq,rr] + + if (src0_needs_grads) { + struct lm_ggml_tensor * s1_tg = + lm_ggml_out_prod(ctx, // [n,m,qq,rr] + src1, // [n,p,qq,rr] + grad); // [m,p,qq,rr] + const int64_t qq = s1_tg->ne[2]; + const int64_t rr = s1_tg->ne[3]; + const int64_t q1 = src0->ne[2]; + const int64_t r1 = src0->ne[3]; + const bool ne2_broadcasted = qq > q1; + const bool ne3_broadcasted = rr > r1; + if (ne2_broadcasted || ne3_broadcasted) { + // sum broadcast repetitions of s1_tg into shape of src0 + s1_tg = lm_ggml_repeat_back(ctx, s1_tg, src0); } + lm_ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/); + } + if (src1_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc1, + // lm_ggml_mul_mat(ctx, // [n,p,qq,rr] + // lm_ggml_cont(ctx, // [m,n,q1,r1] + // lm_ggml_transpose(ctx, src0)), // [m,n,q1,r1] + // grad), // [m,p,qq,rr] + + // when src0 is bigger than tensor->grad (this is mostly the case in llama), + // avoid transpose of src0, rather transpose smaller tensor->grad + // and then use lm_ggml_out_prod + lm_ggml_out_prod(ctx, // [n,p,qq,rr] + src0, // [n,m,q1,r1] + lm_ggml_transpose(ctx, // [p,m,qq,rr] + grad))); // [m,p,qq,rr] + } + } break; + case LM_GGML_OP_SCALE: { + if (src0_needs_grads) { + float s; + memcpy(&s, tensor->op_params, sizeof(float)); + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_scale_impl(ctx, grad, s, false)); + } + } break; + case LM_GGML_OP_SET: { + const size_t nb1 = ((const int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((const int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((const int32_t *) tensor->op_params)[2]; + const size_t offset = ((const int32_t *) tensor->op_params)[3]; + + struct lm_ggml_tensor * tensor_grad_view = NULL; + + if (src0_needs_grads || src1_needs_grads) { + LM_GGML_ASSERT(src0->type == tensor->type); + LM_GGML_ASSERT(!cgraph->grads[isrc0] || cgraph->grads[isrc0]->type == grad->type); + LM_GGML_ASSERT(!cgraph->grads[isrc1] || !src1_needs_grads || cgraph->grads[isrc1]->type == grad->type); + + tensor_grad_view = lm_ggml_view_4d(ctx, + grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + nb1, nb2, nb3, offset); + } - const int64_t elem_q = lm_ggml_nelements(src0); - const int64_t elem_k = lm_ggml_nelements(src1); - const int64_t elem_v = lm_ggml_nelements(src2); - - enum lm_ggml_type result_type = flash_grad->type; - LM_GGML_ASSERT(lm_ggml_blck_size(result_type) == 1); - const size_t tsize = lm_ggml_type_size(result_type); - - const size_t offs_q = 0; - const size_t offs_k = offs_q + LM_GGML_PAD(elem_q * tsize, LM_GGML_MEM_ALIGN); - const size_t offs_v = offs_k + LM_GGML_PAD(elem_k * tsize, LM_GGML_MEM_ALIGN); - - if (src0->grad) { - struct lm_ggml_tensor * view_q = lm_ggml_view_1d(ctx, flash_grad, elem_q, offs_q); - struct lm_ggml_tensor * grad_q = lm_ggml_reshape(ctx, view_q, src0); - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - grad_q, - zero_table, acc_table); - } - if (src1->grad) { - struct lm_ggml_tensor * view_k = lm_ggml_view_1d(ctx, flash_grad, elem_k, offs_k); - struct lm_ggml_tensor * grad_k = lm_ggml_reshape(ctx, view_k, src1); - src1->grad = lm_ggml_add_or_set(ctx, - src1->grad, - grad_k, - zero_table, acc_table); - } - if (src2->grad) { - struct lm_ggml_tensor * view_v = lm_ggml_view_1d(ctx, flash_grad, elem_v, offs_v); - struct lm_ggml_tensor * grad_v = lm_ggml_reshape(ctx, view_v, src2); - src2->grad = lm_ggml_add_or_set(ctx, - src2->grad, - grad_v, - zero_table, acc_table); + if (src0_needs_grads) { + struct lm_ggml_tensor * tmp = lm_ggml_neg(ctx, tensor_grad_view); + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_acc_impl(ctx, grad, tmp, nb1, nb2, nb3, offset, false)); + } + + if (src1_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc1, lm_ggml_reshape(ctx, lm_ggml_cont(ctx, tensor_grad_view), src1)); + } + } break; + case LM_GGML_OP_CPY: { + // cpy overwrites value of src1 by src0 and returns view(src1) + // the overwriting is mathematically equivalent to: + // tensor = src0 * 1 + src1 * 0 + if (src0_needs_grads) { + // dsrc0 = dtensor * 1 + lm_ggml_add_or_set(ctx, cgraph, isrc0, grad); + } + if (src1_needs_grads) { + // dsrc1 = dtensor * 0 -> noop + } + } break; + case LM_GGML_OP_CONT: { + // same as cpy + if (src0_needs_grads) { + LM_GGML_ASSERT(!cgraph->grads[isrc0] || lm_ggml_is_contiguous(cgraph->grads[isrc0])); + LM_GGML_ASSERT(lm_ggml_is_contiguous(grad)); + lm_ggml_add_or_set(ctx, cgraph, isrc0, grad); + } + } break; + case LM_GGML_OP_RESHAPE: { + if (src0_needs_grads) { + struct lm_ggml_tensor * grad_cont = lm_ggml_is_contiguous(grad) ? grad : lm_ggml_cont(ctx, grad); + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_reshape(ctx, grad_cont, src0)); + } + } break; + case LM_GGML_OP_VIEW: { + if (src0_needs_grads) { + size_t offset; + + memcpy(&offset, tensor->op_params, sizeof(offset)); + + size_t nb1 = tensor->nb[1]; + size_t nb2 = tensor->nb[2]; + size_t nb3 = tensor->nb[3]; + + if (cgraph->grads[isrc0] && src0->type != cgraph->grads[isrc0]->type) { + // gradient is typically F32, but src0 could be other type + size_t ng = lm_ggml_element_size(cgraph->grads[isrc0]); + size_t n0 = lm_ggml_element_size(src0); + LM_GGML_ASSERT(offset % n0 == 0); + LM_GGML_ASSERT(nb1 % n0 == 0); + LM_GGML_ASSERT(nb2 % n0 == 0); + LM_GGML_ASSERT(nb3 % n0 == 0); + offset = (offset / n0) * ng; + nb1 = (nb1 / n0) * ng; + nb2 = (nb2 / n0) * ng; + nb3 = (nb3 / n0) * ng; } - } break; - case LM_GGML_OP_FLASH_ATTN_BACK: - { - LM_GGML_ABORT("fatal error"); // not supported + + lm_ggml_acc_or_set(ctx, cgraph, isrc0, src0, grad, nb1, nb2, nb3, offset); } - case LM_GGML_OP_SSM_CONV: - case LM_GGML_OP_SSM_SCAN: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case LM_GGML_OP_PERMUTE: { + if (src0_needs_grads) { + const int32_t * axes = (const int32_t *) tensor->op_params; + const int axis0 = axes[0] & 0x3; + const int axis1 = axes[1] & 0x3; + const int axis2 = axes[2] & 0x3; + const int axis3 = axes[3] & 0x3; + int axb[4] = {0,0,0,0}; // axes backward + axb[axis0] = 0; + axb[axis1] = 1; + axb[axis2] = 2; + axb[axis3] = 3; + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_permute(ctx, grad, axb[0], axb[1], axb[2], axb[3])); + } + } break; + case LM_GGML_OP_TRANSPOSE: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_transpose(ctx, grad)); + } + } break; + case LM_GGML_OP_GET_ROWS: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_get_rows_back(ctx, grad, src1, src0)); + } + if (src1_needs_grads) { + // noop + } + } break; + case LM_GGML_OP_DIAG_MASK_INF: { + if (src0_needs_grads) { + /* lm_ggml_diag_mask_inf_impl() shouldn't be here */ + /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ + const int n_past = ((const int32_t *) tensor->op_params)[0]; + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_diag_mask_zero_impl(ctx, grad, n_past, false)); + } + } break; + case LM_GGML_OP_DIAG_MASK_ZERO: { + if (src0_needs_grads) { + const int n_past = ((const int32_t *) tensor->op_params)[0]; + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_diag_mask_zero_impl(ctx, grad, n_past, false)); } + } break; + case LM_GGML_OP_SOFT_MAX: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_soft_max_back(ctx, grad, tensor)); + } + LM_GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented"); + } break; + case LM_GGML_OP_ROPE: { + if (src0_needs_grads) { + //const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((const int32_t *) tensor->op_params)[1]; + const int mode = ((const int32_t *) tensor->op_params)[2]; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + + memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (const float *) tensor->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float)); + + lm_ggml_add_or_set(ctx, cgraph, isrc0, + lm_ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base, + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow)); + } + LM_GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented"); + } break; + case LM_GGML_OP_IM2COL: { + if (src1_needs_grads) { + const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 0); + const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 1); + const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 2); + const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 3); + const int32_t d0 = lm_ggml_get_op_params_i32(tensor, 4); + const int32_t d1 = lm_ggml_get_op_params_i32(tensor, 5); + const bool is_2D = lm_ggml_get_op_params_i32(tensor, 6) == 1; + + lm_ggml_add_or_set(ctx, cgraph, isrc1, lm_ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D)); + } + } break; + case LM_GGML_OP_POOL_2D: { + if (src0_needs_grads) { + const enum lm_ggml_op_pool op = lm_ggml_get_op_params_i32(tensor, 0); + const int32_t k0 = lm_ggml_get_op_params_i32(tensor, 1); + const int32_t k1 = lm_ggml_get_op_params_i32(tensor, 2); + const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 3); + const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 4); + const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 5); + const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 6); + + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_pool_2d_back(ctx, grad, src0, op, k0, k1, s0, s1, p0, p1)); + } + } break; case LM_GGML_OP_WIN_PART: case LM_GGML_OP_WIN_UNPART: - case LM_GGML_OP_UNARY: - { - switch (lm_ggml_get_unary_op(tensor)) { - case LM_GGML_UNARY_OP_ABS: - { - if (src0->grad) { - src0->grad = - lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_mul(ctx, - lm_ggml_sgn(ctx, src0), - tensor->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_UNARY_OP_SGN: - { - if (src0->grad) { - // noop - } - } break; - case LM_GGML_UNARY_OP_NEG: - { - if (src0->grad) { - src0->grad = lm_ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - } break; - case LM_GGML_UNARY_OP_STEP: - { - if (src0->grad) { - // noop - } - } break; - case LM_GGML_UNARY_OP_TANH: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_UNARY_OP_ELU: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_UNARY_OP_RELU: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_mul(ctx, - lm_ggml_step(ctx, src0), - tensor->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_UNARY_OP_SIGMOID: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_UNARY_OP_GELU: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_UNARY_OP_GELU_QUICK: - { - LM_GGML_ABORT("fatal error"); // TODO: not implemented - } - case LM_GGML_UNARY_OP_SILU: - { - // necessary for llama - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_silu_back(ctx, src0, tensor->grad), - zero_table, acc_table); - } - } break; - case LM_GGML_UNARY_OP_EXP: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_mul(ctx, tensor, tensor->grad), - zero_table, acc_table); - } - } break; - default: - LM_GGML_ABORT("fatal error"); - } - } break; - case LM_GGML_OP_GET_REL_POS: - case LM_GGML_OP_ADD_REL_POS: - case LM_GGML_OP_RWKV_WKV6: - case LM_GGML_OP_MAP_UNARY: - case LM_GGML_OP_MAP_BINARY: - case LM_GGML_OP_MAP_CUSTOM1_F32: - case LM_GGML_OP_MAP_CUSTOM2_F32: - case LM_GGML_OP_MAP_CUSTOM3_F32: - case LM_GGML_OP_MAP_CUSTOM1: - case LM_GGML_OP_MAP_CUSTOM2: - case LM_GGML_OP_MAP_CUSTOM3: - { - LM_GGML_ABORT("fatal error"); // not supported - } - case LM_GGML_OP_CROSS_ENTROPY_LOSS: - { - if (src0->grad) { - src0->grad = lm_ggml_add_or_set(ctx, - src0->grad, - lm_ggml_cross_entropy_loss_back(ctx, - src0, - src1, - tensor->grad), - zero_table, acc_table); - } - LM_GGML_ASSERT(!src1->grad && "backward pass for labels not implemented"); - } break; - case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - LM_GGML_ABORT("fatal error"); // not supported + case LM_GGML_OP_UNARY: { + switch (lm_ggml_get_unary_op(tensor)) { + case LM_GGML_UNARY_OP_ABS: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_mul(ctx, lm_ggml_sgn(ctx, src0), grad)); + } + } break; + case LM_GGML_UNARY_OP_SGN: { + // noop + } break; + case LM_GGML_UNARY_OP_NEG: { + if (src0_needs_grads) { + lm_ggml_sub_or_set(ctx, cgraph, isrc0, grad); + } + } break; + case LM_GGML_UNARY_OP_STEP: { + // noop + } break; + case LM_GGML_UNARY_OP_RELU: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_mul(ctx, lm_ggml_step(ctx, src0), grad)); + } + } break; + case LM_GGML_UNARY_OP_SILU: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_silu_back(ctx, src0, grad)); + } + } break; + case LM_GGML_UNARY_OP_EXP: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_mul(ctx, tensor, grad)); + } + } break; + default: { + fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n", + __func__, lm_ggml_unary_op_name(lm_ggml_get_unary_op(tensor))); + LM_GGML_ABORT("fatal error"); + } //break; } - case LM_GGML_OP_OPT_STEP_ADAMW: - { - LM_GGML_ABORT("fatal error"); // not supported + } break; + case LM_GGML_OP_CROSS_ENTROPY_LOSS: { + if (src0_needs_grads) { + lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_cross_entropy_loss_back(ctx, src0, src1, grad)); } - case LM_GGML_OP_NONE: - { - // nop - } break; + LM_GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); + } break; + case LM_GGML_OP_NONE: { + // noop + } break; case LM_GGML_OP_COUNT: - { - LM_GGML_ABORT("fatal error"); - } + default: { + fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, lm_ggml_op_name(tensor->op)); + LM_GGML_ABORT("fatal error"); + } //break; } - for (int i = 0; i < LM_GGML_MAX_SRC; ++i) { - if (tensor->src[i] && tensor->src[i]->grad) { - LM_GGML_ASSERT(lm_ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad)); - } - } + LM_GGML_ASSERT(!src0_needs_grads || lm_ggml_are_same_shape(src0, cgraph->grads[isrc0])); + LM_GGML_ASSERT(!src1_needs_grads || lm_ggml_are_same_shape(src1, cgraph->grads[isrc1])); + LM_GGML_ASSERT(!src2_needs_grads || lm_ggml_are_same_shape(src2, cgraph->grads[isrc2])); } static void lm_ggml_visit_parents(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * node) { - if (node->grad == NULL) { - // this usually happens when we generate intermediate nodes from constants in the backward pass - // it can also happen during forward pass, if the user performs computations with constants - if (node->op != LM_GGML_OP_NONE) { - //LM_GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); - } - } - // check if already visited if (lm_ggml_hash_insert(&cgraph->visited_hash_set, node) == LM_GGML_HASHSET_ALREADY_EXISTS) { return; @@ -6207,18 +5586,42 @@ void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml lm_ggml_build_forward_impl(cgraph, tensor, true); } -void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool accumulate) { - LM_GGML_ASSERT(gf->n_nodes > 0); - LM_GGML_ASSERT(gf->grads); +void lm_ggml_build_backward_expand( + struct lm_ggml_context * ctx_static, + struct lm_ggml_context * ctx_compute, + struct lm_ggml_cgraph * cgraph, + bool accumulate) { + LM_GGML_ASSERT(cgraph->n_nodes > 0); + LM_GGML_ASSERT(cgraph->grads); + LM_GGML_ASSERT(cgraph->grad_accs); + + const int n_nodes_f = cgraph->n_nodes; - for (int i = 0; i < gf->n_nodes; ++i) { - struct lm_ggml_tensor * node = gf->nodes[i]; + const size_t hash_size = lm_ggml_hash_size(2*cgraph->size); + memset(cgraph->grads, 0, hash_size*sizeof(struct lm_ggml_tensor *)); + memset(cgraph->grad_accs, 0, hash_size*sizeof(struct lm_ggml_tensor *)); + bool * grads_needed = calloc(hash_size, sizeof(bool)); + + { + bool any_params = false; + bool any_loss = false; + for (int i = 0; i < n_nodes_f; ++i) { + struct lm_ggml_tensor * node = cgraph->nodes[i]; + any_params = any_params || (node->flags & LM_GGML_TENSOR_FLAG_PARAM); + any_loss = any_loss || (node->flags & LM_GGML_TENSOR_FLAG_LOSS); + } + LM_GGML_ASSERT(any_params && "no trainable parameters found, did you forget to call lm_ggml_set_param?"); + LM_GGML_ASSERT(any_loss && "no training loss found, did you forget to call lm_ggml_set_loss?"); + } + + for (int i = 0; i < n_nodes_f; ++i) { + struct lm_ggml_tensor * node = cgraph->nodes[i]; if (node->type == LM_GGML_TYPE_I32) { continue; } - bool needs_grad = node->flags & LM_GGML_TENSOR_FLAG_PARAM; + bool node_needs_grad = node->flags & LM_GGML_TENSOR_FLAG_PARAM; bool ignore_src[LM_GGML_MAX_SRC] = {false}; switch (node->op) { // gradients in node->src[0] for one reason or another have no effect on output gradients @@ -6246,14 +5649,14 @@ void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_ break; } for (int j = 0; j < LM_GGML_MAX_SRC; ++j) { - if (!node->src[j] || !node->src[j]->grad || ignore_src[j]) { + if (!node->src[j] || ignore_src[j] || !grads_needed[lm_ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) { continue; } LM_GGML_ASSERT(node->src[j]->type == LM_GGML_TYPE_F32 || node->src[j]->type == LM_GGML_TYPE_F16); - needs_grad = true; + node_needs_grad = true; break; } - if (!needs_grad) { + if (!node_needs_grad) { continue; } @@ -6261,73 +5664,21 @@ void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_ LM_GGML_ASSERT(!node->view_src || node->op == LM_GGML_OP_CPY || node->op == LM_GGML_OP_VIEW || node->op == LM_GGML_OP_RESHAPE || node->op == LM_GGML_OP_PERMUTE || node->op == LM_GGML_OP_TRANSPOSE); - // create a new tensor with the same type and shape as the node and set it as grad - node->grad = lm_ggml_dup_tensor(ctx, node); - } - - // keep tables of original gradients for replacement/accumulation logic - struct lm_ggml_hash_set zero_table = lm_ggml_hash_set_new(gf->size); - struct lm_ggml_hash_set acc_table = lm_ggml_hash_set_new(gf->size); - for (int i = 0; i < gf->n_nodes; i++) { - struct lm_ggml_tensor * node = gf->nodes[i]; - - if (node->grad) { - { - const size_t insert_result = lm_ggml_hash_insert(&zero_table, node->grad); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); - } - - // only gradients of trainable parameters should be accumulated - if (accumulate && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) { - const size_t insert_result = lm_ggml_hash_insert(&acc_table, node->grad); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL); - LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS); - } + const size_t igrad = lm_ggml_hash_find(&cgraph->visited_hash_set, node); + if ((accumulate && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) || (node->flags & LM_GGML_TENSOR_FLAG_LOSS)) { + cgraph->grads[igrad] = lm_ggml_dup_tensor(ctx_static, node); + cgraph->grad_accs[igrad] = cgraph->grads[igrad]; } + grads_needed[igrad] = true; } - for (int i = gf->n_nodes - 1; i >= 0; i--) { - struct lm_ggml_tensor * node = gf->nodes[i]; - + for (int i = n_nodes_f - 1; i >= 0; --i) { // inplace operations to add gradients are not created by lm_ggml_compute_backward except for gradient accumulation // use allocator to automatically make inplace operations - if (node->grad) { - lm_ggml_compute_backward(ctx, node, &zero_table, &acc_table); - } + lm_ggml_compute_backward(ctx_compute, cgraph, i, grads_needed); } - for (int i = 0; i < gf->n_nodes; i++) { - struct lm_ggml_tensor * node = gf->nodes[i]; - - if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) { - LM_GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - lm_ggml_build_forward_expand(gb, node->grad); - } - } - - lm_ggml_hash_set_free(&zero_table); - lm_ggml_hash_set_free(&acc_table); -} - -void lm_ggml_build_opt_adamw( - struct lm_ggml_context * ctx, - struct lm_ggml_cgraph * gf, - struct lm_ggml_cgraph * gb, - float alpha, - float beta1, - float beta2, - float eps, - float wd) { - for (int i = 0; i < gf->n_nodes; i++) { - struct lm_ggml_tensor * node = gf->nodes[i]; - - if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) { - LM_GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - struct lm_ggml_tensor * opt_step = lm_ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd); - lm_ggml_build_forward_expand(gb, opt_step); - } - } + free(grads_needed); } static void * incr_ptr_aligned(void ** p, size_t size, size_t align) { @@ -6345,7 +5696,8 @@ static size_t lm_ggml_graph_nbytes(size_t size, bool grads) { incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // leafs incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // hash keys if (grads) { - incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // grads + incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // grads + incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // grad_accs } incr_ptr_aligned(&p, lm_ggml_bitset_size(hash_size) * sizeof(lm_ggml_bitset_t), sizeof(lm_ggml_bitset_t)); @@ -6371,10 +5723,12 @@ struct lm_ggml_cgraph * lm_ggml_new_graph_custom(struct lm_ggml_context * ctx, s void * p = cgraph + 1; - struct lm_ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); - struct lm_ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); - struct lm_ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); - struct lm_ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)) : NULL; + struct lm_ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); + struct lm_ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); + struct lm_ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); + struct lm_ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)) : NULL; + struct lm_ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)) : NULL; + lm_ggml_bitset_t * hash_used = incr_ptr_aligned(&p, lm_ggml_bitset_size(hash_size) * sizeof(lm_ggml_bitset_t), sizeof(lm_ggml_bitset_t)); // check that we allocated the correct amount of memory @@ -6386,12 +5740,17 @@ struct lm_ggml_cgraph * lm_ggml_new_graph_custom(struct lm_ggml_context * ctx, s /*.n_leafs =*/ 0, /*.nodes =*/ nodes_ptr, /*.grads =*/ grads_ptr, + /*.grad_accs =*/ grad_accs_ptr, /*.leafs =*/ leafs_ptr, /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, /*.order =*/ LM_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, }; lm_ggml_hash_set_reset(&cgraph->visited_hash_set); + if (grads) { + memset(cgraph->grads, 0, hash_size*sizeof(struct lm_ggml_tensor *)); + memset(cgraph->grad_accs, 0, hash_size*sizeof(struct lm_ggml_tensor *)); + } return cgraph; } @@ -6407,6 +5766,7 @@ struct lm_ggml_cgraph lm_ggml_graph_view(struct lm_ggml_cgraph * cgraph0, int i0 /*.n_leafs =*/ 0, /*.nodes =*/ cgraph0->nodes + i0, /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL, + /*.grad_accs =*/ cgraph0->grad_accs ? cgraph0->grad_accs + i0 : NULL, /*.leafs =*/ NULL, /*.hash_table =*/ { 0, NULL, NULL }, /*.order =*/ cgraph0->order, @@ -6432,19 +5792,23 @@ void lm_ggml_graph_cpy(struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst) dst->nodes[i] = src->nodes[i]; } - if (src->grads) { - LM_GGML_ASSERT(dst->grads != NULL); - for (int i = 0; i < src->n_nodes; ++i) { - dst->grads[i] = src->grads[i]; - } - } - for (size_t i = 0; i < src->visited_hash_set.size; ++i) { // copy all hashset keys (tensors) that are in use if (lm_ggml_bitset_get(src->visited_hash_set.used, i)) { lm_ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]); } } + + if (src->grads) { + LM_GGML_ASSERT(dst->grads != NULL); + LM_GGML_ASSERT(dst->grad_accs != NULL); + for (int i = 0; i < src->n_nodes; ++i) { + const size_t igrad_src = lm_ggml_hash_find(&src->visited_hash_set, src->nodes[i]); + const size_t igrad_dst = lm_ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]); + dst->grads[igrad_dst] = src->grads[igrad_src]; + dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src]; + } + } } struct lm_ggml_cgraph * lm_ggml_graph_dup(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph) { @@ -6470,29 +5834,36 @@ void lm_ggml_graph_reset(struct lm_ggml_cgraph * cgraph) { LM_GGML_ASSERT(cgraph->grads != NULL); for (int i = 0; i < cgraph->n_nodes; i++) { - struct lm_ggml_tensor * node = cgraph->nodes[i]; + struct lm_ggml_tensor * node = cgraph->nodes[i]; + struct lm_ggml_tensor * grad_acc = lm_ggml_graph_get_grad_acc(cgraph, node); + + if (node->op == LM_GGML_OP_OPT_STEP_ADAMW) { + // clear momenta + if (node->src[2]->data) { + lm_ggml_set_zero(node->src[2]); + } + if (node->src[3]->data) { + lm_ggml_set_zero(node->src[3]); + } + } // initial gradients of loss should be 1, 0 otherwise - if (node->grad) { + if (grad_acc) { if (node->flags & LM_GGML_TENSOR_FLAG_LOSS) { - LM_GGML_ASSERT(node->grad->buffer); - LM_GGML_ASSERT(node->type == LM_GGML_TYPE_F32); - LM_GGML_ASSERT(lm_ggml_is_scalar(node)); + LM_GGML_ASSERT(grad_acc->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(lm_ggml_is_scalar(grad_acc)); const float onef = 1.0f; - lm_ggml_backend_tensor_set(node->grad, &onef, 0, lm_ggml_nbytes(node->grad)); + if (grad_acc->buffer) { + lm_ggml_backend_tensor_set(grad_acc, &onef, 0, sizeof(float)); + } else { + LM_GGML_ASSERT(grad_acc->data); + *((float *) grad_acc->data) = onef; + } } else { - lm_ggml_set_zero(node->grad); + lm_ggml_set_zero(grad_acc); } } - - LM_GGML_ASSERT(node); - if (node->op == LM_GGML_OP_OPT_STEP_ADAMW) { - // set iteration to 1 and clear momenta - lm_ggml_set_op_params_i32(node, 0, 1); - lm_ggml_set_zero(node->src[2]); - lm_ggml_set_zero(node->src[3]); - } } } @@ -6530,7 +5901,7 @@ void lm_ggml_graph_add_node(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tenso cgraph->n_nodes++; } -struct lm_ggml_tensor * lm_ggml_graph_get_tensor(struct lm_ggml_cgraph * cgraph, const char * name) { +struct lm_ggml_tensor * lm_ggml_graph_get_tensor(const struct lm_ggml_cgraph * cgraph, const char * name) { for (int i = 0; i < cgraph->n_leafs; i++) { struct lm_ggml_tensor * leaf = cgraph->leafs[i]; @@ -6550,6 +5921,16 @@ struct lm_ggml_tensor * lm_ggml_graph_get_tensor(struct lm_ggml_cgraph * cgraph, return NULL; } +struct lm_ggml_tensor * lm_ggml_graph_get_grad(const struct lm_ggml_cgraph * cgraph, const struct lm_ggml_tensor * node) { + const size_t igrad = lm_ggml_hash_find(&cgraph->visited_hash_set, node); + return igrad != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grads[igrad] : NULL; +} + +struct lm_ggml_tensor * lm_ggml_graph_get_grad_acc(const struct lm_ggml_cgraph * cgraph, const struct lm_ggml_tensor * node) { + const size_t igrad = lm_ggml_hash_find(&cgraph->visited_hash_set, node); + return igrad != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grad_accs[igrad] : NULL; +} + void lm_ggml_graph_print(const struct lm_ggml_cgraph * cgraph) { LM_GGML_LOG_INFO("=== GRAPH ===\n"); @@ -6560,7 +5941,8 @@ void lm_ggml_graph_print(const struct lm_ggml_cgraph * cgraph) { LM_GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n", i, node->ne[0], node->ne[1], node->ne[2], - lm_ggml_op_name(node->op), (node->flags & LM_GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " "); + lm_ggml_op_name(node->op), (node->flags & LM_GGML_TENSOR_FLAG_PARAM) ? "x" : + lm_ggml_graph_get_grad(cgraph, node) ? "g" : " "); } LM_GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs); @@ -6595,8 +5977,9 @@ static bool lm_ggml_graph_find(const struct lm_ggml_cgraph * cgraph, const struc static struct lm_ggml_tensor * lm_ggml_graph_get_parent(const struct lm_ggml_cgraph * cgraph, const struct lm_ggml_tensor * node) { for (int i = 0; i < cgraph->n_nodes; i++) { struct lm_ggml_tensor * parent = cgraph->nodes[i]; + struct lm_ggml_tensor * grad = lm_ggml_graph_get_grad(cgraph, parent); - if (parent->grad == node) { + if (grad == node) { return parent; } } @@ -6636,6 +6019,7 @@ void lm_ggml_graph_dump_dot(const struct lm_ggml_cgraph * gb, const struct lm_gg for (int i = 0; i < gb->n_nodes; i++) { struct lm_ggml_tensor * node = gb->nodes[i]; + struct lm_ggml_tensor * grad = lm_ggml_graph_get_grad(gb, node); if (lm_ggml_graph_get_parent(gb, node) != NULL) { continue; @@ -6643,7 +6027,7 @@ void lm_ggml_graph_dump_dot(const struct lm_ggml_cgraph * gb, const struct lm_gg if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) { snprintf(color, sizeof(color), "yellow"); - } else if (node->grad) { + } else if (grad) { if (lm_ggml_graph_find(gf, node)) { snprintf(color, sizeof(color), "green"); } else { @@ -6670,8 +6054,8 @@ void lm_ggml_graph_dump_dot(const struct lm_ggml_cgraph * gb, const struct lm_gg fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], lm_ggml_op_symbol(node->op)); } - if (node->grad) { - fprintf(fp, " | %s\"; ]\n", lm_ggml_op_symbol(node->grad->op)); + if (grad) { + fprintf(fp, " | %s\"; ]\n", lm_ggml_op_symbol(grad->op)); } else { fprintf(fp, "\"; ]\n"); } diff --git a/cpp/ggml.h b/cpp/ggml.h index 677b6d3..f86241d 100644 --- a/cpp/ggml.h +++ b/cpp/ggml.h @@ -602,7 +602,6 @@ extern "C" { int32_t flags; - struct lm_ggml_tensor * grad; struct lm_ggml_tensor * src[LM_GGML_MAX_SRC]; // source tensor and offset for views @@ -615,7 +614,7 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - // char padding[4]; + char padding[8]; }; static const size_t LM_GGML_TENSOR_SIZE = sizeof(struct lm_ggml_tensor); @@ -1985,28 +1984,20 @@ extern "C" { struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * grad, - float alpha, - float beta1, - float beta2, - float eps, - float wd); // weight decay + struct lm_ggml_tensor * m, + struct lm_ggml_tensor * v, + struct lm_ggml_tensor * adamw_params); // parameters such a the learning rate // // automatic differentiation // - LM_GGML_API void lm_ggml_build_forward_expand (struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor); - LM_GGML_API void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool accumulate); - - LM_GGML_API void lm_ggml_build_opt_adamw( - struct lm_ggml_context * ctx, - struct lm_ggml_cgraph * gf, - struct lm_ggml_cgraph * gb, - float alpha, - float beta1, - float beta2, - float eps, - float wd); // weight decay + LM_GGML_API void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor); + LM_GGML_API void lm_ggml_build_backward_expand( + struct lm_ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation) + struct lm_ggml_context * ctx_compute, // context for gradient computation + struct lm_ggml_cgraph * cgraph, + bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static // graph allocation in a context LM_GGML_API struct lm_ggml_cgraph * lm_ggml_new_graph (struct lm_ggml_context * ctx); // size = LM_GGML_DEFAULT_GRAPH_SIZE, grads = false @@ -2026,7 +2017,9 @@ extern "C" { LM_GGML_API size_t lm_ggml_graph_overhead(void); LM_GGML_API size_t lm_ggml_graph_overhead_custom(size_t size, bool grads); - LM_GGML_API struct lm_ggml_tensor * lm_ggml_graph_get_tensor(struct lm_ggml_cgraph * cgraph, const char * name); + LM_GGML_API struct lm_ggml_tensor * lm_ggml_graph_get_tensor (const struct lm_ggml_cgraph * cgraph, const char * name); + LM_GGML_API struct lm_ggml_tensor * lm_ggml_graph_get_grad (const struct lm_ggml_cgraph * cgraph, const struct lm_ggml_tensor * node); + LM_GGML_API struct lm_ggml_tensor * lm_ggml_graph_get_grad_acc(const struct lm_ggml_cgraph * cgraph, const struct lm_ggml_tensor * node); LM_GGML_API void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fname); LM_GGML_API struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_context ** ctx_data, struct lm_ggml_context ** ctx_eval); @@ -2037,198 +2030,15 @@ extern "C" { // dump the graph into a file using the dot format LM_GGML_API void lm_ggml_graph_dump_dot(const struct lm_ggml_cgraph * gb, const struct lm_ggml_cgraph * gf, const char * filename); - // build gradient checkpointing backward graph gb for gf using provided checkpoints - // gb_tmp will contain original backward graph with rewritten backward process nodes, - // but without the second forward pass nodes. - LM_GGML_API void lm_ggml_build_backward_gradient_checkpointing( - struct lm_ggml_context * ctx, - struct lm_ggml_cgraph * gf, - struct lm_ggml_cgraph * gb, - struct lm_ggml_cgraph * gb_tmp, - struct lm_ggml_tensor * * checkpoints, - int n_checkpoints); - // - // optimization - // - - // optimization methods - enum lm_ggml_opt_type { - LM_GGML_OPT_TYPE_ADAM, - LM_GGML_OPT_TYPE_LBFGS, - }; - - // linesearch methods - enum lm_ggml_linesearch { - LM_GGML_LINESEARCH_DEFAULT = 1, - - LM_GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0, - LM_GGML_LINESEARCH_BACKTRACKING_WOLFE = 1, - LM_GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2, - }; - - // optimization return values - enum lm_ggml_opt_result { - LM_GGML_OPT_RESULT_OK = 0, - LM_GGML_OPT_RESULT_DID_NOT_CONVERGE, - LM_GGML_OPT_RESULT_NO_CONTEXT, - LM_GGML_OPT_RESULT_INVALID_WOLFE, - LM_GGML_OPT_RESULT_FAIL, - LM_GGML_OPT_RESULT_CANCEL, - - LM_GGML_LINESEARCH_FAIL = -128, - LM_GGML_LINESEARCH_MINIMUM_STEP, - LM_GGML_LINESEARCH_MAXIMUM_STEP, - LM_GGML_LINESEARCH_MAXIMUM_ITERATIONS, - LM_GGML_LINESEARCH_INVALID_PARAMETERS, - }; - - typedef void (*lm_ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel); + // TODO these functions were sandwiched in the old optimization interface, is there a better place for them? typedef void (*lm_ggml_log_callback)(enum lm_ggml_log_level level, const char * text, void * user_data); // Set callback for all future logging events. // If this is not called, or NULL is supplied, everything is output on stderr. LM_GGML_API void lm_ggml_log_set(lm_ggml_log_callback log_callback, void * user_data); - // optimization parameters - // - // see ggml.c (lm_ggml_opt_default_params) for default values - // - struct lm_ggml_opt_params { - enum lm_ggml_opt_type type; - - size_t graph_size; - - int n_threads; - - // delta-based convergence test - // - // if past == 0 - disabled - // if past > 0: - // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|) - // - int past; - float delta; - - // maximum number of iterations without improvement - // - // if 0 - disabled - // if > 0: - // assume convergence if no cost improvement in this number of iterations - // - int max_no_improvement; - - bool print_forward_graph; - bool print_backward_graph; - - int n_gradient_accumulation; - - // ADAM parameters - struct { - int n_iter; - - float sched; // schedule multiplier (fixed, decay or warmup) - float decay; // weight decay for AdamW, use 0.0f to disable - int decay_min_ndim; // minimum number of tensor dimension to apply weight decay - float alpha; // learning rate - float beta1; - float beta2; - float eps; // epsilon for numerical stability - float eps_f; // epsilon for convergence test - float eps_g; // epsilon for convergence test - float gclip; // gradient clipping - } adam; - - // LBFGS parameters - struct { - int m; // number of corrections to approximate the inv. Hessian - int n_iter; - int max_linesearch; - - float eps; // convergence tolerance - float ftol; // line search tolerance - float wolfe; - float min_step; - float max_step; - - enum lm_ggml_linesearch linesearch; - } lbfgs; - }; - - struct lm_ggml_opt_context { - struct lm_ggml_context * ctx; - struct lm_ggml_opt_params params; - - int iter; - int64_t nx; // number of parameter elements - - bool just_initialized; - - float loss_before; - float loss_after; - - struct { - struct lm_ggml_tensor * g; // current gradient - struct lm_ggml_tensor * m; // first moment - struct lm_ggml_tensor * v; // second moment - struct lm_ggml_tensor * pf; // past function values - float fx_best; - float fx_prev; - int n_no_improvement; - } adam; - - struct { - struct lm_ggml_tensor * x; // current parameters - struct lm_ggml_tensor * xp; // previous parameters - struct lm_ggml_tensor * g; // current gradient - struct lm_ggml_tensor * gp; // previous gradient - struct lm_ggml_tensor * d; // search direction - struct lm_ggml_tensor * pf; // past function values - struct lm_ggml_tensor * lmal; // the L-BFGS memory alpha - struct lm_ggml_tensor * lmys; // the L-BFGS memory ys - struct lm_ggml_tensor * lms; // the L-BFGS memory s - struct lm_ggml_tensor * lmy; // the L-BFGS memory y - float fx_best; - float step; - int j; - int k; - int end; - int n_no_improvement; - } lbfgs; - }; - LM_GGML_API struct lm_ggml_tensor * lm_ggml_set_zero(struct lm_ggml_tensor * tensor); - LM_GGML_API struct lm_ggml_opt_params lm_ggml_opt_default_params(enum lm_ggml_opt_type type); - - // optimize the function defined by the tensor f - LM_GGML_API enum lm_ggml_opt_result lm_ggml_opt( - struct lm_ggml_context * ctx, - struct lm_ggml_opt_params params, - struct lm_ggml_tensor * f); - - // initialize optimizer context - LM_GGML_API void lm_ggml_opt_init( - struct lm_ggml_context * ctx, - struct lm_ggml_opt_context * opt, - struct lm_ggml_opt_params params, - int64_t nx); - - // continue optimizing the function defined by the tensor f - LM_GGML_API enum lm_ggml_opt_result lm_ggml_opt_resume( - struct lm_ggml_context * ctx, - struct lm_ggml_opt_context * opt, - struct lm_ggml_tensor * f); - - // continue optimizing the function defined by the tensor f - LM_GGML_API enum lm_ggml_opt_result lm_ggml_opt_resume_g( - struct lm_ggml_context * ctx, - struct lm_ggml_opt_context * opt, - struct lm_ggml_tensor * f, - struct lm_ggml_cgraph * gf, - struct lm_ggml_cgraph * gb, - lm_ggml_opt_callback callback, - void * callback_data); - // // quantization // diff --git a/cpp/llama.cpp b/cpp/llama.cpp index d5c006d..d6a6292 100644 --- a/cpp/llama.cpp +++ b/cpp/llama.cpp @@ -190,6 +190,7 @@ enum llm_arch { LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, LLM_ARCH_OLMO, + LLM_ARCH_OLMO_1124, LLM_ARCH_OLMOE, LLM_ARCH_OPENELM, LLM_ARCH_ARCTIC, @@ -243,6 +244,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_OLMO_1124, "olmo_1124" }, { LLM_ARCH_OLMOE, "olmoe" }, { LLM_ARCH_OPENELM, "openelm" }, { LLM_ARCH_ARCTIC, "arctic" }, @@ -1218,6 +1220,25 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_OLMO_1124, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_OLMOE, { @@ -3471,21 +3492,13 @@ static bool llama_kv_cache_init( const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - const llama_model::buft_list_t * buft_list; + lm_ggml_backend_buffer_type_t buft; if (offload) { - buft_list = model.dev_layer.at(i).buft_list; + auto * dev = model.dev_layer.at(i).dev; + buft = lm_ggml_backend_dev_buffer_type(dev); } else { - buft_list = &model.cpu_buft_list; + buft = lm_ggml_backend_cpu_buffer_type(); } - lm_ggml_backend_buffer_type_t buft = select_buft(*buft_list, - [&](lm_ggml_context * ctx) { - lm_ggml_tensor * k = lm_ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) { - return k; - } - lm_ggml_tensor * p = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, 1); - return lm_ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type); - }); lm_ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { @@ -5896,6 +5909,17 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_OLMO_1124: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 16: model.type = e_model::MODEL_1B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_OLMOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -8578,6 +8602,31 @@ static bool llm_load_tensors( layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_OLMO_1124: + { + model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; case LLM_ARCH_OLMOE: { model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -14443,6 +14492,130 @@ struct llm_build_context { return gf; } + struct lm_ggml_cgraph * build_olmo_1124() { + struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + LM_GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct lm_ggml_tensor * cur; + struct lm_ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct lm_ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct lm_ggml_tensor * inpSA = inpL; + + cur = inpL; + + // self_attention + { + // compute Q and K and RoPE them + struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + + Qcur = lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = lm_ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur_rope", il); + + Kcur = lm_ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur_rope", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_post_norm", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_ffn(ctx0, lctx, ffn_inp, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "ffn_post_norm", -1); + + cur = lm_ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + lm_ggml_build_forward_expand(gf, cur); + + return gf; + } + // based on the build_qwen2moe() function, changes: // * removed shared experts // * removed bias @@ -16635,6 +16808,10 @@ static struct lm_ggml_cgraph * llama_build_graph( { result = llm.build_olmo(); } break; + case LLM_ARCH_OLMO_1124: + { + result = llm.build_olmo_1124(); + } break; case LLM_ARCH_OLMOE: { result = llm.build_olmoe(); @@ -18047,7 +18224,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { // apply K-shift if needed if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) { - if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA + if (!llama_kv_cache_can_shift(&lctx)) { LM_GGML_ABORT("Deepseek2 does not support K-shift"); } @@ -19904,6 +20081,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: + case LLM_ARCH_OLMO_1124: case LLM_ARCH_OLMOE: case LLM_ARCH_PHI2: case LLM_ARCH_PHI3: @@ -20295,6 +20473,10 @@ void llama_kv_cache_update(struct llama_context * ctx) { llama_kv_cache_update_internal(*ctx); } +bool llama_kv_cache_can_shift(struct llama_context * ctx) { + return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA +} + // deprecated size_t llama_get_state_size(struct llama_context * ctx) { return llama_state_get_size(ctx); diff --git a/cpp/llama.h b/cpp/llama.h index 9d627af..8d14f49 100644 --- a/cpp/llama.h +++ b/cpp/llama.h @@ -667,6 +667,9 @@ extern "C" { // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) LLAMA_API void llama_kv_cache_update(struct llama_context * ctx); + // Check if the context supports KV cache shifting + LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx); + // // State / sessions // diff --git a/example/ios/.xcode.env.local b/example/ios/.xcode.env.local index 1fa3749..c2273fa 100644 --- a/example/ios/.xcode.env.local +++ b/example/ios/.xcode.env.local @@ -1 +1 @@ -export NODE_BINARY=/var/folders/g8/v75_3l3n23g909mshlzdj4wh0000gn/T/yarn--1731985865125-0.724061577974688/node +export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1732158453708-0.8549147503631913/node diff --git a/llama.cpp b/llama.cpp index 0fff7fd..9abe9ee 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit 0fff7fd79818980763a601660f25b01a0cf4b87a +Subproject commit 9abe9eeae98b11fa93b82632b264126a010225ff diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index b13062a..1d194c4 100755 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -10,6 +10,7 @@ cp ./llama.cpp/ggml/include/ggml-alloc.h ./cpp/ggml-alloc.h cp ./llama.cpp/ggml/include/ggml-backend.h ./cpp/ggml-backend.h cp ./llama.cpp/ggml/include/ggml-cpu.h ./cpp/ggml-cpu.h cp ./llama.cpp/ggml/include/ggml-cpp.h ./cpp/ggml-cpp.h +cp ./llama.cpp/ggml/include/ggml-opt.h ./cpp/ggml-opt.h cp ./llama.cpp/ggml/include/ggml-metal.h ./cpp/ggml-metal.h cp ./llama.cpp/ggml/src/ggml-metal/ggml-metal.m ./cpp/ggml-metal.m @@ -32,6 +33,7 @@ cp ./llama.cpp/ggml/src/ggml-backend.cpp ./cpp/ggml-backend.cpp cp ./llama.cpp/ggml/src/ggml-backend-impl.h ./cpp/ggml-backend-impl.h cp ./llama.cpp/ggml/src/ggml-backend-reg.cpp ./cpp/ggml-backend-reg.cpp cp ./llama.cpp/ggml/src/ggml-common.h ./cpp/ggml-common.h +cp ./llama.cpp/ggml/src/ggml-opt.cpp ./cpp/ggml-opt.cpp cp ./llama.cpp/ggml/src/ggml-quants.h ./cpp/ggml-quants.h cp ./llama.cpp/ggml/src/ggml-quants.c ./cpp/ggml-quants.c cp ./llama.cpp/ggml/src/ggml-aarch64.c ./cpp/ggml-aarch64.c @@ -84,6 +86,8 @@ files_add_lm_prefix=( "./cpp/ggml.c" "./cpp/ggml-impl.h" "./cpp/ggml-cpp.h" + "./cpp/ggml-opt.h" + "./cpp/ggml-opt.cpp" "./cpp/ggml-metal.h" "./cpp/ggml-metal.m" "./cpp/ggml-quants.h" diff --git a/scripts/common.cpp.patch b/scripts/common.cpp.patch index 4cc23b7..543397e 100644 --- a/scripts/common.cpp.patch +++ b/scripts/common.cpp.patch @@ -1,5 +1,5 @@ ---- common.cpp.orig 2024-11-21 10:21:53 -+++ common.cpp 2024-11-21 10:22:56 +--- common.cpp.orig 2024-11-21 11:03:19 ++++ common.cpp 2024-11-21 11:03:20 @@ -4,10 +4,6 @@ #include "common.h" @@ -33,7 +33,7 @@ // // CPU utils -@@ -979,6 +979,8 @@ +@@ -985,6 +985,8 @@ if (params.n_gpu_layers != -1) { mparams.n_gpu_layers = params.n_gpu_layers; } @@ -42,7 +42,7 @@ mparams.rpc_servers = params.rpc_servers.c_str(); mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; -@@ -993,6 +995,9 @@ +@@ -999,6 +1001,9 @@ mparams.kv_overrides = params.kv_overrides.data(); } @@ -52,7 +52,7 @@ return mparams; } -@@ -1118,220 +1123,6 @@ +@@ -1124,220 +1129,6 @@ return false; } diff --git a/scripts/ggml-metal.m.patch b/scripts/ggml-metal.m.patch index b2c8d86..3da0d3d 100644 --- a/scripts/ggml-metal.m.patch +++ b/scripts/ggml-metal.m.patch @@ -1,6 +1,6 @@ ---- ggml-metal.m.orig 2024-11-17 11:52:03 -+++ ggml-metal.m 2024-11-17 11:52:05 -@@ -461,7 +461,7 @@ +--- ggml-metal.m.orig 2024-11-21 11:03:19 ++++ ggml-metal.m 2024-11-21 11:03:20 +@@ -463,7 +463,7 @@ const bool try_metallib = true; #endif diff --git a/scripts/llama.cpp.patch b/scripts/llama.cpp.patch index efed972..56e5908 100644 --- a/scripts/llama.cpp.patch +++ b/scripts/llama.cpp.patch @@ -1,5 +1,5 @@ ---- llama.cpp.orig 2024-11-02 12:42:13 -+++ llama.cpp 2024-11-02 13:00:37 +--- llama.cpp.orig 2024-11-21 11:03:19 ++++ llama.cpp 2024-11-21 11:03:20 @@ -80,6 +80,17 @@ #define LLAMA_MAX_LAYERS 512 #define LLAMA_MAX_EXPERTS 160 // DeepSeekV2 @@ -18,7 +18,7 @@ // // helpers // -@@ -1930,16 +1941,16 @@ +@@ -1951,16 +1962,16 @@ if (prefetch > 0) { // advise the kernel to preload the mapped memory From 96b6dd789e06fcb2942313181d8206464088d0ea Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Thu, 21 Nov 2024 11:48:58 +0800 Subject: [PATCH 2/2] fix(ios): add missing ggml-metal-impl.h --- cpp/ggml-metal-impl.h | 249 +++++++++++++++++++++++++++++++++++ example/ios/.xcode.env.local | 2 +- scripts/bootstrap.sh | 2 + 3 files changed, 252 insertions(+), 1 deletion(-) create mode 100644 cpp/ggml-metal-impl.h diff --git a/cpp/ggml-metal-impl.h b/cpp/ggml-metal-impl.h new file mode 100644 index 0000000..481e010 --- /dev/null +++ b/cpp/ggml-metal-impl.h @@ -0,0 +1,249 @@ +#ifndef LM_GGML_METAL_IMPL +#define LM_GGML_METAL_IMPL + +// kernel argument structs +// +// - element counters (e.g. ne00) typically use int32_t to reduce register usage +// however, be careful from int overflows when using those in the kernel implementation +// +// - strides (e.g. nb00) use uint64_t + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t dim; +} lm_ggml_metal_kargs_concat; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; +} lm_ggml_metal_kargs_bin; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} lm_ggml_metal_kargs_repeat; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} lm_ggml_metal_kargs_cpy; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t n_past; + int32_t n_dims; + int32_t n_ctx_orig; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; +} lm_ggml_metal_kargs_rope; + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb_12_1; + uint64_t nb_12_2; + uint64_t nb_12_3; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; +} lm_ggml_metal_kargs_flash_attn_ext; + +typedef struct { + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} lm_ggml_metal_kargs_mul_mm; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} lm_ggml_metal_kargs_mul_mv; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; +} lm_ggml_metal_kargs_mul_mm_id; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; + uint64_t nb1; +} lm_ggml_metal_kargs_mul_mv_id; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} lm_ggml_metal_kargs_norm; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} lm_ggml_metal_kargs_rms_norm; + +#endif // LM_GGML_METAL_IMPL diff --git a/example/ios/.xcode.env.local b/example/ios/.xcode.env.local index c2273fa..b261bc2 100644 --- a/example/ios/.xcode.env.local +++ b/example/ios/.xcode.env.local @@ -1 +1 @@ -export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1732158453708-0.8549147503631913/node +export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1732160897686-0.7191373753799657/node diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index 1d194c4..d436965 100755 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -14,6 +14,7 @@ cp ./llama.cpp/ggml/include/ggml-opt.h ./cpp/ggml-opt.h cp ./llama.cpp/ggml/include/ggml-metal.h ./cpp/ggml-metal.h cp ./llama.cpp/ggml/src/ggml-metal/ggml-metal.m ./cpp/ggml-metal.m +cp ./llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h ./cpp/ggml-metal-impl.h cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c ./cpp/ggml-cpu.c cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp ./cpp/ggml-cpu.cpp @@ -90,6 +91,7 @@ files_add_lm_prefix=( "./cpp/ggml-opt.cpp" "./cpp/ggml-metal.h" "./cpp/ggml-metal.m" + "./cpp/ggml-metal-impl.h" "./cpp/ggml-quants.h" "./cpp/ggml-quants.c" "./cpp/ggml-alloc.h"