Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image generation related #11

Open
zhd61 opened this issue Dec 5, 2024 · 4 comments
Open

Image generation related #11

zhd61 opened this issue Dec 5, 2024 · 4 comments

Comments

@zhd61
Copy link

zhd61 commented Dec 5, 2024

Hi, thank you for the explanation of image generation in #3, #4, and #6. I tried those modifications but it cannot reproduce the paper results.

Could you please share the full code for the image generation and the model weights?

Thank you.

@roymiles
Copy link
Owner

Hi, sorry for the late reply on this! - I have been quite busy lately and then away for christmas.
I will check the original code and checkpoints when I am back at work.

In the meantime, it is worth mentioning that there is a lot of variance on the CIFAR dataset and especially for image generation (+parameters like batch size are very important). Are you able to replicate the KD-DLGAN results? after which I can try and further help with the implementation.

@zhd61
Copy link
Author

zhd61 commented Jan 1, 2025

Hi, thank you for your response. I can replicate the KD-DLGAN results for CIFAR 10 and 100. Please suggest the next steps.

@roymiles
Copy link
Owner

roymiles commented Jan 2, 2025

here are some checkpoints I found for CIFAR-100 20% training data with/without whitening.

>>> w = torch.load('whiten/state_dict_best.pth')
>>> w['best_FID']
13.714126143712015

>>> w = torch.load('no-normalisation/state_dict_best.pth')
>>> w['best_FID']
14.230454154115964

This is the main code snippet (tidied up a bit) in the KD-DLGAN:train_fns.py script for the whitening experiment provided above:

# distillation losses (orthogonal projection)
D_real_out_reg = G.projector_real(D_real_out_reg)
D_fake_out_reg = G.projector_real(D_fake_out_reg)

# whitening
G.dbn_fake.train()
G.dbn_real.train()
latent_real_bn = 1.0 * G.dbn_real(latent_real)
latent_fake_bn = 1.0 * G.dbn_fake(latent_fake)  

weighting = 10
D_loss_real_reg = 0.0

# may need to try different weightings
triplet_weight = 10

D_loss_real_reg = losses.TripletLoss_reverse(latent_real_bn, D_real_out_reg, D_fake_out_reg) * triplet_weight
D_loss_real_reg += F.smooth_l1_loss(latent_real, D_real_out_reg) * weighting

D_loss_fake_reg = 0.0
D_loss_fake_reg = losses.TripletLoss_reverse(latent_fake_bn, D_fake_out_reg, D_real_out_reg) * triplet_weight
D_loss_fake_reg += F.smooth_l1_loss(latent_fake, D_fake_out_reg) * weighting

# we do not need this loss
D_loss_outdistri_div = torch.tensor(0.0)

# consistency along time dimension
if config['LC'] > 0 and state_dict['itr'] > ema_losses.start_itr:
    D_loss_LC = losses.lecam_reg(D_real, D_fake, ema_losses)*config['LC']
else:
    D_loss_LC = torch.tensor(0.)

# Compute components of D's loss, average them, and divide by
# the number of gradient accumulations
D_loss = D_loss_real + D_loss_fake + D_loss_LC  + D_loss_fake_reg + D_loss_real_reg + D_loss_outdistri_div
D_loss = D_loss / float(config['num_D_accumulations'])
D_loss.backward()
counter += 1

# Accumlated D losses
D_loss_real_total += (D_loss_real.item() / float(config['num_D_accumulations']))
D_loss_fake_reg_total += (D_loss_fake_reg.item() / float(config['num_D_accumulations']))
D_loss_real_reg_total = (D_loss_real_reg.item() / float(config['num_D_accumulations']))
D_loss_fake_total += (D_loss_fake.item() / float(config['num_D_accumulations']))
D_loss_outdistri_div_total += (D_loss_outdistri_div.item() / float(config['num_D_accumulations']))

D_real_total += (torch.mean(D_real).item() / float(config['num_D_accumulations']))
D_fake_total += (torch.mean(D_fake).item() / float(config['num_D_accumulations']))

Hope this helps

@zhd61
Copy link
Author

zhd61 commented Jan 3, 2025

Thanks

I see that the details on the G.dbn_fake and G.dbn_real are missing. Will you please share that code from the run_train.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants