diff --git a/pyopenjtalk/htsengine.pyx b/pyopenjtalk/htsengine.pyx index 2abf824..b8ed5e6 100644 --- a/pyopenjtalk/htsengine.pyx +++ b/pyopenjtalk/htsengine.pyx @@ -2,6 +2,9 @@ # cython: boundscheck=True, wraparound=True # cython: c_string_type=unicode, c_string_encoding=ascii +from contextlib import contextmanager +from threading import RLock + import numpy as np cimport numpy as np @@ -19,6 +22,17 @@ from .htsengine cimport ( HTS_Engine_get_generated_speech, HTS_Engine_get_nsamples ) +def _generate_lock_manager(): + lock = RLock() + + @contextmanager + def f(): + with lock: + yield + + return f + + cdef class HTSEngine(object): """HTSEngine @@ -26,6 +40,7 @@ cdef class HTSEngine(object): voice (bytes): File path of htsvoice. """ cdef HTS_Engine* engine + _lock_manager = _generate_lock_manager() def __cinit__(self, bytes voice=b"htsvoice/mei_normal.htsvoice"): self.engine = new HTS_Engine() @@ -36,6 +51,7 @@ cdef class HTSEngine(object): self.clear() raise RuntimeError("Failed to initalize HTS_Engine") + @_lock_manager() def load(self, bytes voice): cdef char* voices = voice cdef char ret @@ -43,15 +59,18 @@ cdef class HTSEngine(object): ret = HTS_Engine_load(self.engine, &voices, 1) return ret + @_lock_manager() def get_sampling_frequency(self): """Get sampling frequency """ return HTS_Engine_get_sampling_frequency(self.engine) + @_lock_manager() def get_fperiod(self): """Get frame period""" return HTS_Engine_get_fperiod(self.engine) + @_lock_manager() def set_speed(self, speed=1.0): """Set speed @@ -60,6 +79,7 @@ cdef class HTSEngine(object): """ HTS_Engine_set_speed(self.engine, speed) + @_lock_manager() def add_half_tone(self, half_tone=0.0): """Additional half tone in log-f0 @@ -68,6 +88,7 @@ cdef class HTSEngine(object): """ HTS_Engine_add_half_tone(self.engine, half_tone) + @_lock_manager() def synthesize(self, list labels): """Synthesize waveform from list of full-context labels @@ -82,6 +103,7 @@ cdef class HTSEngine(object): self.refresh() return x + @_lock_manager() def synthesize_from_strings(self, list labels): """Synthesize from strings""" cdef size_t num_lines = len(labels) @@ -96,6 +118,7 @@ cdef class HTSEngine(object): if ret != 1: raise RuntimeError("Failed to run synthesize_from_strings") + @_lock_manager() def get_generated_speech(self): """Get generated speech""" cdef size_t nsamples = HTS_Engine_get_nsamples(self.engine) @@ -107,13 +130,16 @@ cdef class HTSEngine(object): speech_view[index] = HTS_Engine_get_generated_speech(self.engine, index) return speech + @_lock_manager() def get_fullcontext_label_format(self): """Get full-context label format""" return (HTS_Engine_get_fullcontext_label_format(self.engine)).decode("utf-8") + @_lock_manager() def refresh(self): - HTS_Engine_refresh(self.engine) + HTS_Engine_refresh(self.engine) + @_lock_manager() def clear(self): HTS_Engine_clear(self.engine) diff --git a/pyopenjtalk/openjtalk.pyx b/pyopenjtalk/openjtalk.pyx index e120647..0113844 100644 --- a/pyopenjtalk/openjtalk.pyx +++ b/pyopenjtalk/openjtalk.pyx @@ -2,6 +2,9 @@ # cython: boundscheck=True, wraparound=True # cython: c_string_type=unicode, c_string_encoding=ascii +from contextlib import contextmanager +from threading import Lock + from libc.stdlib cimport calloc from libc.string cimport strlen @@ -143,6 +146,16 @@ cdef inline int Mecab_load_with_userdic(Mecab *m, char* dicdir, char* userdic) n return 1 +def _generate_lock_manager(): + lock = Lock() + + @contextmanager + def f(): + with lock: + yield + + return f + cdef class OpenJTalk(object): """OpenJTalk @@ -156,6 +169,7 @@ cdef class OpenJTalk(object): cdef Mecab* mecab cdef NJD* njd cdef JPCommon* jpcommon + _lock_manager = _generate_lock_manager() def __cinit__(self, bytes dn_mecab=b"/usr/local/dic", bytes userdic=b""): cdef char* _dn_mecab = dn_mecab @@ -183,6 +197,7 @@ cdef class OpenJTalk(object): cdef int _load(self, char* dn_mecab, char* userdic) noexcept nogil: return Mecab_load_with_userdic(self.mecab, dn_mecab, userdic) + @_lock_manager() def run_frontend(self, text): """Run OpenJTalk's text processing frontend """ @@ -209,6 +224,7 @@ cdef class OpenJTalk(object): return features + @_lock_manager() def make_label(self, features): """Make full-context label """