diff --git a/tn/chinese/normalizer.py b/tn/chinese/normalizer.py index 18ab507..a4b7e1b 100644 --- a/tn/chinese/normalizer.py +++ b/tn/chinese/normalizer.py @@ -36,12 +36,14 @@ def __init__(self, cache_dir=None, overwrite_cache=False, remove_interjections=True, + remove_erhua=True, traditional_to_simple=True, remove_puncts=False, full_to_half=True, tag_oov=False): super().__init__(name='normalizer') self.remove_interjections = remove_interjections + self.remove_erhua = remove_erhua self.traditional_to_simple = traditional_to_simple self.remove_puncts = remove_puncts self.full_to_half = full_to_half @@ -81,7 +83,7 @@ def build_verbalizer(self): money = Money().verbalizer sport = Sport().verbalizer time = Time().verbalizer - whitelist = Whitelist().verbalizer + whitelist = Whitelist(remove_erhua=self.remove_erhua).verbalizer verbalizer = (cardinal | char | date | fraction | math | measure | money | sport | time | whitelist).optimize() diff --git a/tn/chinese/rules/whitelist.py b/tn/chinese/rules/whitelist.py index 5db6e06..c081fd0 100644 --- a/tn/chinese/rules/whitelist.py +++ b/tn/chinese/rules/whitelist.py @@ -20,8 +20,9 @@ class Whitelist(Processor): - def __init__(self): + def __init__(self, remove_erhua=True): super().__init__(name='whitelist') + self.remove_erhua = remove_erhua self.build_tagger() self.build_verbalizer() @@ -35,5 +36,9 @@ def build_tagger(self): def build_verbalizer(self): super().build_verbalizer() - verbalizer = self.delete_tokens(delete('erhua: "儿"')) + if self.remove_erhua: + verbalizer = self.delete_tokens(delete('erhua: "儿"')) + else: + verbalizer = self.delete_tokens(delete('erhua: \"') + + accep('儿') + delete('\"')) self.verbalizer |= verbalizer diff --git a/tn/main.py b/tn/main.py index 5dccd4f..0312044 100644 --- a/tn/main.py +++ b/tn/main.py @@ -30,7 +30,10 @@ def main(): help='rebuild *.fst') parser.add_argument('--remove_interjections', type=str, default='True', - help='remove interjections like "啊" and "儿"') + help='remove interjections like "啊"') + parser.add_argument('--remove_erhua', type=str, + default='True', + help='remove "儿"') parser.add_argument('--traditional_to_simple', type=str, default='True', help='i.e., "喆" -> "哲"') @@ -48,6 +51,7 @@ def main(): normalizer = Normalizer(cache_dir=args.cache_dir, overwrite_cache=args.overwrite_cache, remove_interjections=str2bool(args.remove_interjections), + remove_erhua=str2bool(args.remove_erhua), traditional_to_simple=str2bool(args.traditional_to_simple), remove_puncts=str2bool(args.remove_puncts), full_to_half=str2bool(args.full_to_half),