Skip to content

Commit

Permalink
Embed data with scGPT and save output
Browse files Browse the repository at this point in the history
  • Loading branch information
lazappi committed Nov 5, 2024
1 parent 3cd66a7 commit 1241309
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 14 deletions.
5 changes: 4 additions & 1 deletion src/methods/scgpt/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ engines:
- type: python
pypi:
- gdown
- scgpt
- scgpt # Install from PyPI to get dependencies
- type: docker
# Force re-installing from GitHub to get bug fixes
run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git

runners:
- type: executable
Expand Down
44 changes: 31 additions & 13 deletions src/methods/scgpt/script.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import anndata as ad
import scgpt
import sys
import gdown
import tempfile
import torch
import anndata as ad

## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
Expand Down Expand Up @@ -43,7 +44,7 @@

print(adata, flush=True)

print("\n>>> Downloading model...", flush=True)
print(f"\n>>> Downloading '{par['model']}' model...", flush=True)
model_drive_ids = {
"scGPT_human" : "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y",
"scGPT_CP" : "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB"
Expand All @@ -54,20 +55,37 @@
gdown.download_folder(drive_path, output=model_dir.name, quiet = True)
print(f"Model directory: '{model_dir.name}'", flush=True)

print('Preprocess data', flush=True)
# ... preprocessing ...

print('Train model', flush=True)
# ... train model ...

print('Generate predictions', flush=True)
# ... generate predictions ...
print("\n>>> Embedding data...", flush=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: '{device}'", flush=True)
embedded = scgpt.tasks.embed_data(
adata,
model_dir.name,
gene_col="feature_name",
batch_size=64,
use_fast_transformer=False, # Disable fast-attn as not installed
device = device,
return_new_adata = True
)

print("Write output AnnData to file", flush=True)
print("\n>>> Storing output...", flush=True)
output = ad.AnnData(

obs=adata.obs[[]],
var=adata.var[[]],
obsm={
"X_emb": embedded.X,
},
uns={
"dataset_id": adata.uns["dataset_id"],
"normalization_id": adata.uns["normalization_id"],
"method_id": meta["name"],
},
)
output.write_h5ad(par['output'], compression='gzip')
print(output)

print("\n>>> Writing output to file...", flush=True)
print(f"Output H5AD file: '{par['output']}'", flush=True)
output.write_h5ad(par["output"], compression="gzip")

print("\n>>> Cleaning up temporary directories...", flush=True)
model_dir.cleanup()
Expand Down

0 comments on commit 1241309

Please sign in to comment.