Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #146

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
6 changes: 3 additions & 3 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import json
import torch
from scipy.io.wavfile import write
from env import AttrDict
from meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav
from models import Generator
from .env import AttrDict
from .meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav
from .models import Generator

h = None
device = None
Expand Down
6 changes: 3 additions & 3 deletions inference_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import json
import torch
from scipy.io.wavfile import write
from env import AttrDict
from meldataset import MAX_WAV_VALUE
from models import Generator
from .env import AttrDict
from .meldataset import MAX_WAV_VALUE
from .models import Generator

h = None
device = None
Expand Down
4 changes: 2 additions & 2 deletions meldataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin,

def get_dataset_filelist(a):
with open(a.input_training_file, 'r', encoding='utf-8') as fi:
training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
training_files = [os.path.join(a.input_wavs_dir, x.split('|')[1] + '.wav')
for x in fi.read().split('\n') if len(x) > 0]

with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[1] + '.wav')
for x in fi.read().split('\n') if len(x) > 0]
return training_files, validation_files

Expand Down
2 changes: 1 addition & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from utils import init_weights, get_padding
from .utils import init_weights, get_padding

LRELU_SLOPE = 0.1

Expand Down
13 changes: 8 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import torch.multiprocessing as mp
from torch.distributed import init_process_group
from torch.nn.parallel import DistributedDataParallel
from env import AttrDict, build_env
from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
from models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\
from .env import AttrDict, build_env
from .meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
from .models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\
discriminator_loss
from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
from .utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint

torch.backends.cudnn.benchmark = True

Expand Down Expand Up @@ -196,7 +196,8 @@ def train(rank, a, h):
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
h.hop_size, h.win_size,
h.fmin, h.fmax_for_loss)
val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
# val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
val_err_tot += F.l1_loss(y_mel, y_g_hat_mel[:,:,:y_mel.size(2)]).item()

if j <= 4:
if steps == 0:
Expand Down Expand Up @@ -242,6 +243,7 @@ def main():
parser.add_argument('--summary_interval', default=100, type=int)
parser.add_argument('--validation_interval', default=1000, type=int)
parser.add_argument('--fine_tuning', default=False, type=bool)
parser.add_argument('--batch_size', default=8, type=int)

a = parser.parse_args()

Expand All @@ -255,6 +257,7 @@ def main():
torch.manual_seed(h.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(h.seed)
h.batch_size = a.batch_size
h.num_gpus = torch.cuda.device_count()
h.batch_size = int(h.batch_size / h.num_gpus)
print('Batch size per GPU :', h.batch_size)
Expand Down