Skip to content

Commit

Permalink
Added scVI loader and model class
Browse files Browse the repository at this point in the history
  • Loading branch information
jesusCaraball0 committed Feb 14, 2025
1 parent 92a2c66 commit dd56ada
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 1 deletion.
29 changes: 29 additions & 0 deletions tdc/model_server/model_loaders/scvi_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
class scVILoader():
def __init__(self):
pass

def load(self, census_version):
import requests
import os

scvi_url = f"https://cellxgene-contrib-public.s3.us-west-2.amazonaws.com/
models/scvi/{census_version}/homo_sapiens/model.pt"
os.makedirs(os.path.join(os.getcwd(), 'scvi_model'), exist_ok=True)

output_path = os.path.join('scvi_model', 'model.pt')

try:
response = requests.get(scvi_url, verify=False)
if response.status_code == 404:
raise Exception('Census version not found, defaulting to version 2024-07-01')
except Exception as e:
print(e)
census_version = "2024-07-01"
scvi_url = f"https://cellxgene-contrib-public.s3.us-west-2.amazonaws.com/
models/scvi/2024-07-01/homo_sapiens/model.pt"
response = requests.get(scvi_url, verify=False)

with open(output_path, "wb") as file:
file.write(response.content)

print(f'scVI version {census_version} downloaded to {output_path} in current directory')
83 changes: 83 additions & 0 deletions tdc/model_server/models/scvi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch.nn as nn

class scVI(nn.Module):
def __init__(self):
import scvi as scvi_package

super().__init__()
self.model = None
self.var_names = None

def forward(self, adata):
try:
self.prepare_data(adata)

except Exception as error:
print("No var names found in SCVI reference vars")
print(f"adata.var.index must include some of {self.var_names}")

self.force_data_match(adata)

# getting variational autoencoder
vae_q = self.model.load_query_data(adata, self.model)
vae_q.is_trained = True

return vae_q.get_latent_representation() # or get normalized expression


def load(self):
import scvi as scvi_package
import os

from tdc.multi_pred.anndata_dataset import DataLoader
from model_server.model_loaders import scvi_loader

if not os.path.isdir("scvi_model"):
loader = scvi_loader.scVILoader()
loader.load("2024-07-01") # can add a var for a new version

adata = DataLoader("scvi_test_dataset",
"./data",
dataset_names=["scvi_test_dataset"],
no_convert=True).adata

# Matching adata shape and var names with SCVI's adata
self.force_data_match(adata)

#instantiate SCVI model (not callable, just used to get VAE)
self.model = scvi_package.model.SCVI.load('scvi_model', adata)
self.var_names = adata.var.index

print("loaded scVI:")
print(f"{self.model}")

return self.model

def prepare_data(self, adata):
import numpy as np
assert True in np.isin(adata.var.index, self.var_names)
# tutorials also usually have these. Not sure why
# adata.var["ensembl_id"] = adata.var.index
# adata.obs["n_counts"] = adata.X.sum(axis=1)
# adata.obs["joinid"] = list(range(adata.n_obs))
adata.obs["batch"] = "unassigned"
self.model.prepare_query_anndata(adata, self.model)

def force_data_match(self, adata):
import torch
import numpy as np

metadata = torch.load("scvi_model/model.pt", map_location=torch.device('cpu'))

# setting indices that match
adata.var.index = metadata[
"attr_dict"]["registry_"]["field_registries"]["X"]["state_registry"]["column_names"]

# padding X so dimensions match
additional_columns = np.zeros((adata.X.shape[0], 8000 - adata.X.shape[1]))
adata.X = np.hstack([adata.X, additional_columns])

# getting a batch name that matches
adata.obs['batch'] = metadata[
"attr_dict"]["registry_"]["field_registries"]["batch"]["state_registry"][
"categorical_mapping"][0]
7 changes: 6 additions & 1 deletion tdc/model_server/tdc_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
'CYP3A4_Veith-AttentiveFP',
]

model_hub = ["Geneformer", "scGPT"]
model_hub = ["Geneformer", "scGPT", "scVI"]


class tdc_hf_interface:
Expand Down Expand Up @@ -66,6 +66,11 @@ def load(self):
AutoModel.register(ScGPTConfig, ScGPTModel)
model = AutoModel.from_pretrained("tdc/scGPT")
return model
elif self.model_name == "scVI":
from .models.scvi import scVI
model = scVI()
model.load()
return model
raise Exception("Not implemented yet!")

def load_deeppurpose(self, save_path):
Expand Down
14 changes: 14 additions & 0 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,20 @@ def testGeneformerTokenizer(self):
"Geneformer ran sucessfully. Find batch embedding example here:\n {}"
.format(out[0]))

def testscVI(self):
from tdc.multi_pred.anndata_dataset import DataLoader
from tdc import tdc_hf_interface

adata = DataLoader("scvi_test_dataset",
"./data",
dataset_names=["scvi_test_dataset"],
no_convert=True).adata

scvi = tdc_hf_interface("scVI")
model = scvi.load()
output = model(adata)
print(f"scVI ran successfully. here is an ouput: {output}")

def tearDown(self):
try:
print(os.getcwd())
Expand Down

0 comments on commit dd56ada

Please sign in to comment.