Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new method for custom formatter handling #293

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions TTS/tts/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,23 @@ def load_attention_mask_meta_data(metafile_path):
return meta_data


def add_formatter(name: str, formatter: Callable[[str, str, list[str] | None], list[dict]]):
"""Add a formatter to the datasets module. If the formatter already exists, raise an error.
Args:
name (str): The name of the formatter.
formatter (Callable): The formatter function.
Raises:
ValueError: If the formatter already exists.
Returns:
None
"""
thismodule = sys.modules[__name__]
if not hasattr(thismodule, name.lower()):
setattr(thismodule, name.lower(), formatter)
else:
raise ValueError(f"Formatter {name} already exists.")


def _get_formatter_by_name(name):
"""Returns the respective preprocessing function."""
thismodule = sys.modules[__name__]
Expand Down
8 changes: 6 additions & 2 deletions docs/source/datasets/formatting_your_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
Load a custom dataset with a custom formatter.

```python
from TTS.tts.datasets import load_tts_samples
from TTS.tts.datasets import load_tts_samples, add_formatter


# custom formatter implementation
Expand All @@ -119,8 +119,12 @@ def formatter(root_path, manifest_file, **kwargs): # pylint: disable=unused-arg
items.append({"text":text, "audio_file":wav_file, "speaker_name":speaker_name, "root_path": root_path})
return items

add_formatter("custom_formatter_name", formatter) # Use the custom formatter name in the dataset config
dataset_config = BaseDatasetConfig(
formatter="custom_formatter_name", meta_file_train="", language="en-us", path="dataset-path")
)
# load training samples
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True, formatter=formatter)
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
```

See `TTS.tts.datasets.TTSDataset`, a generic `Dataset` implementation for the `tts` models.
Expand Down
11 changes: 11 additions & 0 deletions tests/data_tests/test_dataset_formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,14 @@ def test_common_voice_preprocessor(self): # pylint: disable=no-self-use

assert items[-1]["text"] == "Competition for limited resources has also resulted in some local conflicts."
assert items[-1]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav")

def test_custom_formatter_with_existing_name(self):
from TTS.tts.datasets import add_formatter

def custom_formatter(root_path, meta_file, ignored_speakers=None):
return []

add_formatter("custom_formatter", custom_formatter)

with self.assertRaises(ValueError):
add_formatter("custom_formatter", custom_formatter)
35 changes: 34 additions & 1 deletion tests/data_tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from tests import get_tests_data_path
from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.datasets import add_formatter, load_tts_samples
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
Expand Down Expand Up @@ -251,3 +251,36 @@ def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths):
# check batch zero-frame conditions (zero-frame disabled)
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0


def test_custom_formatted_dataset_with_loader():
def custom_formatter(path, metafile, **kwargs):
with open(os.path.join(path, metafile)) as f:
data = f.readlines()
items = []
for line in data:
file_path, text = line.split("|", 1)
items.append({"text": text, "audio_file": file_path, "root_path": path, "speaker_name": "test"})
return items

def custom_formatter2(x, *args, **kwargs):
items = custom_formatter(x, *args, **kwargs)
[item.update({"audio_file": f"{item['audio_file']}.wav"}) for item in items]
return items

add_formatter("custom_formatter1", custom_formatter)
add_formatter("custom_formatter2", custom_formatter2)
dataset1 = BaseDatasetConfig(
formatter="custom_formatter1",
meta_file_train="metadata.csv",
path=c.data_path,
)
dataset2 = BaseDatasetConfig(
formatter="custom_formatter2",
meta_file_train="metadata.csv",
path=c.data_path,
)
dataset_configs = [dataset1, dataset2]
train_samples, eval_samples = load_tts_samples(dataset_configs, eval_split=True, eval_split_size=0.2)
assert len(train_samples) == 14
assert len(eval_samples) == 2
Loading