Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support needed for Multiple Discriminators Implementation #68

Open
MuruganR96 opened this issue Dec 7, 2022 · 2 comments
Open

Support needed for Multiple Discriminators Implementation #68

MuruganR96 opened this issue Dec 7, 2022 · 2 comments
Labels
help wanted Extra attention is needed

Comments

@MuruganR96
Copy link

MuruganR96 commented Dec 7, 2022

    Using multiple discriminators is effective, and when the model converges, the sound quality on the unseen speaker is better, and the similarity to the target speaker is better than the original one.

Originally posted by @980202006 in #6 (comment)

@MuruganR96
Copy link
Author

Hi @yl4579 @980202006 I read out the #6 (comment)

So could you please guide me, on where to modify the code to implement the Multiple Discriminator Feature?

Thanks

@MuruganR96
Copy link
Author

MuruganR96 commented Dec 10, 2022

@yl4579 @980202006 Please validate my hypothesis about Multiple Discriminator Implementation.

Ex. Batch size 1, I have 60 Speakers, and I am trying 3 Discriminators. So discriminator_ids 1, 2, 3

  1. In meldataset.py - Target speaker_id based, I will pass discriminator_id in MelDataset get_item ([0-19] - 1, [20-39] -2, [40-59] - 3)
  2. In models.py - build a set of discriminators each of which only works on a subset of speakers Some doubt about any to any voice conversion #6 (comment)
def build_model(args, F0_model, ASR_model):
    generator = Generator(args.dim_in, args.style_dim, args.max_conv_dim, w_hpf=args.w_hpf, F0_channel=args.F0_channel)
    mapping_network = MappingNetwork(args.latent_dim, args.style_dim, args.num_domains, hidden_dim=args.max_conv_dim)
    style_encoder = StyleEncoder(args.dim_in, args.style_dim, args.num_domains, args.max_conv_dim)

    discriminator1 = Discriminator(args.dim_in, args.num_domains, args.max_conv_dim, args.n_repeat)
    discriminator2 = Discriminator(args.dim_in, args.num_domains, args.max_conv_dim, args.n_repeat)
    discriminator3 = Discriminator(args.dim_in, args.num_domains, args.max_conv_dim, args.n_repeat)

    generator_ema = copy.deepcopy(generator)
    mapping_network_ema = copy.deepcopy(mapping_network)
    style_encoder_ema = copy.deepcopy(style_encoder)
        
    nets = Munch(generator=generator,
                 mapping_network=mapping_network,
                 style_encoder=style_encoder,
                 discriminator1=discriminator1,
                 discriminator2=discriminator2,
                 discriminator3=discriminator3,
                 f0_model=F0_model,
                 asr_model=ASR_model)
    
    nets_ema = Munch(generator=generator_ema,
                     mapping_network=mapping_network_ema,
                     style_encoder=style_encoder_ema)

    return nets, nets_ema
  1. in trainer.py - passing the discriminator id to compute_d_loss
    def _train_epoch(self):
        self.epochs += 1
        
        train_losses = defaultdict(list)
        _ = [self.model[k].train() for k in self.model]
        scaler = torch.cuda.amp.GradScaler() if (('cuda' in str(self.device)) and self.fp16_run) else None

        use_con_reg = (self.epochs >= self.args.con_reg_epoch)
        use_adv_cls = (self.epochs >= self.args.adv_cls_epoch)
        
        for train_steps_per_epoch, batch in enumerate(tqdm(self.train_dataloader, desc="[train]"), 1):

            ### load data
            batch = [b.to(self.device) for b in batch]
            x_real, y_org, x_ref, x_ref2, y_trg, z_trg, z_trg2, discriminator_id = batch
            
            # train the discriminator (by random reference)
            self.optimizer.zero_grad()
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    d_loss, d_losses_latent = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, discriminator_id, z_trg=z_trg, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg)
                scaler.scale(d_loss).backward()
            else:
                d_loss, d_losses_latent = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, discriminator_id, z_trg=z_trg, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg)
                d_loss.backward()
            self.optimizer.step('discriminator', scaler=scaler)

            # train the discriminator (by target reference)
            self.optimizer.zero_grad()
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    d_loss, d_losses_ref = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, discriminator_id, x_ref=x_ref, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg)
                scaler.scale(d_loss).backward()
            else:
                d_loss, d_losses_ref = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, discriminator_id, x_ref=x_ref, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg)
                d_loss.backward()
  1. losses.py - based on discriminator id, predict and compute_d_loss
def compute_d_loss(nets, args, x_real, y_org, y_trg, discriminator_id, z_trg=None, x_ref=None, use_r1_reg=True, use_adv_cls=False, use_con_reg=False):
    args = Munch(args)

    assert (z_trg is None) != (x_ref is None)
    # with real audios
    x_real.requires_grad_()

    if discriminator_id == 1:
        out = nets.discriminator1(x_real, y_org)
    elif discriminator_id == 2:
        out = nets.discriminator2(x_real, y_org)
    else:
        out = nets.discriminator3(x_real, y_org)

    loss_real = adv_loss(out, 1)
    
    # R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
    if use_r1_reg:
        loss_reg = r1_reg(out, x_real)
    else:
        loss_reg = torch.FloatTensor([0]).to(x_real.device)
    
    # consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724)
    loss_con_reg = torch.FloatTensor([0]).to(x_real.device)
    if use_con_reg:
        t = build_transforms()

        if discriminator_id == 1:
            out_aug = nets.discriminator1(t(x_real).detach(), y_org)
        elif discriminator_id == 2:
            out_aug = nets.discriminator2(t(x_real).detach(), y_org)
        else:
            out_aug = nets.discriminator3(t(x_real).detach(), y_org)

        loss_con_reg += F.smooth_l1_loss(out, out_aug)
    
    # with fake audios
    with torch.no_grad():
        if z_trg is not None:
            s_trg = nets.mapping_network(z_trg, y_trg)
        else:  # x_ref is not None
            s_trg = nets.style_encoder(x_ref, y_trg)
            
        F0 = nets.f0_model.get_feature_GAN(x_real)
        x_fake = nets.generator(x_real, s_trg, masks=None, F0=F0)

    if discriminator_id == 1:
        out = nets.discriminator1(x_fake, y_trg)
    elif discriminator_id == 2:
        out = nets.discriminator2(x_fake, y_trg)
    else:
        out = nets.discriminator3(x_fake, y_trg)

    loss_fake = adv_loss(out, 0)
    if use_con_reg:

        if discriminator_id == 1:
            out_aug = nets.discriminator1(t(x_fake).detach(), y_trg)
        elif discriminator_id == 2:
            out_aug = nets.discriminator2(t(x_fake).detach(), y_trg)
        else:
            out_aug = nets.discriminator3(t(x_fake).detach(), y_trg)

        loss_con_reg += F.smooth_l1_loss(out, out_aug)
    
    # adversarial classifier loss
    if use_adv_cls:

        if discriminator_id == 1:
            out_de = nets.discriminator1.classifier(x_fake)
        elif discriminator_id == 2:
            out_de = nets.discriminator2.classifier(x_fake)
        else:
            out_de = nets.discriminator3.classifier(x_fake)

        loss_real_adv_cls = F.cross_entropy(out_de[y_org != y_trg], y_org[y_org != y_trg])
        
        if use_con_reg:

            if discriminator_id == 1:
                out_de_aug = nets.discriminator1.classifier(t(x_fake).detach())
            elif discriminator_id == 2:
                out_de_aug = nets.discriminator2.classifier(t(x_fake).detach())
            else:
                out_de_aug = nets.discriminator3.classifier(t(x_fake).detach())

            loss_con_reg += F.smooth_l1_loss(out_de, out_de_aug)
    else:
        loss_real_adv_cls = torch.zeros(1).mean()
        
    loss = loss_real + loss_fake + args.lambda_reg * loss_reg + \
            args.lambda_adv_cls * loss_real_adv_cls + \
            args.lambda_con_reg * loss_con_reg 

    return loss, Munch(real=loss_real.item(),
                       fake=loss_fake.item(),
                       reg=loss_reg.item(),
                       real_adv_cls=loss_real_adv_cls.item(),
                       con_reg=loss_con_reg.item())

@yl4579 @980202006

Thanks

@MuruganR96 MuruganR96 changed the title Support for Multiple Discriminators Implementation Support needed for Multiple Discriminators Implementation Dec 21, 2022
@yl4579 yl4579 added the help wanted Extra attention is needed label Jan 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants