Skip to content

Commit

Permalink
Refactor: BytesIO を使い、一時ファイルに書き込まずインメモリで音声合成結果の WAV をレスポンスとして返す
Browse files Browse the repository at this point in the history
逆に最初からなぜこの実装になっていなかったのが謎…
  • Loading branch information
tsukumijima committed Nov 29, 2024
1 parent 3cadf77 commit 3c36dc4
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 88 deletions.
31 changes: 12 additions & 19 deletions voicevox_engine/app/routers/morphing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""モーフィング機能を提供する API Router"""

import io
from functools import lru_cache
from tempfile import NamedTemporaryFile
from typing import Annotated

import soundfile
from fastapi import APIRouter, HTTPException, Query
from fastapi import APIRouter, HTTPException, Query, Response
from pydantic.json_schema import SkipJsonSchema
from starlette.background import BackgroundTask
from starlette.responses import FileResponse

from voicevox_engine.aivm_manager import AivmManager
from voicevox_engine.metas.Metas import StyleId
Expand All @@ -24,7 +22,6 @@
)
from voicevox_engine.morphing.morphing import synthesize_morphed_wave
from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager
from voicevox_engine.utility.file_utility import try_delete_file

# キャッシュを有効化
# モジュール側でlru_cacheを指定するとキャッシュを制御しにくいため、HTTPサーバ側で指定する
Expand Down Expand Up @@ -72,7 +69,7 @@ def morphable_targets(

@router.post(
"/synthesis_morphing",
response_class=FileResponse,
response_class=Response,
responses={
200: {
"content": {
Expand All @@ -91,7 +88,7 @@ def _synthesis_morphing(
str | SkipJsonSchema[None],
Query(description="AivisSpeech Engine ではサポートされていないパラメータです (常に無視されます) 。"),
] = None, # fmt: skip # noqa
) -> FileResponse:
) -> Response:
"""
指定された 2 種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。<br>
モーフィングの割合は `morph_rate` で指定でき、0.0 でベースのスタイル、1.0 でターゲットのスタイルに近づきます。<br>
Expand Down Expand Up @@ -127,18 +124,14 @@ def _synthesis_morphing(
output_stereo=query.outputStereo,
)

with NamedTemporaryFile(delete=False) as f:
soundfile.write(
file=f,
data=morph_wave,
samplerate=query.outputSamplingRate,
format="WAV",
)

return FileResponse(
f.name,
media_type="audio/wav",
background=BackgroundTask(try_delete_file, f.name),
buffer = io.BytesIO()
soundfile.write(
file=buffer,
data=morph_wave,
samplerate=query.outputSamplingRate,
format="WAV",
)

return Response(buffer.getvalue(), media_type="audio/wav")

return router
120 changes: 51 additions & 69 deletions voicevox_engine/app/routers/tts_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""音声合成機能を提供する API Router"""

import io
import zipfile
from tempfile import NamedTemporaryFile, TemporaryFile
from typing import Annotated, Self

import soundfile
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi import APIRouter, HTTPException, Query, Request, Response
from pydantic import BaseModel, Field
from pydantic.json_schema import SkipJsonSchema
from starlette.background import BackgroundTask
from starlette.responses import FileResponse

from voicevox_engine.cancellable_engine import CancellableEngine
from voicevox_engine.core.core_adapter import DeviceSupport
Expand All @@ -32,7 +30,6 @@
Score,
)
from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager
from voicevox_engine.utility.file_utility import try_delete_file


class ParseKanaBadRequest(BaseModel):
Expand Down Expand Up @@ -263,7 +260,7 @@ def mora_pitch(

@router.post(
"/synthesis",
response_class=FileResponse,
response_class=Response,
responses={
200: {
"content": {
Expand All @@ -285,7 +282,7 @@ def synthesis(
str | SkipJsonSchema[None],
Query(description="AivisSpeech Engine ではサポートされていないパラメータです (常に無視されます) 。"),
] = None, # fmt: skip # noqa
) -> FileResponse:
) -> Response:
"""
指定されたスタイル ID に紐づく音声合成モデルを用いて音声合成を行います。
"""
Expand All @@ -295,20 +292,16 @@ def synthesis(
query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak
)

with NamedTemporaryFile(delete=False) as f:
soundfile.write(
file=f, data=wave, samplerate=query.outputSamplingRate, format="WAV"
)

return FileResponse(
f.name,
media_type="audio/wav",
background=BackgroundTask(try_delete_file, f.name),
buffer = io.BytesIO()
soundfile.write(
file=buffer, data=wave, samplerate=query.outputSamplingRate, format="WAV"
)

return Response(buffer.getvalue(), media_type="audio/wav")

@router.post(
"/cancellable_synthesis",
response_class=FileResponse,
response_class=Response,
responses={
200: {
"content": {
Expand All @@ -328,7 +321,7 @@ def cancellable_synthesis(
str | SkipJsonSchema[None],
Query(description="AivisSpeech Engine ではサポートされていないパラメータです (常に無視されます) 。"),
] = None, # fmt: skip # noqa
) -> FileResponse:
) -> Response:
raise HTTPException(
status_code=501,
detail="Cancelable synthesis is not supported in AivisSpeech Engine.",
Expand Down Expand Up @@ -359,7 +352,7 @@ def cancellable_synthesis(

@router.post(
"/multi_synthesis",
response_class=FileResponse,
response_class=Response,
responses={
200: {
"content": {
Expand All @@ -379,36 +372,33 @@ def multi_synthesis(
str | SkipJsonSchema[None],
Query(description="AivisSpeech Engine ではサポートされていないパラメータです (常に無視されます) 。"),
] = None, # fmt: skip # noqa
) -> FileResponse:
) -> Response:
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
sampling_rate = queries[0].outputSamplingRate

with NamedTemporaryFile(delete=False) as f:
with zipfile.ZipFile(f, mode="a") as zip_file:
for i in range(len(queries)):
if queries[i].outputSamplingRate != sampling_rate:
raise HTTPException(
status_code=422,
detail="サンプリングレートが異なるクエリがあります",
)

with TemporaryFile() as wav_file:
wave = engine.synthesize_wave(queries[i], style_id)
soundfile.write(
file=wav_file,
data=wave,
samplerate=sampling_rate,
format="WAV",
)
wav_file.seek(0)
zip_file.writestr(f"{str(i + 1).zfill(3)}.wav", wav_file.read())
buffer = io.BytesIO()
with zipfile.ZipFile(buffer, mode="a") as zip_file:
for i in range(len(queries)):
if queries[i].outputSamplingRate != sampling_rate:
raise HTTPException(
status_code=422,
detail="サンプリングレートが異なるクエリがあります",
)

wav_file_buffer = io.BytesIO()
wave = engine.synthesize_wave(queries[i], style_id)
soundfile.write(
file=wav_file_buffer,
data=wave,
samplerate=sampling_rate,
format="WAV",
)
zip_file.writestr(
f"{str(i + 1).zfill(3)}.wav", wav_file_buffer.getvalue()
)

return FileResponse(
f.name,
media_type="application/zip",
background=BackgroundTask(try_delete_file, f.name),
)
return Response(buffer.getvalue(), media_type="application/zip")

@router.post(
"/sing_frame_audio_query",
Expand Down Expand Up @@ -483,7 +473,7 @@ def sing_frame_volume(

@router.post(
"/frame_synthesis",
response_class=FileResponse,
response_class=Response,
responses={
200: {
"content": {
Expand All @@ -501,7 +491,7 @@ def frame_synthesis(
str | SkipJsonSchema[None],
Query(description="AivisSpeech Engine ではサポートされていないパラメータです (常に無視されます) 。"),
] = None, # fmt: skip # noqa
) -> FileResponse:
) -> Response:
# """
# 歌唱音声合成を行います。
# """
Expand All @@ -517,21 +507,17 @@ def frame_synthesis(
except TalkSingInvalidInputError as e:
raise HTTPException(status_code=400, detail=str(e))
with NamedTemporaryFile(delete=False) as f:
soundfile.write(
file=f, data=wave, samplerate=query.outputSamplingRate, format="WAV"
)
return FileResponse(
f.name,
media_type="audio/wav",
background=BackgroundTask(try_delete_file, f.name),
buffer = io.BytesIO()
soundfile.write(
file=buffer, data=wave, samplerate=query.outputSamplingRate, format="WAV"
)
return Response(buffer.getvalue(), media_type="audio/wav")
"""

@router.post(
"/connect_waves",
response_class=FileResponse,
response_class=Response,
responses={
200: {
"content": {
Expand All @@ -542,7 +528,7 @@ def frame_synthesis(
tags=["音声合成"],
summary="base64エンコードされた複数のwavデータを一つに結合する",
)
def connect_waves(waves: list[str]) -> FileResponse:
def connect_waves(waves: list[str]) -> Response:
"""
base64エンコードされたwavデータを一纏めにし、wavファイルで返します。
"""
Expand All @@ -551,20 +537,16 @@ def connect_waves(waves: list[str]) -> FileResponse:
except ConnectBase64WavesException as err:
raise HTTPException(status_code=422, detail=str(err))

with NamedTemporaryFile(delete=False) as f:
soundfile.write(
file=f,
data=waves_nparray,
samplerate=sampling_rate,
format="WAV",
)

return FileResponse(
f.name,
media_type="audio/wav",
background=BackgroundTask(try_delete_file, f.name),
buffer = io.BytesIO()
soundfile.write(
file=buffer,
data=waves_nparray,
samplerate=sampling_rate,
format="WAV",
)

return Response(buffer.getvalue(), media_type="audio/wav")

@router.post(
"/validate_kana",
tags=["クエリ作成"],
Expand Down

0 comments on commit 3c36dc4

Please sign in to comment.