Skip to content

Commit

Permalink
Add model argument and download model files
Browse files Browse the repository at this point in the history
  • Loading branch information
lazappi committed Nov 5, 2024
1 parent b015193 commit 3cd66a7
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
11 changes: 11 additions & 0 deletions src/methods/scgpt/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,19 @@ links:
repository: https://github.com/bowang-lab/scGPT

info:
method_types: [embedding]
preferred_normalization: counts
variants:
scgpt_default:
scgpt_cp:
model: "scGPT_CP"

arguments:
- name: --model
type: string
description: String giving the scGPT model to use
choices: ["scGPT_human", "scGPT_CP"]
default: "scGPT_human"
- name: --n_hvg
type: integer
default: 3000
Expand All @@ -36,6 +46,7 @@ engines:
setup:
- type: python
pypi:
- gdown
- scgpt

runners:
Expand Down
27 changes: 23 additions & 4 deletions src/methods/scgpt/script.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import anndata as ad
import scgpt
import sys
import gdown
import tempfile

## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
'input': 'resources_test/.../input.h5ad',
'output': 'output.h5ad'
'output': 'output.h5ad',
"model" : "scGPT_human",
"n_hvg": 3000
}
meta = {
Expand All @@ -30,15 +33,26 @@
f"(dataset_organism == \"{adata.uns['dataset_organism']}\")"
)

print(adata)
print(adata, flush=True)

print("\n>>> Preprocessing data...")
print("\n>>> Preprocessing data...", flush=True)
if par["n_hvg"]:
print(f"Selecting top {par['n_hvg']} highly variable genes", flush=True)
idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][:par["n_hvg"]]
adata = adata[:, idx].copy()

print(adata)
print(adata, flush=True)

print("\n>>> Downloading model...", flush=True)
model_drive_ids = {
"scGPT_human" : "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y",
"scGPT_CP" : "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB"
}
drive_path = f"https://drive.google.com/drive/folders/{model_drive_ids[par['model']]}"
model_dir = tempfile.TemporaryDirectory()
print(f"Downloading from '{drive_path}'", flush=True)
gdown.download_folder(drive_path, output=model_dir.name, quiet = True)
print(f"Model directory: '{model_dir.name}'", flush=True)

print('Preprocess data', flush=True)
# ... preprocessing ...
Expand All @@ -54,3 +68,8 @@

)
output.write_h5ad(par['output'], compression='gzip')

print("\n>>> Cleaning up temporary directories...", flush=True)
model_dir.cleanup()

print("\n>>> Done!", flush=True)

0 comments on commit 3cd66a7

Please sign in to comment.