Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove lower and nfc_normalization from default cleaners #482

Merged
merged 3 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions everyvoice/config/text_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from everyvoice.config.shared_types import ConfigModel
from everyvoice.config.utils import PossiblySerializedCallable
from everyvoice.text.utils import normalize_text_helper
from everyvoice.utils import collapse_whitespace, lower, nfc_normalize
from everyvoice.utils import collapse_whitespace


class Punctuation(BaseModel):
Expand Down Expand Up @@ -82,9 +82,7 @@ class TextConfig(ConfigModel):
symbols: Symbols = Field(default_factory=Symbols)
to_replace: Dict[str, str] = {} # Happens before cleaners
cleaners: list[PossiblySerializedCallable] = [
lower,
collapse_whitespace,
nfc_normalize,
]

@model_validator(mode="after")
Expand Down
4 changes: 3 additions & 1 deletion everyvoice/tests/preprocessed_audio_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from everyvoice.model.e2e.config import FeaturePredictionConfig
from everyvoice.preprocessor import Preprocessor
from everyvoice.tests.basic_test_case import BasicTestCase
from everyvoice.utils import collapse_whitespace, lower, nfc_normalize


class PreprocessedAudioFixture:
Expand Down Expand Up @@ -34,6 +35,7 @@ class PreprocessedAudioFixture:
],
),
text=TextConfig(
cleaners=[collapse_whitespace, lower, nfc_normalize],
symbols=Symbols(
ascii_symbols=list(ascii_lowercase),
ipa=[
Expand All @@ -51,7 +53,7 @@ class PreprocessedAudioFixture:
"ʊ",
"ʒ",
],
)
),
),
contact=BasicTestCase.contact,
)
Expand Down
44 changes: 33 additions & 11 deletions everyvoice/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from everyvoice.text.lookups import build_lookup, lookuptables_from_data
from everyvoice.text.phonemizer import AVAILABLE_G2P_ENGINES, get_g2p_engine
from everyvoice.text.text_processor import TextProcessor
from everyvoice.utils import generic_psv_filelist_reader
from everyvoice.utils import (
collapse_whitespace,
generic_psv_filelist_reader,
lower,
nfc_normalize,
)


class TextTest(BasicTestCase):
Expand All @@ -38,7 +43,7 @@ def test_text_to_sequence(self):
self.assertEqual(self.base_text_processor.decode_tokens(sequence, ""), text)

def test_token_sequence_to_text(self):
sequence = [25, 22, 29, 29, 32, 1, 40, 32, 35, 29, 21]
sequence = [51, 48, 55, 55, 58, 1, 66, 58, 61, 55, 47]
self.assertEqual(self.base_text_processor.encode_text("hello world"), sequence)

def test_hardcoded_symbols(self):
Expand All @@ -48,19 +53,31 @@ def test_hardcoded_symbols(self):
"pad should be Unicode PAD symbol and index 0, whitespace should be index 1",
)

def test_cleaners(self):
def test_cleaners_with_upper(self):
text = "hello world"
text_upper = "HELLO WORLD"
sequence = self.base_text_processor.encode_text(text_upper)
self.assertEqual(self.base_text_processor.decode_tokens(sequence, ""), text)
upper_text_processor = TextProcessor(
TextConfig(
cleaners=[collapse_whitespace, lower],
symbols=Symbols(letters=list(string.ascii_letters)),
),
)
sequence = upper_text_processor.encode_text(text_upper)
self.assertEqual(upper_text_processor.decode_tokens(sequence, ""), text)

def test_punctuation(self):
text = "hello! How are you? My name's: foo;."
tokens = self.base_text_processor.apply_tokenization(
self.base_text_processor.normalize_text(text)
upper_text_processor = TextProcessor(
TextConfig(
cleaners=[collapse_whitespace, lower],
symbols=Symbols(letters=list(string.ascii_letters)),
),
)
tokens = upper_text_processor.apply_tokenization(
upper_text_processor.normalize_text(text)
)
self.assertEqual(
self.base_text_processor.apply_punctuation_rules(tokens),
upper_text_processor.apply_punctuation_rules(tokens),
[
"h",
"e",
Expand Down Expand Up @@ -105,6 +122,7 @@ def test_phonological_features(self):
moh_config = FeaturePredictionConfig(
contact=self.contact,
text=TextConfig(
cleaners=[collapse_whitespace, lower, nfc_normalize],
symbols=Symbols(
letters=[
"ʌ̃̀ː",
Expand Down Expand Up @@ -153,7 +171,7 @@ def test_phonological_features(self):
"j",
"ʔ",
]
)
),
),
)
moh_text_processor = TextProcessor(moh_config.text)
Expand Down Expand Up @@ -202,10 +220,11 @@ def test_dipgrahs(self):
self.assertEqual(len(sequence), 1)

def test_normalization(self):
# This test doesn't really test very much, but just here to highlight that base cleaning involves NFC
# This test doesn't really test very much, but just here to highlight that base cleaning doesn't involve NFC
accented_text_processor = TextProcessor(
TextConfig(
symbols=Symbols(letters=list(string.ascii_letters), accented=["é"])
cleaners=[nfc_normalize],
symbols=Symbols(letters=list(string.ascii_letters), accented=["é"]),
),
)
text = "he\u0301llo world"
Expand All @@ -215,6 +234,9 @@ def test_normalization(self):
accented_text_processor.decode_tokens(sequence, ""),
normalize("NFC", text),
)
self.assertNotEqual(
self.base_text_processor.apply_cleaners(text), normalize("NFC", text)
)

def test_missing_symbol(self):
text = "h3llo world"
Expand Down
2 changes: 2 additions & 0 deletions everyvoice/tests/test_wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,7 @@ def test_absolute_wav_file_directory_and_local_experiment(self):
tmpdir = Path(tmpdir).absolute()
wavs_dir = tmpdir / "wavs/Common-Voice"
self.config.state["dataset_0"][SN.wavs_dir_step.value] = wavs_dir
self.config.state["dataset_0"][SN.text_processing_step] = (0,)
self.config.effect()
data_file = (
Path(self.config.state[SN.name_step.value])
Expand Down Expand Up @@ -1248,6 +1249,7 @@ def test_absolute_wav_file_directory_and_nested_experiment(self):
tmpdir = Path(tmpdir).absolute()
wavs_dir = tmpdir / "wavs/Common-Voice"
self.config.state["dataset_0"][SN.wavs_dir_step.value] = wavs_dir
self.config.state["dataset_0"][SN.text_processing_step] = tuple()
self.config.effect()
data_file = (
Path(self.config.state[SN.output_step.value])
Expand Down
6 changes: 4 additions & 2 deletions everyvoice/text/text_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def apply_cleaners(self, text: str) -> str:
Returns:
str: the replaced text

>>> tp = TextProcessor(TextConfig())
>>> from everyvoice.utils import collapse_whitespace, lower, nfc_normalize
>>> tp = TextProcessor(TextConfig(cleaners=[collapse_whitespace, lower, nfc_normalize]))
>>> tp.apply_cleaners('HELLO\u0301')
'helló'
"""
Expand All @@ -190,7 +191,8 @@ def normalize_text(
Returns:
str: normalized text ready to be tokenized

>>> tp = TextProcessor(TextConfig())
>>> from everyvoice.utils import collapse_whitespace, lower, nfc_normalize
>>> tp = TextProcessor(TextConfig(cleaners=[collapse_whitespace, lower, nfc_normalize]))
>>> tp.normalize_text('HELLO\u0301!')
'helló!'
"""
Expand Down
15 changes: 14 additions & 1 deletion everyvoice/wizard/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
Step,
StepNames,
)
from everyvoice.wizard.dataset import get_dataset_steps
from everyvoice.wizard.dataset import TextProcessingStep, get_dataset_steps
from everyvoice.wizard.prompts import (
CUSTOM_QUESTIONARY_STYLE,
get_response_from_menu_prompt,
Expand Down Expand Up @@ -211,8 +211,19 @@ def effect(self):
symbols = {}
multispeaker = False
multilingual = False
global_cleaners = (
[]
) # TODO: this should be fixed by https://github.com/roedoejet/EveryVoice/issues/359
for dataset in [key for key in self.state.keys() if key.startswith("dataset_")]:
dataset_state = self.state[dataset]
# Add Cleaners
# TODO: these should really be dataset-specific cleaners, not global cleaners
# so this should be fixed by https://github.com/roedoejet/EveryVoice/issues/359
if dataset_state.get(StepNames.text_processing_step):
global_cleaners += [
TextProcessingStep().process_lookup[x]["fn"]
for x in dataset_state[StepNames.text_processing_step]
]
# Gather Symbols for Text Configuration
# rename keys based on dataset name:
dataset_name = dataset_state[StepNames.dataset_name_step]
Expand Down Expand Up @@ -261,7 +272,9 @@ def effect(self):
permissions_obtained=True, # If you get this far, you've answered the Dataset Permission Attestation step correctly
)
)

text_config = TextConfig(symbols=Symbols(**symbols))
text_config.cleaners += global_cleaners
text_config_path = Path(f"{TEXT_CONFIG_FILENAME_PREFIX}.{self.response}")
write_dict_to_config(
json.loads(text_config.model_dump_json(exclude_none=False)),
Expand Down
21 changes: 13 additions & 8 deletions everyvoice/wizard/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import re
from pathlib import Path
from typing import Sequence
from unicodedata import normalize

import questionary
import rich
Expand All @@ -14,7 +13,13 @@

from everyvoice.config.type_definitions import DatasetTextRepresentation
from everyvoice.text.utils import guess_graphemes_in_text, guess_ipa_phones_in_text
from everyvoice.utils import generic_xsv_filelist_reader, read_festival, slugify
from everyvoice.utils import (
generic_xsv_filelist_reader,
lower,
nfc_normalize,
read_festival,
slugify,
)
from everyvoice.wizard import TEXT_CONFIG_FILENAME_PREFIX, Step, StepNames, Tour
from everyvoice.wizard.prompts import (
CUSTOM_QUESTIONARY_STYLE,
Expand Down Expand Up @@ -626,6 +631,10 @@ def get_iso_code(language):

class TextProcessingStep(Step):
DEFAULT_NAME = StepNames.text_processing_step
process_lookup = {
0: {"fn": lower, "desc": "lowercase"},
1: {"fn": nfc_normalize, "desc": "NFC Normalization"},
}

def prompt(self):
return get_response_from_menu_prompt(
Expand All @@ -644,21 +653,17 @@ def validate(self, response):

def effect(self):
# Apply the selected text processing processes
process_lookup = {
0: {"fn": lambda x: x.lower(), "desc": "lowercase"},
1: {"fn": lambda x: normalize("NFC", x), "desc": "NFC Normalization"},
}
if "symbols" not in self.state:
self.state["symbols"] = {}
if self.response:
text_index = self.state["filelist_headers"].index(
self.state[StepNames.filelist_text_representation_step]
)
for process in self.response:
process_fn = process_lookup[process]["fn"]
process_fn = self.process_lookup[process]["fn"]
for i in tqdm(
range(len(self.state["filelist_data_list"])),
desc=f"Applying {process_lookup[process]['desc']} to data",
desc=f"Applying {self.process_lookup[process]['desc']} to data",
):
self.state["filelist_data_list"][i][text_index] = process_fn(
self.state["filelist_data_list"][i][text_index]
Expand Down