Skip to content

Commit

Permalink
Style scGPT script
Browse files Browse the repository at this point in the history
  • Loading branch information
lazappi committed Nov 5, 2024
1 parent c4dfc2f commit 5ca0720
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions src/methods/scgpt/script.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
import scgpt
import sys
import gdown
import tempfile
import torch

import anndata as ad
import gdown
import scgpt
import torch

## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
'input': 'resources_test/.../input.h5ad',
'output': 'output.h5ad',
"model" : "scGPT_human",
"n_hvg": 3000
}
meta = {
'name': 'scgpt'
"input": "resources_test/.../input.h5ad",
"output": "output.h5ad",
"model": "scGPT_human",
"n_hvg": 3000,
}
meta = {"name": "scgpt"}
## VIASH END

print(f"====== scGPT version {scgpt.__version__} ======", flush=True)
Expand All @@ -39,33 +38,33 @@
print("\n>>> Preprocessing data...", flush=True)
if par["n_hvg"]:
print(f"Selecting top {par['n_hvg']} highly variable genes", flush=True)
idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][:par["n_hvg"]]
idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][: par["n_hvg"]]
adata = adata[:, idx].copy()

print(adata, flush=True)

print(f"\n>>> Downloading '{par['model']}' model...", flush=True)
model_drive_ids = {
"scGPT_human" : "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y",
"scGPT_CP" : "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB"
"scGPT_human": "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y",
"scGPT_CP": "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB",
}
drive_path = f"https://drive.google.com/drive/folders/{model_drive_ids[par['model']]}"
model_dir = tempfile.TemporaryDirectory()
print(f"Downloading from '{drive_path}'", flush=True)
gdown.download_folder(drive_path, output=model_dir.name, quiet = True)
gdown.download_folder(drive_path, output=model_dir.name, quiet=True)
print(f"Model directory: '{model_dir.name}'", flush=True)

print("\n>>> Embedding data...", flush=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: '{device}'", flush=True)
embedded = scgpt.tasks.embed_data(
adata,
model_dir.name,
gene_col="feature_name",
batch_size=64,
use_fast_transformer=False, # Disable fast-attn as not installed
device = device,
return_new_adata = True
use_fast_transformer=False, # Disable fast-attn as not installed
device=device,
return_new_adata=True,
)

print("\n>>> Storing output...", flush=True)
Expand Down

0 comments on commit 5ca0720

Please sign in to comment.