Skip to content

Commit

Permalink
Merge pull request #87 from sabonerune/enh/add-mutex
Browse files Browse the repository at this point in the history
ENH: Add mutex for global instance
  • Loading branch information
r9y9 authored Dec 10, 2024
2 parents a3e2115 + 5374f0c commit 3c9894f
Showing 1 changed file with 76 additions and 55 deletions.
131 changes: 76 additions & 55 deletions pyopenjtalk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations

import atexit
import os
import sys
import tarfile
import tempfile
from contextlib import ExitStack
from collections.abc import Callable, Generator
from contextlib import ExitStack, contextmanager
from os.path import exists
from threading import Lock
from typing import TypeVar
from urllib.request import urlopen

if sys.version_info >= (3, 9):
Expand Down Expand Up @@ -44,14 +49,6 @@
)
).encode("utf-8")

# Global instance of OpenJTalk
_global_jtalk = None
# Global instance of HTSEngine
# mei_normal.voice is used as default
_global_htsengine = None
# Global instance of Marine
_global_marine = None


def _extract_dic():
from tqdm.auto import tqdm
Expand All @@ -78,6 +75,49 @@ def _lazy_init():
_extract_dic()


_T = TypeVar("_T")


def _global_instance_manager(
instance_factory: Callable[[], _T] | None = None, instance: _T | None = None
) -> Callable[[], Generator[_T, None, None]]:
assert instance_factory is not None or instance is not None
_instance = instance
mutex = Lock()

@contextmanager
def manager() -> Generator[_T, None, None]:
nonlocal _instance
with mutex:
if _instance is None:
_instance = instance_factory()
yield _instance

return manager


def _jtalk_factory() -> OpenJTalk:
_lazy_init()
return OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR)


def _marine_factory():
try:
from marine.predict import Predictor
except ImportError:
raise ImportError("Please install marine by `pip install pyopenjtalk[marine]`")
return Predictor()


# Global instance of OpenJTalk
_global_jtalk = _global_instance_manager(_jtalk_factory)
# Global instance of HTSEngine
# mei_normal.voice is used as default
_global_htsengine = _global_instance_manager(lambda: HTSEngine(DEFAULT_HTS_VOICE))
# Global instance of Marine
_global_marine = _global_instance_manager(_marine_factory)


def g2p(*args, **kwargs):
"""Grapheme-to-phoeneme (G2P) conversion
Expand All @@ -93,11 +133,8 @@ def g2p(*args, **kwargs):
Returns:
str or list: G2P result in 1) str if join is True 2) list if join is False.
"""
global _global_jtalk
if _global_jtalk is None:
_lazy_init()
_global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR)
return _global_jtalk.g2p(*args, **kwargs)
with _global_jtalk() as jtalk:
return jtalk.g2p(*args, **kwargs)


def estimate_accent(njd_features):
Expand All @@ -111,21 +148,13 @@ def estimate_accent(njd_features):
Returns:
list: features for NJDNode with estimation results by marine.
"""
global _global_marine
if _global_marine is None:
try:
from marine.predict import Predictor
except BaseException:
raise ImportError(
"Please install marine by `pip install pyopenjtalk[marine]`"
)
_global_marine = Predictor()
from marine.utils.openjtalk_util import convert_njd_feature_to_marine_feature
with _global_marine() as marine:
from marine.utils.openjtalk_util import convert_njd_feature_to_marine_feature

marine_feature = convert_njd_feature_to_marine_feature(njd_features)
marine_results = _global_marine.predict(
[marine_feature], require_open_jtalk_format=True
)
marine_feature = convert_njd_feature_to_marine_feature(njd_features)
marine_results = marine.predict(
[marine_feature], require_open_jtalk_format=True
)
njd_features = merge_njd_marine_features(njd_features, marine_results)
return njd_features

Expand Down Expand Up @@ -164,13 +193,11 @@ def synthesize(labels, speed=1.0, half_tone=0.0):
if isinstance(labels, tuple) and len(labels) == 2:
labels = labels[1]

global _global_htsengine
if _global_htsengine is None:
_global_htsengine = HTSEngine(DEFAULT_HTS_VOICE)
sr = _global_htsengine.get_sampling_frequency()
_global_htsengine.set_speed(speed)
_global_htsengine.add_half_tone(half_tone)
return _global_htsengine.synthesize(labels), sr
with _global_htsengine() as htsengine:
sr = htsengine.get_sampling_frequency()
htsengine.set_speed(speed)
htsengine.add_half_tone(half_tone)
return htsengine.synthesize(labels), sr


def tts(text, speed=1.0, half_tone=0.0, run_marine=False):
Expand Down Expand Up @@ -202,11 +229,8 @@ def run_frontend(text):
Returns:
list: features for NJDNode.
"""
global _global_jtalk
if _global_jtalk is None:
_lazy_init()
_global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR)
return _global_jtalk.run_frontend(text)
with _global_jtalk() as jtalk:
return jtalk.run_frontend(text)


def make_label(njd_features):
Expand All @@ -218,11 +242,8 @@ def make_label(njd_features):
Returns:
list: full-context labels.
"""
global _global_jtalk
if _global_jtalk is None:
_lazy_init()
_global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR)
return _global_jtalk.make_label(njd_features)
with _global_jtalk() as jtalk:
return jtalk.make_label(njd_features)


def mecab_dict_index(path, out_path, dn_mecab=None):
Expand All @@ -233,12 +254,11 @@ def mecab_dict_index(path, out_path, dn_mecab=None):
out_path (str): path to output dictionary
dn_mecab (optional. str): path to mecab dictionary
"""
global _global_jtalk
if _global_jtalk is None:
_lazy_init()
if not exists(path):
raise FileNotFoundError("no such file or directory: %s" % path)
if dn_mecab is None:
with _global_jtalk(): # call _lazy_init()
pass
dn_mecab = OPEN_JTALK_DICT_DIR
r = _mecab_dict_index(dn_mecab, path.encode("utf-8"), out_path.encode("utf-8"))

Expand All @@ -257,10 +277,11 @@ def update_global_jtalk_with_user_dict(path):
path (str): path to user dictionary
"""
global _global_jtalk
if _global_jtalk is None:
_lazy_init()
if not exists(path):
raise FileNotFoundError("no such file or directory: %s" % path)
_global_jtalk = OpenJTalk(
dn_mecab=OPEN_JTALK_DICT_DIR, userdic=path.encode("utf-8")
)
with _global_jtalk():
if not exists(path):
raise FileNotFoundError("no such file or directory: %s" % path)
_global_jtalk = _global_instance_manager(
instance=OpenJTalk(
dn_mecab=OPEN_JTALK_DICT_DIR, userdic=path.encode("utf-8")
)
)

0 comments on commit 3c9894f

Please sign in to comment.