diff --git a/CMakeLists.txt b/CMakeLists.txt index 3a39871..693921c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,18 @@ set(ENCODEC_LIB encodec) add_subdirectory(ggml) -add_library(${ENCODEC_LIB} STATIC encodec.cpp encodec.h) +add_library( + ${ENCODEC_LIB} STATIC + encodec.cpp + encodec.h + encoder.h + decoder.h + quantizer.h + ops.cpp + ops.h + utils.h + lstm.h +) if (ENCODEC_BUILD_EXAMPLES) add_subdirectory(examples) diff --git a/README.md b/README.md index 67052fe..d06657b 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,8 @@ https://github.com/PABannier/encodec.cpp/assets/12958149/d11561be-98e9-4504-bba7 - [x] Support of 24Khz model - [x] Mixed F16 / F32 precision - [ ] 4-bit and 8-bit quantization -- [x] Metal support -- [x] cuBLAS support +- [ ] Metal support +- [ ] CoreML support ## Implementation details @@ -29,7 +29,6 @@ https://github.com/PABannier/encodec.cpp/assets/12958149/d11561be-98e9-4504-bba7 - The encoder-decoder architecture and the high-level C-style API are implemented in C++ ([encodec.h](encodec.h) / [encodec.cpp](encodec.cpp)) - Basic usage is demonstrated in [main.cpp](examples/main). - ## Usage Here are the steps for the encodec model. diff --git a/decoder.h b/decoder.h new file mode 100644 index 0000000..e2a61e4 --- /dev/null +++ b/decoder.h @@ -0,0 +1,113 @@ +#pragma once + +#include + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#include "lstm.h" +#include "utils.h" + + +struct encodec_decoder_block { + // upsampling layers + struct ggml_tensor *us_conv_w; + struct ggml_tensor *us_conv_b; + + // conv1 + struct ggml_tensor *conv_1_w; + struct ggml_tensor *conv_1_b; + + // conv2 + struct ggml_tensor *conv_2_w; + struct ggml_tensor *conv_2_b; + + // shortcut + struct ggml_tensor *conv_sc_w; + struct ggml_tensor *conv_sc_b; +}; + +struct encodec_decoder { + struct ggml_tensor *init_conv_w; + struct ggml_tensor *init_conv_b; + + encodec_lstm lstm; + + struct ggml_tensor *final_conv_w; + struct ggml_tensor *final_conv_b; + + std::vector blocks; +}; + +struct ggml_tensor *encodec_forward_decoder( + const struct encodec_decoder *decoder, struct ggml_allocr *allocr, struct ggml_context *ctx0, + struct ggml_tensor *quantized_out, const int *ratios, const int kernel_size, const int res_kernel_size, + const int stride) { + + if (!quantized_out) { + fprintf(stderr, "%s: null input tensor\n", __func__); + return NULL; + } + + struct ggml_tensor *inpL = strided_conv_1d( + ctx0, quantized_out, decoder->init_conv_w, decoder->init_conv_b, stride); + + // lstm + { + struct ggml_tensor *cur = inpL; + + const encodec_lstm lstm = decoder->lstm; + + // first lstm layer + struct ggml_tensor *hs1 = forward_pass_lstm_unilayer( + ctx0, allocr, cur, lstm.l0_ih_w, lstm.l0_hh_w, + lstm.l0_ih_b, lstm.l0_hh_b); + + // second lstm layer + struct ggml_tensor *out = forward_pass_lstm_unilayer( + ctx0, allocr, hs1, lstm.l1_ih_w, lstm.l1_hh_w, + lstm.l1_ih_b, lstm.l1_hh_b); + + inpL = ggml_add(ctx0, inpL, out); + } + + for (int layer_ix = 0; layer_ix < 4; layer_ix++) { + encodec_decoder_block block = decoder->blocks[layer_ix]; + + // upsampling layers + inpL = ggml_elu(ctx0, inpL); + + inpL = strided_conv_transpose_1d( + ctx0, inpL, block.us_conv_w, block.us_conv_b, ratios[layer_ix]); + + struct ggml_tensor *current = inpL; + + // shortcut + struct ggml_tensor *shortcut = strided_conv_1d( + ctx0, inpL, block.conv_sc_w, block.conv_sc_b, stride); + + // conv1 + current = ggml_elu(ctx0, current); + + current = strided_conv_1d( + ctx0, current, block.conv_1_w, block.conv_1_b, stride); + + // conv2 + current = ggml_elu(ctx0, current); + + current = strided_conv_1d( + ctx0, current, block.conv_2_w, block.conv_2_b, stride); + + // residual connection + inpL = ggml_add(ctx0, current, shortcut); + } + + // final conv + inpL = ggml_elu(ctx0, inpL); + + struct ggml_tensor *decoded_inp = strided_conv_1d( + ctx0, inpL, decoder->final_conv_w, decoder->final_conv_b, stride); + + return decoded_inp; +} diff --git a/encodec.cpp b/encodec.cpp index d332c39..a377d3e 100644 --- a/encodec.cpp +++ b/encodec.cpp @@ -24,12 +24,23 @@ #include "encodec.h" -#define MAX(a, b) ((a) > (b) ? (a) : (b)) -#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#include "decoder.h" +#include "encoder.h" +#include "lstm.h" +#include "ops.h" +#include "utils.h" +#include "quantizer.h" #define ENCODEC_FILE_MAGIC 'ggml' -static const size_t MB = 1024 * 1024; +typedef enum { + // Run the end-to-end encoder-decoder pipeline + FULL = 0, + // Encode an audio (encoder + quantizer encode) + ENCODE = 1, + // Decode an audio from a compressed representation (quantizer decode + decoder) + DECODE = 2, +} encodec_run_mode_t; struct encodec_hparams { // The number of input channels is always 1 (mono). @@ -67,89 +78,6 @@ struct encodec_hparams { int32_t ftype; }; -// res + downsample block at some ratio -struct encodec_encoder_block { - // conv1 - struct ggml_tensor *conv_1_w; - struct ggml_tensor *conv_1_b; - - // conv2 - struct ggml_tensor *conv_2_w; - struct ggml_tensor *conv_2_b; - - // shortcut - struct ggml_tensor *conv_sc_w; - struct ggml_tensor *conv_sc_b; - - // downsampling layers - struct ggml_tensor *ds_conv_w; - struct ggml_tensor *ds_conv_b; -}; - -struct encodec_lstm { - struct ggml_tensor *l0_ih_w; - struct ggml_tensor *l0_hh_w; - - struct ggml_tensor *l0_ih_b; - struct ggml_tensor *l0_hh_b; - - struct ggml_tensor *l1_ih_w; - struct ggml_tensor *l1_hh_w; - - struct ggml_tensor *l1_ih_b; - struct ggml_tensor *l1_hh_b; -}; - -struct encodec_encoder { - struct ggml_tensor *init_conv_w; - struct ggml_tensor *init_conv_b; - - encodec_lstm lstm; - - struct ggml_tensor *final_conv_w; - struct ggml_tensor *final_conv_b; - - std::vector blocks; -}; - -struct encodec_quant_block { - struct ggml_tensor *embed; -}; - -struct encodec_quantizer { - std::vector blocks; -}; - -struct encodec_decoder_block { - // upsampling layers - struct ggml_tensor *us_conv_w; - struct ggml_tensor *us_conv_b; - - // conv1 - struct ggml_tensor *conv_1_w; - struct ggml_tensor *conv_1_b; - - // conv2 - struct ggml_tensor *conv_2_w; - struct ggml_tensor *conv_2_b; - - // shortcut - struct ggml_tensor *conv_sc_w; - struct ggml_tensor *conv_sc_b; -}; - -struct encodec_decoder { - struct ggml_tensor *init_conv_w; - struct ggml_tensor *init_conv_b; - - encodec_lstm lstm; - - struct ggml_tensor *final_conv_w; - struct ggml_tensor *final_conv_b; - - std::vector blocks; -}; - struct encodec_model { encodec_hparams hparams; @@ -189,236 +117,6 @@ struct encodec_context { encodec_statistics stats; }; -typedef enum { - // Run the end-to-end encoder-decoder pipeline - full = 0, - // Encode an audio (encoder + quantizer encode) - encode = 1, - // Decode an audio from a compressed representation (quantizer decode + decoder) - decode = 2, -} encodec_run_mode; - -template -static void read_safe(std::ifstream &infile, T &dest) { - infile.read((char *)&dest, sizeof(T)); -} - -static void ggml_log_callback_default(ggml_log_level level, const char *text, void *user_data) { - (void)level; - (void)user_data; - fputs(text, stderr); - fflush(stderr); -} - -static void encodec_sigmoid_impl( - struct ggml_tensor *dst, - const struct ggml_tensor *src, - int ith, - int nth, - void *userdata) { - GGML_ASSERT(userdata == NULL); - GGML_ASSERT(ggml_are_same_shape(dst, src)); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_is_contiguous(src)); - - const float *src_data = ggml_get_data_f32(src); - float *dst_data = ggml_get_data_f32(dst); - - const int ne = (int)ggml_nelements(dst); - const int dr = (ne + nth - 1) / nth; - const int ie0 = dr * ith; - const int ie1 = std::min(ie0 + dr, ne); - - for (int i = ie0; i < ie1; ++i) { - dst_data[i] = 1.0f / (1.0f + expf(-src_data[i])); - } -} - -static struct ggml_tensor *encodec_sigmoid( - struct ggml_context *ctx, - struct ggml_tensor *x) { - return ggml_map_custom1(ctx, x, encodec_sigmoid_impl, GGML_N_TASKS_MAX, NULL); -} - -static int get_extra_padding_for_conv_1d( - struct ggml_tensor *inp, - float kernel_size, - float stride, - float padding_total) { - float length = inp->ne[0]; - float n_frames = (length - kernel_size + padding_total) / stride + 1.0f; - int ideal_length = (ceilf(n_frames) - 1) * stride + (kernel_size - padding_total); - return ideal_length - length; -} - -static struct ggml_tensor *pad_1d( - struct ggml_context *ctx0, - struct ggml_tensor *inp, - int padding_left, - int padding_right) { - int length = inp->ne[0]; - int dim = inp->ne[1]; - - const int max_pad = std::max(padding_left, padding_right); - int extra_pad = 0; - - if (length <= max_pad) { - extra_pad = max_pad - length + 1; - - // constant padding - struct ggml_tensor *out = ggml_new_tensor_2d(ctx0, inp->type, length + extra_pad, dim); - ggml_set_zero(out); - out = ggml_set_2d(ctx0, out, inp, out->nb[1], 0); - } - - struct ggml_tensor *padded = ggml_pad_reflec_1d(ctx0, inp, padding_left, padding_right); - - const int end = padded->ne[0] - extra_pad; - struct ggml_tensor *dest = ggml_view_2d(ctx0, padded, end, dim, padded->nb[1], 0); - - return dest; -} - -static int32_t get_num_codebooks(float bandwidth, int hop_length, float sample_rate) { - // The number of codebooks is determined by the bandwidth selected. - // Supported bandwidths are 1.5kbps (n_q = 2), 3 kbps (n_q = 4), 6 kbps (n_q = 8), - // 12 kbps (n_q = 16) and 24kbps (n_q = 32). - return (int32_t)ceilf(1000 * bandwidth / (ceilf(sample_rate / hop_length) * 10)); -} - -static int32_t get_bandwidth_per_quantizer(int bins, float frame_rate) { - return log2f((float)bins) * frame_rate; -} - -static int32_t get_num_quantizers_for_bandwidth(int bins, float frame_rate, float bandwidth) { - float bw_per_q = get_bandwidth_per_quantizer(bins, frame_rate); - int32_t n_q = MAX(1, floorf(bandwidth * 1000 / bw_per_q)); - return n_q; -} - -static struct ggml_tensor *unpad_1d( - struct ggml_context *ctx0, - struct ggml_tensor *inp, - int padding_left, - int padding_right) { - int length = inp->ne[0]; - int dim = inp->ne[1]; - - assert(padding_left >= 0); - assert(padding_right >= 0); - assert(padding_left + padding_right <= length); - - int end = length - padding_right; - - int offset = padding_left * inp->nb[1]; - struct ggml_tensor *dst = ggml_view_2d(ctx0, inp, end, dim, inp->nb[1], offset); - - return dst; -} - -static struct ggml_tensor *strided_conv_1d( - ggml_context *ctx0, - ggml_tensor *inp, - ggml_tensor *conv_w, - ggml_tensor *conv_b, - int stride) { - int kernel_size = conv_w->ne[0]; - int padding_total = kernel_size - stride; - int extra_padding = get_extra_padding_for_conv_1d(inp, kernel_size, stride, padding_total); - - struct ggml_tensor *padded_inp = pad_1d(ctx0, inp, padding_total, extra_padding); - struct ggml_tensor *dst = ggml_conv_1d(ctx0, conv_w, padded_inp, stride, 0, 1); - - // add bias - dst = ggml_transpose(ctx0, dst); - dst = ggml_add(ctx0, ggml_repeat(ctx0, conv_b, dst), dst); - dst = ggml_cont(ctx0, ggml_transpose(ctx0, dst)); - - return dst; -} - -static struct ggml_tensor *strided_conv_transpose_1d( - struct ggml_context *ctx0, - struct ggml_tensor *inp, - struct ggml_tensor *conv_w, - struct ggml_tensor *conv_b, - int stride) { - struct ggml_tensor *dst = ggml_conv_transpose_1d( - ctx0, conv_w, inp, stride, 0 /* p0 */, 1 /* d0 */); - - // add bias - dst = ggml_transpose(ctx0, dst); - dst = ggml_add(ctx0, ggml_repeat(ctx0, conv_b, dst), dst); - dst = ggml_cont(ctx0, ggml_transpose(ctx0, dst)); - - int kernel_size = conv_w->ne[0]; - int padding_total = kernel_size - stride; - - int padding_right = ceilf(padding_total); - int padding_left = padding_total - padding_right; - - struct ggml_tensor *unpadded = unpad_1d(ctx0, dst, padding_left, padding_right); - unpadded = ggml_cont(ctx0, unpadded); - - return unpadded; -} - -static struct ggml_tensor *forward_pass_lstm_unilayer( - struct ggml_context *ctx0, - struct ggml_allocr *allocr, - struct ggml_tensor *inp, - struct ggml_tensor *weight_ih, - struct ggml_tensor *weight_hh, - struct ggml_tensor *bias_ih, - struct ggml_tensor *bias_hh) { - const int input_dim = inp->ne[1]; - const int hidden_dim = weight_ih->ne[1] / 4; - const int seq_length = inp->ne[0]; - - struct ggml_tensor *hs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length); - ggml_allocr_alloc(allocr, hs); - - struct ggml_tensor *c_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim); - ggml_allocr_alloc(allocr, c_t); - - struct ggml_tensor *h_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim); - ggml_allocr_alloc(allocr, h_t); - - if (!ggml_allocr_is_measure(allocr)) { - h_t = ggml_set_zero(h_t); - c_t = ggml_set_zero(c_t); - } - - struct ggml_tensor *current = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); - - for (int t = 0; t < seq_length; t++) { - struct ggml_tensor *x_t = ggml_view_1d(ctx0, current, input_dim, t * current->nb[1]); - - struct ggml_tensor *inp_gates = ggml_mul_mat(ctx0, weight_ih, x_t); - inp_gates = ggml_add(ctx0, inp_gates, bias_ih); - - struct ggml_tensor *hid_gates = ggml_mul_mat(ctx0, weight_hh, h_t); - hid_gates = ggml_add(ctx0, hid_gates, bias_hh); - - struct ggml_tensor *out_gates = ggml_add(ctx0, inp_gates, hid_gates); - - struct ggml_tensor *i_t = encodec_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 0 * sizeof(float) * hidden_dim)); - struct ggml_tensor *f_t = encodec_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 1 * sizeof(float) * hidden_dim)); - struct ggml_tensor *g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 2 * sizeof(float) * hidden_dim)); - struct ggml_tensor *o_t = encodec_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 3 * sizeof(float) * hidden_dim)); - - c_t = ggml_add(ctx0, ggml_mul(ctx0, f_t, c_t), ggml_mul(ctx0, i_t, g_t)); - - h_t = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_t)); - - hs = ggml_set_1d(ctx0, hs, h_t, t * hs->nb[1]); - } - - hs = ggml_cont(ctx0, ggml_transpose(ctx0, hs)); - - return hs; -} - bool encodec_load_model_weights(std::ifstream &infile, encodec_model &model, int n_gpu_layers) { // verify magic (i.e. ggml signature in hex format) { @@ -847,8 +545,6 @@ bool encodec_load_model_weights(std::ifstream &infile, encodec_model &model, int ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); } - // printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); - total_size += ggml_nbytes(tensor); model.n_loaded++; } @@ -862,301 +558,26 @@ bool encodec_load_model_weights(std::ifstream &infile, encodec_model &model, int return true; } -struct ggml_tensor *encodec_forward_encoder( - struct encodec_context *ectx, - struct ggml_context *ctx0, - struct ggml_tensor *inp) { - if (!inp) { - fprintf(stderr, "%s: null input tensor\n", __func__); - return NULL; - } - - const auto &model = ectx->model; - const auto &hparams = model.hparams; - const auto &allocr = ectx->allocr; - - const int *ratios = hparams.ratios; - const int kernel_size = hparams.kernel_size; - const int res_kernel_sz = hparams.residual_kernel_size; - const int stride = hparams.stride; - - struct ggml_tensor *inpL = strided_conv_1d( - ctx0, inp, model.encoder.init_conv_w, model.encoder.init_conv_b, stride); - - for (int layer_ix = 0; layer_ix < 4; layer_ix++) { - encodec_encoder_block block = model.encoder.blocks[layer_ix]; - - struct ggml_tensor *current = inpL; - - // shortcut - struct ggml_tensor *shortcut = strided_conv_1d( - ctx0, inpL, block.conv_sc_w, block.conv_sc_b, stride); - - // conv1 - current = ggml_elu(ctx0, current); - - current = strided_conv_1d( - ctx0, current, block.conv_1_w, block.conv_1_b, stride); - - // conv2 - current = ggml_elu(ctx0, current); - - current = strided_conv_1d( - ctx0, current, block.conv_2_w, block.conv_2_b, stride); - - // residual connection - inpL = ggml_add(ctx0, current, shortcut); - - // downsampling layers - inpL = ggml_elu(ctx0, inpL); - - inpL = strided_conv_1d( - ctx0, inpL, block.ds_conv_w, block.ds_conv_b, ratios[3 - layer_ix]); - } - - // lstm - { - struct ggml_tensor *cur = inpL; - - const encodec_lstm lstm = model.encoder.lstm; - - // first lstm layer - struct ggml_tensor *hs1 = forward_pass_lstm_unilayer( - ctx0, allocr, cur, lstm.l0_ih_w, lstm.l0_hh_w, - lstm.l0_ih_b, lstm.l0_hh_b); - - // second lstm layer - struct ggml_tensor *out = forward_pass_lstm_unilayer( - ctx0, allocr, hs1, lstm.l1_ih_w, lstm.l1_hh_w, - lstm.l1_ih_b, lstm.l1_hh_b); - - inpL = ggml_add(ctx0, inpL, out); - } - - // final conv - inpL = ggml_elu(ctx0, inpL); - - struct ggml_tensor *encoded_inp = strided_conv_1d( - ctx0, inpL, model.encoder.final_conv_w, model.encoder.final_conv_b, stride); - - return encoded_inp; -} - -struct ggml_tensor *encodec_forward_quantizer_encode( - struct encodec_context *ectx, - struct ggml_context *ctx0, - struct ggml_tensor *encoded_inp) { - if (!encoded_inp) { - fprintf(stderr, "%s: null input tensor\n", __func__); - return NULL; - } - - const auto &model = ectx->model; - const auto &hparams = model.hparams; - const auto &allocr = ectx->allocr; - - const int n_bins = hparams.n_bins; - const int sr = hparams.sr; - const int bandwidth = hparams.bandwidth; - const int hop_length = hparams.hop_length; - - const int frame_rate = (int)ceilf(sr / hop_length); - const int n_q = get_num_quantizers_for_bandwidth(n_bins, frame_rate, bandwidth); - - const int seq_length = encoded_inp->ne[0]; - - struct ggml_tensor *codes = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, seq_length, n_q); - ggml_allocr_alloc(allocr, codes); - - struct ggml_tensor *dist_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - ggml_allocr_alloc(allocr, dist_scale); - - if (!ggml_allocr_is_measure(allocr)) { - float s = -2.0f; - ggml_backend_tensor_set(dist_scale, &s, 0, sizeof(s)); - } - - struct ggml_tensor *inpL = ggml_cont(ctx0, ggml_transpose(ctx0, encoded_inp)); - struct ggml_tensor *residual = inpL; - struct ggml_tensor *indices; - - for (int i = 0; i < n_q; i++) { - encodec_quant_block block = model.quantizer.blocks[i]; - - // compute distance - // [seq_length, n_bins] - struct ggml_tensor *dp = ggml_scale( - ctx0, ggml_mul_mat(ctx0, block.embed, residual), dist_scale); - - // [n_bins] - struct ggml_tensor *sqr_embed = ggml_sqr(ctx0, block.embed); - struct ggml_tensor *sqr_embed_nrm = ggml_sum_rows(ctx0, sqr_embed); - - // [seq_length] - struct ggml_tensor *sqr_inp = ggml_sqr(ctx0, residual); - struct ggml_tensor *sqr_inp_nrm = ggml_sum_rows(ctx0, sqr_inp); - - // [seq_length, n_bins] - struct ggml_tensor *dist = ggml_add(ctx0, ggml_repeat(ctx0, sqr_inp_nrm, dp), dp); - dist = ggml_add(ctx0, ggml_repeat(ctx0, ggml_transpose(ctx0, sqr_embed_nrm), dist), dist); - dist = ggml_neg(ctx0, dist); - - // take the argmax over the column dimension - // [seq_length] - indices = ggml_argmax(ctx0, dist); - - // look up in embedding table - struct ggml_tensor *quantized = ggml_get_rows(ctx0, block.embed, indices); - - residual = ggml_sub(ctx0, residual, quantized); - - codes = ggml_set_1d(ctx0, codes, indices, i * codes->nb[1]); - } - - return codes; -} - -struct ggml_tensor *encodec_forward_quantizer_decode( - struct encodec_context *ectx, - struct ggml_context *ctx0, - struct ggml_tensor *codes) { - if (!codes) { - fprintf(stderr, "%s: null input tensor\n", __func__); - return NULL; - } - - const auto &model = ectx->model; - const auto &hparams = model.hparams; - const auto &allocr = ectx->allocr; - - const int hidden_dim = hparams.hidden_dim; - const int seq_length = codes->ne[0]; +struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, + const float * inp_audio, + const int n_samples, + const encodec_run_mode_t mode) { + assert(mode == encodec_run_mode_t::FULL || mode == encodec_run_mode_t::ENCODE); - const int n_bins = hparams.n_bins; - const int sr = hparams.sr; - const int bandwidth = hparams.bandwidth; - const int hop_length = hparams.hop_length; + const auto & model = ectx->model; + const auto & hparams = model.hparams; + const auto & allocr = ectx->allocr; - const int frame_rate = (int)ceilf(sr / hop_length); - const int n_q = get_num_quantizers_for_bandwidth(n_bins, frame_rate, bandwidth); - - assert(n_q == codes->ne[1]); - - struct ggml_tensor *quantized_out = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length); - ggml_allocr_alloc(allocr, quantized_out); - - if (!ggml_allocr_is_measure(allocr)) { - quantized_out = ggml_set_zero(quantized_out); - } - - for (int i = 0; i < n_q; i++) { - encodec_quant_block block = model.quantizer.blocks[i]; - - struct ggml_tensor *indices = ggml_view_1d(ctx0, codes, seq_length, i * codes->nb[1]); - struct ggml_tensor *quantized = ggml_get_rows(ctx0, block.embed, indices); - - quantized_out = ggml_add(ctx0, quantized_out, quantized); - } - - quantized_out = ggml_cont(ctx0, ggml_transpose(ctx0, quantized_out)); - - return quantized_out; -} - -struct ggml_tensor *encodec_forward_decoder( - struct encodec_context *ectx, - struct ggml_context *ctx0, - struct ggml_tensor *quantized_out) { - if (!quantized_out) { - fprintf(stderr, "%s: null input tensor\n", __func__); - return NULL; - } - - const auto &model = ectx->model; - const auto &hparams = model.hparams; - const auto &allocr = ectx->allocr; - - const int *ratios = hparams.ratios; - const int kernel_size = hparams.kernel_size; + const int *ratios = hparams.ratios; + const int kernel_size = hparams.kernel_size; const int res_kernel_sz = hparams.residual_kernel_size; - const int stride = hparams.stride; - - struct ggml_tensor *inpL = strided_conv_1d( - ctx0, quantized_out, model.decoder.init_conv_w, - model.decoder.init_conv_b, stride); - - // lstm - { - struct ggml_tensor *cur = inpL; - - const encodec_lstm lstm = model.decoder.lstm; - - // first lstm layer - struct ggml_tensor *hs1 = forward_pass_lstm_unilayer( - ctx0, allocr, cur, lstm.l0_ih_w, lstm.l0_hh_w, - lstm.l0_ih_b, lstm.l0_hh_b); - - // second lstm layer - struct ggml_tensor *out = forward_pass_lstm_unilayer( - ctx0, allocr, hs1, lstm.l1_ih_w, lstm.l1_hh_w, - lstm.l1_ih_b, lstm.l1_hh_b); - - inpL = ggml_add(ctx0, inpL, out); - } - - for (int layer_ix = 0; layer_ix < 4; layer_ix++) { - encodec_decoder_block block = model.decoder.blocks[layer_ix]; - - // upsampling layers - inpL = ggml_elu(ctx0, inpL); - - inpL = strided_conv_transpose_1d( - ctx0, inpL, block.us_conv_w, block.us_conv_b, ratios[layer_ix]); - - struct ggml_tensor *current = inpL; - - // shortcut - struct ggml_tensor *shortcut = strided_conv_1d( - ctx0, inpL, block.conv_sc_w, block.conv_sc_b, stride); - - // conv1 - current = ggml_elu(ctx0, current); - - current = strided_conv_1d( - ctx0, current, block.conv_1_w, block.conv_1_b, stride); - - // conv2 - current = ggml_elu(ctx0, current); - - current = strided_conv_1d( - ctx0, current, block.conv_2_w, block.conv_2_b, stride); - - // residual connection - inpL = ggml_add(ctx0, current, shortcut); - } - - // final conv - inpL = ggml_elu(ctx0, inpL); - - struct ggml_tensor *decoded_inp = strided_conv_1d( - ctx0, inpL, model.decoder.final_conv_w, - model.decoder.final_conv_b, stride); - - return decoded_inp; -} - -struct ggml_cgraph *encodec_build_graph( - struct encodec_context *ectx, - const float * inp_audio, - const int n_samples, - const encodec_run_mode mode) { - assert(mode == encodec_run_mode::full || mode == encodec_run_mode::encode); - - const auto &model = ectx->model; - const auto &hparams = model.hparams; - const auto &allocr = ectx->allocr; - - const int n_q = hparams.n_q; + const int stride = hparams.stride; + const int n_bins = hparams.n_bins; + const int n_q = hparams.n_q; + const int sr = hparams.sr; + const int bandwidth = hparams.bandwidth; + const int hop_length = hparams.hop_length; + const int hidden_dim = hparams.hidden_dim; // since we are using ggml-alloc, this buffer only needs enough space to hold the // ggml_tensor and ggml_cgraph structs, but not the tensor data @@ -1181,19 +602,34 @@ struct ggml_cgraph *encodec_build_graph( ggml_backend_tensor_set(inp, inp_audio, 0, n_samples * ggml_element_size(inp)); } - struct ggml_tensor *encoded = encodec_forward_encoder(ectx, ctx0, inp); - struct ggml_tensor *codes = encodec_forward_quantizer_encode(ectx, ctx0, encoded); - struct ggml_tensor *quantized = encodec_forward_quantizer_decode(ectx, ctx0, codes); - struct ggml_tensor *decoded = encodec_forward_decoder(ectx, ctx0, quantized); + const struct encodec_encoder *encoder = &model.encoder; + const struct encodec_quantizer *quantizer = &model.quantizer; + const struct encodec_decoder *decoder = &model.decoder; + + struct ggml_tensor * encoded = encodec_forward_encoder( + encoder, allocr, ctx0, inp, ratios, kernel_size, res_kernel_sz, stride + ); + + struct ggml_tensor * codes = encodec_forward_quantizer_encode( + quantizer, allocr, ctx0, encoded, n_bins, sr, bandwidth, hop_length + ); + + struct ggml_tensor * quantized = encodec_forward_quantizer_decode( + quantizer, allocr, ctx0, codes, hidden_dim, n_bins, sr, bandwidth, hop_length + ); + + struct ggml_tensor * decoded = encodec_forward_decoder( + decoder, allocr, ctx0, quantized, ratios, kernel_size, res_kernel_sz, stride + ); switch (mode) { - case encodec_run_mode::full: { + case encodec_run_mode_t::FULL: { ggml_build_forward_expand(gf, decoded); } break; - case encodec_run_mode::encode: { + case encodec_run_mode_t::ENCODE: { ggml_build_forward_expand(gf, codes); } break; - case encodec_run_mode::decode: { + case encodec_run_mode_t::DECODE: { return NULL; } break; default: { @@ -1205,27 +641,29 @@ struct ggml_cgraph *encodec_build_graph( ggml_free(ctx0); ectx->encoded = encoded; - ectx->codes = codes; + ectx->codes = codes; ectx->decoded = decoded; return gf; } -struct ggml_cgraph *encodec_build_graph( - struct encodec_context *ectx, - const int32_t * codes, - const int n_codes, - const encodec_run_mode mode) { - assert(mode == encodec_run_mode::decode); - - const auto &model = ectx->model; - const auto &hparams = model.hparams; - const auto &allocr = ectx->allocr; - - const int n_bins = hparams.n_bins; - const int sr = hparams.sr; - const int bandwidth = hparams.bandwidth; - const int hop_length = hparams.hop_length; +struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, const int32_t *codes, + const int n_codes, const encodec_run_mode_t mode) { + assert(mode == encodec_run_mode_t::DECODE); + + const auto & model = ectx->model; + const auto & hparams = model.hparams; + const auto & allocr = ectx->allocr; + + const int n_bins = hparams.n_bins; + const int sr = hparams.sr; + const int bandwidth = hparams.bandwidth; + const int hop_length = hparams.hop_length; + const int hidden_dim = hparams.hidden_dim; + const int * ratios = hparams.ratios; + const int kernel_size = hparams.kernel_size; + const int res_kernel_sz = hparams.residual_kernel_size; + const int stride = hparams.stride; const int frame_rate = (int)ceilf(sr / hop_length); const int n_q = get_num_quantizers_for_bandwidth(n_bins, frame_rate, bandwidth); @@ -1260,11 +698,19 @@ struct ggml_cgraph *encodec_build_graph( ggml_backend_tensor_set(inp_codes, codes, 0, N * n_q * ggml_element_size(inp_codes)); } - struct ggml_tensor *quantized = encodec_forward_quantizer_decode(ectx, ctx0, inp_codes); - struct ggml_tensor *decoded = encodec_forward_decoder(ectx, ctx0, quantized); + const struct encodec_quantizer *quantizer = &model.quantizer; + const struct encodec_decoder *decoder = &model.decoder; + + struct ggml_tensor *quantized = encodec_forward_quantizer_decode( + quantizer, allocr, ctx0, inp_codes, hidden_dim, n_bins, sr, bandwidth, hop_length + ); + + struct ggml_tensor *decoded = encodec_forward_decoder( + decoder, allocr, ctx0, quantized, ratios, kernel_size, res_kernel_sz, stride + ); switch (mode) { - case encodec_run_mode::decode: { + case encodec_run_mode_t::DECODE: { ggml_build_forward_expand(gf, decoded); } break; default: { @@ -1275,20 +721,17 @@ struct ggml_cgraph *encodec_build_graph( ggml_free(ctx0); - ectx->codes = inp_codes; + ectx->codes = inp_codes; ectx->decoded = decoded; return gf; } -bool encodec_eval_internal( - struct encodec_context *ectx, - const float * raw_audio, - const int n_samples, - const int n_threads, - const encodec_run_mode mode) { - auto &model = ectx->model; - auto &allocr = ectx->allocr; +bool encodec_eval_internal(struct encodec_context *ectx, const float * raw_audio, + const int n_samples, const int n_threads, + const encodec_run_mode_t mode) { + auto & model = ectx->model; + auto & allocr = ectx->allocr; // reset the allocator to free all the memory allocated during the previous inference ggml_allocr_reset(allocr); @@ -1312,14 +755,11 @@ bool encodec_eval_internal( return true; } -bool encodec_eval_internal( - struct encodec_context *ectx, - const int32_t * codes, - const int n_codes, - const int n_threads, - const encodec_run_mode mode) { - auto &model = ectx->model; - auto &allocr = ectx->allocr; +bool encodec_eval_internal(struct encodec_context *ectx, const int32_t *codes, + const int n_codes, const int n_threads, + const encodec_run_mode_t mode) { + auto & model = ectx->model; + auto & allocr = ectx->allocr; // reset the allocator to free all the memory allocated during the previous inference ggml_allocr_reset(allocr); @@ -1343,12 +783,9 @@ bool encodec_eval_internal( return true; } -bool encodec_eval( - struct encodec_context *ectx, - const float *raw_audio, - const int n_samples, - const int n_threads, - const encodec_run_mode mode) { +bool encodec_eval(struct encodec_context *ectx, const float *raw_audio, + const int n_samples, const int n_threads, + const encodec_run_mode_t mode) { const int64_t t_start_us = ggml_time_us(); // allocate the compute buffer @@ -1382,12 +819,9 @@ bool encodec_eval( return true; } -bool encodec_eval( - struct encodec_context *ectx, - const int32_t *codes, - const int n_codes, - const int n_threads, - const encodec_run_mode mode) { +bool encodec_eval(struct encodec_context *ectx, const int32_t *codes, + const int n_codes, const int n_threads, + const encodec_run_mode_t mode) { const int64_t t_start_ms = ggml_time_us(); // allocate the compute buffer @@ -1421,17 +855,14 @@ bool encodec_eval( return true; } -bool encodec_reconstruct_audio( - struct encodec_context *ectx, - const float *raw_audio, - const int n_samples, - int n_threads) { +bool encodec_reconstruct_audio(struct encodec_context *ectx, const float *raw_audio, + const int n_samples, const int n_threads) { if (raw_audio == nullptr) { fprintf(stderr, "%s: null input audio\n", __func__); return false; } - if (!encodec_eval(ectx, raw_audio, n_samples, n_threads, encodec_run_mode::full)) { + if (!encodec_eval(ectx, raw_audio, n_samples, n_threads, encodec_run_mode_t::FULL)) { fprintf(stderr, "%s: failed to run encodec eval\n", __func__); return false; } @@ -1453,12 +884,9 @@ bool encodec_reconstruct_audio( return true; } -bool encodec_compress_audio( - struct encodec_context *ectx, - const float * raw_audio, - const int n_samples, - int n_threads) { - if (!encodec_eval(ectx, raw_audio, n_samples, n_threads, encodec_run_mode::encode)) { +bool encodec_compress_audio(struct encodec_context *ectx, const float *raw_audio, + const int n_samples, const int n_threads) { + if (!encodec_eval(ectx, raw_audio, n_samples, n_threads, encodec_run_mode_t::ENCODE)) { fprintf(stderr, "%s: failed to run encodec eval\n", __func__); return false; } @@ -1480,12 +908,9 @@ bool encodec_compress_audio( return true; } -bool encodec_decompress_audio( - struct encodec_context *ectx, - const int32_t * codes, - const int n_codes, - int n_threads) { - if (!encodec_eval(ectx, codes, n_codes, n_threads, encodec_run_mode::decode)) { +bool encodec_decompress_audio(struct encodec_context *ectx, const int32_t *codes, + const int n_codes, const int n_threads) { + if (!encodec_eval(ectx, codes, n_codes, n_threads, encodec_run_mode_t::DECODE)) { fprintf(stderr, "%s: failed to run encodec eval\n", __func__); return false; } diff --git a/encoder.h b/encoder.h new file mode 100644 index 0000000..12e4413 --- /dev/null +++ b/encoder.h @@ -0,0 +1,109 @@ +#pragma once + +#include + +#include "ggml.h" +#include "lstm.h" + +// res + downsample block at some ratio +struct encodec_encoder_block { + // conv1 + struct ggml_tensor *conv_1_w; + struct ggml_tensor *conv_1_b; + + // conv2 + struct ggml_tensor *conv_2_w; + struct ggml_tensor *conv_2_b; + + // shortcut + struct ggml_tensor *conv_sc_w; + struct ggml_tensor *conv_sc_b; + + // downsampling layers + struct ggml_tensor *ds_conv_w; + struct ggml_tensor *ds_conv_b; +}; + +struct encodec_encoder { + struct ggml_tensor *init_conv_w; + struct ggml_tensor *init_conv_b; + + encodec_lstm lstm; + + struct ggml_tensor *final_conv_w; + struct ggml_tensor *final_conv_b; + + std::vector blocks; +}; + +struct ggml_tensor *encodec_forward_encoder( + const struct encodec_encoder *encoder, struct ggml_allocr *allocr, struct ggml_context *ctx0, + struct ggml_tensor *inp, const int * ratios, const int kernel_size, const int res_kernel_size, + const int stride) { + + if (!inp) { + fprintf(stderr, "%s: null input tensor\n", __func__); + return NULL; + } + + struct ggml_tensor *inpL = strided_conv_1d( + ctx0, inp, encoder->init_conv_w, encoder->init_conv_b, stride); + + for (int layer_ix = 0; layer_ix < 4; layer_ix++) { + encodec_encoder_block block = encoder->blocks[layer_ix]; + + struct ggml_tensor *current = inpL; + + // shortcut + struct ggml_tensor *shortcut = strided_conv_1d( + ctx0, inpL, block.conv_sc_w, block.conv_sc_b, stride); + + // conv1 + current = ggml_elu(ctx0, current); + + current = strided_conv_1d( + ctx0, current, block.conv_1_w, block.conv_1_b, stride); + + // conv2 + current = ggml_elu(ctx0, current); + + current = strided_conv_1d( + ctx0, current, block.conv_2_w, block.conv_2_b, stride); + + // residual connection + inpL = ggml_add(ctx0, current, shortcut); + + // downsampling layers + inpL = ggml_elu(ctx0, inpL); + + inpL = strided_conv_1d( + ctx0, inpL, block.ds_conv_w, block.ds_conv_b, ratios[3 - layer_ix]); + } + + // lstm + { + struct ggml_tensor *cur = inpL; + + const encodec_lstm lstm = encoder->lstm; + + // first lstm layer + struct ggml_tensor *hs1 = forward_pass_lstm_unilayer( + ctx0, allocr, cur, lstm.l0_ih_w, lstm.l0_hh_w, + lstm.l0_ih_b, lstm.l0_hh_b); + + // second lstm layer + struct ggml_tensor *out = forward_pass_lstm_unilayer( + ctx0, allocr, hs1, lstm.l1_ih_w, lstm.l1_hh_w, + lstm.l1_ih_b, lstm.l1_hh_b); + + inpL = ggml_add(ctx0, inpL, out); + } + + // final conv + inpL = ggml_elu(ctx0, inpL); + + struct ggml_tensor *encoded_inp = strided_conv_1d( + ctx0, inpL, encoder->final_conv_w, encoder->final_conv_b, stride); + + return encoded_inp; +} diff --git a/lstm.h b/lstm.h new file mode 100644 index 0000000..31c251d --- /dev/null +++ b/lstm.h @@ -0,0 +1,75 @@ +#pragma once + +#include "ggml.h" +#include "ggml-alloc.h" + +#include "ops.h" + +struct encodec_lstm { + struct ggml_tensor *l0_ih_w; + struct ggml_tensor *l0_hh_w; + + struct ggml_tensor *l0_ih_b; + struct ggml_tensor *l0_hh_b; + + struct ggml_tensor *l1_ih_w; + struct ggml_tensor *l1_hh_w; + + struct ggml_tensor *l1_ih_b; + struct ggml_tensor *l1_hh_b; +}; + +struct ggml_tensor *forward_pass_lstm_unilayer(struct ggml_context *ctx0, + struct ggml_allocr *allocr, + struct ggml_tensor *inp, + struct ggml_tensor *weight_ih, + struct ggml_tensor *weight_hh, + struct ggml_tensor *bias_ih, + struct ggml_tensor *bias_hh) { + const int input_dim = inp->ne[1]; + const int hidden_dim = weight_ih->ne[1] / 4; + const int seq_length = inp->ne[0]; + + struct ggml_tensor *hs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length); + ggml_allocr_alloc(allocr, hs); + + struct ggml_tensor *c_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim); + ggml_allocr_alloc(allocr, c_t); + + struct ggml_tensor *h_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim); + ggml_allocr_alloc(allocr, h_t); + + if (!ggml_allocr_is_measure(allocr)) { + h_t = ggml_set_zero(h_t); + c_t = ggml_set_zero(c_t); + } + + struct ggml_tensor *current = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + + for (int t = 0; t < seq_length; t++) { + struct ggml_tensor *x_t = ggml_view_1d(ctx0, current, input_dim, t * current->nb[1]); + + struct ggml_tensor *inp_gates = ggml_mul_mat(ctx0, weight_ih, x_t); + inp_gates = ggml_add(ctx0, inp_gates, bias_ih); + + struct ggml_tensor *hid_gates = ggml_mul_mat(ctx0, weight_hh, h_t); + hid_gates = ggml_add(ctx0, hid_gates, bias_hh); + + struct ggml_tensor *out_gates = ggml_add(ctx0, inp_gates, hid_gates); + + struct ggml_tensor *i_t = encodec_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 0 * sizeof(float) * hidden_dim)); + struct ggml_tensor *f_t = encodec_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 1 * sizeof(float) * hidden_dim)); + struct ggml_tensor *g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 2 * sizeof(float) * hidden_dim)); + struct ggml_tensor *o_t = encodec_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 3 * sizeof(float) * hidden_dim)); + + c_t = ggml_add(ctx0, ggml_mul(ctx0, f_t, c_t), ggml_mul(ctx0, i_t, g_t)); + + h_t = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_t)); + + hs = ggml_set_1d(ctx0, hs, h_t, t * hs->nb[1]); + } + + hs = ggml_cont(ctx0, ggml_transpose(ctx0, hs)); + + return hs; +} diff --git a/ops.cpp b/ops.cpp new file mode 100644 index 0000000..18c0acc --- /dev/null +++ b/ops.cpp @@ -0,0 +1,123 @@ +#include +#include +#include +#include + +#include "ggml.h" + +#include "ops.h" + +static void encodec_sigmoid_impl(struct ggml_tensor *dst, const struct ggml_tensor *src, + int ith, int nth, void *userdata) { + GGML_ASSERT(userdata == NULL); + GGML_ASSERT(ggml_are_same_shape(dst, src)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src)); + + const float *src_data = ggml_get_data_f32(src); + float *dst_data = ggml_get_data_f32(dst); + + const int ne = (int)ggml_nelements(dst); + const int dr = (ne + nth - 1) / nth; + const int ie0 = dr * ith; + const int ie1 = std::min(ie0 + dr, ne); + + for (int i = ie0; i < ie1; ++i) { + dst_data[i] = 1.0f / (1.0f + expf(-src_data[i])); + } +} + +static int get_extra_padding_for_conv_1d(struct ggml_tensor *inp, float kernel_size, + float stride, float padding_total) { + float length = inp->ne[0]; + float n_frames = (length - kernel_size + padding_total) / stride + 1.0f; + int ideal_length = (ceilf(n_frames) - 1) * stride + (kernel_size - padding_total); + return ideal_length - length; +} + +struct ggml_tensor *encodec_sigmoid(struct ggml_context *ctx, struct ggml_tensor *x) { + return ggml_map_custom1(ctx, x, encodec_sigmoid_impl, GGML_N_TASKS_MAX, NULL); +} + +struct ggml_tensor *pad_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, + int padding_left, int padding_right) { + int length = inp->ne[0]; + int dim = inp->ne[1]; + + const int max_pad = std::max(padding_left, padding_right); + int extra_pad = 0; + + if (length <= max_pad) { + extra_pad = max_pad - length + 1; + + // constant padding + struct ggml_tensor *out = ggml_new_tensor_2d(ctx0, inp->type, length + extra_pad, dim); + ggml_set_zero(out); + out = ggml_set_2d(ctx0, out, inp, out->nb[1], 0); + } + + struct ggml_tensor *padded = ggml_pad_reflec_1d(ctx0, inp, padding_left, padding_right); + + const int end = padded->ne[0] - extra_pad; + struct ggml_tensor *dest = ggml_view_2d(ctx0, padded, end, dim, padded->nb[1], 0); + + return dest; +} + +struct ggml_tensor *unpad_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, + int padding_left, int padding_right) { + int length = inp->ne[0]; + int dim = inp->ne[1]; + + assert(padding_left >= 0); + assert(padding_right >= 0); + assert(padding_left + padding_right <= length); + + int end = length - padding_right; + + int offset = padding_left * inp->nb[1]; + struct ggml_tensor *dst = ggml_view_2d(ctx0, inp, end, dim, inp->nb[1], offset); + + return dst; +} + +struct ggml_tensor *strided_conv_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, + struct ggml_tensor *conv_w, struct ggml_tensor *conv_b, + int stride) { + int kernel_size = conv_w->ne[0]; + int padding_total = kernel_size - stride; + int extra_padding = get_extra_padding_for_conv_1d(inp, kernel_size, stride, padding_total); + + struct ggml_tensor *padded_inp = pad_1d(ctx0, inp, padding_total, extra_padding); + struct ggml_tensor *dst = ggml_conv_1d(ctx0, conv_w, padded_inp, stride, 0, 1); + + // add bias + dst = ggml_transpose(ctx0, dst); + dst = ggml_add(ctx0, ggml_repeat(ctx0, conv_b, dst), dst); + dst = ggml_cont(ctx0, ggml_transpose(ctx0, dst)); + + return dst; +} + +struct ggml_tensor *strided_conv_transpose_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, + struct ggml_tensor *conv_w, struct ggml_tensor *conv_b, + int stride) { + struct ggml_tensor *dst = ggml_conv_transpose_1d( + ctx0, conv_w, inp, stride, 0 /* p0 */, 1 /* d0 */); + + // add bias + dst = ggml_transpose(ctx0, dst); + dst = ggml_add(ctx0, ggml_repeat(ctx0, conv_b, dst), dst); + dst = ggml_cont(ctx0, ggml_transpose(ctx0, dst)); + + int kernel_size = conv_w->ne[0]; + int padding_total = kernel_size - stride; + + int padding_right = ceilf(padding_total); + int padding_left = padding_total - padding_right; + + struct ggml_tensor *unpadded = unpad_1d(ctx0, dst, padding_left, padding_right); + unpadded = ggml_cont(ctx0, unpadded); + + return unpadded; +} diff --git a/ops.h b/ops.h new file mode 100644 index 0000000..891aa90 --- /dev/null +++ b/ops.h @@ -0,0 +1,19 @@ +#pragma once + +#include "ggml.h" + +struct ggml_tensor *encodec_sigmoid(struct ggml_context *ctx, struct ggml_tensor *x); + +struct ggml_tensor *pad_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, + int padding_left, int padding_right); + +struct ggml_tensor *unpad_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, + int padding_left, int padding_right); + +struct ggml_tensor *strided_conv_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, + struct ggml_tensor *conv_w, struct ggml_tensor *conv_b, + int stride); + +struct ggml_tensor *strided_conv_transpose_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, + struct ggml_tensor *conv_w, struct ggml_tensor *conv_b, + int stride); diff --git a/quantizer.h b/quantizer.h new file mode 100644 index 0000000..523c594 --- /dev/null +++ b/quantizer.h @@ -0,0 +1,122 @@ +#pragma once + +#include +#include + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#include "utils.h" + +struct encodec_quant_block { + struct ggml_tensor *embed; +}; + +struct encodec_quantizer { + std::vector blocks; +}; + +struct ggml_tensor *encodec_forward_quantizer_encode( + const struct encodec_quantizer *quantizer, struct ggml_allocr *allocr, struct ggml_context *ctx0, + struct ggml_tensor *encoded_inp, const int n_bins, const int sr, const int bandwidth, + const int hop_length) { + + if (!encoded_inp) { + fprintf(stderr, "%s: null input tensor\n", __func__); + return NULL; + } + + const int frame_rate = (int)ceilf(sr / hop_length); + const int n_q = get_num_quantizers_for_bandwidth(n_bins, frame_rate, bandwidth); + + const int seq_length = encoded_inp->ne[0]; + + struct ggml_tensor *codes = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, seq_length, n_q); + ggml_allocr_alloc(allocr, codes); + + struct ggml_tensor *dist_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(allocr, dist_scale); + + if (!ggml_allocr_is_measure(allocr)) { + float s = -2.0f; + ggml_backend_tensor_set(dist_scale, &s, 0, sizeof(s)); + } + + struct ggml_tensor *inpL = ggml_cont(ctx0, ggml_transpose(ctx0, encoded_inp)); + struct ggml_tensor *residual = inpL; + struct ggml_tensor *indices; + + for (int i = 0; i < n_q; i++) { + encodec_quant_block block = quantizer->blocks[i]; + + // compute distance + // [seq_length, n_bins] + struct ggml_tensor *dp = ggml_scale( + ctx0, ggml_mul_mat(ctx0, block.embed, residual), dist_scale); + + // [n_bins] + struct ggml_tensor *sqr_embed = ggml_sqr(ctx0, block.embed); + struct ggml_tensor *sqr_embed_nrm = ggml_sum_rows(ctx0, sqr_embed); + + // [seq_length] + struct ggml_tensor *sqr_inp = ggml_sqr(ctx0, residual); + struct ggml_tensor *sqr_inp_nrm = ggml_sum_rows(ctx0, sqr_inp); + + // [seq_length, n_bins] + struct ggml_tensor *dist = ggml_add(ctx0, ggml_repeat(ctx0, sqr_inp_nrm, dp), dp); + dist = ggml_add(ctx0, ggml_repeat(ctx0, ggml_transpose(ctx0, sqr_embed_nrm), dist), dist); + dist = ggml_neg(ctx0, dist); + + // take the argmax over the column dimension + // [seq_length] + indices = ggml_argmax(ctx0, dist); + + // look up in embedding table + struct ggml_tensor *quantized = ggml_get_rows(ctx0, block.embed, indices); + + residual = ggml_sub(ctx0, residual, quantized); + + codes = ggml_set_1d(ctx0, codes, indices, i * codes->nb[1]); + } + + return codes; +} + +struct ggml_tensor *encodec_forward_quantizer_decode( + const struct encodec_quantizer *quantizer, struct ggml_allocr *allocr, struct ggml_context *ctx0, + struct ggml_tensor *codes, const int hidden_dim, const int n_bins, const int sr, const int bandwidth, + const int hop_length) { + + if (!codes) { + fprintf(stderr, "%s: null input tensor\n", __func__); + return NULL; + } + + const int seq_length = codes->ne[0]; + + const int frame_rate = (int)ceilf(sr / hop_length); + const int n_q = get_num_quantizers_for_bandwidth(n_bins, frame_rate, bandwidth); + + assert(n_q == codes->ne[1]); + + struct ggml_tensor *quantized_out = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length); + ggml_allocr_alloc(allocr, quantized_out); + + if (!ggml_allocr_is_measure(allocr)) { + quantized_out = ggml_set_zero(quantized_out); + } + + for (int i = 0; i < n_q; i++) { + encodec_quant_block block = quantizer->blocks[i]; + + struct ggml_tensor *indices = ggml_view_1d(ctx0, codes, seq_length, i * codes->nb[1]); + struct ggml_tensor *quantized = ggml_get_rows(ctx0, block.embed, indices); + + quantized_out = ggml_add(ctx0, quantized_out, quantized); + } + + quantized_out = ggml_cont(ctx0, ggml_transpose(ctx0, quantized_out)); + + return quantized_out; +} diff --git a/utils.h b/utils.h new file mode 100644 index 0000000..a5d72fa --- /dev/null +++ b/utils.h @@ -0,0 +1,37 @@ +#pragma once + +#include + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +const size_t MB = 1024 * 1024; + +template +void read_safe(std::ifstream &infile, T &dest) { + infile.read((char *)&dest, sizeof(T)); +} + +int32_t get_num_codebooks(float bandwidth, int hop_length, float sample_rate) { + // The number of codebooks is determined by the bandwidth selected. + // Supported bandwidths are 1.5kbps (n_q = 2), 3 kbps (n_q = 4), 6 kbps (n_q = 8), + // 12 kbps (n_q = 16) and 24kbps (n_q = 32). + return (int32_t)ceilf(1000 * bandwidth / (ceilf(sample_rate / hop_length) * 10)); +} + +int32_t get_bandwidth_per_quantizer(int bins, float frame_rate) { + return log2f((float)bins) * frame_rate; +} + +int32_t get_num_quantizers_for_bandwidth(int bins, float frame_rate, float bandwidth) { + float bw_per_q = get_bandwidth_per_quantizer(bins, frame_rate); + int32_t n_q = MAX(1, floorf(bandwidth * 1000 / bw_per_q)); + return n_q; +} + +void ggml_log_callback_default(ggml_log_level level, const char *text, void *user_data) { + (void)level; + (void)user_data; + fputs(text, stderr); + fflush(stderr); +}