From abcfa78cfdd8c9af4e248075f9332759bff829f7 Mon Sep 17 00:00:00 2001 From: tarepan Date: Sat, 29 Jun 2024 15:23:47 +0900 Subject: [PATCH] =?UTF-8?q?=E6=95=B4=E7=90=86:=20`GET=20/supported=5Fdevic?= =?UTF-8?q?es`=20API=20=E3=82=92=20`tts=5Fpipeline`=20router=20=E3=81=B8?= =?UTF-8?q?=E7=A7=BB=E5=8B=95=20(#1444)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit refactor: `supported_devices` API を `tts_pipeline` router へ移動 --- voicevox_engine/app/application.py | 4 +- voicevox_engine/app/routers/engine_info.py | 42 +-------------------- voicevox_engine/app/routers/tts_pipeline.py | 33 +++++++++++++++- 3 files changed, 35 insertions(+), 44 deletions(-) diff --git a/voicevox_engine/app/application.py b/voicevox_engine/app/application.py index a3c57b6af..14fc89d21 100644 --- a/voicevox_engine/app/application.py +++ b/voicevox_engine/app/application.py @@ -94,9 +94,7 @@ def _get_core_characters(version: str | None) -> list[CoreCharacter]: generate_library_router(library_manager, verify_mutability_allowed) ) app.include_router(generate_user_dict_router(user_dict, verify_mutability_allowed)) - app.include_router( - generate_engine_info_router(core_version_list, tts_engines, engine_manifest) - ) + app.include_router(generate_engine_info_router(core_version_list, engine_manifest)) app.include_router( generate_setting_router( setting_loader, engine_manifest.brand_name, verify_mutability_allowed diff --git a/voicevox_engine/app/routers/engine_info.py b/voicevox_engine/app/routers/engine_info.py index caefe3e40..294d2591e 100644 --- a/voicevox_engine/app/routers/engine_info.py +++ b/voicevox_engine/app/routers/engine_info.py @@ -1,40 +1,13 @@ """エンジンの情報機能を提供する API Router""" -from typing import Self - -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel, Field -from pydantic.json_schema import SkipJsonSchema +from fastapi import APIRouter from voicevox_engine import __version__ -from voicevox_engine.core.core_adapter import DeviceSupport from voicevox_engine.engine_manifest import EngineManifest -from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager - - -class SupportedDevicesInfo(BaseModel): - """ - 対応しているデバイスの情報 - """ - - cpu: bool = Field(description="CPUに対応しているか") - cuda: bool = Field(description="CUDA(Nvidia GPU)に対応しているか") - dml: bool = Field(description="DirectML(Nvidia GPU/Radeon GPU等)に対応しているか") - - @classmethod - def generate_from(cls, device_support: DeviceSupport) -> Self: - """`DeviceSupport` インスタンスからこのインスタンスを生成する。""" - return cls( - cpu=device_support.cpu, - cuda=device_support.cuda, - dml=device_support.dml, - ) def generate_engine_info_router( - core_version_list: list[str], - tts_engine_manager: TTSEngineManager, - engine_manifest_data: EngineManifest, + core_version_list: list[str], engine_manifest_data: EngineManifest ) -> APIRouter: """エンジン情報 API Router を生成する""" router = APIRouter(tags=["その他"]) @@ -49,17 +22,6 @@ async def core_versions() -> list[str]: """利用可能なコアのバージョン一覧を取得します。""" return core_version_list - @router.get("/supported_devices") - def supported_devices( - core_version: str | SkipJsonSchema[None] = None, - ) -> SupportedDevicesInfo: - """対応デバイスの一覧を取得します。""" - version = core_version or LATEST_VERSION - supported_devices = tts_engine_manager.get_engine(version).supported_devices - if supported_devices is None: - raise HTTPException(status_code=422, detail="非対応の機能です。") - return SupportedDevicesInfo.generate_from(supported_devices) - @router.get("/engine_manifest") async def engine_manifest() -> EngineManifest: """エンジンマニフェストを取得します。""" diff --git a/voicevox_engine/app/routers/tts_pipeline.py b/voicevox_engine/app/routers/tts_pipeline.py index 6555c844c..989dd70c5 100644 --- a/voicevox_engine/app/routers/tts_pipeline.py +++ b/voicevox_engine/app/routers/tts_pipeline.py @@ -2,7 +2,7 @@ import zipfile from tempfile import NamedTemporaryFile, TemporaryFile -from typing import Annotated +from typing import Annotated, Self import soundfile from fastapi import APIRouter, HTTPException, Query, Request @@ -15,6 +15,7 @@ CancellableEngine, CancellableEngineInternalError, ) +from voicevox_engine.core.core_adapter import DeviceSupport from voicevox_engine.metas.Metas import StyleId from voicevox_engine.model import AudioQuery from voicevox_engine.preset.preset_manager import ( @@ -63,6 +64,25 @@ def __init__(self, err: ParseKanaError): super().__init__(text=err.text, error_name=err.errname, error_args=err.kwargs) +class SupportedDevicesInfo(BaseModel): + """ + 対応しているデバイスの情報 + """ + + cpu: bool = Field(description="CPUに対応しているか") + cuda: bool = Field(description="CUDA(Nvidia GPU)に対応しているか") + dml: bool = Field(description="DirectML(Nvidia GPU/Radeon GPU等)に対応しているか") + + @classmethod + def generate_from(cls, device_support: DeviceSupport) -> Self: + """`DeviceSupport` インスタンスからこのインスタンスを生成する。""" + return cls( + cpu=device_support.cpu, + cuda=device_support.cuda, + dml=device_support.dml, + ) + + def generate_tts_pipeline_router( tts_engines: TTSEngineManager, preset_manager: PresetManager, @@ -543,4 +563,15 @@ def is_initialized_speaker( engine = tts_engines.get_engine(version) return engine.is_synthesis_initialized(style_id) + @router.get("/supported_devices", tags=["その他"]) + def supported_devices( + core_version: str | SkipJsonSchema[None] = None, + ) -> SupportedDevicesInfo: + """対応デバイスの一覧を取得します。""" + version = core_version or LATEST_VERSION + supported_devices = tts_engines.get_engine(version).supported_devices + if supported_devices is None: + raise HTTPException(status_code=422, detail="非対応の機能です。") + return SupportedDevicesInfo.generate_from(supported_devices) + return router