diff --git a/.github/workflows/test_linux.yml b/.github/workflows/test_linux.yml index bd48c9dd75..ffe9c7c668 100644 --- a/.github/workflows/test_linux.yml +++ b/.github/workflows/test_linux.yml @@ -53,7 +53,7 @@ jobs: DISPLAY: :42 COLUMNS: 120 run: | - coverage run -m pytest -v --color=yes + coverage run -m pytest -v --color=yes -m "not custom_dataloader" coverage report - uses: codecov/codecov-action@v4 diff --git a/.github/workflows/test_linux_custom_dataloader.yml b/.github/workflows/test_linux_custom_dataloader.yml new file mode 100644 index 0000000000..e388fd994f --- /dev/null +++ b/.github/workflows/test_linux_custom_dataloader.yml @@ -0,0 +1,89 @@ +name: test (custom dataloaders) + +on: + push: + branches: [main, "[0-9]+.[0-9]+.x"] + pull_request: + branches: [main, "[0-9]+.[0-9]+.x"] + types: [labeled, synchronize, opened] + schedule: + - cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + # if PR has label "custom_dataloader" or "all tests" or if scheduled or manually triggered + if: >- + ( + contains(github.event.pull_request.labels.*.name, 'custom_dataloader') || + contains(github.event.pull_request.labels.*.name, 'all tests') || + contains(github.event_name, 'schedule') || + contains(github.event_name, 'workflow_dispatch') + ) + + runs-on: ${{ matrix.os }} + + defaults: + run: + shell: bash -e {0} # -e to fail on error + + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.11"] + + name: integration + + env: + OS: ${{ matrix.os }} + PYTHON: ${{ matrix.python }} + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + cache: "pip" + cache-dependency-path: "**/pyproject.toml" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip wheel uv + python -m uv pip install --system "scvi-tools[tests] @ ." + python -m pip install scdataloader + python -m pip install cellxgene-census + python -m pip install tiledbsoma + python -m pip install s3fs + python -m pip install torchdata==0.9.0 + python -m pip install psutil + python -m pip install lamindb + python -m pip install bionty==0.51.0 + python -m pip install biomart + + - name: Install Specific Branch of Repository + env: + GH_TOKEN: ${{ secrets.GH_TOKEN }} + run: | + git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/" + git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/ori-kron-wis/cellxgene-census.git + git clone --single-branch --branch main https://github.com/jkobject/scDataLoader.git + + - name: Run specific custom dataloader pytest + env: + MPLBACKEND: agg + PLATFORM: ${{ matrix.os }} + DISPLAY: :42 + COLUMNS: 120 + run: | + coverage run -m pytest tests/dataloaders/test_custom_dataloader.py -v --color=yes --custom-dataloader-tests + coverage report + + - uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fa825537a..b031a70986 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ to [Semantic Versioning]. Full commit history is available in the representation learning in single-cell RNA sequencing data {pr}`3015`, {pr}`3091`. - Add {class}`scvi.external.RESOLVI` for bias correction in single-cell resolved spatial transcriptomics {pr}`3144`. +- Add support for using Lamin custom dataloaders with {class}`scvi.model.SCVI`, {pr}`2932`. #### Fixed diff --git a/cellxgene-census b/cellxgene-census new file mode 160000 index 0000000000..fac6581530 --- /dev/null +++ b/cellxgene-census @@ -0,0 +1 @@ +Subproject commit fac658153038767c31b4c1e7a0c3e258cc2b1262 diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 943703f938..b692ae37ff 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 943703f938c43ddc681e01c013d704db37fa3193 +Subproject commit b692ae37ff436002aa63ed13bc957c710f9a0d07 diff --git a/pyproject.toml b/pyproject.toml index 9f61089de9..1039b805b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,9 +95,11 @@ regseq = ["biopython>=1.81", "genomepy"] scanpy = ["scanpy>=1.10", "scikit-misc"] # for convinient files sharing pooch = ["pooch"] +# for custom dataloders +dataloaders = ["lamindb","biomart","bionty","cellxgene_lamin"] optional = [ - "scvi-tools[autotune,aws,hub,pooch,regseq,scanpy]" + "scvi-tools[autotune,aws,hub,pooch,regseq,scanpy,dataloaders]" ] tutorials = [ "cell2location", diff --git a/src/scvi/data/_utils.py b/src/scvi/data/_utils.py index f35a3e9bdc..044e20f341 100644 --- a/src/scvi/data/_utils.py +++ b/src/scvi/data/_utils.py @@ -16,6 +16,7 @@ from torch import as_tensor, sparse_csc_tensor, sparse_csr_tensor from scvi import REGISTRY_KEYS, settings +from scvi.utils import attrdict from . import _constants @@ -150,6 +151,14 @@ def _set_data_in_registry( setattr(adata, attr_name, attribute) +def _get_summary_stats_from_registry(registry: dict) -> attrdict: + summary_stats = {} + for field_registry in registry[_constants._FIELD_REGISTRIES_KEY].values(): + field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY] + summary_stats.update(field_summary_stats) + return attrdict(summary_stats) + + def _verify_and_correct_data_format(adata: AnnData, attr_name: str, attr_key: str | None): """Check data format and correct if necessary. diff --git a/src/scvi/dataloaders/__init__.py b/src/scvi/dataloaders/__init__.py index 302055c3d5..8c07c58f3b 100644 --- a/src/scvi/dataloaders/__init__.py +++ b/src/scvi/dataloaders/__init__.py @@ -3,6 +3,7 @@ from ._ann_dataloader import AnnDataLoader from ._concat_dataloader import ConcatDataLoader +from ._custom_dataloders import MappedCollectionDataModule from ._data_splitting import ( DataSplitter, DeviceBackedDataSplitter, @@ -20,4 +21,5 @@ "DataSplitter", "SemiSupervisedDataSplitter", "BatchDistributedSampler", + "MappedCollectionDataModule", ] diff --git a/src/scvi/dataloaders/_custom_dataloders.py b/src/scvi/dataloaders/_custom_dataloders.py new file mode 100644 index 0000000000..42a89ca161 --- /dev/null +++ b/src/scvi/dataloaders/_custom_dataloders.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import psutil +from lightning.pytorch import LightningDataModule +from torch.utils.data import DataLoader + +import scvi + +if TYPE_CHECKING: + import lamindb as ln + import numpy as np + + +class MappedCollectionDataModule(LightningDataModule): + def __init__( + self, + collection: ln.Collection, + batch_key: str | None = None, + label_key: str | None = None, + batch_size: int = 128, + **kwargs, + ): + self._batch_size = batch_size + self._batch_key = batch_key + self._label_key = label_key + self._parallel = kwargs.pop("parallel", True) + # here we initialize MappedCollection to use in a pytorch DataLoader + self._dataset = collection.mapped( + obs_keys=self._batch_key, parallel=self._parallel, **kwargs + ) + # need by scvi and lightning.pytorch + self._log_hyperparams = False + self.allow_zero_length_dataloader_with_multiple_devices = False + + def close(self): + self._dataset.close() + + def setup(self, stage): + pass + + def train_dataloader(self): + return self._create_dataloader(shuffle=True) + + def inference_dataloader(self): + """Dataloader for inference with `on_before_batch_transfer` applied.""" + dataloader = self._create_dataloader(shuffle=False, batch_size=4096) + return self._InferenceDataloader(dataloader, self.on_before_batch_transfer) + + def _create_dataloader(self, shuffle, batch_size=None): + if self._parallel: + num_workers = psutil.cpu_count() - 1 + worker_init_fn = self._dataset.torch_worker_init_fn + else: + num_workers = 0 + worker_init_fn = None + if batch_size is None: + batch_size = self._batch_size + return DataLoader( + self._dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + worker_init_fn=worker_init_fn, + ) + + @property + def n_obs(self) -> int: + return self._dataset.n_obs + + @property + def var_names(self) -> int: + return self._dataset.var_joint + + @property + def n_vars(self) -> int: + return self._dataset.n_vars + + @property + def n_batch(self) -> int: + if self._batch_key is None: + return 1 + return len(self._dataset.encoders[self._batch_key]) + + @property + def n_labels(self) -> int: + if self._label_key is None: + return 1 + return len(self._dataset.encoders[self._label_key]) + + @property + def labels(self) -> np.ndarray: + return self._dataset[self._label_key] + + @property + def registry(self) -> dict: + return { + "scvi_version": scvi.__version__, + "model_name": "SCVI", + "setup_args": { + "layer": None, + "batch_key": self._batch_key, + "labels_key": self._label_key, + "size_factor_key": None, + "categorical_covariate_keys": None, + "continuous_covariate_keys": None, + }, + "field_registries": { + "X": { + "data_registry": {"attr_name": "X", "attr_key": None}, + "state_registry": { + "n_obs": self.n_obs, + "n_vars": self.n_vars, + "column_names": self.var_names, + }, + "summary_stats": {"n_vars": self.n_vars, "n_cells": self.n_obs}, + }, + "batch": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"}, + "state_registry": { + "categorical_mapping": self.batch_keys, + "original_key": self._batch_key, + }, + "summary_stats": {"n_batch": self.n_batch}, + }, + "labels": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"}, + "state_registry": { + "categorical_mapping": self.label_keys, + "original_key": self._label_key, + "unlabeled_category": "unlabeled", + }, + "summary_stats": {"n_labels": self.n_labels}, + }, + "size_factor": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {}, + }, + "extra_categorical_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_categorical_covs": 0}, + }, + "extra_continuous_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_continuous_covs": 0}, + }, + }, + "setup_method_name": "setup_anndata", + } + + @property + def batch_keys(self) -> int: + if self._batch_key is None: + return None + return self._dataset.encoders[self._batch_key] + + @property + def label_keys(self) -> int: + if self._label_key is None: + return None + return self._dataset.encoders[self._label_key] + + def on_before_batch_transfer( + self, + batch, + dataloader_idx, + ): + X_KEY: str = "X" + BATCH_KEY: str = "batch" + LABEL_KEY: str = "labels" + + return { + X_KEY: batch["X"].float(), + BATCH_KEY: batch[self._batch_key][:, None] if self._batch_key is not None else None, + LABEL_KEY: 0, + } + + class _InferenceDataloader: + """Wrapper to apply `on_before_batch_transfer` during iteration.""" + + def __init__(self, dataloader, transform_fn): + self.dataloader = dataloader + self.transform_fn = transform_fn + + def __iter__(self): + for batch in self.dataloader: + yield self.transform_fn(batch, dataloader_idx=None) + + def __len__(self): + return len(self.dataloader) diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index 9ea0146acb..4b4b14da63 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -386,7 +386,8 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`, def __init__( self, - adata_manager: AnnDataManager, + adata_manager: AnnDataManager | None = None, + datamodule: pl.LightningDataModule | None = None, train_size: float | None = None, validation_size: float | None = None, shuffle_set_split: bool = True, diff --git a/src/scvi/external/resolvi/_model.py b/src/scvi/external/resolvi/_model.py index 3a5dbffbd0..d38a4ba9c9 100644 --- a/src/scvi/external/resolvi/_model.py +++ b/src/scvi/external/resolvi/_model.py @@ -98,7 +98,8 @@ class RESOLVI( def __init__( self, - adata: AnnData, + adata: AnnData | None, + registry: dict | None = None, n_hidden: int = 32, n_hidden_encoder: int = 128, n_latent: int = 10, diff --git a/src/scvi/external/stereoscope/_model.py b/src/scvi/external/stereoscope/_model.py index 05e1ad0bf7..99e6639f1b 100644 --- a/src/scvi/external/stereoscope/_model.py +++ b/src/scvi/external/stereoscope/_model.py @@ -53,7 +53,8 @@ class RNAStereoscope(UnsupervisedTrainingMixin, BaseModelClass): def __init__( self, - sc_adata: AnnData, + sc_adata: AnnData | None = None, + registry: dict | None = None, **model_kwargs, ): super().__init__(sc_adata) diff --git a/src/scvi/external/stereoscope/_module.py b/src/scvi/external/stereoscope/_module.py index eefb2eb139..f74977d3ec 100644 --- a/src/scvi/external/stereoscope/_module.py +++ b/src/scvi/external/stereoscope/_module.py @@ -140,6 +140,7 @@ def __init__( n_spots: int, sc_params: tuple[np.ndarray], prior_weight: Literal["n_obs", "minibatch"] = "n_obs", + **model_kwargs, ): super().__init__() # unpack and copy parameters diff --git a/src/scvi/model/_amortizedlda.py b/src/scvi/model/_amortizedlda.py index 1817d96fc0..b2f7030914 100644 --- a/src/scvi/model/_amortizedlda.py +++ b/src/scvi/model/_amortizedlda.py @@ -61,7 +61,8 @@ class AmortizedLDA(PyroSviTrainMixin, BaseModelClass): def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, n_topics: int = 20, n_hidden: int = 128, cell_topic_prior: float | Sequence[float] | None = None, diff --git a/src/scvi/model/_autozi.py b/src/scvi/model/_autozi.py index 08b4e35131..b5cb50e6b3 100644 --- a/src/scvi/model/_autozi.py +++ b/src/scvi/model/_autozi.py @@ -104,7 +104,8 @@ class AUTOZI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, diff --git a/src/scvi/model/_condscvi.py b/src/scvi/model/_condscvi.py index b5e49a711d..38730fb735 100644 --- a/src/scvi/model/_condscvi.py +++ b/src/scvi/model/_condscvi.py @@ -67,7 +67,8 @@ class CondSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass) def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 5, n_layers: int = 2, diff --git a/src/scvi/model/_jaxscvi.py b/src/scvi/model/_jaxscvi.py index 2a212aad53..e82847c6f5 100644 --- a/src/scvi/model/_jaxscvi.py +++ b/src/scvi/model/_jaxscvi.py @@ -59,7 +59,8 @@ class JaxSCVI(JaxTrainingMixin, BaseModelClass): def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, dropout_rate: float = 0.1, diff --git a/src/scvi/model/_linear_scvi.py b/src/scvi/model/_linear_scvi.py index 0cbefc3968..cb09e3ce80 100644 --- a/src/scvi/model/_linear_scvi.py +++ b/src/scvi/model/_linear_scvi.py @@ -78,7 +78,8 @@ class LinearSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClas def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 61ffb04556..1450ea7a12 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -145,6 +145,7 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin): def __init__( self, adata: AnnOrMuData, + registry: dict | None = None, n_genes: int | None = None, n_regions: int | None = None, modality_weights: Literal["equal", "cell", "universal"] = "equal", diff --git a/src/scvi/model/_peakvi.py b/src/scvi/model/_peakvi.py index f10b39c771..2ec7e18258 100644 --- a/src/scvi/model/_peakvi.py +++ b/src/scvi/model/_peakvi.py @@ -92,7 +92,8 @@ class PEAKVI(ArchesMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, n_hidden: int | None = None, n_latent: int | None = None, n_layers_encoder: int = 2, diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 87a14dbd54..59f8034ac3 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -8,7 +8,9 @@ import numpy as np import pandas as pd import torch +from anndata import AnnData +import scvi from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager from scvi.data._constants import ( @@ -39,9 +41,14 @@ from typing import Literal from anndata import AnnData + from lightning import LightningDataModule from ._scvi import SCVI +_SCANVI_LATENT_QZM = "_scanvi_latent_qzm" +_SCANVI_LATENT_QZV = "_scanvi_latent_qzv" +_SCANVI_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size" + logger = logging.getLogger(__name__) @@ -75,6 +82,9 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass): * ``'nb'`` - Negative binomial distribution * ``'zinb'`` - Zero-inflated negative binomial distribution * ``'poisson'`` - Poisson distribution + use_observed_lib_size + If ``True``, use the observed library size for RNA as the scaling factor in the mean of the + conditional distribution. linear_classifier If ``True``, uses a single linear layer for classification instead of a multi-layer perceptron. @@ -106,35 +116,45 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass): def __init__( self, - adata: AnnData, + adata: AnnData | None, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", + use_observed_lib_size: bool = True, linear_classifier: bool = False, + datamodule: LightningDataModule | None = None, **model_kwargs, ): - super().__init__(adata) + super().__init__(adata, registry) scanvae_model_kwargs = dict(model_kwargs) - self._set_indices_and_labels() + self._set_indices_and_labels(datamodule) # ignores unlabeled catgegory n_labels = self.summary_stats.n_labels - 1 - n_cats_per_cov = ( - self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key - if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry - else None - ) + if adata is not None: + n_cats_per_cov = ( + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry + else None + ) + else: + # custom datamodule + n_cats_per_cov = self.summary_stats[f"n_{REGISTRY_KEYS.CAT_COVS_KEY}"] + if n_cats_per_cov == 0: + n_cats_per_cov = None n_batch = self.summary_stats.n_batch - use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry + use_size_factor_key = self.registry_["setup_args"][f"{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key"] library_log_means, library_log_vars = None, None if ( not use_size_factor_key and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR + and not use_observed_lib_size ): library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) @@ -151,6 +171,7 @@ def __init__( dispersion=dispersion, gene_likelihood=gene_likelihood, use_size_factor_key=use_size_factor_key, + use_observed_lib_size=use_observed_lib_size, library_log_means=library_log_means, library_log_vars=library_log_vars, linear_classifier=linear_classifier, @@ -178,6 +199,7 @@ def from_scvi_model( unlabeled_category: str, labels_key: str | None = None, adata: AnnData | None = None, + registry: dict | None = None, **scanvi_kwargs, ): """Initialize scanVI model with weights from pretrained :class:`~scvi.model.SCVI` model. @@ -194,6 +216,8 @@ def from_scvi_model( Value used for unlabeled cells in `labels_key` used to setup AnnData with scvi. adata AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. + registry + Registry of the datamodule used to train scANVI model. scanvi_kwargs kwargs for scANVI model """ @@ -223,13 +247,15 @@ def from_scvi_model( if adata is None: adata = scvi_model.adata - else: + elif adata: if _is_minified(adata): raise ValueError("Please provide a non-minified `adata` to initialize scANVI.") # validate new anndata against old model scvi_model._validate_anndata(adata) + else: + adata = None - scvi_setup_args = deepcopy(scvi_model.adata_manager.registry[_SETUP_ARGS_KEY]) + scvi_setup_args = deepcopy(scvi_model.registry[_SETUP_ARGS_KEY]) scvi_labels_key = scvi_setup_args["labels_key"] if labels_key is None and scvi_labels_key is None: raise ValueError( @@ -237,35 +263,40 @@ def from_scvi_model( ) if scvi_labels_key is None: scvi_setup_args.update({"labels_key": labels_key}) - cls.setup_anndata( - adata, - unlabeled_category=unlabeled_category, - use_minified=False, - **scvi_setup_args, - ) - scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs) + if adata is not None: + cls.setup_anndata( + adata, + unlabeled_category=unlabeled_category, + use_minified=False, + **scvi_setup_args, + ) + + scanvi_model = cls(adata, scvi_model.registry, **non_kwargs, **kwargs, **scanvi_kwargs) scvi_state_dict = scvi_model.module.state_dict() scanvi_model.module.load_state_dict(scvi_state_dict, strict=False) scanvi_model.was_pretrained = True return scanvi_model - def _set_indices_and_labels(self): + def _set_indices_and_labels(self, datamodule=None): """Set indices for labeled and unlabeled cells.""" - labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) + labels_state_registry = self.get_state_registry(REGISTRY_KEYS.LABELS_KEY) self.original_label_key = labels_state_registry.original_key self.unlabeled_category_ = labels_state_registry.unlabeled_category - labels = get_anndata_attribute( - self.adata, - self.adata_manager.data_registry.labels.attr_name, - self.original_label_key, - ).ravel() + if datamodule is None: + self.labels_ = get_anndata_attribute( + self.adata, + self.adata_manager.data_registry.labels.attr_name, + self.original_label_key, + ).ravel() + else: + self.labels_ = datamodule.labels.ravel() self._label_mapping = labels_state_registry.categorical_mapping # set unlabeled and labeled indices - self._unlabeled_indices = np.argwhere(labels == self.unlabeled_category_).ravel() - self._labeled_indices = np.argwhere(labels != self.unlabeled_category_).ravel() + self._unlabeled_indices = np.argwhere(self.labels_ == self.unlabeled_category_).ravel() + self._labeled_indices = np.argwhere(self.labels_ != self.unlabeled_category_).ravel() self._code_to_label = dict(enumerate(self._label_mapping)) def predict( @@ -357,6 +388,7 @@ def train( devices: int | list[int] | str = "auto", datasplitter_kwargs: dict | None = None, plan_kwargs: dict | None = None, + datamodule: LightningDataModule | None = None, **trainer_kwargs, ): """Train the model. @@ -391,6 +423,10 @@ def train( plan_kwargs Keyword args for :class:`~scvi.train.SemiSupervisedTrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. + datamodule + ``EXPERIMENTAL`` A :class:`~lightning.pytorch.core.LightningDataModule` instance to use + for training in place of the default :class:`~scvi.dataloaders.DataSplitter`. Can only + be passed in if the model was not initialized with :class:`~anndata.AnnData`. **trainer_kwargs Other keyword args for :class:`~scvi.train.Trainer`. """ @@ -406,17 +442,24 @@ def train( datasplitter_kwargs = datasplitter_kwargs or {} # if we have labeled cells, we want to subsample labels each epoch - sampler_callback = [SubSampleLabels()] if len(self._labeled_indices) != 0 else [] - - data_splitter = SemiSupervisedDataSplitter( - adata_manager=self.adata_manager, - train_size=train_size, - validation_size=validation_size, - shuffle_set_split=shuffle_set_split, - n_samples_per_label=n_samples_per_label, - batch_size=batch_size, - **datasplitter_kwargs, - ) + if datamodule is None: + sampler_callback = [SubSampleLabels()] if len(self._labeled_indices) != 0 else [] + # In the general case we enter here + datasplitter_kwargs = datasplitter_kwargs or {} + datamodule = SemiSupervisedDataSplitter( + adata_manager=self.adata_manager, + datamodule=datamodule, + train_size=train_size, + validation_size=validation_size, + shuffle_set_split=shuffle_set_split, + n_samples_per_label=n_samples_per_label, + batch_size=batch_size, + **datasplitter_kwargs, + ) + else: + # TODO fix in external dataloader? + sampler_callback = [] + training_plan = self._training_plan_cls(self.module, self.n_labels, **plan_kwargs) if "callbacks" in trainer_kwargs.keys(): trainer_kwargs["callbacks"] + [sampler_callback] @@ -426,7 +469,7 @@ def train( runner = TrainRunner( self, training_plan=training_plan, - data_splitter=data_splitter, + data_splitter=datamodule, max_epochs=max_epochs, accelerator=accelerator, devices=devices, @@ -475,9 +518,117 @@ def setup_anndata( NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] # register new fields if the adata is minified - adata_minify_type = _get_adata_minify_type(adata) - if adata_minify_type is not None and use_minified: - anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) - adata_manager.register_fields(adata, **kwargs) - cls.register_manager(adata_manager) + if adata: + adata_minify_type = _get_adata_minify_type(adata) + if adata_minify_type is not None and use_minified: + anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) + adata_manager.register_fields(adata, **kwargs) + cls.register_manager(adata_manager) + + @classmethod + @setup_anndata_dsp.dedent + def setup_datamodule( + cls, + datamodule: LightningDataModule | None = None, + source_registry=None, + layer: str | None = None, + batch_key: list[str] | None = None, + labels_key: str | None = None, + size_factor_key: str | None = None, + categorical_covariate_keys: list[str] | None = None, + continuous_covariate_keys: list[str] | None = None, + **kwargs, + ): + """%(summary)s. + + Parameters + ---------- + %(param_datamodule)s + %(param_source_registry)s + %(param_layer)s + %(param_batch_key)s + %(param_size_factor_key)s + %(param_cat_cov_keys)s + %(param_cont_cov_keys)s + """ + if datamodule.__class__.__name__ == "CensusSCVIDataModule": + # CZI + batch_mapping = datamodule.datapipe.obs_encoders["batch"].classes_ + labels_mapping = datamodule.datapipe.obs_encoders["label"].classes_ + features_names = list( + datamodule.datapipe.var_query.coords[0] + if datamodule.datapipe.var_query is not None + else range(datamodule.n_vars) + ) + n_batch = datamodule.n_batch + n_label = datamodule.n_label + + else: + # Anndata -> CZI + # if we are here and datamodule is actually an AnnData object + # it means we init the custom dataloder model with anndata + batch_mapping = source_registry["field_registries"]["batch"]["state_registry"][ + "categorical_mapping" + ] + labels_mapping = source_registry["field_registries"]["label"]["state_registry"][ + "categorical_mapping" + ] + features_names = datamodule.var.soma_joinid.values + n_batch = source_registry["field_registries"]["batch"]["summary_stats"]["n_batch"] + n_label = 1 # need to change + + datamodule.registry = { + "scvi_version": scvi.__version__, + "model_name": "SCVI", + "setup_args": { + "layer": layer, + "batch_key": batch_key, + "labels_key": labels_key, + "size_factor_key": size_factor_key, + "categorical_covariate_keys": categorical_covariate_keys, + "continuous_covariate_keys": continuous_covariate_keys, + }, + "field_registries": { + "X": { + "data_registry": {"attr_name": "X", "attr_key": None}, + "state_registry": { + "n_obs": datamodule.n_obs, + "n_vars": datamodule.n_vars, + "column_names": [str(i) for i in features_names], + }, + "summary_stats": {"n_vars": datamodule.n_vars, "n_cells": datamodule.n_obs}, + }, + "batch": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"}, + "state_registry": { + "categorical_mapping": batch_mapping, + "original_key": "batch", + }, + "summary_stats": {"n_batch": n_batch}, + }, + "labels": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"}, + "state_registry": { + "categorical_mapping": labels_mapping, + "original_key": "label", + "unlabeled_category": datamodule.unlabeled_category, + }, + "summary_stats": {"n_labels": n_label}, + }, + "size_factor": {"data_registry": {}, "state_registry": {}, "summary_stats": {}}, + "extra_categorical_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_categorical_covs": 0}, + }, + "extra_continuous_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_continuous_covs": 0}, + }, + }, + "setup_method_name": "setup_datamodule", + } diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 1eb23aa138..f89c3723c2 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -4,6 +4,9 @@ import warnings from typing import TYPE_CHECKING +import numpy as np + +import scvi from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager from scvi.data._constants import ADATA_MINIFY_TYPE @@ -26,6 +29,12 @@ from typing import Literal from anndata import AnnData + from lightning import LightningDataModule + + +_SCVI_LATENT_QZM = "_scvi_latent_qzm" +_SCVI_LATENT_QZV = "_scvi_latent_qzv" +_SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size" logger = logging.getLogger(__name__) @@ -69,6 +78,9 @@ class SCVI( * ``'zinb'`` - Zero-inflated negative binomial distribution * ``'poisson'`` - Poisson distribution * ``'normal'`` - ``EXPERIMENTAL`` Normal distribution + use_observed_lib_size + If ``True``, use the observed library size for RNA as the scaling factor in the mean of the + conditional distribution. latent_distribution One of: @@ -106,17 +118,19 @@ class SCVI( def __init__( self, - adata: AnnData | None = None, + adata: AnnData | None, + registry: dict | None = None, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", gene_likelihood: Literal["zinb", "nb", "poisson", "normal"] = "zinb", + use_observed_lib_size: bool = True, latent_distribution: Literal["normal", "ln"] = "normal", **kwargs, ): - super().__init__(adata) + super().__init__(adata, registry) self._module_kwargs = { "n_hidden": n_hidden, @@ -134,6 +148,7 @@ def __init__( f"dropout_rate: {dropout_rate}, dispersion: {dispersion}, " f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}." ) + self._module_init_on_train = False if self._module_init_on_train: self.module = None @@ -144,17 +159,29 @@ def __init__( stacklevel=settings.warnings_stacklevel, ) else: - n_cats_per_cov = ( - self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key - if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry - else None - ) + if adata is not None: + n_cats_per_cov = ( + self.adata_manager.get_state_registry( + REGISTRY_KEYS.CAT_COVS_KEY + ).n_cats_per_key + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry + else None + ) + else: + # custom datamodule + n_cats_per_cov = self.summary_stats[f"n_{REGISTRY_KEYS.CAT_COVS_KEY}"] + if n_cats_per_cov == 0: + n_cats_per_cov = None + n_batch = self.summary_stats.n_batch - use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry + use_size_factor_key = self.registry_["setup_args"][ + f"{REGISTRY_KEYS.SIZE_FACTOR_KEY}_key" + ] library_log_means, library_log_vars = None, None if ( not use_size_factor_key and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR + and not use_observed_lib_size ): library_log_means, library_log_vars = _init_library_size( self.adata_manager, n_batch @@ -171,6 +198,7 @@ def __init__( dropout_rate=dropout_rate, dispersion=dispersion, gene_likelihood=gene_likelihood, + use_observed_lib_size=use_observed_lib_size, latent_distribution=latent_distribution, use_size_factor_key=use_size_factor_key, library_log_means=library_log_means, @@ -222,3 +250,98 @@ def setup_anndata( adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) + + @classmethod + @setup_anndata_dsp.dedent + def setup_datamodule( + cls, + datamodule: LightningDataModule | None = None, + source_registry=None, + layer: str | None = None, + batch_key: list[str] | None = None, + labels_key: str | None = None, + size_factor_key: str | None = None, + categorical_covariate_keys: list[str] | None = None, + continuous_covariate_keys: list[str] | None = None, + **kwargs, + ): + """%(summary)s. + + Parameters + ---------- + %(param_datamodule)s + %(param_source_registry)s + %(param_layer)s + %(param_batch_key)s + %(param_labels_key)s + %(param_size_factor_key)s + %(param_cat_cov_keys)s + %(param_cont_cov_keys)s + """ + if datamodule.__class__.__name__ == "CensusSCVIDataModule": + # CZI + categorical_mapping = datamodule.datapipe.obs_encoders["batch"].classes_ + column_names = list( + datamodule.datapipe.var_query.coords[0] + if datamodule.datapipe.var_query is not None + else range(datamodule.n_vars) + ) + n_batch = datamodule.n_batch + else: + categorical_mapping = source_registry["field_registries"]["batch"]["state_registry"][ + "categorical_mapping" + ] + column_names = datamodule.var_names + n_batch = source_registry["field_registries"]["batch"]["summary_stats"]["n_batch"] + + datamodule.registry = { + "scvi_version": scvi.__version__, + "model_name": "SCVI", + "setup_args": { + "layer": layer, + "batch_key": batch_key, + "labels_key": labels_key, + "size_factor_key": size_factor_key, + "categorical_covariate_keys": categorical_covariate_keys, + "continuous_covariate_keys": continuous_covariate_keys, + }, + "field_registries": { + "X": { + "data_registry": {"attr_name": "X", "attr_key": None}, + "state_registry": { + "n_obs": datamodule.n_obs, + "n_vars": datamodule.n_vars, + "column_names": [str(i) for i in column_names], + }, + "summary_stats": {"n_vars": datamodule.n_vars, "n_cells": datamodule.n_obs}, + }, + "batch": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"}, + "state_registry": { + "categorical_mapping": categorical_mapping, + "original_key": "batch", + }, + "summary_stats": {"n_batch": n_batch}, + }, + "labels": { + "data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"}, + "state_registry": { + "categorical_mapping": np.array([0]), + "original_key": "_scvi_labels", + }, + "summary_stats": {"n_labels": 1}, + }, + "size_factor": {"data_registry": {}, "state_registry": {}, "summary_stats": {}}, + "extra_categorical_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_categorical_covs": 0}, + }, + "extra_continuous_covs": { + "data_registry": {}, + "state_registry": {}, + "summary_stats": {"n_extra_continuous_covs": 0}, + }, + }, + "setup_method_name": "setup_datamodule", + } diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index d23c533da6..fab5e8bb31 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -124,6 +124,7 @@ class TOTALVI( def __init__( self, adata: AnnOrMuData, + registry: dict | None = None, n_latent: int = 20, gene_dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", protein_dispersion: Literal["protein", "protein-batch", "protein-label"] = "protein", @@ -1224,7 +1225,8 @@ def get_protein_background_mean(self, adata, indices, batch_size): def setup_anndata( cls, adata: AnnData, - protein_expression_obsm_key: str, + registry: dict | None = None, + protein_expression_obsm_key: str | None = None, protein_names_uns_key: str | None = None, batch_key: str | None = None, layer: str | None = None, diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index 910809f8dc..841d4744b6 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import logging import warnings from copy import deepcopy +from typing import TYPE_CHECKING import anndata import numpy as np @@ -13,12 +16,10 @@ from torch.distributions import transform_to from scvi import settings -from scvi._types import AnnOrMuData from scvi.data import _constants from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME from scvi.model._utils import parse_device_args from scvi.model.base._save_load import ( - _get_var_names, _initialize_model, _load_saved_files, _validate_var_names, @@ -26,7 +27,10 @@ from scvi.nn import FCLayers from scvi.utils._docstrings import devices_dsp -from ._base_model import BaseModelClass +if TYPE_CHECKING: + from scvi._types import AnnOrMuData + + from ._base_model import BaseModelClass logger = logging.getLogger(__name__) @@ -40,8 +44,9 @@ class ArchesMixin: @devices_dsp.dedent def load_query_data( cls, - adata: AnnOrMuData, - reference_model: str | BaseModelClass, + adata: AnnOrMuData = None, + reference_model: str | BaseModelClass = None, + registry: dict = None, inplace_subset_query_vars: bool = False, accelerator: str = "auto", device: int | str = "auto", @@ -84,6 +89,11 @@ def load_query_data( freeze_classifier Whether to freeze classifier completely. Only applies to `SCANVI`. """ + if reference_model is None: + raise ValueError("Please provide a reference model as string or loaded model.") + if adata is None and registry is None: + raise ValueError("Please provide either an AnnData or a registry dictionary.") + _, _, device = parse_device_args( accelerator=accelerator, devices=device, @@ -92,50 +102,51 @@ def load_query_data( ) attr_dict, var_names, load_state_dict, pyro_param_store = _get_loaded_data( - reference_model, device=device + reference_model, device=device, adata=adata ) - if isinstance(adata, MuData): - for modality in adata.mod: + if adata: + if isinstance(adata, MuData): + for modality in adata.mod: + if inplace_subset_query_vars: + logger.debug(f"Subsetting {modality} query vars to reference vars.") + adata[modality]._inplace_subset_var(var_names[modality]) + _validate_var_names(adata[modality], var_names[modality]) + + else: if inplace_subset_query_vars: - logger.debug(f"Subsetting {modality} query vars to reference vars.") - adata[modality]._inplace_subset_var(var_names[modality]) - _validate_var_names(adata[modality], var_names[modality]) + logger.debug("Subsetting query vars to reference vars.") + adata._inplace_subset_var(var_names) + _validate_var_names(adata, var_names) - else: if inplace_subset_query_vars: logger.debug("Subsetting query vars to reference vars.") adata._inplace_subset_var(var_names) _validate_var_names(adata, var_names) - if inplace_subset_query_vars: - logger.debug("Subsetting query vars to reference vars.") - adata._inplace_subset_var(var_names) - _validate_var_names(adata, var_names) + registry = attr_dict.pop("registry_") + if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: + raise ValueError("It appears you are loading a model from a different class.") - registry = attr_dict.pop("registry_") - if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: - raise ValueError("It appears you are loading a model from a different class.") + if _SETUP_ARGS_KEY not in registry: + raise ValueError( + "Saved model does not contain original setup inputs. " + "Cannot load the original setup." + ) - if _SETUP_ARGS_KEY not in registry: - raise ValueError( - "Saved model does not contain original setup inputs. " - "Cannot load the original setup." + setup_method = getattr(cls, registry[_SETUP_METHOD_NAME]) + setup_method( + adata, + source_registry=registry, + extend_categories=True, + allow_missing_labels=True, + **registry[_SETUP_ARGS_KEY], ) - setup_method = getattr(cls, registry[_SETUP_METHOD_NAME]) - setup_method( - adata, - source_registry=registry, - extend_categories=True, - allow_missing_labels=True, - **registry[_SETUP_ARGS_KEY], - ) + model = _initialize_model(cls, adata, registry, attr_dict) - model = _initialize_model(cls, adata, attr_dict) - adata_manager = model.get_anndata_manager(adata, required=True) + version_split = model.registry[_constants._SCVI_VERSION_KEY].split(".") - version_split = adata_manager.registry[_constants._SCVI_VERSION_KEY].split(".") if int(version_split[1]) < 8 and int(version_split[0]) == 0: warnings.warn( "Query integration should be performed using models trained with version >= 0.8", @@ -152,6 +163,12 @@ def load_query_data( load_ten = load_ten.to(new_ten.device) if new_ten.size() == load_ten.size(): continue + # new categoricals changed size + else: + dim_diff = new_ten.size()[-1] - load_ten.size()[-1] + fixed_ten = torch.cat([load_ten, new_ten[..., -dim_diff:]], dim=-1) + load_state_dict[key] = fixed_ten + # TODO VERIFY THIS! fixed_ten = load_ten.clone() for dim in range(len(new_ten.shape)): if new_ten.size(dim) != load_ten.size(dim): @@ -405,7 +422,7 @@ def requires_grad(key): par.requires_grad = False -def _get_loaded_data(reference_model, device=None): +def _get_loaded_data(reference_model, device=None, adata=None): if isinstance(reference_model, str): attr_dict, var_names, load_state_dict, _ = _load_saved_files( reference_model, load_adata=False, map_location=device @@ -414,7 +431,7 @@ def _get_loaded_data(reference_model, device=None): else: attr_dict = reference_model._get_user_attributes() attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"} - var_names = _get_var_names(reference_model.adata) + var_names = reference_model.get_var_names() load_state_dict = deepcopy(reference_model.module.state_dict()) pyro_param_store = pyro.get_param_store().get_state() diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index 06c1b9a4c4..40db731c15 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -3,8 +3,10 @@ import inspect import logging import os +import sys import warnings from abc import ABCMeta, abstractmethod +from io import StringIO from typing import TYPE_CHECKING from uuid import uuid4 @@ -14,19 +16,29 @@ import torch from anndata import AnnData from mudata import MuData +from rich import box +from rich.console import Console from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager, fields from scvi.data._compat import registry_from_setup_dict from scvi.data._constants import ( _ADATA_MINIFY_TYPE_UNS_KEY, + _FIELD_REGISTRIES_KEY, _MODEL_NAME_KEY, _SCVI_UUID_KEY, + _SCVI_VERSION_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME, + _STATE_REGISTRY_KEY, ADATA_MINIFY_TYPE, ) -from scvi.data._utils import _assign_adata_uuid, _check_if_view, _get_adata_minify_type +from scvi.data._utils import ( + _assign_adata_uuid, + _check_if_view, + _get_adata_minify_type, + _get_summary_stats_from_registry, +) from scvi.dataloaders import AnnDataLoader from scvi.model._utils import parse_device_args from scvi.model.base._constants import SAVE_KEYS @@ -40,9 +52,13 @@ from scvi.utils import attrdict, setup_anndata_dsp from scvi.utils._docstrings import devices_dsp +from . import _constants + if TYPE_CHECKING: from collections.abc import Sequence + import pandas as pd + from scvi._types import AnnOrMuData, MinifiedDataType logger = logging.getLogger(__name__) @@ -94,7 +110,7 @@ class BaseModelClass(metaclass=BaseModelMetaClass): _OBSERVED_LIB_SIZE_KEY = "observed_lib_size" _data_loader_cls = AnnDataLoader - def __init__(self, adata: AnnOrMuData | None = None): + def __init__(self, adata: AnnOrMuData | None = None, registry: object | None = None): # check if the given adata is minified and check if the model being created # supports minified-data mode (i.e. inherits from the abstract BaseMinifiedModeModelClass). # If not, raise an error to inform the user of the lack of minified-data functionality @@ -110,8 +126,19 @@ def __init__(self, adata: AnnOrMuData | None = None): self._adata_manager = self._get_most_recent_anndata_manager(adata, required=True) self._register_manager_for_instance(self.adata_manager) # Suffix registry instance variable with _ to include it when saving the model. - self.registry_ = self._adata_manager.registry - self.summary_stats = self._adata_manager.summary_stats + self.registry_ = self._adata_manager._registry + self.summary_stats = _get_summary_stats_from_registry(self.registry_) + elif registry is not None: + self._adata = None + self._adata_manager = None + # Suffix registry instance variable with _ to include it when saving the model. + self.registry_ = registry + self.summary_stats = _get_summary_stats_from_registry(registry) + elif self.__class__.__name__ == "GIMVI": + # note some models do accept empty registry/adata (e.g: gimvi) + pass + else: + raise ValueError("adata or registry must be provided.") self._module_init_on_train = adata is None self.is_trained_ = False @@ -122,10 +149,24 @@ def __init__(self, adata: AnnOrMuData | None = None): self.history_ = None @property - def adata(self) -> AnnOrMuData: + def adata(self) -> None | AnnOrMuData: """Data attached to model instance.""" return self._adata + @property + def registry(self) -> dict: + """Data attached to model instance.""" + return self.registry_ + + def get_var_names(self, legacy_mudata_format=False) -> dict: + """Variable names of input data.""" + from scvi.model.base._save_load import _get_var_names + + if self.adata: + return _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format) + else: + return self.registry[_FIELD_REGISTRIES_KEY]["X"][_STATE_REGISTRY_KEY]["column_names"] + @adata.setter def adata(self, adata: AnnOrMuData): if adata is None: @@ -247,6 +288,23 @@ def _register_manager_for_instance(self, adata_manager: AnnDataManager): instance_manager_store = self._per_instance_manager_store[self.id] instance_manager_store[adata_id] = adata_manager + def data_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: + """Returns the object in AnnData associated with the key in the data registry. + + Parameters + ---------- + registry_key + key of object to get from ``self.data_registry`` + + Returns + ------- + The requested data. + """ + if not self.adata: + raise ValueError("self.adata is None. Please register AnnData object to access data.") + else: + return self._adata_manager.get_from_registry(registry_key) + def deregister_manager(self, adata: AnnData | None = None): """Deregisters the :class:`~scvi.data.AnnDataManager` instance associated with `adata`. @@ -339,6 +397,9 @@ def get_anndata_manager( If True, errors on missing manager. Otherwise, returns None when manager is missing. """ cls = self.__class__ + if not adata: + return None + if _SCVI_UUID_KEY not in adata.uns: if required: raise ValueError( @@ -478,6 +539,13 @@ def _validate_anndata( return adata + def transfer_fields(self, adata: AnnOrMuData, **kwargs) -> AnnData: + """Transfer fields from a model to an AnnData object.""" + if self.adata: + return self.adata_manager.transfer_fields(adata, **kwargs) + else: + raise ValueError("Model need to be initialized with AnnData to transfer fields.") + def _check_if_trained(self, warn: bool = True, message: str = _UNTRAINED_WARNING_MESSAGE): """Check if the model is trained. @@ -540,7 +608,7 @@ def _get_user_attributes(self): def _get_init_params(self, locals): """Returns the model init signature with associated passed in values. - Ignores the initial AnnData. + Ignores the initial AnnData or Registry. """ init = self.__init__ sig = inspect.signature(init) @@ -551,7 +619,9 @@ def _get_init_params(self, locals): all_params = { k: v for (k, v) in all_params.items() - if not isinstance(v, AnnData) and not isinstance(v, MuData) + if not isinstance(v, AnnData) + and not isinstance(v, MuData) + and k not in ("adata", "registry") } # not very efficient but is explicit # separates variable params (**kwargs) from non variable params into two dicts @@ -606,8 +676,6 @@ def save( anndata_write_kwargs Kwargs for :meth:`~anndata.AnnData.write` """ - from scvi.model.base._save_load import _get_var_names - if not os.path.exists(dir_path) or overwrite: os.makedirs(dir_path, exist_ok=overwrite) else: @@ -635,7 +703,7 @@ def save( model_state_dict = self.module.state_dict() model_state_dict["pyro_param_store"] = pyro.get_param_store().get_state() - var_names = _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format) + var_names = self.get_var_names(legacy_mudata_format=legacy_mudata_format) # get all the user attributes user_attributes = self._get_user_attributes() @@ -674,6 +742,7 @@ def load( It is not necessary to run setup_anndata, as AnnData is validated against the saved `scvi` setup dictionary. If None, will check for and load anndata saved with the model. + If False, will load the model without AnnData. %(param_accelerator)s %(param_device)s prefix @@ -712,32 +781,32 @@ def load( ) adata = new_adata if new_adata is not None else adata - _validate_var_names(adata, var_names) - registry = attr_dict.pop("registry_") if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__: raise ValueError("It appears you are loading a model from a different class.") - if _SETUP_ARGS_KEY not in registry: - raise ValueError( - "Saved model does not contain original setup inputs. " - "Cannot load the original setup." - ) - # Calling ``setup_anndata`` method with the original arguments passed into # the saved model. This enables simple backwards compatibility in the case of # newly introduced fields or parameters. - method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") - getattr(cls, method_name)(adata, source_registry=registry, **registry[_SETUP_ARGS_KEY]) + if adata: + if _SETUP_ARGS_KEY not in registry: + raise ValueError( + "Saved model does not contain original setup inputs. " + "Cannot load the original setup." + ) + _validate_var_names(adata, var_names) + method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") + getattr(cls, method_name)(adata, source_registry=registry, **registry[_SETUP_ARGS_KEY]) - model = _initialize_model(cls, adata, attr_dict) + model = _initialize_model(cls, adata, registry, attr_dict) pyro_param_store = model_state_dict.pop("pyro_param_store", None) model.module.on_load(model, pyro_param_store=pyro_param_store) model.module.load_state_dict(model_state_dict) model.to_device(device) model.module.eval() - model._validate_anndata(adata) + if adata: + model._validate_anndata(adata) return model @classmethod @@ -893,6 +962,149 @@ def view_anndata_setup( ) from err adata_manager.view_registry(hide_state_registries=hide_state_registries) + def view_setup_method_args(self) -> None: + """Prints setup kwargs used to produce a given registry. + + Parameters + ---------- + registry + Registry produced by an AnnDataManager. + """ + model_name = self.registry_[_MODEL_NAME_KEY] + setup_args = self.registry_[_SETUP_ARGS_KEY] + if model_name is not None and setup_args is not None: + rich.print(f"Setup via `{model_name}.setup_anndata` with arguments:") + rich.pretty.pprint(setup_args) + rich.print() + + def view_registry(self, hide_state_registries: bool = False) -> None: + """Prints summary of the registry. + + Parameters + ---------- + hide_state_registries + If True, prints a shortened summary without details of each state registry. + """ + version = self.registry_[_SCVI_VERSION_KEY] + rich.print(f"Anndata setup with scvi-tools version {version}.") + rich.print() + self.view_setup_method_args(self._registry) + + in_colab = "google.colab" in sys.modules + force_jupyter = None if not in_colab else True + console = rich.console.Console(force_jupyter=force_jupyter) + + ss = _get_summary_stats_from_registry(self._registry) + dr = self._get_data_registry_from_registry(self._registry) + console.print(self._view_summary_stats(ss)) + console.print(self._view_data_registry(dr)) + + if not hide_state_registries: + for field in self.fields: + state_registry = self.get_state_registry(field.registry_key) + t = field.view_state_registry(state_registry) + if t is not None: + console.print(t) + + def get_state_registry(self, registry_key: str) -> attrdict: + """Returns the state registry for the AnnDataField registered with this instance.""" + return attrdict(self.registry_[_FIELD_REGISTRIES_KEY][registry_key][_STATE_REGISTRY_KEY]) + + def get_setup_arg(self, setup_arg: str) -> attrdict: + """Returns the string provided to setup of a specific setup_arg.""" + return self.registry_[_SETUP_ARGS_KEY][setup_arg] + + @staticmethod + def _view_summary_stats( + summary_stats: attrdict, as_markdown: bool = False + ) -> rich.table.Table | str: + """Prints summary stats.""" + if not as_markdown: + t = rich.table.Table(title="Summary Statistics") + else: + t = rich.table.Table(box=box.MARKDOWN) + + t.add_column( + "Summary Stat Key", + justify="center", + style="dodger_blue1", + no_wrap=True, + overflow="fold", + ) + t.add_column( + "Value", + justify="center", + style="dark_violet", + no_wrap=True, + overflow="fold", + ) + for stat_key, count in summary_stats.items(): + t.add_row(stat_key, str(count)) + + if as_markdown: + console = Console(file=StringIO(), force_jupyter=False) + console.print(t) + return console.file.getvalue().strip() + + return t + + @staticmethod + def _view_data_registry( + data_registry: attrdict, as_markdown: bool = False + ) -> rich.table.Table | str: + """Prints data registry.""" + if not as_markdown: + t = rich.table.Table(title="Data Registry") + else: + t = rich.table.Table(box=box.MARKDOWN) + + t.add_column( + "Registry Key", + justify="center", + style="dodger_blue1", + no_wrap=True, + overflow="fold", + ) + t.add_column( + "scvi-tools Location", + justify="center", + style="dark_violet", + no_wrap=True, + overflow="fold", + ) + + for registry_key, data_loc in data_registry.items(): + mod_key = getattr(data_loc, _constants._DR_MOD_KEY, None) + attr_name = data_loc.attr_name + attr_key = data_loc.attr_key + scvi_data_str = "adata" + if mod_key is not None: + scvi_data_str += f".mod['{mod_key}']" + if attr_key is None: + scvi_data_str += f".{attr_name}" + else: + scvi_data_str += f".{attr_name}['{attr_key}']" + t.add_row(registry_key, scvi_data_str) + + if as_markdown: + console = Console(file=StringIO(), force_jupyter=False) + console.print(t) + return console.file.getvalue().strip() + + return t + + def update_setup_method_args(self, setup_method_args: dict): + """Update setup method args. + + Parameters + ---------- + setup_method_args + This is a bit of a misnomer, this is a dict representing kwargs + of the setup method that will be used to update the existing values + in the registry of this instance. + """ + self._registry[_SETUP_ARGS_KEY].update(setup_method_args) + class BaseMinifiedModeModelClass(BaseModelClass): """Abstract base class for scvi-tools models that can handle minified data.""" @@ -900,11 +1112,14 @@ class BaseMinifiedModeModelClass(BaseModelClass): @property def minified_data_type(self) -> MinifiedDataType | None: """The type of minified data associated with this model, if applicable.""" - return ( - self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) - if REGISTRY_KEYS.MINIFY_TYPE_KEY in self.adata_manager.data_registry - else None - ) + if self.adata_manager: + return ( + self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) + if REGISTRY_KEYS.MINIFY_TYPE_KEY in self.adata_manager.data_registry + else None + ) + else: + return None def minify_adata( self, diff --git a/src/scvi/model/base/_save_load.py b/src/scvi/model/base/_save_load.py index c990f4880c..736c260f80 100644 --- a/src/scvi/model/base/_save_load.py +++ b/src/scvi/model/base/_save_load.py @@ -102,7 +102,7 @@ def _load_saved_files( return attr_dict, var_names, model_state_dict, adata -def _initialize_model(cls, adata, attr_dict): +def _initialize_model(cls, adata, registry, attr_dict): """Helper to initialize a model.""" if "init_params_" not in attr_dict.keys(): raise ValueError( @@ -133,7 +133,10 @@ def _initialize_model(cls, adata, attr_dict): if "pretrained_model" in non_kwargs.keys(): non_kwargs.pop("pretrained_model") - model = cls(adata, **non_kwargs, **kwargs) + if not adata: + adata = None + + model = cls(adata, registry=registry, **non_kwargs, **kwargs) for attr, val in attr_dict.items(): setattr(model, attr, val) @@ -177,7 +180,9 @@ def _get_var_names( def _validate_var_names( - adata: AnnOrMuData, source_var_names: npt.NDArray | dict[str, npt.NDArray] + adata: AnnOrMuData | None, + source_var_names: npt.NDArray | dict[str, npt.NDArray], + load_var_names: npt.NDArray | dict[str, npt.NDArray] | None = None, ) -> None: """Validate that source and loaded variable names match. @@ -188,15 +193,19 @@ def _validate_var_names( source_var_names Variable names from a saved model file corresponding to the variable names used during training. + load_var_names + Variable names from the loaded registry. """ from numpy import array_equal - is_anndata = isinstance(adata, AnnData) source_per_mod_var_names = isinstance(source_var_names, dict) - load_var_names = _get_var_names( - adata, - legacy_mudata_format=(not is_anndata and not source_per_mod_var_names), - ) + + if load_var_names is None: + is_anndata = isinstance(adata, AnnData) + load_var_names = _get_var_names( + adata, + legacy_mudata_format=(not is_anndata and not source_per_mod_var_names), + ) if source_per_mod_var_names: valid_load_var_names = all( @@ -208,7 +217,7 @@ def _validate_var_names( if not valid_load_var_names: warnings.warn( - "`var_names` for the loaded `adata` does not match those of the `adata` used to " + "`var_names` for the loaded `model` does not match those used to " "train the model. For valid results, the former should match the latter.", UserWarning, stacklevel=settings.warnings_stacklevel, diff --git a/src/scvi/model/base/_training_mixin.py b/src/scvi/model/base/_training_mixin.py index ebace98445..e27727c029 100644 --- a/src/scvi/model/base/_training_mixin.py +++ b/src/scvi/model/base/_training_mixin.py @@ -84,27 +84,14 @@ def train( **kwargs Additional keyword arguments passed into :class:`~scvi.train.Trainer`. """ - if datamodule is not None and not self._module_init_on_train: - raise ValueError( - "Cannot pass in `datamodule` if the model was initialized with `adata`." - ) - elif datamodule is None and self._module_init_on_train: - raise ValueError( - "If the model was not initialized with `adata`, a `datamodule` must be passed in." - ) - if max_epochs is None: - if datamodule is None: + if self.adata is not None: max_epochs = get_max_epochs_heuristic(self.adata.n_obs) - elif hasattr(datamodule, "n_obs"): - max_epochs = get_max_epochs_heuristic(datamodule.n_obs) else: - raise ValueError( - "If `datamodule` does not have `n_obs` attribute, `max_epochs` must be " - "passed in." - ) + max_epochs = get_max_epochs_heuristic(self.summary_stats.n_obs) if datamodule is None: + # In the general case we enter here datasplitter_kwargs = datasplitter_kwargs or {} datamodule = self._data_splitter_cls( self.adata_manager, @@ -116,15 +103,6 @@ def train( load_sparse_tensor=load_sparse_tensor, **datasplitter_kwargs, ) - elif self.module is None: - self.module = self._module_cls( - datamodule.n_vars, - n_batch=datamodule.n_batch, - n_labels=getattr(datamodule, "n_labels", 1), - n_continuous_cov=getattr(datamodule, "n_continuous_cov", 0), - n_cats_per_cov=getattr(datamodule, "n_cats_per_cov", None), - **self._module_kwargs, - ) plan_kwargs = plan_kwargs or {} training_plan = self._training_plan_cls(self.module, **plan_kwargs) diff --git a/src/scvi/module/_scanvae.py b/src/scvi/module/_scanvae.py index 2028718d1c..be8cedd79b 100644 --- a/src/scvi/module/_scanvae.py +++ b/src/scvi/module/_scanvae.py @@ -66,6 +66,9 @@ class SCANVAE(VAE): * ``'nb'`` - Negative binomial distribution * ``'zinb'`` - Zero-inflated negative binomial distribution + use_observed_lib_size + If ``True``, use the observed library size for RNA as the scaling factor in the mean of the + conditional distribution. y_prior If None, initialized to uniform probability over cell types labels_groups @@ -102,6 +105,7 @@ def __init__( dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", log_variational: bool = True, gene_likelihood: Literal["zinb", "nb"] = "zinb", + use_observed_lib_size: bool = True, y_prior: torch.Tensor | None = None, labels_groups: Sequence[int] = None, use_labels_groups: bool = False, @@ -123,6 +127,7 @@ def __init__( dispersion=dispersion, log_variational=log_variational, gene_likelihood=gene_likelihood, + use_observed_lib_size=use_observed_lib_size, use_batch_norm=use_batch_norm, use_layer_norm=use_layer_norm, **vae_kwargs, diff --git a/tests/conftest.py b/tests/conftest.py index 6ef9467efc..2c6d29d1e1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,12 @@ def pytest_addoption(parser): default=False, help="Run tests that are desinged for multiGPU.", ) + parser.addoption( + "--custom-dataloader-tests", + action="store_true", + default=False, + help="Run tests that deals with custom dataloaders. This increases test time.", + ) parser.addoption( "--optional", action="store_true", @@ -72,6 +78,23 @@ def pytest_collection_modifyitems(config, items): elif run_internet and ("internet" not in item.keywords): item.add_marker(skip_non_internet) + run_custom_dataloader = config.getoption("--custom-dataloader-tests") + skip_custom_dataloader = pytest.mark.skip( + reason="need ---custom-dataloader-tests option to run" + ) + skip_non_custom_dataloader = pytest.mark.skip( + reason="test not having a pytest.mark.custom_dataloader decorator" + ) + for item in items: + # All tests marked with `pytest.mark.custom_dataloader` get skipped unless + # `--custom_dataloader-tests` passed + if not run_custom_dataloader and ("dataloader" in item.keywords): + item.add_marker(skip_custom_dataloader) + # Skip all tests not marked with `pytest.mark.custom_dataloader` + # if `--custom-dataloader-tests` passed + elif run_internet and ("dataloader" not in item.keywords): + item.add_marker(skip_non_custom_dataloader) + run_optional = config.getoption("--optional") skip_optional = pytest.mark.skip(reason="need --optional option to run") skip_non_optional = pytest.mark.skip(reason="test not having a pytest.mark.optional decorator") diff --git a/tests/dataloaders/test_custom_dataloader.py b/tests/dataloaders/test_custom_dataloader.py new file mode 100644 index 0000000000..4ea023a311 --- /dev/null +++ b/tests/dataloaders/test_custom_dataloader.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import os +from pprint import pprint +from time import time + +import pytest + +import scvi +from scvi.dataloaders import MappedCollectionDataModule +from scvi.utils import dependencies + + +@pytest.mark.custom_dataloader +@dependencies("lamindb") +def test_lamindb_dataloader_scvi_scanvi(save_path: str): + os.system("lamin init --storage ./test-registries") + import lamindb as ln + + ln.setup.init(name="lamindb_instance_name", storage=save_path) + + # a test for mapped collection + collection = ln.Collection.get(name="covid_normal_lung") + artifacts = collection.artifacts.all() + artifacts.df() + + datamodule = MappedCollectionDataModule( + collection, batch_key="assay", batch_size=1024, join="inner" + ) + model = scvi.model.SCVI(adata=None, registry=datamodule.registry) + pprint(model.summary_stats) + pprint(model.module) + inference_dataloader = datamodule.inference_dataloader() + + # Using regular adata laoder + # adata = collection.load() # try to compare this in regular settings + # # setup large + # SCVI.setup_anndata( + # adata, + # batch_key="assay", + # ) + # model_reg = SCVI(adata) + # start_time = time() + # model_reg.train(max_epochs=10, batch_size=1024) + # time_reg = time() - start_time + # print(time_reg) + + start_time = time() + model.train(max_epochs=10, batch_size=1024, datamodule=datamodule) + time_lamin = time() - start_time + print(time_lamin) + + _ = model.get_elbo(dataloader=inference_dataloader) + _ = model.get_marginal_ll(dataloader=inference_dataloader) + _ = model.get_reconstruction_error(dataloader=inference_dataloader) + _ = model.get_latent_representation(dataloader=inference_dataloader) + + model.save("lamin_model", save_anndata=False, overwrite=True) + model_query = model.load_query_data( + adata=False, reference_model="lamin_model", registry=datamodule.registry + ) + model_query.train(max_epochs=1, datamodule=datamodule) + _ = model_query.get_elbo(dataloader=inference_dataloader) + _ = model_query.get_marginal_ll(dataloader=inference_dataloader) + _ = model_query.get_reconstruction_error(dataloader=inference_dataloader) + _ = model_query.get_latent_representation(dataloader=inference_dataloader) + + adata = collection.load(join="inner") + model_query_adata = model.load_query_data(adata=adata, reference_model="lamin_model") + adata = collection.load(join="inner") + model_query_adata = model.load_query_data(adata) + model_query_adata.train(max_epochs=1) + _ = model_query_adata.get_elbo() + _ = model_query_adata.get_marginal_ll() + _ = model_query_adata.get_reconstruction_error() + _ = model_query_adata.get_latent_representation() + _ = model_query_adata.get_latent_representation(dataloader=inference_dataloader) + + model.save("lamin_model", save_anndata=False, overwrite=True) + model.load("lamin_model", adata=False) + model.load_query_data(adata=False, reference_model="lamin_model", registry=datamodule.registry) + + model.load_query_data(adata=adata, reference_model="lamin_model") + model_adata = model.load("lamin_model", adata=adata) + model_adata.train(max_epochs=1) + model_adata.save("lamin_model_anndata", save_anndata=False, overwrite=True) + model_adata.load("lamin_model_anndata", adata=False) + model_adata.load_query_data( + adata=False, reference_model="lamin_model_anndata", registry=datamodule.registry + ) diff --git a/tests/hub/test_hub_metadata.py b/tests/hub/test_hub_metadata.py index bf8cac951b..8bf93fa18d 100644 --- a/tests/hub/test_hub_metadata.py +++ b/tests/hub/test_hub_metadata.py @@ -11,6 +11,7 @@ def prep_model(): scvi.model.SCVI.setup_anndata(adata) model = scvi.model.SCVI(adata) model.train(1) + # model.init_params_["non_kwargs"].pop("datamodule") # scvi hub not supporting customdatamodule return model @@ -90,6 +91,7 @@ def test_hub_modelcardhelper(request, save_path): "dispersion": "gene", "gene_likelihood": "zinb", "latent_distribution": "normal", + "use_observed_lib_size": True, }, } assert hmch.model_setup_anndata_args == { diff --git a/tests/model/test_pyro.py b/tests/model/test_pyro.py index 07756011e1..03a2d1174f 100644 --- a/tests/model/test_pyro.py +++ b/tests/model/test_pyro.py @@ -138,7 +138,8 @@ def list_obs_plate_vars(self): class BayesianRegressionModel(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass): def __init__( self, - adata: AnnData, + adata: AnnData | None = None, + registry: dict | None = None, per_cell_weight=False, ): # in case any other model was created before that shares the same parameter names. diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index f094b05a1d..f6dc44b522 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -1122,32 +1122,9 @@ def test_scvi_no_anndata(n_batches: int = 3, n_latent: int = 5): datamodule.n_vars = adata.n_vars datamodule.n_batch = n_batches - model = SCVI(n_latent=5) - assert model._module_init_on_train - assert model.module is None - - # cannot infer default max_epochs without n_obs set in datamodule - with pytest.raises(ValueError): - model.train(datamodule=datamodule) - - # must pass in datamodule if not initialized with adata - with pytest.raises(ValueError): - model.train() - - model.train(max_epochs=1, datamodule=datamodule) - - # must set n_obs for defaulting max_epochs - datamodule.n_obs = 100_000_000 # large number for fewer default epochs - model.train(datamodule=datamodule) - - model = SCVI(adata, n_latent=5) - assert not model._module_init_on_train - assert model.module is not None - assert hasattr(model, "adata") - - # initialized with adata, cannot pass in datamodule - with pytest.raises(ValueError): - model.train(datamodule=datamodule) + with pytest.raises(TypeError) as excinfo: + SCVI(n_latent=n_latent) + assert str(excinfo.value) == "SCVI.__init__() missing 1 required positional argument: 'adata'" def test_scvi_no_anndata_with_external_indices(n_batches: int = 3, n_latent: int = 5): @@ -1170,32 +1147,9 @@ def test_scvi_no_anndata_with_external_indices(n_batches: int = 3, n_latent: int datamodule.n_vars = adata.n_vars datamodule.n_batch = n_batches - model = SCVI(n_latent=5) - assert model._module_init_on_train - assert model.module is None - - # cannot infer default max_epochs without n_obs set in datamodule - with pytest.raises(ValueError): - model.train(datamodule=datamodule) - - # must pass in datamodule if not initialized with adata - with pytest.raises(ValueError): - model.train() - - model.train(max_epochs=1, datamodule=datamodule) - - # must set n_obs for defaulting max_epochs - datamodule.n_obs = 100_000_000 # large number for fewer default epochs - model.train(datamodule=datamodule) - - model = SCVI(adata, n_latent=5) - assert not model._module_init_on_train - assert model.module is not None - assert hasattr(model, "adata") - - # initialized with adata, cannot pass in datamodule - with pytest.raises(ValueError): - model.train(datamodule=datamodule) + with pytest.raises(TypeError) as excinfo: + SCVI(n_latent=n_latent) + assert str(excinfo.value) == "SCVI.__init__() missing 1 required positional argument: 'adata'" @pytest.mark.parametrize("embedding_dim", [5, 10]) @@ -1261,6 +1215,23 @@ def test_scvi_inference_custom_dataloader(n_latent: int = 5): _ = model.get_latent_representation(dataloader=dataloader) +def test_scvi_train_custom_dataloader(n_latent: int = 5): + # ORI this function could help get started. + adata = synthetic_iid() + SCVI.setup_anndata(adata, batch_key="batch") + + model = SCVI(adata, n_latent=n_latent) + model.train(max_epochs=1) + dataloader = model._make_data_loader(adata) + # SCVI.setup_datamodule(dataloader) + # continue from here. Datamodule will always require to pass it into all downstream functions. + model.train(max_epochs=1, datamodule=dataloader) + _ = model.get_elbo(dataloader=dataloader) + _ = model.get_marginal_ll(dataloader=dataloader) + _ = model.get_reconstruction_error(dataloader=dataloader) + _ = model.get_latent_representation(dataloader=dataloader) + + def test_scvi_normal_likelihood(): import scanpy as sc