diff --git a/pyopenjtalk/htsengine.pyx b/pyopenjtalk/htsengine.pyx index 9936a6e..b8ed5e6 100644 --- a/pyopenjtalk/htsengine.pyx +++ b/pyopenjtalk/htsengine.pyx @@ -2,6 +2,9 @@ # cython: boundscheck=True, wraparound=True # cython: c_string_type=unicode, c_string_encoding=ascii +from contextlib import contextmanager +from threading import RLock + import numpy as np cimport numpy as np @@ -19,6 +22,17 @@ from .htsengine cimport ( HTS_Engine_get_generated_speech, HTS_Engine_get_nsamples ) +def _generate_lock_manager(): + lock = RLock() + + @contextmanager + def f(): + with lock: + yield + + return f + + cdef class HTSEngine(object): """HTSEngine @@ -26,6 +40,7 @@ cdef class HTSEngine(object): voice (bytes): File path of htsvoice. """ cdef HTS_Engine* engine + _lock_manager = _generate_lock_manager() def __cinit__(self, bytes voice=b"htsvoice/mei_normal.htsvoice"): self.engine = new HTS_Engine() @@ -33,24 +48,29 @@ cdef class HTSEngine(object): HTS_Engine_initialize(self.engine) if self.load(voice) != 1: - self.clear() - raise RuntimeError("Failed to initalize HTS_Engine") + self.clear() + raise RuntimeError("Failed to initalize HTS_Engine") + @_lock_manager() def load(self, bytes voice): cdef char* voices = voice cdef char ret - ret = HTS_Engine_load(self.engine, &voices, 1) + with nogil: + ret = HTS_Engine_load(self.engine, &voices, 1) return ret + @_lock_manager() def get_sampling_frequency(self): """Get sampling frequency """ return HTS_Engine_get_sampling_frequency(self.engine) + @_lock_manager() def get_fperiod(self): """Get frame period""" return HTS_Engine_get_fperiod(self.engine) + @_lock_manager() def set_speed(self, speed=1.0): """Set speed @@ -59,6 +79,7 @@ cdef class HTSEngine(object): """ HTS_Engine_set_speed(self.engine, speed) + @_lock_manager() def add_half_tone(self, half_tone=0.0): """Additional half tone in log-f0 @@ -67,6 +88,7 @@ cdef class HTSEngine(object): """ HTS_Engine_add_half_tone(self.engine, half_tone) + @_lock_manager() def synthesize(self, list labels): """Synthesize waveform from list of full-context labels @@ -81,34 +103,43 @@ cdef class HTSEngine(object): self.refresh() return x + @_lock_manager() def synthesize_from_strings(self, list labels): """Synthesize from strings""" cdef size_t num_lines = len(labels) cdef char **lines = malloc((num_lines + 1) * sizeof(char*)) - for n in range(len(labels)): + for n in range(num_lines): lines[n] = labels[n] - cdef char ret = HTS_Engine_synthesize_from_strings(self.engine, lines, num_lines) - free(lines) + cdef char ret + with nogil: + ret = HTS_Engine_synthesize_from_strings(self.engine, lines, num_lines) + free(lines) if ret != 1: raise RuntimeError("Failed to run synthesize_from_strings") + @_lock_manager() def get_generated_speech(self): """Get generated speech""" cdef size_t nsamples = HTS_Engine_get_nsamples(self.engine) - cdef np.ndarray speech = np.zeros([nsamples], dtype=np.float64) + cdef np.ndarray speech = np.empty([nsamples], dtype=np.float64) + cdef double[:] speech_view = speech cdef size_t index - for index in range(nsamples): - speech[index] = HTS_Engine_get_generated_speech(self.engine, index) + with (nogil, cython.boundscheck(False)): + for index in range(nsamples): + speech_view[index] = HTS_Engine_get_generated_speech(self.engine, index) return speech + @_lock_manager() def get_fullcontext_label_format(self): """Get full-context label format""" return (HTS_Engine_get_fullcontext_label_format(self.engine)).decode("utf-8") + @_lock_manager() def refresh(self): - HTS_Engine_refresh(self.engine) + HTS_Engine_refresh(self.engine) + @_lock_manager() def clear(self): HTS_Engine_clear(self.engine) diff --git a/pyopenjtalk/htsengine/__init__.pxd b/pyopenjtalk/htsengine/__init__.pxd index c24d959..d8f4fe4 100644 --- a/pyopenjtalk/htsengine/__init__.pxd +++ b/pyopenjtalk/htsengine/__init__.pxd @@ -7,16 +7,16 @@ cdef extern from "HTS_engine.h": ctypedef _HTS_Engine HTS_Engine void HTS_Engine_initialize(HTS_Engine * engine) - char HTS_Engine_load(HTS_Engine * engine, char **voices, size_t num_voices) + char HTS_Engine_load(HTS_Engine * engine, char **voices, size_t num_voices) nogil size_t HTS_Engine_get_sampling_frequency(HTS_Engine * engine) size_t HTS_Engine_get_fperiod(HTS_Engine * engine) void HTS_Engine_refresh(HTS_Engine * engine) void HTS_Engine_clear(HTS_Engine * engine) const char *HTS_Engine_get_fullcontext_label_format(HTS_Engine * engine) - char HTS_Engine_synthesize_from_strings(HTS_Engine * engine, char **lines, size_t num_lines) + char HTS_Engine_synthesize_from_strings(HTS_Engine * engine, char **lines, size_t num_lines) nogil char HTS_Engine_synthesize_from_fn(HTS_Engine * engine, const char *fn) - double HTS_Engine_get_generated_speech(HTS_Engine * engine, size_t index) + double HTS_Engine_get_generated_speech(HTS_Engine * engine, size_t index) nogil size_t HTS_Engine_get_nsamples(HTS_Engine * engine) void HTS_Engine_set_speed(HTS_Engine * engine, double f) - void HTS_Engine_add_half_tone(HTS_Engine * engine, double f) \ No newline at end of file + void HTS_Engine_add_half_tone(HTS_Engine * engine, double f) diff --git a/pyopenjtalk/openjtalk.pyx b/pyopenjtalk/openjtalk.pyx index 291311e..0113844 100644 --- a/pyopenjtalk/openjtalk.pyx +++ b/pyopenjtalk/openjtalk.pyx @@ -2,12 +2,9 @@ # cython: boundscheck=True, wraparound=True # cython: c_string_type=unicode, c_string_encoding=ascii -import numpy as np +from contextlib import contextmanager +from threading import Lock -cimport numpy as np -np.import_array() - -cimport cython from libc.stdlib cimport calloc from libc.string cimport strlen @@ -20,7 +17,6 @@ from .openjtalk cimport njd as _njd from .openjtalk.jpcommon cimport JPCommon, JPCommon_initialize,JPCommon_make_label from .openjtalk.jpcommon cimport JPCommon_get_label_size, JPCommon_get_label_feature from .openjtalk.jpcommon cimport JPCommon_refresh, JPCommon_clear -from .openjtalk cimport njd2jpcommon from .openjtalk.text2mecab cimport text2mecab from .openjtalk.mecab2njd cimport mecab2njd from .openjtalk.njd2jpcommon cimport njd2jpcommon @@ -55,48 +51,48 @@ cdef njd_node_get_read(_njd.NJDNode* node): cdef njd_node_get_pron(_njd.NJDNode* node): return ((_njd.NJDNode_get_pron(node))).decode("utf-8") -cdef njd_node_get_acc(_njd.NJDNode* node): +cdef int njd_node_get_acc(_njd.NJDNode* node) noexcept: return _njd.NJDNode_get_acc(node) -cdef njd_node_get_mora_size(_njd.NJDNode* node): +cdef int njd_node_get_mora_size(_njd.NJDNode* node) noexcept: return _njd.NJDNode_get_mora_size(node) cdef njd_node_get_chain_rule(_njd.NJDNode* node): return ((_njd.NJDNode_get_chain_rule(node))).decode("utf-8") -cdef njd_node_get_chain_flag(_njd.NJDNode* node): - return _njd.NJDNode_get_chain_flag(node) +cdef int njd_node_get_chain_flag(_njd.NJDNode* node) noexcept: + return _njd.NJDNode_get_chain_flag(node) cdef node2feature(_njd.NJDNode* node): - return { - "string": njd_node_get_string(node), - "pos": njd_node_get_pos(node), - "pos_group1": njd_node_get_pos_group1(node), - "pos_group2": njd_node_get_pos_group2(node), - "pos_group3": njd_node_get_pos_group3(node), - "ctype": njd_node_get_ctype(node), - "cform": njd_node_get_cform(node), - "orig": njd_node_get_orig(node), - "read": njd_node_get_read(node), - "pron": njd_node_get_pron(node), - "acc": njd_node_get_acc(node), - "mora_size": njd_node_get_mora_size(node), - "chain_rule": njd_node_get_chain_rule(node), - "chain_flag": njd_node_get_chain_flag(node), - } + return { + "string": njd_node_get_string(node), + "pos": njd_node_get_pos(node), + "pos_group1": njd_node_get_pos_group1(node), + "pos_group2": njd_node_get_pos_group2(node), + "pos_group3": njd_node_get_pos_group3(node), + "ctype": njd_node_get_ctype(node), + "cform": njd_node_get_cform(node), + "orig": njd_node_get_orig(node), + "read": njd_node_get_read(node), + "pron": njd_node_get_pron(node), + "acc": njd_node_get_acc(node), + "mora_size": njd_node_get_mora_size(node), + "chain_rule": njd_node_get_chain_rule(node), + "chain_flag": njd_node_get_chain_flag(node), + } cdef njd2feature(_njd.NJD* njd): cdef _njd.NJDNode* node = njd.head features = [] while node is not NULL: - features.append(node2feature(node)) - node = node.next + features.append(node2feature(node)) + node = node.next return features -cdef feature2njd(_njd.NJD* njd, features): +cdef void feature2njd(_njd.NJD* njd, features): cdef _njd.NJDNode* node for feature_node in features: @@ -120,7 +116,7 @@ cdef feature2njd(_njd.NJD* njd, features): _njd.NJD_push_node(njd, node) # based on Mecab_load in impl. from mecab.cpp -cdef inline int Mecab_load_with_userdic(Mecab *m, char* dicdir, char* userdic): +cdef inline int Mecab_load_with_userdic(Mecab *m, char* dicdir, char* userdic) noexcept nogil: if userdic == NULL or strlen(userdic) == 0: return Mecab_load(m, dicdir) @@ -150,6 +146,16 @@ cdef inline int Mecab_load_with_userdic(Mecab *m, char* dicdir, char* userdic): return 1 +def _generate_lock_manager(): + lock = Lock() + + @contextmanager + def f(): + with lock: + yield + + return f + cdef class OpenJTalk(object): """OpenJTalk @@ -163,31 +169,35 @@ cdef class OpenJTalk(object): cdef Mecab* mecab cdef NJD* njd cdef JPCommon* jpcommon + _lock_manager = _generate_lock_manager() def __cinit__(self, bytes dn_mecab=b"/usr/local/dic", bytes userdic=b""): + cdef char* _dn_mecab = dn_mecab + cdef char* _userdic = userdic + self.mecab = new Mecab() self.njd = new NJD() self.jpcommon = new JPCommon() - Mecab_initialize(self.mecab) - NJD_initialize(self.njd) - JPCommon_initialize(self.jpcommon) - - r = self._load(dn_mecab, userdic) - if r != 1: - self._clear() - raise RuntimeError("Failed to initalize Mecab") + with nogil: + Mecab_initialize(self.mecab) + NJD_initialize(self.njd) + JPCommon_initialize(self.jpcommon) + r = self._load(_dn_mecab, _userdic) + if r != 1: + self._clear() + raise RuntimeError("Failed to initalize Mecab") - def _clear(self): - Mecab_clear(self.mecab) - NJD_clear(self.njd) - JPCommon_clear(self.jpcommon) + cdef void _clear(self) noexcept nogil: + Mecab_clear(self.mecab) + NJD_clear(self.njd) + JPCommon_clear(self.jpcommon) - def _load(self, bytes dn_mecab, bytes userdic): + cdef int _load(self, char* dn_mecab, char* userdic) noexcept nogil: return Mecab_load_with_userdic(self.mecab, dn_mecab, userdic) - + @_lock_manager() def run_frontend(self, text): """Run OpenJTalk's text processing frontend """ @@ -195,15 +205,17 @@ cdef class OpenJTalk(object): if isinstance(text, str): text = text.encode("utf-8") - text2mecab(buff, text) - Mecab_analysis(self.mecab, buff) - mecab2njd(self.njd, Mecab_get_feature(self.mecab), Mecab_get_size(self.mecab)) - _njd.njd_set_pronunciation(self.njd) - _njd.njd_set_digit(self.njd) - _njd.njd_set_accent_phrase(self.njd) - _njd.njd_set_accent_type(self.njd) - _njd.njd_set_unvoiced_vowel(self.njd) - _njd.njd_set_long_vowel(self.njd) + cdef const char* _text = text + with nogil: + text2mecab(buff, _text) + Mecab_analysis(self.mecab, buff) + mecab2njd(self.njd, Mecab_get_feature(self.mecab), Mecab_get_size(self.mecab)) + _njd.njd_set_pronunciation(self.njd) + _njd.njd_set_digit(self.njd) + _njd.njd_set_accent_phrase(self.njd) + _njd.njd_set_accent_type(self.njd) + _njd.njd_set_unvoiced_vowel(self.njd) + _njd.njd_set_long_vowel(self.njd) features = njd2feature(self.njd) # Note that this will release memory for njd feature @@ -212,23 +224,24 @@ cdef class OpenJTalk(object): return features + @_lock_manager() def make_label(self, features): """Make full-context label """ feature2njd(self.njd, features) - njd2jpcommon(self.jpcommon, self.njd) + with nogil: + njd2jpcommon(self.jpcommon, self.njd) - JPCommon_make_label(self.jpcommon) + JPCommon_make_label(self.jpcommon) - cdef int label_size = JPCommon_get_label_size(self.jpcommon) - cdef char** label_feature - label_feature = JPCommon_get_label_feature(self.jpcommon) + label_size = JPCommon_get_label_size(self.jpcommon) + label_feature = JPCommon_get_label_feature(self.jpcommon) labels = [] for i in range(label_size): - # This will create a copy of c string - # http://cython.readthedocs.io/en/latest/src/tutorial/strings.html - labels.append(label_feature[i]) + # This will create a copy of c string + # http://cython.readthedocs.io/en/latest/src/tutorial/strings.html + labels.append(label_feature[i]) # Note that this will release memory for label feature JPCommon_refresh(self.jpcommon) @@ -282,4 +295,7 @@ def mecab_dict_index(bytes dn_mecab, bytes path, bytes out_path): "utf-8", path ] - return _mecab_dict_index(10, argv) + cdef int ret + with nogil: + ret = _mecab_dict_index(10, argv) + return ret diff --git a/pyopenjtalk/openjtalk/jpcommon.pxd b/pyopenjtalk/openjtalk/jpcommon.pxd index 8e86bea..c78dd0f 100644 --- a/pyopenjtalk/openjtalk/jpcommon.pxd +++ b/pyopenjtalk/openjtalk/jpcommon.pxd @@ -4,26 +4,26 @@ from libc.stdio cimport FILE cdef extern from "jpcommon.h": cdef cppclass JPCommonNode: - char *pron - char *pos - char *ctype - char *cform - int acc - int chain_flag - void *prev - void *next + char *pron + char *pos + char *ctype + char *cform + int acc + int chain_flag + void *prev + void *next cdef cppclass JPCommon: - JPCommonNode *head - JPCommonNode *tail - void *label + JPCommonNode *head + JPCommonNode *tail + void *label - void JPCommon_initialize(JPCommon * jpcommon) + void JPCommon_initialize(JPCommon * jpcommon) nogil void JPCommon_push(JPCommon * jpcommon, JPCommonNode * node) - void JPCommon_make_label(JPCommon * jpcommon) - int JPCommon_get_label_size(JPCommon * jpcommon) - char **JPCommon_get_label_feature(JPCommon * jpcommon) + void JPCommon_make_label(JPCommon * jpcommon) nogil + int JPCommon_get_label_size(JPCommon * jpcommon) nogil + char **JPCommon_get_label_feature(JPCommon * jpcommon) nogil void JPCommon_print(JPCommon * jpcommon) void JPCommon_fprint(JPCommon * jpcommon, FILE * fp) void JPCommon_refresh(JPCommon * jpcommon) - void JPCommon_clear(JPCommon * jpcommon) + void JPCommon_clear(JPCommon * jpcommon) nogil diff --git a/pyopenjtalk/openjtalk/mecab.pxd b/pyopenjtalk/openjtalk/mecab.pxd index 1538e05..2c31912 100644 --- a/pyopenjtalk/openjtalk/mecab.pxd +++ b/pyopenjtalk/openjtalk/mecab.pxd @@ -8,15 +8,15 @@ cdef extern from "mecab.h": void *tagger void *lattice - cdef int Mecab_initialize(Mecab *m) - cdef int Mecab_load(Mecab *m, const char *dicdir) - cdef int Mecab_analysis(Mecab *m, const char *str) + cdef int Mecab_initialize(Mecab *m) nogil + cdef int Mecab_load(Mecab *m, const char *dicdir) nogil + cdef int Mecab_analysis(Mecab *m, const char *str) nogil cdef int Mecab_print(Mecab *m) - int Mecab_get_size(Mecab *m) - char **Mecab_get_feature(Mecab *m) - cdef int Mecab_refresh(Mecab *m) - cdef int Mecab_clear(Mecab *m) - cdef int mecab_dict_index(int argc, char **argv) + int Mecab_get_size(Mecab *m) nogil + char **Mecab_get_feature(Mecab *m) nogil + cdef int Mecab_refresh(Mecab *m) nogil + cdef int Mecab_clear(Mecab *m) nogil + cdef int mecab_dict_index(int argc, char **argv) nogil cdef extern from "mecab.h" namespace "MeCab": cdef cppclass Tagger: @@ -24,6 +24,6 @@ cdef extern from "mecab.h" namespace "MeCab": cdef cppclass Lattice: pass cdef cppclass Model: - Tagger *createTagger() - Lattice *createLattice() - cdef Model *createModel(int argc, char **argv) + Tagger *createTagger() nogil + Lattice *createLattice() nogil + cdef Model *createModel(int argc, char **argv) nogil diff --git a/pyopenjtalk/openjtalk/mecab2njd.pxd b/pyopenjtalk/openjtalk/mecab2njd.pxd index be57ccc..fdfc1b0 100644 --- a/pyopenjtalk/openjtalk/mecab2njd.pxd +++ b/pyopenjtalk/openjtalk/mecab2njd.pxd @@ -3,4 +3,4 @@ from .njd cimport NJD cdef extern from "mecab2njd.h": - void mecab2njd(NJD * njd, char **feature, int size); + void mecab2njd(NJD * njd, char **feature, int size) nogil diff --git a/pyopenjtalk/openjtalk/njd.pxd b/pyopenjtalk/openjtalk/njd.pxd index 38d3887..f745f8d 100644 --- a/pyopenjtalk/openjtalk/njd.pxd +++ b/pyopenjtalk/openjtalk/njd.pxd @@ -66,7 +66,7 @@ cdef extern from "njd.h": NJDNode *head NJDNode *tail - void NJD_initialize(NJD * njd) + void NJD_initialize(NJD * njd) nogil void NJD_load(NJD * njd, const char *str) void NJD_load_from_fp(NJD * njd, FILE * fp) int NJD_get_size(NJD * njd) @@ -76,22 +76,22 @@ cdef extern from "njd.h": void NJD_fprint(NJD * njd, FILE * fp) void NJD_sprint(NJD * njd, char *buff, const char *split_code) void NJD_refresh(NJD * njd) - void NJD_clear(NJD * wl) + void NJD_clear(NJD * wl) nogil cdef extern from "njd_set_accent_phrase.h": - void njd_set_accent_phrase(NJD * njd) + void njd_set_accent_phrase(NJD * njd) nogil cdef extern from "njd_set_accent_type.h": - void njd_set_accent_type(NJD * njd) + void njd_set_accent_type(NJD * njd) nogil cdef extern from "njd_set_digit.h": - void njd_set_digit(NJD * njd) + void njd_set_digit(NJD * njd) nogil cdef extern from "njd_set_long_vowel.h": - void njd_set_long_vowel(NJD * njd) + void njd_set_long_vowel(NJD * njd) nogil cdef extern from "njd_set_pronunciation.h": - void njd_set_pronunciation(NJD * njd) + void njd_set_pronunciation(NJD * njd) nogil cdef extern from "njd_set_unvoiced_vowel.h": - void njd_set_unvoiced_vowel(NJD * njd) + void njd_set_unvoiced_vowel(NJD * njd) nogil diff --git a/pyopenjtalk/openjtalk/njd2jpcommon.pxd b/pyopenjtalk/openjtalk/njd2jpcommon.pxd index 8309288..a84c859 100644 --- a/pyopenjtalk/openjtalk/njd2jpcommon.pxd +++ b/pyopenjtalk/openjtalk/njd2jpcommon.pxd @@ -4,4 +4,4 @@ from .jpcommon cimport JPCommon from .njd cimport NJD cdef extern from "njd2jpcommon.h": - void njd2jpcommon(JPCommon * jpcommon, NJD * njd) + void njd2jpcommon(JPCommon * jpcommon, NJD * njd) nogil diff --git a/pyopenjtalk/openjtalk/text2mecab.pxd b/pyopenjtalk/openjtalk/text2mecab.pxd index 6081757..1d26049 100644 --- a/pyopenjtalk/openjtalk/text2mecab.pxd +++ b/pyopenjtalk/openjtalk/text2mecab.pxd @@ -1,4 +1,4 @@ # distutils: language = c++ cdef extern from "text2mecab.h": - void text2mecab(char *output, const char *input) + void text2mecab(char *output, const char *input) nogil diff --git a/setup.py b/setup.py index dfac91e..c67d3c4 100644 --- a/setup.py +++ b/setup.py @@ -119,7 +119,7 @@ def check_cmake_in_path(): Extension( name="pyopenjtalk.openjtalk", sources=[join("pyopenjtalk", "openjtalk.pyx")] + all_src, - include_dirs=[np.get_include()] + include_dirs, + include_dirs=include_dirs, extra_compile_args=[], extra_link_args=[], language="c++", diff --git a/tests/test_openjtalk.py b/tests/test_openjtalk.py index adfc24e..5a27466 100644 --- a/tests/test_openjtalk.py +++ b/tests/test_openjtalk.py @@ -1,3 +1,4 @@ +from concurrent.futures import ThreadPoolExecutor from pathlib import Path import pyopenjtalk @@ -68,7 +69,10 @@ def test_g2p_kana(): for text, pron in [ ("今日もこんにちは", "キョーモコンニチワ"), ("いやあん", "イヤーン"), - ("パソコンのとりあえず知っておきたい使い方", "パソコンノトリアエズシッテオキタイツカイカタ"), + ( + "パソコンのとりあえず知っておきたい使い方", + "パソコンノトリアエズシッテオキタイツカイカタ", + ), ]: p = pyopenjtalk.g2p(text, kana=True) assert p == pron @@ -108,3 +112,27 @@ def test_userdic(): ]: p = pyopenjtalk.g2p(text) assert p == expected + + +def test_multithreading(): + ojt = pyopenjtalk.openjtalk.OpenJTalk(pyopenjtalk.OPEN_JTALK_DICT_DIR) + texts = [ + "今日もいい天気ですね", + "こんにちは", + "マルチスレッドプログラミング", + "テストです", + "Pythonはプログラミング言語です", + "日本語テキストを音声合成します", + ] * 4 + + # Test consistency between single and multi-threaded runs + # make sure no corruptions happen in OJT internal + results_s = [ojt.run_frontend(text) for text in texts] + results_m = [] + with ThreadPoolExecutor() as e: + results_m = [i for i in e.map(ojt.run_frontend, texts)] + for s, m in zip(results_s, results_m): + assert len(s) == len(m) + for s_, m_ in zip(s, m): + # full context must exactly match + assert s_ == m_