Skip to content

Commit

Permalink
ENH: Add lock.
Browse files Browse the repository at this point in the history
  • Loading branch information
sabonerune committed Nov 21, 2024
1 parent 3f20a06 commit 9bbebfb
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
28 changes: 27 additions & 1 deletion pyopenjtalk/htsengine.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# cython: boundscheck=True, wraparound=True
# cython: c_string_type=unicode, c_string_encoding=ascii

from contextlib import contextmanager
from threading import RLock

import numpy as np

cimport numpy as np
Expand All @@ -19,13 +22,25 @@ from .htsengine cimport (
HTS_Engine_get_generated_speech, HTS_Engine_get_nsamples
)

def _generate_lock_manager():
lock = RLock()

@contextmanager
def f():
with lock:
yield

return f


cdef class HTSEngine(object):
"""HTSEngine
Args:
voice (bytes): File path of htsvoice.
"""
cdef HTS_Engine* engine
_lock_manager = _generate_lock_manager()

def __cinit__(self, bytes voice=b"htsvoice/mei_normal.htsvoice"):
self.engine = new HTS_Engine()
Expand All @@ -36,22 +51,26 @@ cdef class HTSEngine(object):
self.clear()
raise RuntimeError("Failed to initalize HTS_Engine")

@_lock_manager()
def load(self, bytes voice):
cdef char* voices = voice
cdef char ret
with nogil:
ret = HTS_Engine_load(self.engine, &voices, 1)
return ret

@_lock_manager()
def get_sampling_frequency(self):
"""Get sampling frequency
"""
return HTS_Engine_get_sampling_frequency(self.engine)

@_lock_manager()
def get_fperiod(self):
"""Get frame period"""
return HTS_Engine_get_fperiod(self.engine)

@_lock_manager()
def set_speed(self, speed=1.0):
"""Set speed
Expand All @@ -60,6 +79,7 @@ cdef class HTSEngine(object):
"""
HTS_Engine_set_speed(self.engine, speed)

@_lock_manager()
def add_half_tone(self, half_tone=0.0):
"""Additional half tone in log-f0
Expand All @@ -68,6 +88,7 @@ cdef class HTSEngine(object):
"""
HTS_Engine_add_half_tone(self.engine, half_tone)

@_lock_manager()
def synthesize(self, list labels):
"""Synthesize waveform from list of full-context labels
Expand All @@ -82,6 +103,7 @@ cdef class HTSEngine(object):
self.refresh()
return x

@_lock_manager()
def synthesize_from_strings(self, list labels):
"""Synthesize from strings"""
cdef size_t num_lines = len(labels)
Expand All @@ -96,6 +118,7 @@ cdef class HTSEngine(object):
if ret != 1:
raise RuntimeError("Failed to run synthesize_from_strings")

@_lock_manager()
def get_generated_speech(self):
"""Get generated speech"""
cdef size_t nsamples = HTS_Engine_get_nsamples(self.engine)
Expand All @@ -107,13 +130,16 @@ cdef class HTSEngine(object):
speech_view[index] = HTS_Engine_get_generated_speech(self.engine, index)
return speech

@_lock_manager()
def get_fullcontext_label_format(self):
"""Get full-context label format"""
return (<bytes>HTS_Engine_get_fullcontext_label_format(self.engine)).decode("utf-8")

@_lock_manager()
def refresh(self):
HTS_Engine_refresh(self.engine)
HTS_Engine_refresh(self.engine)

@_lock_manager()
def clear(self):
HTS_Engine_clear(self.engine)

Expand Down
16 changes: 16 additions & 0 deletions pyopenjtalk/openjtalk.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# cython: boundscheck=True, wraparound=True
# cython: c_string_type=unicode, c_string_encoding=ascii

from contextlib import contextmanager
from threading import Lock

from libc.stdlib cimport calloc
from libc.string cimport strlen

Expand Down Expand Up @@ -143,6 +146,16 @@ cdef inline int Mecab_load_with_userdic(Mecab *m, char* dicdir, char* userdic) n

return 1

def _generate_lock_manager():
lock = Lock()

@contextmanager
def f():
with lock:
yield

return f


cdef class OpenJTalk(object):
"""OpenJTalk
Expand All @@ -156,6 +169,7 @@ cdef class OpenJTalk(object):
cdef Mecab* mecab
cdef NJD* njd
cdef JPCommon* jpcommon
_lock_manager = _generate_lock_manager()

def __cinit__(self, bytes dn_mecab=b"/usr/local/dic", bytes userdic=b""):
cdef char* _dn_mecab = dn_mecab
Expand Down Expand Up @@ -183,6 +197,7 @@ cdef class OpenJTalk(object):
cdef int _load(self, char* dn_mecab, char* userdic) noexcept nogil:
return Mecab_load_with_userdic(self.mecab, dn_mecab, userdic)

@_lock_manager()
def run_frontend(self, text):
"""Run OpenJTalk's text processing frontend
"""
Expand All @@ -209,6 +224,7 @@ cdef class OpenJTalk(object):

return features

@_lock_manager()
def make_label(self, features):
"""Make full-context label
"""
Expand Down

0 comments on commit 9bbebfb

Please sign in to comment.