Skip to content

Commit

Permalink
BasePhoneme 不使用メソッドの削除 (#782)
Browse files Browse the repository at this point in the history
* Refactor BasePhoneme by removing unused methods

* Refactor BasePhoneme test by removing unused attr

* Refactor unused imports
  • Loading branch information
tarepan authored Nov 26, 2023
1 parent 44dc4b5 commit bad1209
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 158 deletions.
72 changes: 0 additions & 72 deletions test/test_acoustic_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
from pathlib import Path
from typing import List, Type
from unittest import TestCase

from voicevox_engine.acoustic_feature_extractor import BasePhoneme, OjtPhoneme
Expand All @@ -13,32 +10,6 @@ def setUp(self):
self.base_hello_hiho = [
BasePhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split())
]
self.lab_str = """
0.00 1.00 pau
1.00 2.00 k
2.00 3.00 o
3.00 4.00 N
4.00 5.00 n
5.00 6.00 i
6.00 7.00 ch
7.00 8.00 i
8.00 9.00 w
9.00 10.00 a
10.00 11.00 pau
11.00 12.00 h
12.00 13.00 i
13.00 14.00 h
14.00 15.00 o
15.00 16.00 d
16.00 17.00 e
17.00 18.00 s
18.00 19.00 U
19.00 20.00 pau
""".replace(
" ", ""
)[
1:-1
] # ダブルクオーテーションx3で囲われている部分で、空白をすべて置き換え、先頭と最後の"\n"を除外する

def test_repr_(self):
self.assertEqual(
Expand All @@ -53,34 +24,6 @@ def test_convert(self):
with self.assertRaises(NotImplementedError):
BasePhoneme.convert(self.base_hello_hiho)

def test_duration(self):
self.assertEqual(self.base_hello_hiho[1].duration, 1)

def test_parse(self):
parse_str_1 = "0 1 pau"
parse_str_2 = "32.67543 33.48933 e"
parsed_base_1 = BasePhoneme.parse(parse_str_1)
parsed_base_2 = BasePhoneme.parse(parse_str_2)
self.assertEqual(parsed_base_1.phoneme, "pau")
self.assertEqual(parsed_base_1.start, 0.0)
self.assertEqual(parsed_base_1.end, 1.0)
self.assertEqual(parsed_base_2.phoneme, "e")
self.assertEqual(parsed_base_2.start, 32.68)
self.assertEqual(parsed_base_2.end, 33.49)

def lab_test_base(
self,
file_path: str,
phonemes: List["BasePhoneme"],
phoneme_class: Type["BasePhoneme"],
):
phoneme_class.save_lab_list(phonemes, Path(file_path))
with open(file_path, mode="r") as f:
self.assertEqual(f.read(), self.lab_str)
result_phoneme = phoneme_class.load_lab_list(Path(file_path))
self.assertEqual(result_phoneme, phonemes)
os.remove(file_path)


class TestOjtPhoneme(TestBasePhoneme):
def setUp(self):
Expand Down Expand Up @@ -118,10 +61,6 @@ def test_equal(self):
self.assertFalse(self.ojt_hello_hiho[9] == false_ojt_phoneme_1)
self.assertFalse(self.ojt_hello_hiho[9] == false_ojt_phoneme_2)

def test_verify(self):
for phoneme in self.ojt_hello_hiho:
phoneme.verify()

def test_phoneme_id(self):
ojt_str_hello_hiho = " ".join([str(p.phoneme_id) for p in self.ojt_hello_hiho])
self.assertEqual(
Expand Down Expand Up @@ -157,14 +96,3 @@ def test_onehot(self):
self.assertEqual(phoneme.onehot[j], True)
else:
self.assertEqual(phoneme.onehot[j], False)

def test_parse(self):
parse_str_1 = "0 1 pau"
parse_str_2 = "32.67543 33.48933 e"
parsed_ojt_1 = OjtPhoneme.parse(parse_str_1)
parsed_ojt_2 = OjtPhoneme.parse(parse_str_2)
self.assertEqual(parsed_ojt_1.phoneme_id, 0)
self.assertEqual(parsed_ojt_2.phoneme_id, 14)

def tes_lab_list(self):
self.lab_test_base("./ojt_lab_test", self.ojt_hello_hiho, OjtPhoneme)
86 changes: 0 additions & 86 deletions voicevox_engine/acoustic_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import abstractmethod
from pathlib import Path
from typing import List, Sequence

import numpy
Expand Down Expand Up @@ -41,12 +40,6 @@ def __eq__(self, o: object):
self.phoneme == o.phoneme and self.start == o.start and self.end == o.end
)

def verify(self):
"""
音素クラスとして、データが正しいかassertする
"""
assert self.phoneme in self.phoneme_list, f"{self.phoneme} is not defined."

@property
def phoneme_id(self):
"""
Expand All @@ -58,17 +51,6 @@ def phoneme_id(self):
"""
return self.phoneme_list.index(self.phoneme)

@property
def duration(self):
"""
音素継続期間を取得する
Returns
-------
duration : int
音素継続期間を返す
"""
return self.end - self.start

@property
def onehot(self):
"""
Expand All @@ -82,79 +64,11 @@ def onehot(self):
array[self.phoneme_id] = True
return array

@classmethod
def parse(cls, s: str):
"""
文字列をパースして音素クラスを作る
Parameters
----------
s : str
パースしたい文字列
Returns
-------
phoneme : BasePhoneme
パース結果を用いた音素クラスを返す
Examples
--------
>>> BasePhoneme.parse('1.7425000 1.9125000 o:')
Phoneme(phoneme='o:', start=1.74, end=1.91)
"""
words = s.split()
return cls(
start=float(words[0]),
end=float(words[1]),
phoneme=words[2],
)

@classmethod
@abstractmethod
def convert(cls, phonemes: List["BasePhoneme"]) -> List["BasePhoneme"]:
raise NotImplementedError

@classmethod
def load_lab_list(cls, path: Path):
"""
labファイルを読み込む
Parameters
----------
path : Path
読み込みたいlabファイルのパス
Returns
-------
phonemes : List[BasePhoneme]
パース結果を用いた音素クラスを返す
"""
phonemes = [cls.parse(s) for s in path.read_text().split("\n") if len(s) > 0]
phonemes = cls.convert(phonemes)

for phoneme in phonemes:
phoneme.verify()
return phonemes

@classmethod
def save_lab_list(cls, phonemes: List["BasePhoneme"], path: Path):
"""
音素クラスのリストをlabファイル形式で保存する
Parameters
----------
phonemes : List[BasePhoneme]
保存したい音素クラスのリスト
path : Path
labファイルの保存先パス
"""
text = "\n".join(
[
f"{numpy.round(p.start, decimals=2):.2f}\t"
f"{numpy.round(p.end, decimals=2):.2f}\t"
f"{p.phoneme}"
for p in phonemes
]
)
path.write_text(text)


class OjtPhoneme(BasePhoneme):
"""
Expand Down

0 comments on commit bad1209

Please sign in to comment.