diff --git a/src/methods/scgpt/config.vsh.yaml b/src/methods/scgpt/config.vsh.yaml new file mode 100644 index 00000000..a6a72835 --- /dev/null +++ b/src/methods/scgpt/config.vsh.yaml @@ -0,0 +1,59 @@ +__merge__: ../../api/comp_method.yaml + +name: scgpt +label: scGPT +summary: "A foundation model for single-cell biology" +description: | + scGPT is a foundation model for single-cell biology based on a generative + pre-trained transformer and trained on a repository of over 33 million cells. + Here, we use zero-shot output from a pre-trained model to get an integrated + embedding for the batch integration task. +references: + doi: + - 10.1038/s41592-024-02201-0 +links: + documentation: https://scgpt.readthedocs.io/en/latest/ + repository: https://github.com/bowang-lab/scGPT + +info: + method_types: [embedding] + preferred_normalization: counts + variants: + scgpt_default: + scgpt_cp: + model: "scGPT_CP" + +arguments: + - name: --model + type: string + description: String giving the scGPT model to use + choices: ["scGPT_human", "scGPT_CP"] + default: "scGPT_human" + - name: --n_hvg + type: integer + default: 3000 + description: Number of highly variable genes to use. + +resources: + - type: python_script + path: script.py + - path: /src/utils/read_anndata_partial.py + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1.0.0 + # TODO: Try to find working installation of flash attention (flash-attn<1.0.5) + setup: + - type: python + pypi: + - gdown + - scgpt # Install from PyPI to get dependencies + - type: docker + # Force re-installing from GitHub to get bug fixes + run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git + +runners: + - type: executable + - type: nextflow + directives: + label: [midtime, midmem, midcpu, gpu] diff --git a/src/methods/scgpt/script.py b/src/methods/scgpt/script.py new file mode 100644 index 00000000..37297958 --- /dev/null +++ b/src/methods/scgpt/script.py @@ -0,0 +1,92 @@ +import sys +import tempfile + +import anndata as ad +import gdown +import scgpt +import torch + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + "input": "resources_test/.../input.h5ad", + "output": "output.h5ad", + "model": "scGPT_human", + "n_hvg": 3000, +} +meta = {"name": "scgpt"} +## VIASH END + +print(f"====== scGPT version {scgpt.__version__} ======", flush=True) + +sys.path.append(meta["resources_dir"]) +from read_anndata_partial import read_anndata + +print("\n>>> Reading input files...", flush=True) +print(f"Input H5AD file: '{par['input']}'", flush=True) +adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") + +if adata.uns["dataset_organism"] != "homo_sapiens": + raise ValueError( + f"scGPT can only be used with human data " + f"(dataset_organism == \"{adata.uns['dataset_organism']}\")" + ) + +print(adata, flush=True) + +print("\n>>> Preprocessing data...", flush=True) +if par["n_hvg"]: + print(f"Selecting top {par['n_hvg']} highly variable genes", flush=True) + idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][: par["n_hvg"]] + adata = adata[:, idx].copy() + +print(adata, flush=True) + +print(f"\n>>> Downloading '{par['model']}' model...", flush=True) +model_drive_ids = { + "scGPT_human": "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y", + "scGPT_CP": "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB", +} +drive_path = f"https://drive.google.com/drive/folders/{model_drive_ids[par['model']]}" +model_dir = tempfile.TemporaryDirectory() +print(f"Downloading from '{drive_path}'", flush=True) +gdown.download_folder(drive_path, output=model_dir.name, quiet=True) +print(f"Model directory: '{model_dir.name}'", flush=True) + +print("\n>>> Embedding data...", flush=True) +device = "cuda" if torch.cuda.is_available() else "cpu" +print(f"Device: '{device}'", flush=True) +embedded = scgpt.tasks.embed_data( + adata, + model_dir.name, + gene_col="feature_name", + batch_size=64, + use_fast_transformer=False, # Disable fast-attn as not installed + device=device, + return_new_adata=True, +) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs=adata.obs[[]], + var=adata.var[[]], + obsm={ + "X_emb": embedded.X, + }, + uns={ + "dataset_id": adata.uns["dataset_id"], + "normalization_id": adata.uns["normalization_id"], + "method_id": meta["name"], + }, +) +print(output) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Cleaning up temporary directories...", flush=True) +model_dir.cleanup() + +print("\n>>> Done!", flush=True) diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index a4df6706..21917544 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -95,6 +95,7 @@ dependencies: - name: methods/scanorama - name: methods/scanvi - name: methods/scimilarity + - name: methods/scgpt - name: methods/scvi - name: methods/uce # metrics diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index b208df9a..be3cabd1 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -32,10 +32,11 @@ methods = [ scimilarity.run( args: [model: file("s3://openproblems-work/cache/scimilarity-model_v1.1.tar.gz")] ), + scgpt scvi, uce.run( args: [model: file("s3://openproblems-work/cache/uce-model-v5.zip")] - ), + ) ] // construct list of metrics