Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: change shape of filelist list data instead of re-reading it #515

Merged
merged 7 commits into from
Jul 30, 2024
Merged
23 changes: 19 additions & 4 deletions everyvoice/config/text_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from everyvoice.config.shared_types import ConfigModel
from everyvoice.config.utils import PossiblySerializedCallable
from everyvoice.text.utils import normalize_text_helper
from everyvoice.utils import collapse_whitespace
from everyvoice.utils import collapse_whitespace, strip_text


class Punctuation(BaseModel):
Expand Down Expand Up @@ -62,6 +62,23 @@ def all_except_punctuation(self) -> set[str]:
"""Returns the set containing all characters."""
return set(w for _, v in self if not isinstance(v, Punctuation) for w in v)

@model_validator(mode="after")
def cannot_have_punctuation_in_symbol_set(self) -> "Symbols":
"""You cannot have the same symbol defined in punctuation as elsewhere.

Raises:
ValueError: raised if a symbol from punctuation is found elsewhere

Returns:
Symbols: The validated symbol set
"""
for punctuation in self.punctuation.all:
if punctuation in self.all_except_punctuation:
raise ValueError(
f"Sorry, the symbol '{punctuation}' occurs in both your declared punctuation and in your other symbol set. Please inspect your text configuration and either remove the symbol from the punctuation or other symbol set."
)
return self

@model_validator(mode="after")
def member_must_be_list_of_strings(self) -> "Symbols":
"""Except for `punctuation` & `pad`, all user defined member variables
Expand All @@ -81,9 +98,7 @@ def member_must_be_list_of_strings(self) -> "Symbols":
class TextConfig(ConfigModel):
symbols: Symbols = Field(default_factory=Symbols)
to_replace: Dict[str, str] = {} # Happens before cleaners
cleaners: list[PossiblySerializedCallable] = [
collapse_whitespace,
]
cleaners: list[PossiblySerializedCallable] = [collapse_whitespace, strip_text]

@model_validator(mode="after")
def clean_symbols(self) -> "TextConfig":
Expand Down
12 changes: 6 additions & 6 deletions everyvoice/model/e2e/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@

# The contact information only needs to be registered on the main config
class AlignerConfigNoContact(AlignerConfig):
contact: Optional[ContactInformation] = None
contact: Optional[ContactInformation] = None # type: ignore


class VocoderConfigNoContact(VocoderConfig):
contact: Optional[ContactInformation] = None
contact: Optional[ContactInformation] = None # type: ignore


class FeaturePredictionConfigNoContact(FeaturePredictionConfig):
contact: Optional[ContactInformation] = None
contact: Optional[ContactInformation] = None # type: ignore


class E2ETrainingConfig(BaseTrainingConfig):
Expand All @@ -36,17 +36,17 @@ class E2ETrainingConfig(BaseTrainingConfig):

class EveryVoiceConfig(BaseModelWithContact):
aligner: AlignerConfig | AlignerConfigNoContact = Field(
default_factory=AlignerConfigNoContact
default_factory=AlignerConfigNoContact # type: ignore
)
path_to_aligner_config_file: Optional[FilePath] = None

feature_prediction: FeaturePredictionConfig | FeaturePredictionConfigNoContact = (
Field(default_factory=FeaturePredictionConfigNoContact)
Field(default_factory=FeaturePredictionConfigNoContact) # type: ignore
)
path_to_feature_prediction_config_file: Optional[FilePath] = None

vocoder: VocoderConfig | VocoderConfigNoContact = Field(
default_factory=VocoderConfigNoContact
default_factory=VocoderConfigNoContact # type: ignore
)
path_to_vocoder_config_file: Optional[FilePath] = None

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cleaners: [everyvoice.utils.lower, everyvoice.utils.collapse_whitespace, everyvoice.utils.nfc_normalize]
symbols:
dataset_0-symbols: [' ', '''', ',', '-', ., C, E, H, K, P, T, a, b, c, d, e, f,
dataset_0-symbols: [' ', C, E, H, K, P, T, a, b, c, d, e, f,
g, h, i, l, m, n, o, p, r, s, t, u, v, w, x, y]
pad: _
silence: [<SIL>]
Expand Down
1 change: 1 addition & 0 deletions everyvoice/tests/data/unit-test-case1.psv
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
columnA|non-standard-basename|non-standard-text|extra
blah|somefile|characters|irrelevant extra
boom|file2|CaSeD NFD: éàê NFC: éàê|blah
floop|file3| let us see if it collapses whitespace|blah
bam|banned_file|ZZZ|has banned symbol (Z)
8 changes: 6 additions & 2 deletions everyvoice/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def test_token_sequence_to_text(self):

def test_hardcoded_symbols(self):
self.assertEqual(
self.base_text_processor.encode_text("\x80 "),
[0, 1],
self.base_text_processor.encode_text("\x80 \x80"),
[0, 1, 0],
"pad should be Unicode PAD symbol and index 0, whitespace should be index 1",
)

Expand All @@ -65,6 +65,10 @@ def test_cleaners_with_upper(self):
sequence = upper_text_processor.encode_text(text_upper)
self.assertEqual(upper_text_processor.decode_tokens(sequence, ""), text)

def test_no_duplicate_punctuation(self):
with self.assertRaises(ValidationError):
TextConfig(symbols=Symbols(letters=[":"] + list(string.ascii_letters)))

def test_punctuation(self):
text = "hello! How are you? My name's: foo;."
upper_text_processor = TextProcessor(
Expand Down
104 changes: 97 additions & 7 deletions everyvoice/tests/test_wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,80 @@ def test_sample_rate_config(self):
self.assertTrue(step.completed)
self.assertEqual(step.response, 512)

def test_whitespace_always_collapsed(self):
tour = Tour("unit testing", steps=dataset.get_dataset_steps())
filelist = str(self.data_dir / "unit-test-case1.psv")
filelist_step = find_step(SN.filelist_step, tour.steps)
monkey = monkeypatch(filelist_step, "prompt", Say(filelist))
with monkey:
filelist_step.run()

permission_step = find_step(SN.dataset_permission_step, tour.steps)
with patch_menu_prompt(1): # 1 is "yes, I have permission"
permission_step.run()
self.assertTrue(
permission_step.state[SN.dataset_permission_step].startswith("Yes")
)

format_step = find_step(SN.filelist_format_step, tour.steps)
with patch_menu_prompt(0): # 0 is "psv"
format_step.run()
step = format_step.children[0]
with patch_menu_prompt(1): # 1 is "yes"
step.run()

step = format_step.children[1]
with patch_menu_prompt(1): # 1 is second column
step.run()

step = format_step.children[2]
with patch_menu_prompt(1): # 1 is second remaining column, i.e., third column
step.run()

text_representation_step = find_step(
SN.filelist_text_representation_step, tour.steps
)
with patch_menu_prompt(0): # 0 is "characters"
text_representation_step.run()
speaker_step = find_step(SN.data_has_speaker_value_step, tour.steps)
with patch_menu_prompt(0): # 0 is "no"
speaker_step.run()

know_speaker_step = speaker_step.children[0]
with patch_menu_prompt(1): # 1 is "yes"
know_speaker_step.run()

add_speaker_step = know_speaker_step.children[0]
with patch_input("default"):
add_speaker_step.run()

language_step = find_step(SN.data_has_language_value_step, tour.steps)
with patch_menu_prompt(0): # 0 is "no"
language_step.run()

select_lang_step = language_step.children[0]
with capture_stdout(), capture_stderr():
with patch_menu_prompt(15): # some arbitrary language from the list
select_lang_step.run()

wavs_dir_step = find_step(SN.wavs_dir_step, tour.steps)
with monkeypatch(wavs_dir_step, "prompt", Say(str(self.data_dir))):
wavs_dir_step.run()

validate_wavs_step = find_step(SN.validate_wavs_step, tour.steps)
with patch_menu_prompt(1), capture_stdout():
validate_wavs_step.run()

text_processing_step = find_step(SN.text_processing_step, tour.steps)
# 0 is lowercase, 1 is NFC Normalization, select none
with monkeypatch(dataset, "tqdm", lambda seq, desc: seq):
with patch_menu_prompt([]):
text_processing_step.run()
self.assertEqual(
text_processing_step.state["filelist_data_list"][3][2],
"let us see if it collapses whitespace",
)

def test_dataset_subtour(self):
tour = Tour("unit testing", steps=dataset.get_dataset_steps())

Expand Down Expand Up @@ -435,7 +509,7 @@ def test_dataset_subtour(self):
with patch_menu_prompt(1): # 1 is "yes"
step.run()
self.assertEqual(step.state[SN.data_has_header_line_step.value], "yes")
self.assertEqual(len(step.state["filelist_data_list"]), 4)
self.assertEqual(len(step.state["filelist_data_list"]), 5)

step = format_step.children[1]
self.assertIsInstance(step, dataset.HeaderStep)
Expand Down Expand Up @@ -503,7 +577,7 @@ def test_dataset_subtour(self):
with patch_menu_prompt(1), capture_stdout() as out:
validate_wavs_step.run()
self.assertEqual(step.state[SN.validate_wavs_step][:2], "No")
self.assertIn("Warning: 3 wav files were not found", out.getvalue())
self.assertIn("Warning: 4 wav files were not found", out.getvalue())

text_processing_step = find_step(SN.text_processing_step, tour.steps)
# 0 is lowercase, 1 is NFC Normalization, select both
Expand All @@ -513,11 +587,16 @@ def test_dataset_subtour(self):
# print(text_processing_step.state)
self.assertEqual(
text_processing_step.state["filelist_data_list"][2][2],
"cased \t nfd: éàê nfc: éàê", # the "nfd: éàê" bit here is now NFC
"cased nfd: éàê nfc: éàê", # the "nfd: éàê" bit here is now NFC
)

self.assertEqual(
text_processing_step.state["filelist_data_list"][3][2],
"let us see if it collapses whitespace",
)

# Make sure realoading the data as dict stripped the header line
self.assertEqual(len(step.state["filelist_data"]), 3)
self.assertEqual(len(step.state["filelist_data"]), 4)

sox_effects_step = find_step(SN.sox_effects_step, tour.steps)
# 0 is resample to 22050 kHz, 2 is remove silence at start and end
Expand Down Expand Up @@ -547,11 +626,24 @@ def test_dataset_subtour(self):
)

symbol_set_step = find_step(SN.symbol_set_step, tour.steps)
self.assertEqual(len(symbol_set_step.state["filelist_data"]), 3)
self.assertEqual(len(symbol_set_step.state["filelist_data"]), 4)
with capture_stdout(), capture_stderr():
symbol_set_step.run()
self.assertEqual(len(symbol_set_step.state[SN.symbol_set_step.value]), 2)
self.assertIn("t͡s", symbol_set_step.state[SN.symbol_set_step.value]["phones"])
self.assertNotIn(
":", symbol_set_step.state[SN.symbol_set_step.value]["characters"]
)
self.assertNotIn(":", symbol_set_step.state[SN.symbol_set_step.value]["phones"])
# assert that symbols contain no duplicates
self.assertEqual(
len(set(symbol_set_step.state[SN.symbol_set_step.value]["characters"])),
len(symbol_set_step.state[SN.symbol_set_step.value]["characters"]),
)
self.assertEqual(
len(set(symbol_set_step.state[SN.symbol_set_step.value]["phones"])),
len(symbol_set_step.state[SN.symbol_set_step.value]["phones"]),
)

def test_empty_filelist(self):
tour = Tour(
Expand Down Expand Up @@ -1599,8 +1691,6 @@ def setUp(self):
SN.symbol_set_step.value: {
"characters": [
" ",
",",
".",
"A",
"D",
"E",
Expand Down
10 changes: 9 additions & 1 deletion everyvoice/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,14 +356,22 @@ def generic_dict_loader(
generic_csv_filelist_reader = partial(generic_dict_loader, delimiter=",")


def collapse_whitespace(text):
def collapse_whitespace(text: str):
"""
>>> collapse_whitespace(" asdf qwer ")
' asdf qwer '
"""
return re.sub(_whitespace_re, " ", text)


def strip_text(text: str):
"""
>>> strip_text(" asdf qwer ")
'asdf qwer'
"""
return text.strip()


@contextmanager
def tqdm_joblib_context(tqdm_instance):
"""Context manager to make tqdm compatible with joblib.Parallel
Expand Down
Loading
Loading