Skip to content

Commit

Permalink
feat(tn): remove_erhua (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong authored Nov 13, 2023
1 parent 52f2504 commit bd48adb
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
4 changes: 3 additions & 1 deletion tn/chinese/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ def __init__(self,
cache_dir=None,
overwrite_cache=False,
remove_interjections=True,
remove_erhua=True,
traditional_to_simple=True,
remove_puncts=False,
full_to_half=True,
tag_oov=False):
super().__init__(name='normalizer')
self.remove_interjections = remove_interjections
self.remove_erhua = remove_erhua
self.traditional_to_simple = traditional_to_simple
self.remove_puncts = remove_puncts
self.full_to_half = full_to_half
Expand Down Expand Up @@ -81,7 +83,7 @@ def build_verbalizer(self):
money = Money().verbalizer
sport = Sport().verbalizer
time = Time().verbalizer
whitelist = Whitelist().verbalizer
whitelist = Whitelist(remove_erhua=self.remove_erhua).verbalizer

verbalizer = (cardinal | char | date | fraction | math | measure
| money | sport | time | whitelist).optimize()
Expand Down
9 changes: 7 additions & 2 deletions tn/chinese/rules/whitelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@

class Whitelist(Processor):

def __init__(self):
def __init__(self, remove_erhua=True):
super().__init__(name='whitelist')
self.remove_erhua = remove_erhua
self.build_tagger()
self.build_verbalizer()

Expand All @@ -35,5 +36,9 @@ def build_tagger(self):

def build_verbalizer(self):
super().build_verbalizer()
verbalizer = self.delete_tokens(delete('erhua: "儿"'))
if self.remove_erhua:
verbalizer = self.delete_tokens(delete('erhua: "儿"'))
else:
verbalizer = self.delete_tokens(delete('erhua: \"') +
accep('儿') + delete('\"'))
self.verbalizer |= verbalizer
6 changes: 5 additions & 1 deletion tn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ def main():
help='rebuild *.fst')
parser.add_argument('--remove_interjections', type=str,
default='True',
help='remove interjections like "啊" and "儿"')
help='remove interjections like "啊"')
parser.add_argument('--remove_erhua', type=str,
default='True',
help='remove "儿"')
parser.add_argument('--traditional_to_simple', type=str,
default='True',
help='i.e., "喆" -> "哲"')
Expand All @@ -48,6 +51,7 @@ def main():
normalizer = Normalizer(cache_dir=args.cache_dir,
overwrite_cache=args.overwrite_cache,
remove_interjections=str2bool(args.remove_interjections),
remove_erhua=str2bool(args.remove_erhua),
traditional_to_simple=str2bool(args.traditional_to_simple),
remove_puncts=str2bool(args.remove_puncts),
full_to_half=str2bool(args.full_to_half),
Expand Down

0 comments on commit bd48adb

Please sign in to comment.