From b9fbec1fb9d11aa116c18dd0a3c5795f7dd10a91 Mon Sep 17 00:00:00 2001 From: Nathan Levy Date: Wed, 25 May 2022 05:14:11 -0700 Subject: [PATCH 01/12] branch --- scvi/module/_scanvae.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index bdebe692a5..abd7dc5106 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -314,13 +314,13 @@ def loss( else: kl_divergence_l = torch.tensor(0.0) #indeed si tu observes l il ne sera plus dans la var dist - print('The version is: ', self.n_version) + #print('The version is: ', self.n_version) #if is_labelled: if labelled_tensors is not None: - print("--------------------labelled_tensors is not None-------------------------") + #print("--------------------labelled_tensors is not None-------------------------") if self.n_version == 1: - print("Adding KLs to the loss...") + #print("Adding KLs to the loss...") loss = reconst_loss.mean()+loss_z1_weight.mean()+loss_z1_unweight.mean()+ kl_weight*(kl_divergence_z2.mean()+kl_divergence_l.mean()) # add kl terms here # else: # print("The loss is unchanged...") From f6c1e10a979547c54be1f131e44a852bee3c72bc Mon Sep 17 00:00:00 2001 From: Nathan Levy Date: Sun, 22 May 2022 04:49:22 -0700 Subject: [PATCH 02/12] fixed n_version --- scvi/model/_scanvi.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scvi/model/_scanvi.py b/scvi/model/_scanvi.py index a2ee76615b..3353ff57e9 100644 --- a/scvi/model/_scanvi.py +++ b/scvi/model/_scanvi.py @@ -95,6 +95,7 @@ def __init__( dropout_rate: float = 0.1, dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", + #n_version = 0, **model_kwargs, ): super(SCANVI, self).__init__(adata) @@ -137,6 +138,7 @@ def __init__( use_size_factor_key=use_size_factor_key, library_log_means=library_log_means, library_log_vars=library_log_vars, + #n_version = n_version, **scanvae_model_kwargs, ) From 15ee46d64764a9a6fc7df146bdb9fa45a86db3d6 Mon Sep 17 00:00:00 2001 From: Nathan Levy Date: Sun, 22 May 2022 05:43:07 -0700 Subject: [PATCH 03/12] code with two versions --- scvi/module/_scanvae.py | 59 ++++++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index 6701a6d123..30dea01469 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -25,6 +25,7 @@ class SCANVAE(VAE): Parameters ---------- + n_version n_input Number of input genes n_batch @@ -91,6 +92,7 @@ def __init__( classifier_parameters: dict = dict(), use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", + n_version = 0, # 0 denotes the old one without the fix **vae_kwargs ): super().__init__( @@ -115,6 +117,7 @@ def __init__( use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both" use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both" + self.n_version = n_version self.n_labels = n_labels # Classifier takes n_latent as input cls_parameters = { @@ -293,28 +296,39 @@ def loss( else: kl_divergence_l = 0.0 - if is_labelled: - loss = reconst_loss + loss_z1_weight + loss_z1_unweight + print('The version is: ', self.n_version) + + #if is_labelled: + if labelled_tensors is not None: + print("--------------------labelled_tensors is not None-------------------------") + if self.n_version == 1: + print("Adding KLs to the loss...") + loss = reconst_loss + loss_z1_weight + loss_z1_unweight + kl_divergence_z2 + kl_divergence_l # add kl terms here + else: + print("The loss is unchanged...") + loss = reconst_loss + loss_z1_weight + loss_z1_unweight + kl_locals = { "kl_divergence_z2": kl_divergence_z2, "kl_divergence_l": kl_divergence_l, } - if labelled_tensors is not None: - classifier_loss = self.classification_loss(labelled_tensors) - loss += classifier_loss * classification_ratio - return LossRecorder( - loss, - reconst_loss, - kl_locals, - classification_loss=classifier_loss, - n_labelled_tensors=labelled_tensors[REGISTRY_KEYS.X_KEY].shape[0], - ) + #if labelled_tensors is not None: + #print("And labelled_tensors is not None") + classifier_loss = self.classification_loss(labelled_tensors) + loss += classifier_loss * classification_ratio return LossRecorder( loss, reconst_loss, kl_locals, - kl_global=torch.tensor(0.0), + classification_loss=classifier_loss, + n_labelled_tensors=labelled_tensors[REGISTRY_KEYS.X_KEY].shape[0], ) + #return LossRecorder( + # loss, + # reconst_loss, + # kl_locals, + # kl_global=torch.tensor(0.0), + #) probs = self.classifier(z1) reconst_loss += loss_z1_weight + ( @@ -332,13 +346,14 @@ def loss( loss = torch.mean(reconst_loss + kl_divergence * kl_weight) - if labelled_tensors is not None: - classifier_loss = self.classification_loss(labelled_tensors) - loss += classifier_loss * classification_ratio - return LossRecorder( - loss, - reconst_loss, - kl_divergence, - classification_loss=classifier_loss, - ) + # if labelled_tensors is not None: + # print("is_labelled=False and labelled_tensors is not None") + # classifier_loss = self.classification_loss(labelled_tensors) + # loss += classifier_loss * classification_ratio + # return LossRecorder( + # loss, + # reconst_loss, + # kl_divergence, + # classification_loss=classifier_loss, + # ) return LossRecorder(loss, reconst_loss, kl_divergence) From 418b00ef07af972696d5e2fa411446151a8c71aa Mon Sep 17 00:00:00 2001 From: Nathan Levy Date: Sun, 22 May 2022 16:21:22 +0300 Subject: [PATCH 04/12] commented SCANVAE file --- scvi/module/_scanvae.py | 91 +++++++++++++++++++++++++---------------- 1 file changed, 55 insertions(+), 36 deletions(-) diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index 30dea01469..3886a8e8e9 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -2,7 +2,7 @@ import numpy as np import torch -from torch.distributions import Categorical, Normal +from torch.distributions import Categorical, Normal #ok from torch.distributions import kl_divergence as kl from torch.nn import functional as F @@ -11,12 +11,12 @@ from scvi.module.base import LossRecorder, auto_move_data from scvi.nn import Decoder, Encoder -from ._classifier import Classifier +from ._classifier import Classifier #Basic fully-connected NN classifier. from ._utils import broadcast_labels from ._vae import VAE -class SCANVAE(VAE): +class SCANVAE(VAE): #inherits from VAE class (for instance inherits z_encoder) """ Single-cell annotation using variational inference. @@ -39,7 +39,7 @@ class SCANVAE(VAE): n_layers Number of hidden layers used for encoder and decoder NNs n_continuous_cov - Number of continuous covarites + Number of continuous covariates n_cats_per_cov Number of categories for each extra categorical covariate dropout_rate @@ -59,9 +59,9 @@ class SCANVAE(VAE): * ``'nb'`` - Negative binomial distribution * ``'zinb'`` - Zero-inflated negative binomial distribution y_prior - If None, initialized to uniform probability over cell types + If None, initialized to uniform probability over cell types OK labels_groups - Label group designations + Label group designations ?? --> hierarchie entre labels use_labels_groups Whether to use the label groups use_batch_norm @@ -72,6 +72,9 @@ class SCANVAE(VAE): Keyword args for :class:`~scvi.module.VAE` """ + + #--------------------------------INIT----------------------------------------------------------------------------------------------------------- + def __init__( self, n_input: int, @@ -80,14 +83,15 @@ def __init__( n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, - n_continuous_cov: int = 0, + n_continuous_cov: int = 0, #in the following, we assume only one categorical covariate with categories, which represents the common case of having multiple batches of data. + n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, gene_likelihood: str = "zinb", y_prior=None, - labels_groups: Sequence[int] = None, + labels_groups: Sequence[int] = None, #?? use_labels_groups: bool = False, classifier_parameters: dict = dict(), use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", @@ -119,6 +123,8 @@ def __init__( self.n_version = n_version self.n_labels = n_labels + + # Classifier takes n_latent as input cls_parameters = { "n_layers": n_layers, @@ -126,15 +132,15 @@ def __init__( "dropout_rate": dropout_rate, } cls_parameters.update(classifier_parameters) - self.classifier = Classifier( - n_latent, + self.classifier = Classifier( #PROBABILISTIC CELL-TYPE ANNOTATION? n_hidden kept as default, classifies between n_labels + n_latent, #Number of input dimensions n_labels=n_labels, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, **cls_parameters ) - self.encoder_z2_z1 = Encoder( + self.encoder_z2_z1 = Encoder( #q(z2|z1,....) ??? n_latent, n_latent, n_cat_list=[self.n_labels], @@ -144,7 +150,7 @@ def __init__( use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, ) - self.decoder_z1_z2 = Decoder( + self.decoder_z1_z2 = Decoder( # p(z1|z2,....) ???? n_latent, n_latent, n_cat_list=[self.n_labels], @@ -154,7 +160,7 @@ def __init__( use_layer_norm=use_layer_norm_decoder, ) - self.y_prior = torch.nn.Parameter( + self.y_prior = torch.nn.Parameter( #uniform probabilities for categorical distribution on the cell type HERE Y=C y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), @@ -164,13 +170,14 @@ def __init__( self.labels_groups = ( np.array(labels_groups) if labels_groups is not None else None ) - if self.use_labels_groups: + if self.use_labels_groups: if labels_groups is None: raise ValueError("Specify label groups") unique_groups = np.unique(self.labels_groups) self.n_groups = len(unique_groups) if not (unique_groups == np.arange(self.n_groups)).all(): raise ValueError() + self.classifier_groups = Classifier( n_latent, n_hidden, self.n_groups, n_layers, dropout_rate ) @@ -187,9 +194,12 @@ def __init__( ] ) + + #---------------------------------------METHODS---------------------------------------------------------------------------------------------------------------------------- + @auto_move_data def classify(self, x, batch_index=None, cont_covs=None, cat_covs=None): - if self.log_variational: + if self.log_variational: #for numerical stability x = torch.log(1 + x) if cont_covs is not None and self.encode_covariates: @@ -201,9 +211,9 @@ def classify(self, x, batch_index=None, cont_covs=None, cat_covs=None): else: categorical_input = tuple() - qz_m, _, z = self.z_encoder(encoder_input, batch_index, *categorical_input) + qz_m, _, z = self.z_encoder(encoder_input, batch_index, *categorical_input) #q(z1|x) without the var qz_v # We classify using the inferred mean parameter of z_1 in the latent space - z = qz_m + z = qz_m if self.use_labels_groups: w_g = self.classifier_groups(z) unw_y = self.classifier(z) @@ -219,7 +229,7 @@ def classify(self, x, batch_index=None, cont_covs=None, cat_covs=None): return w_y @auto_move_data - def classification_loss(self, labelled_dataset): + def classification_loss(self, labelled_dataset): #add a classifiaction loss ON THE LABELLED ATA, following Kingma et al x = labelled_dataset[REGISTRY_KEYS.X_KEY] y = labelled_dataset[REGISTRY_KEYS.LABELS_KEY] batch_idx = labelled_dataset[REGISTRY_KEYS.BATCH_KEY] @@ -236,7 +246,7 @@ def classification_loss(self, labelled_dataset): self.classify( x, batch_index=batch_idx, cat_covs=cat_covs, cont_covs=cont_covs ), - y.view(-1).long(), + y.view(-1).long(), ) return classification_loss @@ -245,9 +255,9 @@ def loss( tensors, inference_outputs, generative_ouputs, - feed_labels=False, + feed_labels=False, #? ---> 2 dataloaders, for annotated and un annotated, don't feed labels for un annotated kl_weight=1, - labelled_tensors=None, + labelled_tensors=None, #?? -->scvanvi.py classification_ratio=None, ): px_r = generative_ouputs["px_r"] @@ -263,25 +273,33 @@ def loss( y = tensors[REGISTRY_KEYS.LABELS_KEY] else: y = None - is_labelled = False if y is None else True + + + is_labelled = False if y is None else True #important for ELBO # Enumerate choices of label - ys, z1s = broadcast_labels(y, z1, n_broadcast=self.n_labels) - qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys) - pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) + ys, z1s = broadcast_labels(y, z1, n_broadcast=self.n_labels) #one-hot encoding of the labels + #if z1 is of size (batch_size,latent), z1_s is of size (n_labels*batch_size,latent) + qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys) #q(z2|z1,..) + pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) #p(z1|z2,..) - reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) + reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) #expectation of log(p),as in scvi # KL Divergence mean = torch.zeros_like(qz2_m) scale = torch.ones_like(qz2_v) kl_divergence_z2 = kl( - Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale) + Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale) #q(z2|z1,..)||p(z2) ).sum(dim=1) + loss_z1_unweight = -Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1) - loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1) - if not self.use_observed_lib_size: + #Sum of the log of the Normal probability density evaluated at value z1s. The sum is over the 10-dim latent space. + + + loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1) #????ok + + if not self.use_observed_lib_size: #comme dans le user guide, si l_n is latent! ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] ( @@ -291,10 +309,10 @@ def loss( kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), - Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), + Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), #ok ).sum(dim=1) else: - kl_divergence_l = 0.0 + kl_divergence_l = 0.0 #indeed si tu observes l il ne sera plus dans la var dist print('The version is: ', self.n_version) @@ -309,7 +327,7 @@ def loss( loss = reconst_loss + loss_z1_weight + loss_z1_unweight kl_locals = { - "kl_divergence_z2": kl_divergence_z2, + "kl_divergence_z2": kl_divergence_z2, #in scvi, this is added to the loss? "kl_divergence_l": kl_divergence_l, } #if labelled_tensors is not None: @@ -330,10 +348,11 @@ def loss( # kl_global=torch.tensor(0.0), #) - probs = self.classifier(z1) + # the ELBO in the case where C=Y is not observed + probs = self.classifier(z1) #outputs a vector of size n_labels suming to 1 reconst_loss += loss_z1_weight + ( (loss_z1_unweight).view(self.n_labels, -1).t() * probs - ).sum(dim=1) + ).sum(dim=1) #why loss_z1_weight is not in the sum? kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() * probs).sum( dim=1 @@ -344,7 +363,7 @@ def loss( ) kl_divergence += kl_divergence_l - loss = torch.mean(reconst_loss + kl_divergence * kl_weight) + loss = torch.mean(reconst_loss + kl_divergence * kl_weight) #annealing to avoid posterior collapse!!! # if labelled_tensors is not None: # print("is_labelled=False and labelled_tensors is not None") @@ -356,4 +375,4 @@ def loss( # kl_divergence, # classification_loss=classifier_loss, # ) - return LossRecorder(loss, reconst_loss, kl_divergence) + return LossRecorder(loss, reconst_loss, kl_divergence) From 2eaeb53cdca696ce5ab133bb66b44b4cc0be00c4 Mon Sep 17 00:00:00 2001 From: Nathan Levy Date: Mon, 23 May 2022 19:13:41 +0300 Subject: [PATCH 05/12] amended _scanvae --- scvi/model/_scanvi.py | 2 -- scvi/module/_scanvae.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/scvi/model/_scanvi.py b/scvi/model/_scanvi.py index 3353ff57e9..a2ee76615b 100644 --- a/scvi/model/_scanvi.py +++ b/scvi/model/_scanvi.py @@ -95,7 +95,6 @@ def __init__( dropout_rate: float = 0.1, dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", - #n_version = 0, **model_kwargs, ): super(SCANVI, self).__init__(adata) @@ -138,7 +137,6 @@ def __init__( use_size_factor_key=use_size_factor_key, library_log_means=library_log_means, library_log_vars=library_log_vars, - #n_version = n_version, **scanvae_model_kwargs, ) diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index 3886a8e8e9..dd81e98b06 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -321,7 +321,7 @@ def loss( print("--------------------labelled_tensors is not None-------------------------") if self.n_version == 1: print("Adding KLs to the loss...") - loss = reconst_loss + loss_z1_weight + loss_z1_unweight + kl_divergence_z2 + kl_divergence_l # add kl terms here + loss = reconst_loss + loss_z1_weight + loss_z1_unweight + kl_divergence_z2 #+ kl_divergence_l # add kl terms here else: print("The loss is unchanged...") loss = reconst_loss + loss_z1_weight + loss_z1_unweight @@ -375,4 +375,4 @@ def loss( # kl_divergence, # classification_loss=classifier_loss, # ) - return LossRecorder(loss, reconst_loss, kl_divergence) + return LossRecorder(loss, reconst_loss, kl_divergence) From 505a628947b526867a29020e54cf6c68fb0ac657 Mon Sep 17 00:00:00 2001 From: Nathan Levy Date: Mon, 23 May 2022 21:09:04 +0300 Subject: [PATCH 06/12] test new version --- scvi/module/_scanvae.py | 70 +++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 37 deletions(-) diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index dd81e98b06..b4dcfe0825 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -312,7 +312,7 @@ def loss( Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), #ok ).sum(dim=1) else: - kl_divergence_l = 0.0 #indeed si tu observes l il ne sera plus dans la var dist + kl_divergence_l = torch.tensor(0.0) #indeed si tu observes l il ne sera plus dans la var dist print('The version is: ', self.n_version) @@ -321,32 +321,26 @@ def loss( print("--------------------labelled_tensors is not None-------------------------") if self.n_version == 1: print("Adding KLs to the loss...") - loss = reconst_loss + loss_z1_weight + loss_z1_unweight + kl_divergence_z2 #+ kl_divergence_l # add kl terms here - else: - print("The loss is unchanged...") - loss = reconst_loss + loss_z1_weight + loss_z1_unweight - - kl_locals = { - "kl_divergence_z2": kl_divergence_z2, #in scvi, this is added to the loss? - "kl_divergence_l": kl_divergence_l, - } - #if labelled_tensors is not None: - #print("And labelled_tensors is not None") - classifier_loss = self.classification_loss(labelled_tensors) - loss += classifier_loss * classification_ratio - return LossRecorder( - loss, - reconst_loss, - kl_locals, - classification_loss=classifier_loss, - n_labelled_tensors=labelled_tensors[REGISTRY_KEYS.X_KEY].shape[0], - ) - #return LossRecorder( - # loss, - # reconst_loss, - # kl_locals, - # kl_global=torch.tensor(0.0), - #) + loss = reconst_loss.mean()+loss_z1_weight.mean()+loss_z1_unweight.mean()+kl_divergence_z2.mean()+kl_divergence_l.mean() # add kl terms here + # else: + # print("The loss is unchanged...") + # loss = reconst_loss.mean() + loss_z1_weight.mean() + loss_z1_unweight.mean() + + kl_locals = { + "kl_divergence_z2": kl_divergence_z2, #in scvi, this is added to the loss? + "kl_divergence_l": kl_divergence_l, + } + #if labelled_tensors is not None: + #print("And labelled_tensors is not None") + classifier_loss = self.classification_loss(labelled_tensors) + loss += classifier_loss * classification_ratio + return LossRecorder( + loss, + reconst_loss, + kl_locals, + classification_loss=classifier_loss, + n_labelled_tensors=labelled_tensors[REGISTRY_KEYS.X_KEY].shape[0], + ) # the ELBO in the case where C=Y is not observed probs = self.classifier(z1) #outputs a vector of size n_labels suming to 1 @@ -365,14 +359,16 @@ def loss( loss = torch.mean(reconst_loss + kl_divergence * kl_weight) #annealing to avoid posterior collapse!!! - # if labelled_tensors is not None: - # print("is_labelled=False and labelled_tensors is not None") - # classifier_loss = self.classification_loss(labelled_tensors) - # loss += classifier_loss * classification_ratio - # return LossRecorder( - # loss, - # reconst_loss, - # kl_divergence, - # classification_loss=classifier_loss, - # ) + if labelled_tensors is not None: + if self._version == 0: + classifier_loss = self.classification_loss(labelled_tensors) + loss += classifier_loss * classification_ratio + return LossRecorder( + loss, + reconst_loss, + kl_divergence, + classification_loss=classifier_loss, + ) + return LossRecorder(loss, reconst_loss, kl_divergence) + \ No newline at end of file From 3b432c4fff0cd738ecf1b66d15086e6442e50aeb Mon Sep 17 00:00:00 2001 From: Nathan Levy Date: Mon, 23 May 2022 23:17:15 +0300 Subject: [PATCH 07/12] add kl_weight --- scvi/module/_scanvae.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index b4dcfe0825..bdebe692a5 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -321,7 +321,7 @@ def loss( print("--------------------labelled_tensors is not None-------------------------") if self.n_version == 1: print("Adding KLs to the loss...") - loss = reconst_loss.mean()+loss_z1_weight.mean()+loss_z1_unweight.mean()+kl_divergence_z2.mean()+kl_divergence_l.mean() # add kl terms here + loss = reconst_loss.mean()+loss_z1_weight.mean()+loss_z1_unweight.mean()+ kl_weight*(kl_divergence_z2.mean()+kl_divergence_l.mean()) # add kl terms here # else: # print("The loss is unchanged...") # loss = reconst_loss.mean() + loss_z1_weight.mean() + loss_z1_unweight.mean() @@ -371,4 +371,3 @@ def loss( ) return LossRecorder(loss, reconst_loss, kl_divergence) - \ No newline at end of file From 38e0a95988ea144f25f7f56280288b062cb75f10 Mon Sep 17 00:00:00 2001 From: Nathan Levy Date: Wed, 25 May 2022 05:14:11 -0700 Subject: [PATCH 08/12] branch --- scvi/module/_scanvae.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index bdebe692a5..abd7dc5106 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -314,13 +314,13 @@ def loss( else: kl_divergence_l = torch.tensor(0.0) #indeed si tu observes l il ne sera plus dans la var dist - print('The version is: ', self.n_version) + #print('The version is: ', self.n_version) #if is_labelled: if labelled_tensors is not None: - print("--------------------labelled_tensors is not None-------------------------") + #print("--------------------labelled_tensors is not None-------------------------") if self.n_version == 1: - print("Adding KLs to the loss...") + #print("Adding KLs to the loss...") loss = reconst_loss.mean()+loss_z1_weight.mean()+loss_z1_unweight.mean()+ kl_weight*(kl_divergence_z2.mean()+kl_divergence_l.mean()) # add kl terms here # else: # print("The loss is unchanged...") From 8a12c3c5a97977b21d70c0093c049a8feea30e37 Mon Sep 17 00:00:00 2001 From: NathanAzoL Date: Thu, 26 May 2022 16:56:14 -0700 Subject: [PATCH 09/12] cleaned file --- scvi/module/_scanvae.py | 42 ++++++++++++++++------------------------- 1 file changed, 16 insertions(+), 26 deletions(-) diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index abd7dc5106..0d90458493 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -26,6 +26,7 @@ class SCANVAE(VAE): #inherits from VAE class (for instance inherits z_encoder) Parameters ---------- n_version + n_input Number of input genes n_batch @@ -61,8 +62,8 @@ class SCANVAE(VAE): #inherits from VAE class (for instance inherits z_encoder) y_prior If None, initialized to uniform probability over cell types OK labels_groups - Label group designations ?? --> hierarchie entre labels - use_labels_groups + Label group designations + use_labels_groups Whether to use the label groups use_batch_norm Whether to use batch norm in layers @@ -73,7 +74,6 @@ class SCANVAE(VAE): #inherits from VAE class (for instance inherits z_encoder) """ - #--------------------------------INIT----------------------------------------------------------------------------------------------------------- def __init__( self, @@ -91,7 +91,7 @@ def __init__( log_variational: bool = True, gene_likelihood: str = "zinb", y_prior=None, - labels_groups: Sequence[int] = None, #?? + labels_groups: Sequence[int] = None, use_labels_groups: bool = False, classifier_parameters: dict = dict(), use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", @@ -140,7 +140,7 @@ def __init__( **cls_parameters ) - self.encoder_z2_z1 = Encoder( #q(z2|z1,....) ??? + self.encoder_z2_z1 = Encoder( #q(z2|z1,....) n_latent, n_latent, n_cat_list=[self.n_labels], @@ -150,7 +150,7 @@ def __init__( use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, ) - self.decoder_z1_z2 = Decoder( # p(z1|z2,....) ???? + self.decoder_z1_z2 = Decoder( # p(z1|z2,....) n_latent, n_latent, n_cat_list=[self.n_labels], @@ -195,7 +195,6 @@ def __init__( ) - #---------------------------------------METHODS---------------------------------------------------------------------------------------------------------------------------- @auto_move_data def classify(self, x, batch_index=None, cont_covs=None, cat_covs=None): @@ -257,7 +256,7 @@ def loss( generative_ouputs, feed_labels=False, #? ---> 2 dataloaders, for annotated and un annotated, don't feed labels for un annotated kl_weight=1, - labelled_tensors=None, #?? -->scvanvi.py + labelled_tensors=None, classification_ratio=None, ): px_r = generative_ouputs["px_r"] @@ -275,7 +274,7 @@ def loss( y = None - is_labelled = False if y is None else True #important for ELBO + is_labelled = False if y is None else True # Enumerate choices of label ys, z1s = broadcast_labels(y, z1, n_broadcast=self.n_labels) #one-hot encoding of the labels @@ -294,12 +293,12 @@ def loss( ).sum(dim=1) loss_z1_unweight = -Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1) - #Sum of the log of the Normal probability density evaluated at value z1s. The sum is over the 10-dim latent space. + #Sum of the log of the Normal probability density evaluated at value z1s. The sum is over the latent space. - loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1) #????ok + loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1) - if not self.use_observed_lib_size: #comme dans le user guide, si l_n is latent! + if not self.use_observed_lib_size: ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] ( @@ -309,29 +308,20 @@ def loss( kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), - Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), #ok + Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), ).sum(dim=1) else: - kl_divergence_l = torch.tensor(0.0) #indeed si tu observes l il ne sera plus dans la var dist + kl_divergence_l = torch.tensor(0.0) - #print('The version is: ', self.n_version) - #if is_labelled: if labelled_tensors is not None: - #print("--------------------labelled_tensors is not None-------------------------") if self.n_version == 1: - #print("Adding KLs to the loss...") loss = reconst_loss.mean()+loss_z1_weight.mean()+loss_z1_unweight.mean()+ kl_weight*(kl_divergence_z2.mean()+kl_divergence_l.mean()) # add kl terms here - # else: - # print("The loss is unchanged...") - # loss = reconst_loss.mean() + loss_z1_weight.mean() + loss_z1_unweight.mean() kl_locals = { - "kl_divergence_z2": kl_divergence_z2, #in scvi, this is added to the loss? + "kl_divergence_z2": kl_divergence_z2, "kl_divergence_l": kl_divergence_l, } - #if labelled_tensors is not None: - #print("And labelled_tensors is not None") classifier_loss = self.classification_loss(labelled_tensors) loss += classifier_loss * classification_ratio return LossRecorder( @@ -346,7 +336,7 @@ def loss( probs = self.classifier(z1) #outputs a vector of size n_labels suming to 1 reconst_loss += loss_z1_weight + ( (loss_z1_unweight).view(self.n_labels, -1).t() * probs - ).sum(dim=1) #why loss_z1_weight is not in the sum? + ).sum(dim=1) kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() * probs).sum( dim=1 @@ -357,7 +347,7 @@ def loss( ) kl_divergence += kl_divergence_l - loss = torch.mean(reconst_loss + kl_divergence * kl_weight) #annealing to avoid posterior collapse!!! + loss = torch.mean(reconst_loss + kl_divergence * kl_weight) if labelled_tensors is not None: if self._version == 0: From 397a6c6e1dfdc9d2f5571fdaad560e41520ed8f6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 May 2022 00:43:49 +0000 Subject: [PATCH 10/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scvi/module/_scanvae.py | 100 +++++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 48 deletions(-) diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index 37a74e4a07..924a58b30a 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -2,7 +2,7 @@ import numpy as np import torch -from torch.distributions import Categorical, Normal #ok +from torch.distributions import Categorical, Normal # ok from torch.distributions import kl_divergence as kl from torch.nn import functional as F @@ -11,12 +11,12 @@ from scvi.module.base import LossRecorder, auto_move_data from scvi.nn import Decoder, Encoder -from ._classifier import Classifier #Basic fully-connected NN classifier. +from ._classifier import Classifier # Basic fully-connected NN classifier. from ._utils import broadcast_labels from ._vae import VAE -class SCANVAE(VAE): #inherits from VAE class (for instance inherits z_encoder) +class SCANVAE(VAE): # inherits from VAE class (for instance inherits z_encoder) """ Single-cell annotation using variational inference. @@ -39,7 +39,7 @@ class SCANVAE(VAE): #inherits from VAE class (for instance inherits z_encoder) n_layers Number of hidden layers used for encoder and decoder NNs n_continuous_cov - Number of continuous covariates + Number of continuous covariates n_cats_per_cov Number of categories for each extra categorical covariate dropout_rate @@ -61,8 +61,8 @@ class SCANVAE(VAE): #inherits from VAE class (for instance inherits z_encoder) y_prior If None, initialized to uniform probability over cell types OK labels_groups - Label group designations - use_labels_groups + Label group designations + use_labels_groups Whether to use the label groups use_batch_norm Whether to use batch norm in layers @@ -72,8 +72,6 @@ class SCANVAE(VAE): #inherits from VAE class (for instance inherits z_encoder) Keyword args for :class:`~scvi.module.VAE` """ - - def __init__( self, n_input: int, @@ -82,20 +80,19 @@ def __init__( n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, - n_continuous_cov: int = 0, #in the following, we assume only one categorical covariate with categories, which represents the common case of having multiple batches of data. - + n_continuous_cov: int = 0, # in the following, we assume only one categorical covariate with categories, which represents the common case of having multiple batches of data. n_cats_per_cov: Optional[Iterable[int]] = None, dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, gene_likelihood: str = "zinb", y_prior=None, - labels_groups: Sequence[int] = None, + labels_groups: Sequence[int] = None, use_labels_groups: bool = False, classifier_parameters: dict = dict(), use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", - n_version = 0, # 0 denotes the old one without the fix + n_version=0, # 0 denotes the old one without the fix **vae_kwargs ): super().__init__( @@ -123,7 +120,6 @@ def __init__( self.n_version = n_version self.n_labels = n_labels - # Classifier takes n_latent as input cls_parameters = { "n_layers": n_layers, @@ -131,15 +127,15 @@ def __init__( "dropout_rate": dropout_rate, } cls_parameters.update(classifier_parameters) - self.classifier = Classifier( #PROBABILISTIC CELL-TYPE ANNOTATION? n_hidden kept as default, classifies between n_labels - n_latent, #Number of input dimensions + self.classifier = Classifier( # PROBABILISTIC CELL-TYPE ANNOTATION? n_hidden kept as default, classifies between n_labels + n_latent, # Number of input dimensions n_labels=n_labels, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, **cls_parameters ) - self.encoder_z2_z1 = Encoder( #q(z2|z1,....) + self.encoder_z2_z1 = Encoder( # q(z2|z1,....) n_latent, n_latent, n_cat_list=[self.n_labels], @@ -149,7 +145,7 @@ def __init__( use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, ) - self.decoder_z1_z2 = Decoder( # p(z1|z2,....) + self.decoder_z1_z2 = Decoder( # p(z1|z2,....) n_latent, n_latent, n_cat_list=[self.n_labels], @@ -159,7 +155,7 @@ def __init__( use_layer_norm=use_layer_norm_decoder, ) - self.y_prior = torch.nn.Parameter( #uniform probabilities for categorical distribution on the cell type HERE Y=C + self.y_prior = torch.nn.Parameter( # uniform probabilities for categorical distribution on the cell type HERE Y=C y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), @@ -169,7 +165,7 @@ def __init__( self.labels_groups = ( np.array(labels_groups) if labels_groups is not None else None ) - if self.use_labels_groups: + if self.use_labels_groups: if labels_groups is None: raise ValueError("Specify label groups") unique_groups = np.unique(self.labels_groups) @@ -193,11 +189,9 @@ def __init__( ] ) - - @auto_move_data def classify(self, x, batch_index=None, cont_covs=None, cat_covs=None): - if self.log_variational: #for numerical stability + if self.log_variational: # for numerical stability x = torch.log(1 + x) if cont_covs is not None and self.encode_covariates: @@ -209,9 +203,11 @@ def classify(self, x, batch_index=None, cont_covs=None, cat_covs=None): else: categorical_input = tuple() - qz_m, _, z = self.z_encoder(encoder_input, batch_index, *categorical_input) #q(z1|x) without the var qz_v + qz_m, _, z = self.z_encoder( + encoder_input, batch_index, *categorical_input + ) # q(z1|x) without the var qz_v # We classify using the inferred mean parameter of z_1 in the latent space - z = qz_m + z = qz_m if self.use_labels_groups: w_g = self.classifier_groups(z) unw_y = self.classifier(z) @@ -227,7 +223,9 @@ def classify(self, x, batch_index=None, cont_covs=None, cat_covs=None): return w_y @auto_move_data - def classification_loss(self, labelled_dataset): #add a classifiaction loss ON THE LABELLED ATA, following Kingma et al + def classification_loss( + self, labelled_dataset + ): # add a classifiaction loss ON THE LABELLED ATA, following Kingma et al x = labelled_dataset[REGISTRY_KEYS.X_KEY] y = labelled_dataset[REGISTRY_KEYS.LABELS_KEY] batch_idx = labelled_dataset[REGISTRY_KEYS.BATCH_KEY] @@ -244,7 +242,7 @@ def classification_loss(self, labelled_dataset): #add a classifiaction los self.classify( x, batch_index=batch_idx, cat_covs=cat_covs, cont_covs=cont_covs ), - y.view(-1).long(), + y.view(-1).long(), ) return classification_loss @@ -253,9 +251,9 @@ def loss( tensors, inference_outputs, generative_ouputs, - feed_labels=False, #? ---> 2 dataloaders, for annotated and un annotated, don't feed labels for un annotated + feed_labels=False, # ? ---> 2 dataloaders, for annotated and un annotated, don't feed labels for un annotated kl_weight=1, - labelled_tensors=None, + labelled_tensors=None, classification_ratio=None, ): px_r = generative_ouputs["px_r"] @@ -272,32 +270,34 @@ def loss( else: y = None - - is_labelled = False if y is None else True + is_labelled = False if y is None else True # Enumerate choices of label - ys, z1s = broadcast_labels(y, z1, n_broadcast=self.n_labels) #one-hot encoding of the labels - #if z1 is of size (batch_size,latent), z1_s is of size (n_labels*batch_size,latent) - qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys) #q(z2|z1,..) - pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) #p(z1|z2,..) + ys, z1s = broadcast_labels( + y, z1, n_broadcast=self.n_labels + ) # one-hot encoding of the labels + # if z1 is of size (batch_size,latent), z1_s is of size (n_labels*batch_size,latent) + qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys) # q(z2|z1,..) + pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) # p(z1|z2,..) - reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) #expectation of log(p),as in scvi + reconst_loss = self.get_reconstruction_loss( + x, px_rate, px_r, px_dropout + ) # expectation of log(p),as in scvi # KL Divergence mean = torch.zeros_like(qz2_m) scale = torch.ones_like(qz2_v) kl_divergence_z2 = kl( - Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale) #q(z2|z1,..)||p(z2) + Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale) # q(z2|z1,..)||p(z2) ).sum(dim=1) loss_z1_unweight = -Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1) - #Sum of the log of the Normal probability density evaluated at value z1s. The sum is over the latent space. + # Sum of the log of the Normal probability density evaluated at value z1s. The sum is over the latent space. + loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1) - loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1) - - if not self.use_observed_lib_size: + if not self.use_observed_lib_size: ql_m = inference_outputs["ql_m"] ql_v = inference_outputs["ql_v"] ( @@ -307,18 +307,22 @@ def loss( kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), - Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), + Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), ).sum(dim=1) else: - kl_divergence_l = torch.tensor(0.0) - + kl_divergence_l = torch.tensor(0.0) if labelled_tensors is not None: if self.n_version == 1: - loss = reconst_loss.mean()+loss_z1_weight.mean()+loss_z1_unweight.mean()+ kl_weight*(kl_divergence_z2.mean()+kl_divergence_l.mean()) # add kl terms here + loss = ( + reconst_loss.mean() + + loss_z1_weight.mean() + + loss_z1_unweight.mean() + + kl_weight * (kl_divergence_z2.mean() + kl_divergence_l.mean()) + ) # add kl terms here kl_locals = { - "kl_divergence_z2": kl_divergence_z2, + "kl_divergence_z2": kl_divergence_z2, "kl_divergence_l": kl_divergence_l, } classifier_loss = self.classification_loss(labelled_tensors) @@ -332,10 +336,10 @@ def loss( ) # the ELBO in the case where C=Y is not observed - probs = self.classifier(z1) #outputs a vector of size n_labels suming to 1 + probs = self.classifier(z1) # outputs a vector of size n_labels suming to 1 reconst_loss += loss_z1_weight + ( (loss_z1_unweight).view(self.n_labels, -1).t() * probs - ).sum(dim=1) + ).sum(dim=1) kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() * probs).sum( dim=1 @@ -346,7 +350,7 @@ def loss( ) kl_divergence += kl_divergence_l - loss = torch.mean(reconst_loss + kl_divergence * kl_weight) + loss = torch.mean(reconst_loss + kl_divergence * kl_weight) if labelled_tensors is not None: if self._version == 0: From 55d023876726ac45fdf988e635a6c54d0266747d Mon Sep 17 00:00:00 2001 From: NathanAzoL Date: Thu, 9 Jun 2022 14:41:57 -0700 Subject: [PATCH 11/12] print before lastrecorder --- scvi/module/_scanvae.py | 42 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index 37a74e4a07..d12025bee2 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -315,7 +315,12 @@ def loss( if labelled_tensors is not None: if self.n_version == 1: - loss = reconst_loss.mean()+loss_z1_weight.mean()+loss_z1_unweight.mean()+ kl_weight*(kl_divergence_z2.mean()+kl_divergence_l.mean()) # add kl terms here + + loss_z1_weight_mean = loss_z1_weight.mean() + loss_z1_unweight_mean = loss_z1_unweight.mean() + kl_divergence_z2_mean = kl_divergence_z2.mean() + kl_divergence_l_mean = kl_divergence_l.mean() + loss = reconst_loss.mean()+loss_z1_weight_mean+loss_z1_unweight_mean+ kl_weight*(kl_divergence_z2_mean+kl_divergence_l_mean) # add kl terms here kl_locals = { "kl_divergence_z2": kl_divergence_z2, @@ -327,6 +332,10 @@ def loss( loss, reconst_loss, kl_locals, + loss_z1_weight_mean = loss_z1_weight_mean, + loss_z1_unweight_mean = loss_z1_unweight_mean, + kl_divergence_z2_mean = kl_divergence_z2_mean, + kl_divergence_l_mean = kl_divergence_l_mean, classification_loss=classifier_loss, n_labelled_tensors=labelled_tensors[REGISTRY_KEYS.X_KEY].shape[0], ) @@ -340,23 +349,46 @@ def loss( kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() * probs).sum( dim=1 ) - kl_divergence += kl( + kl_divergence_cat = kl( Categorical(probs=probs), Categorical(probs=self.y_prior.repeat(probs.size(0), 1)), ) + + kl_divergence += kl_divergence_cat + kl_divergence += kl_divergence_l - loss = torch.mean(reconst_loss + kl_divergence * kl_weight) + loss = torch.mean(reconst_loss + kl_divergence * kl_weight) + + loss_z1_weight_mean = loss_z1_weight.mean() + loss_z1_unweight_mean = loss_z1_unweight.mean() + kl_divergence_z2_mean = kl_divergence_z2.mean() + kl_divergence_l_mean = kl_divergence_l.mean() + kl_divergence_cat_mean = kl_divergence_cat.mean() + if labelled_tensors is not None: - if self._version == 0: + if self.n_version == 0: classifier_loss = self.classification_loss(labelled_tensors) loss += classifier_loss * classification_ratio return LossRecorder( loss, reconst_loss, kl_divergence, + loss_z1_weight_mean = loss_z1_weight_mean, + kl_divergence_cat_mean = kl_divergence_cat_mean, + loss_z1_unweight_mean = loss_z1_unweight_mean, + kl_divergence_z2_mean = kl_divergence_z2_mean, + kl_divergence_l_mean = kl_divergence_l_mean, classification_loss=classifier_loss, ) + print('Hi') + + return LossRecorder(loss, reconst_loss, kl_divergence, + loss_z1_weight_mean = loss_z1_weight_mean, + kl_divergence_cat_mean = kl_divergence_cat_mean, + loss_z1_unweight_mean = loss_z1_unweight_mean, + kl_divergence_z2_mean = kl_divergence_z2_mean, + kl_divergence_l_mean = kl_divergence_l_mean, +) - return LossRecorder(loss, reconst_loss, kl_divergence) From ee8f44e3c895460cbb2df278de7a3e683d331821 Mon Sep 17 00:00:00 2001 From: NathanAzoL Date: Tue, 13 Sep 2022 15:56:17 -0700 Subject: [PATCH 12/12] 12/06 version --- scvi/__init__.py | 3 +++ scvi/module/_scanvae.py | 2 ++ scvi/train/_trainingplans.py | 20 +++++++++++++++----- tests/models/test_models.py | 10 ++++++++-- 4 files changed, 28 insertions(+), 7 deletions(-) diff --git a/scvi/__init__.py b/scvi/__init__.py index a9fbb7ca10..6162562a91 100644 --- a/scvi/__init__.py +++ b/scvi/__init__.py @@ -11,6 +11,9 @@ # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094 # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302 + +print('Modified version, 12/06') + try: import importlib.metadata as importlib_metadata except ModuleNotFoundError: diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index c7c913b7b8..fbb29b9605 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -382,6 +382,8 @@ def loss( classification_loss=classifier_loss, ) + # print('Hi') + return LossRecorder(loss, reconst_loss, kl_divergence, loss_z1_weight_mean = loss_z1_weight_mean, kl_divergence_cat_mean = kl_divergence_cat_mean, diff --git a/scvi/train/_trainingplans.py b/scvi/train/_trainingplans.py index a4db8a9f29..0cd5220733 100644 --- a/scvi/train/_trainingplans.py +++ b/scvi/train/_trainingplans.py @@ -558,6 +558,7 @@ def __init__( lr_factor: float = 0.6, lr_patience: int = 30, lr_threshold: float = 0.0, + mode="old", lr_scheduler_metric: Literal[ "elbo_validation", "reconstruction_loss_validation", "kl_local_validation" ] = "elbo_validation", @@ -577,14 +578,19 @@ def __init__( **loss_kwargs, ) self.loss_kwargs.update({"classification_ratio": classification_ratio}) - + self.mode = mode + def training_step(self, batch, batch_idx, optimizer_idx=0): # Potentially dangerous if batch is from a single dataloader with two keys - if len(batch) == 2: + if self.mode == "old": + cdt = len(batch) == 2 + else: + cdt = batch_idx % 2 == 0 + if cdt: full_dataset = batch[0] labelled_dataset = batch[1] else: - full_dataset = batch + full_dataset = batch[0] labelled_dataset = None if "kl_weight" in self.loss_kwargs: @@ -607,11 +613,15 @@ def training_step(self, batch, batch_idx, optimizer_idx=0): def validation_step(self, batch, batch_idx, optimizer_idx=0): # Potentially dangerous if batch is from a single dataloader with two keys - if len(batch) == 2: + if self.mode == "old": + cdt = len(batch) == 2 + else: + cdt = batch_idx % 2 == 0 + if cdt: full_dataset = batch[0] labelled_dataset = batch[1] else: - full_dataset = batch + full_dataset = batch[0] labelled_dataset = None input_kwargs = dict( diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 253f67fd62..4cabe90571 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -861,10 +861,16 @@ def test_scanvi(save_path): "label_0", batch_key="batch", ) - model = SCANVI(adata, n_latent=10) + model = SCANVI(adata, n_latent=10, n_version=0) assert len(model._labeled_indices) == sum(adata.obs["labels"] != "label_0") assert len(model._unlabeled_indices) == sum(adata.obs["labels"] == "label_0") - model.train(1, train_size=0.5, check_val_every_n_epoch=1) + model.train(1, train_size=0.5, check_val_every_n_epoch=1, plan_kwargs=dict(mode="old")) + + model = SCANVI(adata, n_latent=10, n_version = 1) + assert len(model._labeled_indices) == sum(adata.obs["labels"] != "label_0") + assert len(model._unlabeled_indices) == sum(adata.obs["labels"] == "label_0") + model.train(1, train_size=0.5, check_val_every_n_epoch=1, plan_kwargs=dict(mode="new")) + logged_keys = model.history.keys() assert "elbo_validation" in logged_keys assert "reconstruction_loss_validation" in logged_keys