Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up by releasing gil and using openmp with zerocopy view #36

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/build_whl.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: build wheel

on:
workflow_dispatch:

jobs:
build:

runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"]
os: [ubuntu-latest, macos-latest, windows-latest]
fail-fast: false

steps:
- uses: actions/checkout@v2
- name: Check out recursively
run: git submodule update --init --recursive
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade setuptools
python -m pip install --upgrade wheel
pip install flake8 pytest
pip install cython numpy tqdm
- name: build_whl
run: |
python setup.py sdist bdist_wheel
- uses: actions/upload-artifact@v2
with:
name: ${{ matrix.os }}-${{ matrix.python-version }}
path: dist/*.whl
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,4 @@ Temporary Items

# Linux trash folder which might appear on any partition or disk
.Trash-*
dic.tar.gz
55 changes: 3 additions & 52 deletions pyopenjtalk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,21 @@
import os
from os.path import exists

import pkg_resources
import six
from tqdm.auto import tqdm

if six.PY2:
from urllib import urlretrieve
else:
from urllib.request import urlretrieve

import tarfile

try:
from .version import __version__ # NOQA
from pyopenjtalk.version import __version__ # NOQA
except ImportError:
raise ImportError("BUG: version.py doesn't exist. Please file a bug report.")

from .htsengine import HTSEngine
from .openjtalk import OpenJTalk
from pyopenjtalk.htsengine import HTSEngine
from pyopenjtalk.openjtalk import OpenJTalk

# Dictionary directory
# defaults to the package directory where the dictionary will be automatically downloaded
OPEN_JTALK_DICT_DIR = os.environ.get(
"OPEN_JTALK_DICT_DIR",
pkg_resources.resource_filename(__name__, "open_jtalk_dic_utf_8-1.11"),
).encode("utf-8")
_dict_download_url = "https://github.com/r9y9/open_jtalk/releases/download/v1.11.1"
_DICT_URL = f"{_dict_download_url}/open_jtalk_dic_utf_8-1.11.tar.gz"

# Default mei_normal.voice for HMM-based TTS
DEFAULT_HTS_VOICE = pkg_resources.resource_filename(
Expand All @@ -41,41 +29,6 @@
_global_htsengine = None


# https://github.com/tqdm/tqdm#hooks-and-callbacks
class _TqdmUpTo(tqdm): # type: ignore
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
return self.update(b * bsize - self.n)
Comment on lines -45 to -49
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the tqdm-related code removed? If it doesn't cause any problem, could you revert it back?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is still in steup.py and this removing is part of the

Consider removing the lazy downloader in pyopenjtalk/__init__.py

I said yesterday.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant why the following code:

with TqdmUpTo(...):
  urlretrieve(...)

is reduced to

urlretrieve(...)

I think the former was okay as is.

Copy link

@fumiama fumiama Aug 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I just noticed that this code was removed after my changes by @synodriver . In my opinion, since this tqdm is only appears in setup time, whether removing or remaining it is okey. So what is your purpose of removing this?😂

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a progress bar and not necessary. Besides, I noticed some wired issue about downloading when runing ci with tqdm. I have no idea what's happening but to remove it. Then the ci works fine.

Copy link

@fumiama fumiama Aug 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That may be an acceptable reason of removing this. Or we can pass an env variable to disable it in CI and enable it by default when user want to compile it natively?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have never had issues with tqdm on CI so far. Is the problem reproducible? If so, it's okay to remove it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have run it several times on my fork but it always fail to build with tqdm installed. It's just like... the stdout is hiden, making it difficult to debug. And without tqdm it works fine. BTW, why does the progress bar so important for this project?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Umm, OK. The progress bar was just useful for knowing the estimated time to finish the downloading process.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But with prebuild-wheels, there is no need for users to download. You can check the files I mentioned here, the size of them shows that they already contains data files in the wheel.



def _extract_dic():
global OPEN_JTALK_DICT_DIR
filename = pkg_resources.resource_filename(__name__, "dic.tar.gz")
print('Downloading: "{}"'.format(_DICT_URL))
with _TqdmUpTo(
unit="B",
unit_scale=True,
unit_divisor=1024,
miniters=1,
desc="dic.tar.gz",
) as t: # all optional kwargs
urlretrieve(_DICT_URL, filename, reporthook=t.update_to)
t.total = t.n
print("Extracting tar file {}".format(filename))
with tarfile.open(filename, mode="r|gz") as f:
f.extractall(path=pkg_resources.resource_filename(__name__, ""))
OPEN_JTALK_DICT_DIR = pkg_resources.resource_filename(
__name__, "open_jtalk_dic_utf_8-1.11"
).encode("utf-8")
os.remove(filename)


def _lazy_init():
if not exists(OPEN_JTALK_DICT_DIR):
_extract_dic()


def g2p(*args, **kwargs):
"""Grapheme-to-phoeneme (G2P) conversion

Expand All @@ -93,7 +46,6 @@ def g2p(*args, **kwargs):
"""
global _global_jtalk
if _global_jtalk is None:
_lazy_init()
_global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR)
return _global_jtalk.g2p(*args, **kwargs)

Expand Down Expand Up @@ -164,6 +116,5 @@ def run_frontend(text, verbose=0):
"""
global _global_jtalk
if _global_jtalk is None:
_lazy_init()
_global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR)
return _global_jtalk.run_frontend(text, verbose)
119 changes: 73 additions & 46 deletions pyopenjtalk/htsengine.pyx
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
# coding: utf-8
# cython: boundscheck=True, wraparound=True
# cython: c_string_type=unicode, c_string_encoding=ascii
# cython: language_level=3
# cython: boundscheck=False, wraparound=False
# cython: c_string_type=unicode, c_string_encoding=ascii, cdivision=True

import numpy as np

cimport numpy as np

np.import_array()

cimport cython
from libc.stdlib cimport malloc, free

from htsengine cimport HTS_Engine
from htsengine cimport (
HTS_Engine_initialize, HTS_Engine_load, HTS_Engine_clear, HTS_Engine_refresh,
HTS_Engine_get_sampling_frequency, HTS_Engine_get_fperiod,
HTS_Engine_set_speed, HTS_Engine_add_half_tone,
HTS_Engine_synthesize_from_strings,
HTS_Engine_get_generated_speech, HTS_Engine_get_nsamples
)

cdef class HTSEngine(object):
from cpython.mem cimport PyMem_Free, PyMem_Malloc
from cython.parallel cimport prange
from libc.stdint cimport uint8_t

from pyopenjtalk.htsengine cimport (HTS_Engine, HTS_Engine_add_half_tone,
HTS_Engine_clear, HTS_Engine_get_fperiod,
HTS_Engine_get_generated_speech,
HTS_Engine_get_nsamples,
HTS_Engine_get_sampling_frequency,
HTS_Engine_initialize, HTS_Engine_load,
HTS_Engine_refresh, HTS_Engine_set_speed,
HTS_Engine_synthesize_from_strings)


@cython.final
@cython.no_gc
@cython.freelist(4)
cdef class HTSEngine:
"""HTSEngine

Args:
Expand All @@ -36,38 +44,48 @@ cdef class HTSEngine(object):
self.clear()
raise RuntimeError("Failed to initalize HTS_Engine")

def load(self, bytes voice):
cdef char* voices = voice
cdef char ret
ret = HTS_Engine_load(self.engine, &voices, 1)
cpdef inline char load(self, const uint8_t[::1] voice):
cdef:
char ret
const uint8_t *voice_ptr = &voice[0]
with nogil:
ret = HTS_Engine_load(self.engine, <char**>(&voice_ptr), 1)
return ret

def get_sampling_frequency(self):
cpdef inline size_t get_sampling_frequency(self):
"""Get sampling frequency
"""
return HTS_Engine_get_sampling_frequency(self.engine)
cdef size_t ret
with nogil:
ret = HTS_Engine_get_sampling_frequency(self.engine)
return ret

def get_fperiod(self):
cpdef inline size_t get_fperiod(self):
"""Get frame period"""
return HTS_Engine_get_fperiod(self.engine)
cdef size_t ret
with nogil:
ret = HTS_Engine_get_fperiod(self.engine)
return ret

def set_speed(self, speed=1.0):
cpdef inline void set_speed(self, double speed=1.0):
"""Set speed

Args:
speed (float): speed
"""
HTS_Engine_set_speed(self.engine, speed)
with nogil:
HTS_Engine_set_speed(self.engine, speed)

def add_half_tone(self, half_tone=0.0):
cpdef inline void add_half_tone(self, double half_tone=0.0):
"""Additional half tone in log-f0

Args:
half_tone (float): additional half tone
"""
HTS_Engine_add_half_tone(self.engine, half_tone)
with nogil:
HTS_Engine_add_half_tone(self.engine, half_tone)

def synthesize(self, list labels):
cpdef inline np.ndarray[np.float64_t, ndim=1] synthesize(self, list labels):
"""Synthesize waveform from list of full-context labels

Args:
Expand All @@ -77,40 +95,49 @@ cdef class HTSEngine(object):
np.ndarray: speech waveform
"""
self.synthesize_from_strings(labels)
x = self.get_generated_speech()
cdef np.ndarray[np.float64_t, ndim=1] x = self.get_generated_speech()
self.refresh()
return x

def synthesize_from_strings(self, list labels):
cpdef inline char synthesize_from_strings(self, list labels) except? 0:
"""Synthesize from strings"""
cdef size_t num_lines = len(labels)
cdef char **lines = <char**> malloc((num_lines + 1) * sizeof(char*))
cdef char **lines = <char**> PyMem_Malloc((num_lines + 1) * sizeof(char*))
cdef int n
for n in range(len(labels)):
lines[n] = <char*>labels[n]

cdef char ret = HTS_Engine_synthesize_from_strings(self.engine, lines, num_lines)
free(lines)
cdef char ret
with nogil:
ret = HTS_Engine_synthesize_from_strings(self.engine, lines, num_lines)
PyMem_Free(lines) # todo: use finally
if ret != 1:
raise RuntimeError("Failed to run synthesize_from_strings")
return ret

def get_generated_speech(self):
cpdef inline np.ndarray[np.float64_t, ndim=1] get_generated_speech(self):
"""Get generated speech"""
cdef size_t nsamples = HTS_Engine_get_nsamples(self.engine)
cdef np.ndarray speech = np.zeros([nsamples], dtype=np.float64)
cdef size_t index
for index in range(nsamples):
speech[index] = HTS_Engine_get_generated_speech(self.engine, index)
cdef np.ndarray[np.float64_t, ndim=1] speech = np.zeros([nsamples], dtype=np.float64)
cdef double[::1] speech_view = speech
cdef int index
for index in prange(nsamples, nogil=True):
speech_view[index] = HTS_Engine_get_generated_speech(self.engine, <size_t>index)
return speech

def get_fullcontext_label_format(self):
cpdef inline str get_fullcontext_label_format(self):
"""Get full-context label format"""
return (<bytes>HTS_Engine_get_fullcontext_label_format(self.engine)).decode("utf-8")

def refresh(self):
HTS_Engine_refresh(self.engine)

def clear(self):
HTS_Engine_clear(self.engine)
cdef const char* f
with nogil:
f = HTS_Engine_get_fullcontext_label_format(self.engine)
return (<bytes>f).decode("utf-8")

cpdef inline void refresh(self):
with nogil:
HTS_Engine_refresh(self.engine)

cpdef inline void clear(self):
with nogil:
HTS_Engine_clear(self.engine)

def __dealloc__(self):
self.clear()
Expand Down
4 changes: 3 additions & 1 deletion pyopenjtalk/htsengine/__init__.pxd
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# distutils: language = c++


cdef extern from "HTS_engine.h":
# cython: language_level=3

cdef extern from "HTS_engine.h" nogil:
cdef cppclass _HTS_Engine:
pass
ctypedef _HTS_Engine HTS_Engine
Expand Down
Loading