diff --git a/everyvoice/cli.py b/everyvoice/cli.py index cd04ecfc..2c2e0c74 100644 --- a/everyvoice/cli.py +++ b/everyvoice/cli.py @@ -121,9 +121,9 @@ class ModelTypes(str, Enum): """, ) def new_project(): - from everyvoice.wizard.main_tour import WIZARD_TOUR + from everyvoice.wizard.main_tour import get_main_wizard_tour - WIZARD_TOUR.run() + get_main_wizard_tour().run() # Add preprocess to root diff --git a/everyvoice/tests/test_wizard.py b/everyvoice/tests/test_wizard.py index 09a9d950..42f87b63 100644 --- a/everyvoice/tests/test_wizard.py +++ b/everyvoice/tests/test_wizard.py @@ -29,7 +29,7 @@ from everyvoice.wizard import StepNames as SN from everyvoice.wizard import Tour, basic, dataset, prompts from everyvoice.wizard.basic import ConfigFormatStep -from everyvoice.wizard.main_tour import WIZARD_TOUR +from everyvoice.wizard.main_tour import get_main_wizard_tour from everyvoice.wizard.utils import EnumDict CONTACT_INFO_STATE = State() @@ -71,6 +71,19 @@ def monkey(self): return self.answer_or_monkey +def find_step(name: Enum, steps: Sequence[Step | list[Step]]): + """Find a step with the given name in steps, of potentially variable depth""" + for s in steps: + if isinstance(s, list): + try: + return find_step(name, s) + except IndexError: + pass + elif s.name == name.value: + return s + raise IndexError(f"Step {name} not found.") # pragma: no cover + + class WizardTest(TestCase): """Basic test for the configuration wizard""" @@ -146,7 +159,7 @@ def validate(self, x): leaf[2].run() def test_main_tour(self): - tour = WIZARD_TOUR + tour = get_main_wizard_tour() self.assertGreater(len(tour.steps), 6) # TODO try to figure out how to actually run the tour in unit testing or # at least add more interesting assertions that just the fact that it's @@ -154,11 +167,12 @@ def test_main_tour(self): # self.monkey_run_tour() with a bunch of recursive answer would the thing to use here... def test_visualize(self): + tour = get_main_wizard_tour() with capture_stdout() as out: - WIZARD_TOUR.visualize() + tour.visualize() log = out.getvalue() - self.assertIn("└── Contact Name Step", log) - self.assertIn("└── Validate Wavs Step", log) + self.assertIn("── Contact Name Step", log) + self.assertIn("── Validate Wavs Step", log) def test_name_step(self): """Exercise providing a valid dataset name.""" @@ -217,6 +231,26 @@ def test_bad_contact_email_step(self): self.assertIn("There must be something after the @-sign", output) self.assertIn("An email address cannot end with a period", output) + def test_no_permissions(self): + """Exercise lacking permissions, then trying again""" + tour = get_main_wizard_tour() + permission_step = find_step(SN.dataset_permission_step, tour.steps) + self.assertGreater(len(permission_step.children), 8) + self.assertGreater(len(tour.root.descendants), 14) + self.assertIn("dataset_0", tour.state) + with patch_menu_prompt(0): # 0 is no, I don't have permission + permission_step.run() + self.assertEqual(permission_step.children, ()) + self.assertLess(len(tour.root.descendants), 10) + self.assertNotIn("dataset_0", tour.state) + + more_dataset_step = find_step(SN.more_datasets_step, tour.steps) + with patch_menu_prompt(1): # 1 is Yes, I have more data + more_dataset_step.run() + self.assertIn("dataset_1", tour.state) + self.assertGreater(len(more_dataset_step.descendants), 8) + self.assertGreater(len(tour.root.descendants), 14) + def test_output_path_step(self): """Exercise the OutputPathStep""" tour = Tour( @@ -250,22 +284,31 @@ def test_output_path_step(self): self.assertFalse(step.validate(tmpdirname)) os.unlink(dataset_file) + # Bad case 3: file under read-only directory + ro_dir = Path(tmpdirname) / "read-only" + ro_dir.mkdir(mode=0x555) + with capture_stdout() as out: + self.assertFalse(step.validate(str(ro_dir))) + self.assertIn("could not create", out.getvalue()) + # Good case with capture_stdout() as stdout: with monkeypatch(step, "prompt", Say(tmpdirname)): step.run() self.assertIn("will put your files", stdout.getvalue()) - output_dir = Path(tmpdirname) / "myname" - self.assertTrue(output_dir.exists()) - self.assertTrue(output_dir.is_dir()) def test_more_data_step(self): """Exercise giving an invalid response and a yes response to more data.""" - tour = Tour("testing", [basic.MoreDatasetsStep()]) - step = tour.steps[0] + tour = Tour( + "testing", + [dataset.FilelistStep(state_subset="dataset_0"), basic.MoreDatasetsStep()], + ) + + step = tour.steps[1] self.assertFalse(step.validate("foo")) self.assertTrue(step.validate("yes")) self.assertEqual(len(step.children), 0) + with patch_menu_prompt(0): # answer 0 is "no" step.run() self.assertEqual(len(step.children), 1) @@ -273,7 +316,16 @@ def test_more_data_step(self): with patch_menu_prompt(1): # answer 1 is "yes" step.run() - self.assertGreater(len(step.children), 5) + self.assertGreater(len(step.descendants), 10) + + def test_no_data_to_save(self): + """When the tour created no datasets at all, saving the config is skipped.""" + tour = Tour("testing", [basic.MoreDatasetsStep()]) + step = tour.steps[0] + with patch_menu_prompt(0), capture_stdout() as out: # answer 0 is "no" + step.run() + self.assertEqual(len(step.children), 0) + self.assertIn("No dataset to save", out.getvalue()) def test_dataset_name(self): step = dataset.DatasetNameStep() @@ -332,56 +384,57 @@ def test_sample_rate_config(self): self.assertEqual(step.response, 512) def test_dataset_subtour(self): - def find_step(name: Enum, steps: Sequence[Step]): - for s in steps: - if s.name == name.value: - return s - raise IndexError(f"Step {name} not found.") # pragma: no cover - tour = Tour("unit testing", steps=dataset.get_dataset_steps()) - wavs_dir_step = find_step(SN.wavs_dir_step, tour.steps) - with monkeypatch(wavs_dir_step, "prompt", Say(str(self.data_dir))): - wavs_dir_step.run() - filelist = str(self.data_dir / "unit-test-case1.psv") filelist_step = find_step(SN.filelist_step, tour.steps) monkey = monkeypatch(filelist_step, "prompt", Say(filelist)) with monkey: filelist_step.run() + + permission_step = find_step(SN.dataset_permission_step, tour.steps) + with patch_menu_prompt(1): # 1 is "yes, I have permission" + permission_step.run() + self.assertTrue( + permission_step.state[SN.dataset_permission_step].startswith("Yes") + ) + format_step = find_step(SN.filelist_format_step, tour.steps) with patch_menu_prompt(0): # 0 is "psv" format_step.run() - self.assertIsInstance(format_step.children[0], dataset.HasHeaderLineStep) - self.assertEqual( - format_step.children[0].name, SN.data_has_header_line_step.value - ) - self.assertIsInstance(format_step.children[1], dataset.HeaderStep) - self.assertEqual(format_step.children[1].name, SN.basename_header_step.value) - self.assertIsInstance(format_step.children[2], dataset.HeaderStep) - self.assertEqual(format_step.children[2].name, SN.text_header_step.value) + self.assertEqual(len(format_step.children), 3) step = format_step.children[0] + self.assertIsInstance(step, dataset.HasHeaderLineStep) + self.assertEqual(step.name, SN.data_has_header_line_step.value) with patch_menu_prompt(1): # 1 is "yes" step.run() self.assertEqual(step.state[SN.data_has_header_line_step.value], "yes") self.assertEqual(len(step.state["filelist_data_list"]), 4) step = format_step.children[1] + self.assertIsInstance(step, dataset.HeaderStep) + self.assertEqual(step.name, SN.basename_header_step.value) with patch_menu_prompt(1): # 1 is second column step.run() self.assertEqual(step.response, 1) self.assertEqual(step.state["filelist_headers"][1], "basename") - step = tour.steps[2].children[2] + step = format_step.children[2] + self.assertIsInstance(step, dataset.HeaderStep) + self.assertEqual(step.name, SN.text_header_step.value) with patch_menu_prompt(1): # 1 is second remaining column, i.e., third column step.run() self.assertEqual(step.state["filelist_headers"][2], "text") + wavs_dir_step = find_step(SN.wavs_dir_step, tour.steps) + with monkeypatch(wavs_dir_step, "prompt", Say(str(self.data_dir))): + wavs_dir_step.run() + validate_wavs_step = find_step(SN.validate_wavs_step, tour.steps) - with monkeypatch(builtins, "input", Say("no")), capture_stdout() as out: + with patch_menu_prompt(1), capture_stdout() as out: validate_wavs_step.run() - self.assertEqual(step.state[SN.validate_wavs_step], "no") + self.assertEqual(step.state[SN.validate_wavs_step][:2], "No") self.assertIn("Warning: 3 wav files were not found", out.getvalue()) text_representation_step = find_step( @@ -570,7 +623,9 @@ def test_prompt(self): ) self.assertEqual(answer, 1) - def monkey_run_tour(self, name: str, steps_and_answers: list[StepAndAnswer]): + def monkey_run_tour( + self, name: str, steps_and_answers: list[StepAndAnswer] + ) -> tuple[Tour, str]: """Create and run a tour with the monkey answers given Args: @@ -642,10 +697,9 @@ def test_monkey_tour_1(self): def test_monkey_tour_2(self): data_dir = Path(__file__).parent / "data" - tour, log = self.monkey_run_tour( + tour, out = self.monkey_run_tour( "monkey tour 2", [ - StepAndAnswer(dataset.WavsDirStep(), Say(str(data_dir))), StepAndAnswer( dataset.FilelistStep(), Say(str(data_dir / "metadata.psv")), @@ -664,9 +718,10 @@ def test_monkey_tour_2(self): Say("no"), children_answers=[RecursiveAnswers(Say("eng"))], ), + StepAndAnswer(dataset.WavsDirStep(), Say(str(data_dir))), StepAndAnswer( dataset.ValidateWavsStep(), - monkeypatch(builtins, "input", Say("Yes")), + patch_menu_prompt(0), # 0 is Yes children_answers=[ RecursiveAnswers(Say(str(data_dir / "lj/wavs"))), RecursiveAnswers(null_patch()), @@ -682,8 +737,10 @@ def test_monkey_tour_2(self): ], ) - self.assertIn("Warning: 5 wav files were not found", log) - self.assertIn("Great! All audio files found in directory", log) + tree = str(RenderTree(tour.root)) + self.assertIn("├── Validate Wavs Step", tree) + self.assertIn("│ └── Validate Wavs Step", tree) + self.assertIn("Great! All audio files found in directory", out) # print(tour.state) self.assertEqual(len(tour.state["filelist_data"]), 5) diff --git a/everyvoice/wizard/__init__.py b/everyvoice/wizard/__init__.py index fd91f6ec..3c13780c 100644 --- a/everyvoice/wizard/__init__.py +++ b/everyvoice/wizard/__init__.py @@ -1,6 +1,6 @@ import sys from enum import Enum -from typing import Optional +from typing import Optional, Sequence from anytree import NodeMixin, RenderTree @@ -117,19 +117,25 @@ def run(self): self.run() +class RootStep(Step): + """Dummy step sitting at the root of the tour""" + + DEFAULT_NAME = "Root" + + def run(self): + pass + + class Tour: def __init__(self, name: str, steps: list[Step], state: Optional[State] = None): - """Create the tour by setting each Step as the child of the previous Step.""" + """Create the tour by placing all steps under a dummy root node""" self.name = name self.state: State = state if state is not None else State() - for parent, child in zip(steps, steps[1:]): - child.parent = parent - self.determine_state(child, self.state) - child.tour = self self.steps = steps - self.root = steps[0] + self.root = RootStep() self.root.tour = self self.determine_state(self.root, self.state) + self.add_steps(steps, self.root) def determine_state(self, step: Step, state: State): if step.state_subset is not None: @@ -139,6 +145,20 @@ def determine_state(self, step: Step, state: State): else: step.state = state + def add_steps(self, steps: Sequence[Step | list[Step]], parent: Step): + """Insert steps in front of the other children of parent. + + Steps are added as direct children. + For sublists of steps, the first is a direct child, the rest are under it. + """ + for item in reversed(steps): + if isinstance(item, list): + step, *children = item + self.add_step(step, parent) + self.add_steps(children, step) + else: + self.add_step(item, parent) + def add_step(self, step: Step, parent: Step, child_index=0): self.determine_state(step, self.state) step.tour = self @@ -161,6 +181,7 @@ class StepNames(Enum): contact_name_step = "Contact Name Step" contact_email_step = "Contact Email Step" dataset_name_step = "Dataset Name Step" + dataset_permission_step = "Dataset Permission Step" output_step = "Output Path Step" wavs_dir_step = "Wavs Dir Step" filelist_step = "Filelist Step" diff --git a/everyvoice/wizard/basic.py b/everyvoice/wizard/basic.py index 167e3bb9..9607bf48 100644 --- a/everyvoice/wizard/basic.py +++ b/everyvoice/wizard/basic.py @@ -91,12 +91,24 @@ class ContactEmailStep(Step): def prompt(self): return input("Please provide a contact email address for your models. ") + def in_unit_testing(self): + """Skip checking deliverability when in unit testing. + + Checking deliverability can be slow where there is not web connection, as + is sometimes the case when running the unit tests, so skip it in that context. + """ + import inspect + + return any( + frame.filename.endswith("test_wizard.py") for frame in inspect.stack() + ) + def validate(self, response): try: # Check that the email address is valid. Turn on check_deliverability # for first-time validations like on account creation pages (but not # login pages). - validate_email(response, check_deliverability=True) + validate_email(response, check_deliverability=not self.in_unit_testing()) except EmailNotValidError as e: # The exception message is a human-readable explanation of why it's # not a valid (or deliverable) email address. @@ -138,18 +150,32 @@ def validate(self, response) -> bool: output_path = path / self.state.get(StepNames.name_step, "DEFAULT_NAME") if output_path.exists(): print( - f"Sorry, '{output_path}' already exists. Please choose another output directory or start again and choose a different project name." + f"Sorry, '{output_path}' already exists. " + "Please choose another output directory or start again and choose a different project name." + ) + return False + + # We create the output directory in validate() instead of effect() so that + # failure can be reported to the user and the question asked again if necessary. + try: + output_path.mkdir(parents=True, exist_ok=True) + # we created it just to test permission, but don't leave it lying around in + # case the wizard is interrupted or fails. We'll create it again when we save. + output_path.rmdir() + except OSError as e: + print( + f"Sorry, could not create '{output_path}': {e}. " + "Please choose another output directory." ) return False + + self.output_path = output_path return True def effect(self): - assert self.state is not None, "OutputPathStep requires NameStep" - output_path = Path(self.response) / self.state.get( - StepNames.name_step, "DEFAULT_NAME" + print( + f"The Configuration Wizard 🧙 will put your files here: '{self.output_path}'" ) - output_path.mkdir(parents=True, exist_ok=True) - print(f"The Configuration Wizard 🧙 will put your files here: '{output_path}'") class ConfigFormatStep(Step): @@ -432,11 +458,14 @@ def effect(self): ) + 1 ) - self.tour.add_step( - MoreDatasetsStep(name=StepNames.more_datasets_step), self + + self.tour.add_steps( + get_dataset_steps(dataset_index=new_dataset_index) + + [MoreDatasetsStep()], + self, ) - for step in reversed(get_dataset_steps(dataset_index=new_dataset_index)): - self.tour.add_step(step, self) + elif len([key for key in self.state.keys() if key.startswith("dataset_")]) == 0: + print("No dataset to save, exiting without saving any configuration.") else: self.tour.add_step( ConfigFormatStep(name=StepNames.config_format_step), self diff --git a/everyvoice/wizard/dataset.py b/everyvoice/wizard/dataset.py index 9ea0967f..84432236 100644 --- a/everyvoice/wizard/dataset.py +++ b/everyvoice/wizard/dataset.py @@ -5,6 +5,9 @@ from unicodedata import normalize import questionary +import rich +from rich.panel import Panel +from rich.style import Style from tqdm import tqdm from everyvoice.config.type_definitions import DatasetTextRepresentation @@ -51,6 +54,30 @@ def effect(self): ) +class DatasetPermissionStep(Step): + DEFAULT_NAME = StepNames.dataset_permission_step + choices = ( + "No, I don't have permission to use this data.", + "Yes, I do have permission to use this data.", + ) + + def prompt(self): + prompt_text = """Do you have permission to use this data to build a TTS model? It is unethical to build a TTS model of a speaker without their knowledge or permission and there can be serious consequences for doing so.""" + return get_response_from_menu_prompt( + prompt_text=prompt_text, + choices=self.choices, + ) + + def validate(self, response): + return response in self.choices + + def effect(self): + if self.state[StepNames.dataset_permission_step.value].startswith("No"): + print("OK, we'll ask you to choose another dataset then!") + self.children = [] + del self.root.state[self.state_subset] + + class WavsDirStep(Step): DEFAULT_NAME = StepNames.wavs_dir_step @@ -271,16 +298,36 @@ def wav_file_early_validation(self) -> int: def prompt(self): error_count = self.wav_file_early_validation() if error_count: - return input("Do you want to pick a different wavs directory? ").lower() - return "ok" + return get_response_from_menu_prompt( + title="Do you want to pick a different wavs directory?", + choices=( + "Yes", + "No, I will fix my audio basenames or add missing audio files later.", + ), + ) + else: + return "OK" def validate(self, response): - return response in ("y", "yes", "n", "no", "ok") + return response[:3] in ("OK", "Yes", "No,") def effect(self): - if self.response[:1] == "y": - self.tour.add_step(ValidateWavsStep(state_subset=self.state_subset), self) - self.tour.add_step(WavsDirStep(state_subset=self.state_subset), self) + if self.response == "Yes": + self.tour.add_steps( + [ + WavsDirStep(state_subset=self.state_subset), + ValidateWavsStep(state_subset=self.state_subset), + ], + self, + ) + elif self.response.startswith("No"): + rich.print( + Panel( + "Continuing despite missing audio files. Make sure you fix your filelist later or add missing audio files, otherwise entries in your filelist with missing audio files will be skipped during preprocessing and therefore be ignored during training.", + title="Missing audio files", + border_style=Style(color="#EF1010"), + ) + ) class FilelistTextRepresentationStep(Step): @@ -682,17 +729,20 @@ def effect(self): def get_dataset_steps(dataset_index=0): return [ - WavsDirStep(state_subset=f"dataset_{dataset_index}"), FilelistStep(state_subset=f"dataset_{dataset_index}"), - FilelistFormatStep(state_subset=f"dataset_{dataset_index}"), - ValidateWavsStep(state_subset=f"dataset_{dataset_index}"), - FilelistTextRepresentationStep(state_subset=f"dataset_{dataset_index}"), - TextProcessingStep(state_subset=f"dataset_{dataset_index}"), - HasSpeakerStep(state_subset=f"dataset_{dataset_index}"), - HasLanguageStep(state_subset=f"dataset_{dataset_index}"), - SymbolSetStep(state_subset=f"dataset_{dataset_index}"), - SoxEffectsStep(state_subset=f"dataset_{dataset_index}"), - DatasetNameStep(state_subset=f"dataset_{dataset_index}"), + [ + DatasetPermissionStep(state_subset=f"dataset_{dataset_index}"), + FilelistFormatStep(state_subset=f"dataset_{dataset_index}"), + WavsDirStep(state_subset=f"dataset_{dataset_index}"), + ValidateWavsStep(state_subset=f"dataset_{dataset_index}"), + FilelistTextRepresentationStep(state_subset=f"dataset_{dataset_index}"), + TextProcessingStep(state_subset=f"dataset_{dataset_index}"), + HasSpeakerStep(state_subset=f"dataset_{dataset_index}"), + HasLanguageStep(state_subset=f"dataset_{dataset_index}"), + SymbolSetStep(state_subset=f"dataset_{dataset_index}"), + SoxEffectsStep(state_subset=f"dataset_{dataset_index}"), + DatasetNameStep(state_subset=f"dataset_{dataset_index}"), + ], ] diff --git a/everyvoice/wizard/main_tour.py b/everyvoice/wizard/main_tour.py index 8422c09c..097b97d7 100644 --- a/everyvoice/wizard/main_tour.py +++ b/everyvoice/wizard/main_tour.py @@ -8,17 +8,21 @@ ) from everyvoice.wizard.dataset import get_dataset_steps -WIZARD_TOUR = Tour( - name="Basic Tour", - steps=[ - NameStep(name=StepNames.name_step), - ContactNameStep(name=StepNames.contact_name_step), - ContactEmailStep(name=StepNames.contact_email_step), - OutputPathStep(name=StepNames.output_step), - ] - + get_dataset_steps() - + [MoreDatasetsStep(name=StepNames.more_datasets_step)], -) + +def get_main_wizard_tour(): + """Get the main wizard tour""" + return Tour( + name="Basic Tour", + steps=[ + NameStep(name=StepNames.name_step), + ContactNameStep(name=StepNames.contact_name_step), + ContactEmailStep(name=StepNames.contact_email_step), + OutputPathStep(name=StepNames.output_step), + ] + + get_dataset_steps() + + [MoreDatasetsStep(name=StepNames.more_datasets_step)], + ) + if __name__ == "__main__": - WIZARD_TOUR.run() + get_main_wizard_tour().run() diff --git a/everyvoice/wizard/prompts.py b/everyvoice/wizard/prompts.py index f948eaf6..1a01cf16 100644 --- a/everyvoice/wizard/prompts.py +++ b/everyvoice/wizard/prompts.py @@ -1,5 +1,5 @@ import sys -from typing import List, Tuple, Union +from typing import Union import simple_term_menu from questionary import Style @@ -27,18 +27,18 @@ def get_response_from_menu_prompt( prompt_text: str = "", - choices: Tuple[str, ...] = (), + choices: tuple[str, ...] = (), title: str = "", multi=False, search=False, return_indices=False, -) -> Union[int, str, List[int], List[str]]: +) -> Union[int, str, list[int], list[str]]: """Given some prompt text and a list of choices, create a simple terminal window and return the index of the choice Args: prompt_text (str): rich prompt text to print before menu - choices (List[str]): choices to display + choices (list[str]): choices to display Returns: int: index of choice diff --git a/requirements.txt b/requirements.txt index 4f755e59..bb22c9ae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ clipdetect>=0.1.4 deepdiff>=6.5.0 -anytree>=2.8.0 +anytree>=2.12.1 einops==0.5.0 g2p~=2.0.0 gradio>=4.32.1