Skip to content

Commit

Permalink
fix types and format
Browse files Browse the repository at this point in the history
  • Loading branch information
tuna2134 committed Aug 18, 2023
1 parent e356ee2 commit 4d22817
Show file tree
Hide file tree
Showing 12 changed files with 78 additions and 104 deletions.
4 changes: 1 addition & 3 deletions examples/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

async def main():
async with Client() as client:
audio_query = await client.create_audio_query(
"こんにちは!", speaker=1
)
audio_query = await client.create_audio_query("こんにちは!", speaker=1)
with open("voice.wav", "wb") as f:
f.write(await audio_query.synthesis(speaker=1))

Expand Down
4 changes: 1 addition & 3 deletions examples/multi_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ async def main():
audio_queries = []
async with Client() as client:
for text in ["おはようございます!", "こんにちは!"]:
audio_queries.append(await client.create_audio_query(
text, speaker=1
))
audio_queries.append(await client.create_audio_query(text, speaker=1))
with open("audio.zip", "wb") as f:
f.write(await client.multi_synthesis(audio_queries, speaker=1))

Expand Down
4 changes: 1 addition & 3 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,5 @@
@pytest.mark.asyncio
async def test_basic():
async with Client() as client:
audio_query = await client.create_audio_query(
"こんにちは!", speaker=1
)
audio_query = await client.create_audio_query("こんにちは!", speaker=1)
await audio_query.synthesis(speaker=1)
7 changes: 5 additions & 2 deletions voicevox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
__all__ = (
"Client",
"AudioQuery",
"HttpException", "NotfoundError",
"Speaker", "Style", "SupportedFeature"
"HttpException",
"NotfoundError",
"Speaker",
"Style",
"SupportedFeature",
)
__version__ = "0.2.1"
29 changes: 14 additions & 15 deletions voicevox/audio_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def __init__(self, payload: MoraType):
self.text: str = payload["text"]
self.consonant: str = payload["consonant"]
self.consonant_length: int = payload["consonant_length"]
self.vowel: int = payload["vowel"]
self.vowel: str = payload["vowel"]
self.vowel_length: int = payload["vowel_length"]
self.pitch: int = payload["pitch"]
self.pitch: float = payload["pitch"]

def to_dict(self) -> dict:
return {
Expand All @@ -41,7 +41,7 @@ def to_dict(self) -> dict:
"consonant_length": self.consonant_length,
"vowel": self.vowel,
"vowel_length": self.vowel_length,
"pitch": self.pitch
"pitch": self.pitch,
}


Expand Down Expand Up @@ -74,7 +74,7 @@ def to_dict(self) -> AccentPhraseType:
payload = {
"moras": [mora.to_dict() for mora in self.moras],
"accent": self.accent,
"is_interrogative": self.is_interrogative
"is_interrogative": self.is_interrogative,
}
if self.pause_mora is not None:
payload["pause_mora"] = self.pause_mora.to_dict()
Expand Down Expand Up @@ -112,15 +112,12 @@ class AudioQuery:
[読み取り専用]AquesTalkライクな読み仮名。音声合成クエリとしては無視される
"""

def __init__(
self, http: HttpClient, payload: AudioQueryType
):
def __init__(self, http: HttpClient, payload: AudioQueryType):
self.http = http
self.__data = payload

self.accent_phrases: List[AccentPhrase] = [
AccentPhrase(accent_phrase)
for accent_phrase in payload["accent_phrases"]
AccentPhrase(accent_phrase) for accent_phrase in payload["accent_phrases"]
]
self.speed_scale: float = payload["speedScale"]
self.pitch_scale: float = payload["pitchScale"]
Expand All @@ -138,8 +135,7 @@ def kana(self) -> str:
def to_dict(self) -> AudioQueryType:
return {
"accent_phrases": [
accent_phrase.to_dict()
for accent_phrase in self.accent_phrases
accent_phrase.to_dict() for accent_phrase in self.accent_phrases
],
"speedScale": self.speed_scale,
"pitchScale": self.pitch_scale,
Expand All @@ -149,16 +145,19 @@ def to_dict(self) -> AudioQueryType:
"postPhonemeLength": self.post_phoneme_length,
"outputSamplingRate": self.output_sampling_rate,
"outputStereo": self.output_stereo,
"kana": self.kana
"kana": self.kana,
}

async def synthesis(
self, *, enable_interrogative_upspeak: bool = True,
speaker: int, core_version: Optional[str] = None
self,
*,
enable_interrogative_upspeak: bool = True,
speaker: int,
core_version: Optional[str] = None,
) -> bytes:
params = {
"speaker": speaker,
"enable_interrogative_upspeak": enable_interrogative_upspeak
"enable_interrogative_upspeak": enable_interrogative_upspeak,
}
if core_version is not None:
params["core_version"] = core_version
Expand Down
60 changes: 27 additions & 33 deletions voicevox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ class Client:
You can customize timeout. If you use cpu mode, I recommend to use this.
"""

def __init__(self, base_url: str = "http://localhost:50021", timeout: Optional[int] = None):
def __init__(
self, base_url: str = "http://localhost:50021", timeout: Optional[int] = None
):
self.http = HttpClient(base_url=base_url, timeout=timeout)

async def close(self) -> None:
Expand Down Expand Up @@ -70,10 +72,7 @@ async def create_audio_query(
AudioQuery
Audio query, that run synthesis.
"""
params = {
"text": text,
"speaker": speaker
}
params = {"text": text, "speaker": speaker}
if core_version is not None:
params["core_version"] = core_version
audio_query = await self.http.create_audio_query(params)
Expand All @@ -83,10 +82,7 @@ async def create_audio_query(
async def create_audio_query_from_preset(
self, text: str, preset_id: int, *, core_version: Optional[str] = None
) -> AudioQuery:
params = {
"text": text,
"preset_id": preset_id
}
params = {"text": text, "preset_id": preset_id}
if core_version is not None:
params["core_version"] = core_version
audio_query = await self.http.create_audio_query_from_preset(params)
Expand Down Expand Up @@ -116,9 +112,7 @@ async def fetch_core_versions(self) -> List[str]:
"""
return await self.http.get_core_versions()

async def fetch_speakers(
self, core_version: Optional[str] = None
) -> List[Speaker]:
async def fetch_speakers(self, core_version: Optional[str] = None) -> List[Speaker]:
"""Fetch speakers
This can fetch voicevox speakers.
Expand All @@ -130,7 +124,7 @@ async def fetch_speakers(
"""
speakers = await self.http.get_speakers(core_version)
return [Speaker(speaker) for speaker in speakers]

async def fetch_speaker_info(
self, speaker_uuid: str, core_version: Optional[str] = None
) -> SpeakerInfo:
Expand All @@ -150,17 +144,22 @@ async def fetch_speaker_info(
SpeakerInfo
Contains additional information of the speaker.
"""
return SpeakerInfo(await self.http.get_speaker_info(speaker_uuid, core_version))
return SpeakerInfo(await self.http.get_speaker_info(speaker_uuid, core_version))

async def check_devices(self, core_version: Optional[str] = None) -> SupportedDevices:
async def check_devices(
self, core_version: Optional[str] = None
) -> SupportedDevices:
params = {}
if core_version:
params["core_version"] = core_version
return SupportedDevices(await self.http.supported_devices(params))

async def multi_synthesis(
self, audio_queries: List[AudioQuery], speaker: int,
*, core_version: Optional[str] = None
self,
audio_queries: List[AudioQuery],
speaker: int,
*,
core_version: Optional[str] = None
) -> bytes:
"""Multi synthe
Expand All @@ -179,20 +178,18 @@ async def multi_synthesis(
-------
bytes
Return zip file"""
params = {
"speaker": speaker
}
params = {"speaker": speaker}
if core_version is not None:
params["core_version"] = core_version
return await self.http.multi_synthesis(
params, [
audio_query.to_dict()
for audio_query in audio_queries
]
params, [audio_query.to_dict() for audio_query in audio_queries]
)

async def init_speaker(
self, speaker: int, *, skip_reinit: bool = False,
self,
speaker: int,
*,
skip_reinit: bool = False,
core_version: Optional[str] = None
) -> None:
"""Initilize speaker
Expand All @@ -210,15 +207,14 @@ async def init_speaker(
who have already been initialized
core_version: Optional[str]
core version"""
params = {
"speaker": speaker,
"skip_reinit": skip_reinit
}
params = {"speaker": speaker, "skip_reinit": skip_reinit}
if core_version is not None:
params["core_version"] = core_version
await self.http.initialize_speaker(params)

async def check_inited_speaker(self, speaker: int, *, core_version: Optional[str] = None):
async def check_inited_speaker(
self, speaker: int, *, core_version: Optional[str] = None
):
"""Check initialized speaker
Returns whether the speaker with the specified speaker_id is initialized or not.
Expand All @@ -234,9 +230,7 @@ async def check_inited_speaker(self, speaker: int, *, core_version: Optional[str
-------
bool
If initialized speaker, it return `True`."""
params = {
"speaker": speaker
}
params = {"speaker": speaker}
if core_version is not None:
params["core_version"] = core_version
return await self.http.is_initialized_speaker(params)
1 change: 1 addition & 0 deletions voicevox/errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# voicevox - errors


class HttpException(Exception):
pass

Expand Down
46 changes: 14 additions & 32 deletions voicevox/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@

from .errors import NotfoundError, HttpException
from .types import AudioQueryType, SpeakerType
from .types.speaker_info import SpeakerInfoType
from .types.speaker_info import SpeakerInfoType


logger = logging.getLogger(__name__)


class HttpClient:

def __init__(self, base_url: str, timeout: Optional[int] = None):
self.session = AsyncClient(base_url=base_url, timeout=timeout)
logger.debug("Start session.")
Expand All @@ -28,9 +27,10 @@ async def close(self) -> None:
async def request(self, method: str, path: str, **kwargs) -> dict:
logger.debug(f"Request: {method} Path: {path} kwargs: {kwargs}")
response = await self.session.request(method, path, **kwargs)
logger.debug("StatusCode: {0.status_code} Response: {0.content}".format(response))
if response.status_code == 200 or \
response.status_code == 204:
logger.debug(
"StatusCode: {0.status_code} Response: {0.content}".format(response)
)
if response.status_code == 200 or response.status_code == 204:
if response.headers.get("content-type") == "application/json":
return response.json()
else:
Expand All @@ -41,42 +41,26 @@ async def request(self, method: str, path: str, **kwargs) -> dict:
raise HttpException(response.json())

async def synthesis(self, params: dict, payload: dict) -> bytes:
return await self.request(
"POST", "/synthesis", params=params,
json=payload
)
return await self.request("POST", "/synthesis", params=params, json=payload)

async def multi_synthesis(
self, params: dict, payload: List[dict]
) -> bytes:
async def multi_synthesis(self, params: dict, payload: List[dict]) -> bytes:
return await self.request(
"POST", "/multi_synthesis", params=params,
json=payload
"POST", "/multi_synthesis", params=params, json=payload
)

async def create_audio_query(self, params: dict) -> AudioQueryType:
return await self.request(
"POST", "/audio_query", params=params
)
return await self.request("POST", "/audio_query", params=params)

async def create_audio_query_from_preset(
self, params: dict
) -> AudioQueryType:
return await self.request(
"POST", "/audio_query_from_preset", params=params
)
async def create_audio_query_from_preset(self, params: dict) -> AudioQueryType:
return await self.request("POST", "/audio_query_from_preset", params=params)

async def get_version(self) -> str:
return await self.request(
"GET", "/version"
)
return await self.request("GET", "/version")

async def get_core_versions(self) -> List[str]:
return await self.request("GET", "/core_versions")

async def get_speakers(
self, core_version: Optional[str]
) -> List[SpeakerType]:
async def get_speakers(self, core_version: Optional[str]) -> List[SpeakerType]:
params = {}
if core_version is not None:
params["core_version"] = core_version
Expand All @@ -85,9 +69,7 @@ async def get_speakers(
async def get_speaker_info(
self, speaker_uuid: str, core_version: Optional[str]
) -> SpeakerInfoType:
params = {
"speaker_uuid": speaker_uuid
}
params = {"speaker_uuid": speaker_uuid}
if core_version is not None:
params["core_version"] = core_version
return await self.request("GET", "/speaker_info", params=params)
Expand Down
8 changes: 4 additions & 4 deletions voicevox/speaker_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ def __init__(self, payload: StyleInfoType):
@property
def id(self) -> int:
return self.__data["id"]

@property
def icon(self) -> str:
return self.__data["icon"]

@property
def portrait(self) -> str:
return self.__data["portrait"]

@property
def voice_samples(self) -> list[str]:
return self.__data["voice_samples"]


class SpeakerInfo:
"""Return speaker info
Expand Down
Loading

0 comments on commit 4d22817

Please sign in to comment.