Skip to content

Commit

Permalink
refactor: move symbol sorter to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
roedoejet committed Feb 12, 2025
1 parent 577ba61 commit 78dc32c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
14 changes: 4 additions & 10 deletions everyvoice/text/text_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
apply_cleaners_helper,
apply_to_replace_helper,
normalize_text_helper,
symbol_sorter,
)

PAD_SYMBOL = "\x80"
Expand Down Expand Up @@ -109,16 +110,9 @@ def __init__(self, config: TextConfig, punctuation_hash=DEFAULT_PUNCTUATION_HASH
# TODO: do I need to clean the symbols? How to do this if datasets have
# their own cleaners?
_hardcoded_internal_symbols = [self._pad_symbol, " "]
self.symbols = _hardcoded_internal_symbols + list(
sorted(
# Remove duplicates from symbol list, and apply longest
# characters first to apply multigraph symbols first
symbols - set(_hardcoded_internal_symbols),
key=lambda symbol: (
-len(symbol),
symbol,
), # reverse-length sort, then sort alphabetically
)
self.symbols = symbol_sorter(
list(symbols - set(_hardcoded_internal_symbols)),
hardcoded_initial_symbols=_hardcoded_internal_symbols,
)
self.to_replace = config.to_replace
self.missing_symbols: Counter[str] = Counter()
Expand Down
24 changes: 24 additions & 0 deletions everyvoice/text/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from typing import Optional

import grapheme
from ipatok import tokenise
Expand All @@ -7,6 +8,29 @@
from everyvoice.exceptions import ConfigError


def get_symbols_from_checkpoint_symbol_dict(symbols: dict) -> list[str]:
punctuation = list(symbols.get("punctuation", {}).values())
other_symbols = [v for k, v in symbols.items() if k != "punctuation"]
nested_symbols = punctuation + other_symbols
return [item for sublist in nested_symbols for item in sublist]


def symbol_sorter(
symbols_for_sorting: list[str],
hardcoded_initial_symbols: Optional[list[str]] = None,
hardcoded_final_symbols: Optional[list[str]] = None,
) -> list[str]:
if hardcoded_initial_symbols is None:
hardcoded_initial_symbols = []
if hardcoded_final_symbols is None:
hardcoded_final_symbols = []
return (
hardcoded_initial_symbols
+ sorted(symbols_for_sorting, key=lambda symbol: (-len(symbol), symbol))
+ hardcoded_final_symbols
)


def normalize_text_helper(
text: str,
to_replace: dict[str, str],
Expand Down

0 comments on commit 78dc32c

Please sign in to comment.