Skip to content

Commit

Permalink
Merge pull request #72 from r9y9/userdic
Browse files Browse the repository at this point in the history
Add user dic support
  • Loading branch information
r9y9 authored Nov 18, 2023
2 parents 6928f45 + 6740742 commit 26fcdd9
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 4 deletions.
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,38 @@ In [3]: pyopenjtalk.g2p("こんにちは", kana=True)
Out[3]: 'コンニチワ'
```

### Create/Apply user dictionary

1. Create a CSV file (e.g. `user.csv`) and write custom words like below:

```csv
GNU,,,1,名詞,一般,*,*,*,*,GNU,グヌー,グヌー,2/3,*
```

2. Call `mecab_dict_index` to compile the CSV file.

```python
In [1]: import pyopenjtalk

In [2]: pyopenjtalk.mecab_dict_index("user.csv", "user.dic")
reading user.csv ... 1
emitting double-array: 100% |###########################################|

done!
```

3. Call `update_global_jtalk_with_user_dict` to apply the user dictionary.

```python
In [3]: pyopenjtalk.g2p("GNU")
Out[3]: 'j i i e n u y u u'

In [4]: pyopenjtalk.update_global_jtalk_with_user_dict("user.dic")

In [5]: pyopenjtalk.g2p("GNU")
Out[5]: 'g u n u u'
```

### About `run_marine` option

After v0.3.0, the `run_marine` option has been available for estimating the Japanese accent with the DNN-based method (see [marine](https://github.com/6gsn/marine)). If you want to use the feature, please install pyopenjtalk as below;
Expand Down
42 changes: 42 additions & 0 deletions pyopenjtalk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from .htsengine import HTSEngine
from .openjtalk import OpenJTalk
from .openjtalk import mecab_dict_index as _mecab_dict_index
from .utils import merge_njd_marine_features

# Dictionary directory
Expand Down Expand Up @@ -224,3 +225,44 @@ def make_label(njd_features):
_lazy_init()
_global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR)
return _global_jtalk.make_label(njd_features)


def mecab_dict_index(path, out_path, dn_mecab=None):
"""Create user dictionary
Args:
path (str): path to user csv
out_path (str): path to output dictionary
dn_mecab (optional. str): path to mecab dictionary
"""
global _global_jtalk
if _global_jtalk is None:
_lazy_init()
if not exists(path):
raise FileNotFoundError("no such file or directory: %s" % path)
if dn_mecab is None:
dn_mecab = OPEN_JTALK_DICT_DIR
r = _mecab_dict_index(dn_mecab, path.encode("utf-8"), out_path.encode("utf-8"))

# NOTE: mecab load returns 1 if success, but mecab_dict_index return the opposite
# yeah it's confusing...
if r != 0:
raise RuntimeError("Failed to create user dictionary")


def update_global_jtalk_with_user_dict(path):
"""Update global openjtalk instance with the user dictionary
Note that this will change the global state of the openjtalk module.
Args:
path (str): path to user dictionary
"""
global _global_jtalk
if _global_jtalk is None:
_lazy_init()
if not exists(path):
raise FileNotFoundError("no such file or directory: %s" % path)
_global_jtalk = OpenJTalk(
dn_mecab=OPEN_JTALK_DICT_DIR, userdic=path.encode("utf-8")
)
60 changes: 56 additions & 4 deletions pyopenjtalk/openjtalk.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ np.import_array()

cimport cython
from libc.stdlib cimport calloc
from libc.string cimport strlen

from .openjtalk.mecab cimport Mecab, Mecab_initialize, Mecab_load, Mecab_analysis
from .openjtalk.mecab cimport Mecab_get_feature, Mecab_get_size, Mecab_refresh, Mecab_clear
from .openjtalk.mecab cimport createModel, Model, Tagger, Lattice
from .openjtalk.mecab cimport mecab_dict_index as _mecab_dict_index
from .openjtalk.njd cimport NJD, NJD_initialize, NJD_refresh, NJD_print, NJD_clear
from .openjtalk cimport njd as _njd
from .openjtalk.jpcommon cimport JPCommon, JPCommon_initialize,JPCommon_make_label
Expand Down Expand Up @@ -116,18 +119,52 @@ cdef feature2njd(_njd.NJD* njd, features):
_njd.NJDNode_set_chain_flag(node, feature_node["chain_flag"])
_njd.NJD_push_node(njd, node)

# based on Mecab_load in impl. from mecab.cpp
cdef inline int Mecab_load_with_userdic(Mecab *m, char* dicdir, char* userdic):
if userdic == NULL or strlen(userdic) == 0:
return Mecab_load(m, dicdir)

if m == NULL or dicdir == NULL or strlen(dicdir) == 0:
return 0

Mecab_clear(m)

cdef (char*)[5] argv = ["mecab", "-d", dicdir, "-u", userdic]
cdef Model *model = createModel(5, argv)

if model == NULL:
return 0
m.model = model

cdef Tagger *tagger = model.createTagger()
if tagger == NULL:
Mecab_clear(m)
return 0
m.tagger = tagger

cdef Lattice *lattice = model.createLattice()
if lattice == NULL:
Mecab_clear(m)
return 0
m.lattice = lattice

return 1


cdef class OpenJTalk(object):
"""OpenJTalk
Args:
dn_mecab (bytes): Dictionaly path for MeCab.
userdic (bytes): Dictionary path for MeCab userdic.
This option is ignored when empty bytestring is given.
Default is empty.
"""
cdef Mecab* mecab
cdef NJD* njd
cdef JPCommon* jpcommon

def __cinit__(self, bytes dn_mecab=b"/usr/local/dic"):
def __cinit__(self, bytes dn_mecab=b"/usr/local/dic", bytes userdic=b""):
self.mecab = new Mecab()
self.njd = new NJD()
self.jpcommon = new JPCommon()
Expand All @@ -136,7 +173,7 @@ cdef class OpenJTalk(object):
NJD_initialize(self.njd)
JPCommon_initialize(self.jpcommon)

r = self._load(dn_mecab)
r = self._load(dn_mecab, userdic)
if r != 1:
self._clear()
raise RuntimeError("Failed to initalize Mecab")
Expand All @@ -147,8 +184,8 @@ cdef class OpenJTalk(object):
NJD_clear(self.njd)
JPCommon_clear(self.jpcommon)

def _load(self, bytes dn_mecab):
return Mecab_load(self.mecab, dn_mecab)
def _load(self, bytes dn_mecab, bytes userdic):
return Mecab_load_with_userdic(self.mecab, dn_mecab, userdic)


def run_frontend(self, text):
Expand Down Expand Up @@ -231,3 +268,18 @@ cdef class OpenJTalk(object):
del self.mecab
del self.njd
del self.jpcommon

def mecab_dict_index(bytes dn_mecab, bytes path, bytes out_path):
cdef (char*)[10] argv = [
"mecab-dict-index",
"-d",
dn_mecab,
"-u",
out_path,
"-f",
"utf-8",
"-t",
"utf-8",
path
]
return _mecab_dict_index(10, argv)
11 changes: 11 additions & 0 deletions pyopenjtalk/openjtalk/mecab.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,14 @@ cdef extern from "mecab.h":
char **Mecab_get_feature(Mecab *m)
cdef int Mecab_refresh(Mecab *m)
cdef int Mecab_clear(Mecab *m)
cdef int mecab_dict_index(int argc, char **argv)

cdef extern from "mecab.h" namespace "MeCab":
cdef cppclass Tagger:
pass
cdef cppclass Lattice:
pass
cdef cppclass Model:
Tagger *createTagger()
Lattice *createLattice()
cdef Model *createModel(int argc, char **argv)
2 changes: 2 additions & 0 deletions tests/test_data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
28 changes: 28 additions & 0 deletions tests/test_openjtalk.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import pyopenjtalk


Expand Down Expand Up @@ -80,3 +82,29 @@ def test_g2p_phone():
]:
p = pyopenjtalk.g2p(text, kana=False)
assert p == pron


def test_userdic():
for text, expected in [
("nnmn", "n a n a m i N"),
("GNU", "g u n u u"),
]:
p = pyopenjtalk.g2p(text)
assert p != expected

user_csv = str(Path(__file__).parent / "test_data" / "user.csv")
user_dic = str(Path(__file__).parent / "test_data" / "user.dic")

with open(user_csv, "w", encoding="utf-8") as f:
f.write("nnmn,,,1,名詞,一般,*,*,*,*,nnmn,ナナミン,ナナミン,1/4,*\n")
f.write("GNU,,,1,名詞,一般,*,*,*,*,GNU,グヌー,グヌー,2/3,*\n")

pyopenjtalk.mecab_dict_index(f.name, user_dic)
pyopenjtalk.update_global_jtalk_with_user_dict(user_dic)

for text, expected in [
("nnmn", "n a n a m i N"),
("GNU", "g u n u u"),
]:
p = pyopenjtalk.g2p(text)
assert p == expected

0 comments on commit 26fcdd9

Please sign in to comment.