diff --git a/README.md b/README.md index c69149e..76b7369 100644 --- a/README.md +++ b/README.md @@ -105,12 +105,31 @@ In [3]: pyopenjtalk.g2p("こんにちは", kana=True) Out[3]: 'コンニチワ' ``` +### About `run_marine` option + +After v0.3.0, the `run_marine` option has been available for estimating the Japanese accent with the DNN-based method (see [marine](https://github.com/6gsn/marine)). If you want to use the feature, please install pyopenjtalk as below; + +```shell +pip install pyopenjtalk[marine] +``` + +And then, you can use the option as the following examples; + +```python +In [1]: import pyopenjtalk + +In [2]: x, sr = pyopenjtalk.tts("おめでとうございます", run_marine=True) # for TTS + +In [3]: label = pyopenjtalk.extract_fullcontext("こんにちは", run_marine=True) # for text processing frontend only +``` + ## LICENSE - pyopenjtalk: MIT license ([LICENSE.md](LICENSE.md)) - Open JTalk: Modified BSD license ([COPYING](https://github.com/r9y9/open_jtalk/blob/1.10/src/COPYING)) - htsvoice in this repository: Please check [pyopenjtalk/htsvoice/README.md](pyopenjtalk/htsvoice/README.md). +- marine: Apache 2.0 license ([LICENSE](https://github.com/6gsn/marine/blob/main/LICENSE)) ## Acknowledgements diff --git a/docs/changelog.rst b/docs/changelog.rst index be60a9a..e75eee2 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,7 @@ v0.3.0 <2022-xx-xx> Newer numpy (>v1.20.0) is required to avoid ABI compatibility issues. Please check the updated installation guide. +* `#40`_: Introduce marine for Japanese accent estimation. Note that there could be a breakpoint regarding `run_frontend` because this PR changed the behavior of the API. * `#35`_: Fixes for Python 3.10. v0.2.0 <2022-02-06> @@ -90,3 +91,4 @@ Initial release with OpenJTalk's text processsing functionality .. _#27: https://github.com/r9y9/pyopenjtalk/issues/27 .. _#29: https://github.com/r9y9/pyopenjtalk/pull/29 .. _#35: https://github.com/r9y9/pyopenjtalk/pull/35 +.. _#40: https://github.com/r9y9/pyopenjtalk/pull/40 diff --git a/docs/pyopenjtalk.rst b/docs/pyopenjtalk.rst index 370d2f5..5e03e7b 100644 --- a/docs/pyopenjtalk.rst +++ b/docs/pyopenjtalk.rst @@ -25,3 +25,5 @@ Misc ---- .. autofunction:: run_frontend +.. autofunction:: make_label +.. autofunction:: estimate_accent diff --git a/pyopenjtalk/__init__.py b/pyopenjtalk/__init__.py index a266104..299ac29 100644 --- a/pyopenjtalk/__init__.py +++ b/pyopenjtalk/__init__.py @@ -19,6 +19,7 @@ from .htsengine import HTSEngine from .openjtalk import OpenJTalk +from .utils import merge_njd_marine_features # Dictionary directory # defaults to the package directory where the dictionary will be automatically downloaded @@ -39,6 +40,8 @@ # Global instance of HTSEngine # mei_normal.voice is used as default _global_htsengine = None +# Global instance of Marine +_global_marine = None # https://github.com/tqdm/tqdm#hooks-and-callbacks @@ -98,18 +101,53 @@ def g2p(*args, **kwargs): return _global_jtalk.g2p(*args, **kwargs) -def extract_fullcontext(text): +def estimate_accent(njd_features): + """Accent estimation using marine + + This function requires marine (https://github.com/6gsn/marine) + + Args: + njd_result (list): features generated by OpenJTalk. + + Returns: + list: features for NJDNode with estimation results by marine. + """ + global _global_marine + if _global_marine is None: + try: + from marine.predict import Predictor + except BaseException: + raise ImportError( + "Please install marine by `pip install pyopenjtalk[marine]`" + ) + _global_marine = Predictor() + from marine.utils.openjtalk_util import convert_njd_feature_to_marine_feature + + marine_feature = convert_njd_feature_to_marine_feature(njd_features) + marine_results = _global_marine.predict( + [marine_feature], require_open_jtalk_format=True + ) + njd_features = merge_njd_marine_features(njd_features, marine_results) + return njd_features + + +def extract_fullcontext(text, run_marine=False): """Extract full-context labels from text Args: text (str): Input text + run_marine (bool): Whether to estimate accent using marine. + Default is False. If you want to activate this option, you need to install marine + by `pip install pyopenjtalk[marine]` Returns: list: List of full-context labels """ - # note: drop first return - _, labels = run_frontend(text) - return labels + + njd_features = run_frontend(text) + if run_marine: + njd_features = estimate_accent(njd_features) + return make_label(njd_features) def synthesize(labels, speed=1.0, half_tone=0.0): @@ -136,34 +174,53 @@ def synthesize(labels, speed=1.0, half_tone=0.0): return _global_htsengine.synthesize(labels), sr -def tts(text, speed=1.0, half_tone=0.0): +def tts(text, speed=1.0, half_tone=0.0, run_marine=False): """Text-to-speech Args: text (str): Input text speed (float): speech speed rate. Default is 1.0. half_tone (float): additional half-tone. Default is 0. + run_marine (bool): Whether to estimate accent using marine. + Default is False. If you want activate this option, you need to install marine + by `pip install pyopenjtalk[marine]` Returns: np.ndarray: speech waveform (dtype: np.float64) int: sampling frequency (defualt: 48000) """ - return synthesize(extract_fullcontext(text), speed, half_tone) + return synthesize( + extract_fullcontext(text, run_marine=run_marine), speed, half_tone + ) -def run_frontend(text, verbose=0): +def run_frontend(text): """Run OpenJTalk's text processing frontend Args: text (str): Unicode Japanese text. - verbose (int): Verbosity. Default is 0. Returns: - tuple: Pair of 1) NJD_print and 2) JPCommon_make_label. - The latter is the full-context labels in HTS-style format. + list: features for NJDNode. + """ + global _global_jtalk + if _global_jtalk is None: + _lazy_init() + _global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR) + return _global_jtalk.run_frontend(text) + + +def make_label(njd_features): + """Make full-context label using features + + Args: + njd_features (list): features for NJDNode. + + Returns: + list: full-context labels. """ global _global_jtalk if _global_jtalk is None: _lazy_init() _global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR) - return _global_jtalk.run_frontend(text, verbose) + return _global_jtalk.make_label(njd_features) diff --git a/pyopenjtalk/openjtalk.pyx b/pyopenjtalk/openjtalk.pyx index 46e6004..f28509f 100644 --- a/pyopenjtalk/openjtalk.pyx +++ b/pyopenjtalk/openjtalk.pyx @@ -8,6 +8,7 @@ cimport numpy as np np.import_array() cimport cython +from libc.stdlib cimport calloc from openjtalk.mecab cimport Mecab, Mecab_initialize, Mecab_load, Mecab_analysis from openjtalk.mecab cimport Mecab_get_feature, Mecab_get_size, Mecab_refresh, Mecab_clear @@ -64,32 +65,57 @@ cdef njd_node_get_chain_flag(_njd.NJDNode* node): return _njd.NJDNode_get_chain_flag(node) -cdef njd_node_print(_njd.NJDNode* node): - return "{},{},{},{},{},{},{},{},{},{},{}/{},{},{}".format( - njd_node_get_string(node), - njd_node_get_pos(node), - njd_node_get_pos_group1(node), - njd_node_get_pos_group2(node), - njd_node_get_pos_group3(node), - njd_node_get_ctype(node), - njd_node_get_cform(node), - njd_node_get_orig(node), - njd_node_get_read(node), - njd_node_get_pron(node), - njd_node_get_acc(node), - njd_node_get_mora_size(node), - njd_node_get_chain_rule(node), - njd_node_get_chain_flag(node) - ) - - -cdef njd_print(_njd.NJD* njd): +cdef node2feature(_njd.NJDNode* node): + return { + "string": njd_node_get_string(node), + "pos": njd_node_get_pos(node), + "pos_group1": njd_node_get_pos_group1(node), + "pos_group2": njd_node_get_pos_group2(node), + "pos_group3": njd_node_get_pos_group3(node), + "ctype": njd_node_get_ctype(node), + "cform": njd_node_get_cform(node), + "orig": njd_node_get_orig(node), + "read": njd_node_get_read(node), + "pron": njd_node_get_pron(node), + "acc": njd_node_get_acc(node), + "mora_size": njd_node_get_mora_size(node), + "chain_rule": njd_node_get_chain_rule(node), + "chain_flag": njd_node_get_chain_flag(node), + } + + +cdef njd2feature(_njd.NJD* njd): cdef _njd.NJDNode* node = njd.head - njd_results = [] + features = [] while node is not NULL: - njd_results.append(njd_node_print(node)) + features.append(node2feature(node)) node = node.next - return njd_results + return features + + +cdef feature2njd(_njd.NJD* njd, features): + cdef _njd.NJDNode* node + + for feature_node in features: + node = <_njd.NJDNode *> calloc(1, sizeof(_njd.NJDNode)) + _njd.NJDNode_initialize(node) + # set values + _njd.NJDNode_set_string(node, feature_node["string"].encode("utf-8")) + _njd.NJDNode_set_pos(node, feature_node["pos"].encode("utf-8")) + _njd.NJDNode_set_pos_group1(node, feature_node["pos_group1"].encode("utf-8")) + _njd.NJDNode_set_pos_group2(node, feature_node["pos_group2"].encode("utf-8")) + _njd.NJDNode_set_pos_group3(node, feature_node["pos_group3"].encode("utf-8")) + _njd.NJDNode_set_ctype(node, feature_node["ctype"].encode("utf-8")) + _njd.NJDNode_set_cform(node, feature_node["cform"].encode("utf-8")) + _njd.NJDNode_set_orig(node, feature_node["orig"].encode("utf-8")) + _njd.NJDNode_set_read(node, feature_node["read"].encode("utf-8")) + _njd.NJDNode_set_pron(node, feature_node["pron"].encode("utf-8")) + _njd.NJDNode_set_acc(node, feature_node["acc"]) + _njd.NJDNode_set_mora_size(node, feature_node["mora_size"]) + _njd.NJDNode_set_chain_rule(node, feature_node["chain_rule"].encode("utf-8")) + _njd.NJDNode_set_chain_flag(node, feature_node["chain_flag"]) + _njd.NJD_push_node(njd, node) + cdef class OpenJTalk(object): """OpenJTalk @@ -125,12 +151,13 @@ cdef class OpenJTalk(object): return Mecab_load(self.mecab, dn_mecab) - def run_frontend(self, text, verbose=0): + def run_frontend(self, text): """Run OpenJTalk's text processing frontend """ - if isinstance(text, str): - text = text.encode("utf-8") cdef char buff[8192] + + if isinstance(text, str): + text = text.encode("utf-8") text2mecab(buff, text) Mecab_analysis(self.mecab, buff) mecab2njd(self.njd, Mecab_get_feature(self.mecab), Mecab_get_size(self.mecab)) @@ -140,7 +167,20 @@ cdef class OpenJTalk(object): _njd.njd_set_accent_type(self.njd) _njd.njd_set_unvoiced_vowel(self.njd) _njd.njd_set_long_vowel(self.njd) + features = njd2feature(self.njd) + + # Note that this will release memory for njd feature + NJD_refresh(self.njd) + Mecab_refresh(self.mecab) + + return features + + def make_label(self, features): + """Make full-context label + """ + feature2njd(self.njd, features) njd2jpcommon(self.jpcommon, self.njd) + JPCommon_make_label(self.jpcommon) cdef int label_size = JPCommon_get_label_size(self.jpcommon) @@ -153,23 +193,19 @@ cdef class OpenJTalk(object): # http://cython.readthedocs.io/en/latest/src/tutorial/strings.html labels.append(label_feature[i]) - njd_results = njd_print(self.njd) - - if verbose > 0: - NJD_print(self.njd) - # Note that this will release memory for label feature JPCommon_refresh(self.jpcommon) NJD_refresh(self.njd) - Mecab_refresh(self.mecab) - return njd_results, labels + return labels def g2p(self, text, kana=False, join=True): """Grapheme-to-phoeneme (G2P) conversion """ - njd_results, labels = self.run_frontend(text) + njd_features = self.run_frontend(text) + if not kana: + labels = self.make_label(njd_features) prons = list(map(lambda s: s.split("-")[1].split("+")[0], labels[1:-1])) if join: prons = " ".join(prons) @@ -177,12 +213,11 @@ cdef class OpenJTalk(object): # kana prons = [] - for n in njd_results: - row = n.split(",") - if row[1] == "記号": - p = row[0] + for n in njd_features: + if n["pos"] == "記号": + p = n["string"] else: - p = row[9] + p = n["pron"] # remove special chars for c in "’": p = p.replace(c,"") diff --git a/pyopenjtalk/utils.py b/pyopenjtalk/utils.py new file mode 100644 index 0000000..7aeb1bf --- /dev/null +++ b/pyopenjtalk/utils.py @@ -0,0 +1,21 @@ +def merge_njd_marine_features(njd_features, marine_results): + features = [] + + marine_accs = marine_results["accent_status"] + marine_chain_flags = marine_results["accent_phrase_boundary"] + + assert ( + len(njd_features) == len(marine_accs) == len(marine_chain_flags) + ), "Invalid sequence sizes in njd_results, marine_results" + + for node_index, njd_feature in enumerate(njd_features): + _feature = {} + for feature_key in njd_feature.keys(): + if feature_key == "acc": + _feature["acc"] = int(marine_accs[node_index]) + elif feature_key == "chain_flag": + _feature[feature_key] = int(marine_chain_flags[node_index]) + else: + _feature[feature_key] = njd_feature[feature_key] + features.append(_feature) + return features diff --git a/setup.py b/setup.py index c68db9c..62ec246 100644 --- a/setup.py +++ b/setup.py @@ -303,6 +303,7 @@ def run(self): "types-decorator", ], "test": ["pytest", "scipy"], + "marine": ["marine>=0.0.5"], }, classifiers=[ "Operating System :: POSIX", diff --git a/tests/test_openjtalk.py b/tests/test_openjtalk.py index 56b4a80..0f43363 100644 --- a/tests/test_openjtalk.py +++ b/tests/test_openjtalk.py @@ -1,10 +1,9 @@ import pyopenjtalk -def _print_results(njd_results, labels): - for n in njd_results: - row = n.split(",") - s, p = row[0], row[9] +def _print_results(njd_features, labels): + for f in njd_features: + s, p = f["string"], f["pron"] print(s, p) for label in labels: @@ -12,12 +11,37 @@ def _print_results(njd_results, labels): def test_hello(): - njd_results, labels = pyopenjtalk.run_frontend("こんにちは") - _print_results(njd_results, labels) + njd_features = pyopenjtalk.run_frontend("こんにちは") + labels = pyopenjtalk.make_label(njd_features) + _print_results(njd_features, labels) + + +def test_njd_features(): + njd_features = pyopenjtalk.run_frontend("こんにちは") + expected_feature = [ + { + "string": "こんにちは", + "pos": "感動詞", + "pos_group1": "*", + "pos_group2": "*", + "pos_group3": "*", + "ctype": "*", + "cform": "*", + "orig": "こんにちは", + "read": "コンニチハ", + "pron": "コンニチワ", + "acc": 0, + "mora_size": 5, + "chain_rule": "-1", + "chain_flag": -1, + } + ] + assert njd_features == expected_feature def test_fullcontext(): - _, labels = pyopenjtalk.run_frontend("こんにちは") + features = pyopenjtalk.run_frontend("こんにちは") + labels = pyopenjtalk.make_label(features) labels2 = pyopenjtalk.extract_fullcontext("こんにちは") for a, b in zip(labels, labels2): assert a == b @@ -30,10 +54,11 @@ def test_jtalk(): "どんまい!", "パソコンのとりあえず知っておきたい使い方", ]: - njd_results, labels = pyopenjtalk.run_frontend(text) - _print_results(njd_results, labels) + njd_features = pyopenjtalk.run_frontend(text) + labels = pyopenjtalk.make_label(njd_features) + _print_results(njd_features, labels) - surface = "".join(map(lambda s: s.split(",")[0], njd_results)) + surface = "".join(map(lambda f: f["string"], njd_features)) assert surface == text