Skip to content

Commit

Permalink
feat!: [Python] EnumLiteral
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Jan 26, 2025
1 parent 33578f6 commit 5a70552
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 81 deletions.
91 changes: 40 additions & 51 deletions crates/voicevox_core_python_api/python/voicevox_core/_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import dataclasses
from enum import Enum
from typing import NewType
from typing import Literal, NewType, TypeAlias
from uuid import UUID

import pydantic
Expand Down Expand Up @@ -34,21 +33,19 @@
x : UUID
"""

StyleType: TypeAlias = Literal["talk", "singing_teacher", "frame_decode", "sing"]
"""
class StyleType(str, Enum):
"""**スタイル** (_style_)に対応するモデルの種類。"""

TALK = "talk"
"""音声合成クエリの作成と音声合成が可能。"""

SINGING_TEACHER = "singing_teacher"
"""歌唱音声合成用のクエリの作成が可能。"""

FRAME_DECODE = "frame_decode"
"""歌唱音声合成が可能。"""
**スタイル** (_style_)に対応するモデルの種類。
SING = "sing"
"""歌唱音声合成用のクエリの作成と歌唱音声合成が可能。"""
===================== ==================================================
値 説明
``"talk"`` 音声合成クエリの作成と音声合成が可能。
``"singing_teacher"`` 歌唱音声合成用のクエリの作成が可能。
``"frame_decode"`` 歌唱音声合成用のクエリの作成が可能。
``"sing"`` 歌唱音声合成用のクエリの作成と歌唱音声合成が可能。
===================== ==================================================
"""


@pydantic.dataclasses.dataclass
Expand All @@ -61,7 +58,7 @@ class StyleMeta:
id: StyleId
"""スタイルID。"""

type: StyleType = dataclasses.field(default=StyleType.TALK)
type: StyleType = dataclasses.field(default="talk")
"""スタイルに対応するモデルの種類。"""

order: int | None = None
Expand Down Expand Up @@ -129,21 +126,17 @@ class SupportedDevices:
"""


class AccelerationMode(str, Enum):
"""
ハードウェアアクセラレーションモードを設定する設定値。
"""

AUTO = "AUTO"
"""
実行環境に合った適切なハードウェアアクセラレーションモードを選択する。
"""

CPU = "CPU"
"""ハードウェアアクセラレーションモードを"CPU"に設定する。"""

GPU = "GPU"
"""ハードウェアアクセラレーションモードを"GPU"に設定する。"""
AccelerationMode: TypeAlias = Literal["AUTO", "CPU", "GPU"]
"""
ハードウェアアクセラレーションモードを設定する設定値。
========== ======================================================================
値 説明
``"AUTO"`` 実行環境に合った適切なハードウェアアクセラレーションモードを選択する。
``"CPU"`` ハードウェアアクセラレーションモードを"CPU"に設定する。
``"GPU"`` ハードウェアアクセラレーションモードを"GPU"に設定する。
========== ======================================================================
"""


@pydantic.dataclasses.dataclass
Expand Down Expand Up @@ -232,23 +225,21 @@ class AudioQuery:
"""


class UserDictWordType(str, Enum):
"""ユーザー辞書の単語の品詞。"""

PROPER_NOUN = "PROPER_NOUN"
"""固有名詞。"""

COMMON_NOUN = "COMMON_NOUN"
"""一般名詞。"""

VERB = "VERB"
"""動詞。"""

ADJECTIVE = "ADJECTIVE"
"""形容詞。"""

SUFFIX = "SUFFIX"
"""語尾。"""
UserDictWordType: TypeAlias = Literal[
"PROPER_NOUN", "COMMON_NOUN", "VERB", "ADJECTIVE", "SUFFIX"
]
"""
ユーザー辞書の単語の品詞。
================= ==========
値 説明
``"PROPER_NOUN"`` 固有名詞。
``"COMMON_NOUN"`` 一般名詞。
``"VERB"`` 動詞。
``"ADJECTIVE"`` 形容詞。
``"SUFFIX"`` 語尾。
================= ==========
"""


@pydantic.dataclasses.dataclass
Expand All @@ -272,9 +263,7 @@ class UserDictWord:
音が下がる場所を指す。
"""

word_type: UserDictWordType = dataclasses.field(
default=UserDictWordType.COMMON_NOUN
)
word_type: UserDictWordType = dataclasses.field(default="COMMON_NOUN")
"""品詞。"""

priority: int = dataclasses.field(default=5)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from os import PathLike
from typing import TYPE_CHECKING, Literal, Union
from typing import TYPE_CHECKING, Union
from uuid import UUID

if TYPE_CHECKING:
Expand Down Expand Up @@ -175,8 +175,7 @@ class Synthesizer:
self,
onnxruntime: Onnxruntime,
open_jtalk: OpenJtalk,
acceleration_mode: AccelerationMode
| Literal["AUTO", "CPU", "GPU"] = AccelerationMode.AUTO,
acceleration_mode: AccelerationMode = "AUTO",
cpu_num_threads: int = 0,
) -> None: ...
def __repr__(self) -> str: ...
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from os import PathLike
from typing import TYPE_CHECKING, Literal, Union
from typing import TYPE_CHECKING, Union
from uuid import UUID

if TYPE_CHECKING:
Expand Down Expand Up @@ -175,8 +175,7 @@ class Synthesizer:
self,
onnxruntime: Onnxruntime,
open_jtalk: OpenJtalk,
acceleration_mode: AccelerationMode
| Literal["AUTO", "CPU", "GPU"] = AccelerationMode.AUTO,
acceleration_mode: AccelerationMode = "AUTO",
cpu_num_threads: int = 0,
) -> None: ...
def __repr__(self) -> str: ...
Expand Down
33 changes: 13 additions & 20 deletions crates/voicevox_core_python_api/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use camino::Utf8PathBuf;
use easy_ext::ext;
use pyo3::{
exceptions::{PyException, PyRuntimeError, PyValueError},
types::{IntoPyDict as _, PyList},
types::{IntoPyDict as _, PyList, PyString},
FromPyObject as _, IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject,
};
use serde::{de::DeserializeOwned, Serialize};
use serde_json::json;
use uuid::Uuid;
use voicevox_core::{AccelerationMode, AccentPhrase, StyleId, UserDictWordType, VoiceModelMeta};
use voicevox_core::{AccelerationMode, AccentPhrase, StyleId, VoiceModelMeta};

use crate::{
AnalyzeTextError, GetSupportedDevicesError, GpuSupportError, InitInferenceRuntimeError,
Expand All @@ -21,19 +21,14 @@ use crate::{
};

pub(crate) fn from_acceleration_mode(ob: &PyAny) -> PyResult<AccelerationMode> {
let py = ob.py();

let class = py.import("voicevox_core")?.getattr("AccelerationMode")?;
let mode = class.get_item(ob)?;

if mode.eq(class.getattr("AUTO")?)? {
Ok(AccelerationMode::Auto)
} else if mode.eq(class.getattr("CPU")?)? {
Ok(AccelerationMode::Cpu)
} else if mode.eq(class.getattr("GPU")?)? {
Ok(AccelerationMode::Gpu)
} else {
unreachable!("{} should be one of {{AUTO, CPU, GPU}}", mode.repr()?);
match ob.extract::<&str>()? {
"AUTO" => Ok(AccelerationMode::Auto),
"CPU" => Ok(AccelerationMode::Cpu),
"GPU" => Ok(AccelerationMode::Gpu),
mode => Err(PyValueError::new_err(format!(
"`AccelerationMode` should be one of {{AUTO, CPU, GPU}}: {mode}",
mode = PyString::new(ob.py(), mode).repr()?,
))),
}
}

Expand Down Expand Up @@ -153,7 +148,7 @@ pub(crate) fn to_rust_user_dict_word(ob: &PyAny) -> PyResult<voicevox_core::User
ob.getattr("surface")?.extract()?,
ob.getattr("pronunciation")?.extract()?,
ob.getattr("accent_type")?.extract()?,
to_rust_word_type(ob.getattr("word_type")?.extract()?)?,
from_literal_choice(ob.getattr("word_type")?.extract()?)?,
ob.getattr("priority")?.extract()?,
)
.into_py_result(ob.py())
Expand All @@ -168,10 +163,8 @@ pub(crate) fn to_py_user_dict_word<'py>(
.downcast()?;
to_pydantic_dataclass(word, class)
}
pub(crate) fn to_rust_word_type(word_type: &PyAny) -> PyResult<UserDictWordType> {
let name = word_type.getattr("name")?.extract::<String>()?;

serde_json::from_value::<UserDictWordType>(json!(name)).into_py_value_result()
fn from_literal_choice<T: DeserializeOwned>(s: &str) -> PyResult<T> {
serde_json::from_value::<T>(json!(s)).into_py_value_result()
}

/// おおよそ以下のコードにおける`f(x)`のようなものを得る。
Expand Down
1 change: 1 addition & 0 deletions docs/guide/dev/api-design.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ VOICEVOX CORE の主要機能は Rust で実装されることを前提として
* [`StyleId`](https://voicevox.github.io/voicevox_core/apis/rust_api/voicevox_core/struct.StyleId.html)といった[newtype](https://rust-unofficial.github.io/patterns/patterns/behavioural/newtype.html)は、そのままnewtypeとして表現するべきです。
* 例えばPythonなら[`typing.NewType`](https://docs.python.org/ja/3/library/typing.html#newtype)で表現します。
* オプショナルな引数は、キーワード引数がある言語であればキーワード引数で、ビルダースタイルが一般的な言語であればビルダースタイルで表現すべきです。
* 列挙型は、PythonやTypeScriptでは文字列リテラルの合併型で表現するべきです。
* [`VoiceModelFile`](https://voicevox.github.io/voicevox_core/apis/rust_api/voicevox_core/nonblocking/struct.VoiceModelFile.html)の"close"後でも`id``metas`は利用可能であるべきです。ただしRustにおける"close"だけは、`VoiceModelFile``id``metas`に分解するような形にします。
* `Synthesizer::render``range: std::ops::Range<usize>`を引数に取っています。`Range<usize>`にあたる型が標準で存在し、かつそれが配列の範囲指定として用いられるようなものであれば、それを使うべきです。
* ただし例えばPythonでは、`slice`を引数に取るのは慣習にそぐわないため`start: int, stop: int`のようにすべきです。
Expand Down
10 changes: 8 additions & 2 deletions example/python/run-asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@ class Args:

@staticmethod
def parse_args() -> "Args":
ACCELERATION_MODE_CHOICES = ("AUTO", "CPU", "GPU")

def _(s: str):
if s in ACCELERATION_MODE_CHOICES:
_: AccelerationMode = s

argparser = ArgumentParser()
argparser.add_argument(
"--mode",
default="AUTO",
type=AccelerationMode,
help='モード ("AUTO", "CPU", "GPU")',
choices=ACCELERATION_MODE_CHOICES,
help="モード",
)
argparser.add_argument(
"vvm",
Expand Down
10 changes: 8 additions & 2 deletions example/python/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@ class Args:

@staticmethod
def parse_args() -> "Args":
ACCELERATION_MODE_CHOICES = ("AUTO", "CPU", "GPU")

def _(s: str):
if s in ACCELERATION_MODE_CHOICES:
_: AccelerationMode = s

argparser = ArgumentParser()
argparser.add_argument(
"--mode",
default="AUTO",
type=AccelerationMode,
help='モード ("AUTO", "CPU", "GPU")',
choices=ACCELERATION_MODE_CHOICES,
help="モード",
)
argparser.add_argument(
"vvm",
Expand Down

0 comments on commit 5a70552

Please sign in to comment.