Skip to content

Commit

Permalink
fix: encodec forward pass (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier authored Oct 7, 2023
1 parent e59b55d commit 9c3316c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
23 changes: 17 additions & 6 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
```bash
python convert.py \
--dir-model ./ggml_weights/ \
--out-dir ./ggml_weights/
--out-dir ./ggml_weights/ \
--use-f16
```
"""
import argparse
Expand All @@ -39,9 +40,10 @@
parser = argparse.ArgumentParser()
parser.add_argument("--dir-model", type=str, required=True)
parser.add_argument("--out-dir", type=str, required=True)
parser.add_argument("--use-f16", type=bool, default=True)


def parse_codec_model(checkpoint, out_dir):
def parse_codec_model(checkpoint, out_dir, use_f16):
"""Load encodec model checkpoint."""
outfile = open(out_dir, "wb")
outfile.write(struct.pack("i", 0x67676d6c)) # ggml magic
Expand Down Expand Up @@ -78,14 +80,23 @@ def parse_codec_model(checkpoint, out_dir):

print(f"Processing variable: {name} with shape: {var_data.shape}")

if var_data.dtype != np.float32:
if use_f16:
if "weight" in name:
print(" Converting to float16")
var_data = var_data.astype(np.float16)
ftype_cur = 1
else:
print(" Converting to float32")
var_data = var_data.astype(np.float32)
ftype_cur = 0
else:
print(" Converting to float32")
var_data = var_data.astype(np.float32)
ftype_cur = 0

n_dims = len(var_data.shape)
encoded_name = name.encode("utf-8")
ftype = 0 # float32
outfile.write(struct.pack("iii", n_dims, len(encoded_name), ftype))
outfile.write(struct.pack("iii", n_dims, len(encoded_name), ftype_cur))

for i in range(n_dims):
outfile.write(struct.pack("i", var_data.shape[n_dims - 1 - i]))
Expand All @@ -107,6 +118,6 @@ def parse_codec_model(checkpoint, out_dir):
outfile = Path(out_dir / "ggml-model.bin")

checkpoint = torch.load(dir_model / "encodec_24khz-d7cc33bc.th", map_location="cpu")
parse_codec_model(checkpoint, outfile)
parse_codec_model(checkpoint, outfile, args.use_f16)

print("Done.")
14 changes: 8 additions & 6 deletions encodec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,10 @@ static struct ggml_tensor * forward_pass_lstm_unilayer(
struct ggml_tensor * c_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim);
struct ggml_tensor * h_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim);

if (is_measure) {
h_t = ggml_set_zero(h_t);
c_t = ggml_set_zero(c_t);
}
// if (!is_measure) {
// h_t = ggml_set_zero(h_t);
// c_t = ggml_set_zero(c_t);
// }

struct ggml_tensor * current = ggml_cont(ctx0, ggml_transpose(ctx0, inp));

Expand Down Expand Up @@ -697,7 +697,9 @@ static struct ggml_cgraph * encodec_build_graph(
const int n_q = codes->ne[1];

quantized_out = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length);
quantized_out = ggml_set_zero(quantized_out);
// if (!ggml_allocr_is_measure(ectx.allocr)) {
// quantized_out = ggml_set_zero(quantized_out);
// }

for (int i = 0; i < n_q; i++) {
encodec_quant_block block = model.quantizer.blocks[i];
Expand Down Expand Up @@ -818,7 +820,7 @@ bool encodec_reconstruct_audio(

ectx.ctx_audio = ggml_init(ggml_params);

ectx.reconstructed_audio = ggml_new_tensor_1d(ectx.ctx_audio, GGML_TYPE_F32, raw_audio.size());
ectx.reconstructed_audio = ggml_new_tensor_1d(ectx.ctx_audio, GGML_TYPE_F32, 100160);

// reconstruct the audio
ectx.buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
Expand Down
2 changes: 1 addition & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct encodec_params {
std::string model_path = "/Users/pbannier/Documents/encodec.cpp/ggml_weights/ggml-model.bin";

// input location
std::string original_audio_path = "/Users/pbannier/Documents/encodec/test_24k.wav";
std::string original_audio_path = "/Users/pbannier/Documents/encodec/decomp_24khz_True.wav";

// output location
std::string dest_wav_path = "output.wav";
Expand Down

0 comments on commit 9c3316c

Please sign in to comment.