diff --git a/README.md b/README.md index 76b7369..6df5ee5 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,38 @@ In [3]: pyopenjtalk.g2p("こんにちは", kana=True) Out[3]: 'コンニチワ' ``` +### Create/Apply user dictionary + +1. Create a CSV file (e.g. `user.csv`) and write custom words like below: + +```csv +GNU,,,1,名詞,一般,*,*,*,*,GNU,グヌー,グヌー,2/3,* +``` + +2. Call `mecab_dict_index` to compile the CSV file. + +```python +In [1]: import pyopenjtalk + +In [2]: pyopenjtalk.mecab_dict_index("user.csv", "user.dic") +reading user.csv ... 1 +emitting double-array: 100% |###########################################| + +done! +``` + +3. Call `update_global_jtalk_with_user_dict` to apply the user dictionary. + +```python +In [3]: pyopenjtalk.g2p("GNU") +Out[3]: 'j i i e n u y u u' + +In [4]: pyopenjtalk.update_global_jtalk_with_user_dict("user.dic") + +In [5]: pyopenjtalk.g2p("GNU") +Out[5]: 'g u n u u' +``` + ### About `run_marine` option After v0.3.0, the `run_marine` option has been available for estimating the Japanese accent with the DNN-based method (see [marine](https://github.com/6gsn/marine)). If you want to use the feature, please install pyopenjtalk as below; diff --git a/pyopenjtalk/__init__.py b/pyopenjtalk/__init__.py index 299ac29..72f5db5 100644 --- a/pyopenjtalk/__init__.py +++ b/pyopenjtalk/__init__.py @@ -19,6 +19,7 @@ from .htsengine import HTSEngine from .openjtalk import OpenJTalk +from .openjtalk import mecab_dict_index as _mecab_dict_index from .utils import merge_njd_marine_features # Dictionary directory @@ -224,3 +225,44 @@ def make_label(njd_features): _lazy_init() _global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR) return _global_jtalk.make_label(njd_features) + + +def mecab_dict_index(path, out_path, dn_mecab=None): + """Create user dictionary + + Args: + path (str): path to user csv + out_path (str): path to output dictionary + dn_mecab (optional. str): path to mecab dictionary + """ + global _global_jtalk + if _global_jtalk is None: + _lazy_init() + if not exists(path): + raise FileNotFoundError("no such file or directory: %s" % path) + if dn_mecab is None: + dn_mecab = OPEN_JTALK_DICT_DIR + r = _mecab_dict_index(dn_mecab, path.encode("utf-8"), out_path.encode("utf-8")) + + # NOTE: mecab load returns 1 if success, but mecab_dict_index return the opposite + # yeah it's confusing... + if r != 0: + raise RuntimeError("Failed to create user dictionary") + + +def update_global_jtalk_with_user_dict(path): + """Update global openjtalk instance with the user dictionary + + Note that this will change the global state of the openjtalk module. + + Args: + path (str): path to user dictionary + """ + global _global_jtalk + if _global_jtalk is None: + _lazy_init() + if not exists(path): + raise FileNotFoundError("no such file or directory: %s" % path) + _global_jtalk = OpenJTalk( + dn_mecab=OPEN_JTALK_DICT_DIR, userdic=path.encode("utf-8") + ) diff --git a/pyopenjtalk/openjtalk.pyx b/pyopenjtalk/openjtalk.pyx index 96505b8..291311e 100644 --- a/pyopenjtalk/openjtalk.pyx +++ b/pyopenjtalk/openjtalk.pyx @@ -9,9 +9,12 @@ np.import_array() cimport cython from libc.stdlib cimport calloc +from libc.string cimport strlen from .openjtalk.mecab cimport Mecab, Mecab_initialize, Mecab_load, Mecab_analysis from .openjtalk.mecab cimport Mecab_get_feature, Mecab_get_size, Mecab_refresh, Mecab_clear +from .openjtalk.mecab cimport createModel, Model, Tagger, Lattice +from .openjtalk.mecab cimport mecab_dict_index as _mecab_dict_index from .openjtalk.njd cimport NJD, NJD_initialize, NJD_refresh, NJD_print, NJD_clear from .openjtalk cimport njd as _njd from .openjtalk.jpcommon cimport JPCommon, JPCommon_initialize,JPCommon_make_label @@ -116,18 +119,52 @@ cdef feature2njd(_njd.NJD* njd, features): _njd.NJDNode_set_chain_flag(node, feature_node["chain_flag"]) _njd.NJD_push_node(njd, node) +# based on Mecab_load in impl. from mecab.cpp +cdef inline int Mecab_load_with_userdic(Mecab *m, char* dicdir, char* userdic): + if userdic == NULL or strlen(userdic) == 0: + return Mecab_load(m, dicdir) + + if m == NULL or dicdir == NULL or strlen(dicdir) == 0: + return 0 + + Mecab_clear(m) + + cdef (char*)[5] argv = ["mecab", "-d", dicdir, "-u", userdic] + cdef Model *model = createModel(5, argv) + + if model == NULL: + return 0 + m.model = model + + cdef Tagger *tagger = model.createTagger() + if tagger == NULL: + Mecab_clear(m) + return 0 + m.tagger = tagger + + cdef Lattice *lattice = model.createLattice() + if lattice == NULL: + Mecab_clear(m) + return 0 + m.lattice = lattice + + return 1 + cdef class OpenJTalk(object): """OpenJTalk Args: dn_mecab (bytes): Dictionaly path for MeCab. + userdic (bytes): Dictionary path for MeCab userdic. + This option is ignored when empty bytestring is given. + Default is empty. """ cdef Mecab* mecab cdef NJD* njd cdef JPCommon* jpcommon - def __cinit__(self, bytes dn_mecab=b"/usr/local/dic"): + def __cinit__(self, bytes dn_mecab=b"/usr/local/dic", bytes userdic=b""): self.mecab = new Mecab() self.njd = new NJD() self.jpcommon = new JPCommon() @@ -136,7 +173,7 @@ cdef class OpenJTalk(object): NJD_initialize(self.njd) JPCommon_initialize(self.jpcommon) - r = self._load(dn_mecab) + r = self._load(dn_mecab, userdic) if r != 1: self._clear() raise RuntimeError("Failed to initalize Mecab") @@ -147,8 +184,8 @@ cdef class OpenJTalk(object): NJD_clear(self.njd) JPCommon_clear(self.jpcommon) - def _load(self, bytes dn_mecab): - return Mecab_load(self.mecab, dn_mecab) + def _load(self, bytes dn_mecab, bytes userdic): + return Mecab_load_with_userdic(self.mecab, dn_mecab, userdic) def run_frontend(self, text): @@ -231,3 +268,18 @@ cdef class OpenJTalk(object): del self.mecab del self.njd del self.jpcommon + +def mecab_dict_index(bytes dn_mecab, bytes path, bytes out_path): + cdef (char*)[10] argv = [ + "mecab-dict-index", + "-d", + dn_mecab, + "-u", + out_path, + "-f", + "utf-8", + "-t", + "utf-8", + path + ] + return _mecab_dict_index(10, argv) diff --git a/pyopenjtalk/openjtalk/mecab.pxd b/pyopenjtalk/openjtalk/mecab.pxd index bd367c7..1538e05 100644 --- a/pyopenjtalk/openjtalk/mecab.pxd +++ b/pyopenjtalk/openjtalk/mecab.pxd @@ -16,3 +16,14 @@ cdef extern from "mecab.h": char **Mecab_get_feature(Mecab *m) cdef int Mecab_refresh(Mecab *m) cdef int Mecab_clear(Mecab *m) + cdef int mecab_dict_index(int argc, char **argv) + +cdef extern from "mecab.h" namespace "MeCab": + cdef cppclass Tagger: + pass + cdef cppclass Lattice: + pass + cdef cppclass Model: + Tagger *createTagger() + Lattice *createLattice() + cdef Model *createModel(int argc, char **argv) diff --git a/tests/test_data/.gitignore b/tests/test_data/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/tests/test_data/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/tests/test_openjtalk.py b/tests/test_openjtalk.py index 0f43363..adfc24e 100644 --- a/tests/test_openjtalk.py +++ b/tests/test_openjtalk.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pyopenjtalk @@ -80,3 +82,29 @@ def test_g2p_phone(): ]: p = pyopenjtalk.g2p(text, kana=False) assert p == pron + + +def test_userdic(): + for text, expected in [ + ("nnmn", "n a n a m i N"), + ("GNU", "g u n u u"), + ]: + p = pyopenjtalk.g2p(text) + assert p != expected + + user_csv = str(Path(__file__).parent / "test_data" / "user.csv") + user_dic = str(Path(__file__).parent / "test_data" / "user.dic") + + with open(user_csv, "w", encoding="utf-8") as f: + f.write("nnmn,,,1,名詞,一般,*,*,*,*,nnmn,ナナミン,ナナミン,1/4,*\n") + f.write("GNU,,,1,名詞,一般,*,*,*,*,GNU,グヌー,グヌー,2/3,*\n") + + pyopenjtalk.mecab_dict_index(f.name, user_dic) + pyopenjtalk.update_global_jtalk_with_user_dict(user_dic) + + for text, expected in [ + ("nnmn", "n a n a m i N"), + ("GNU", "g u n u u"), + ]: + p = pyopenjtalk.g2p(text) + assert p == expected