diff --git a/README.md b/README.md index 73abf82..24db863 100644 --- a/README.md +++ b/README.md @@ -32,11 +32,14 @@ pip install git+https://github.com/descriptinc/descript-audio-codec ### Weights Weights are released as part of this repo under MIT license. -They are automatically downloaded when you first run `encode` or `decode` command. They can be cached locally with +We release weights for models that can natively support 24kHz and 44.1kHz sampling rates. +Weights are automatically downloaded when you first run `encode` or `decode` command. You can cache them using one of the following commands +```bash +python3 -m dac download # downloads the default 44kHz variant +python3 -m dac download --model_type 44khz # downloads the 44kHz variant +python3 -m dac download --model_type 24khz # downloads the 24kHz variant ``` -python3 -m dac download -``` -We provide a Dockerfile that installs all required dependencies for encoding and decoding. The build process caches model weights inside the image. This allows the image to be used without an internet connection. [Please refer to instructions below.](#docker-image) +We provide a Dockerfile that installs all required dependencies for encoding and decoding. The build process caches the default model weights inside the image. This allows the image to be used without an internet connection. [Please refer to instructions below.](#docker-image) ### Compress audio @@ -74,7 +77,7 @@ from audiotools import AudioSignal model = DAC() # Load compatible pre-trained model -model = load_model(dac.__model_version__) +model = load_model(tag="latest", model_type="44khz") model.eval() model.to('cuda') diff --git a/dac/__init__.py b/dac/__init__.py index e16f354..231ebbc 100644 --- a/dac/__init__.py +++ b/dac/__init__.py @@ -1,5 +1,8 @@ -__version__ = "0.0.3" -__model_version__ = "0.0.1" +__version__ = "0.0.4" + +# preserved here for legacy reasons +__model_version__ = "latest" + import audiotools audiotools.ml.BaseModel.INTERN += ["dac.**"] diff --git a/dac/utils/__init__.py b/dac/utils/__init__.py index 7ee945d..3693b82 100644 --- a/dac/utils/__init__.py +++ b/dac/utils/__init__.py @@ -1,24 +1,68 @@ from pathlib import Path +import argbind from audiotools import ml import dac - DAC = dac.model.DAC Accelerator = ml.Accelerator +__MODEL_LATEST_TAGS__ = { + "44khz": "0.0.1", + "24khz": "0.0.4", +} + +__MODEL_URLS__ = { + ( + "44khz", + "0.0.1", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", + ( + "24khz", + "0.0.4", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", +} + -def ensure_default_model(tag: str = dac.__model_version__): +@argbind.bind(group="download", positional=True, without_prefix=True) +def ensure_default_model(tag: str = "latest", model_type: str = "44khz"): """ - Function that downloads the weights file from URL if a local cache is not - found. + Function that downloads the weights file from URL if a local cache is not found. - Args: - tag (str): The tag of the model to download. + Parameters + ---------- + tag : str + The tag of the model to download. Defaults to "latest". + model_type : str + The type of model to download. Must be one of "44khz" or "24khz". Defaults to "44khz". + + Returns + ------- + Path + Directory path required to load model via audiotools. """ - download_link = f"https://github.com/descriptinc/descript-audio-codec/releases/download/{tag}/weights.pth" - local_path = Path.home() / ".cache" / "descript" / tag / "dac" / f"weights.pth" + model_type = model_type.lower() + tag = tag.lower() + + assert model_type in [ + "44khz", + "24khz", + ], "model_type must be one of '44khz' or '24khz'" + + if tag == "latest": + tag = __MODEL_LATEST_TAGS__[model_type] + + download_link = __MODEL_URLS__.get((model_type, tag), None) + + if download_link is None: + raise ValueError( + f"Could not find model with tag {tag} and model type {model_type}" + ) + + local_path = ( + Path.home() / ".cache" / "descript" / model_type / tag / "dac" / f"weights.pth" + ) if not local_path.exists(): local_path.parent.mkdir(parents=True, exist_ok=True) @@ -38,11 +82,12 @@ def ensure_default_model(tag: str = dac.__model_version__): def load_model( - tag: str, + tag: str = "latest", load_path: str = "", + model_type: str = "44khz", ): if not load_path: - load_path = ensure_default_model(tag) + load_path = ensure_default_model(tag, model_type) kwargs = { "folder": load_path, "map_location": "cpu", diff --git a/dac/utils/decode.py b/dac/utils/decode.py index aebc785..69bdccd 100644 --- a/dac/utils/decode.py +++ b/dac/utils/decode.py @@ -7,7 +7,6 @@ from audiotools import AudioSignal from tqdm import tqdm -import dac from dac.utils import load_model warnings.filterwarnings("ignore", category=UserWarning) @@ -99,13 +98,36 @@ def decode( input: str, output: str = "", weights_path: str = "", - model_tag: str = dac.__model_version__, + model_tag: str = "latest", preserve_sample_rate: bool = False, device: str = "cuda", + model_type: str = "44khz", ): + """Decode audio from codes. + + Parameters + ---------- + input : str + Path to input directory or file + output : str, optional + Path to output directory, by default "". + If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + preserve_sample_rate : bool, optional + If True, return audio will have the same sample rate as the original + device : str, optional + Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. + model_type : str, optional + The type of model to download. Must be one of "44khz" or "24khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ generator = load_model( tag=model_tag, load_path=weights_path, + model_type=model_type, ) generator.to(device) generator.eval() diff --git a/dac/utils/encode.py b/dac/utils/encode.py index 860a980..f45912e 100644 --- a/dac/utils/encode.py +++ b/dac/utils/encode.py @@ -9,7 +9,6 @@ from audiotools.core import util from tqdm import tqdm -import dac from dac.utils import load_model warnings.filterwarnings("ignore", category=UserWarning) @@ -124,13 +123,35 @@ def encode( input: str, output: str = "", weights_path: str = "", - model_tag: str = dac.__model_version__, + model_tag: str = "latest", n_quantizers: int = None, device: str = "cuda", + model_type: str = "44khz", ): + """Encode audio files in input path to .dac format. + + Parameters + ---------- + input : str + Path to input audio file or directory + output : str, optional + Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + n_quantizers : int, optional + Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. + device : str, optional + Device to use, by default "cuda" + model_type : str, optional + The type of model to download. Must be one of "44khz" or "24khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ generator = load_model( tag=model_tag, load_path=weights_path, + model_type=model_type, ) generator.to(device) generator.eval() diff --git a/setup.py b/setup.py index b31b7a9..490d0ec 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="descript-audio-codec", - version="0.0.3", + version="0.0.4", classifiers=[ "Intended Audience :: Developers", "Natural Language :: English", diff --git a/tests/test_cli.py b/tests/test_cli.py index 8565e25..60c5215 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,6 +6,8 @@ import argbind import numpy as np +import pytest +import torch from audiotools import AudioSignal from dac.__main__ import run @@ -28,20 +30,23 @@ def teardown_module(module): subprocess.check_output(["rm", "-rf", f"{repo_root}/tests/assets"]) -def test_reconstruction(): +@pytest.mark.parametrize("model_type", ["44khz", "24khz"]) +def test_reconstruction(model_type): # Test encoding input_dir = Path(__file__).parent / "assets" / "input" - output_dir = input_dir.parent / "encoded_output" + output_dir = input_dir.parent / model_type / "encoded_output" args = { "input": str(input_dir), "output": str(output_dir), + "device": "cuda" if torch.cuda.is_available() else "cpu", + "model_type": model_type, } with argbind.scope(args): run("encode") # Test decoding input_dir = output_dir - output_dir = input_dir.parent / "decoded_output" + output_dir = input_dir.parent / model_type / "decoded_output" args = { "input": str(input_dir), "output": str(output_dir), @@ -54,7 +59,12 @@ def test_compression(): # Test encoding input_dir = Path(__file__).parent / "assets" / "input" output_dir = input_dir.parent / "encoded_output_quantizers" - args = {"input": str(input_dir), "output": str(output_dir), "n_quantizers": 3} + args = { + "input": str(input_dir), + "output": str(output_dir), + "n_quantizers": 3, + "device": "cuda" if torch.cuda.is_available() else "cpu", + } with argbind.scope(args): run("encode")