Skip to content

Commit

Permalink
fix: multiple small improvements to the wizard
Browse files Browse the repository at this point in the history
 - Move the wavs question after fully processing the filelist, including
   permission.
 - Catch the case where we don't have permission to write the directory with a
   friendly error message.
 - Make the ValidateWavsStep question use the same type of prompt as the other
   similar questions, for better UX consistency
 - If we accept a wavs dir / filelist combo with missing wav files, give a
   bright error message to make it clear future problems will happen.
 - Don't save the config if no dataset was created.

The ideas implemented here were refined in consultation with Samuel Larkin.
  • Loading branch information
joanise committed Jun 14, 2024
1 parent cab5124 commit 752b002
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 31 deletions.
53 changes: 39 additions & 14 deletions everyvoice/tests/test_wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,13 @@ def test_output_path_step(self):
self.assertFalse(step.validate(tmpdirname))
os.unlink(dataset_file)

# Bad case 3: file under read-only directory
ro_dir = Path(tmpdirname) / "read-only"
ro_dir.mkdir(mode=0x555)
with capture_stdout() as out:
self.assertFalse(step.validate(str(ro_dir)))
self.assertIn("could not create", out.getvalue())

# Good case
with capture_stdout() as stdout:
with monkeypatch(step, "prompt", Say(tmpdirname)):
Expand All @@ -295,11 +302,16 @@ def test_output_path_step(self):

def test_more_data_step(self):
"""Exercise giving an invalid response and a yes response to more data."""
tour = Tour("testing", [basic.MoreDatasetsStep()])
step = tour.steps[0]
tour = Tour(
"testing",
[dataset.FilelistStep(state_subset="dataset_0"), basic.MoreDatasetsStep()],
)

step = tour.steps[1]
self.assertFalse(step.validate("foo"))
self.assertTrue(step.validate("yes"))
self.assertEqual(len(step.children), 0)

with patch_menu_prompt(0): # answer 0 is "no"
step.run()
self.assertEqual(len(step.children), 1)
Expand All @@ -309,6 +321,15 @@ def test_more_data_step(self):
step.run()
self.assertGreater(len(step.descendants), 10)

def test_no_data_to_save(self):
"""When the tour created no datasets at all, saving the config is skipped."""
tour = Tour("testing", [basic.MoreDatasetsStep()])
step = tour.steps[0]
with patch_menu_prompt(0), capture_stdout() as out: # answer 0 is "no"
step.run()
self.assertEqual(len(step.children), 0)
self.assertIn("No dataset to save", out.getvalue())

def test_dataset_name(self):
step = dataset.DatasetNameStep()
with monkeypatch(builtins, "input", Say(("", "bad/name", "good-name"), True)):
Expand Down Expand Up @@ -368,10 +389,6 @@ def test_sample_rate_config(self):
def test_dataset_subtour(self):
tour = Tour("unit testing", steps=dataset.get_dataset_steps())

wavs_dir_step = find_step(SN.wavs_dir_step, tour.steps)
with monkeypatch(wavs_dir_step, "prompt", Say(str(self.data_dir))):
wavs_dir_step.run()

filelist = str(self.data_dir / "unit-test-case1.psv")
filelist_step = find_step(SN.filelist_step, tour.steps)
monkey = monkeypatch(filelist_step, "prompt", Say(filelist))
Expand Down Expand Up @@ -413,10 +430,14 @@ def test_dataset_subtour(self):
step.run()
self.assertEqual(step.state["filelist_headers"][2], "text")

wavs_dir_step = find_step(SN.wavs_dir_step, tour.steps)
with monkeypatch(wavs_dir_step, "prompt", Say(str(self.data_dir))):
wavs_dir_step.run()

validate_wavs_step = find_step(SN.validate_wavs_step, tour.steps)
with monkeypatch(builtins, "input", Say("no")), capture_stdout() as out:
with patch_menu_prompt(1), capture_stdout() as out:
validate_wavs_step.run()
self.assertEqual(step.state[SN.validate_wavs_step], "no")
self.assertEqual(step.state[SN.validate_wavs_step][:2], "No")
self.assertIn("Warning: 3 wav files were not found", out.getvalue())

text_representation_step = find_step(
Expand Down Expand Up @@ -605,7 +626,9 @@ def test_prompt(self):
)
self.assertEqual(answer, 1)

def monkey_run_tour(self, name: str, steps_and_answers: list[StepAndAnswer]):
def monkey_run_tour(
self, name: str, steps_and_answers: list[StepAndAnswer]
) -> tuple[Tour, str]:
"""Create and run a tour with the monkey answers given
Args:
Expand Down Expand Up @@ -677,10 +700,9 @@ def test_monkey_tour_1(self):

def test_monkey_tour_2(self):
data_dir = Path(__file__).parent / "data"
tour, log = self.monkey_run_tour(
tour, out = self.monkey_run_tour(
"monkey tour 2",
[
StepAndAnswer(dataset.WavsDirStep(), Say(str(data_dir))),
StepAndAnswer(
dataset.FilelistStep(),
Say(str(data_dir / "metadata.psv")),
Expand All @@ -699,9 +721,10 @@ def test_monkey_tour_2(self):
Say("no"),
children_answers=[RecursiveAnswers(Say("eng"))],
),
StepAndAnswer(dataset.WavsDirStep(), Say(str(data_dir))),
StepAndAnswer(
dataset.ValidateWavsStep(),
monkeypatch(builtins, "input", Say("Yes")),
patch_menu_prompt(0), # 0 is Yes
children_answers=[
RecursiveAnswers(Say(str(data_dir / "lj/wavs"))),
RecursiveAnswers(null_patch()),
Expand All @@ -717,8 +740,10 @@ def test_monkey_tour_2(self):
],
)

self.assertIn("Warning: 5 wav files were not found", log)
self.assertIn("Great! All audio files found in directory", log)
tree = str(RenderTree(tour.root))
self.assertIn("├── Validate Wavs Step", tree)
self.assertIn("│ └── Validate Wavs Step", tree)
self.assertIn("Great! All audio files found in directory", out)

# print(tour.state)
self.assertEqual(len(tour.state["filelist_data"]), 5)
Expand Down
24 changes: 16 additions & 8 deletions everyvoice/wizard/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,20 +148,26 @@ def validate(self, response) -> bool:
return False
assert self.state is not None, "OutputPathStep requires NameStep"
output_path = path / self.state.get(StepNames.name_step, "DEFAULT_NAME")
TRY_AGAIN = "Please choose another output directory or start again and choose a different project name."
if output_path.exists():
print(
f"Sorry, '{output_path}' already exists. Please choose another output directory or start again and choose a different project name."
)
print(f"Sorry, '{output_path}' already exists.", TRY_AGAIN)
return False

# We create the output directory in validate() instead of effect() so that
# failure can be reported to the user and the question asked again if necessary.
try:
output_path.mkdir(parents=True, exist_ok=True)
except OSError as e:
print(f"Sorry, could not create '{output_path}': {e}.", TRY_AGAIN)
return False

self.output_path = output_path
return True

def effect(self):
assert self.state is not None, "OutputPathStep requires NameStep"
output_path = Path(self.response) / self.state.get(
StepNames.name_step, "DEFAULT_NAME"
print(
f"The Configuration Wizard 🧙 will put your files here: '{self.output_path}'"
)
output_path.mkdir(parents=True, exist_ok=True)
print(f"The Configuration Wizard 🧙 will put your files here: '{output_path}'")


class ConfigFormatStep(Step):
Expand Down Expand Up @@ -450,6 +456,8 @@ def effect(self):
+ [MoreDatasetsStep()],
self,
)
elif len([key for key in self.state.keys() if key.startswith("dataset_")]) == 0:
print("No dataset to save, exiting.")
else:
self.tour.add_step(
ConfigFormatStep(name=StepNames.config_format_step), self
Expand Down
28 changes: 23 additions & 5 deletions everyvoice/wizard/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from unicodedata import normalize

import questionary
import rich
from rich.panel import Panel
from rich.style import Style
from tqdm import tqdm

from everyvoice.config.type_definitions import DatasetTextRepresentation
Expand Down Expand Up @@ -295,21 +298,36 @@ def wav_file_early_validation(self) -> int:
def prompt(self):
error_count = self.wav_file_early_validation()
if error_count:
return input("Do you want to pick a different wavs directory? ").lower()
return "ok"
return get_response_from_menu_prompt(
title="Do you want to pick a different wavs directory?",
choices=(
"Yes",
"No, I will fix my audio basenames or add missing audio files later.",
),
)
else:
return "OK"

def validate(self, response):
return response in ("y", "yes", "n", "no", "ok")
return response[:3] in ("OK", "Yes", "No,")

def effect(self):
if self.response[:1] == "y":
if self.response == "Yes":
self.tour.add_steps(
[
WavsDirStep(state_subset=self.state_subset),
ValidateWavsStep(state_subset=self.state_subset),
],
self,
)
elif self.response.startswith("No"):
rich.print(
Panel(
"Continuing despite missing audio files. Make sure you fix your filelist later or add missing audio files, otherwise entries in your filelist with missing audio files will be skipped during preprocessing and therefore be ignored during training.",
title="Missing audio files",
border_style=Style(color="#EF1010"),
)
)


class FilelistTextRepresentationStep(Step):
Expand Down Expand Up @@ -711,11 +729,11 @@ def effect(self):

def get_dataset_steps(dataset_index=0):
return [
WavsDirStep(state_subset=f"dataset_{dataset_index}"),
FilelistStep(state_subset=f"dataset_{dataset_index}"),
[
DatasetPermissionStep(state_subset=f"dataset_{dataset_index}"),
FilelistFormatStep(state_subset=f"dataset_{dataset_index}"),
WavsDirStep(state_subset=f"dataset_{dataset_index}"),
ValidateWavsStep(state_subset=f"dataset_{dataset_index}"),
FilelistTextRepresentationStep(state_subset=f"dataset_{dataset_index}"),
TextProcessingStep(state_subset=f"dataset_{dataset_index}"),
Expand Down
8 changes: 4 additions & 4 deletions everyvoice/wizard/prompts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import List, Tuple, Union
from typing import Union

import simple_term_menu
from questionary import Style
Expand Down Expand Up @@ -27,18 +27,18 @@

def get_response_from_menu_prompt(
prompt_text: str = "",
choices: Tuple[str, ...] = (),
choices: tuple[str, ...] = (),
title: str = "",
multi=False,
search=False,
return_indices=False,
) -> Union[int, str, List[int], List[str]]:
) -> Union[int, str, list[int], list[str]]:
"""Given some prompt text and a list of choices, create a simple terminal window
and return the index of the choice
Args:
prompt_text (str): rich prompt text to print before menu
choices (List[str]): choices to display
choices (list[str]): choices to display
Returns:
int: index of choice
Expand Down

0 comments on commit 752b002

Please sign in to comment.