diff --git a/.github/workflows/build_whl.yml b/.github/workflows/build_whl.yml new file mode 100644 index 0000000..14c450a --- /dev/null +++ b/.github/workflows/build_whl.yml @@ -0,0 +1,37 @@ +name: build wheel + +on: + workflow_dispatch: + +jobs: + build: + + runs-on: ${{ matrix.os }} + strategy: + matrix: + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] + os: [ubuntu-latest, macos-latest, windows-latest] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Check out recursively + run: git submodule update --init --recursive + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install --upgrade setuptools + python -m pip install --upgrade wheel + pip install flake8 pytest + pip install cython numpy tqdm + - name: build_whl + run: | + python setup.py sdist bdist_wheel + - uses: actions/upload-artifact@v2 + with: + name: ${{ matrix.os }}-${{ matrix.python-version }} + path: dist/*.whl \ No newline at end of file diff --git a/.gitignore b/.gitignore index 4d3b0cb..8b2bc54 100644 --- a/.gitignore +++ b/.gitignore @@ -194,3 +194,4 @@ Temporary Items # Linux trash folder which might appear on any partition or disk .Trash-* +dic.tar.gz diff --git a/pyopenjtalk/__init__.py b/pyopenjtalk/__init__.py index a266104..2e0e3f5 100644 --- a/pyopenjtalk/__init__.py +++ b/pyopenjtalk/__init__.py @@ -1,24 +1,14 @@ import os -from os.path import exists import pkg_resources -import six -from tqdm.auto import tqdm - -if six.PY2: - from urllib import urlretrieve -else: - from urllib.request import urlretrieve - -import tarfile try: - from .version import __version__ # NOQA + from pyopenjtalk.version import __version__ # NOQA except ImportError: raise ImportError("BUG: version.py doesn't exist. Please file a bug report.") -from .htsengine import HTSEngine -from .openjtalk import OpenJTalk +from pyopenjtalk.htsengine import HTSEngine +from pyopenjtalk.openjtalk import OpenJTalk # Dictionary directory # defaults to the package directory where the dictionary will be automatically downloaded @@ -26,8 +16,6 @@ "OPEN_JTALK_DICT_DIR", pkg_resources.resource_filename(__name__, "open_jtalk_dic_utf_8-1.11"), ).encode("utf-8") -_dict_download_url = "https://github.com/r9y9/open_jtalk/releases/download/v1.11.1" -_DICT_URL = f"{_dict_download_url}/open_jtalk_dic_utf_8-1.11.tar.gz" # Default mei_normal.voice for HMM-based TTS DEFAULT_HTS_VOICE = pkg_resources.resource_filename( @@ -41,41 +29,6 @@ _global_htsengine = None -# https://github.com/tqdm/tqdm#hooks-and-callbacks -class _TqdmUpTo(tqdm): # type: ignore - def update_to(self, b=1, bsize=1, tsize=None): - if tsize is not None: - self.total = tsize - return self.update(b * bsize - self.n) - - -def _extract_dic(): - global OPEN_JTALK_DICT_DIR - filename = pkg_resources.resource_filename(__name__, "dic.tar.gz") - print('Downloading: "{}"'.format(_DICT_URL)) - with _TqdmUpTo( - unit="B", - unit_scale=True, - unit_divisor=1024, - miniters=1, - desc="dic.tar.gz", - ) as t: # all optional kwargs - urlretrieve(_DICT_URL, filename, reporthook=t.update_to) - t.total = t.n - print("Extracting tar file {}".format(filename)) - with tarfile.open(filename, mode="r|gz") as f: - f.extractall(path=pkg_resources.resource_filename(__name__, "")) - OPEN_JTALK_DICT_DIR = pkg_resources.resource_filename( - __name__, "open_jtalk_dic_utf_8-1.11" - ).encode("utf-8") - os.remove(filename) - - -def _lazy_init(): - if not exists(OPEN_JTALK_DICT_DIR): - _extract_dic() - - def g2p(*args, **kwargs): """Grapheme-to-phoeneme (G2P) conversion @@ -93,7 +46,6 @@ def g2p(*args, **kwargs): """ global _global_jtalk if _global_jtalk is None: - _lazy_init() _global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR) return _global_jtalk.g2p(*args, **kwargs) @@ -164,6 +116,5 @@ def run_frontend(text, verbose=0): """ global _global_jtalk if _global_jtalk is None: - _lazy_init() _global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR) return _global_jtalk.run_frontend(text, verbose) diff --git a/pyopenjtalk/htsengine.pyx b/pyopenjtalk/htsengine.pyx index a69400c..b30e813 100644 --- a/pyopenjtalk/htsengine.pyx +++ b/pyopenjtalk/htsengine.pyx @@ -1,25 +1,33 @@ # coding: utf-8 -# cython: boundscheck=True, wraparound=True -# cython: c_string_type=unicode, c_string_encoding=ascii +# cython: language_level=3 +# cython: boundscheck=False, wraparound=False +# cython: c_string_type=unicode, c_string_encoding=ascii, cdivision=True import numpy as np cimport numpy as np + np.import_array() cimport cython -from libc.stdlib cimport malloc, free - -from htsengine cimport HTS_Engine -from htsengine cimport ( - HTS_Engine_initialize, HTS_Engine_load, HTS_Engine_clear, HTS_Engine_refresh, - HTS_Engine_get_sampling_frequency, HTS_Engine_get_fperiod, - HTS_Engine_set_speed, HTS_Engine_add_half_tone, - HTS_Engine_synthesize_from_strings, - HTS_Engine_get_generated_speech, HTS_Engine_get_nsamples -) - -cdef class HTSEngine(object): +from cpython.mem cimport PyMem_Free, PyMem_Malloc +from cython.parallel cimport prange +from libc.stdint cimport uint8_t + +from pyopenjtalk.htsengine cimport (HTS_Engine, HTS_Engine_add_half_tone, + HTS_Engine_clear, HTS_Engine_get_fperiod, + HTS_Engine_get_generated_speech, + HTS_Engine_get_nsamples, + HTS_Engine_get_sampling_frequency, + HTS_Engine_initialize, HTS_Engine_load, + HTS_Engine_refresh, HTS_Engine_set_speed, + HTS_Engine_synthesize_from_strings) + + +@cython.final +@cython.no_gc +@cython.freelist(4) +cdef class HTSEngine: """HTSEngine Args: @@ -36,38 +44,48 @@ cdef class HTSEngine(object): self.clear() raise RuntimeError("Failed to initalize HTS_Engine") - def load(self, bytes voice): - cdef char* voices = voice - cdef char ret - ret = HTS_Engine_load(self.engine, &voices, 1) + cpdef inline char load(self, const uint8_t[::1] voice): + cdef: + char ret + const uint8_t *voice_ptr = &voice[0] + with nogil: + ret = HTS_Engine_load(self.engine, (&voice_ptr), 1) return ret - def get_sampling_frequency(self): + cpdef inline size_t get_sampling_frequency(self): """Get sampling frequency """ - return HTS_Engine_get_sampling_frequency(self.engine) + cdef size_t ret + with nogil: + ret = HTS_Engine_get_sampling_frequency(self.engine) + return ret - def get_fperiod(self): + cpdef inline size_t get_fperiod(self): """Get frame period""" - return HTS_Engine_get_fperiod(self.engine) + cdef size_t ret + with nogil: + ret = HTS_Engine_get_fperiod(self.engine) + return ret - def set_speed(self, speed=1.0): + cpdef inline void set_speed(self, double speed=1.0): """Set speed Args: speed (float): speed """ - HTS_Engine_set_speed(self.engine, speed) + with nogil: + HTS_Engine_set_speed(self.engine, speed) - def add_half_tone(self, half_tone=0.0): + cpdef inline void add_half_tone(self, double half_tone=0.0): """Additional half tone in log-f0 Args: half_tone (float): additional half tone """ - HTS_Engine_add_half_tone(self.engine, half_tone) + with nogil: + HTS_Engine_add_half_tone(self.engine, half_tone) - def synthesize(self, list labels): + cpdef inline np.ndarray[np.float64_t, ndim=1] synthesize(self, list labels): """Synthesize waveform from list of full-context labels Args: @@ -77,40 +95,49 @@ cdef class HTSEngine(object): np.ndarray: speech waveform """ self.synthesize_from_strings(labels) - x = self.get_generated_speech() + cdef np.ndarray[np.float64_t, ndim=1] x = self.get_generated_speech() self.refresh() return x - def synthesize_from_strings(self, list labels): + cpdef inline char synthesize_from_strings(self, list labels) except? 0: """Synthesize from strings""" cdef size_t num_lines = len(labels) - cdef char **lines = malloc((num_lines + 1) * sizeof(char*)) + cdef char **lines = PyMem_Malloc((num_lines + 1) * sizeof(char*)) + cdef int n for n in range(len(labels)): lines[n] = labels[n] - - cdef char ret = HTS_Engine_synthesize_from_strings(self.engine, lines, num_lines) - free(lines) + cdef char ret + with nogil: + ret = HTS_Engine_synthesize_from_strings(self.engine, lines, num_lines) + PyMem_Free(lines) # todo: use finally if ret != 1: raise RuntimeError("Failed to run synthesize_from_strings") + return ret - def get_generated_speech(self): + cpdef inline np.ndarray[np.float64_t, ndim=1] get_generated_speech(self): """Get generated speech""" cdef size_t nsamples = HTS_Engine_get_nsamples(self.engine) - cdef np.ndarray speech = np.zeros([nsamples], dtype=np.float64) - cdef size_t index - for index in range(nsamples): - speech[index] = HTS_Engine_get_generated_speech(self.engine, index) + cdef np.ndarray[np.float64_t, ndim=1] speech = np.zeros([nsamples], dtype=np.float64) + cdef double[::1] speech_view = speech + cdef int index + for index in prange(nsamples, nogil=True): + speech_view[index] = HTS_Engine_get_generated_speech(self.engine, index) return speech - def get_fullcontext_label_format(self): + cpdef inline str get_fullcontext_label_format(self): """Get full-context label format""" - return (HTS_Engine_get_fullcontext_label_format(self.engine)).decode("utf-8") - - def refresh(self): - HTS_Engine_refresh(self.engine) - - def clear(self): - HTS_Engine_clear(self.engine) + cdef const char* f + with nogil: + f = HTS_Engine_get_fullcontext_label_format(self.engine) + return (f).decode("utf-8") + + cpdef inline void refresh(self): + with nogil: + HTS_Engine_refresh(self.engine) + + cpdef inline void clear(self): + with nogil: + HTS_Engine_clear(self.engine) def __dealloc__(self): self.clear() diff --git a/pyopenjtalk/htsengine/__init__.pxd b/pyopenjtalk/htsengine/__init__.pxd index c24d959..e033ea5 100644 --- a/pyopenjtalk/htsengine/__init__.pxd +++ b/pyopenjtalk/htsengine/__init__.pxd @@ -1,7 +1,9 @@ # distutils: language = c++ -cdef extern from "HTS_engine.h": +# cython: language_level=3 + +cdef extern from "HTS_engine.h" nogil: cdef cppclass _HTS_Engine: pass ctypedef _HTS_Engine HTS_Engine diff --git a/pyopenjtalk/openjtalk.pyx b/pyopenjtalk/openjtalk.pyx index 46e6004..eda7ae1 100644 --- a/pyopenjtalk/openjtalk.pyx +++ b/pyopenjtalk/openjtalk.pyx @@ -1,97 +1,112 @@ # coding: utf-8 -# cython: boundscheck=True, wraparound=True -# cython: c_string_type=unicode, c_string_encoding=ascii +# cython: language_level=3 +# cython: boundscheck=False, wraparound=True +# cython: c_string_type=unicode, c_string_encoding=ascii, cdivision=True + +from libc.stdint cimport uint8_t import numpy as np cimport numpy as np + np.import_array() cimport cython - -from openjtalk.mecab cimport Mecab, Mecab_initialize, Mecab_load, Mecab_analysis -from openjtalk.mecab cimport Mecab_get_feature, Mecab_get_size, Mecab_refresh, Mecab_clear -from openjtalk.njd cimport NJD, NJD_initialize, NJD_refresh, NJD_print, NJD_clear -from openjtalk cimport njd as _njd -from openjtalk.jpcommon cimport JPCommon, JPCommon_initialize,JPCommon_make_label -from openjtalk.jpcommon cimport JPCommon_get_label_size, JPCommon_get_label_feature -from openjtalk.jpcommon cimport JPCommon_refresh, JPCommon_clear -from openjtalk cimport njd2jpcommon -from openjtalk.text2mecab cimport text2mecab -from openjtalk.mecab2njd cimport mecab2njd -from openjtalk.njd2jpcommon cimport njd2jpcommon - -cdef njd_node_get_string(_njd.NJDNode* node): +from cpython.bytes cimport PyBytes_AS_STRING + +from pyopenjtalk.openjtalk cimport njd as _njd +from pyopenjtalk.openjtalk cimport njd2jpcommon +from pyopenjtalk.openjtalk.jpcommon cimport (JPCommon, JPCommon_clear, + JPCommon_get_label_feature, + JPCommon_get_label_size, + JPCommon_initialize, + JPCommon_make_label, + JPCommon_refresh) +from pyopenjtalk.openjtalk.mecab cimport (Mecab, Mecab_analysis, Mecab_clear, + Mecab_get_feature, Mecab_get_size, + Mecab_initialize, Mecab_load, + Mecab_refresh) +from pyopenjtalk.openjtalk.mecab2njd cimport mecab2njd +from pyopenjtalk.openjtalk.njd cimport (NJD, NJD_clear, NJD_initialize, + NJD_print, NJD_refresh) +from pyopenjtalk.openjtalk.njd2jpcommon cimport njd2jpcommon +from pyopenjtalk.openjtalk.text2mecab cimport text2mecab + + +cdef inline str njd_node_get_string(_njd.NJDNode* node): return ((_njd.NJDNode_get_string(node))).decode("utf-8") -cdef njd_node_get_pos(_njd.NJDNode* node): +cdef inline str njd_node_get_pos(_njd.NJDNode* node): return ((_njd.NJDNode_get_pos(node))).decode("utf-8") -cdef njd_node_get_pos_group1(_njd.NJDNode* node): +cdef inline str njd_node_get_pos_group1(_njd.NJDNode* node): return ((_njd.NJDNode_get_pos_group1(node))).decode("utf-8") -cdef njd_node_get_pos_group2(_njd.NJDNode* node): +cdef inline str njd_node_get_pos_group2(_njd.NJDNode* node): return ((_njd.NJDNode_get_pos_group2(node))).decode("utf-8") -cdef njd_node_get_pos_group3(_njd.NJDNode* node): +cdef inline str njd_node_get_pos_group3(_njd.NJDNode* node): return ((_njd.NJDNode_get_pos_group3(node))).decode("utf-8") -cdef njd_node_get_ctype(_njd.NJDNode* node): +cdef inline str njd_node_get_ctype(_njd.NJDNode* node): return ((_njd.NJDNode_get_ctype(node))).decode("utf-8") -cdef njd_node_get_cform(_njd.NJDNode* node): +cdef inline str njd_node_get_cform(_njd.NJDNode* node): return ((_njd.NJDNode_get_cform(node))).decode("utf-8") -cdef njd_node_get_orig(_njd.NJDNode* node): +cdef inline str njd_node_get_orig(_njd.NJDNode* node): return ((_njd.NJDNode_get_orig(node))).decode("utf-8") -cdef njd_node_get_read(_njd.NJDNode* node): +cdef inline str njd_node_get_read(_njd.NJDNode* node): return ((_njd.NJDNode_get_read(node))).decode("utf-8") -cdef njd_node_get_pron(_njd.NJDNode* node): +cdef inline str njd_node_get_pron(_njd.NJDNode* node): return ((_njd.NJDNode_get_pron(node))).decode("utf-8") -cdef njd_node_get_acc(_njd.NJDNode* node): +cdef inline int njd_node_get_acc(_njd.NJDNode* node): return _njd.NJDNode_get_acc(node) -cdef njd_node_get_mora_size(_njd.NJDNode* node): +cdef inline int njd_node_get_mora_size(_njd.NJDNode* node): return _njd.NJDNode_get_mora_size(node) -cdef njd_node_get_chain_rule(_njd.NJDNode* node): +cdef inline str njd_node_get_chain_rule(_njd.NJDNode* node): return ((_njd.NJDNode_get_chain_rule(node))).decode("utf-8") -cdef njd_node_get_chain_flag(_njd.NJDNode* node): - return _njd.NJDNode_get_chain_flag(node) - - -cdef njd_node_print(_njd.NJDNode* node): - return "{},{},{},{},{},{},{},{},{},{},{}/{},{},{}".format( - njd_node_get_string(node), - njd_node_get_pos(node), - njd_node_get_pos_group1(node), - njd_node_get_pos_group2(node), - njd_node_get_pos_group3(node), - njd_node_get_ctype(node), - njd_node_get_cform(node), - njd_node_get_orig(node), - njd_node_get_read(node), - njd_node_get_pron(node), - njd_node_get_acc(node), - njd_node_get_mora_size(node), - njd_node_get_chain_rule(node), - njd_node_get_chain_flag(node) +cdef inline int njd_node_get_chain_flag(_njd.NJDNode* node): + return _njd.NJDNode_get_chain_flag(node) + + +cdef inline str njd_node_print(_njd.NJDNode* node): + return "{},{},{},{},{},{},{},{},{},{},{}/{},{},{}".format( + njd_node_get_string(node), + njd_node_get_pos(node), + njd_node_get_pos_group1(node), + njd_node_get_pos_group2(node), + njd_node_get_pos_group3(node), + njd_node_get_ctype(node), + njd_node_get_cform(node), + njd_node_get_orig(node), + njd_node_get_read(node), + njd_node_get_pron(node), + njd_node_get_acc(node), + njd_node_get_mora_size(node), + njd_node_get_chain_rule(node), + njd_node_get_chain_flag(node) ) -cdef njd_print(_njd.NJD* njd): +cdef list njd_print(_njd.NJD* njd): cdef _njd.NJDNode* node = njd.head njd_results = [] while node is not NULL: - njd_results.append(njd_node_print(node)) - node = node.next + njd_results.append(njd_node_print(node)) + node = node.next return njd_results -cdef class OpenJTalk(object): +@cython.no_gc +@cython.final +@cython.freelist(4) +cdef class OpenJTalk: """OpenJTalk Args: @@ -112,62 +127,72 @@ cdef class OpenJTalk(object): r = self._load(dn_mecab) if r != 1: - self._clear() - raise RuntimeError("Failed to initalize Mecab") + self._clear() + raise RuntimeError("Failed to initalize Mecab") + cpdef inline void _clear(self): + with nogil: + Mecab_clear(self.mecab) + NJD_clear(self.njd) + JPCommon_clear(self.jpcommon) - def _clear(self): - Mecab_clear(self.mecab) - NJD_clear(self.njd) - JPCommon_clear(self.jpcommon) + cpdef inline int _load(self, const uint8_t[::1] dn_mecab): + cdef int ret + with nogil: + ret = Mecab_load(self.mecab, &dn_mecab[0]) + return ret - def _load(self, bytes dn_mecab): - return Mecab_load(self.mecab, dn_mecab) - - def run_frontend(self, text, verbose=0): + cpdef inline tuple run_frontend(self, object text, int verbose=0): """Run OpenJTalk's text processing frontend """ if isinstance(text, str): - text = text.encode("utf-8") - cdef char buff[8192] - text2mecab(buff, text) - Mecab_analysis(self.mecab, buff) - mecab2njd(self.njd, Mecab_get_feature(self.mecab), Mecab_get_size(self.mecab)) - _njd.njd_set_pronunciation(self.njd) - _njd.njd_set_digit(self.njd) - _njd.njd_set_accent_phrase(self.njd) - _njd.njd_set_accent_type(self.njd) - _njd.njd_set_unvoiced_vowel(self.njd) - _njd.njd_set_long_vowel(self.njd) - njd2jpcommon(self.jpcommon, self.njd) - JPCommon_make_label(self.jpcommon) - - cdef int label_size = JPCommon_get_label_size(self.jpcommon) - cdef char** label_feature - label_feature = JPCommon_get_label_feature(self.jpcommon) + text = text.encode("utf-8") + cdef: + char buff[8192] + const char* text_ptr + int label_size + char** label_feature + text_ptr = PyBytes_AS_STRING(text) + with nogil: + text2mecab(buff, text_ptr) + Mecab_analysis(self.mecab, buff) + mecab2njd(self.njd, Mecab_get_feature(self.mecab), Mecab_get_size(self.mecab)) + _njd.njd_set_pronunciation(self.njd) + _njd.njd_set_digit(self.njd) + _njd.njd_set_accent_phrase(self.njd) + _njd.njd_set_accent_type(self.njd) + _njd.njd_set_unvoiced_vowel(self.njd) + _njd.njd_set_long_vowel(self.njd) + njd2jpcommon(self.jpcommon, self.njd) + JPCommon_make_label(self.jpcommon) + + label_size = JPCommon_get_label_size(self.jpcommon) + label_feature = JPCommon_get_label_feature(self.jpcommon) labels = [] + cdef int i for i in range(label_size): - # This will create a copy of c string - # http://cython.readthedocs.io/en/latest/src/tutorial/strings.html - labels.append(label_feature[i]) + # This will create a copy of c string + # http://cython.readthedocs.io/en/latest/src/tutorial/strings.html + labels.append(label_feature[i]) - njd_results = njd_print(self.njd) + cdef list njd_results = njd_print(self.njd) if verbose > 0: - NJD_print(self.njd) + NJD_print(self.njd) # Note that this will release memory for label feature - JPCommon_refresh(self.jpcommon) - NJD_refresh(self.njd) - Mecab_refresh(self.mecab) - + with nogil: + JPCommon_refresh(self.jpcommon) + NJD_refresh(self.njd) + Mecab_refresh(self.mecab) return njd_results, labels - def g2p(self, text, kana=False, join=True): + def g2p(self, object text, bint kana=False, bint join=True): """Grapheme-to-phoeneme (G2P) conversion """ + cdef list njd_results, labels njd_results, labels = self.run_frontend(text) if not kana: prons = list(map(lambda s: s.split("-")[1].split("+")[0], labels[1:-1])) diff --git a/pyopenjtalk/openjtalk/__init__.pxd b/pyopenjtalk/openjtalk/__init__.pxd index e69de29..019523c 100644 --- a/pyopenjtalk/openjtalk/__init__.pxd +++ b/pyopenjtalk/openjtalk/__init__.pxd @@ -0,0 +1 @@ +# cython: language_level=3 \ No newline at end of file diff --git a/pyopenjtalk/openjtalk/jpcommon.pxd b/pyopenjtalk/openjtalk/jpcommon.pxd index 8e86bea..3667d1e 100644 --- a/pyopenjtalk/openjtalk/jpcommon.pxd +++ b/pyopenjtalk/openjtalk/jpcommon.pxd @@ -1,22 +1,24 @@ # distutils: language = c++ +# cython: language_level=3 from libc.stdio cimport FILE -cdef extern from "jpcommon.h": + +cdef extern from "jpcommon.h" nogil: cdef cppclass JPCommonNode: - char *pron - char *pos - char *ctype - char *cform - int acc - int chain_flag - void *prev - void *next + char *pron + char *pos + char *ctype + char *cform + int acc + int chain_flag + void *prev + void *next cdef cppclass JPCommon: - JPCommonNode *head - JPCommonNode *tail - void *label + JPCommonNode *head + JPCommonNode *tail + void *label void JPCommon_initialize(JPCommon * jpcommon) void JPCommon_push(JPCommon * jpcommon, JPCommonNode * node) diff --git a/pyopenjtalk/openjtalk/mecab.pxd b/pyopenjtalk/openjtalk/mecab.pxd index bd367c7..2aa39e5 100644 --- a/pyopenjtalk/openjtalk/mecab.pxd +++ b/pyopenjtalk/openjtalk/mecab.pxd @@ -1,6 +1,7 @@ # distutils: language = c++ +# cython: language_level=3 -cdef extern from "mecab.h": +cdef extern from "mecab.h" nogil: cdef cppclass Mecab: char **feature int size diff --git a/pyopenjtalk/openjtalk/mecab2njd.pxd b/pyopenjtalk/openjtalk/mecab2njd.pxd index be57ccc..f42d3ec 100644 --- a/pyopenjtalk/openjtalk/mecab2njd.pxd +++ b/pyopenjtalk/openjtalk/mecab2njd.pxd @@ -1,6 +1,8 @@ # distutils: language = c++ +# cython: language_level=3 -from .njd cimport NJD +from pyopenjtalk.openjtalk.njd cimport NJD -cdef extern from "mecab2njd.h": + +cdef extern from "mecab2njd.h" nogil: void mecab2njd(NJD * njd, char **feature, int size); diff --git a/pyopenjtalk/openjtalk/njd.pxd b/pyopenjtalk/openjtalk/njd.pxd index 38d3887..ef2b8b3 100644 --- a/pyopenjtalk/openjtalk/njd.pxd +++ b/pyopenjtalk/openjtalk/njd.pxd @@ -1,8 +1,10 @@ # distutils: language = c++ +# cython: language_level=3 from libc.stdio cimport FILE -cdef extern from "njd.h": + +cdef extern from "njd.h" nogil: cdef cppclass NJDNode: char *string char *pos @@ -78,20 +80,20 @@ cdef extern from "njd.h": void NJD_refresh(NJD * njd) void NJD_clear(NJD * wl) -cdef extern from "njd_set_accent_phrase.h": +cdef extern from "njd_set_accent_phrase.h" nogil: void njd_set_accent_phrase(NJD * njd) -cdef extern from "njd_set_accent_type.h": +cdef extern from "njd_set_accent_type.h" nogil: void njd_set_accent_type(NJD * njd) -cdef extern from "njd_set_digit.h": +cdef extern from "njd_set_digit.h" nogil: void njd_set_digit(NJD * njd) -cdef extern from "njd_set_long_vowel.h": +cdef extern from "njd_set_long_vowel.h" nogil: void njd_set_long_vowel(NJD * njd) -cdef extern from "njd_set_pronunciation.h": +cdef extern from "njd_set_pronunciation.h" nogil: void njd_set_pronunciation(NJD * njd) -cdef extern from "njd_set_unvoiced_vowel.h": +cdef extern from "njd_set_unvoiced_vowel.h" nogil: void njd_set_unvoiced_vowel(NJD * njd) diff --git a/pyopenjtalk/openjtalk/njd2jpcommon.pxd b/pyopenjtalk/openjtalk/njd2jpcommon.pxd index 8309288..a803032 100644 --- a/pyopenjtalk/openjtalk/njd2jpcommon.pxd +++ b/pyopenjtalk/openjtalk/njd2jpcommon.pxd @@ -1,7 +1,9 @@ # distutils: language = c++ +# cython: language_level=3 -from .jpcommon cimport JPCommon -from .njd cimport NJD +from pyopenjtalk.openjtalk.jpcommon cimport JPCommon +from pyopenjtalk.openjtalk.njd cimport NJD -cdef extern from "njd2jpcommon.h": + +cdef extern from "njd2jpcommon.h" nogil: void njd2jpcommon(JPCommon * jpcommon, NJD * njd) diff --git a/pyopenjtalk/openjtalk/text2mecab.pxd b/pyopenjtalk/openjtalk/text2mecab.pxd index 6081757..0190f1a 100644 --- a/pyopenjtalk/openjtalk/text2mecab.pxd +++ b/pyopenjtalk/openjtalk/text2mecab.pxd @@ -1,4 +1,5 @@ # distutils: language = c++ +# cython: language_level=3 -cdef extern from "text2mecab.h": +cdef extern from "text2mecab.h" nogil: void text2mecab(char *output, const char *input) diff --git a/setup.py b/setup.py index c68db9c..ce6e3fe 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,9 @@ import os +import platform +import shutil import subprocess import sys +import tarfile from distutils.errors import DistutilsExecError from distutils.spawn import spawn from distutils.version import LooseVersion @@ -8,6 +11,7 @@ from itertools import chain from os.path import exists, join from subprocess import run +from urllib.request import urlretrieve import numpy as np import setuptools.command.build_py @@ -33,6 +37,10 @@ "/execution-charset:utf-8", ] +_dict_folder_name = "open_jtalk_dic_utf_8-1.11" +_dict_download_url = "https://github.com/r9y9/open_jtalk/releases/download/v1.11.1" +_DICT_URL = f"{_dict_download_url}/{_dict_folder_name}.tar.gz" + try: if not _CYTHON_INSTALLED: raise ImportError("No supported version of Cython installed.") @@ -70,6 +78,22 @@ def build_extensions(self): if not os.path.exists(join("pyopenjtalk", "openjtalk" + ext)): raise RuntimeError("Cython is required to generate C++ code") +# make openmp available +system = platform.system() +if system == "Windows": + extra_compile_args = [] + extra_link_args = ["/openmp"] +elif system == "Linux": + extra_compile_args = ["-fopenmp"] + extra_link_args = ["-fopenmp"] +elif system == "Darwin": + os.system("brew install libomp") + extra_compile_args = ["-Xpreprocessor", "-fopenmp"] + extra_link_args = ["-L/usr/local/lib", "-lomp"] +else: + extra_compile_args = ["-fopenmp"] + extra_link_args = ["-fopenmp"] + # Workaround for `distutils.spawn` problem on Windows python < 3.9 # See details: [bpo-39763: distutils.spawn now uses subprocess (GH-18743)] @@ -138,6 +162,24 @@ def escape_macros(macros): # open_jtalk sources src_top = join("lib", "open_jtalk", "src") + +# extract dic +filename = "dic.tar.gz" +print(f"Downloading: {_DICT_URL}") +urlretrieve(_DICT_URL, filename) +print("Download complete") + +print("Extracting tar file {}".format(filename)) +with tarfile.open(filename, mode="r|gz") as f: + f.extractall(path="./") +os.remove(filename) +print("Extract complete") +try: + shutil.copytree(f"./{_dict_folder_name}", f"./pyopenjtalk/{_dict_folder_name}") + sys.stdout.flush() +except FileExistsError: + pass + # generate config.h for mecab # NOTE: need to run cmake to generate config.h # we could do it on python side but it would be very tricky, @@ -180,8 +222,8 @@ def escape_macros(macros): name="pyopenjtalk.openjtalk", sources=[join("pyopenjtalk", "openjtalk" + ext)] + all_src, include_dirs=[np.get_include()] + include_dirs, - extra_compile_args=[], - extra_link_args=[], + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, language="c++", define_macros=custom_define_macros( [ @@ -204,8 +246,8 @@ def escape_macros(macros): name="pyopenjtalk.htsengine", sources=[join("pyopenjtalk", "htsengine" + ext)] + all_htsengine_src, include_dirs=[np.get_include(), join(htsengine_src_top, "include")], - extra_compile_args=[], - extra_link_args=[], + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, libraries=["winmm"] if platform_is_windows else [], language="c++", define_macros=custom_define_macros( @@ -272,14 +314,12 @@ def run(self): url="https://github.com/r9y9/pyopenjtalk", license="MIT", packages=find_packages(), - package_data={"": ["htsvoice/*"]}, + package_data={"": ["htsvoice/*", f"{_dict_folder_name}/*"]}, ext_modules=ext_modules, cmdclass=cmdclass, install_requires=[ "numpy >= 1.20.0", "cython >= " + min_cython_ver, - "six", - "tqdm", ], tests_require=["nose", "coverage"], extras_require={