Skip to content

Commit

Permalink
✨ tts v2 api #187
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Nov 10, 2024
1 parent 13d4408 commit 27f5dfc
Show file tree
Hide file tree
Showing 29 changed files with 391 additions and 203 deletions.
16 changes: 16 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,22 @@ XTTS v2 系列 API 支持一些外部应用的调用,例如与 SillyTavern 的

![sillytavern_tts](./sillytavern_tts.png)

### Forge Api v2
为了完全提供所有功能开发的 v2 接口,可以对所有可配置内容进行设置,同时还可以直接输入base64编码的音频作为参考音频创建临时 spk 用于推理

```bash
curl http://localhost:7870/v2/tts \
-H "Authorization: Bearer anything_your_wanna" \
-H "Content-Type: application/json" \
-d '{
"text": "Today is a wonderful day to build something people love!"
}' \
--output speech.mp3
```

#### usage
> WIP
## playground

启动 api 服务之后,在 `/playground` 下有一个非 gradio 的调试页面用于接口测试
Expand Down
7 changes: 7 additions & 0 deletions modules/api/api_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
vc_api,
xtts_v2_api,
)
from modules.api.v2 import (
tts_api,
#
)
from modules.utils import env

logger = logging.getLogger(__name__)
Expand All @@ -41,6 +45,9 @@ def create_api(app: FastAPI, exclude=[]):
stt_api.setup(app_mgr)
vc_api.setup(app_mgr)

# v2 apis
tts_api.setup(app_mgr)

return app_mgr


Expand Down
59 changes: 18 additions & 41 deletions modules/api/impl/google_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from modules.core.handler.datacls.enhancer_model import EnhancerConfig
from modules.core.handler.datacls.tts_model import InferConfig, TTSConfig
from modules.core.handler.datacls.vc_model import VCConfig
from modules.core.handler.SSMLHandler import SSMLHandler
from modules.core.handler.TTSHandler import TTSHandler
from modules.core.spk.SpkMgr import spk_mgr
from modules.core.spk.TTSSpeaker import TTSSpeaker
Expand Down Expand Up @@ -139,47 +138,22 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
bitrate=audio_bitrate,
)

text_content = input.text
ssml_content = input.ssml
handler = TTSHandler(
text_content=text_content,
spk=speaker,
tts_config=tts_config,
infer_config=infer_config,
adjust_config=adjust_config,
enhancer_config=enhancer_config,
encoder_config=encoder_config,
vc_config=VCConfig(enabled=False),
)
try:
if input.text:
text_content = input.text

handler = TTSHandler(
text_content=text_content,
spk=speaker,
tts_config=tts_config,
infer_config=infer_config,
adjust_config=adjust_config,
enhancer_config=enhancer_config,
encoder_config=encoder_config,
vc_config=VCConfig(enabled=False),
)
media_type = handler.get_media_type()

base64_string = handler.enqueue_to_base64()
return {"audioContent": f"data:{media_type};base64,{base64_string}"}

elif input.ssml:
ssml_content = input.ssml

handler = SSMLHandler(
ssml_content=ssml_content,
tts_config=tts_config,
infer_config=infer_config,
adjust_config=adjust_config,
enhancer_config=enhancer_config,
encoder_config=encoder_config,
)
media_type = handler.get_media_type()

base64_string = handler.enqueue_to_base64()

return {"audioContent": f"data:{media_type};base64,{base64_string}"}

else:
raise HTTPException(
status_code=422, detail="Invalid input text or ssml specified."
)

media_type = handler.get_media_type()
base64_string = handler.enqueue_to_base64()
return {"audioContent": f"data:{media_type};base64,{base64_string}"}
except Exception as e:
import logging

Expand All @@ -206,12 +180,14 @@ def setup(app: APIManager):
- 编码格式影响的是 audioContent 的二进制格式,所以所有format都是返回带有base64数据的json
""",
tags=["Google API"],
)(google_text_synthesize)

@app.post(
"/v1/speech:recognize",
# response_model=None,
description="Performs synchronous speech recognition: receive results after all audio has been sent and processed.",
tags=["Google API"],
)
async def speech_recognize():
raise HTTPException(status_code=501, detail="Not implemented")
Expand All @@ -220,6 +196,7 @@ async def speech_recognize():
"/v1/speech:longrunningrecognize",
# response_model=None,
description="Performs asynchronous speech recognition: receive results via the google.longrunning.Operations interface.",
tags=["Google API"],
)
async def long_running_recognize():
raise HTTPException(status_code=501, detail="Not implemented")
10 changes: 7 additions & 3 deletions modules/api/impl/models_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@

def setup(app: APIManager):

@app.get("/v1/models/reload", response_model=api_utils.BaseResponse)
@app.get(
"/v1/models/reload", response_model=api_utils.BaseResponse, tags=["Models"]
)
async def reload_models():
zoo.model_zoo.reload_all_models()
return api_utils.success_response("Models reloaded")

@app.get("/v1/models/unload", response_model=api_utils.BaseResponse)
@app.get(
"/v1/models/unload", response_model=api_utils.BaseResponse, tags=["Models"]
)
async def unload_models():
zoo.model_zoo.unload_all_models()
return api_utils.success_response("Models unloaded")

@app.get("/v1/models/list", response_model=api_utils.BaseResponse)
@app.get("/v1/models/list", response_model=api_utils.BaseResponse, tags=["Models"])
async def unload_models():
model_ids = zoo.model_zoo.get_model_ids()
return api_utils.success_response(model_ids)
2 changes: 2 additions & 0 deletions modules/api/impl/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def setup(app: APIManager):
> model 可填任意值
""",
tags=["OpenAI API"],
)(openai_speech_api)

def pydub_to_numpy(audio_segment: AudioSegment) -> np.ndarray:
Expand All @@ -206,6 +207,7 @@ def pydub_to_numpy(audio_segment: AudioSegment) -> np.ndarray:
# NOTE: 其实最好是不设置这个model...因为这个接口可以返回很多情况...
# response_model=TranscriptionsResponse,
description="Transcribes audio into the input language.",
tags=["OpenAI API"],
)
async def transcribe(
file: UploadFile = File(...),
Expand Down
18 changes: 11 additions & 7 deletions modules/api/impl/refiner_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,14 @@ async def text_normalize_post(request: TextNormalizeRequest):


def setup(app: APIManager):
app.post("/v1/prompt/refine", response_model=api_utils.BaseResponse)(
refiner_prompt_post
)

app.post("/v1/text/normalize", response_model=api_utils.BaseResponse)(
text_normalize_post
)
app.post(
"/v1/prompt/refine",
response_model=api_utils.BaseResponse,
tags=["Text Refiner"],
)(refiner_prompt_post)

app.post(
"/v1/text/normalize",
response_model=api_utils.BaseResponse,
tags=["Text Refiner"],
)(text_normalize_post)
24 changes: 18 additions & 6 deletions modules/api/impl/speaker_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def setup(app: APIManager):
class SpkListParams(BaseModel):
full_data: bool = Query(False, description="Return all data")

@app.get("/v1/speakers/list", response_model=api_utils.BaseResponse)
@app.get(
"/v1/speakers/list", response_model=api_utils.BaseResponse, tags=["Speaker"]
)
async def list_speakers(
request: Request,
parmas: SpkListParams = Depends(),
Expand All @@ -50,12 +52,16 @@ async def list_speakers(

return api_utils.success_response(data)

@app.post("/v1/speakers/refresh", response_model=api_utils.BaseResponse)
@app.post(
"/v1/speakers/refresh", response_model=api_utils.BaseResponse, tags=["Speaker"]
)
async def refresh_speakers():
spk_mgr.refresh()
return api_utils.success_response(None)

@app.post("/v1/speakers/update", response_model=api_utils.BaseResponse)
@app.post(
"/v1/speakers/update", response_model=api_utils.BaseResponse, tags=["Speaker"]
)
async def update_speakers(request: SpeakersUpdate):
for config in request.speakers:
config: dict = config
Expand Down Expand Up @@ -85,7 +91,9 @@ async def update_speakers(request: SpeakersUpdate):
return api_utils.success_response(None)

# TODO 需要适配新版本 speaker
@app.post("/v1/speaker/create", response_model=api_utils.BaseResponse)
@app.post(
"/v1/speaker/create", response_model=api_utils.BaseResponse, tags=["Speaker"]
)
async def create_speaker(request: CreateSpeaker):
if (
request.tensor
Expand All @@ -111,7 +119,9 @@ async def create_speaker(request: CreateSpeaker):
spk_mgr.refresh()
return api_utils.success_response(spk.to_json())

@app.post("/v1/speaker/update", response_model=api_utils.BaseResponse)
@app.post(
"/v1/speaker/update", response_model=api_utils.BaseResponse, tags=["Speaker"]
)
async def update_speaker(request: UpdateSpeaker):
speaker = spk_mgr.get_speaker_by_id(request.id)
if speaker is None:
Expand All @@ -132,7 +142,9 @@ async def update_speaker(request: UpdateSpeaker):
spk_mgr.update_item(speaker, lambda x: x.id == speaker.id)
return api_utils.success_response(None)

@app.post("/v1/speaker/detail", response_model=api_utils.BaseResponse)
@app.post(
"/v1/speaker/detail", response_model=api_utils.BaseResponse, tags=["Speaker"]
)
async def speaker_detail(request: SpeakerDetail):
speaker = spk_mgr.get_speaker_by_id(request.id)
if speaker is None:
Expand Down
9 changes: 5 additions & 4 deletions modules/api/impl/ssml_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from modules.core.handler.datacls.enhancer_model import EnhancerConfig
from modules.core.handler.datacls.tts_model import InferConfig, TTSConfig
from modules.core.handler.SSMLHandler import SSMLHandler
from modules.core.handler.TTSHandler import TTSHandler


class SSMLRequest(BaseModel):
Expand Down Expand Up @@ -76,10 +76,9 @@ async def synthesize_ssml_api(
format=AudioFormat(format),
bitrate="64k",
)
# TODO: 作为 SSML 默认值
tts_config = TTSConfig(mid=model)

handler = SSMLHandler(
handler = TTSHandler(
ssml_content=ssml,
tts_config=tts_config,
infer_config=infer_config,
Expand All @@ -102,4 +101,6 @@ async def synthesize_ssml_api(


def setup(api_manager: APIManager):
api_manager.post("/v1/ssml", response_class=FileResponse)(synthesize_ssml_api)
api_manager.post("/v1/ssml", response_class=FileResponse, tags=["SSML"])(
synthesize_ssml_api
)
2 changes: 2 additions & 0 deletions modules/api/impl/stt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def pydub_to_numpy(audio_segment: AudioSegment) -> np.ndarray:
"/v1/stt/transcribe",
description="Transcribes audio into the input language.",
response_model=TranscriptionsResponse,
tags=["STT"],
)
async def transcribe(
form_data: TranscriptionsForm = Depends(TranscriptionsForm.as_form),
Expand Down Expand Up @@ -170,6 +171,7 @@ async def transcribe(
@app.post(
"/v1/stt/stream",
description="Transcribes audio into the input language in real-time.",
tags=["STT"],
)
async def transcribe_stream(
form_data: TranscriptionsForm = Depends(TranscriptionsForm.as_form),
Expand Down
4 changes: 3 additions & 1 deletion modules/api/impl/style_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ async def create_style():


def setup(app: APIManager):
app.get("/v1/styles/list", response_model=api_utils.BaseResponse)(list_styles)
app.get("/v1/styles/list", response_model=api_utils.BaseResponse, tags=["Style"])(
list_styles
)
9 changes: 6 additions & 3 deletions modules/api/impl/sys_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@


def setup(app: APIManager):
@app.get("/v1/ping", response_model=api_utils.BaseResponse)

@app.get("/v1/ping", response_model=api_utils.BaseResponse, tags=["System"])
async def ping():
return api_utils.success_response("pong")

@app.get("/v1/versions", response_model=api_utils.BaseResponse)
@app.get("/v1/versions", response_model=api_utils.BaseResponse, tags=["System"])
async def get_versions():
return api_utils.success_response(config.versions.to_dict())

@app.get("/v1/audio_formats", response_model=api_utils.BaseResponse)
@app.get(
"/v1/audio_formats", response_model=api_utils.BaseResponse, tags=["System"]
)
async def get_audio_formats():
return api_utils.success_response([e.value for e in AudioFormat])
26 changes: 5 additions & 21 deletions modules/api/impl/tts_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ class TTSParams(BaseModel):
spk: str = Query(
"female2", description="Specific speaker by speaker name or speaker seed"
)
ref_spk: Optional[str] = Query(
None, description="Specific speaker by speaker name or speaker seed"
)

style: str = Query("chat", description="Specific style by style name")
temperature: float = Query(
Expand Down Expand Up @@ -158,8 +155,6 @@ async def synthesize_tts(request: Request, params: TTSParams = Depends()):
else params.no_cache == "on"
)

ref_spk = params.ref_spk

if eos == "[uv_break]" and params.model != "chat-tts":
eos = " "

Expand Down Expand Up @@ -200,21 +195,8 @@ async def synthesize_tts(request: Request, params: TTSParams = Depends()):
bitrate=params.bitrate,
)

vc_config = VCConfig()
has_ref_spk = ref_spk is not None and isinstance(ref_spk, str) and ref_spk != ""
if has_ref_spk:
vc_config.enabled = True
try:
vc_config.spk = spk_mgr.get_speaker(ref_spk)
except Exception as e:
raise HTTPException(status_code=422, detail=str(e))

if not vc_config.spk.has_refs:
raise HTTPException(
status_code=422,
detail='Invalid "ref_spk", speaker has no refs data',
)

# NOTE: 这个接口不在支持 voice clone
vc_config = VCConfig(enabled=False)
handler = TTSHandler(
text_content=params.text,
spk=spk,
Expand All @@ -239,4 +221,6 @@ async def synthesize_tts(request: Request, params: TTSParams = Depends()):


def setup(api_manager: APIManager):
api_manager.get("/v1/tts", response_class=FileResponse)(synthesize_tts)
api_manager.get("/v1/tts", response_class=FileResponse, tags=["TTS"])(
synthesize_tts
)
Loading

0 comments on commit 27f5dfc

Please sign in to comment.