-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathcontrollable_talknet_cli.py
118 lines (92 loc) · 6.58 KB
/
controllable_talknet_cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright (C) 2023 HydrusBeta
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import argparse
import base64
import hashlib
import io
import os
import librosa
import soundfile
import controllable_talknet
RESULTS_DIR = os.path.join(controllable_talknet.RUN_PATH, 'results')
def parse_arguments():
parser = argparse.ArgumentParser(prog='Controllable TalkNet',
description="A text-to-speech program based on NVIDIA's implementation of "
"Talknet2, with some changes to support singing synthesis and higher "
"audio quality.")
parser.add_argument('-t', '--text', type=str, required=True, help='The text you would like a pony to say.')
parser.add_argument('-i', '--reference_audio', type=str, help='The reference audio file to use for guiding the pacing, pitch, and inflection of the generated voice.')
parser.add_argument('-f', '--pitch_factor', type=int, default=0, help='An integer specifying how many semitones by which to shift the pitch of the input audio.')
parser.add_argument('-o', '--output', type=str, help='The desired output filepath. Be sure to include the desired extension, such as .flac or .mp3. If this argument is not passed, then the output will be written as a flac file to a "results" subdirectory in the ControllableTalkNet directory')
parser.add_argument('-p', '--pitch_options', type=str, default=[], choices=['srec', 'pc'], nargs=argparse.REMAINDER, help='One or both of the following values: "pc", which instructs Controllable TalkNet to auto-tune the output using the reference audio, and "srec", which instructs Controllable TalkNet to attempt to reduce metallic noise in the output.')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument( '-c', '--character', type=str, choices=known_characters(), help='The name of the pony character whose voice you would like to generate.'),
group.add_argument( '-m', '--custom_model', type=str, help='The name of a custom model to use for voice generation. You must place your custom model folder within the "models" folder. Pass the name of the custom model folder as this argument'),
return parser.parse_args()
def known_characters():
dropdown_options, _ = controllable_talknet.init_dropdown.__wrapped__(None)
return [option['label'] for option in dropdown_options]
def amend_pitch_options_if_needed(args):
# If no reference audio is supplied, add 'dra' (Disable Reference Audio) to pitch_options
if not args.reference_audio:
args.pitch_options = list(set(args.pitch_options + ['dra']))
# If a nonzero pitch shift is specified, add 'pf' (Pitch Factor) to pitch_options
if args.pitch_factor != 0:
args.pitch_options = list(set(args.pitch_options + ['pf']))
return args.pitch_options
def create_unique_file(args):
# Create a unique file name by hashing all the arguments together. Prepend part of the input text to make it easier
# for the user to find the output file they just generated. The file will be placed in the results directory;
# return the full path to the file.
input_hash = ''
if args.reference_audio and 'dra' not in args.pitch_options:
input_data, _ = librosa.load(args.reference_audio, sr=None)
input_hash = hashlib.sha256(input_data).hexdigest()[:20]
base_string = args.text + input_hash + str(args.character) + str(args.custom_model) + str(args.pitch_factor) + \
''.join(args.pitch_options)
hash = hashlib.sha256(base_string.encode('utf-8')).hexdigest()[:20]
unique_filename = args.text[:15] + ('...' if len(args.text) > 15 else '_') + hash + '.flac'
return os.path.join(RESULTS_DIR, unique_filename)
def prepare_output_directory():
if not os.path.exists(RESULTS_DIR):
os.makedirs(RESULTS_DIR)
def generate_audio(args) -> (float, int):
f0s, f0s_wo_silence, wav_name = None, None, None
if args.reference_audio is not None:
_, f0s, f0s_wo_silence, wav_name = controllable_talknet.select_file.__wrapped__(args.reference_audio, [''])
drive_id = 'Custom' if args.custom_model else get_drive_id_from_character(args.character)
src, _, _, _ = controllable_talknet.generate_audio.__wrapped__(0, drive_id, args.custom_model, args.text,
args.pitch_options, args.pitch_factor, wav_name,
f0s, f0s_wo_silence)
return get_audio_from_src(src, encoding='ascii')
def get_drive_id_from_character(character):
dropdown_options, _ = controllable_talknet.init_dropdown.__wrapped__(None)
character_to_id_map = {option['label']: option['value'].split('|')[0] for option in dropdown_options}
return character_to_id_map.get(character)
def write_output_file(args, output_array, output_samplerate):
output_path = create_unique_file(args) if args.output is None else args.output
soundfile.write(output_path, output_array, output_samplerate)
def get_audio_from_src(src, encoding):
_, raw = src.split(',')
b64_output_bytes = raw.encode(encoding)
output_bytes = base64.b64decode(b64_output_bytes)
buffer = io.BytesIO(output_bytes)
return librosa.load(buffer, sr=None)
if __name__ == '__main__':
args = parse_arguments()
args.pitch_options = amend_pitch_options_if_needed(args)
prepare_output_directory()
output_array, output_samplerate = generate_audio(args)
write_output_file(args, output_array, output_samplerate)