Skip to content

Commit

Permalink
Merge pull request #40 from 6gsn/introduce/marine
Browse files Browse the repository at this point in the history
Introduce marine
  • Loading branch information
r9y9 authored Sep 19, 2022
2 parents d4c59cd + 60d85ff commit c676f56
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 60 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,31 @@ In [3]: pyopenjtalk.g2p("こんにちは", kana=True)
Out[3]: 'コンニチワ'
```

### About `run_marine` option

After v0.3.0, the `run_marine` option has been available for estimating the Japanese accent with the DNN-based method (see [marine](https://github.com/6gsn/marine)). If you want to use the feature, please install pyopenjtalk as below;

```shell
pip install pyopenjtalk[marine]
```

And then, you can use the option as the following examples;

```python
In [1]: import pyopenjtalk

In [2]: x, sr = pyopenjtalk.tts("おめでとうございます", run_marine=True) # for TTS

In [3]: label = pyopenjtalk.extract_fullcontext("こんにちは", run_marine=True) # for text processing frontend only
```


## LICENSE

- pyopenjtalk: MIT license ([LICENSE.md](LICENSE.md))
- Open JTalk: Modified BSD license ([COPYING](https://github.com/r9y9/open_jtalk/blob/1.10/src/COPYING))
- htsvoice in this repository: Please check [pyopenjtalk/htsvoice/README.md](pyopenjtalk/htsvoice/README.md).
- marine: Apache 2.0 license ([LICENSE](https://github.com/6gsn/marine/blob/main/LICENSE))

## Acknowledgements

Expand Down
2 changes: 2 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ v0.3.0 <2022-xx-xx>

Newer numpy (>v1.20.0) is required to avoid ABI compatibility issues. Please check the updated installation guide.

* `#40`_: Introduce marine for Japanese accent estimation. Note that there could be a breakpoint regarding `run_frontend` because this PR changed the behavior of the API.
* `#35`_: Fixes for Python 3.10.

v0.2.0 <2022-02-06>
Expand Down Expand Up @@ -90,3 +91,4 @@ Initial release with OpenJTalk's text processsing functionality
.. _#27: https://github.com/r9y9/pyopenjtalk/issues/27
.. _#29: https://github.com/r9y9/pyopenjtalk/pull/29
.. _#35: https://github.com/r9y9/pyopenjtalk/pull/35
.. _#40: https://github.com/r9y9/pyopenjtalk/pull/40
2 changes: 2 additions & 0 deletions docs/pyopenjtalk.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ Misc
----

.. autofunction:: run_frontend
.. autofunction:: make_label
.. autofunction:: estimate_accent
79 changes: 68 additions & 11 deletions pyopenjtalk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from .htsengine import HTSEngine
from .openjtalk import OpenJTalk
from .utils import merge_njd_marine_features

# Dictionary directory
# defaults to the package directory where the dictionary will be automatically downloaded
Expand All @@ -39,6 +40,8 @@
# Global instance of HTSEngine
# mei_normal.voice is used as default
_global_htsengine = None
# Global instance of Marine
_global_marine = None


# https://github.com/tqdm/tqdm#hooks-and-callbacks
Expand Down Expand Up @@ -98,18 +101,53 @@ def g2p(*args, **kwargs):
return _global_jtalk.g2p(*args, **kwargs)


def extract_fullcontext(text):
def estimate_accent(njd_features):
"""Accent estimation using marine
This function requires marine (https://github.com/6gsn/marine)
Args:
njd_result (list): features generated by OpenJTalk.
Returns:
list: features for NJDNode with estimation results by marine.
"""
global _global_marine
if _global_marine is None:
try:
from marine.predict import Predictor
except BaseException:
raise ImportError(
"Please install marine by `pip install pyopenjtalk[marine]`"
)
_global_marine = Predictor()
from marine.utils.openjtalk_util import convert_njd_feature_to_marine_feature

marine_feature = convert_njd_feature_to_marine_feature(njd_features)
marine_results = _global_marine.predict(
[marine_feature], require_open_jtalk_format=True
)
njd_features = merge_njd_marine_features(njd_features, marine_results)
return njd_features


def extract_fullcontext(text, run_marine=False):
"""Extract full-context labels from text
Args:
text (str): Input text
run_marine (bool): Whether to estimate accent using marine.
Default is False. If you want to activate this option, you need to install marine
by `pip install pyopenjtalk[marine]`
Returns:
list: List of full-context labels
"""
# note: drop first return
_, labels = run_frontend(text)
return labels

njd_features = run_frontend(text)
if run_marine:
njd_features = estimate_accent(njd_features)
return make_label(njd_features)


def synthesize(labels, speed=1.0, half_tone=0.0):
Expand All @@ -136,34 +174,53 @@ def synthesize(labels, speed=1.0, half_tone=0.0):
return _global_htsengine.synthesize(labels), sr


def tts(text, speed=1.0, half_tone=0.0):
def tts(text, speed=1.0, half_tone=0.0, run_marine=False):
"""Text-to-speech
Args:
text (str): Input text
speed (float): speech speed rate. Default is 1.0.
half_tone (float): additional half-tone. Default is 0.
run_marine (bool): Whether to estimate accent using marine.
Default is False. If you want activate this option, you need to install marine
by `pip install pyopenjtalk[marine]`
Returns:
np.ndarray: speech waveform (dtype: np.float64)
int: sampling frequency (defualt: 48000)
"""
return synthesize(extract_fullcontext(text), speed, half_tone)
return synthesize(
extract_fullcontext(text, run_marine=run_marine), speed, half_tone
)


def run_frontend(text, verbose=0):
def run_frontend(text):
"""Run OpenJTalk's text processing frontend
Args:
text (str): Unicode Japanese text.
verbose (int): Verbosity. Default is 0.
Returns:
tuple: Pair of 1) NJD_print and 2) JPCommon_make_label.
The latter is the full-context labels in HTS-style format.
list: features for NJDNode.
"""
global _global_jtalk
if _global_jtalk is None:
_lazy_init()
_global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR)
return _global_jtalk.run_frontend(text)


def make_label(njd_features):
"""Make full-context label using features
Args:
njd_features (list): features for NJDNode.
Returns:
list: full-context labels.
"""
global _global_jtalk
if _global_jtalk is None:
_lazy_init()
_global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR)
return _global_jtalk.run_frontend(text, verbose)
return _global_jtalk.make_label(njd_features)
113 changes: 74 additions & 39 deletions pyopenjtalk/openjtalk.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ cimport numpy as np
np.import_array()

cimport cython
from libc.stdlib cimport calloc

from openjtalk.mecab cimport Mecab, Mecab_initialize, Mecab_load, Mecab_analysis
from openjtalk.mecab cimport Mecab_get_feature, Mecab_get_size, Mecab_refresh, Mecab_clear
Expand Down Expand Up @@ -64,32 +65,57 @@ cdef njd_node_get_chain_flag(_njd.NJDNode* node):
return _njd.NJDNode_get_chain_flag(node)


cdef njd_node_print(_njd.NJDNode* node):
return "{},{},{},{},{},{},{},{},{},{},{}/{},{},{}".format(
njd_node_get_string(node),
njd_node_get_pos(node),
njd_node_get_pos_group1(node),
njd_node_get_pos_group2(node),
njd_node_get_pos_group3(node),
njd_node_get_ctype(node),
njd_node_get_cform(node),
njd_node_get_orig(node),
njd_node_get_read(node),
njd_node_get_pron(node),
njd_node_get_acc(node),
njd_node_get_mora_size(node),
njd_node_get_chain_rule(node),
njd_node_get_chain_flag(node)
)


cdef njd_print(_njd.NJD* njd):
cdef node2feature(_njd.NJDNode* node):
return {
"string": njd_node_get_string(node),
"pos": njd_node_get_pos(node),
"pos_group1": njd_node_get_pos_group1(node),
"pos_group2": njd_node_get_pos_group2(node),
"pos_group3": njd_node_get_pos_group3(node),
"ctype": njd_node_get_ctype(node),
"cform": njd_node_get_cform(node),
"orig": njd_node_get_orig(node),
"read": njd_node_get_read(node),
"pron": njd_node_get_pron(node),
"acc": njd_node_get_acc(node),
"mora_size": njd_node_get_mora_size(node),
"chain_rule": njd_node_get_chain_rule(node),
"chain_flag": njd_node_get_chain_flag(node),
}


cdef njd2feature(_njd.NJD* njd):
cdef _njd.NJDNode* node = njd.head
njd_results = []
features = []
while node is not NULL:
njd_results.append(njd_node_print(node))
features.append(node2feature(node))
node = node.next
return njd_results
return features


cdef feature2njd(_njd.NJD* njd, features):
cdef _njd.NJDNode* node

for feature_node in features:
node = <_njd.NJDNode *> calloc(1, sizeof(_njd.NJDNode))
_njd.NJDNode_initialize(node)
# set values
_njd.NJDNode_set_string(node, feature_node["string"].encode("utf-8"))
_njd.NJDNode_set_pos(node, feature_node["pos"].encode("utf-8"))
_njd.NJDNode_set_pos_group1(node, feature_node["pos_group1"].encode("utf-8"))
_njd.NJDNode_set_pos_group2(node, feature_node["pos_group2"].encode("utf-8"))
_njd.NJDNode_set_pos_group3(node, feature_node["pos_group3"].encode("utf-8"))
_njd.NJDNode_set_ctype(node, feature_node["ctype"].encode("utf-8"))
_njd.NJDNode_set_cform(node, feature_node["cform"].encode("utf-8"))
_njd.NJDNode_set_orig(node, feature_node["orig"].encode("utf-8"))
_njd.NJDNode_set_read(node, feature_node["read"].encode("utf-8"))
_njd.NJDNode_set_pron(node, feature_node["pron"].encode("utf-8"))
_njd.NJDNode_set_acc(node, feature_node["acc"])
_njd.NJDNode_set_mora_size(node, feature_node["mora_size"])
_njd.NJDNode_set_chain_rule(node, feature_node["chain_rule"].encode("utf-8"))
_njd.NJDNode_set_chain_flag(node, feature_node["chain_flag"])
_njd.NJD_push_node(njd, node)


cdef class OpenJTalk(object):
"""OpenJTalk
Expand Down Expand Up @@ -125,12 +151,13 @@ cdef class OpenJTalk(object):
return Mecab_load(self.mecab, dn_mecab)


def run_frontend(self, text, verbose=0):
def run_frontend(self, text):
"""Run OpenJTalk's text processing frontend
"""
if isinstance(text, str):
text = text.encode("utf-8")
cdef char buff[8192]

if isinstance(text, str):
text = text.encode("utf-8")
text2mecab(buff, text)
Mecab_analysis(self.mecab, buff)
mecab2njd(self.njd, Mecab_get_feature(self.mecab), Mecab_get_size(self.mecab))
Expand All @@ -140,7 +167,20 @@ cdef class OpenJTalk(object):
_njd.njd_set_accent_type(self.njd)
_njd.njd_set_unvoiced_vowel(self.njd)
_njd.njd_set_long_vowel(self.njd)
features = njd2feature(self.njd)

# Note that this will release memory for njd feature
NJD_refresh(self.njd)
Mecab_refresh(self.mecab)

return features

def make_label(self, features):
"""Make full-context label
"""
feature2njd(self.njd, features)
njd2jpcommon(self.jpcommon, self.njd)

JPCommon_make_label(self.jpcommon)

cdef int label_size = JPCommon_get_label_size(self.jpcommon)
Expand All @@ -153,36 +193,31 @@ cdef class OpenJTalk(object):
# http://cython.readthedocs.io/en/latest/src/tutorial/strings.html
labels.append(<unicode>label_feature[i])

njd_results = njd_print(self.njd)

if verbose > 0:
NJD_print(self.njd)

# Note that this will release memory for label feature
JPCommon_refresh(self.jpcommon)
NJD_refresh(self.njd)
Mecab_refresh(self.mecab)

return njd_results, labels
return labels

def g2p(self, text, kana=False, join=True):
"""Grapheme-to-phoeneme (G2P) conversion
"""
njd_results, labels = self.run_frontend(text)
njd_features = self.run_frontend(text)

if not kana:
labels = self.make_label(njd_features)
prons = list(map(lambda s: s.split("-")[1].split("+")[0], labels[1:-1]))
if join:
prons = " ".join(prons)
return prons

# kana
prons = []
for n in njd_results:
row = n.split(",")
if row[1] == "記号":
p = row[0]
for n in njd_features:
if n["pos"] == "記号":
p = n["string"]
else:
p = row[9]
p = n["pron"]
# remove special chars
for c in "":
p = p.replace(c,"")
Expand Down
21 changes: 21 additions & 0 deletions pyopenjtalk/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
def merge_njd_marine_features(njd_features, marine_results):
features = []

marine_accs = marine_results["accent_status"]
marine_chain_flags = marine_results["accent_phrase_boundary"]

assert (
len(njd_features) == len(marine_accs) == len(marine_chain_flags)
), "Invalid sequence sizes in njd_results, marine_results"

for node_index, njd_feature in enumerate(njd_features):
_feature = {}
for feature_key in njd_feature.keys():
if feature_key == "acc":
_feature["acc"] = int(marine_accs[node_index])
elif feature_key == "chain_flag":
_feature[feature_key] = int(marine_chain_flags[node_index])
else:
_feature[feature_key] = njd_feature[feature_key]
features.append(_feature)
return features
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def run(self):
"types-decorator",
],
"test": ["pytest", "scipy"],
"marine": ["marine>=0.0.5"],
},
classifiers=[
"Operating System :: POSIX",
Expand Down
Loading

0 comments on commit c676f56

Please sign in to comment.