diff --git a/dac/__init__.py b/dac/__init__.py index d872b73..e16f354 100644 --- a/dac/__init__.py +++ b/dac/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.2" +__version__ = "0.0.3" __model_version__ = "0.0.1" import audiotools diff --git a/dac/nn/quantize.py b/dac/nn/quantize.py index 0a2f438..2a0b939 100644 --- a/dac/nn/quantize.py +++ b/dac/nn/quantize.py @@ -171,6 +171,9 @@ def forward(self, z, n_quantizers: int = None): n_quantizers = n_quantizers.to(z.device) for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( residual ) diff --git a/dac/utils/decode.py b/dac/utils/decode.py index 02cd5d9..aebc785 100644 --- a/dac/utils/decode.py +++ b/dac/utils/decode.py @@ -52,7 +52,9 @@ def process( """ if isinstance(generator, torch.nn.DataParallel): generator = generator.module - audio_signal = AudioSignal(artifacts["codes"], generator.sample_rate) + audio_signal = AudioSignal( + artifacts["codes"].astype(np.int64), generator.sample_rate + ) metadata = artifacts["metadata"] # Decode chunks diff --git a/dac/utils/encode.py b/dac/utils/encode.py index ac904b8..860a980 100644 --- a/dac/utils/encode.py +++ b/dac/utils/encode.py @@ -104,7 +104,7 @@ def process( codebook_indices = torch.cat(codebook_indices, dim=0) return { - "codes": codebook_indices.numpy(), + "codes": codebook_indices.numpy().astype(np.uint16), "metadata": { "original_db": input_db, "overlap_hop_duration": overlap_hop_duration, diff --git a/setup.py b/setup.py index a5299e0..b31b7a9 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="descript-audio-codec", - version="0.0.2", + version="0.0.3", classifiers=[ "Intended Audience :: Developers", "Natural Language :: English", diff --git a/tests/test_cli.py b/tests/test_cli.py index fa3def9..8565e25 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -50,4 +50,24 @@ def test_reconstruction(): run("decode") +def test_compression(): + # Test encoding + input_dir = Path(__file__).parent / "assets" / "input" + output_dir = input_dir.parent / "encoded_output_quantizers" + args = {"input": str(input_dir), "output": str(output_dir), "n_quantizers": 3} + with argbind.scope(args): + run("encode") + + # Open .dac file + dac_file = output_dir / "sample_0.dac" + artifacts = np.load(dac_file, allow_pickle=True)[()] + codes = artifacts["codes"] + + # Ensure that the number of quantizers is correct + assert codes.shape[1] == 3 + + # Ensure that dtype of compression is uint16 + assert codes.dtype == np.uint16 + + # CUDA_VISIBLE_DEVICES=0 python -m pytest tests/test_cli.py -s