-
Notifications
You must be signed in to change notification settings - Fork 145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why is tc_loss in bTCVAE negative? #60
Comments
https://github.com/rtqichen/beta-tcvae/ calculates and in case of # minibatch stratified sampling, they do so in this codebase, shouldn't we also do (in case of NOT is_mss)
and in case of (is_mss)
|
Thanks @UserName-AnkitSisodia! Did you test it with these changes? |
Using some random matrices (code attached ), I used your code as well as Ricky Chen's code to compare what is happening. I found MWS MSS So, when I use your code with is_mss=true, then I get -ve tc_loss and with is_mss=false, I get -ve mi_loss and -ve tc_loss. Then I changed the _get_log_pz_qz_prodzi_qzCx function in your code to make it similar to Ricky Chen's code.
Then I get +ve losses for everything when is_mss=True but then I get -ve dw_kl_loss term. |
Awesome thanks for checking. Few comments: 1/ What do you mean by "+ve" and "-ve" ? What is ve ? 2/ Looking back at it it seems that I actually had the correct code and then incorporated the problem it in a late night push ( #43 ) Here's what I had before my changes: def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist,n_data, is_mss=False):
batch_size, hidden_dim = latent_sample.shape
# calculate log q(z|x)
log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)
# calculate log p(z)
# mean and log var is 0
zeros = torch.zeros_like(latent_sample)
log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)
if not self.is_mss:
log_qz, log_prod_qzi = _minibatch_weighted_sampling(latent_dist,
latent_sample,
n_data)
else:
log_qz, log_prod_qzi = _minibatch_stratified_sampling(latent_dist,
latent_sample,
n_data)
return log_pz, log_qz, log_prod_qzi, log_q_zCx
def _minibatch_weighted_sampling(latent_dist, latent_sample, data_size):
"""
Estimates log q(z) and the log (product of marginals of q(z_j)) with minibatch
weighted sampling.
Parameters
----------
latent_dist : tuple of torch.tensor
sufficient statistics of the latent dimension. E.g. for gaussian
(mean, log_var) each of shape : (batch_size, latent_dim).
latent_sample: torch.Tensor
sample from the latent dimension using the reparameterisation trick
shape : (batch_size, latent_dim).
data_size : int
Number of data in the training set
References
-----------
[1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
autoencoders." Advances in Neural Information Processing Systems. 2018.
"""
batch_size = latent_sample.size(0)
mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)
log_prod_qzi = (torch.logsumexp(mat_log_qz, dim=1, keepdim=False) -
math.log(batch_size * data_size)).sum(dim=1)
log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False
) - math.log(batch_size * data_size)
return log_qz, log_prod_qzi
def _minibatch_stratified_sampling(latent_dist, latent_sample, data_size):
"""
Estimates log q(z) and the log (product of marginals of q(z_j)) with minibatch
stratified sampling.
Parameters
-----------
latent_dist : tuple of torch.tensor
sufficient statistics of the latent dimension. E.g. for gaussian
(mean, log_var) each of shape : (batch_size, latent_dim).
latent_sample: torch.Tensor
sample from the latent dimension using the reparameterisation trick
shape : (batch_size, latent_dim).
data_size : int
Number of data in the training set
References
-----------
[1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
autoencoders." Advances in Neural Information Processing Systems. 2018.
"""
batch_size = latent_sample.size(0)
mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)
log_iw_mat = log_importance_weight_matrix(batch_size, data_size).to(latent_sample.device)
log_qz = torch.logsumexp(log_iw_mat + mat_log_qz.sum(2), dim=1, keepdim=False)
log_prod_qzi = torch.logsumexp(log_iw_mat.view(batch_size, batch_size, 1) +
mat_log_qz, dim=1, keepdim=False).sum(1)
return log_qz, log_prod_qzi |
which is (I believe) exactly what you tested.
|
Yes, this makes the code exactly same. Once these changes are made, I get negative dw_kl_loss term in case of _minibatch_weighted_sampling. For _minibatch_stratified_sampling, I am getting all loss terms as positive. I tested on dsprites. |
and qualitatively do you see any differences? |
I didn't test that yet. I was just trying to see from the math/code where am I getting the error. |
Has this issue been solved ? Training on dSprites, I also get negative tc loss |
I also got the negative loss with the DSprites data |
tc loss |
Hi, is this fixed? It seems the current version of code here disentangling-vae/disvae/models/losses.py Lines 536 to 542 in f045219
is not implementing MWS if is_mss=False like this one
I noticed even when we change the loss to be exactly like #60 (comment), behaviors of MI / TC / dw-KL are very different when using MWS or MSS in my experiments, especially dw-KL in MWS is negative. Example loss values are Is it normal that these two methods have very different values for these terms and negative estimation for some KL? |
disentangling-vae/results/btcvae_dsprites/train_losses.log
Line 5 in 535bbd2
The text was updated successfully, but these errors were encountered: