From 3c36dc46862e5c91a185606602da3d96255fe0ad Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 29 Nov 2024 15:54:33 +0900 Subject: [PATCH] =?UTF-8?q?Refactor:=20BytesIO=20=E3=82=92=E4=BD=BF?= =?UTF-8?q?=E3=81=84=E3=80=81=E4=B8=80=E6=99=82=E3=83=95=E3=82=A1=E3=82=A4?= =?UTF-8?q?=E3=83=AB=E3=81=AB=E6=9B=B8=E3=81=8D=E8=BE=BC=E3=81=BE=E3=81=9A?= =?UTF-8?q?=E3=82=A4=E3=83=B3=E3=83=A1=E3=83=A2=E3=83=AA=E3=81=A7=E9=9F=B3?= =?UTF-8?q?=E5=A3=B0=E5=90=88=E6=88=90=E7=B5=90=E6=9E=9C=E3=81=AE=20WAV=20?= =?UTF-8?q?=E3=82=92=E3=83=AC=E3=82=B9=E3=83=9D=E3=83=B3=E3=82=B9=E3=81=A8?= =?UTF-8?q?=E3=81=97=E3=81=A6=E8=BF=94=E3=81=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 逆に最初からなぜこの実装になっていなかったのが謎… --- voicevox_engine/app/routers/morphing.py | 31 ++--- voicevox_engine/app/routers/tts_pipeline.py | 120 +++++++++----------- 2 files changed, 63 insertions(+), 88 deletions(-) diff --git a/voicevox_engine/app/routers/morphing.py b/voicevox_engine/app/routers/morphing.py index 01c9ff8..fba340f 100644 --- a/voicevox_engine/app/routers/morphing.py +++ b/voicevox_engine/app/routers/morphing.py @@ -1,14 +1,12 @@ """モーフィング機能を提供する API Router""" +import io from functools import lru_cache -from tempfile import NamedTemporaryFile from typing import Annotated import soundfile -from fastapi import APIRouter, HTTPException, Query +from fastapi import APIRouter, HTTPException, Query, Response from pydantic.json_schema import SkipJsonSchema -from starlette.background import BackgroundTask -from starlette.responses import FileResponse from voicevox_engine.aivm_manager import AivmManager from voicevox_engine.metas.Metas import StyleId @@ -24,7 +22,6 @@ ) from voicevox_engine.morphing.morphing import synthesize_morphed_wave from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager -from voicevox_engine.utility.file_utility import try_delete_file # キャッシュを有効化 # モジュール側でlru_cacheを指定するとキャッシュを制御しにくいため、HTTPサーバ側で指定する @@ -72,7 +69,7 @@ def morphable_targets( @router.post( "/synthesis_morphing", - response_class=FileResponse, + response_class=Response, responses={ 200: { "content": { @@ -91,7 +88,7 @@ def _synthesis_morphing( str | SkipJsonSchema[None], Query(description="AivisSpeech Engine ではサポートされていないパラメータです (常に無視されます) 。"), ] = None, # fmt: skip # noqa - ) -> FileResponse: + ) -> Response: """ 指定された 2 種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。
モーフィングの割合は `morph_rate` で指定でき、0.0 でベースのスタイル、1.0 でターゲットのスタイルに近づきます。
@@ -127,18 +124,14 @@ def _synthesis_morphing( output_stereo=query.outputStereo, ) - with NamedTemporaryFile(delete=False) as f: - soundfile.write( - file=f, - data=morph_wave, - samplerate=query.outputSamplingRate, - format="WAV", - ) - - return FileResponse( - f.name, - media_type="audio/wav", - background=BackgroundTask(try_delete_file, f.name), + buffer = io.BytesIO() + soundfile.write( + file=buffer, + data=morph_wave, + samplerate=query.outputSamplingRate, + format="WAV", ) + return Response(buffer.getvalue(), media_type="audio/wav") + return router diff --git a/voicevox_engine/app/routers/tts_pipeline.py b/voicevox_engine/app/routers/tts_pipeline.py index 380600c..2499fcd 100644 --- a/voicevox_engine/app/routers/tts_pipeline.py +++ b/voicevox_engine/app/routers/tts_pipeline.py @@ -1,15 +1,13 @@ """音声合成機能を提供する API Router""" +import io import zipfile -from tempfile import NamedTemporaryFile, TemporaryFile from typing import Annotated, Self import soundfile -from fastapi import APIRouter, HTTPException, Query, Request +from fastapi import APIRouter, HTTPException, Query, Request, Response from pydantic import BaseModel, Field from pydantic.json_schema import SkipJsonSchema -from starlette.background import BackgroundTask -from starlette.responses import FileResponse from voicevox_engine.cancellable_engine import CancellableEngine from voicevox_engine.core.core_adapter import DeviceSupport @@ -32,7 +30,6 @@ Score, ) from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager -from voicevox_engine.utility.file_utility import try_delete_file class ParseKanaBadRequest(BaseModel): @@ -263,7 +260,7 @@ def mora_pitch( @router.post( "/synthesis", - response_class=FileResponse, + response_class=Response, responses={ 200: { "content": { @@ -285,7 +282,7 @@ def synthesis( str | SkipJsonSchema[None], Query(description="AivisSpeech Engine ではサポートされていないパラメータです (常に無視されます) 。"), ] = None, # fmt: skip # noqa - ) -> FileResponse: + ) -> Response: """ 指定されたスタイル ID に紐づく音声合成モデルを用いて音声合成を行います。 """ @@ -295,20 +292,16 @@ def synthesis( query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak ) - with NamedTemporaryFile(delete=False) as f: - soundfile.write( - file=f, data=wave, samplerate=query.outputSamplingRate, format="WAV" - ) - - return FileResponse( - f.name, - media_type="audio/wav", - background=BackgroundTask(try_delete_file, f.name), + buffer = io.BytesIO() + soundfile.write( + file=buffer, data=wave, samplerate=query.outputSamplingRate, format="WAV" ) + return Response(buffer.getvalue(), media_type="audio/wav") + @router.post( "/cancellable_synthesis", - response_class=FileResponse, + response_class=Response, responses={ 200: { "content": { @@ -328,7 +321,7 @@ def cancellable_synthesis( str | SkipJsonSchema[None], Query(description="AivisSpeech Engine ではサポートされていないパラメータです (常に無視されます) 。"), ] = None, # fmt: skip # noqa - ) -> FileResponse: + ) -> Response: raise HTTPException( status_code=501, detail="Cancelable synthesis is not supported in AivisSpeech Engine.", @@ -359,7 +352,7 @@ def cancellable_synthesis( @router.post( "/multi_synthesis", - response_class=FileResponse, + response_class=Response, responses={ 200: { "content": { @@ -379,36 +372,33 @@ def multi_synthesis( str | SkipJsonSchema[None], Query(description="AivisSpeech Engine ではサポートされていないパラメータです (常に無視されます) 。"), ] = None, # fmt: skip # noqa - ) -> FileResponse: + ) -> Response: version = core_version or LATEST_VERSION engine = tts_engines.get_engine(version) sampling_rate = queries[0].outputSamplingRate - with NamedTemporaryFile(delete=False) as f: - with zipfile.ZipFile(f, mode="a") as zip_file: - for i in range(len(queries)): - if queries[i].outputSamplingRate != sampling_rate: - raise HTTPException( - status_code=422, - detail="サンプリングレートが異なるクエリがあります", - ) - - with TemporaryFile() as wav_file: - wave = engine.synthesize_wave(queries[i], style_id) - soundfile.write( - file=wav_file, - data=wave, - samplerate=sampling_rate, - format="WAV", - ) - wav_file.seek(0) - zip_file.writestr(f"{str(i + 1).zfill(3)}.wav", wav_file.read()) + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, mode="a") as zip_file: + for i in range(len(queries)): + if queries[i].outputSamplingRate != sampling_rate: + raise HTTPException( + status_code=422, + detail="サンプリングレートが異なるクエリがあります", + ) + + wav_file_buffer = io.BytesIO() + wave = engine.synthesize_wave(queries[i], style_id) + soundfile.write( + file=wav_file_buffer, + data=wave, + samplerate=sampling_rate, + format="WAV", + ) + zip_file.writestr( + f"{str(i + 1).zfill(3)}.wav", wav_file_buffer.getvalue() + ) - return FileResponse( - f.name, - media_type="application/zip", - background=BackgroundTask(try_delete_file, f.name), - ) + return Response(buffer.getvalue(), media_type="application/zip") @router.post( "/sing_frame_audio_query", @@ -483,7 +473,7 @@ def sing_frame_volume( @router.post( "/frame_synthesis", - response_class=FileResponse, + response_class=Response, responses={ 200: { "content": { @@ -501,7 +491,7 @@ def frame_synthesis( str | SkipJsonSchema[None], Query(description="AivisSpeech Engine ではサポートされていないパラメータです (常に無視されます) 。"), ] = None, # fmt: skip # noqa - ) -> FileResponse: + ) -> Response: # """ # 歌唱音声合成を行います。 # """ @@ -517,21 +507,17 @@ def frame_synthesis( except TalkSingInvalidInputError as e: raise HTTPException(status_code=400, detail=str(e)) - with NamedTemporaryFile(delete=False) as f: - soundfile.write( - file=f, data=wave, samplerate=query.outputSamplingRate, format="WAV" - ) - - return FileResponse( - f.name, - media_type="audio/wav", - background=BackgroundTask(try_delete_file, f.name), + buffer = io.BytesIO() + soundfile.write( + file=buffer, data=wave, samplerate=query.outputSamplingRate, format="WAV" ) + + return Response(buffer.getvalue(), media_type="audio/wav") """ @router.post( "/connect_waves", - response_class=FileResponse, + response_class=Response, responses={ 200: { "content": { @@ -542,7 +528,7 @@ def frame_synthesis( tags=["音声合成"], summary="base64エンコードされた複数のwavデータを一つに結合する", ) - def connect_waves(waves: list[str]) -> FileResponse: + def connect_waves(waves: list[str]) -> Response: """ base64エンコードされたwavデータを一纏めにし、wavファイルで返します。 """ @@ -551,20 +537,16 @@ def connect_waves(waves: list[str]) -> FileResponse: except ConnectBase64WavesException as err: raise HTTPException(status_code=422, detail=str(err)) - with NamedTemporaryFile(delete=False) as f: - soundfile.write( - file=f, - data=waves_nparray, - samplerate=sampling_rate, - format="WAV", - ) - - return FileResponse( - f.name, - media_type="audio/wav", - background=BackgroundTask(try_delete_file, f.name), + buffer = io.BytesIO() + soundfile.write( + file=buffer, + data=waves_nparray, + samplerate=sampling_rate, + format="WAV", ) + return Response(buffer.getvalue(), media_type="audio/wav") + @router.post( "/validate_kana", tags=["クエリ作成"],