From 818a108b41be2dd43dada04bd319fdfcdabc5c6a Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Wed, 29 Nov 2023 10:05:11 +0100 Subject: [PATCH] Adding custom model loading --- README.md | 15 ++++++++++++++- server/Dockerfile | 1 + server/Dockerfile.cuda121 | 1 + server/main.py | 17 ++++++++++++----- 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 245385e..66e985b 100644 --- a/README.md +++ b/README.md @@ -46,8 +46,21 @@ docker build -t xtts-stream . -f Dockerfile.cuda121 2. Run the server container: ```bash -$ docker run --gpus=all -e COQUI_TOS_AGREED=1 --rm -p 8000:80 xtts-stream +$ docker run --gpus all -e COQUI_TOS_AGREED=1 --rm -p 8000:80 xtts-stream ``` Setting the `COQUI_TOS_AGREED` environment variable to `1` indicates you have read and agreed to the terms of the [CPML license](https://coqui.ai/cpml). + +2. (bis) Run the server container with your own model: + +```bash +docker run -v /path/to/model/folder:/app/tts_models --gpus all --rm -p 8000:80 xtts-stream +``` + +Make sure the model folder contains the following files: +- `config.json` +- `model.pth` +- `vocab.json` + +(Fine-tuned XTTS models also are under the [CPML license](https://coqui.ai/cpml)) \ No newline at end of file diff --git a/server/Dockerfile b/server/Dockerfile index 212d98b..d4a1969 100644 --- a/server/Dockerfile +++ b/server/Dockerfile @@ -11,6 +11,7 @@ RUN python -m pip install --use-deprecated=legacy-resolver -r requirements.txt \ && python -m pip cache purge RUN python -m unidic download +RUN mkdir -p /app/tts_models COPY main.py . ENV NVIDIA_DISABLE_REQUIRE=1 diff --git a/server/Dockerfile.cuda121 b/server/Dockerfile.cuda121 index fe5df2e..7d9d70e 100644 --- a/server/Dockerfile.cuda121 +++ b/server/Dockerfile.cuda121 @@ -11,6 +11,7 @@ RUN python -m pip install --use-deprecated=legacy-resolver -r requirements.txt \ && python -m pip cache purge RUN python -m unidic download +RUN mkdir -p /app/tts_models COPY main.py . diff --git a/server/main.py b/server/main.py index dfe56b8..0e83cd9 100644 --- a/server/main.py +++ b/server/main.py @@ -23,11 +23,18 @@ torch.set_num_threads(int(os.environ.get("NUM_THREADS", "2"))) device = torch.device("cuda") -model_name = "tts_models/multilingual/multi-dataset/xtts_v2" -print("Downloading XTTS Model:",model_name,flush=True) -ModelManager().download_model(model_name) -model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--")) -print("XTTS Model downloaded",flush=True) +custom_model_path = os.environ.get("CUSTOM_MODEL_PATH", "/app/tts_models") + +if os.path.exists(custom_model_path) and os.path.isfile(custom_model_path + "/config.json"): + model_path = custom_model_path + print("Loading custom model from", model_path, flush=True) +else: + print("Loading default model", flush=True) + model_name = "tts_models/multilingual/multi-dataset/xtts_v2" + print("Downloading XTTS Model:",model_name, flush=True) + ModelManager().download_model(model_name) + model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--")) + print("XTTS Model downloaded",flush=True) print("Loading XTTS",flush=True) config = XttsConfig()