Skip to content

Commit

Permalink
Merge pull request #88 from sabonerune/enh/release-gil
Browse files Browse the repository at this point in the history
ENH: Optimizations and GIL release.
  • Loading branch information
r9y9 authored Dec 25, 2024
2 parents 3a91d16 + 73173e8 commit 0cedad1
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 116 deletions.
51 changes: 41 additions & 10 deletions pyopenjtalk/htsengine.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# cython: boundscheck=True, wraparound=True
# cython: c_string_type=unicode, c_string_encoding=ascii

from contextlib import contextmanager
from threading import RLock

import numpy as np

cimport numpy as np
Expand All @@ -19,38 +22,55 @@ from .htsengine cimport (
HTS_Engine_get_generated_speech, HTS_Engine_get_nsamples
)

def _generate_lock_manager():
lock = RLock()

@contextmanager
def f():
with lock:
yield

return f


cdef class HTSEngine(object):
"""HTSEngine
Args:
voice (bytes): File path of htsvoice.
"""
cdef HTS_Engine* engine
_lock_manager = _generate_lock_manager()

def __cinit__(self, bytes voice=b"htsvoice/mei_normal.htsvoice"):
self.engine = new HTS_Engine()

HTS_Engine_initialize(self.engine)

if self.load(voice) != 1:
self.clear()
raise RuntimeError("Failed to initalize HTS_Engine")
self.clear()
raise RuntimeError("Failed to initalize HTS_Engine")

@_lock_manager()
def load(self, bytes voice):
cdef char* voices = voice
cdef char ret
ret = HTS_Engine_load(self.engine, &voices, 1)
with nogil:
ret = HTS_Engine_load(self.engine, &voices, 1)
return ret

@_lock_manager()
def get_sampling_frequency(self):
"""Get sampling frequency
"""
return HTS_Engine_get_sampling_frequency(self.engine)

@_lock_manager()
def get_fperiod(self):
"""Get frame period"""
return HTS_Engine_get_fperiod(self.engine)

@_lock_manager()
def set_speed(self, speed=1.0):
"""Set speed
Expand All @@ -59,6 +79,7 @@ cdef class HTSEngine(object):
"""
HTS_Engine_set_speed(self.engine, speed)

@_lock_manager()
def add_half_tone(self, half_tone=0.0):
"""Additional half tone in log-f0
Expand All @@ -67,6 +88,7 @@ cdef class HTSEngine(object):
"""
HTS_Engine_add_half_tone(self.engine, half_tone)

@_lock_manager()
def synthesize(self, list labels):
"""Synthesize waveform from list of full-context labels
Expand All @@ -81,34 +103,43 @@ cdef class HTSEngine(object):
self.refresh()
return x

@_lock_manager()
def synthesize_from_strings(self, list labels):
"""Synthesize from strings"""
cdef size_t num_lines = len(labels)
cdef char **lines = <char**> malloc((num_lines + 1) * sizeof(char*))
for n in range(len(labels)):
for n in range(num_lines):
lines[n] = <char*>labels[n]

cdef char ret = HTS_Engine_synthesize_from_strings(self.engine, lines, num_lines)
free(lines)
cdef char ret
with nogil:
ret = HTS_Engine_synthesize_from_strings(self.engine, lines, num_lines)
free(lines)
if ret != 1:
raise RuntimeError("Failed to run synthesize_from_strings")

@_lock_manager()
def get_generated_speech(self):
"""Get generated speech"""
cdef size_t nsamples = HTS_Engine_get_nsamples(self.engine)
cdef np.ndarray speech = np.zeros([nsamples], dtype=np.float64)
cdef np.ndarray speech = np.empty([nsamples], dtype=np.float64)
cdef double[:] speech_view = speech
cdef size_t index
for index in range(nsamples):
speech[index] = HTS_Engine_get_generated_speech(self.engine, index)
with (nogil, cython.boundscheck(False)):
for index in range(nsamples):
speech_view[index] = HTS_Engine_get_generated_speech(self.engine, index)
return speech

@_lock_manager()
def get_fullcontext_label_format(self):
"""Get full-context label format"""
return (<bytes>HTS_Engine_get_fullcontext_label_format(self.engine)).decode("utf-8")

@_lock_manager()
def refresh(self):
HTS_Engine_refresh(self.engine)
HTS_Engine_refresh(self.engine)

@_lock_manager()
def clear(self):
HTS_Engine_clear(self.engine)

Expand Down
8 changes: 4 additions & 4 deletions pyopenjtalk/htsengine/__init__.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ cdef extern from "HTS_engine.h":
ctypedef _HTS_Engine HTS_Engine

void HTS_Engine_initialize(HTS_Engine * engine)
char HTS_Engine_load(HTS_Engine * engine, char **voices, size_t num_voices)
char HTS_Engine_load(HTS_Engine * engine, char **voices, size_t num_voices) nogil
size_t HTS_Engine_get_sampling_frequency(HTS_Engine * engine)
size_t HTS_Engine_get_fperiod(HTS_Engine * engine)
void HTS_Engine_refresh(HTS_Engine * engine)
void HTS_Engine_clear(HTS_Engine * engine)
const char *HTS_Engine_get_fullcontext_label_format(HTS_Engine * engine)
char HTS_Engine_synthesize_from_strings(HTS_Engine * engine, char **lines, size_t num_lines)
char HTS_Engine_synthesize_from_strings(HTS_Engine * engine, char **lines, size_t num_lines) nogil
char HTS_Engine_synthesize_from_fn(HTS_Engine * engine, const char *fn)
double HTS_Engine_get_generated_speech(HTS_Engine * engine, size_t index)
double HTS_Engine_get_generated_speech(HTS_Engine * engine, size_t index) nogil
size_t HTS_Engine_get_nsamples(HTS_Engine * engine)

void HTS_Engine_set_speed(HTS_Engine * engine, double f)
void HTS_Engine_add_half_tone(HTS_Engine * engine, double f)
void HTS_Engine_add_half_tone(HTS_Engine * engine, double f)
140 changes: 78 additions & 62 deletions pyopenjtalk/openjtalk.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
# cython: boundscheck=True, wraparound=True
# cython: c_string_type=unicode, c_string_encoding=ascii

import numpy as np
from contextlib import contextmanager
from threading import Lock

cimport numpy as np
np.import_array()

cimport cython
from libc.stdlib cimport calloc
from libc.string cimport strlen

Expand All @@ -20,7 +17,6 @@ from .openjtalk cimport njd as _njd
from .openjtalk.jpcommon cimport JPCommon, JPCommon_initialize,JPCommon_make_label
from .openjtalk.jpcommon cimport JPCommon_get_label_size, JPCommon_get_label_feature
from .openjtalk.jpcommon cimport JPCommon_refresh, JPCommon_clear
from .openjtalk cimport njd2jpcommon
from .openjtalk.text2mecab cimport text2mecab
from .openjtalk.mecab2njd cimport mecab2njd
from .openjtalk.njd2jpcommon cimport njd2jpcommon
Expand Down Expand Up @@ -55,48 +51,48 @@ cdef njd_node_get_read(_njd.NJDNode* node):
cdef njd_node_get_pron(_njd.NJDNode* node):
return (<bytes>(_njd.NJDNode_get_pron(node))).decode("utf-8")

cdef njd_node_get_acc(_njd.NJDNode* node):
cdef int njd_node_get_acc(_njd.NJDNode* node) noexcept:
return _njd.NJDNode_get_acc(node)

cdef njd_node_get_mora_size(_njd.NJDNode* node):
cdef int njd_node_get_mora_size(_njd.NJDNode* node) noexcept:
return _njd.NJDNode_get_mora_size(node)

cdef njd_node_get_chain_rule(_njd.NJDNode* node):
return (<bytes>(_njd.NJDNode_get_chain_rule(node))).decode("utf-8")

cdef njd_node_get_chain_flag(_njd.NJDNode* node):
return _njd.NJDNode_get_chain_flag(node)
cdef int njd_node_get_chain_flag(_njd.NJDNode* node) noexcept:
return _njd.NJDNode_get_chain_flag(node)


cdef node2feature(_njd.NJDNode* node):
return {
"string": njd_node_get_string(node),
"pos": njd_node_get_pos(node),
"pos_group1": njd_node_get_pos_group1(node),
"pos_group2": njd_node_get_pos_group2(node),
"pos_group3": njd_node_get_pos_group3(node),
"ctype": njd_node_get_ctype(node),
"cform": njd_node_get_cform(node),
"orig": njd_node_get_orig(node),
"read": njd_node_get_read(node),
"pron": njd_node_get_pron(node),
"acc": njd_node_get_acc(node),
"mora_size": njd_node_get_mora_size(node),
"chain_rule": njd_node_get_chain_rule(node),
"chain_flag": njd_node_get_chain_flag(node),
}
return {
"string": njd_node_get_string(node),
"pos": njd_node_get_pos(node),
"pos_group1": njd_node_get_pos_group1(node),
"pos_group2": njd_node_get_pos_group2(node),
"pos_group3": njd_node_get_pos_group3(node),
"ctype": njd_node_get_ctype(node),
"cform": njd_node_get_cform(node),
"orig": njd_node_get_orig(node),
"read": njd_node_get_read(node),
"pron": njd_node_get_pron(node),
"acc": njd_node_get_acc(node),
"mora_size": njd_node_get_mora_size(node),
"chain_rule": njd_node_get_chain_rule(node),
"chain_flag": njd_node_get_chain_flag(node),
}


cdef njd2feature(_njd.NJD* njd):
cdef _njd.NJDNode* node = njd.head
features = []
while node is not NULL:
features.append(node2feature(node))
node = node.next
features.append(node2feature(node))
node = node.next
return features


cdef feature2njd(_njd.NJD* njd, features):
cdef void feature2njd(_njd.NJD* njd, features):
cdef _njd.NJDNode* node

for feature_node in features:
Expand All @@ -120,7 +116,7 @@ cdef feature2njd(_njd.NJD* njd, features):
_njd.NJD_push_node(njd, node)

# based on Mecab_load in impl. from mecab.cpp
cdef inline int Mecab_load_with_userdic(Mecab *m, char* dicdir, char* userdic):
cdef inline int Mecab_load_with_userdic(Mecab *m, char* dicdir, char* userdic) noexcept nogil:
if userdic == NULL or strlen(userdic) == 0:
return Mecab_load(m, dicdir)

Expand Down Expand Up @@ -150,6 +146,16 @@ cdef inline int Mecab_load_with_userdic(Mecab *m, char* dicdir, char* userdic):

return 1

def _generate_lock_manager():
lock = Lock()

@contextmanager
def f():
with lock:
yield

return f


cdef class OpenJTalk(object):
"""OpenJTalk
Expand All @@ -163,47 +169,53 @@ cdef class OpenJTalk(object):
cdef Mecab* mecab
cdef NJD* njd
cdef JPCommon* jpcommon
_lock_manager = _generate_lock_manager()

def __cinit__(self, bytes dn_mecab=b"/usr/local/dic", bytes userdic=b""):
cdef char* _dn_mecab = dn_mecab
cdef char* _userdic = userdic

self.mecab = new Mecab()
self.njd = new NJD()
self.jpcommon = new JPCommon()

Mecab_initialize(self.mecab)
NJD_initialize(self.njd)
JPCommon_initialize(self.jpcommon)

r = self._load(dn_mecab, userdic)
if r != 1:
self._clear()
raise RuntimeError("Failed to initalize Mecab")
with nogil:
Mecab_initialize(self.mecab)
NJD_initialize(self.njd)
JPCommon_initialize(self.jpcommon)

r = self._load(_dn_mecab, _userdic)
if r != 1:
self._clear()
raise RuntimeError("Failed to initalize Mecab")

def _clear(self):
Mecab_clear(self.mecab)
NJD_clear(self.njd)
JPCommon_clear(self.jpcommon)
cdef void _clear(self) noexcept nogil:
Mecab_clear(self.mecab)
NJD_clear(self.njd)
JPCommon_clear(self.jpcommon)

def _load(self, bytes dn_mecab, bytes userdic):
cdef int _load(self, char* dn_mecab, char* userdic) noexcept nogil:
return Mecab_load_with_userdic(self.mecab, dn_mecab, userdic)


@_lock_manager()
def run_frontend(self, text):
"""Run OpenJTalk's text processing frontend
"""
cdef char buff[8192]

if isinstance(text, str):
text = text.encode("utf-8")
text2mecab(buff, text)
Mecab_analysis(self.mecab, buff)
mecab2njd(self.njd, Mecab_get_feature(self.mecab), Mecab_get_size(self.mecab))
_njd.njd_set_pronunciation(self.njd)
_njd.njd_set_digit(self.njd)
_njd.njd_set_accent_phrase(self.njd)
_njd.njd_set_accent_type(self.njd)
_njd.njd_set_unvoiced_vowel(self.njd)
_njd.njd_set_long_vowel(self.njd)
cdef const char* _text = text
with nogil:
text2mecab(buff, _text)
Mecab_analysis(self.mecab, buff)
mecab2njd(self.njd, Mecab_get_feature(self.mecab), Mecab_get_size(self.mecab))
_njd.njd_set_pronunciation(self.njd)
_njd.njd_set_digit(self.njd)
_njd.njd_set_accent_phrase(self.njd)
_njd.njd_set_accent_type(self.njd)
_njd.njd_set_unvoiced_vowel(self.njd)
_njd.njd_set_long_vowel(self.njd)
features = njd2feature(self.njd)

# Note that this will release memory for njd feature
Expand All @@ -212,23 +224,24 @@ cdef class OpenJTalk(object):

return features

@_lock_manager()
def make_label(self, features):
"""Make full-context label
"""
feature2njd(self.njd, features)
njd2jpcommon(self.jpcommon, self.njd)
with nogil:
njd2jpcommon(self.jpcommon, self.njd)

JPCommon_make_label(self.jpcommon)
JPCommon_make_label(self.jpcommon)

cdef int label_size = JPCommon_get_label_size(self.jpcommon)
cdef char** label_feature
label_feature = JPCommon_get_label_feature(self.jpcommon)
label_size = JPCommon_get_label_size(self.jpcommon)
label_feature = JPCommon_get_label_feature(self.jpcommon)

labels = []
for i in range(label_size):
# This will create a copy of c string
# http://cython.readthedocs.io/en/latest/src/tutorial/strings.html
labels.append(<unicode>label_feature[i])
# This will create a copy of c string
# http://cython.readthedocs.io/en/latest/src/tutorial/strings.html
labels.append(<unicode>label_feature[i])

# Note that this will release memory for label feature
JPCommon_refresh(self.jpcommon)
Expand Down Expand Up @@ -282,4 +295,7 @@ def mecab_dict_index(bytes dn_mecab, bytes path, bytes out_path):
"utf-8",
path
]
return _mecab_dict_index(10, argv)
cdef int ret
with nogil:
ret = _mecab_dict_index(10, argv)
return ret
Loading

0 comments on commit 0cedad1

Please sign in to comment.