Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Optimizations and GIL release. #88

Merged
merged 4 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 41 additions & 10 deletions pyopenjtalk/htsengine.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# cython: boundscheck=True, wraparound=True
# cython: c_string_type=unicode, c_string_encoding=ascii

from contextlib import contextmanager
from threading import RLock

import numpy as np

cimport numpy as np
Expand All @@ -19,38 +22,55 @@ from .htsengine cimport (
HTS_Engine_get_generated_speech, HTS_Engine_get_nsamples
)

def _generate_lock_manager():
lock = RLock()

@contextmanager
def f():
with lock:
yield

return f


cdef class HTSEngine(object):
"""HTSEngine

Args:
voice (bytes): File path of htsvoice.
"""
cdef HTS_Engine* engine
_lock_manager = _generate_lock_manager()

def __cinit__(self, bytes voice=b"htsvoice/mei_normal.htsvoice"):
self.engine = new HTS_Engine()

HTS_Engine_initialize(self.engine)

if self.load(voice) != 1:
self.clear()
raise RuntimeError("Failed to initalize HTS_Engine")
self.clear()
raise RuntimeError("Failed to initalize HTS_Engine")

@_lock_manager()
def load(self, bytes voice):
cdef char* voices = voice
cdef char ret
ret = HTS_Engine_load(self.engine, &voices, 1)
with nogil:
ret = HTS_Engine_load(self.engine, &voices, 1)
return ret

@_lock_manager()
def get_sampling_frequency(self):
"""Get sampling frequency
"""
return HTS_Engine_get_sampling_frequency(self.engine)

@_lock_manager()
def get_fperiod(self):
"""Get frame period"""
return HTS_Engine_get_fperiod(self.engine)

@_lock_manager()
def set_speed(self, speed=1.0):
"""Set speed

Expand All @@ -59,6 +79,7 @@ cdef class HTSEngine(object):
"""
HTS_Engine_set_speed(self.engine, speed)

@_lock_manager()
def add_half_tone(self, half_tone=0.0):
"""Additional half tone in log-f0

Expand All @@ -67,6 +88,7 @@ cdef class HTSEngine(object):
"""
HTS_Engine_add_half_tone(self.engine, half_tone)

@_lock_manager()
def synthesize(self, list labels):
"""Synthesize waveform from list of full-context labels

Expand All @@ -81,34 +103,43 @@ cdef class HTSEngine(object):
self.refresh()
return x

@_lock_manager()
def synthesize_from_strings(self, list labels):
"""Synthesize from strings"""
cdef size_t num_lines = len(labels)
cdef char **lines = <char**> malloc((num_lines + 1) * sizeof(char*))
for n in range(len(labels)):
for n in range(num_lines):
lines[n] = <char*>labels[n]

cdef char ret = HTS_Engine_synthesize_from_strings(self.engine, lines, num_lines)
free(lines)
cdef char ret
with nogil:
ret = HTS_Engine_synthesize_from_strings(self.engine, lines, num_lines)
free(lines)
if ret != 1:
raise RuntimeError("Failed to run synthesize_from_strings")

@_lock_manager()
def get_generated_speech(self):
"""Get generated speech"""
cdef size_t nsamples = HTS_Engine_get_nsamples(self.engine)
cdef np.ndarray speech = np.zeros([nsamples], dtype=np.float64)
cdef np.ndarray speech = np.empty([nsamples], dtype=np.float64)
cdef double[:] speech_view = speech
cdef size_t index
for index in range(nsamples):
speech[index] = HTS_Engine_get_generated_speech(self.engine, index)
with (nogil, cython.boundscheck(False)):
for index in range(nsamples):
speech_view[index] = HTS_Engine_get_generated_speech(self.engine, index)
return speech

@_lock_manager()
def get_fullcontext_label_format(self):
"""Get full-context label format"""
return (<bytes>HTS_Engine_get_fullcontext_label_format(self.engine)).decode("utf-8")

@_lock_manager()
def refresh(self):
HTS_Engine_refresh(self.engine)
HTS_Engine_refresh(self.engine)

@_lock_manager()
def clear(self):
HTS_Engine_clear(self.engine)

Expand Down
8 changes: 4 additions & 4 deletions pyopenjtalk/htsengine/__init__.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ cdef extern from "HTS_engine.h":
ctypedef _HTS_Engine HTS_Engine

void HTS_Engine_initialize(HTS_Engine * engine)
char HTS_Engine_load(HTS_Engine * engine, char **voices, size_t num_voices)
char HTS_Engine_load(HTS_Engine * engine, char **voices, size_t num_voices) nogil
size_t HTS_Engine_get_sampling_frequency(HTS_Engine * engine)
size_t HTS_Engine_get_fperiod(HTS_Engine * engine)
void HTS_Engine_refresh(HTS_Engine * engine)
void HTS_Engine_clear(HTS_Engine * engine)
const char *HTS_Engine_get_fullcontext_label_format(HTS_Engine * engine)
char HTS_Engine_synthesize_from_strings(HTS_Engine * engine, char **lines, size_t num_lines)
char HTS_Engine_synthesize_from_strings(HTS_Engine * engine, char **lines, size_t num_lines) nogil
char HTS_Engine_synthesize_from_fn(HTS_Engine * engine, const char *fn)
double HTS_Engine_get_generated_speech(HTS_Engine * engine, size_t index)
double HTS_Engine_get_generated_speech(HTS_Engine * engine, size_t index) nogil
size_t HTS_Engine_get_nsamples(HTS_Engine * engine)

void HTS_Engine_set_speed(HTS_Engine * engine, double f)
void HTS_Engine_add_half_tone(HTS_Engine * engine, double f)
void HTS_Engine_add_half_tone(HTS_Engine * engine, double f)
140 changes: 78 additions & 62 deletions pyopenjtalk/openjtalk.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
# cython: boundscheck=True, wraparound=True
# cython: c_string_type=unicode, c_string_encoding=ascii

import numpy as np
from contextlib import contextmanager
from threading import Lock

cimport numpy as np
np.import_array()
Comment on lines -7 to -8
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my understanding, this is still recommended in Cython > 3.

https://cython.readthedocs.io/en/latest/src/tutorial/numpy.html

# It's necessary to call "import_array" if you use any part of the
# numpy PyArray_* API. From Cython 3, accessing attributes like
# ".shape" on a typed Numpy array use this API. Therefore we recommend
# always calling "import_array" whenever you "cimport numpy"
cnp.import_array()

we don't use shape on a typed numpy array though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy is not used in openjtalk.pyx.
I think if you are not using numpy at all don't need to import it.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I missed the point. Then LGTM


cimport cython
from libc.stdlib cimport calloc
from libc.string cimport strlen

Expand All @@ -20,7 +17,6 @@ from .openjtalk cimport njd as _njd
from .openjtalk.jpcommon cimport JPCommon, JPCommon_initialize,JPCommon_make_label
from .openjtalk.jpcommon cimport JPCommon_get_label_size, JPCommon_get_label_feature
from .openjtalk.jpcommon cimport JPCommon_refresh, JPCommon_clear
from .openjtalk cimport njd2jpcommon
from .openjtalk.text2mecab cimport text2mecab
from .openjtalk.mecab2njd cimport mecab2njd
from .openjtalk.njd2jpcommon cimport njd2jpcommon
Expand Down Expand Up @@ -55,48 +51,48 @@ cdef njd_node_get_read(_njd.NJDNode* node):
cdef njd_node_get_pron(_njd.NJDNode* node):
return (<bytes>(_njd.NJDNode_get_pron(node))).decode("utf-8")

cdef njd_node_get_acc(_njd.NJDNode* node):
cdef int njd_node_get_acc(_njd.NJDNode* node) noexcept:
return _njd.NJDNode_get_acc(node)

cdef njd_node_get_mora_size(_njd.NJDNode* node):
cdef int njd_node_get_mora_size(_njd.NJDNode* node) noexcept:
return _njd.NJDNode_get_mora_size(node)

cdef njd_node_get_chain_rule(_njd.NJDNode* node):
return (<bytes>(_njd.NJDNode_get_chain_rule(node))).decode("utf-8")

cdef njd_node_get_chain_flag(_njd.NJDNode* node):
return _njd.NJDNode_get_chain_flag(node)
cdef int njd_node_get_chain_flag(_njd.NJDNode* node) noexcept:
return _njd.NJDNode_get_chain_flag(node)


cdef node2feature(_njd.NJDNode* node):
return {
"string": njd_node_get_string(node),
"pos": njd_node_get_pos(node),
"pos_group1": njd_node_get_pos_group1(node),
"pos_group2": njd_node_get_pos_group2(node),
"pos_group3": njd_node_get_pos_group3(node),
"ctype": njd_node_get_ctype(node),
"cform": njd_node_get_cform(node),
"orig": njd_node_get_orig(node),
"read": njd_node_get_read(node),
"pron": njd_node_get_pron(node),
"acc": njd_node_get_acc(node),
"mora_size": njd_node_get_mora_size(node),
"chain_rule": njd_node_get_chain_rule(node),
"chain_flag": njd_node_get_chain_flag(node),
}
return {
"string": njd_node_get_string(node),
"pos": njd_node_get_pos(node),
"pos_group1": njd_node_get_pos_group1(node),
"pos_group2": njd_node_get_pos_group2(node),
"pos_group3": njd_node_get_pos_group3(node),
"ctype": njd_node_get_ctype(node),
"cform": njd_node_get_cform(node),
"orig": njd_node_get_orig(node),
"read": njd_node_get_read(node),
"pron": njd_node_get_pron(node),
"acc": njd_node_get_acc(node),
"mora_size": njd_node_get_mora_size(node),
"chain_rule": njd_node_get_chain_rule(node),
"chain_flag": njd_node_get_chain_flag(node),
}


cdef njd2feature(_njd.NJD* njd):
cdef _njd.NJDNode* node = njd.head
features = []
while node is not NULL:
features.append(node2feature(node))
node = node.next
features.append(node2feature(node))
node = node.next
return features


cdef feature2njd(_njd.NJD* njd, features):
cdef void feature2njd(_njd.NJD* njd, features):
cdef _njd.NJDNode* node

for feature_node in features:
Expand All @@ -120,7 +116,7 @@ cdef feature2njd(_njd.NJD* njd, features):
_njd.NJD_push_node(njd, node)

# based on Mecab_load in impl. from mecab.cpp
cdef inline int Mecab_load_with_userdic(Mecab *m, char* dicdir, char* userdic):
cdef inline int Mecab_load_with_userdic(Mecab *m, char* dicdir, char* userdic) noexcept nogil:
if userdic == NULL or strlen(userdic) == 0:
return Mecab_load(m, dicdir)

Expand Down Expand Up @@ -150,6 +146,16 @@ cdef inline int Mecab_load_with_userdic(Mecab *m, char* dicdir, char* userdic):

return 1

def _generate_lock_manager():
lock = Lock()

@contextmanager
def f():
with lock:
yield

return f


cdef class OpenJTalk(object):
"""OpenJTalk
Expand All @@ -163,47 +169,53 @@ cdef class OpenJTalk(object):
cdef Mecab* mecab
cdef NJD* njd
cdef JPCommon* jpcommon
_lock_manager = _generate_lock_manager()

def __cinit__(self, bytes dn_mecab=b"/usr/local/dic", bytes userdic=b""):
cdef char* _dn_mecab = dn_mecab
cdef char* _userdic = userdic

self.mecab = new Mecab()
self.njd = new NJD()
self.jpcommon = new JPCommon()

Mecab_initialize(self.mecab)
NJD_initialize(self.njd)
JPCommon_initialize(self.jpcommon)

r = self._load(dn_mecab, userdic)
if r != 1:
self._clear()
raise RuntimeError("Failed to initalize Mecab")
with nogil:
Mecab_initialize(self.mecab)
NJD_initialize(self.njd)
JPCommon_initialize(self.jpcommon)

r = self._load(_dn_mecab, _userdic)
if r != 1:
self._clear()
raise RuntimeError("Failed to initalize Mecab")

def _clear(self):
Mecab_clear(self.mecab)
NJD_clear(self.njd)
JPCommon_clear(self.jpcommon)
cdef void _clear(self) noexcept nogil:
Mecab_clear(self.mecab)
NJD_clear(self.njd)
JPCommon_clear(self.jpcommon)

def _load(self, bytes dn_mecab, bytes userdic):
cdef int _load(self, char* dn_mecab, char* userdic) noexcept nogil:
return Mecab_load_with_userdic(self.mecab, dn_mecab, userdic)


@_lock_manager()
def run_frontend(self, text):
"""Run OpenJTalk's text processing frontend
"""
cdef char buff[8192]

if isinstance(text, str):
text = text.encode("utf-8")
text2mecab(buff, text)
Mecab_analysis(self.mecab, buff)
mecab2njd(self.njd, Mecab_get_feature(self.mecab), Mecab_get_size(self.mecab))
_njd.njd_set_pronunciation(self.njd)
_njd.njd_set_digit(self.njd)
_njd.njd_set_accent_phrase(self.njd)
_njd.njd_set_accent_type(self.njd)
_njd.njd_set_unvoiced_vowel(self.njd)
_njd.njd_set_long_vowel(self.njd)
cdef const char* _text = text
with nogil:
text2mecab(buff, _text)
Mecab_analysis(self.mecab, buff)
mecab2njd(self.njd, Mecab_get_feature(self.mecab), Mecab_get_size(self.mecab))
_njd.njd_set_pronunciation(self.njd)
_njd.njd_set_digit(self.njd)
_njd.njd_set_accent_phrase(self.njd)
_njd.njd_set_accent_type(self.njd)
_njd.njd_set_unvoiced_vowel(self.njd)
_njd.njd_set_long_vowel(self.njd)
features = njd2feature(self.njd)

# Note that this will release memory for njd feature
Expand All @@ -212,23 +224,24 @@ cdef class OpenJTalk(object):

return features

@_lock_manager()
def make_label(self, features):
"""Make full-context label
"""
feature2njd(self.njd, features)
njd2jpcommon(self.jpcommon, self.njd)
with nogil:
njd2jpcommon(self.jpcommon, self.njd)

JPCommon_make_label(self.jpcommon)
JPCommon_make_label(self.jpcommon)

cdef int label_size = JPCommon_get_label_size(self.jpcommon)
cdef char** label_feature
label_feature = JPCommon_get_label_feature(self.jpcommon)
label_size = JPCommon_get_label_size(self.jpcommon)
label_feature = JPCommon_get_label_feature(self.jpcommon)

labels = []
for i in range(label_size):
# This will create a copy of c string
# http://cython.readthedocs.io/en/latest/src/tutorial/strings.html
labels.append(<unicode>label_feature[i])
# This will create a copy of c string
# http://cython.readthedocs.io/en/latest/src/tutorial/strings.html
labels.append(<unicode>label_feature[i])

# Note that this will release memory for label feature
JPCommon_refresh(self.jpcommon)
Expand Down Expand Up @@ -282,4 +295,7 @@ def mecab_dict_index(bytes dn_mecab, bytes path, bytes out_path):
"utf-8",
path
]
return _mecab_dict_index(10, argv)
cdef int ret
with nogil:
ret = _mecab_dict_index(10, argv)
return ret
Loading
Loading