Skip to content

Commit

Permalink
chore(maint): split encodec.cpp into smaller headers (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier authored Oct 14, 2024
1 parent 91e4cb4 commit a1aac2e
Show file tree
Hide file tree
Showing 10 changed files with 720 additions and 687 deletions.
13 changes: 12 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,18 @@ set(ENCODEC_LIB encodec)

add_subdirectory(ggml)

add_library(${ENCODEC_LIB} STATIC encodec.cpp encodec.h)
add_library(
${ENCODEC_LIB} STATIC
encodec.cpp
encodec.h
encoder.h
decoder.h
quantizer.h
ops.cpp
ops.h
utils.h
lstm.h
)

if (ENCODEC_BUILD_EXAMPLES)
add_subdirectory(examples)
Expand Down
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@ https://github.com/PABannier/encodec.cpp/assets/12958149/d11561be-98e9-4504-bba7
- [x] Support of 24Khz model
- [x] Mixed F16 / F32 precision
- [ ] 4-bit and 8-bit quantization
- [x] Metal support
- [x] cuBLAS support
- [ ] Metal support
- [ ] CoreML support

## Implementation details

- The core tensor operations are implemented in C ([ggml.h](ggml.h) / [ggml.c](ggml.c))
- The encoder-decoder architecture and the high-level C-style API are implemented in C++ ([encodec.h](encodec.h) / [encodec.cpp](encodec.cpp))
- Basic usage is demonstrated in [main.cpp](examples/main).


## Usage

Here are the steps for the encodec model.
Expand Down
113 changes: 113 additions & 0 deletions decoder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#pragma once

#include <vector>

#include "ggml.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"

#include "lstm.h"
#include "utils.h"


struct encodec_decoder_block {
// upsampling layers
struct ggml_tensor *us_conv_w;
struct ggml_tensor *us_conv_b;

// conv1
struct ggml_tensor *conv_1_w;
struct ggml_tensor *conv_1_b;

// conv2
struct ggml_tensor *conv_2_w;
struct ggml_tensor *conv_2_b;

// shortcut
struct ggml_tensor *conv_sc_w;
struct ggml_tensor *conv_sc_b;
};

struct encodec_decoder {
struct ggml_tensor *init_conv_w;
struct ggml_tensor *init_conv_b;

encodec_lstm lstm;

struct ggml_tensor *final_conv_w;
struct ggml_tensor *final_conv_b;

std::vector<encodec_decoder_block> blocks;
};

struct ggml_tensor *encodec_forward_decoder(
const struct encodec_decoder *decoder, struct ggml_allocr *allocr, struct ggml_context *ctx0,
struct ggml_tensor *quantized_out, const int *ratios, const int kernel_size, const int res_kernel_size,
const int stride) {

if (!quantized_out) {
fprintf(stderr, "%s: null input tensor\n", __func__);
return NULL;
}

struct ggml_tensor *inpL = strided_conv_1d(
ctx0, quantized_out, decoder->init_conv_w, decoder->init_conv_b, stride);

// lstm
{
struct ggml_tensor *cur = inpL;

const encodec_lstm lstm = decoder->lstm;

// first lstm layer
struct ggml_tensor *hs1 = forward_pass_lstm_unilayer(
ctx0, allocr, cur, lstm.l0_ih_w, lstm.l0_hh_w,
lstm.l0_ih_b, lstm.l0_hh_b);

// second lstm layer
struct ggml_tensor *out = forward_pass_lstm_unilayer(
ctx0, allocr, hs1, lstm.l1_ih_w, lstm.l1_hh_w,
lstm.l1_ih_b, lstm.l1_hh_b);

inpL = ggml_add(ctx0, inpL, out);
}

for (int layer_ix = 0; layer_ix < 4; layer_ix++) {
encodec_decoder_block block = decoder->blocks[layer_ix];

// upsampling layers
inpL = ggml_elu(ctx0, inpL);

inpL = strided_conv_transpose_1d(
ctx0, inpL, block.us_conv_w, block.us_conv_b, ratios[layer_ix]);

struct ggml_tensor *current = inpL;

// shortcut
struct ggml_tensor *shortcut = strided_conv_1d(
ctx0, inpL, block.conv_sc_w, block.conv_sc_b, stride);

// conv1
current = ggml_elu(ctx0, current);

current = strided_conv_1d(
ctx0, current, block.conv_1_w, block.conv_1_b, stride);

// conv2
current = ggml_elu(ctx0, current);

current = strided_conv_1d(
ctx0, current, block.conv_2_w, block.conv_2_b, stride);

// residual connection
inpL = ggml_add(ctx0, current, shortcut);
}

// final conv
inpL = ggml_elu(ctx0, inpL);

struct ggml_tensor *decoded_inp = strided_conv_1d(
ctx0, inpL, decoder->final_conv_w, decoder->final_conv_b, stride);

return decoded_inp;
}
Loading

0 comments on commit a1aac2e

Please sign in to comment.