Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sema 1d endpoint #8

Merged
merged 2 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions apps/fastapi/src/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str
SUPERUSER_PASSWORD: str = Field(..., env="SUPERUSER_PASSWORD")
JWT_SECRET: str = Field(..., env="JWT_SECRET")
S3_BUCKET_NAME: str = Field(default="fastapi-csv", env="S3_BUCKET_NAME")
SAGEMAKER_ENDPOINT_NAME: str = Field(
default="huggingface-pytorch-inference-2024-10-16-20-16-41-824",
env="SAGEMAKER_ENDPOINT_NAME",
)

# Optional
HUGGINGFACE_ACCESS_TOKEN: str = Field(None, env="HUGGINGFACE_ACCESS_TOKEN")
Expand Down
15 changes: 15 additions & 0 deletions apps/fastapi/src/app/services/background_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from fastapi import HTTPException
from supabase._async.client import AsyncClient

from app.core.config import settings
from app.crud.crud_conformational_b_prediction import crud_conformational_b_prediction
from app.crud.crud_job import crud_job
from app.crud.crud_linear_b_prediction import crud_linear_b_prediction
from app.crud.crud_mhc_i_prediction import crud_mhc_i_prediction
from app.crud.crud_mhc_ii_prediction import crud_mhc_ii_prediction
from app.services.pipeline import run
from app.services.postprocess import (
process_conformational_b_prediction,
process_linear_b_prediction,
Expand All @@ -35,6 +37,19 @@ async def process_and_update_prediction(
# Update job status to 'running'
await crud_job.update_status(db=db, id=job_id, status="running")

# Sample input
glp_1_receptor = "MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS"
sample = {"inputs": glp_1_receptor}

# Run prediction pipeline
predictions = run(
input_objects=[sample],
endpoint_name=settings.SAGEMAKER_ENDPOINT_NAME,
model_type="single",
)

logger.info(predictions)

# Process the prediction based on its type
if prediction_type == "conformational-b":
if is_structure_based:
Expand Down
8 changes: 3 additions & 5 deletions apps/fastapi/src/app/services/inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json

import boto3
import pandas as pd

from app.core.config import settings

Expand All @@ -22,12 +21,11 @@ def get_predictions(requests, endpoint_name, model_name, model_type="mme"):
)
res = json.loads(response["Body"].read().decode("utf-8"))[0]
else:
# XGBoost by default does not allow application/json
response = sagemaker_runtime.invoke_endpoint(
ContentType="csv",
ContentType="application/json",
EndpointName=endpoint_name,
Body=pd.DataFrame(data=[request]).to_csv(index=False, header=False),
Body=json.dumps(request),
)
res = float(response["Body"].read().decode("utf-8"))
res = json.loads(response["Body"].read().decode("utf-8"))
responses.append(res)
return responses
19 changes: 11 additions & 8 deletions apps/fastapi/src/app/services/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from typing import List
from typing import Any, List

from app.core.utils import get_endpoint
from app.schemas.input_object import InputObject
from app.services.inference import get_predictions
from app.services.postprocess import do_postprocessing
from app.services.preprocess import do_preprocess


def run(
input_objects: List[InputObject],
input_objects: List[Any],
endpoint_name: str,
model_name=None,
model_type="mme",
Expand All @@ -30,12 +27,18 @@ def run(
if model_name is None:
model_name = endpoint_name

preprocessed_df = do_preprocess(input_objects)
# Preprocess the inputs
preprocess = [obj for obj in input_objects]

# Run inference
model_responses = get_predictions(
preprocessed_df.values.tolist(),
preprocess,
endpoint_name,
model_name,
model_type=model_type,
)
predictions = do_postprocessing(model_responses)

# Postprocess the inferences
predictions = model_responses

return predictions
217 changes: 117 additions & 100 deletions notebooks/protein_language_modeling.ipynb

Large diffs are not rendered by default.

Loading
Loading