Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scanvae fix #1

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions scvi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

# https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
# https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302

print('Modified version, 12/06')

try:
import importlib.metadata as importlib_metadata
except ModuleNotFoundError:
Expand Down
143 changes: 82 additions & 61 deletions scvi/module/_scanvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import torch
from torch.distributions import Categorical, Normal #ok
from torch.distributions import Categorical, Normal # ok
from torch.distributions import kl_divergence as kl
from torch.nn import functional as F

Expand All @@ -11,12 +11,12 @@
from scvi.module.base import LossRecorder, auto_move_data
from scvi.nn import Decoder, Encoder

from ._classifier import Classifier #Basic fully-connected NN classifier.
from ._classifier import Classifier # Basic fully-connected NN classifier.
from ._utils import broadcast_labels
from ._vae import VAE


class SCANVAE(VAE): #inherits from VAE class (for instance inherits z_encoder)
class SCANVAE(VAE): # inherits from VAE class (for instance inherits z_encoder)
"""
Single-cell annotation using variational inference.

Expand All @@ -39,7 +39,7 @@ class SCANVAE(VAE): #inherits from VAE class (for instance inherits z_encoder)
n_layers
Number of hidden layers used for encoder and decoder NNs
n_continuous_cov
Number of continuous covariates
Number of continuous covariates
n_cats_per_cov
Number of categories for each extra categorical covariate
dropout_rate
Expand All @@ -61,7 +61,7 @@ class SCANVAE(VAE): #inherits from VAE class (for instance inherits z_encoder)
y_prior
If None, initialized to uniform probability over cell types OK
labels_groups
Label group designations ?? --> hierarchie entre labels
Label group designations
use_labels_groups
Whether to use the label groups
use_batch_norm
Expand All @@ -72,9 +72,6 @@ class SCANVAE(VAE): #inherits from VAE class (for instance inherits z_encoder)
Keyword args for :class:`~scvi.module.VAE`
"""


#--------------------------------INIT-----------------------------------------------------------------------------------------------------------

def __init__(
self,
n_input: int,
Expand All @@ -83,20 +80,19 @@ def __init__(
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
n_continuous_cov: int = 0, #in the following, we assume only one categorical covariate with categories, which represents the common case of having multiple batches of data.

n_continuous_cov: int = 0, # in the following, we assume only one categorical covariate with categories, which represents the common case of having multiple batches of data.
n_cats_per_cov: Optional[Iterable[int]] = None,
dropout_rate: float = 0.1,
dispersion: str = "gene",
log_variational: bool = True,
gene_likelihood: str = "zinb",
y_prior=None,
labels_groups: Sequence[int] = None, #??
labels_groups: Sequence[int] = None,
use_labels_groups: bool = False,
classifier_parameters: dict = dict(),
use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both",
use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none",
n_version = 0, # 0 denotes the old one without the fix
n_version=0, # 0 denotes the old one without the fix
**vae_kwargs
):
super().__init__(
Expand Down Expand Up @@ -124,23 +120,22 @@ def __init__(
self.n_version = n_version
self.n_labels = n_labels


# Classifier takes n_latent as input
cls_parameters = {
"n_layers": n_layers,
"n_hidden": n_hidden,
"dropout_rate": dropout_rate,
}
cls_parameters.update(classifier_parameters)
self.classifier = Classifier( #PROBABILISTIC CELL-TYPE ANNOTATION? n_hidden kept as default, classifies between n_labels
n_latent, #Number of input dimensions
self.classifier = Classifier( # PROBABILISTIC CELL-TYPE ANNOTATION? n_hidden kept as default, classifies between n_labels
n_latent, # Number of input dimensions
n_labels=n_labels,
use_batch_norm=use_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
**cls_parameters
)

self.encoder_z2_z1 = Encoder( #q(z2|z1,....) ???
self.encoder_z2_z1 = Encoder( # q(z2|z1,....)
n_latent,
n_latent,
n_cat_list=[self.n_labels],
Expand All @@ -150,7 +145,7 @@ def __init__(
use_batch_norm=use_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
)
self.decoder_z1_z2 = Decoder( # p(z1|z2,....) ????
self.decoder_z1_z2 = Decoder( # p(z1|z2,....)
n_latent,
n_latent,
n_cat_list=[self.n_labels],
Expand All @@ -160,7 +155,7 @@ def __init__(
use_layer_norm=use_layer_norm_decoder,
)

self.y_prior = torch.nn.Parameter( #uniform probabilities for categorical distribution on the cell type HERE Y=C
self.y_prior = torch.nn.Parameter( # uniform probabilities for categorical distribution on the cell type HERE Y=C
y_prior
if y_prior is not None
else (1 / n_labels) * torch.ones(1, n_labels),
Expand All @@ -170,7 +165,7 @@ def __init__(
self.labels_groups = (
np.array(labels_groups) if labels_groups is not None else None
)
if self.use_labels_groups:
if self.use_labels_groups:
if labels_groups is None:
raise ValueError("Specify label groups")
unique_groups = np.unique(self.labels_groups)
Expand All @@ -194,12 +189,9 @@ def __init__(
]
)


#---------------------------------------METHODS----------------------------------------------------------------------------------------------------------------------------

@auto_move_data
def classify(self, x, batch_index=None, cont_covs=None, cat_covs=None):
if self.log_variational: #for numerical stability
if self.log_variational: # for numerical stability
x = torch.log(1 + x)

if cont_covs is not None and self.encode_covariates:
Expand All @@ -211,9 +203,11 @@ def classify(self, x, batch_index=None, cont_covs=None, cat_covs=None):
else:
categorical_input = tuple()

qz_m, _, z = self.z_encoder(encoder_input, batch_index, *categorical_input) #q(z1|x) without the var qz_v
qz_m, _, z = self.z_encoder(
encoder_input, batch_index, *categorical_input
) # q(z1|x) without the var qz_v
# We classify using the inferred mean parameter of z_1 in the latent space
z = qz_m
z = qz_m
if self.use_labels_groups:
w_g = self.classifier_groups(z)
unw_y = self.classifier(z)
Expand All @@ -229,7 +223,9 @@ def classify(self, x, batch_index=None, cont_covs=None, cat_covs=None):
return w_y

@auto_move_data
def classification_loss(self, labelled_dataset): #add a classifiaction loss ON THE LABELLED ATA, following Kingma et al
def classification_loss(
self, labelled_dataset
): # add a classifiaction loss ON THE LABELLED ATA, following Kingma et al
x = labelled_dataset[REGISTRY_KEYS.X_KEY]
y = labelled_dataset[REGISTRY_KEYS.LABELS_KEY]
batch_idx = labelled_dataset[REGISTRY_KEYS.BATCH_KEY]
Expand All @@ -246,7 +242,7 @@ def classification_loss(self, labelled_dataset): #add a classifiaction los
self.classify(
x, batch_index=batch_idx, cat_covs=cat_covs, cont_covs=cont_covs
),
y.view(-1).long(),
y.view(-1).long(),
)
return classification_loss

Expand All @@ -255,9 +251,9 @@ def loss(
tensors,
inference_outputs,
generative_ouputs,
feed_labels=False, #? ---> 2 dataloaders, for annotated and un annotated, don't feed labels for un annotated
feed_labels=False, # ? ---> 2 dataloaders, for annotated and un annotated, don't feed labels for un annotated
kl_weight=1,
labelled_tensors=None, #?? -->scvanvi.py
labelled_tensors=None,
classification_ratio=None,
):
px_r = generative_ouputs["px_r"]
Expand All @@ -274,32 +270,34 @@ def loss(
else:
y = None


is_labelled = False if y is None else True #important for ELBO
is_labelled = False if y is None else True

# Enumerate choices of label
ys, z1s = broadcast_labels(y, z1, n_broadcast=self.n_labels) #one-hot encoding of the labels
#if z1 is of size (batch_size,latent), z1_s is of size (n_labels*batch_size,latent)
qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys) #q(z2|z1,..)
pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) #p(z1|z2,..)
ys, z1s = broadcast_labels(
y, z1, n_broadcast=self.n_labels
) # one-hot encoding of the labels
# if z1 is of size (batch_size,latent), z1_s is of size (n_labels*batch_size,latent)
qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys) # q(z2|z1,..)
pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) # p(z1|z2,..)

reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) #expectation of log(p),as in scvi
reconst_loss = self.get_reconstruction_loss(
x, px_rate, px_r, px_dropout
) # expectation of log(p),as in scvi

# KL Divergence
mean = torch.zeros_like(qz2_m)
scale = torch.ones_like(qz2_v)

kl_divergence_z2 = kl(
Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale) #q(z2|z1,..)||p(z2)
Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale) # q(z2|z1,..)||p(z2)
).sum(dim=1)

loss_z1_unweight = -Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1)
#Sum of the log of the Normal probability density evaluated at value z1s. The sum is over the 10-dim latent space.
# Sum of the log of the Normal probability density evaluated at value z1s. The sum is over the latent space.

loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1)

loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1) #????ok

if not self.use_observed_lib_size: #comme dans le user guide, si l_n is latent!
if not self.use_observed_lib_size:
ql_m = inference_outputs["ql_m"]
ql_v = inference_outputs["ql_v"]
(
Expand All @@ -309,65 +307,88 @@ def loss(

kl_divergence_l = kl(
Normal(ql_m, torch.sqrt(ql_v)),
Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), #ok
Normal(local_library_log_means, torch.sqrt(local_library_log_vars)),
).sum(dim=1)
else:
kl_divergence_l = torch.tensor(0.0) #indeed si tu observes l il ne sera plus dans la var dist
kl_divergence_l = torch.tensor(0.0)

print('The version is: ', self.n_version)

#if is_labelled:
if labelled_tensors is not None:
print("--------------------labelled_tensors is not None-------------------------")
if self.n_version == 1:
print("Adding KLs to the loss...")
loss = reconst_loss.mean()+loss_z1_weight.mean()+loss_z1_unweight.mean()+ kl_weight*(kl_divergence_z2.mean()+kl_divergence_l.mean()) # add kl terms here
# else:
# print("The loss is unchanged...")
# loss = reconst_loss.mean() + loss_z1_weight.mean() + loss_z1_unweight.mean()

loss_z1_weight_mean = loss_z1_weight.mean()
loss_z1_unweight_mean = loss_z1_unweight.mean()
kl_divergence_z2_mean = kl_divergence_z2.mean()
kl_divergence_l_mean = kl_divergence_l.mean()
loss = reconst_loss.mean()+loss_z1_weight_mean+loss_z1_unweight_mean+ kl_weight*(kl_divergence_z2_mean+kl_divergence_l_mean) # add kl terms here

kl_locals = {
"kl_divergence_z2": kl_divergence_z2, #in scvi, this is added to the loss?
"kl_divergence_z2": kl_divergence_z2,
"kl_divergence_l": kl_divergence_l,
}
#if labelled_tensors is not None:
#print("And labelled_tensors is not None")
classifier_loss = self.classification_loss(labelled_tensors)
loss += classifier_loss * classification_ratio
return LossRecorder(
loss,
reconst_loss,
kl_locals,
loss_z1_weight_mean = loss_z1_weight_mean,
loss_z1_unweight_mean = loss_z1_unweight_mean,
kl_divergence_z2_mean = kl_divergence_z2_mean,
kl_divergence_l_mean = kl_divergence_l_mean,
classification_loss=classifier_loss,
n_labelled_tensors=labelled_tensors[REGISTRY_KEYS.X_KEY].shape[0],
)

# the ELBO in the case where C=Y is not observed
probs = self.classifier(z1) #outputs a vector of size n_labels suming to 1
probs = self.classifier(z1) # outputs a vector of size n_labels suming to 1
reconst_loss += loss_z1_weight + (
(loss_z1_unweight).view(self.n_labels, -1).t() * probs
).sum(dim=1) #why loss_z1_weight is not in the sum?
).sum(dim=1)

kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() * probs).sum(
dim=1
)
kl_divergence += kl(
kl_divergence_cat = kl(
Categorical(probs=probs),
Categorical(probs=self.y_prior.repeat(probs.size(0), 1)),
)

kl_divergence += kl_divergence_cat

kl_divergence += kl_divergence_l

loss = torch.mean(reconst_loss + kl_divergence * kl_weight) #annealing to avoid posterior collapse!!!
loss = torch.mean(reconst_loss + kl_divergence * kl_weight)

loss_z1_weight_mean = loss_z1_weight.mean()
loss_z1_unweight_mean = loss_z1_unweight.mean()
kl_divergence_z2_mean = kl_divergence_z2.mean()
kl_divergence_l_mean = kl_divergence_l.mean()
kl_divergence_cat_mean = kl_divergence_cat.mean()


if labelled_tensors is not None:
if self._version == 0:
if self.n_version == 0:
classifier_loss = self.classification_loss(labelled_tensors)
loss += classifier_loss * classification_ratio
return LossRecorder(
loss,
reconst_loss,
kl_divergence,
loss_z1_weight_mean = loss_z1_weight_mean,
kl_divergence_cat_mean = kl_divergence_cat_mean,
loss_z1_unweight_mean = loss_z1_unweight_mean,
kl_divergence_z2_mean = kl_divergence_z2_mean,
kl_divergence_l_mean = kl_divergence_l_mean,
classification_loss=classifier_loss,
)

# print('Hi')

return LossRecorder(loss, reconst_loss, kl_divergence,
loss_z1_weight_mean = loss_z1_weight_mean,
kl_divergence_cat_mean = kl_divergence_cat_mean,
loss_z1_unweight_mean = loss_z1_unweight_mean,
kl_divergence_z2_mean = kl_divergence_z2_mean,
kl_divergence_l_mean = kl_divergence_l_mean,
)

return LossRecorder(loss, reconst_loss, kl_divergence)
Loading