Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why is tc_loss in bTCVAE negative? #60

Open
sisodia-a opened this issue Oct 4, 2020 · 12 comments
Open

Why is tc_loss in bTCVAE negative? #60

sisodia-a opened this issue Oct 4, 2020 · 12 comments

Comments

@sisodia-a
Copy link

@sisodia-a
Copy link
Author

https://github.com/rtqichen/beta-tcvae/ calculates
logqz_prodmarginals = (logsumexp(_logqz, dim=1, keepdim=False) - math.log(batch_size * dataset_size)).sum(1)
logqz = (logsumexp(_logqz.sum(2), dim=1, keepdim=False) - math.log(batch_size * dataset_size))
in case of # minibatch weighted sampling

and in case of # minibatch stratified sampling, they do
logiw_matrix = Variable(self._log_importance_weight_matrix(batch_size, dataset_size).type_as(_logqz.data))
logqz = logsumexp(logiw_matrix + _logqz.sum(2), dim=1, keepdim=False)
logqz_prodmarginals = logsumexp(logiw_matrix.view(batch_size, batch_size, 1) + _logqz, dim=1, keepdim=False).sum(1)

so in this codebase, shouldn't we also do (in case of NOT is_mss)

log_qz = (torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)-math.log(batch_size*n_data))       
log_prod_qzi = (torch.logsumexp(mat_log_qz, dim=1, keepdim=False)-math.log(batch_size*n_data)).sum(1)

and in case of (is_mss)

    log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)                   
    log_qz = torch.logsumexp(log_iw_mat + mat_log_qz.sum(2), dim=1, keepdim=False)                            
    log_prod_qzi = torch.logsumexp(log_iw_mat.view(batch_size,batch_size,1)+mat_log_qz, dim=1, keepdim=False).sum(1)    

@YannDubs
Copy link
Owner

YannDubs commented Oct 7, 2020

Thanks @UserName-AnkitSisodia!
I think you might be right (I am taking a sum instead of marginalizing in the log space), but It's been a long time so I'll have to double-check this w-e.

Did you test it with these changes?

@sisodia-a
Copy link
Author

Using some random matrices (code attached
temp.txt
temp.txt

), I used your code as well as Ricky Chen's code to compare what is happening.

I found

MWS
log_qz != logqz_ricky
log_prod_qzi != logqz_prodmarginals_ricky

MSS
logqz_prodmarginals_ricky_mss == log_prod_qzi_mss
logqz_ricky_mss != log_qz_mss

So, when I use your code with is_mss=true, then I get -ve tc_loss and with is_mss=false, I get -ve mi_loss and -ve tc_loss.
I ran it on dsprites dataset with batchsize 128.

Then I changed the _get_log_pz_qz_prodzi_qzCx function in your code to make it similar to Ricky Chen's code.

batch_size, hidden_dim = latent_sample.shape

# calculate log q(z|x)
log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)

# calculate log p(z)
# mean and log var is 0
zeros = torch.zeros_like(latent_sample)
log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)

mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)

log_qz = (torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)-math.log(batch_size * n_data))       ## Ankit - modified
log_prod_qzi = (torch.logsumexp(mat_log_qz, dim=1, keepdim=False)-math.log(batch_size * n_data)).sum(1) ## Ankit - modified

# is_mss=False
if is_mss:                                                                                                                ## Ankit - modified
    # use stratification                                                                                                  ## Ankit - modifiede
    log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)                                ## Ankit - modified
    log_qz = torch.logsumexp(log_iw_mat + mat_log_qz.sum(2), dim=1, keepdim=False)                                        ## Ankit - modified
    log_prod_qzi = torch.logsumexp(log_iw_mat.view(batch_size,batch_size,1)+mat_log_qz, dim=1, keepdim=False).sum(1)      ## Ankit - modified

return log_pz, log_qz, log_prod_qzi, log_q_zCx

Then I get +ve losses for everything when is_mss=True but then I get -ve dw_kl_loss term.

@YannDubs
Copy link
Owner

YannDubs commented Oct 7, 2020

Awesome thanks for checking. Few comments:

1/ What do you mean by "+ve" and "-ve" ? What is ve ?

2/ Looking back at it it seems that I actually had the correct code and then incorporated the problem it in a late night push ( #43 )

Here's what I had before my changes:

def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist,n_data, is_mss=False):
    batch_size, hidden_dim = latent_sample.shape

    # calculate log q(z|x)
    log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)

    # calculate log p(z)
    # mean and log var is 0
    zeros = torch.zeros_like(latent_sample)
    log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)

    if not self.is_mss:
        log_qz, log_prod_qzi = _minibatch_weighted_sampling(latent_dist,
                                                            latent_sample,
                                                            n_data)

    else:
        log_qz, log_prod_qzi = _minibatch_stratified_sampling(latent_dist,
                                                              latent_sample,
                                                              n_data)

    return log_pz, log_qz, log_prod_qzi, log_q_zCx


def _minibatch_weighted_sampling(latent_dist, latent_sample, data_size):
    """
    Estimates log q(z) and the log (product of marginals of q(z_j)) with minibatch
    weighted sampling.

    Parameters
    ----------
    latent_dist : tuple of torch.tensor
        sufficient statistics of the latent dimension. E.g. for gaussian
        (mean, log_var) each of shape : (batch_size, latent_dim).

    latent_sample: torch.Tensor
        sample from the latent dimension using the reparameterisation trick
        shape : (batch_size, latent_dim).

    data_size : int
        Number of data in the training set

    References 
    -----------
       [1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
       autoencoders." Advances in Neural Information Processing Systems. 2018.
    """
    batch_size = latent_sample.size(0)

    mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)

    log_prod_qzi = (torch.logsumexp(mat_log_qz, dim=1, keepdim=False) -
                    math.log(batch_size * data_size)).sum(dim=1)
    log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False
                             ) - math.log(batch_size * data_size)

    return log_qz, log_prod_qzi


def _minibatch_stratified_sampling(latent_dist, latent_sample, data_size):
    """
    Estimates log q(z) and the log (product of marginals of q(z_j)) with minibatch
    stratified sampling.
    
    Parameters
    -----------
    latent_dist : tuple of torch.tensor
        sufficient statistics of the latent dimension. E.g. for gaussian
        (mean, log_var) each of shape : (batch_size, latent_dim).

    latent_sample: torch.Tensor
        sample from the latent dimension using the reparameterisation trick
        shape : (batch_size, latent_dim).

    data_size : int
        Number of data in the training set

    References 
    -----------
       [1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
       autoencoders." Advances in Neural Information Processing Systems. 2018.
    """
    batch_size = latent_sample.size(0)

    mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)

    log_iw_mat = log_importance_weight_matrix(batch_size, data_size).to(latent_sample.device)
    log_qz = torch.logsumexp(log_iw_mat + mat_log_qz.sum(2), dim=1, keepdim=False)
    log_prod_qzi = torch.logsumexp(log_iw_mat.view(batch_size, batch_size, 1) +
                                   mat_log_qz, dim=1, keepdim=False).sum(1)

    return log_qz, log_prod_qzi

@YannDubs
Copy link
Owner

YannDubs commented Oct 7, 2020

which is (I believe) exactly what you tested.

  • Does it also work for is_mss =False?

  • Just to be sure I understand, are you saying that with MSS this makes dw_kl_loss become negative ?

  • did you see any impact on the qualitative samples when training a model that way ?

@sisodia-a
Copy link
Author

Yes, this makes the code exactly same. Once these changes are made, I get negative dw_kl_loss term in case of _minibatch_weighted_sampling. For _minibatch_stratified_sampling, I am getting all loss terms as positive. I tested on dsprites.

@YannDubs
Copy link
Owner

YannDubs commented Oct 8, 2020

and qualitatively do you see any differences?

@sisodia-a
Copy link
Author

I didn't test that yet. I was just trying to see from the math/code where am I getting the error.

@DianeBouchacourt
Copy link

DianeBouchacourt commented Feb 25, 2021

Has this issue been solved ? Training on dSprites, I also get negative tc loss

@shi-yu-wang
Copy link

I also got the negative loss with the DSprites data

@shi-yu-wang
Copy link

tc loss

@sivannavis
Copy link

sivannavis commented Oct 28, 2024

Hi, is this fixed? It seems the current version of code here

if is_mss:
# use stratification
log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)
mat_log_qz = mat_log_qz + log_iw_mat.view(batch_size, batch_size, 1)
log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)
log_prod_qzi = torch.logsumexp(mat_log_qz, dim=1, keepdim=False).sum(1)

is not implementing MWS if is_mss=False like this one

https://github.com/clementchadebec/benchmark_VAE/blob/6419e21558f2a6abc2da99944bddda846ded30f4/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py#L137-L146

I noticed even when we change the loss to be exactly like #60 (comment), behaviors of MI / TC / dw-KL are very different when using MWS or MSS in my experiments, especially dw-KL in MWS is negative.

Example loss values are
MSS: TC 0.4 MI 2 dw-KL 0.87
MWS: TC 132 MI 21 dw-KL -150

Is it normal that these two methods have very different values for these terms and negative estimation for some KL?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants