diff --git a/src/methods/scgpt/config.vsh.yaml b/src/methods/scgpt/config.vsh.yaml index f067df32..a6a72835 100644 --- a/src/methods/scgpt/config.vsh.yaml +++ b/src/methods/scgpt/config.vsh.yaml @@ -47,7 +47,10 @@ engines: - type: python pypi: - gdown - - scgpt + - scgpt # Install from PyPI to get dependencies + - type: docker + # Force re-installing from GitHub to get bug fixes + run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git runners: - type: executable diff --git a/src/methods/scgpt/script.py b/src/methods/scgpt/script.py index 02df025b..1d041a30 100644 --- a/src/methods/scgpt/script.py +++ b/src/methods/scgpt/script.py @@ -1,8 +1,9 @@ -import anndata as ad import scgpt import sys import gdown import tempfile +import torch +import anndata as ad ## VIASH START # Note: this section is auto-generated by viash at runtime. To edit it, make changes @@ -43,7 +44,7 @@ print(adata, flush=True) -print("\n>>> Downloading model...", flush=True) +print(f"\n>>> Downloading '{par['model']}' model...", flush=True) model_drive_ids = { "scGPT_human" : "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y", "scGPT_CP" : "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB" @@ -54,20 +55,37 @@ gdown.download_folder(drive_path, output=model_dir.name, quiet = True) print(f"Model directory: '{model_dir.name}'", flush=True) -print('Preprocess data', flush=True) -# ... preprocessing ... - -print('Train model', flush=True) -# ... train model ... - -print('Generate predictions', flush=True) -# ... generate predictions ... +print("\n>>> Embedding data...", flush=True) +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print(f"Device: '{device}'", flush=True) +embedded = scgpt.tasks.embed_data( + adata, + model_dir.name, + gene_col="feature_name", + batch_size=64, + use_fast_transformer=False, # Disable fast-attn as not installed + device = device, + return_new_adata = True +) -print("Write output AnnData to file", flush=True) +print("\n>>> Storing output...", flush=True) output = ad.AnnData( - + obs=adata.obs[[]], + var=adata.var[[]], + obsm={ + "X_emb": embedded.X, + }, + uns={ + "dataset_id": adata.uns["dataset_id"], + "normalization_id": adata.uns["normalization_id"], + "method_id": meta["name"], + }, ) -output.write_h5ad(par['output'], compression='gzip') +print(output) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") print("\n>>> Cleaning up temporary directories...", flush=True) model_dir.cleanup()