diff --git a/__init__.py b/__init__.py new file mode 100644 index 000000000..2692e56ab --- /dev/null +++ b/__init__.py @@ -0,0 +1,4 @@ +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self \ No newline at end of file diff --git a/inference.py b/inference.py index 96ba10672..e23252f3c 100644 --- a/inference.py +++ b/inference.py @@ -6,9 +6,9 @@ import json import torch from scipy.io.wavfile import write -from env import AttrDict -from meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav -from models import Generator +from .env import AttrDict +from .meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav +from .models import Generator h = None device = None diff --git a/inference_e2e.py b/inference_e2e.py index 0a8b6b323..6130f3115 100644 --- a/inference_e2e.py +++ b/inference_e2e.py @@ -7,9 +7,9 @@ import json import torch from scipy.io.wavfile import write -from env import AttrDict -from meldataset import MAX_WAV_VALUE -from models import Generator +from .env import AttrDict +from .meldataset import MAX_WAV_VALUE +from .models import Generator h = None device = None diff --git a/meldataset.py b/meldataset.py index 450292451..af5c591d4 100644 --- a/meldataset.py +++ b/meldataset.py @@ -74,11 +74,11 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, def get_dataset_filelist(a): with open(a.input_training_file, 'r', encoding='utf-8') as fi: - training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') + training_files = [os.path.join(a.input_wavs_dir, x.split('|')[1] + '.wav') for x in fi.read().split('\n') if len(x) > 0] with open(a.input_validation_file, 'r', encoding='utf-8') as fi: - validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') + validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[1] + '.wav') for x in fi.read().split('\n') if len(x) > 0] return training_files, validation_files diff --git a/models.py b/models.py index da233d02d..d2a2ef420 100644 --- a/models.py +++ b/models.py @@ -3,7 +3,7 @@ import torch.nn as nn from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from utils import init_weights, get_padding +from .utils import init_weights, get_padding LRELU_SLOPE = 0.1 diff --git a/train.py b/train.py index 3b5509428..516ac3f63 100644 --- a/train.py +++ b/train.py @@ -12,11 +12,11 @@ import torch.multiprocessing as mp from torch.distributed import init_process_group from torch.nn.parallel import DistributedDataParallel -from env import AttrDict, build_env -from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist -from models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\ +from .env import AttrDict, build_env +from .meldataset import MelDataset, mel_spectrogram, get_dataset_filelist +from .models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\ discriminator_loss -from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint +from .utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint torch.backends.cudnn.benchmark = True @@ -196,7 +196,8 @@ def train(rank, a, h): y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax_for_loss) - val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() + # val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() + val_err_tot += F.l1_loss(y_mel, y_g_hat_mel[:,:,:y_mel.size(2)]).item() if j <= 4: if steps == 0: @@ -242,6 +243,7 @@ def main(): parser.add_argument('--summary_interval', default=100, type=int) parser.add_argument('--validation_interval', default=1000, type=int) parser.add_argument('--fine_tuning', default=False, type=bool) + parser.add_argument('--batch_size', default=8, type=int) a = parser.parse_args() @@ -255,6 +257,7 @@ def main(): torch.manual_seed(h.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(h.seed) + h.batch_size = a.batch_size h.num_gpus = torch.cuda.device_count() h.batch_size = int(h.batch_size / h.num_gpus) print('Batch size per GPU :', h.batch_size)